@@ -117,10 +117,9 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
117117 @assert Meta. isexpr (call, :call )
118118
119119 # Annotate all arguments in the signature as scalars
120- inputs = _constrain_and_name .(call. args[2 : end ], :Number )
121-
120+ inputs = esc .(_constrain_and_name .(call. args[2 : end ], :Number ))
122121 # Remove annotations and escape names for the call
123- call. args = _unconstrain .(call. args)
122+ call. args[ 2 : end ] . = _unconstrain .(call. args[ 2 : end ] )
124123 call. args = esc .(call. args)
125124
126125 # For consistency in code that follows we make all partials tuple expressions
@@ -186,7 +185,7 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials)
186185
187186 # Δs is the input to the propagator rule
188187 # because this is a pull-back there is one per output of function
189- Δs = [Symbol (string ( :Δ , i) ) for i in 1 : n_outputs]
188+ Δs = [Symbol (:Δ , i) for i in 1 : n_outputs]
190189
191190 # 1 partial derivative per input
192191 pullback_returns = map (1 : n_inputs) do input_i
@@ -197,7 +196,7 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials)
197196 # Multi-output functions have pullbacks with a tuple input that will be destructured
198197 pullback_input = n_outputs == 1 ? first (Δs) : Expr (:tuple , Δs... )
199198 pullback = quote
200- function $ (propagator_name (f, :pullback ))($ pullback_input)
199+ function $ (esc ( propagator_name (f, :pullback ) ))($ pullback_input)
201200 return (NO_FIELDS, $ (pullback_returns... ))
202201 end
203202 end
@@ -223,16 +222,14 @@ function propagation_expr(Δs, ∂s, _conj = false)
223222 ∂s = map (esc, ∂s)
224223 n∂s = length (∂s)
225224
226- # Due to bugs in Julia 1.0, we can't use `.+` or `.*` inside expression
227- # literals.
225+ # Due to bugs in Julia 1.0, we can't use `.+` or `.*` inside expression literals.
228226 ∂_mul_Δs = if _conj
229227 ntuple (i-> :(conj ($ (∂s[i])) * $ (Δs[i])), n∂s)
230228 else
231229 ntuple (i-> :($ (∂s[i]) * $ (Δs[i])), n∂s)
232230 end
233231
234- # Avoiding the extra `+` operation, it is potentially expensive for vector
235- # mode AD.
232+ # Avoiding the extra `+` operation, it is potentially expensive for vector mode AD.
236233 sumed_∂_mul_Δs = if n∂s > 1
237234 # we use `@.` to broadcast `*` and `+`
238235 :(@. + ($ (∂_mul_Δs... )))
@@ -273,37 +270,38 @@ macro non_differentiable(call_expr)
273270 primal_name, orig_args = Iterators. peel (call_expr. args)
274271
275272 constrained_args = _constrain_and_name .(orig_args, :Any )
273+ primal_sig_parts = [:(:: typeof ($ primal_name)), constrained_args... ]
274+
276275 unconstrained_args = _unconstrain .(constrained_args)
277276 primal_invoke = Expr (:call , esc (primal_name), esc .(unconstrained_args)... )
278-
279-
280- primal_sig_parts = [:(:: typeof ($ primal_name)), constrained_args... ]
277+
278+ quote
279+ $ (_nondiff_frule_expr (primal_sig_parts, primal_invoke))
280+ $ (_nondiff_rrule_expr (primal_sig_parts, primal_invoke))
281+ end
282+ end
281283
282- # TODO Move to frule helper
283- frule_defn = Expr (
284+ function _nondiff_frule_expr (primal_sig_parts, primal_invoke)
285+ return Expr (
284286 :(= ),
285287 Expr (:call , :(ChainRulesCore. frule), esc (:_ ), esc .(primal_sig_parts)... ),
286- # How many outputs we have it doesn't matter: `DoesNotExist()` is a iterator that
287- # returns `DoesNotExist()` for every position.
288+ # Julia functions always only have 1 output, so just return a single DoesNotExist()
288289 Expr (:tuple , primal_invoke, DoesNotExist ())
289290 )
291+ end
290292
291- # TODO Move to rrule helper
292-
293+ function _nondiff_rrule_expr (primal_sig_parts, primal_invoke)
294+ num_primal_inputs = length (primal_sig_parts) - 1
295+ primal_name = first (primal_invoke. args)
293296 pullback_expr = Expr (
294297 :function ,
295298 Expr (:call , esc (propagator_name (primal_name, :pullback )), esc (:_ )),
296- Expr (:tuple , NO_FIELDS, ( DoesNotExist () for _ in constrained_args ). .. )
299+ Expr (:tuple , NO_FIELDS, ntuple (_ -> DoesNotExist (), num_primal_inputs )... )
297300 )
298301 rrule_defn = Expr (
299302 :(= ),
300303 Expr (:call , :(ChainRulesCore. rrule), esc .(primal_sig_parts)... ),
301304 Expr (:tuple , primal_invoke, pullback_expr),
302305 )
303-
304- quote
305- $ frule_defn
306- $ rrule_defn
307- end
308- end
309-
306+ return rrule_defn
307+ end
0 commit comments