Skip to content

Commit 129a596

Browse files
committed
Reorganization
1 parent 080c8a4 commit 129a596

File tree

10 files changed

+206
-347
lines changed

10 files changed

+206
-347
lines changed

README.md

Lines changed: 0 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -173,150 +173,3 @@ lotka_sol = solve(lotka_prob)
173173
Notice how cleanly the ```composed!``` function can pass variables from one function to another with no array index juggling in sight. This is especially useful for large models as it becomes harder to keep track top-level model array position when adding new or deleting old components from the model. We could go further and compose ```composed!``` with other components ad (practically) infinitum with no mental bookkeeping.
174174

175175
The main benefit, however, is now our differential equations are unit testable. Both ```lorenz``` and ```lotka``` can be run as their own ```ODEProblem``` with ```f``` set to zero to see the unforced response.
176-
177-
178-
### Control of a sliding block
179-
In this example, we'll build a model of a block sliding on a surface and use ```ComponentArray```s to easily switch between coulomb and equivalent viscous damping models. The block is controlled by pushing and pulling a spring attached to it and we will use feedback through a PID controller to try to track a reference signal. For simplification, we are using the velocity of the block directly for the derivative term, rather than taking a filtered derivative of the error signal. We are also setting a deadzone on the friction force with exponential decay to zero velocity to get rid of simulation chatter during the static friction regime.
180-
181-
```julia
182-
using ComponentArrays
183-
using DifferentialEquations
184-
using Interact: @manipulate
185-
using Parameters: @unpack
186-
using Plots
187-
188-
## Setup
189-
const g = 9.80665
190-
191-
maybe_apply(f::Function, x, p, t) = f(x, p, t)
192-
maybe_apply(f, x, p, t) = f
193-
194-
# Allows functions of form f(x,p,t) to be applied and passed in as inputs
195-
function simulator(func; kwargs...)
196-
simfun(dx, x, p, t) = func(dx, x, p, t; map(f->maybe_apply(f, x, p, t), (;kwargs...))...)
197-
simfun(x, p, t) = func(x, p, t; map(f->maybe_apply(f, x, p, t), (;kwargs...))...)
198-
return simfun
199-
end
200-
201-
softsign(x) = tanh(1e3x)
202-
203-
204-
## Dynamics update functions
205-
# Sliding block with viscous friction
206-
function viscous_block!(D, vars, p, t; u=0.0)
207-
@unpack m, c, k = p
208-
@unpack v, x = vars
209-
210-
D.x = v
211-
D.v = (-c*v + k*(u-x))/m
212-
return x
213-
end
214-
215-
# Sliding block with coulomb friction
216-
function coulomb_block!(D, vars, p, t; u=0.0)
217-
@unpack m, μ, k = p
218-
@unpack v, x = vars
219-
220-
D.x = v
221-
a = -μ*g*softsign(v) + k*(u-x)/m
222-
D.v = abs(a)<1e-3 && abs(v)<1e-3 ? -10v : a
223-
return x
224-
end
225-
226-
function PID_controller!(D, vars, p, t; err=0.0, v=0.0)
227-
@unpack kp, ki, kd = p
228-
@unpack x = vars
229-
230-
D.x = err
231-
return ki*x + kp*err + kd*v
232-
end
233-
234-
function feedback_sys!(D, components, p, t; ref=0.0)
235-
@unpack ctrl, plant = components
236-
237-
u = p.ctrl.fun(D.ctrl, ctrl, p.ctrl.params, t; err=ref-plant.x, v=-plant.v)
238-
return p.plant.fun(D.plant, plant, p.plant.params, t; u=u)
239-
end
240-
241-
step_input(;time=1.0, mag=1.0) = (x,p,t) -> t>time ? mag : 0
242-
sine_input(;mag=1.0, period=10.0) = (x,p,t) -> mag*sin(t*2π/period)
243-
244-
245-
## Interactive GUI for switching out plant models and varying PID gains
246-
@manipulate for kp in 0:0.01:15,
247-
ki in 0:0.01:15,
248-
kd in 0:0.01:15,
249-
damping in Dict(
250-
"Coulomb" => coulomb_block!,
251-
"Viscous" => viscous_block!,
252-
),
253-
reference in Dict(
254-
"Sine" => sine_input,
255-
"Step" => step_input,
256-
),
257-
magnitude in 0:0.01:10, # pop-pop!
258-
period in 1:0.01:30,
259-
plot_v in false
260-
261-
# Inputs
262-
tspan = (0.0, 30.0)
263-
264-
ctrl_fun = PID_controller!
265-
# plant_fun = coulomb_block!
266-
267-
ref = if reference==sine_input
268-
reference(period=period, mag=magnitude)
269-
else
270-
reference(mag=magnitude)
271-
end
272-
273-
m = 50.0
274-
μ = 0.1
275-
ω = 2π/period
276-
c = 4*μ*m*g/*ω*magnitude) # Viscous equivalent damping
277-
k = 50.0
278-
279-
plant_p = (m=m, μ=μ, c=c, k=k)
280-
ctrl_p = (kp=kp, ki=ki, kd=kd)
281-
282-
plant_ic = (v=0, x=0)
283-
ctrl_ic = (;x=0)
284-
285-
286-
287-
# Set up and solve
288-
sys_p = (
289-
ctrl = (
290-
params = ctrl_p,
291-
fun = ctrl_fun,
292-
),
293-
plant = (
294-
params = plant_p,
295-
fun = damping,
296-
),
297-
)
298-
sys_ic = ComponentArray(ctrl=ctrl_ic, plant=plant_ic)
299-
sys_fun = ODEFunction(simulator(feedback_sys!, ref=ref), syms=[:u, :v, :x])
300-
sys_prob = ODEProblem(sys_fun, sys_ic, tspan, sys_p)
301-
302-
sol = solve(sys_prob, Tsit5())
303-
304-
305-
# Plot
306-
t = tspan[1]:0.1:tspan[2]
307-
lims = magnitude*[-1, 1]
308-
plotvars = plot_v ? [3, 2] : [3]
309-
strip = plot(t, ref.(0, 0, t), ylim=1.2lims, label="r(t)")
310-
plot!(strip, sol, vars=plotvars)
311-
phase = plot(ref.(0, 0, t), map(x->x.plant.x, sol(t).u),
312-
xlim=lims,
313-
ylim=1.2lims,
314-
legend=false,
315-
xlabel="r(t)",
316-
ylabel="x(t)",
317-
)
318-
plot(strip, phase, layout=(2, 1), size=(700, 800))
319-
320-
end
321-
```
322-
<img src="assets/coulomb_control.png" alight="middle" />

src/ComponentArrays.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ const FlatOrColonIdx = Union{FlatIdx, Colon}
99

1010

1111
include("utils.jl")
12-
export fastindices
12+
export fastindices # Deprecated
1313

1414
include("lazyarray.jl")
1515

@@ -21,13 +21,19 @@ include("componentindex.jl")
2121
include("componentarray.jl")
2222
export ComponentArray, ComponentVector, ComponentMatrix, getaxes, getdata, valkeys
2323

24-
include("set_get.jl")
24+
include("array_interface.jl")
25+
# Base methods: parent, size, elsize, axes, reinterpret, hcat, vcat, permutedims, IndexStyle, to_indices, to_index, getindex, setindex!, view, pointer, unsafe_convert, strides, stride
26+
# ArrayInterface methods: strides, size, lu_instance, parent_type
27+
28+
include("namedtuple_interface.jl")
29+
# Base methods: hash, ==, keys, haskey, propertynames, getproperty, setproperty!
2530

2631
include("similar_convert_copy.jl")
32+
# Base methods: similar, zero, copy, copyto!, deepcopy, convert (to Array and NamedTuple), promote
2733

2834
include("broadcasting.jl")
29-
30-
include("math.jl")
35+
# Base methods: BroadcastStyle, convert(to Broadcasted{Nothing}), similar, map, dataids
36+
# Broadcast methods: BroadcastStyle, broadcasted, broadcast_unalias
3137

3238
include("show.jl")
3339

src/array_interface.jl

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
Base.parent(x::ComponentArray) = getfield(x, :data)
2+
3+
Base.size(x::ComponentArray) = size(getdata(x))
4+
ArrayInterface.size(A::ComponentArray) = ArrayInterface.size(parent(A))
5+
6+
Base.elsize(x::Type{<:ComponentArray{T,N,A,Axes}}) where {T,N,A,Axes} = Base.elsize(A)
7+
8+
Base.axes(x::ComponentArray) = axes(getdata(x))
9+
10+
Base.reinterpret(::Type{T}, x::ComponentArray, args...) where T = ComponentArray(reinterpret(T, getdata(x), args...), getaxes(x))
11+
12+
Base.hcat(x::CV...) where {CV<:ComponentVector} = ComponentArray(reduce(hcat, getdata.(x)), getaxes(x[1])[1], FlatAxis())
13+
14+
Base.vcat(x::ComponentVector, y::AbstractVector) = vcat(getdata(x), y)
15+
Base.vcat(x::AbstractVector, y::ComponentVector) = vcat(x, getdata(y))
16+
function Base.vcat(x::ComponentVector, y::ComponentVector)
17+
if reduce((accum, key) -> accum || (key in keys(x)), keys(y); init=false)
18+
return vcat(getdata(x), getdata(y))
19+
else
20+
data_x, data_y = getdata.((x, y))
21+
ax_x, ax_y = only.(getaxes.((x, y)))
22+
ax_y = reindex(ax_y, length(x))
23+
idxmap_x, idxmap_y = indexmap.((ax_x, ax_y))
24+
return ComponentArray(vcat(data_x, data_y), Axis((;idxmap_x..., idxmap_y...)))
25+
end
26+
end
27+
Base.vcat(x::CV...) where {CV<:AdjOrTransComponentArray} = ComponentArray(reduce(vcat, map(y->getdata(y.parent)', x)), getaxes(x[1]))
28+
Base.vcat(x::ComponentVector...) = reduce(vcat, x)
29+
Base.vcat(x::ComponentVector, args...) = vcat(getdata(x), getdata.(args)...)
30+
Base.vcat(x::ComponentVector, args::Vararg{AbstractVector{T}, N}) where {T,N} = vcat(getdata(x), getdata.(args)...)
31+
32+
function Base.permutedims(x::ComponentArray, dims)
33+
axs = getaxes(x)
34+
return ComponentArray(permutedims(getdata(x), dims), map(i->axs[i], dims)...)
35+
end
36+
37+
## Indexing
38+
Base.IndexStyle(::Type{<:ComponentArray{T,N,<:A,<:Axes}}) where {T,N,A,Axes} = IndexStyle(A)
39+
40+
# Since we aren't really using the standard approach to indexing, this will forward things to
41+
# the correct methods
42+
Base.to_indices(x::ComponentArray, i::Tuple) = i
43+
Base.to_indices(x::ComponentArray, i::Tuple{Vararg{Union{Integer, CartesianIndex}, N}}) where N = i
44+
Base.to_indices(x::ComponentArray, i::Tuple{Vararg{Int64}}) where N = i
45+
Base.to_index(x::ComponentArray, i) = i
46+
47+
# Get AbstractAxis index
48+
@inline Base.getindex(::AbstractAxis, idx::FlatIdx) = ComponentIndex(idx)
49+
@inline Base.getindex(ax::AbstractAxis, ::Colon) = ComponentIndex(:, ax)
50+
@inline Base.getindex(::AbstractAxis{IdxMap}, s::Symbol) where IdxMap =
51+
ComponentIndex(getproperty(IdxMap, s))
52+
53+
# Get ComponentArray index
54+
Base.@propagate_inbounds Base.getindex(x::ComponentArray, idx::CartesianIndex) = getdata(x)[idx]
55+
Base.@propagate_inbounds Base.getindex(x::ComponentArray, idx::FlatIdx...) = getdata(x)[idx...]
56+
Base.@propagate_inbounds function Base.getindex(x::ComponentArray, idx::FlatOrColonIdx...)
57+
axs = map((ax, i) -> getindex(ax, i).ax, getaxes(x), idx)
58+
axs = remove_nulls(axs...)
59+
return ComponentArray(getdata(x)[idx...], axs...)
60+
end
61+
Base.@propagate_inbounds Base.getindex(x::ComponentArray, ::Colon) = getdata(x)[:]
62+
@inline Base.getindex(x::ComponentArray, ::Colon...) = x
63+
Base.@propagate_inbounds Base.getindex(x::ComponentArray, idx...) = getindex(x, toval.(idx)...)
64+
@inline Base.getindex(x::ComponentArray, idx::Val...) = _getindex(x, idx...)
65+
66+
# Set ComponentArray index
67+
@inline Base.setindex!(x::ComponentArray, v, idx::FlatIdx...) = setindex!(getdata(x), v, idx...)
68+
Base.@propagate_inbounds Base.setindex!(x::ComponentArray, v, ::Colon) = setindex!(getdata(x), v, :)
69+
Base.@propagate_inbounds Base.setindex!(x::ComponentArray, v, idx...) = setindex!(x, v, toval.(idx)...)
70+
@inline Base.setindex!(x::ComponentArray, v, idx::Val...) = _setindex!(x, v, idx...)
71+
72+
# Explicitly view
73+
Base.@propagate_inbounds Base.view(x::ComponentArray, idx::ComponentArrays.FlatIdx...) = view(getdata(x), idx...)
74+
Base.@propagate_inbounds Base.view(x::ComponentArray, idx...) = _getindex(x, toval.(idx)...)
75+
76+
# Generated get and set index methods to do all of the heavy lifting in the type domain
77+
@generated function _getindex(x::ComponentArray, idx...)
78+
ci = getindex.(getaxes(x), getval.(idx))
79+
inds = map(i -> i.idx, ci)
80+
axs = map(i -> i.ax, ci)
81+
axs = remove_nulls(axs...)
82+
# the index must be valid after computing `ci`
83+
:(Base.@_inline_meta; ComponentArray(Base.maybeview(getdata(x), $inds...), $axs...))
84+
end
85+
86+
@generated function _setindex!(x::ComponentArray, v, idx...)
87+
ci = getindex.(getaxes(x), getval.(idx))
88+
inds = map(i -> i.idx, ci)
89+
# the index must be valid after computing `ci`
90+
return :(Base.@_inline_meta; setindex!(getdata(x), v, $inds...))
91+
end
92+
93+
## Linear Algebra
94+
Base.pointer(x::ComponentArray{T,N,A,Axes}) where {T,N,A<:DenseArray,Axes} = pointer(getdata(x))
95+
96+
Base.unsafe_convert(::Type{Ptr{T}}, x::ComponentArray{T,N,A,Axes}) where {T,N,A,Axes} = Base.unsafe_convert(Ptr{T}, getdata(x))
97+
98+
Base.strides(x::ComponentArray) = strides(getdata(x))
99+
ArrayInterface.strides(A::ComponentArray) = ArrayInterface.strides(parent(A))
100+
for f in [:device, :stride_rank, :contiguous_axis, :contiguous_batch_size, :dense_dims]
101+
@eval ArrayInterface.$f(::Type{ComponentArray{T,N,A,Axes}}) where {T,N,A,Axes} = ArrayInterface.$f(A)
102+
end
103+
104+
Base.stride(x::ComponentArray, k) = stride(getdata(x), k)
105+
Base.stride(x::ComponentArray, k::Int64) = stride(getdata(x), k)
106+
107+
ArrayInterface.lu_instance(jac_prototype::ComponentArray) = ArrayInterface.lu_instance(getdata(jac_prototype))
108+
109+
ArrayInterface.parent_type(::Type{ComponentArray{T,N,A,Axes}}) where {T,N,A,Axes} = A
110+
111+
112+
113+
# While there are some cases where these were faster, it is going to be almost impossible to
114+
# to keep up with method ambiguity errors due to other array types overloading *, /, and \.
115+
# Leaving these here and commented out for now, but will delete them later.
116+
117+
# # Avoid slower fallback
118+
# for f in [:(*), :(/), :(\)]
119+
# @eval begin
120+
# # The normal stuff
121+
# Base.$f(x::ComponentArray, y::AbstractArray) = $f(getdata(x), y)
122+
# Base.$f(x::AbstractArray, y::ComponentArray) = $f(x, getdata(y))
123+
# Base.$f(x::ComponentArray, y::ComponentArray) = $f(getdata(x), getdata(y))
124+
125+
# # A bunch of special cases to avoid ambiguous method errors
126+
# Base.$f(x::ComponentArray, y::AbstractMatrix) = $f(getdata(x), y)
127+
# Base.$f(x::AbstractMatrix, y::ComponentArray) = $f(x, getdata(y))
128+
129+
# Base.$f(x::ComponentArray, y::AbstractVector) = $f(getdata(x), y)
130+
# Base.$f(x::AbstractVector, y::ComponentArray) = $f(x, getdata(y))
131+
# end
132+
# end
133+
134+
# # Adjoint/transpose special cases
135+
# for f in [:(*), :(/)]
136+
# @eval begin
137+
# Base.$f(x::Adjoint, y::ComponentArray) = $f(getdata(x), getdata(y))
138+
# Base.$f(x::Transpose, y::ComponentArray) = $f(getdata(x), getdata(y))
139+
140+
# Base.$f(x::Adjoint{T,<:AbstractVector{T}}, y::ComponentVector) where T = $f(x, getdata(y))
141+
# Base.$f(x::Transpose{T,<:AbstractVector{T}}, y::ComponentVector) where T = $f(x, getdata(y))
142+
143+
# Base.$f(x::Adjoint{T,<:AbstractVector{T}}, y::ComponentMatrix{T,A,Axes}) where {T,A,Axes} = $f(x, getdata(y))
144+
# Base.$f(x::Transpose{T,<:AbstractVector{T}}, y::ComponentMatrix{T,A,Axes}) where {T,A,Axes} = $f(x, getdata(y))
145+
146+
# Base.$f(x::Adjoint{T,<:AbstractMatrix{T}}, y::ComponentVector) where {T} = $f(x, getdata(y))
147+
# Base.$f(x::Transpose{T,<:AbstractMatrix{T}}, y::ComponentVector) where {T} = $f(x, getdata(y))
148+
149+
# Base.$f(x::ComponentArray, y::Adjoint{T,<:AbstractVector{T}}) where T = $f(getdata(x), y)
150+
# Base.$f(x::ComponentArray, y::Transpose{T,<:AbstractVector{T}}) where T = $f(getdata(x), y)
151+
152+
# Base.$f(x::ComponentArray, y::Adjoint{T,<:ComponentVector}) where T = $f(getdata(x), getdata(y))
153+
# Base.$f(x::ComponentArray, y::Transpose{T,<:ComponentVector}) where T = $f(getdata(x), getdata(y))
154+
155+
# # There seems to be a new method in Julia > v.1.4 that specializes on this
156+
# Base.$f(x::Adjoint{T,<:AbstractVector{T}}, y::ComponentVector{T,A,Axes}) where {T<:Number,A,Axes} = $f(x, getdata(y))
157+
# Base.$f(x::Transpose{T,<:AbstractVector{T}}, y::ComponentVector{T,A,Axes}) where {T<:Number,A,Axes} = $f(x, getdata(y))
158+
159+
# Base.$f(x::Adjoint{T,<:AbstractVector{T}}, y::ComponentVector{T,A,Axes}) where {T<:Real,A,Axes} = $f(getdata(x), getdata(y))
160+
# Base.$f(x::Transpose{T,<:AbstractVector{T}}, y::ComponentVector{T,A,Axes}) where {T<:Real,A,Axes} = $f(getdata(x), getdata(y))
161+
# end
162+
# end

src/axis.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ const NullOrFlatView{Inds,IdxMap} = ViewAxis{Inds,IdxMap,<:NullorFlatAxis}
7878

7979
viewindex(::ViewAxis{Inds,IdxMap}) where {Inds,IdxMap} = Inds
8080
viewindex(::Type{<:ViewAxis{Inds,IdxMap}}) where {Inds,IdxMap} = Inds
81+
viewindex(i) = i
8182

8283

8384

@@ -132,4 +133,10 @@ const NotPartitionedAxis = Union{Axis{IdxMap}, FlatAxis, NullAxis, ShapedAxis{Sh
132133
const NotShapedOrPartitionedAxis = Union{Axis{IdxMap}, FlatAxis, NullAxis} where {IdxMap}
133134

134135

135-
Base.merge(axs::Axis...) = Axis(merge(indexmap.(axs)...))
136+
Base.merge(axs::Axis...) = Axis(merge(indexmap.(axs)...))
137+
138+
Base.lastindex(ax::AbstractAxis) = last(viewindex(last(indexmap(ax))))
139+
140+
reindex(i, offset) = i .+ offset
141+
reindex(ax::Axis, offset) = Axis(map(x->reindex(x, offset), indexmap(ax)))
142+
reindex(ax::ViewAxis, offset) = ViewAxis(viewindex(ax) .+ offset, indexmap(ax))

src/broadcasting.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ function Base.similar(bc::BC.Broadcasted{<:CAStyle{<:BC.Unknown, Axes, N}}, T::T
6969
end
7070

7171

72-
Base.Broadcast.broadcasted(f, x::ComponentArray) = ComponentArray(map(f, getdata(x)), getaxes(x))
72+
BC.broadcasted(f, x::ComponentArray) = ComponentArray(map(f, getdata(x)), getaxes(x))
7373

7474
# Need a special case here because `map` doesn't follow same rules as normal broadcasting. To be safe and avoid ambiguities,
7575
# we'll just handle the case where everything is a ComponentArray. Else it falls back to a plain Array output.

0 commit comments

Comments
 (0)