Skip to content

Commit 376bfdf

Browse files
committed
Fix #211 making scalar_rule's frule return a Composite not a tuple
1 parent 7582999 commit 376bfdf

File tree

4 files changed

+23
-19
lines changed

4 files changed

+23
-19
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "0.9.12"
3+
version = "0.9.13"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/rule_definition_tools.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,12 @@ function scalar_frule_expr(f, call, setup_stmts, inputs, partials)
148148
propagation_expr(Δs, ∂s)
149149
end
150150
if n_outputs > 1
151-
# For forward-mode we only return a tuple if output actually a tuple.
152-
pushforward_returns = Expr(:tuple, pushforward_returns...)
151+
# For forward-mode we return a Composite if output actually a tuple.
152+
pushforward_returns = Expr(
153+
:call, :(ChainRulesCore.Composite{typeof($(esc()))}), pushforward_returns...
154+
)
153155
else
154-
pushforward_returns = pushforward_returns[1]
156+
pushforward_returns = first(pushforward_returns)
155157
end
156158

157159
return quote

test/rule_definition_tools.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,4 +110,20 @@ end
110110
)
111111
end
112112
end
113+
114+
@testset "@scalar_rule" begin
115+
@testset "@scalar_rule with multiple output" begin
116+
simo(x) = (x, 2x)
117+
@scalar_rule(simo(x), 1f0, 2f0)
118+
119+
y, simo_pb = rrule(simo, π)
120+
121+
@test simo_pb((10f0, 20f0)) == (NO_FIELDS, 50f0)
122+
123+
y, ẏ = frule((NO_FIELDS, 50f0), simo, π)
124+
@test y == (π, 2π)
125+
# test with === because type also must match
126+
@test=== Composite{typeof(y)}(50f0, 100f0)
127+
end
128+
end
113129
end

test/rules.jl

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -136,18 +136,4 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
136136
@test ∂self == NO_FIELDS
137137
@test ∂x j′vp(central_fdm(5, 1), complex_times, Ω̄, x)[1]
138138
end
139-
end
140-
141-
142-
simo(x) = (x, 2x)
143-
@scalar_rule(simo(x), 1, 2)
144-
145-
@testset "@scalar_rule with multiple inputs" begin
146-
y, simo_pb = rrule(simo, π)
147-
148-
@test simo_pb((10, 20)) == (NO_FIELDS, 50)
149-
150-
y, ẏ = frule((NO_FIELDS, 50), simo, π)
151-
@test y == (π, 2π)
152-
@test== (50, 100)
153-
end
139+
end

0 commit comments

Comments
 (0)