Skip to content

Commit 30b4759

Browse files
feat: implement remake_buffer for Tuples
1 parent b70f98f commit 30b4759

File tree

4 files changed

+52
-24
lines changed

4 files changed

+52
-24
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@ authors = ["Aayush Sabharwal <aayush.sabharwal@gmail.com> and contributors"]
44
version = "0.3.13"
55

66
[deps]
7+
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
78
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
89
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
910
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
1011
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1112

1213
[compat]
14+
Accessors = "0.1.36"
1315
Aqua = "0.8"
1416
ArrayInterface = "7.9"
1517
MacroTools = "0.5.13"

src/SymbolicIndexingInterface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import MacroTools
44
using RuntimeGeneratedFunctions
55
import StaticArraysCore: MArray, similar_type
66
import ArrayInterface
7+
using Accessors: @reset
78

89
RuntimeGeneratedFunctions.init(@__MODULE__)
910

src/remake.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,19 @@ function remake_buffer(sys, oldbuffer::AbstractArray, vals::Dict)
3030
end
3131
return newbuffer
3232
end
33+
34+
mutable struct TupleRemakeWrapper
35+
t::Tuple
36+
end
37+
38+
function set_parameter!(sys::TupleRemakeWrapper, val, idx)
39+
tp = sys.t
40+
@reset tp[idx] = val
41+
sys.t = tp
42+
end
43+
44+
function remake_buffer(sys, oldbuffer::Tuple, vals::Dict)
45+
wrap = TupleRemakeWrapper(oldbuffer)
46+
setu(sys, collect(keys(vals)))(wrap, values(vals))
47+
return wrap.t
48+
end

test/remake_test.jl

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,23 @@ using StaticArrays
44
sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)
55

66
for (buf, newbuf, newvals) in [
7-
# standard operation
8-
([1.0, 2.0, 3.0], [2.0, 3.0, 4.0],
9-
Dict(:x => 2.0, :y => 3.0, :z => 4.0))
10-
# buffer type "demotion"
11-
([1.0, 2.0, 3.0], [2, 3, 4],
12-
Dict(:x => 2, :y => 3, :z => 4))
13-
# buffer type promotion
14-
([1, 2, 3], [2.0, 3.0, 4.0],
15-
Dict(:x => 2.0, :y => 3.0, :z => 4.0))
16-
# value type promotion
17-
([1, 2, 3], [2.0, 3.0, 4.0],
18-
Dict(:x => 2, :y => 3.0, :z => 4.0))
19-
# standard operation
20-
([1.0, 2.0, 3.0], [2.0, 3.0, 4.0],
21-
Dict(:a => 2.0, :b => 3.0, :c => 4.0))
22-
# buffer type "demotion"
23-
([1.0, 2.0, 3.0], [2, 3, 4],
24-
Dict(:a => 2, :b => 3, :c => 4))
25-
# buffer type promotion
26-
([1, 2, 3], [2.0, 3.0, 4.0],
27-
Dict(:a => 2.0, :b => 3.0, :c => 4.0))
28-
# value type promotion
29-
([1, 2, 3], [2, 3.0, 4.0],
30-
Dict(:a => 2, :b => 3.0, :c => 4.0))]
7+
# standard operation
8+
([1.0, 2.0, 3.0], [2.0, 3.0, 4.0], Dict(:x => 2.0, :y => 3.0, :z => 4.0)),
9+
# buffer type "demotion"
10+
([1.0, 2.0, 3.0], [2, 3, 4], Dict(:x => 2, :y => 3, :z => 4)),
11+
# buffer type promotion
12+
([1, 2, 3], [2.0, 3.0, 4.0], Dict(:x => 2.0, :y => 3.0, :z => 4.0)),
13+
# value type promotion
14+
([1, 2, 3], [2.0, 3.0, 4.0], Dict(:x => 2, :y => 3.0, :z => 4.0)),
15+
# standard operation
16+
([1.0, 2.0, 3.0], [2.0, 3.0, 4.0], Dict(:a => 2.0, :b => 3.0, :c => 4.0)),
17+
# buffer type "demotion"
18+
([1.0, 2.0, 3.0], [2, 3, 4], Dict(:a => 2, :b => 3, :c => 4)),
19+
# buffer type promotion
20+
([1, 2, 3], [2.0, 3.0, 4.0], Dict(:a => 2.0, :b => 3.0, :c => 4.0)),
21+
# value type promotion
22+
([1, 2, 3], [2, 3.0, 4.0], Dict(:a => 2, :b => 3.0, :c => 4.0))
23+
]
3124
for arrType in [Vector, SVector{3}, MVector{3}, SizedVector{3}]
3225
buf = arrType(buf)
3326
newbuf = arrType(newbuf)
@@ -38,3 +31,19 @@ for (buf, newbuf, newvals) in [
3831
@test typeof(newbuf) == typeof(_newbuf) # ensure appropriate type
3932
end
4033
end
34+
35+
# Tuples not allowed for state
36+
for (buf, newbuf, newvals) in [
37+
# standard operation
38+
((1.0, 2.0, 3.0), (2.0, 3.0, 4.0), Dict(:a => 2.0, :b => 3.0, :c => 4.0)),
39+
# buffer type "demotion"
40+
((1.0, 2.0, 3.0), (2, 3, 4), Dict(:a => 2, :b => 3, :c => 4)),
41+
# buffer type promotion
42+
((1, 2, 3), (2.0, 3.0, 4.0), Dict(:a => 2.0, :b => 3.0, :c => 4.0)),
43+
# value type promotion
44+
((1, 2, 3), (2, 3.0, 4.0), Dict(:a => 2, :b => 3.0, :c => 4.0))
45+
]
46+
_newbuf = remake_buffer(sys, buf, newvals)
47+
@test newbuf == _newbuf # test values
48+
@test typeof(newbuf) == typeof(_newbuf) # ensure appropriate type
49+
end

0 commit comments

Comments
 (0)