Skip to content

Commit 7f35d07

Browse files
committed
WIP outline nondifferentiable macro (untested)
1 parent 0fe9da8 commit 7f35d07

File tree

2 files changed

+53
-7
lines changed

2 files changed

+53
-7
lines changed

src/rule_definition_tools.jl

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,8 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
122122
end
123123

124124
# Remove annotations and escape names for the call
125-
for (i, arg) in enumerate(call.args)
126-
if Meta.isexpr(arg, :(::))
127-
call.args[i] = esc(first(arg.args))
128-
else
129-
call.args[i] = esc(arg)
130-
end
131-
end
125+
call = _without_constraints(call)
126+
call.args = esc.(call.args)
132127

133128
# For consistency in code that follows we make all partials tuple expressions
134129
partials = map(partials) do partial
@@ -143,6 +138,15 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
143138
return call, setup_stmts, inputs, partials
144139
end
145140

141+
"turn `foo(a, b::S)` into `foo(a, b)`"
142+
function _without_constraints(call_expr)
143+
return Expr(
144+
:call,
145+
(Meta.isexpr(arg, :(::)) ? first(arg.args) : arg for arg in call_expr.args)...
146+
)
147+
end
148+
149+
146150
function scalar_frule_expr(f, call, setup_stmts, inputs, partials)
147151
n_outputs = length(partials)
148152
n_inputs = length(inputs)
@@ -258,3 +262,39 @@ This is able to deal with fairly complex expressions for `f`:
258262
propagator_name(f::Expr, propname::Symbol) = propagator_name(f.args[end], propname)
259263
propagator_name(fname::Symbol, propname::Symbol) = Symbol(fname, :_, propname)
260264
propagator_name(fname::QuoteNode, propname::Symbol) = propagator_name(fname.value, propname)
265+
266+
267+
macro @non_differentiable(call_expr)
268+
Meta.isexpr(:call, call_expr) || error("Invalid use of `@non_differentiable`")
269+
270+
primal_call = _without_constraints(call_expr)
271+
primal_call.args = esc.(primal_call.args)
272+
273+
# TODO Move to frule helper
274+
frule_defn = Expr(
275+
:(=),
276+
Expr(:call, :(ChainRulesCore.frule), :_, call_expr.args...),
277+
# How many outputs we have it doesn't matter: `DoesNotExist()` is a iterator that
278+
# returns `DoesNotExist()` for every position.
279+
Expr(:tuple, primal_call, DoesNotExist())
280+
)
281+
282+
# TODO Move to rrule helper
283+
primal_name = first(primal_call.args)
284+
pullback_expr = Expr(
285+
:(=),
286+
Expr(:call, propagator_name(primal_name, :pullback), :_),
287+
Expr(:tuple, NO_FIELDS, (DoesNotExist() for _ in primal_call.args[2:end])...)
288+
)
289+
rrule_defn = Expr(
290+
:(=),
291+
Expr(:call, :(ChainRulesCore.rrule), call_expr.args...),
292+
Expr(:tuple, primal_call, pullback_expr),
293+
)
294+
295+
quote
296+
$frule_defn
297+
$rrule_defn
298+
end
299+
end
300+

test/rule_definition_tools.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
@testset "rule_definition_tools.jl" begin
2+
3+
@testset "@nondifferentiable" begin
4+
5+
end
6+
end

0 commit comments

Comments
 (0)