@@ -135,20 +135,6 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
135135 return call, setup_stmts, inputs, partials
136136end
137137
138- " turn both `a` and `a::S` into `a`"
139- _unconstrain (arg:: Symbol ) = arg
140- function _unconstrain (arg:: Expr )
141- Meta. isexpr (arg, :(:: ), 2 ) && return arg. args[1 ] # dop constraint.
142- error (" malformed arguments: $arg " )
143- end
144-
145- " turn both `a` and `::Number` into `a::Number` into `a::Number` etc"
146- function _constrain_and_name (arg:: Expr , default_constraint)
147- Meta. isexpr (arg, :(:: ), 2 ) && return arg # it is already fine.
148- Meta. isexpr (arg, :(:: ), 1 ) && return Expr (:(:: ), gensym (), arg. args[1 ]) # add name
149- error (" malformed arguments: $arg " )
150- end
151- _constrain_and_name (name:: Symbol , constraint) = Expr (:(:: ), name, constraint) # add type
152138
153139function scalar_frule_expr (f, call, setup_stmts, inputs, partials)
154140 n_outputs = length (partials)
@@ -264,10 +250,47 @@ propagator_name(f::Expr, propname::Symbol) = propagator_name(f.args[end], propna
264250propagator_name (fname:: Symbol , propname:: Symbol ) = Symbol (fname, :_ , propname)
265251propagator_name (fname:: QuoteNode , propname:: Symbol ) = propagator_name (fname. value, propname)
266252
253+ """
254+ @non_differentiable(signature_expression)
255+
256+ A helper to make it easier to declare that a method is not not differentiable.
257+ This is a short-hand for defining a [`frule`](@ref) and an [`rrule`](@ref)+ pullback that
258+ returns [`DoesNotExist()`](@ref) for all partials (except for the function `s̄elf`-partial
259+ itself which is `NO_FIELDS`)
267260
268- macro non_differentiable (call_expr)
269- Meta. isexpr (call_expr, :call ) || error (" Invalid use of `@non_differentiable`" )
270- primal_name, orig_args = Iterators. peel (call_expr. args)
261+ The usage is to put the macro before a function signature.
262+ Keyword arguments should not be included.
263+
264+ ```jldoctest
265+ julia> @non_differentiable Base.:(==)(a, b)
266+
267+ julia> _, pullback = rrule(==, 2.0, 3.0);
268+
269+ julia> pullback(1.0)
270+ (Zero(), DoesNotExist(), DoesNotExist())
271+ ```
272+
273+ You can place type-constraints in the signature:
274+ ```jldoctest
275+ julia> @non_differentiable Base.length(xs::Union{Number, Array})
276+
277+ julia> frule((Zero(), 1), length, [2.0, 3.0])
278+ (2, DoesNotExist())
279+ ```
280+
281+ !!! warning
282+ This helper macro covers only the simple common cases.
283+ It does not support Varargs, or `where`-clauses.
284+ For these you can declare the `rrule` and `frule` directly
285+
286+ """
287+ macro non_differentiable (sig_expr)
288+ Meta. isexpr (sig_expr, :call ) || error (" Invalid use of `@non_differentiable`" )
289+ for arg in sig_expr. args
290+ _isvararg (arg) && error (" @non_differentiable does not support Varargs like: $arg " )
291+ end
292+
293+ primal_name, orig_args = Iterators. peel (sig_expr. args)
271294
272295 constrained_args = _constrain_and_name .(orig_args, :Any )
273296 primal_sig_parts = [:(:: typeof ($ primal_name)), constrained_args... ]
@@ -304,4 +327,66 @@ function _nondiff_rrule_expr(primal_sig_parts, primal_invoke)
304327 Expr (:tuple , primal_invoke, pullback_expr),
305328 )
306329 return rrule_defn
307- end
330+ end
331+
332+
333+ # ##########
334+ # Helpers
335+
336+ """
337+ _isvararg(expr)
338+
339+ returns true if the expression could represent a vararg
340+
341+ ```jldoctest
342+ julia> ChainRulesCore._isvararg(:(x...))
343+ true
344+
345+ julia> ChainRulesCore._isvararg(:(x::Int...))
346+ true
347+
348+ julia> ChainRulesCore._isvararg(:(::Int...))
349+ true
350+
351+ julia> ChainRulesCore._isvararg(:(x::Vararg))
352+ true
353+
354+ julia> ChainRulesCore._isvararg(:(x::Vararg{Int}))
355+ true
356+
357+ julia> ChainRulesCore._isvararg(:(::Vararg))
358+ true
359+
360+ julia> ChainRulesCore._isvararg(:(::Vararg{Int}))
361+ true
362+
363+ julia> ChainRulesCore._isvararg(:(x))
364+ false
365+ ````
366+ """
367+ _isvararg (expr) = false
368+ function _isvararg (expr:: Expr )
369+ Meta. isexpr (expr, :... ) && return true
370+ if Meta. isexpr (expr, :(:: ))
371+ constraint = last (expr. args)
372+ constraint == :Vararg && return true
373+ Meta. isexpr (constraint, :curly ) && first (constraint. args) == :Vararg && return true
374+ end
375+ return false
376+ end
377+
378+
379+ " turn both `a` and `a::S` into `a`"
380+ _unconstrain (arg:: Symbol ) = arg
381+ function _unconstrain (arg:: Expr )
382+ Meta. isexpr (arg, :(:: ), 2 ) && return arg. args[1 ] # dop constraint.
383+ error (" malformed arguments: $arg " )
384+ end
385+
386+ " turn both `a` and `::Number` into `a::Number` into `a::Number` etc"
387+ function _constrain_and_name (arg:: Expr , default_constraint)
388+ Meta. isexpr (arg, :(:: ), 2 ) && return arg # it is already fine.
389+ Meta. isexpr (arg, :(:: ), 1 ) && return Expr (:(:: ), gensym (), arg. args[1 ]) # add name
390+ error (" malformed arguments: $arg " )
391+ end
392+ _constrain_and_name (name:: Symbol , constraint) = Expr (:(:: ), name, constraint) # add type
0 commit comments