@@ -117,12 +117,10 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
117117 @assert Meta. isexpr (call, :call )
118118
119119 # Annotate all arguments in the signature as scalars
120- inputs = map (call. args[2 : end ]) do arg
121- esc (Meta. isexpr (arg, :(:: )) ? arg : Expr (:(:: ), arg, :Number ))
122- end
120+ inputs = _constrain_and_name .(call. args[2 : end ], :Number )
123121
124122 # Remove annotations and escape names for the call
125- call = _without_constraints (call)
123+ call. args = _unconstrain . (call. args )
126124 call. args = esc .(call. args)
127125
128126 # For consistency in code that follows we make all partials tuple expressions
@@ -138,14 +136,20 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
138136 return call, setup_stmts, inputs, partials
139137end
140138
141- " turn `foo(a, b::S)` into `foo(a, b)`"
142- function _without_constraints (call_expr)
143- return Expr (
144- :call ,
145- (Meta. isexpr (arg, :(:: )) ? first (arg. args) : arg for arg in call_expr. args). ..
146- )
139+ " turn both `a` and `a::S` into `a`"
140+ _unconstrain (arg:: Symbol ) = arg
141+ function _unconstrain (arg:: Expr )
142+ Meta. isexpr (arg, :(:: ), 2 ) && return arg. args[1 ] # dop constraint.
143+ error (" malformed arguments: $arg " )
147144end
148145
146+ " turn both `a` and `::Number` into `a::Number` into `a::Number` etc"
147+ function _constrain_and_name (arg:: Expr , default_constraint)
148+ Meta. isexpr (arg, :(:: ), 2 ) && return arg # it is already fine.
149+ Meta. isexpr (arg, :(:: ), 1 ) && return Expr (:(:: ), gensym (), arg. args[1 ]) # add name
150+ error (" malformed arguments: $arg " )
151+ end
152+ _constrain_and_name (name:: Symbol , constraint) = Expr (:(:: ), name, constraint) # add type
149153
150154function scalar_frule_expr (f, call, setup_stmts, inputs, partials)
151155 n_outputs = length (partials)
@@ -264,32 +268,37 @@ propagator_name(fname::Symbol, propname::Symbol) = Symbol(fname, :_, propname)
264268propagator_name (fname:: QuoteNode , propname:: Symbol ) = propagator_name (fname. value, propname)
265269
266270
267- macro @non_differentiable (call_expr)
268- Meta. isexpr (:call , call_expr) || error (" Invalid use of `@non_differentiable`" )
271+ macro non_differentiable (call_expr)
272+ Meta. isexpr (call_expr, :call ) || error (" Invalid use of `@non_differentiable`" )
273+ primal_name, orig_args = Iterators. peel (call_expr. args)
269274
270- primal_call = _without_constraints (call_expr)
271- primal_call. args = esc .(primal_call. args)
275+ constrained_args = _constrain_and_name .(orig_args, :Any )
276+ unconstrained_args = _unconstrain .(constrained_args)
277+ primal_invoke = Expr (:call , esc (primal_name), esc .(unconstrained_args)... )
278+
279+
280+ primal_sig_parts = [:(:: typeof ($ primal_name)), constrained_args... ]
272281
273282 # TODO Move to frule helper
274283 frule_defn = Expr (
275284 :(= ),
276- Expr (:call , :(ChainRulesCore. frule), :_ , call_expr . args ... ),
285+ Expr (:call , :(ChainRulesCore. frule), esc ( :_ ), esc .(primal_sig_parts) ... ),
277286 # How many outputs we have it doesn't matter: `DoesNotExist()` is a iterator that
278287 # returns `DoesNotExist()` for every position.
279- Expr (:tuple , primal_call , DoesNotExist ())
288+ Expr (:tuple , primal_invoke , DoesNotExist ())
280289 )
281290
282291 # TODO Move to rrule helper
283- primal_name = first (primal_call . args)
292+
284293 pullback_expr = Expr (
285- :( = ) ,
286- Expr (:call , propagator_name (primal_name, :pullback ), :_ ),
287- Expr (:tuple , NO_FIELDS, (DoesNotExist () for _ in primal_call . args[ 2 : end ] ). .. )
294+ :function ,
295+ Expr (:call , esc ( propagator_name (primal_name, :pullback )), esc ( :_ ) ),
296+ Expr (:tuple , NO_FIELDS, (DoesNotExist () for _ in constrained_args ). .. )
288297 )
289298 rrule_defn = Expr (
290299 :(= ),
291- Expr (:call , :(ChainRulesCore. rrule), call_expr . args ... ),
292- Expr (:tuple , primal_call , pullback_expr),
300+ Expr (:call , :(ChainRulesCore. rrule), esc .(primal_sig_parts) ... ),
301+ Expr (:tuple , primal_invoke , pullback_expr),
293302 )
294303
295304 quote
0 commit comments