@@ -295,7 +295,8 @@ macro non_differentiable(sig_expr)
295295 primal_sig_parts = [:(:: typeof ($ primal_name)), constrained_args... ]
296296
297297 unconstrained_args = _unconstrain .(constrained_args)
298- primal_invoke = Expr (:call , esc (primal_name), esc .(unconstrained_args)... )
298+
299+ primal_invoke = :($ (primal_name)($ (unconstrained_args... ); kwargs... ))
299300
300301 quote
301302 $ (_nondiff_frule_expr (primal_sig_parts, primal_invoke))
@@ -304,28 +305,27 @@ macro non_differentiable(sig_expr)
304305end
305306
306307function _nondiff_frule_expr (primal_sig_parts, primal_invoke)
307- return Expr (
308- :( = ),
309- Expr ( :call , :(ChainRulesCore . frule), esc ( :_ ), esc .(primal_sig_parts) ... ),
310- # Julia functions always only have 1 output, so just return a single DoesNotExist()
311- Expr ( :tuple , primal_invoke, DoesNotExist ()),
312- )
308+ return esc (: (
309+ function ChainRulesCore . frule ( $ ( gensym ( :_ )), $ (primal_sig_parts ... ); kwargs ... )
310+ # Julia functions always only have 1 output, so return a single DoesNotExist()
311+ return ( $ primal_invoke, DoesNotExist () )
312+ end
313+ ))
313314end
314315
315316function _nondiff_rrule_expr (primal_sig_parts, primal_invoke)
316317 num_primal_inputs = length (primal_sig_parts) - 1
317318 primal_name = first (primal_invoke. args)
318319 pullback_expr = Expr (
319320 :function ,
320- Expr (:call , esc ( propagator_name (primal_name, :pullback )), esc ( :_ ) ),
321+ Expr (:call , propagator_name (primal_name, :pullback ), :_ ),
321322 Expr (:tuple , NO_FIELDS, ntuple (_-> DoesNotExist (), num_primal_inputs)... )
322323 )
323- rrule_defn = Expr (
324- :(= ),
325- Expr (:call , :(ChainRulesCore. rrule), esc .(primal_sig_parts)... ),
326- Expr (:tuple , primal_invoke, pullback_expr),
327- )
328- return rrule_defn
324+ return esc (:(
325+ function ChainRulesCore. rrule ($ (primal_sig_parts... ); kwargs... )
326+ return ($ primal_invoke, $ pullback_expr)
327+ end
328+ ))
329329end
330330
331331
0 commit comments