Skip to content

Commit fc42649

Browse files
committed
fix code generated
1 parent 7f35d07 commit fc42649

File tree

3 files changed

+37
-23
lines changed

3 files changed

+37
-23
lines changed

src/ChainRulesCore.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using MuladdMacro: @muladd
44

55
export on_new_rule, refresh_rules # generation tools
66
export frule, rrule # core function
7-
export @scalar_rule, @thunk # definition helper macros
7+
export @non_differentiable, @scalar_rule, @thunk # definition helper macros
88
export canonicalize, extern, unthunk # differential operations
99
# differentials
1010
export Composite, DoesNotExist, InplaceableThunk, One, Thunk, Zero, AbstractZero, AbstractThunk

src/rule_definition_tools.jl

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,10 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
117117
@assert Meta.isexpr(call, :call)
118118

119119
# Annotate all arguments in the signature as scalars
120-
inputs = map(call.args[2:end]) do arg
121-
esc(Meta.isexpr(arg, :(::)) ? arg : Expr(:(::), arg, :Number))
122-
end
120+
inputs = _constrain_and_name.(call.args[2:end], :Number)
123121

124122
# Remove annotations and escape names for the call
125-
call = _without_constraints(call)
123+
call.args = _unconstrain.(call.args)
126124
call.args = esc.(call.args)
127125

128126
# For consistency in code that follows we make all partials tuple expressions
@@ -138,14 +136,20 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
138136
return call, setup_stmts, inputs, partials
139137
end
140138

141-
"turn `foo(a, b::S)` into `foo(a, b)`"
142-
function _without_constraints(call_expr)
143-
return Expr(
144-
:call,
145-
(Meta.isexpr(arg, :(::)) ? first(arg.args) : arg for arg in call_expr.args)...
146-
)
139+
"turn both `a` and `a::S` into `a`"
140+
_unconstrain(arg::Symbol) = arg
141+
function _unconstrain(arg::Expr)
142+
Meta.isexpr(arg, :(::), 2) && return arg.args[1] # dop constraint.
143+
error("malformed arguments: $arg")
147144
end
148145

146+
"turn both `a` and `::Number` into `a::Number` into `a::Number` etc"
147+
function _constrain_and_name(arg::Expr, default_constraint)
148+
Meta.isexpr(arg, :(::), 2) && return arg # it is already fine.
149+
Meta.isexpr(arg, :(::), 1) && return Expr(:(::), gensym(), arg.args[1]) #add name
150+
error("malformed arguments: $arg")
151+
end
152+
_constrain_and_name(name::Symbol, constraint) = Expr(:(::), name, constraint) # add type
149153

150154
function scalar_frule_expr(f, call, setup_stmts, inputs, partials)
151155
n_outputs = length(partials)
@@ -264,32 +268,37 @@ propagator_name(fname::Symbol, propname::Symbol) = Symbol(fname, :_, propname)
264268
propagator_name(fname::QuoteNode, propname::Symbol) = propagator_name(fname.value, propname)
265269

266270

267-
macro @non_differentiable(call_expr)
268-
Meta.isexpr(:call, call_expr) || error("Invalid use of `@non_differentiable`")
271+
macro non_differentiable(call_expr)
272+
Meta.isexpr(call_expr, :call) || error("Invalid use of `@non_differentiable`")
273+
primal_name, orig_args = Iterators.peel(call_expr.args)
269274

270-
primal_call = _without_constraints(call_expr)
271-
primal_call.args = esc.(primal_call.args)
275+
constrained_args = _constrain_and_name.(orig_args, :Any)
276+
unconstrained_args = _unconstrain.(constrained_args)
277+
primal_invoke = Expr(:call, esc(primal_name), esc.(unconstrained_args)...)
278+
279+
280+
primal_sig_parts = [:(::typeof($primal_name)), constrained_args...]
272281

273282
# TODO Move to frule helper
274283
frule_defn = Expr(
275284
:(=),
276-
Expr(:call, :(ChainRulesCore.frule), :_, call_expr.args...),
285+
Expr(:call, :(ChainRulesCore.frule), esc(:_), esc.(primal_sig_parts)...),
277286
# How many outputs we have it doesn't matter: `DoesNotExist()` is a iterator that
278287
# returns `DoesNotExist()` for every position.
279-
Expr(:tuple, primal_call, DoesNotExist())
288+
Expr(:tuple, primal_invoke, DoesNotExist())
280289
)
281290

282291
# TODO Move to rrule helper
283-
primal_name = first(primal_call.args)
292+
284293
pullback_expr = Expr(
285-
:(=),
286-
Expr(:call, propagator_name(primal_name, :pullback), :_),
287-
Expr(:tuple, NO_FIELDS, (DoesNotExist() for _ in primal_call.args[2:end])...)
294+
:function,
295+
Expr(:call, esc(propagator_name(primal_name, :pullback)), esc(:_)),
296+
Expr(:tuple, NO_FIELDS, (DoesNotExist() for _ in constrained_args)...)
288297
)
289298
rrule_defn = Expr(
290299
:(=),
291-
Expr(:call, :(ChainRulesCore.rrule), call_expr.args...),
292-
Expr(:tuple, primal_call, pullback_expr),
300+
Expr(:call, :(ChainRulesCore.rrule), esc.(primal_sig_parts)...),
301+
Expr(:tuple, primal_invoke, pullback_expr),
293302
)
294303

295304
quote

test/rule_definition_tools.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,8 @@
44

55
end
66
end
7+
8+
9+
Base.remove_linenums!(@macroexpand @non_differentiable println(io::IO))
10+
11+
@non_differentiable println(io::IO)

0 commit comments

Comments
 (0)