Skip to content

Commit 15577d3

Browse files
committed
Document and finish
1 parent 0d5b1d8 commit 15577d3

File tree

2 files changed

+129
-19
lines changed

2 files changed

+129
-19
lines changed

src/rule_definition_tools.jl

Lines changed: 103 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -135,20 +135,6 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
135135
return call, setup_stmts, inputs, partials
136136
end
137137

138-
"turn both `a` and `a::S` into `a`"
139-
_unconstrain(arg::Symbol) = arg
140-
function _unconstrain(arg::Expr)
141-
Meta.isexpr(arg, :(::), 2) && return arg.args[1] # dop constraint.
142-
error("malformed arguments: $arg")
143-
end
144-
145-
"turn both `a` and `::Number` into `a::Number` into `a::Number` etc"
146-
function _constrain_and_name(arg::Expr, default_constraint)
147-
Meta.isexpr(arg, :(::), 2) && return arg # it is already fine.
148-
Meta.isexpr(arg, :(::), 1) && return Expr(:(::), gensym(), arg.args[1]) #add name
149-
error("malformed arguments: $arg")
150-
end
151-
_constrain_and_name(name::Symbol, constraint) = Expr(:(::), name, constraint) # add type
152138

153139
function scalar_frule_expr(f, call, setup_stmts, inputs, partials)
154140
n_outputs = length(partials)
@@ -264,10 +250,47 @@ propagator_name(f::Expr, propname::Symbol) = propagator_name(f.args[end], propna
264250
propagator_name(fname::Symbol, propname::Symbol) = Symbol(fname, :_, propname)
265251
propagator_name(fname::QuoteNode, propname::Symbol) = propagator_name(fname.value, propname)
266252

253+
"""
254+
@non_differentiable(signature_expression)
255+
256+
A helper to make it easier to declare that a method is not not differentiable.
257+
This is a short-hand for defining a [`frule`](@ref) and an [`rrule`](@ref)+ pullback that
258+
returns [`DoesNotExist()`](@ref) for all partials (except for the function `s̄elf`-partial
259+
itself which is `NO_FIELDS`)
267260
268-
macro non_differentiable(call_expr)
269-
Meta.isexpr(call_expr, :call) || error("Invalid use of `@non_differentiable`")
270-
primal_name, orig_args = Iterators.peel(call_expr.args)
261+
The usage is to put the macro before a function signature.
262+
Keyword arguments should not be included.
263+
264+
```jldoctest
265+
julia> @non_differentiable Base.:(==)(a, b)
266+
267+
julia> _, pullback = rrule(==, 2.0, 3.0);
268+
269+
julia> pullback(1.0)
270+
(Zero(), DoesNotExist(), DoesNotExist())
271+
```
272+
273+
You can place type-constraints in the signature:
274+
```jldoctest
275+
julia> @non_differentiable Base.length(xs::Union{Number, Array})
276+
277+
julia> frule((Zero(), 1), length, [2.0, 3.0])
278+
(2, DoesNotExist())
279+
```
280+
281+
!!! warning
282+
This helper macro covers only the simple common cases.
283+
It does not support Varargs, or `where`-clauses.
284+
For these you can declare the `rrule` and `frule` directly
285+
286+
"""
287+
macro non_differentiable(sig_expr)
288+
Meta.isexpr(sig_expr, :call) || error("Invalid use of `@non_differentiable`")
289+
for arg in sig_expr.args
290+
_isvararg(arg) && error("@non_differentiable does not support Varargs like: $arg")
291+
end
292+
293+
primal_name, orig_args = Iterators.peel(sig_expr.args)
271294

272295
constrained_args = _constrain_and_name.(orig_args, :Any)
273296
primal_sig_parts = [:(::typeof($primal_name)), constrained_args...]
@@ -304,4 +327,66 @@ function _nondiff_rrule_expr(primal_sig_parts, primal_invoke)
304327
Expr(:tuple, primal_invoke, pullback_expr),
305328
)
306329
return rrule_defn
307-
end
330+
end
331+
332+
333+
###########
334+
# Helpers
335+
336+
"""
337+
_isvararg(expr)
338+
339+
returns true if the expression could represent a vararg
340+
341+
```jldoctest
342+
julia> ChainRulesCore._isvararg(:(x...))
343+
true
344+
345+
julia> ChainRulesCore._isvararg(:(x::Int...))
346+
true
347+
348+
julia> ChainRulesCore._isvararg(:(::Int...))
349+
true
350+
351+
julia> ChainRulesCore._isvararg(:(x::Vararg))
352+
true
353+
354+
julia> ChainRulesCore._isvararg(:(x::Vararg{Int}))
355+
true
356+
357+
julia> ChainRulesCore._isvararg(:(::Vararg))
358+
true
359+
360+
julia> ChainRulesCore._isvararg(:(::Vararg{Int}))
361+
true
362+
363+
julia> ChainRulesCore._isvararg(:(x))
364+
false
365+
````
366+
"""
367+
_isvararg(expr) = false
368+
function _isvararg(expr::Expr)
369+
Meta.isexpr(expr, :...) && return true
370+
if Meta.isexpr(expr, :(::))
371+
constraint = last(expr.args)
372+
constraint == :Vararg && return true
373+
Meta.isexpr(constraint, :curly) && first(constraint.args) == :Vararg && return true
374+
end
375+
return false
376+
end
377+
378+
379+
"turn both `a` and `a::S` into `a`"
380+
_unconstrain(arg::Symbol) = arg
381+
function _unconstrain(arg::Expr)
382+
Meta.isexpr(arg, :(::), 2) && return arg.args[1] # dop constraint.
383+
error("malformed arguments: $arg")
384+
end
385+
386+
"turn both `a` and `::Number` into `a::Number` into `a::Number` etc"
387+
function _constrain_and_name(arg::Expr, default_constraint)
388+
Meta.isexpr(arg, :(::), 2) && return arg # it is already fine.
389+
Meta.isexpr(arg, :(::), 1) && return Expr(:(::), gensym(), arg.args[1]) #add name
390+
error("malformed arguments: $arg")
391+
end
392+
_constrain_and_name(name::Symbol, constraint) = Expr(:(::), name, constraint) # add type

test/rule_definition_tools.jl

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
@testset "rule_definition_tools.jl" begin
2-
32
@testset "@non_differentiable" begin
43
@testset "nondiff_2_1" begin
54
nondiff_2_1(x, y) = fill(7.5, 100)[x + y]
@@ -35,6 +34,32 @@
3534

3635
@test rrule(nonembed_identity, 2.0) == nothing
3736
end
37+
38+
@testset "Pointy UnionAll constraints" begin
39+
pointy_identity(x) = x
40+
@non_differentiable pointy_identity(::Vector{<:AbstractString})
41+
42+
@test frule((Zero(), 1.2), pointy_identity, ["2"]) == (["2"], DoesNotExist())
43+
@test frule((Zero(), 1.2), pointy_identity, 2.0) == nothing
44+
45+
res, pullback = rrule(pointy_identity, ["2"])
46+
@test res == ["2"]
47+
@test pullback(1.2) == (NO_FIELDS, DoesNotExist())
48+
49+
@test rrule(pointy_identity, 2.0) == nothing
50+
end
51+
52+
@testset "Not supported (Yet)" begin
53+
# Varargs are not supported
54+
@test_throws Exception @macroexpand(@non_differentiable vararg1(xs...))|
55+
@test_throws Exception @macroexpand(@non_differentiable vararg1(xs::Vararg))
56+
57+
# Where clauses are not supported.
58+
@test_throws Exception @macroexpand(
59+
@non_differentiable where_identity(::Vector{T}) where T<:AbstractString
60+
)
61+
end
62+
3863
end
3964
end
4065

0 commit comments

Comments
 (0)