Skip to content

Commit 07ccefa

Browse files
feat: add new derivative rule syntax
1 parent bda71ae commit 07ccefa

File tree

2 files changed

+298
-3
lines changed

2 files changed

+298
-3
lines changed

src/Symbolics.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import TermInterface: maketerm, iscall, operation, arguments, metadata
2828
import SymbolicUtils: Term, Add, Mul, Sym, Div, BasicSymbolic, Const,
2929
FnType, @rule, Rewriters, substitute, symtype, shape, unwrap, unwrap_const,
3030
promote_symtype, isadd, ismul, ispow, isterm, issym, isdiv, BSImpl, scalarize,
31-
Operator, _iszero, _isone, search_variables, search_variables!
31+
Operator, _iszero, _isone, search_variables, search_variables!, ArgsT, ROArgsT
3232

3333
using SymbolicUtils.Code
3434

@@ -74,6 +74,7 @@ const COMMON_ONE = SymbolicUtils.one_of_vartype(VartypeT)
7474
const COMMON_ZERO = SymbolicUtils.zero_of_vartype(VartypeT)
7575
const SymbolicT = BasicSymbolic{VartypeT}
7676
const SArgsT = SymbolicUtils.ArgsT{VartypeT}
77+
const SConst = SymbolicUtils.BSImpl.Const{VartypeT}
7778
const SSym = SymbolicUtils.Sym{VartypeT}
7879
const STerm = SymbolicUtils.Term{VartypeT}
7980

@@ -152,8 +153,8 @@ include("linearity.jl")
152153
using DiffRules, SpecialFunctions, NaNMath
153154

154155

155-
export Differential, expand_derivatives, is_derivative
156-
156+
export Differential, expand_derivatives, is_derivative, @register_derivative, @derivative_rule
157+
include("register_derivatives.jl")
157158
include("diff.jl")
158159

159160
export SymbolicsSparsityDetector

src/register_derivatives.jl

Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
"""
2+
derivative_rule(::typeof(f), ::Val{NArgs}, args::ArgsT{VartypeT}, ::Val{I})
3+
4+
Define the derivative rule for `f` with `Nargs` arguments `args` with respect to the `I`th
5+
argument. Do not define this function directly. Prefer using
6+
[`@register_derivative`](@ref). Instead of calling this function directly, prefer
7+
[`@derivative_rule`](@ref).
8+
"""
9+
function derivative_rule end
10+
11+
"""
12+
@register_derivative fn(args...) Ith_arg derivative
13+
14+
Register a symbolic derivative for a function. This typically accompanies a call to
15+
[`@register_symbolic`](@ref) or [`@register_array_symbolic`](@ref) and defines how
16+
[`expand_derivatives`](@ref) will behave when it tries to differentiate the registered
17+
function.
18+
19+
The first argument to the macro is a call to the function whose derivative is being
20+
defined. The call cannot have keyword arguments or default arguments. The call must have
21+
either an exact number of arguments or a single variadic argument. For example, `f(a)`,
22+
`f(a, b)`, `f(a, b, c)` and `f(args...)` are valid signatures. `f(a, b, args...)` is
23+
invalid. If an exact number of arguments is provided, the defined derivative is specific
24+
to that number of arguments. If the variadic signature is used, the defined derivative
25+
is valid for all numbers of arguments. In case multiple derivatives are registered for
26+
the same function, they must have different numbers of arguments. A derivative for an
27+
exact number of arguments is more specific than a variadic definition. For example,
28+
`@register_derivatives f(a, b) #...` is more specific than
29+
`@register_derivatives f(args...) #...` for a 2-argument call to `f`. The arguments
30+
can be referred to with their declared names inside the derivative definition.
31+
32+
The second argument to the macro is the argument with respect to which the derivative
33+
rule is defined. For example, `@register_derivative f(a, b) 2 #...` is a derivative rule
34+
with respect to the second argument of `f`. Mathematically, it represents
35+
``\\frac{ \\partial f(a, b) }{ \\partial b }``. To define a generic derivative, this
36+
argument can be an identifier. For example, `@register_derivative f(a, b) I #...` makes
37+
`I` available in the derivative definition as the index of the argument with respect to
38+
which the derivative is being taken.
39+
40+
The third argument to the macro is the derivative expression. This should be a symbolically
41+
traceable expression returning the derivative of the specified function with respect to
42+
the specified argument. In case of a variadic definition, the identifier `Nargs` is available
43+
to denote the number of arguments provided to the function. In case the variadic form is
44+
used, the arguments are available as a read-only array (mutation will error). Mutating
45+
the array is unsafe and undefined behavior.
46+
47+
!!! note
48+
For functions that return arrays (such as those registered via `@register_array_symbolic`)
49+
the returned expression must be the Jacobian. Currently, support for differentiating array
50+
functions is considered experimental.
51+
52+
!!! warning
53+
The derivative expression MUST return a symbolic value, or `nothing` if the derivative is
54+
not defined. In case the result is a non-symbolic value, such as a constant derivative or
55+
Jacobian of array functions, the result MUST be wrapped in `Symbolics.SConst(..)`.
56+
57+
Following are example definitions of derivatives:
58+
59+
```julia
60+
@register_derivative sin(x) 1 cos(x)
61+
@register_derivative max(x, y) 2 ifelse(x >= y, 0, 1)
62+
@register_derivative min(args...) I begin
63+
error("The rule for the derivative of `min` with \$Nargs arguments w.r.t the \$I-th argument is undefined.")
64+
end
65+
@register_derivative (foo::MyCallableStruct)(args...) I begin
66+
error("Oops! Didn't implement the derivative for \$foo")
67+
end
68+
```
69+
"""
70+
macro register_derivative(f::Expr, I::Union{Symbol, Int}, body)
71+
@assert Meta.isexpr(f, :call) """
72+
Incorrect `@register_derivative` syntax. The function must be provided as a call \
73+
signature. Got `$f` which is not a call signature.
74+
"""
75+
fnhead = f.args[1]
76+
fncallargs = @view f.args[2:end]
77+
is_struct_der = Meta.isexpr(fnhead, :(::))
78+
if is_struct_der
79+
@assert length(fnhead.args) == 2 """
80+
Incorrect `@register_derivative` syntax. Registering derivatives of callable \
81+
structs requires providing a name for the struct. For example, instead of
82+
`@register_derivative (::MyStruct) # ...` use \
83+
`@register_derivative (x::MyStruct) # ...`.
84+
"""
85+
end
86+
@assert !any(Base.Fix2(Meta.isexpr, :kw), fncallargs) """
87+
Incorrect `@register_derivative` syntax. The function cannot have default arguments.
88+
"""
89+
@assert !Meta.isexpr(fncallargs[1], :parameters) """
90+
Incorrect `@register_derivative` syntax. The function cannot have keyword arguments.
91+
"""
92+
93+
is_varargs = Meta.isexpr(fncallargs[end], :...)
94+
if is_varargs
95+
@assert length(fncallargs) == 1 """
96+
Incorrect `@register_derivative` syntax. The function call signature must either \
97+
be a single variadic argument `@register_derivative foo(args...) #...` or a \
98+
concrete number of arguments `@register_derivative foo(arg1, arg2, arg3) # ...`.
99+
"""
100+
end
101+
102+
derhead = Expr(:call, :($Symbolics.derivative_rule), is_struct_der ? fnhead : :(::($typeof($fnhead))))
103+
Nargs = is_varargs ? :Nargs : length(fncallargs)
104+
push!(derhead.args, :(::Val{$Nargs}))
105+
args_name = gensym(:args)
106+
push!(derhead.args, :($args_name::$SymbolicUtils.ROArgsT{$VartypeT}))
107+
push!(derhead.args, :(::Val{$I}))
108+
109+
if is_varargs || I isa Symbol
110+
derhead = Expr(:where, derhead)
111+
is_varargs && push!(derhead.args, Nargs)
112+
I isa Symbol && push!(derhead.args, I)
113+
end
114+
115+
if is_varargs
116+
unpack = Expr(:(=), fncallargs[1].args[1], args_name)
117+
else
118+
unpack = Expr(:tuple)
119+
append!(unpack.args, fncallargs)
120+
unpack = Expr(:(=), unpack, args_name)
121+
end
122+
123+
return esc(Expr(:function, derhead, Expr(:block, unpack, body)))
124+
end
125+
126+
# Pre-defined derivatives
127+
import DiffRules
128+
for (modu, fun, arity) DiffRules.diffrules(; filter_modules=(:Base, :SpecialFunctions, :NaNMath))
129+
fun in [:*, :+, :abs, :mod, :rem, :max, :min] && continue # special
130+
for i 1:arity
131+
132+
expr = if arity == 1
133+
DiffRules.diffrule(modu, fun, :(args[1]))
134+
else
135+
DiffRules.diffrule(modu, fun, ntuple(k->:(args[$k]), arity)...)[i]
136+
end
137+
138+
# Using the macro here doesn't work somehow.
139+
@eval function derivative_rule(::typeof($modu.$fun), ::Val{$arity}, args::SymbolicUtils.ArgsT{VartypeT}, ::Val{$i})
140+
$SConst($expr)
141+
end
142+
end
143+
end
144+
145+
Base.@propagate_inbounds function _derivative_rule_proxy(f, args::NTuple{N, SymbolicT}, ::Val{I}) where {N, I}
146+
_derivative_rule_proxy(f, Val{N}(), args, Val{I}())
147+
end
148+
Base.@propagate_inbounds function _derivative_rule_proxy(f, ::Val{N}, args::NTuple{N, SymbolicT}, ::Val{I}) where {N, I}
149+
_derivative_rule_proxy(f, Val{N}(), SymbolicUtils.ArgsT{VartypeT}(args), Val{I}())
150+
end
151+
Base.@propagate_inbounds function _derivative_rule_proxy(f, args::Tuple, ::Val{I}) where {I}
152+
_derivative_rule_proxy(f, Val{length(args)}(), args, Val{I}())
153+
end
154+
Base.@propagate_inbounds function _derivative_rule_proxy(f, ::Val{N}, args::Tuple{Vararg{Any, N}}, ::Val{I}) where {N, I}
155+
args = ntuple(BSImpl.Const{VartypeT} Base.Fix1(getindex, args), Val{N}())
156+
_derivative_rule_proxy(f, Val{N}(), args, Val{I}())
157+
end
158+
Base.@propagate_inbounds function _derivative_rule_proxy(f, args::ROArgsT{VartypeT}, ::Val{I}) where {I}
159+
@inbounds _derivative_rule_proxy(f, Val{length(args)}(), args, Val{I}())
160+
end
161+
Base.@propagate_inbounds function _derivative_rule_proxy(f, ::Val{N}, args::ROArgsT{VartypeT}, ::Val{I}) where {N, I}
162+
@boundscheck checkbounds(args, N)
163+
derivative_rule(f, Val{N}(), args, Val{I}())
164+
end
165+
Base.@propagate_inbounds function _derivative_rule_proxy(f, args::ArgsT{VartypeT}, ::Val{I}) where {I}
166+
@inbounds _derivative_rule_proxy(f, Val{length(args)}(), args, Val{I}())
167+
end
168+
Base.@propagate_inbounds function _derivative_rule_proxy(f, ::Val{N}, args::ArgsT{VartypeT}, ::Val{I}) where {N, I}
169+
@boundscheck checkbounds(args, N)
170+
_derivative_rule_proxy(f, Val{N}(), ROArgsT{VartypeT}(args), Val{I}())
171+
end
172+
Base.@propagate_inbounds function _derivative_rule_proxy(f, args::AbstractArray{SymbolicT}, ::Val{I}) where {I}
173+
@inbounds _derivative_rule_proxy(f, Val{length(args)}(), args, Val{I}())
174+
end
175+
Base.@propagate_inbounds function _derivative_rule_proxy(f, ::Val{N}, args::AbstractArray{SymbolicT}, ::Val{I}) where {N, I}
176+
@boundscheck checkbounds(args, N)
177+
_derivative_rule_proxy(f, Val{N}(), ArgsT{VartypeT}(args), Val{I}())
178+
end
179+
Base.@propagate_inbounds function _derivative_rule_proxy(f, args::AbstractArray, ::Val{I}) where {I}
180+
@inbounds _derivative_rule_proxy(f, Val{length(args)}(), args, Val{I}())
181+
end
182+
Base.@propagate_inbounds function _derivative_rule_proxy(f, ::Val{N}, args::AbstractArray, ::Val{I}) where {N, I}
183+
@boundscheck checkbounds(args, N)
184+
_args = ArgsT{VartypeT}()
185+
sizehint!(_args, N)
186+
for a in args
187+
push!(_args, BSImpl.Const{VartypeT}(a))
188+
end
189+
_derivative_rule_proxy(f, Val{N}(), _args, Val{I}())
190+
end
191+
192+
"""
193+
@derivative_rule f(args...) I
194+
195+
Query Symbolics.jl's derivative rule system for the derivative of `f(args...)` with respect to
196+
`args[I]`. Returns a symbolic result representing the derivative. In case the derivative rule is
197+
not defined, evaluates to `nothing`.
198+
199+
The first argument to the macro must be a valid function call syntax. Splatting of arguments is
200+
permitted. The second argument must be an expression or literal evaluating to the index of the
201+
argument with respect to which the derivative is required.
202+
203+
The derivative rule can dispatch statically if `f`, the number of arguments and `I` are known
204+
at compile time. Example invocations are:
205+
206+
```julia
207+
# static dispatch
208+
@derivative_rule sin(x) 1
209+
# static dispatch if `xs` is a tuple
210+
@derivative_rule max(xs...) 2
211+
# static dispatch if `y` and `w` are tuples, and `N + 2K` is a compile-time constant
212+
@derivative_rule foo(x, y..., z, w...) (N + 2K)
213+
```
214+
"""
215+
macro derivative_rule(f, I)
216+
@assert Meta.isexpr(f, :call) """
217+
Incorrect `@derivative_rule` syntax. The function must be provided as a call \
218+
signature. Got `$f` which is not a call signature.
219+
"""
220+
fnhead = f.args[1]
221+
fncallargs = @view f.args[2:end]
222+
result = Expr(:call, _derivative_rule_proxy, fnhead)
223+
if length(fncallargs) == 1 && Meta.isexpr(fncallargs[1], :...)
224+
push!(result.args, fncallargs[1].args[1])
225+
elseif any(Base.Fix2(Meta.isexpr, :...), fncallargs)
226+
args = Expr(:tuple)
227+
append!(args.args, fncallargs)
228+
push!(result.args, args)
229+
else
230+
push!(result.args, :(Val{$(length(fncallargs))}()))
231+
args = Expr(:tuple)
232+
append!(args.args, fncallargs)
233+
push!(result.args, args)
234+
end
235+
push!(result.args, :(Val{$I}()))
236+
return esc(result)
237+
end
238+
239+
@register_derivative +(args...) I COMMON_ONE
240+
@register_derivative *(args...) I begin
241+
if I == 1
242+
SymbolicUtils.mul_worker(VartypeT, view(args, 2:Nargs))
243+
elseif I == Nargs
244+
SymbolicUtils.mul_worker(VartypeT, view(args, 1:(Nargs-1)))
245+
else
246+
t1 = SymbolicUtils.mul_worker(VartypeT, view(args, 1:(Nargs-1)))
247+
t2 = SymbolicUtils.mul_worker(VartypeT, view(args, 2:Nargs))
248+
t1 * t2
249+
end
250+
end
251+
@register_derivative one(x) 1 COMMON_ZERO
252+
253+
"""
254+
$(SIGNATURES)
255+
256+
Calculate the derivative of the op `O` with respect to its argument with index
257+
`idx`.
258+
259+
# Examples
260+
261+
```jldoctest label1
262+
julia> using Symbolics
263+
264+
julia> @variables x y;
265+
266+
julia> Symbolics.derivative_idx(Symbolics.value(sin(x)), 1)
267+
cos(x)
268+
```
269+
270+
Note that the function does not recurse into the operation's arguments, i.e., the
271+
chain rule is not applied:
272+
273+
```jldoctest label1
274+
julia> myop = Symbolics.value(sin(x) * y^2)
275+
sin(x)*(y^2)
276+
277+
julia> typeof(Symbolics.operation(myop)) # Op is multiplication function
278+
typeof(*)
279+
280+
julia> Symbolics.derivative_idx(myop, 1) # wrt. sin(x)
281+
y^2
282+
283+
julia> Symbolics.derivative_idx(myop, 2) # wrt. y^2
284+
sin(x)
285+
```
286+
"""
287+
@inline derivative_idx(::Any, ::Any) = COMMON_ZERO
288+
function derivative_idx(O::VartypeT, idx::Int)
289+
iscall(O) || return COMMON_ZERO
290+
f = operation(O)
291+
args = arguments(O)
292+
return @derivative_rule f(args...) idx
293+
end
294+

0 commit comments

Comments
 (0)