Skip to content

Commit 467d913

Browse files
authored
Merge branch 'master' into mz/canonicalize
2 parents 58ea459 + 63de708 commit 467d913

File tree

4 files changed

+56
-9
lines changed

4 files changed

+56
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "0.9.30"
3+
version = "0.9.32"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/rule_definition_tools.jl

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using Base.Meta
2+
13
# These are some macros (and supporting functions) to make it easier to define rules.
24
"""
35
@scalar_rule(f(x₁, x₂, ...),
@@ -198,7 +200,7 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials)
198200
end
199201
end
200202

201-
# For context on why this is important, see
203+
# For context on why this is important, see
202204
# https://github.com/JuliaDiff/ChainRulesCore.jl/pull/276
203205
"Declares properly hygenic inputs for propagation expressions"
204206
_propagator_inputs(n) = [esc(gensym(Symbol(, i))) for i in 1:n]
@@ -307,11 +309,11 @@ macro non_differentiable(sig_expr)
307309
unconstrained_args = _unconstrain.(constrained_args)
308310

309311
primal_invoke = if !has_vararg
310-
:($(primal_name)($(unconstrained_args...); kwargs...))
312+
:($(primal_name)($(unconstrained_args...)))
311313
else
312314
normal_args = unconstrained_args[1:end-1]
313315
var_arg = unconstrained_args[end]
314-
:($(primal_name)($(normal_args...), $(var_arg)...; kwargs...))
316+
:($(primal_name)($(normal_args...), $(var_arg)...))
315317
end
316318

317319
quote
@@ -320,11 +322,19 @@ macro non_differentiable(sig_expr)
320322
end
321323
end
322324

325+
"changes `f(x,y)` into `f(x,y; kwargs....)`"
326+
function _with_kwargs_expr(call_expr::Expr)
327+
@assert isexpr(call_expr, :call)
328+
return Expr(
329+
:call, call_expr.args[1], Expr(:parameters, :(kwargs...)), call_expr.args[2:end]...
330+
)
331+
end
332+
323333
function _nondiff_frule_expr(primal_sig_parts, primal_invoke)
324334
return esc(:(
325335
function ChainRulesCore.frule($(gensym(:_)), $(primal_sig_parts...); kwargs...)
326336
# Julia functions always only have 1 output, so return a single DoesNotExist()
327-
return ($primal_invoke, DoesNotExist())
337+
return ($(_with_kwargs_expr(primal_invoke)), DoesNotExist())
328338
end
329339
))
330340
end
@@ -349,11 +359,15 @@ function _nondiff_rrule_expr(primal_sig_parts, primal_invoke)
349359
Expr(:call, propagator_name(primal_name, :pullback), :_),
350360
Expr(:tuple, DoesNotExist(), Expr(:(...), tup_expr))
351361
)
352-
return esc(:(
353-
function ChainRulesCore.rrule($(primal_sig_parts...); kwargs...)
362+
return esc(quote
363+
# Manually defined kw version to save compiler work. See explanation in rules.jl
364+
function (::Core.kwftype(typeof(ChainRulesCore.rrule)))(kwargs::Any, rrule::typeof(ChainRulesCore.rrule), $(primal_sig_parts...))
365+
return ($(_with_kwargs_expr(primal_invoke)), $pullback_expr)
366+
end
367+
function ChainRulesCore.rrule($(primal_sig_parts...))
354368
return ($primal_invoke, $pullback_expr)
355369
end
356-
))
370+
end)
357371
end
358372

359373

src/rules.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,4 +103,16 @@ true
103103
104104
See also: [`frule`](@ref), [`@scalar_rule`](@ref)
105105
"""
106-
rrule(::Any, ::Vararg{Any}; kwargs...) = nothing
106+
rrule(::Any, ::Vararg{Any}) = nothing
107+
108+
# Manual fallback for keyword arguments. Usually this would be generated by
109+
#
110+
# rrule(::Any, ::Vararg{Any}; kwargs...) = nothing
111+
#
112+
# However - the fallback method is so hot that we want to avoid any extra code
113+
# that would be required to have the automatically generated method package up
114+
# the keyword arguments (which the optimizer will throw away, but the compiler
115+
# still has to manually analyze). Manually declare this method with an
116+
# explicitly empty body to save the compiler that work.
117+
118+
(::Core.kwftype(typeof(rrule)))(::Any, ::Any, ::Vararg{Any}) = nothing

test/rule_definition_tools.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,24 @@ end
251251

252252

253253
end
254+
255+
256+
module IsolatedModuleForTestingScoping
257+
using Test
258+
# need to make sure macros work in something that hasn't imported all exports
259+
# all that matters is that the following don't error, since they will resolve at
260+
# parse time
261+
using ChainRulesCore: ChainRulesCore
262+
263+
@testset "@non_differentiable" begin
264+
# this is
265+
# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/317
266+
fixed(x) = :abc
267+
ChainRulesCore.@non_differentiable fixed(x)
268+
end
269+
270+
@testset "@scalar_rule" begin
271+
my_id(x) = x
272+
ChainRulesCore.@scalar_rule(my_id(x), 1.0)
273+
end
274+
end

0 commit comments

Comments
 (0)