Skip to content

Commit 9328125

Browse files
authored
Merge pull request #308 from Keno/kf/fallbackvarargs
Avoid unnecessary keyword arguments check in fallback rrule
2 parents 74201e6 + b43c8a6 commit 9328125

File tree

2 files changed

+34
-8
lines changed

2 files changed

+34
-8
lines changed

src/rule_definition_tools.jl

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using Base.Meta
2+
13
# These are some macros (and supporting functions) to make it easier to define rules.
24
"""
35
@scalar_rule(f(x₁, x₂, ...),
@@ -198,7 +200,7 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials)
198200
end
199201
end
200202

201-
# For context on why this is important, see
203+
# For context on why this is important, see
202204
# https://github.com/JuliaDiff/ChainRulesCore.jl/pull/276
203205
"Declares properly hygenic inputs for propagation expressions"
204206
_propagator_inputs(n) = [esc(gensym(Symbol(, i))) for i in 1:n]
@@ -307,11 +309,11 @@ macro non_differentiable(sig_expr)
307309
unconstrained_args = _unconstrain.(constrained_args)
308310

309311
primal_invoke = if !has_vararg
310-
:($(primal_name)($(unconstrained_args...); kwargs...))
312+
:($(primal_name)($(unconstrained_args...)))
311313
else
312314
normal_args = unconstrained_args[1:end-1]
313315
var_arg = unconstrained_args[end]
314-
:($(primal_name)($(normal_args...), $(var_arg)...; kwargs...))
316+
:($(primal_name)($(normal_args...), $(var_arg)...))
315317
end
316318

317319
quote
@@ -320,11 +322,19 @@ macro non_differentiable(sig_expr)
320322
end
321323
end
322324

325+
"changes `f(x,y)` into `f(x,y; kwargs....)`"
326+
function _with_kwargs_expr(call_expr::Expr)
327+
@assert isexpr(call_expr, :call)
328+
return Expr(
329+
:call, call_expr.args[1], Expr(:parameters, :(kwargs...)), call_expr.args[2:end]...
330+
)
331+
end
332+
323333
function _nondiff_frule_expr(primal_sig_parts, primal_invoke)
324334
return esc(:(
325335
function ChainRulesCore.frule($(gensym(:_)), $(primal_sig_parts...); kwargs...)
326336
# Julia functions always only have 1 output, so return a single DoesNotExist()
327-
return ($primal_invoke, DoesNotExist())
337+
return ($(_with_kwargs_expr(primal_invoke)), DoesNotExist())
328338
end
329339
))
330340
end
@@ -349,11 +359,15 @@ function _nondiff_rrule_expr(primal_sig_parts, primal_invoke)
349359
Expr(:call, propagator_name(primal_name, :pullback), :_),
350360
Expr(:tuple, DoesNotExist(), Expr(:(...), tup_expr))
351361
)
352-
return esc(:(
353-
function ChainRulesCore.rrule($(primal_sig_parts...); kwargs...)
362+
return esc(quote
363+
# Manually defined kw version to save compiler work. See explanation in rules.jl
364+
function (::Core.kwftype(typeof(rrule)))(kwargs::Any, rrule::typeof(ChainRulesCore.rrule), $(primal_sig_parts...))
365+
return ($(_with_kwargs_expr(primal_invoke)), $pullback_expr)
366+
end
367+
function ChainRulesCore.rrule($(primal_sig_parts...))
354368
return ($primal_invoke, $pullback_expr)
355369
end
356-
))
370+
end)
357371
end
358372

359373

src/rules.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,4 +103,16 @@ true
103103
104104
See also: [`frule`](@ref), [`@scalar_rule`](@ref)
105105
"""
106-
rrule(::Any, ::Vararg{Any}; kwargs...) = nothing
106+
rrule(::Any, ::Vararg{Any}) = nothing
107+
108+
# Manual fallback for keyword arguments. Usually this would be generated by
109+
#
110+
# rrule(::Any, ::Vararg{Any}; kwargs...) = nothing
111+
#
112+
# However - the fallback method is so hot that we want to avoid any extra code
113+
# that would be required to have the automatically generated method package up
114+
# the keyword arguments (which the optimizer will throw away, but the compiler
115+
# still has to manually analyze). Manually declare this method with an
116+
# explicitly empty body to save the compiler that work.
117+
118+
(::Core.kwftype(typeof(rrule)))(::Any, ::Any, ::Vararg{Any}) = nothing

0 commit comments

Comments
 (0)