Skip to content

Commit 2ead553

Browse files
committed
Avoid unnecessary keyword arguments check in fallback rrule
The expansion `rrule(::Any, ::Vararg{Any}; kwargs...)` actually generates two methods. One along the lines of: ``` (::typeof(Core.kwfunc(rrule))(kwargs, ::typeof(rrule), ::Any, ::Vararg{Any}) = nothing ``` and the other that just calls it: ``` rrule(a::Any, b::Vararg{Any}) = Core.kwfunc(rrule)(NamedTuple{}(), a, b...) ``` The compiler handles this fallback well, since it's used all over the place, but the cost to infer it is non-zero. Of course, in the AD use case, this fallback method is visited literally on every call, so saving a tiny amount of inference/compile time actually leads to noticable improvements over a whole AD problem.
1 parent 01b956f commit 2ead553

File tree

2 files changed

+34
-8
lines changed

2 files changed

+34
-8
lines changed

src/rule_definition_tools.jl

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using Base.Meta
2+
13
# These are some macros (and supporting functions) to make it easier to define rules.
24
"""
35
@scalar_rule(f(x₁, x₂, ...),
@@ -198,7 +200,7 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials)
198200
end
199201
end
200202

201-
# For context on why this is important, see
203+
# For context on why this is important, see
202204
# https://github.com/JuliaDiff/ChainRulesCore.jl/pull/276
203205
"Declares properly hygenic inputs for propagation expressions"
204206
_propagator_inputs(n) = [esc(gensym(Symbol(, i))) for i in 1:n]
@@ -307,11 +309,11 @@ macro non_differentiable(sig_expr)
307309
unconstrained_args = _unconstrain.(constrained_args)
308310

309311
primal_invoke = if !has_vararg
310-
:($(primal_name)($(unconstrained_args...); kwargs...))
312+
:($(primal_name)($(unconstrained_args...)))
311313
else
312314
normal_args = unconstrained_args[1:end-1]
313315
var_arg = unconstrained_args[end]
314-
:($(primal_name)($(normal_args...), $(var_arg)...; kwargs...))
316+
:($(primal_name)($(normal_args...), $(var_arg)...))
315317
end
316318

317319
quote
@@ -320,11 +322,18 @@ macro non_differentiable(sig_expr)
320322
end
321323
end
322324

325+
"changes `f(x,y)` into `f(x,y; kwargs....)`"
326+
function _with_kwargs_expr(call_expr::Expr)
327+
@assert isexpr(call_expr, :call)
328+
Expr(:call, call_expr.args[1], Expr(:parameters, :(kwargs...)),
329+
call_expr.args[2:end]...)
330+
end
331+
323332
function _nondiff_frule_expr(primal_sig_parts, primal_invoke)
324333
return esc(:(
325334
function ChainRulesCore.frule($(gensym(:_)), $(primal_sig_parts...); kwargs...)
326335
# Julia functions always only have 1 output, so return a single DoesNotExist()
327-
return ($primal_invoke, DoesNotExist())
336+
return ($(_with_kwargs_expr(primal_invoke)), DoesNotExist())
328337
end
329338
))
330339
end
@@ -349,11 +358,16 @@ function _nondiff_rrule_expr(primal_sig_parts, primal_invoke)
349358
Expr(:call, propagator_name(primal_name, :pullback), :_),
350359
Expr(:tuple, DoesNotExist(), Expr(:(...), tup_expr))
351360
)
352-
return esc(:(
353-
function ChainRulesCore.rrule($(primal_sig_parts...); kwargs...)
361+
return esc(quote
362+
# Manully defined kw version to save compiler work.
363+
# See rules.jl
364+
function (::Core.kwftype(typeof(rrule)))(kwargs::Any, rrule::typeof(ChainRulesCore.rrule), $(primal_sig_parts...))
365+
return ($(_with_kwargs_expr(primal_invoke)), $pullback_expr)
366+
end
367+
function ChainRulesCore.rrule($(primal_sig_parts...))
354368
return ($primal_invoke, $pullback_expr)
355369
end
356-
))
370+
end)
357371
end
358372

359373

src/rules.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,4 +103,16 @@ true
103103
104104
See also: [`frule`](@ref), [`@scalar_rule`](@ref)
105105
"""
106-
rrule(::Any, ::Vararg{Any}; kwargs...) = nothing
106+
rrule(::Any, ::Vararg{Any}) = nothing
107+
108+
# Manual fallback for keyword arguments. Usually this would be generated by
109+
#
110+
# rrule(::Any, ::Vararg{Any}; kwargs...) = nothing
111+
#
112+
# However - the fallback method is so hot that we want to avoid any extra code
113+
# that would be required to have the automatically generated method package up
114+
# the keyword arguments (which the optimizer will throw away, but the compiler
115+
# still has to manually analyze). Manually declare this method with an
116+
# explicitly empty body to save the compiler that work.
117+
118+
(::Core.kwftype(typeof(rrule)))(::Any, ::Any, ::Vararg{Any}) = nothing

0 commit comments

Comments
 (0)