@@ -117,18 +117,10 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
117117 @assert Meta. isexpr (call, :call )
118118
119119 # Annotate all arguments in the signature as scalars
120- inputs = map (call. args[2 : end ]) do arg
121- esc (Meta. isexpr (arg, :(:: )) ? arg : Expr (:(:: ), arg, :Number ))
122- end
123-
120+ inputs = esc .(_constrain_and_name .(call. args[2 : end ], :Number ))
124121 # Remove annotations and escape names for the call
125- for (i, arg) in enumerate (call. args)
126- if Meta. isexpr (arg, :(:: ))
127- call. args[i] = esc (first (arg. args))
128- else
129- call. args[i] = esc (arg)
130- end
131- end
122+ call. args[2 : end ] .= _unconstrain .(call. args[2 : end ])
123+ call. args = esc .(call. args)
132124
133125 # For consistency in code that follows we make all partials tuple expressions
134126 partials = map (partials) do partial
@@ -143,6 +135,7 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
143135 return call, setup_stmts, inputs, partials
144136end
145137
138+
146139function scalar_frule_expr (f, call, setup_stmts, inputs, partials)
147140 n_outputs = length (partials)
148141 n_inputs = length (inputs)
@@ -178,7 +171,7 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials)
178171
179172 # Δs is the input to the propagator rule
180173 # because this is a pull-back there is one per output of function
181- Δs = [Symbol (string ( :Δ , i) ) for i in 1 : n_outputs]
174+ Δs = [Symbol (:Δ , i) for i in 1 : n_outputs]
182175
183176 # 1 partial derivative per input
184177 pullback_returns = map (1 : n_inputs) do input_i
@@ -189,7 +182,7 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials)
189182 # Multi-output functions have pullbacks with a tuple input that will be destructured
190183 pullback_input = n_outputs == 1 ? first (Δs) : Expr (:tuple , Δs... )
191184 pullback = quote
192- function $ (propagator_name (f, :pullback ))($ pullback_input)
185+ function $ (esc ( propagator_name (f, :pullback ) ))($ pullback_input)
193186 return (NO_FIELDS, $ (pullback_returns... ))
194187 end
195188 end
@@ -215,16 +208,14 @@ function propagation_expr(Δs, ∂s, _conj = false)
215208 ∂s = map (esc, ∂s)
216209 n∂s = length (∂s)
217210
218- # Due to bugs in Julia 1.0, we can't use `.+` or `.*` inside expression
219- # literals.
211+ # Due to bugs in Julia 1.0, we can't use `.+` or `.*` inside expression literals.
220212 ∂_mul_Δs = if _conj
221213 ntuple (i-> :(conj ($ (∂s[i])) * $ (Δs[i])), n∂s)
222214 else
223215 ntuple (i-> :($ (∂s[i]) * $ (Δs[i])), n∂s)
224216 end
225217
226- # Avoiding the extra `+` operation, it is potentially expensive for vector
227- # mode AD.
218+ # Avoiding the extra `+` operation, it is potentially expensive for vector mode AD.
228219 sumed_∂_mul_Δs = if n∂s > 1
229220 # we use `@.` to broadcast `*` and `+`
230221 :(@. + ($ (∂_mul_Δs... )))
@@ -258,3 +249,143 @@ This is able to deal with fairly complex expressions for `f`:
258249propagator_name (f:: Expr , propname:: Symbol ) = propagator_name (f. args[end ], propname)
259250propagator_name (fname:: Symbol , propname:: Symbol ) = Symbol (fname, :_ , propname)
260251propagator_name (fname:: QuoteNode , propname:: Symbol ) = propagator_name (fname. value, propname)
252+
253+ """
254+ @non_differentiable(signature_expression)
255+
256+ A helper to make it easier to declare that a method is not not differentiable.
257+ This is a short-hand for defining an [`frule`](@ref) and [`rrule`](@ref) that
258+ return [`DoesNotExist()`](@ref) for all partials (except for the function `s̄elf`-partial
259+ itself which is `NO_FIELDS`)
260+
261+ Keyword arguments should not be included.
262+
263+ ```jldoctest
264+ julia> @non_differentiable Base.:(==)(a, b)
265+
266+ julia> _, pullback = rrule(==, 2.0, 3.0);
267+
268+ julia> pullback(1.0)
269+ (Zero(), DoesNotExist(), DoesNotExist())
270+ ```
271+
272+ You can place type-constraints in the signature:
273+ ```jldoctest
274+ julia> @non_differentiable Base.length(xs::Union{Number, Array})
275+
276+ julia> frule((Zero(), 1), length, [2.0, 3.0])
277+ (2, DoesNotExist())
278+ ```
279+
280+ !!! warning
281+ This helper macro covers only the simple common cases.
282+ It does not support Varargs, or `where`-clauses.
283+ For these you can declare the `rrule` and `frule` directly
284+
285+ """
286+ macro non_differentiable (sig_expr)
287+ Meta. isexpr (sig_expr, :call ) || error (" Invalid use of `@non_differentiable`" )
288+ for arg in sig_expr. args
289+ _isvararg (arg) && error (" @non_differentiable does not support Varargs like: $arg " )
290+ end
291+
292+ primal_name, orig_args = Iterators. peel (sig_expr. args)
293+
294+ constrained_args = _constrain_and_name .(orig_args, :Any )
295+ primal_sig_parts = [:(:: typeof ($ primal_name)), constrained_args... ]
296+
297+ unconstrained_args = _unconstrain .(constrained_args)
298+ primal_invoke = Expr (:call , esc (primal_name), esc .(unconstrained_args)... )
299+
300+ quote
301+ $ (_nondiff_frule_expr (primal_sig_parts, primal_invoke))
302+ $ (_nondiff_rrule_expr (primal_sig_parts, primal_invoke))
303+ end
304+ end
305+
306+ function _nondiff_frule_expr (primal_sig_parts, primal_invoke)
307+ return Expr (
308+ :(= ),
309+ Expr (:call , :(ChainRulesCore. frule), esc (:_ ), esc .(primal_sig_parts)... ),
310+ # Julia functions always only have 1 output, so just return a single DoesNotExist()
311+ Expr (:tuple , primal_invoke, DoesNotExist ()),
312+ )
313+ end
314+
315+ function _nondiff_rrule_expr (primal_sig_parts, primal_invoke)
316+ num_primal_inputs = length (primal_sig_parts) - 1
317+ primal_name = first (primal_invoke. args)
318+ pullback_expr = Expr (
319+ :function ,
320+ Expr (:call , esc (propagator_name (primal_name, :pullback )), esc (:_ )),
321+ Expr (:tuple , NO_FIELDS, ntuple (_-> DoesNotExist (), num_primal_inputs)... )
322+ )
323+ rrule_defn = Expr (
324+ :(= ),
325+ Expr (:call , :(ChainRulesCore. rrule), esc .(primal_sig_parts)... ),
326+ Expr (:tuple , primal_invoke, pullback_expr),
327+ )
328+ return rrule_defn
329+ end
330+
331+
332+ # ##########
333+ # Helpers
334+
335+ """
336+ _isvararg(expr)
337+
338+ returns true if the expression could represent a vararg
339+
340+ ```jldoctest
341+ julia> ChainRulesCore._isvararg(:(x...))
342+ true
343+
344+ julia> ChainRulesCore._isvararg(:(x::Int...))
345+ true
346+
347+ julia> ChainRulesCore._isvararg(:(::Int...))
348+ true
349+
350+ julia> ChainRulesCore._isvararg(:(x::Vararg))
351+ true
352+
353+ julia> ChainRulesCore._isvararg(:(x::Vararg{Int}))
354+ true
355+
356+ julia> ChainRulesCore._isvararg(:(::Vararg))
357+ true
358+
359+ julia> ChainRulesCore._isvararg(:(::Vararg{Int}))
360+ true
361+
362+ julia> ChainRulesCore._isvararg(:(x))
363+ false
364+ ````
365+ """
366+ _isvararg (expr) = false
367+ function _isvararg (expr:: Expr )
368+ Meta. isexpr (expr, :... ) && return true
369+ if Meta. isexpr (expr, :(:: ))
370+ constraint = last (expr. args)
371+ constraint == :Vararg && return true
372+ Meta. isexpr (constraint, :curly ) && first (constraint. args) == :Vararg && return true
373+ end
374+ return false
375+ end
376+
377+
378+ " turn both `a` and `a::S` into `a`"
379+ _unconstrain (arg:: Symbol ) = arg
380+ function _unconstrain (arg:: Expr )
381+ Meta. isexpr (arg, :(:: ), 2 ) && return arg. args[1 ] # drop constraint.
382+ error (" malformed arguments: $arg " )
383+ end
384+
385+ " turn both `a` and `::constraint` into `a::constraint` etc"
386+ function _constrain_and_name (arg:: Expr , _)
387+ Meta. isexpr (arg, :(:: ), 2 ) && return arg # it is already fine.
388+ Meta. isexpr (arg, :(:: ), 1 ) && return Expr (:(:: ), gensym (), arg. args[1 ]) # add name
389+ error (" malformed arguments: $arg " )
390+ end
391+ _constrain_and_name (name:: Symbol , constraint) = Expr (:(:: ), name, constraint) # add type
0 commit comments