Skip to content

Commit 063783b

Browse files
Merge pull request #65 from SciML/as/batched-getu
feat: add `BatchedInterface`
2 parents f43b850 + 6dba9ff commit 063783b

File tree

9 files changed

+433
-27
lines changed

9 files changed

+433
-27
lines changed

.github/workflows/ci.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ jobs:
1212
strategy:
1313
matrix:
1414
group:
15-
- All
15+
- Core
16+
- Downstream
1617
version:
1718
- '1'
1819
steps:
@@ -32,3 +33,5 @@ jobs:
3233
${{ runner.os }}-
3334
- uses: julia-actions/julia-buildpkg@v1
3435
- uses: julia-actions/julia-runtest@v1
36+
env:
37+
GROUP: ${{ matrix.group }}

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ Accessors = "0.1.36"
1515
Aqua = "0.8"
1616
ArrayInterface = "7.9"
1717
MacroTools = "0.5.13"
18-
RuntimeGeneratedFunctions = "0.5"
18+
Pkg = "1"
19+
RuntimeGeneratedFunctions = "0.5.12"
1920
SafeTestsets = "0.0.1"
2021
StaticArrays = "1.9"
2122
StaticArraysCore = "1.4"
@@ -24,9 +25,10 @@ julia = "1.10"
2425

2526
[extras]
2627
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
28+
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
2729
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
2830
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2931
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3032

3133
[targets]
32-
test = ["Aqua", "Test", "SafeTestsets", "StaticArrays"]
34+
test = ["Aqua", "Pkg", "Test", "SafeTestsets", "StaticArrays"]

docs/src/api.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,10 @@ symbolic_evaluate
9090
SymbolCache
9191
ProblemState
9292
```
93+
94+
### Batched Queries and Updates
95+
96+
```@docs
97+
BatchedInterface
98+
associated_systems
99+
```

src/SymbolicIndexingInterface.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ include("parameter_indexing.jl")
3131
export state_values, set_state!, current_time, getu, setu
3232
include("state_indexing.jl")
3333

34+
export BatchedInterface, associated_systems
35+
include("batched_interface.jl")
36+
3437
export ProblemState
3538
include("problem_state.jl")
3639

src/batched_interface.jl

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
"""
2+
struct BatchedInterface{S <: AbstractVector, I}
3+
function BatchedInterface(syssyms::Tuple...)
4+
5+
A struct which stores information for batched calls to [`getu`](@ref) or [`setu`](@ref).
6+
Given `Tuple`s, where the first element of each tuple is a system and the second an
7+
array of symbols (either variables or parameters) in the system, `BatchedInterface` will
8+
compute the union of all symbols and associate each symbol with the first system with
9+
which it occurs.
10+
11+
For example, given two systems `s1 = SymbolCache([:x, :y, :z])` and
12+
`s2 = SymbolCache([:y, :z, :w])`, `BatchedInterface((s1, [:x, :y]), (s2, [:y, :z]))` will
13+
associate `:x` and `:y` with `s1` and `:z` with `s2`. The information that `s1` had
14+
associated symbols `:x` and `:y` and `s2` had associated symbols `:y` and `:z` will also
15+
be retained internally.
16+
17+
`BatchedInterface` implements [`variable_symbols`](@ref), [`is_variable`](@ref),
18+
[`variable_index`](@ref) to query the order of symbols in the union.
19+
20+
See [`getu`](@ref) and [`setu`](@ref) for further details.
21+
22+
See also: [`associated_systems`](@ref).
23+
"""
24+
struct BatchedInterface{S <: AbstractVector, I, T}
25+
"Order of symbols in the union."
26+
symbol_order::S
27+
"Index of the system each symbol in the union is associated with."
28+
associated_systems::Vector{Int}
29+
"Index of symbol in the system it is associated with."
30+
associated_indexes::I
31+
"Whether the symbol is a state in the system it is associated with."
32+
isstate::BitVector
33+
"Map from system to indexes of its symbols in the union."
34+
system_to_symbol_subset::Vector{Vector{Int}}
35+
"Map from system to indexes of its symbols in the system."
36+
system_to_symbol_indexes::Vector{Vector{T}}
37+
"Map from system to whether each of its symbols is a state in the system."
38+
system_to_isstate::Vector{BitVector}
39+
end
40+
41+
function BatchedInterface(syssyms::Tuple...)
42+
symbol_order = []
43+
associated_systems = Int[]
44+
associated_indexes = []
45+
isstate = BitVector()
46+
system_to_symbol_subset = Vector{Int}[]
47+
system_to_symbol_indexes = []
48+
system_to_isstate = BitVector[]
49+
for (i, (sys, syms)) in enumerate(syssyms)
50+
symbol_subset = Int[]
51+
symbol_indexes = []
52+
system_isstate = BitVector()
53+
allsyms = []
54+
for sym in syms
55+
if symbolic_type(sym) === NotSymbolic()
56+
error("Only symbolic variables allowed in BatchedInterface.")
57+
end
58+
if symbolic_type(sym) === ArraySymbolic()
59+
append!(allsyms, collect(sym))
60+
else
61+
push!(allsyms, sym)
62+
end
63+
end
64+
for sym in allsyms
65+
if !is_variable(sys, sym) && !is_parameter(sys, sym)
66+
error("Only variables and parameters allowed in BatchedInterface.")
67+
end
68+
if !any(isequal(sym), symbol_order)
69+
push!(symbol_order, sym)
70+
push!(associated_systems, i)
71+
push!(isstate, is_variable(sys, sym))
72+
if isstate[end]
73+
push!(associated_indexes, variable_index(sys, sym))
74+
else
75+
push!(associated_indexes, parameter_index(sys, sym))
76+
end
77+
end
78+
push!(symbol_subset, findfirst(isequal(sym), symbol_order))
79+
push!(system_isstate, is_variable(sys, sym))
80+
push!(symbol_indexes,
81+
system_isstate[end] ? variable_index(sys, sym) : parameter_index(sys, sym))
82+
end
83+
push!(system_to_symbol_subset, symbol_subset)
84+
push!(system_to_symbol_indexes, identity.(symbol_indexes))
85+
push!(system_to_isstate, system_isstate)
86+
end
87+
symbol_order = identity.(symbol_order)
88+
associated_indexes = identity.(associated_indexes)
89+
system_to_symbol_indexes = identity.(system_to_symbol_indexes)
90+
91+
return BatchedInterface{typeof(symbol_order), typeof(associated_indexes),
92+
eltype(eltype(system_to_symbol_indexes))}(
93+
symbol_order, associated_systems, associated_indexes, isstate,
94+
system_to_symbol_subset, system_to_symbol_indexes, system_to_isstate)
95+
end
96+
97+
variable_symbols(bi::BatchedInterface) = bi.symbol_order
98+
variable_index(bi::BatchedInterface, sym) = findfirst(isequal(sym), bi.symbol_order)
99+
is_variable(bi::BatchedInterface, sym) = variable_index(bi, sym) !== nothing
100+
101+
"""
102+
associated_systems(bi::BatchedInterface)
103+
104+
Return an array of integers of the same length as `variable_symbols(bi)` where each value
105+
is the index of the system associated with the corresponding symbol in
106+
`variable_symbols(bi)`.
107+
"""
108+
associated_systems(bi::BatchedInterface) = bi.associated_systems
109+
110+
"""
111+
getu(bi::BatchedInterface)
112+
113+
Given a [`BatchedInterface`](@ref) composed from `n` systems (and corresponding symbols),
114+
return a function which takes `n` corresponding problems and returns an array of the values
115+
of the symbols in the union. The returned function can also be passed an `AbstractArray` of
116+
the appropriate `eltype` and size as its first argument, in which case the operation will
117+
populate the array in-place with the values of the symbols in the union.
118+
119+
Note that all of the problems passed to the function returned by `getu` must satisfy
120+
`is_timeseries(prob) === NotTimeseries()`.
121+
122+
The value of the `i`th symbol in the union (obtained through `variable_symbols(bi)[i]`) is
123+
obtained from the problem corresponding to the associated system (i.e. the problem at
124+
index `associated_systems(bi)[i]`).
125+
126+
See also: [`variable_symbols`](@ref), [`associated_systems`](@ref), [`is_timeseries`](@ref),
127+
[`NotTimeseries`](@ref).
128+
"""
129+
function getu(bi::BatchedInterface)
130+
numprobs = length(bi.system_to_symbol_subset)
131+
probnames = [Symbol(:prob, i) for i in 1:numprobs]
132+
133+
fnbody = quote end
134+
for (i, (prob, idx, isstate)) in enumerate(zip(
135+
bi.associated_systems, bi.associated_indexes, bi.isstate))
136+
symname = Symbol(:sym, i)
137+
getter = isstate ? state_values : parameter_values
138+
probname = probnames[prob]
139+
push!(fnbody.args, :($symname = $getter($probname, $idx)))
140+
end
141+
142+
oop_expr = Expr(:vect)
143+
for i in 1:length(bi.symbol_order)
144+
push!(oop_expr.args, Symbol(:sym, i))
145+
end
146+
147+
iip_expr = quote end
148+
for i in 1:length(bi.symbol_order)
149+
symname = Symbol(:sym, i)
150+
push!(iip_expr.args, :(out[$i] = $symname))
151+
end
152+
153+
oopfn = Expr(
154+
:function,
155+
Expr(:tuple, probnames...),
156+
quote
157+
$fnbody
158+
$oop_expr
159+
end
160+
)
161+
iipfn = Expr(
162+
:function,
163+
Expr(:tuple, :out, probnames...),
164+
quote
165+
$fnbody
166+
$iip_expr
167+
out
168+
end
169+
)
170+
171+
return let oop = @RuntimeGeneratedFunction(oopfn),
172+
iip = @RuntimeGeneratedFunction(iipfn)
173+
174+
_getter(probs...) = oop(probs...)
175+
_getter(out::AbstractArray, probs...) = iip(out, probs...)
176+
_getter
177+
end
178+
end
179+
180+
"""
181+
setu(bi::BatchedInterface)
182+
183+
Given a [`BatchedInterface`](@ref) composed from `n` systems (and corresponding symbols),
184+
return a function which takes `n` corresponding problems and an array of the values, and
185+
updates each of the problems with the values of the corresponding symbols.
186+
187+
Note that all of the problems passed to the function returned by `setu` must satisfy
188+
`is_timeseries(prob) === NotTimeseries()`.
189+
190+
Note that if any subset of the `n` systems share common symbols (among those passed to
191+
`BatchedInterface`) then all of the corresponding problems in the subset will be updated
192+
with the values of the common symbols.
193+
194+
See also: [`is_timeseries`](@ref), [`NotTimeseries`](@ref).
195+
"""
196+
function setu(bi::BatchedInterface)
197+
numprobs = length(bi.system_to_symbol_subset)
198+
probnames = [Symbol(:prob, i) for i in 1:numprobs]
199+
200+
full_update_fnexpr = let fnbody = quote end
201+
for (sys_idx, subset) in enumerate(bi.system_to_symbol_subset)
202+
probname = probnames[sys_idx]
203+
for (idx_in_subset, idx_in_union) in enumerate(subset)
204+
idx = bi.system_to_symbol_indexes[sys_idx][idx_in_subset]
205+
isstate = bi.system_to_isstate[sys_idx][idx_in_subset]
206+
setter = isstate ? set_state! : set_parameter!
207+
push!(fnbody.args, :($setter($probname, vals[$idx_in_union], $idx)))
208+
end
209+
# also run hook
210+
if !all(bi.system_to_isstate[sys_idx])
211+
paramidxs = [bi.system_to_symbol_indexes[sys_idx][idx_in_subset]
212+
for idx_in_subset in 1:length(subset)
213+
if !bi.system_to_isstate[sys_idx][idx_in_subset]]
214+
push!(fnbody.args, :($finalize_parameters_hook!($probname, $paramidxs)))
215+
end
216+
end
217+
push!(fnbody.args, :(return vals))
218+
Expr(
219+
:function,
220+
Expr(:tuple, probnames..., :vals),
221+
fnbody
222+
)
223+
end
224+
225+
partial_update_fnexpr = let fnbody = quote end
226+
curfnbody = fnbody
227+
for (sys_idx, subset) in enumerate(bi.system_to_symbol_subset)
228+
newcurfnbody = if sys_idx == 1
229+
Expr(:if, :(idx == $sys_idx))
230+
else
231+
Expr(:elseif, :(idx == $sys_idx))
232+
end
233+
push!(curfnbody.args, newcurfnbody)
234+
curfnbody = newcurfnbody
235+
236+
ifbody = quote end
237+
push!(curfnbody.args, ifbody)
238+
239+
probname = :prob
240+
for (idx_in_subset, idx_in_union) in enumerate(subset)
241+
idx = bi.system_to_symbol_indexes[sys_idx][idx_in_subset]
242+
isstate = bi.system_to_isstate[sys_idx][idx_in_subset]
243+
setter = isstate ? set_state! : set_parameter!
244+
push!(ifbody.args, :($setter($probname, vals[$idx_in_union], $idx)))
245+
end
246+
# also run hook
247+
if !all(bi.system_to_isstate[sys_idx])
248+
paramidxs = [bi.system_to_symbol_indexes[sys_idx][idx_in_subset]
249+
for idx_in_subset in 1:length(subset)
250+
if !bi.system_to_isstate[sys_idx][idx_in_subset]]
251+
push!(ifbody.args, :($finalize_parameters_hook!($probname, $paramidxs)))
252+
end
253+
end
254+
push!(curfnbody.args, :(error("Invalid problem index $idx")))
255+
push!(fnbody.args, :(return nothing))
256+
Expr(
257+
:function,
258+
Expr(:tuple, :prob, :idx, :vals),
259+
fnbody
260+
)
261+
end
262+
return let full_update = @RuntimeGeneratedFunction(full_update_fnexpr),
263+
partial_update = @RuntimeGeneratedFunction(partial_update_fnexpr)
264+
265+
setter!(args...) = full_update(args...)
266+
setter!(prob, idx::Int, vals::AbstractVector) = partial_update(prob, idx, vals)
267+
setter!
268+
end
269+
end

test/batched_interface_test.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
using SymbolicIndexingInterface
2+
3+
syss = [
4+
SymbolCache([:x, :y, :z], [:a, :b, :c], :t),
5+
SymbolCache([:z, :w, :v], [:c, :e, :f]),
6+
SymbolCache([:w, :x, :u], [:e, :a, :f])
7+
]
8+
syms = [
9+
[:x, :z, :b, :c],
10+
[:z, :w, :c, :f],
11+
[:w, :x, :e, :a]
12+
]
13+
probs = [
14+
ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3]),
15+
ProblemState(; u = [4.0, 5.0, 6.0], p = [0.4, 0.5, 0.6]),
16+
ProblemState(; u = [7.0, 8.0, 9.0], p = [0.7, 0.8, 0.9])
17+
]
18+
19+
@test_throws ErrorException BatchedInterface((syss[1], [:x, 3]))
20+
@test_throws ErrorException BatchedInterface((syss[1], [:(x + y)]))
21+
@test_throws ErrorException BatchedInterface((syss[1], [:t]))
22+
23+
bi = BatchedInterface(zip(syss, syms)...)
24+
@test variable_symbols(bi) == [:x, :z, :b, :c, :w, :f, :e, :a]
25+
@test variable_index.((bi,), [:a, :b, :c, :e, :f, :x, :y, :z, :w, :v, :u]) ==
26+
[8, 3, 4, 7, 6, 1, nothing, 2, 5, nothing, nothing]
27+
@test is_variable.((bi,), [:a, :b, :c, :e, :f, :x, :y, :z, :w, :v, :u]) ==
28+
Bool[1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0]
29+
@test associated_systems(bi) == [1, 1, 1, 1, 2, 2, 3, 3]
30+
31+
getter = getu(bi)
32+
@test (@inferred getter(probs...)) == [1.0, 3.0, 0.2, 0.3, 5.0, 0.6, 0.7, 0.8]
33+
buf = zeros(8)
34+
@inferred getter(buf, probs...)
35+
@test buf == [1.0, 3.0, 0.2, 0.3, 5.0, 0.6, 0.7, 0.8]
36+
37+
setter! = setu(bi)
38+
buf .*= 100
39+
setter!(probs..., buf)
40+
41+
@test state_values(probs[1]) == [100.0, 2.0, 300.0]
42+
# :a isn't updated here because it wasn't part of the symbols associated with syss[1] (syms[1])
43+
@test parameter_values(probs[1]) == [0.1, 20.0, 30.0]
44+
@test state_values(probs[2]) == [300.0, 500.0, 6.0]
45+
# Similarly for :e
46+
@test parameter_values(probs[2]) == [30.0, 0.5, 60.0]
47+
@test state_values(probs[3]) == [500.0, 100.0, 9.0]
48+
# Similarly for :f
49+
@test parameter_values(probs[3]) == [70.0, 80.0, 0.9]
50+
51+
buf ./= 100
52+
setter!(probs[1], 1, buf)
53+
@test state_values(probs[1]) == [1.0, 2.0, 3.0]
54+
@test parameter_values(probs[1]) == [0.1, 0.2, 0.3]
55+
56+
@test_throws ErrorException setter!(probs[1], 4, buf)

test/downstream/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[deps]
2+
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
3+
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"

0 commit comments

Comments
 (0)