@@ -122,13 +122,8 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
122122 end
123123
124124 # Remove annotations and escape names for the call
125- for (i, arg) in enumerate (call. args)
126- if Meta. isexpr (arg, :(:: ))
127- call. args[i] = esc (first (arg. args))
128- else
129- call. args[i] = esc (arg)
130- end
131- end
125+ call = _without_constraints (call)
126+ call. args = esc .(call. args)
132127
133128 # For consistency in code that follows we make all partials tuple expressions
134129 partials = map (partials) do partial
@@ -143,6 +138,15 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
143138 return call, setup_stmts, inputs, partials
144139end
145140
141+ " turn `foo(a, b::S)` into `foo(a, b)`"
142+ function _without_constraints (call_expr)
143+ return Expr (
144+ :call ,
145+ (Meta. isexpr (arg, :(:: )) ? first (arg. args) : arg for arg in call_expr. args). ..
146+ )
147+ end
148+
149+
146150function scalar_frule_expr (f, call, setup_stmts, inputs, partials)
147151 n_outputs = length (partials)
148152 n_inputs = length (inputs)
@@ -258,3 +262,39 @@ This is able to deal with fairly complex expressions for `f`:
258262propagator_name (f:: Expr , propname:: Symbol ) = propagator_name (f. args[end ], propname)
259263propagator_name (fname:: Symbol , propname:: Symbol ) = Symbol (fname, :_ , propname)
260264propagator_name (fname:: QuoteNode , propname:: Symbol ) = propagator_name (fname. value, propname)
265+
266+
267+ macro @non_differentiable (call_expr)
268+ Meta. isexpr (:call , call_expr) || error (" Invalid use of `@non_differentiable`" )
269+
270+ primal_call = _without_constraints (call_expr)
271+ primal_call. args = esc .(primal_call. args)
272+
273+ # TODO Move to frule helper
274+ frule_defn = Expr (
275+ :(= ),
276+ Expr (:call , :(ChainRulesCore. frule), :_ , call_expr. args... ),
277+ # How many outputs we have it doesn't matter: `DoesNotExist()` is a iterator that
278+ # returns `DoesNotExist()` for every position.
279+ Expr (:tuple , primal_call, DoesNotExist ())
280+ )
281+
282+ # TODO Move to rrule helper
283+ primal_name = first (primal_call. args)
284+ pullback_expr = Expr (
285+ :(= ),
286+ Expr (:call , propagator_name (primal_name, :pullback ), :_ ),
287+ Expr (:tuple , NO_FIELDS, (DoesNotExist () for _ in primal_call. args[2 : end ]). .. )
288+ )
289+ rrule_defn = Expr (
290+ :(= ),
291+ Expr (:call , :(ChainRulesCore. rrule), call_expr. args... ),
292+ Expr (:tuple , primal_call, pullback_expr),
293+ )
294+
295+ quote
296+ $ frule_defn
297+ $ rrule_defn
298+ end
299+ end
300+
0 commit comments