1+ using Base. Meta
2+
13# These are some macros (and supporting functions) to make it easier to define rules.
24"""
35 @scalar_rule(f(x₁, x₂, ...),
@@ -198,7 +200,7 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials)
198200 end
199201end
200202
201- # For context on why this is important, see
203+ # For context on why this is important, see
202204# https://github.com/JuliaDiff/ChainRulesCore.jl/pull/276
203205" Declares properly hygenic inputs for propagation expressions"
204206_propagator_inputs (n) = [esc (gensym (Symbol (:Δ , i))) for i in 1 : n]
@@ -307,11 +309,11 @@ macro non_differentiable(sig_expr)
307309 unconstrained_args = _unconstrain .(constrained_args)
308310
309311 primal_invoke = if ! has_vararg
310- :($ (primal_name)($ (unconstrained_args... ); kwargs ... ))
312+ :($ (primal_name)($ (unconstrained_args... )))
311313 else
312314 normal_args = unconstrained_args[1 : end - 1 ]
313315 var_arg = unconstrained_args[end ]
314- :($ (primal_name)($ (normal_args... ), $ (var_arg). .. ; kwargs ... ))
316+ :($ (primal_name)($ (normal_args... ), $ (var_arg). .. ))
315317 end
316318
317319 quote
@@ -320,11 +322,19 @@ macro non_differentiable(sig_expr)
320322 end
321323end
322324
325+ " changes `f(x,y)` into `f(x,y; kwargs....)`"
326+ function _with_kwargs_expr (call_expr:: Expr )
327+ @assert isexpr (call_expr, :call )
328+ return Expr (
329+ :call , call_expr. args[1 ], Expr (:parameters , :(kwargs... )), call_expr. args[2 : end ]. ..
330+ )
331+ end
332+
323333function _nondiff_frule_expr (primal_sig_parts, primal_invoke)
324334 return esc (:(
325335 function ChainRulesCore. frule ($ (gensym (:_ )), $ (primal_sig_parts... ); kwargs... )
326336 # Julia functions always only have 1 output, so return a single DoesNotExist()
327- return ($ primal_invoke, DoesNotExist ())
337+ return ($ ( _with_kwargs_expr ( primal_invoke)) , DoesNotExist ())
328338 end
329339 ))
330340end
@@ -349,11 +359,15 @@ function _nondiff_rrule_expr(primal_sig_parts, primal_invoke)
349359 Expr (:call , propagator_name (primal_name, :pullback ), :_ ),
350360 Expr (:tuple , DoesNotExist (), Expr (:(... ), tup_expr))
351361 )
352- return esc (:(
353- function ChainRulesCore. rrule ($ (primal_sig_parts... ); kwargs... )
362+ return esc (quote
363+ # Manually defined kw version to save compiler work. See explanation in rules.jl
364+ function (:: Core.kwftype (typeof (ChainRulesCore. rrule)))(kwargs:: Any , rrule:: typeof (ChainRulesCore. rrule), $ (primal_sig_parts... ))
365+ return ($ (_with_kwargs_expr (primal_invoke)), $ pullback_expr)
366+ end
367+ function ChainRulesCore. rrule ($ (primal_sig_parts... ))
354368 return ($ primal_invoke, $ pullback_expr)
355369 end
356- ) )
370+ end )
357371end
358372
359373
0 commit comments