Skip to content

Commit 39f1caf

Browse files
authored
Merge pull request #217 from JuliaDiff/ox/nondiff_kw
Add kwarg support to at-nondifferentiable
2 parents dc7e159 + 185e518 commit 39f1caf

File tree

3 files changed

+40
-15
lines changed

3 files changed

+40
-15
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "0.9.8"
3+
version = "0.9.9"
44

55
[deps]
66
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"

src/rule_definition_tools.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,8 @@ macro non_differentiable(sig_expr)
295295
primal_sig_parts = [:(::typeof($primal_name)), constrained_args...]
296296

297297
unconstrained_args = _unconstrain.(constrained_args)
298-
primal_invoke = Expr(:call, esc(primal_name), esc.(unconstrained_args)...)
298+
299+
primal_invoke = :($(primal_name)($(unconstrained_args...); kwargs...))
299300

300301
quote
301302
$(_nondiff_frule_expr(primal_sig_parts, primal_invoke))
@@ -304,28 +305,27 @@ macro non_differentiable(sig_expr)
304305
end
305306

306307
function _nondiff_frule_expr(primal_sig_parts, primal_invoke)
307-
return Expr(
308-
:(=),
309-
Expr(:call, :(ChainRulesCore.frule), esc(:_), esc.(primal_sig_parts)...),
310-
# Julia functions always only have 1 output, so just return a single DoesNotExist()
311-
Expr(:tuple, primal_invoke, DoesNotExist()),
312-
)
308+
return esc(:(
309+
function ChainRulesCore.frule($(gensym(:_)), $(primal_sig_parts...); kwargs...)
310+
# Julia functions always only have 1 output, so return a single DoesNotExist()
311+
return ($primal_invoke, DoesNotExist())
312+
end
313+
))
313314
end
314315

315316
function _nondiff_rrule_expr(primal_sig_parts, primal_invoke)
316317
num_primal_inputs = length(primal_sig_parts) - 1
317318
primal_name = first(primal_invoke.args)
318319
pullback_expr = Expr(
319320
:function,
320-
Expr(:call, esc(propagator_name(primal_name, :pullback)), esc(:_)),
321+
Expr(:call, propagator_name(primal_name, :pullback), :_),
321322
Expr(:tuple, NO_FIELDS, ntuple(_->DoesNotExist(), num_primal_inputs)...)
322323
)
323-
rrule_defn = Expr(
324-
:(=),
325-
Expr(:call, :(ChainRulesCore.rrule), esc.(primal_sig_parts)...),
326-
Expr(:tuple, primal_invoke, pullback_expr),
327-
)
328-
return rrule_defn
324+
return esc(:(
325+
function ChainRulesCore.rrule($(primal_sig_parts...); kwargs...)
326+
return ($primal_invoke, $pullback_expr)
327+
end
328+
))
329329
end
330330

331331

test/rule_definition_tools.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,31 @@ end
7373
@test rrule(pointy_identity, 2.0) == nothing
7474
end
7575

76+
@testset "kwargs" begin
77+
kw_demo(x; kw=2.0) = x + kw
78+
@non_differentiable kw_demo(::Any)
79+
80+
@testset "not setting kw" begin
81+
@assert kw_demo(1.5) == 3.5
82+
83+
res, pullback = rrule(kw_demo, 1.5)
84+
@test res == 3.5
85+
@test pullback(4.1) == (NO_FIELDS, DoesNotExist())
86+
87+
@test frule((Zero(), 11.1), kw_demo, 1.5) == (3.5, DoesNotExist())
88+
end
89+
90+
@testset "setting kw" begin
91+
@assert kw_demo(1.5; kw=3.0) == 4.5
92+
93+
res, pullback = rrule(kw_demo, 1.5; kw=3.0)
94+
@test res == 4.5
95+
@test pullback(1.1) == (NO_FIELDS, DoesNotExist())
96+
97+
@test frule((Zero(), 11.1), kw_demo, 1.5; kw=3.0) == (4.5, DoesNotExist())
98+
end
99+
end
100+
76101
@testset "Not supported (Yet)" begin
77102
# Varargs are not supported
78103
@test_macro_throws ErrorException @non_differentiable vararg1(xs...)

0 commit comments

Comments
 (0)