Skip to content

Commit 0d5b1d8

Browse files
committed
Finish testing and cleaning code on non_differentiable macro
1 parent fc42649 commit 0d5b1d8

File tree

2 files changed

+58
-31
lines changed

2 files changed

+58
-31
lines changed

src/rule_definition_tools.jl

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,9 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
117117
@assert Meta.isexpr(call, :call)
118118

119119
# Annotate all arguments in the signature as scalars
120-
inputs = _constrain_and_name.(call.args[2:end], :Number)
121-
120+
inputs = esc.(_constrain_and_name.(call.args[2:end], :Number))
122121
# Remove annotations and escape names for the call
123-
call.args = _unconstrain.(call.args)
122+
call.args[2:end] .= _unconstrain.(call.args[2:end])
124123
call.args = esc.(call.args)
125124

126125
# For consistency in code that follows we make all partials tuple expressions
@@ -186,7 +185,7 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials)
186185

187186
# Δs is the input to the propagator rule
188187
# because this is a pull-back there is one per output of function
189-
Δs = [Symbol(string(, i)) for i in 1:n_outputs]
188+
Δs = [Symbol(, i) for i in 1:n_outputs]
190189

191190
# 1 partial derivative per input
192191
pullback_returns = map(1:n_inputs) do input_i
@@ -197,7 +196,7 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials)
197196
# Multi-output functions have pullbacks with a tuple input that will be destructured
198197
pullback_input = n_outputs == 1 ? first(Δs) : Expr(:tuple, Δs...)
199198
pullback = quote
200-
function $(propagator_name(f, :pullback))($pullback_input)
199+
function $(esc(propagator_name(f, :pullback)))($pullback_input)
201200
return (NO_FIELDS, $(pullback_returns...))
202201
end
203202
end
@@ -223,16 +222,14 @@ function propagation_expr(Δs, ∂s, _conj = false)
223222
∂s = map(esc, ∂s)
224223
n∂s = length(∂s)
225224

226-
# Due to bugs in Julia 1.0, we can't use `.+` or `.*` inside expression
227-
# literals.
225+
# Due to bugs in Julia 1.0, we can't use `.+` or `.*` inside expression literals.
228226
∂_mul_Δs = if _conj
229227
ntuple(i->:(conj($(∂s[i])) * $(Δs[i])), n∂s)
230228
else
231229
ntuple(i->:($(∂s[i]) * $(Δs[i])), n∂s)
232230
end
233231

234-
# Avoiding the extra `+` operation, it is potentially expensive for vector
235-
# mode AD.
232+
# Avoiding the extra `+` operation, it is potentially expensive for vector mode AD.
236233
sumed_∂_mul_Δs = if n∂s > 1
237234
# we use `@.` to broadcast `*` and `+`
238235
:(@. +($(∂_mul_Δs...)))
@@ -273,37 +270,38 @@ macro non_differentiable(call_expr)
273270
primal_name, orig_args = Iterators.peel(call_expr.args)
274271

275272
constrained_args = _constrain_and_name.(orig_args, :Any)
273+
primal_sig_parts = [:(::typeof($primal_name)), constrained_args...]
274+
276275
unconstrained_args = _unconstrain.(constrained_args)
277276
primal_invoke = Expr(:call, esc(primal_name), esc.(unconstrained_args)...)
278-
279-
280-
primal_sig_parts = [:(::typeof($primal_name)), constrained_args...]
277+
278+
quote
279+
$(_nondiff_frule_expr(primal_sig_parts, primal_invoke))
280+
$(_nondiff_rrule_expr(primal_sig_parts, primal_invoke))
281+
end
282+
end
281283

282-
# TODO Move to frule helper
283-
frule_defn = Expr(
284+
function _nondiff_frule_expr(primal_sig_parts, primal_invoke)
285+
return Expr(
284286
:(=),
285287
Expr(:call, :(ChainRulesCore.frule), esc(:_), esc.(primal_sig_parts)...),
286-
# How many outputs we have it doesn't matter: `DoesNotExist()` is a iterator that
287-
# returns `DoesNotExist()` for every position.
288+
# Julia functions always only have 1 output, so just return a single DoesNotExist()
288289
Expr(:tuple, primal_invoke, DoesNotExist())
289290
)
291+
end
290292

291-
# TODO Move to rrule helper
292-
293+
function _nondiff_rrule_expr(primal_sig_parts, primal_invoke)
294+
num_primal_inputs = length(primal_sig_parts) - 1
295+
primal_name = first(primal_invoke.args)
293296
pullback_expr = Expr(
294297
:function,
295298
Expr(:call, esc(propagator_name(primal_name, :pullback)), esc(:_)),
296-
Expr(:tuple, NO_FIELDS, (DoesNotExist() for _ in constrained_args)...)
299+
Expr(:tuple, NO_FIELDS, ntuple(_->DoesNotExist(), num_primal_inputs)...)
297300
)
298301
rrule_defn = Expr(
299302
:(=),
300303
Expr(:call, :(ChainRulesCore.rrule), esc.(primal_sig_parts)...),
301304
Expr(:tuple, primal_invoke, pullback_expr),
302305
)
303-
304-
quote
305-
$frule_defn
306-
$rrule_defn
307-
end
308-
end
309-
306+
return rrule_defn
307+
end

test/rule_definition_tools.jl

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,40 @@
11
@testset "rule_definition_tools.jl" begin
22

3-
@testset "@nondifferentiable" begin
3+
@testset "@non_differentiable" begin
4+
@testset "nondiff_2_1" begin
5+
nondiff_2_1(x, y) = fill(7.5, 100)[x + y]
6+
@non_differentiable nondiff_2_1(::Any, ::Any)
7+
@test frule((Zero(), 1.2, 2.3), nondiff_2_1, 3, 2) == (7.5, DoesNotExist())
8+
res, pullback = rrule(nondiff_2_1, 3, 2)
9+
@test res == 7.5
10+
@test pullback(4.5) == (NO_FIELDS, DoesNotExist(), DoesNotExist())
11+
end
412

5-
end
6-
end
13+
@testset "nondiff_1_2" begin
14+
nondiff_1_2(x) = (5.0, 3.0)
15+
@non_differentiable nondiff_1_2(::Any)
16+
@test frule((Zero(), 1.2), nondiff_1_2, 3.1) == ((5.0, 3.0), DoesNotExist())
17+
res, pullback = rrule(nondiff_1_2, 3.1)
18+
@test res == (5.0, 3.0)
19+
@test isequal(
20+
pullback(Composite{Tuple{Float64, Float64}}(1.2, 3.2)),
21+
(NO_FIELDS, DoesNotExist()),
22+
)
23+
end
24+
25+
@testset "specific signature" begin
26+
nonembed_identity(x) = x
27+
@non_differentiable nonembed_identity(::Integer)
728

29+
@test frule((Zero(), 1.2), nonembed_identity, 2) == (2, DoesNotExist())
30+
@test frule((Zero(), 1.2), nonembed_identity, 2.0) == nothing
831

9-
Base.remove_linenums!(@macroexpand @non_differentiable println(io::IO))
32+
res, pullback = rrule(nonembed_identity, 2)
33+
@test res == 2
34+
@test pullback(1.2) == (NO_FIELDS, DoesNotExist())
35+
36+
@test rrule(nonembed_identity, 2.0) == nothing
37+
end
38+
end
39+
end
1040

11-
@non_differentiable println(io::IO)

0 commit comments

Comments
 (0)