Skip to content

Commit 431dc91

Browse files
refactor: update to new derivative rule syntax
1 parent 07ccefa commit 431dc91

File tree

5 files changed

+60
-154
lines changed

5 files changed

+60
-154
lines changed

src/diff.jl

Lines changed: 10 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -225,15 +225,15 @@ function Base.showerror(io::IO, err::DerivativeNotDefinedError)
225225
# and code fences
226226
err_str = Markdown.parse("""
227227
Derivative of `$(err.expr)` with respect to its $(err.i)-th argument is not defined.
228-
Define a derivative by adding a method to `Symbolics.derivative`:
228+
Define a derivative by using `@register_derivative`:
229229
230230
```julia
231-
function Symbolics.derivative(::typeof($op), args::NTuple{$nargs, Any}, ::Val{$(err.i)})
231+
@register_derivative $op(args...) $(err.i) begin
232232
# ...
233233
end
234234
```
235235
236-
Refer to the documentation for `Symbolics.derivative` and the
236+
Refer to the documentation for `@register_derivative` and the
237237
"[Adding Analytical Derivatives](@ref)" section of the docs for further information.
238238
""")
239239
show(io, MIME"text/plain"(), err_str)
@@ -321,12 +321,12 @@ function executediff(D::Differential, arg::BasicSymbolic{VartypeT}; simplify=fal
321321
# We know `D.x` is in `arg`, so the derivative is not identically zero.
322322
# `arg` cannot be `D.x` since, that would have also early exited.
323323
for (i, a) in enumerate(inner_args)
324-
der = derivative_idx(arr, i)
324+
der = derivative_idx(arr, i)::Union{Nothing, SymbolicT}
325325
if isequal(a, D.x)
326-
der isa NoDeriv && return D(arg)
326+
der === nothing && return D(arg)
327327
push!(summed_args, der[idx])
328328
continue
329-
elseif der isa NoDeriv
329+
elseif der === nothing
330330
push!(summed_args, Differential(a)(arg) * executediff(D, a))
331331
else
332332
push!(summed_args, der[idx] * executediff(D, a))
@@ -381,8 +381,8 @@ function executediff(D::Differential, arg::BasicSymbolic{VartypeT}; simplify=fal
381381
for (i, iarg) in enumerate(inner_args)
382382
t2 = executediff(D, iarg; simplify, throw_no_derivative)::SymbolicT
383383
_iszero(t2) && continue
384-
t = derivative_idx(arg, i)::Union{NoDeriv, SymbolicT}
385-
if t isa NoDeriv
384+
t = derivative_idx(arg, i)::Union{Nothing, SymbolicT}
385+
if t === nothing
386386
throw_no_derivative && throw(DerivativeNotDefinedError(arg, i))
387387
t = D(arg)
388388
end
@@ -506,89 +506,6 @@ function expand_derivatives(n::Complex{Num}, simplify=false; kwargs...)
506506
end
507507
expand_derivatives(x, simplify=false; kwargs...) = x
508508

509-
# Don't specialize on the function here
510-
"""
511-
$(SIGNATURES)
512-
513-
Calculate the derivative of the op `O` with respect to its argument with index
514-
`idx`.
515-
516-
# Examples
517-
518-
```jldoctest label1
519-
julia> using Symbolics
520-
521-
julia> @variables x y;
522-
523-
julia> Symbolics.derivative_idx(Symbolics.value(sin(x)), 1)
524-
cos(x)
525-
```
526-
527-
Note that the function does not recurse into the operation's arguments, i.e., the
528-
chain rule is not applied:
529-
530-
```jldoctest label1
531-
julia> myop = Symbolics.value(sin(x) * y^2)
532-
sin(x)*(y^2)
533-
534-
julia> typeof(Symbolics.operation(myop)) # Op is multiplication function
535-
typeof(*)
536-
537-
julia> Symbolics.derivative_idx(myop, 1) # wrt. sin(x)
538-
y^2
539-
540-
julia> Symbolics.derivative_idx(myop, 2) # wrt. y^2
541-
sin(x)
542-
```
543-
"""
544-
derivative_idx(O::Any, ::Any) = COMMON_ZERO
545-
function derivative_idx(O::BasicSymbolic, idx)
546-
iscall(O) || return COMMON_ZERO
547-
res = derivative(operation(O), (arguments(O)...,), Val(idx))
548-
if res isa NoDeriv
549-
return res
550-
else
551-
return Const{VartypeT}(res)
552-
end
553-
end
554-
555-
# Indicate that no derivative is defined.
556-
struct NoDeriv
557-
end
558-
559-
"""
560-
Symbolics.derivative(::typeof(f), args::NTuple{N, Any}, ::Val{i})
561-
562-
Return the derivative of `f(args...)` with respect to `args[i]`. `N` should be the number
563-
of arguments that `f` takes and `i` is the argument with respect to which the derivative
564-
is taken. The result can be a numeric value (if the derivative is constant) or a symbolic
565-
expression. This function is useful for defining derivatives of custom functions registered
566-
via `@register_symbolic`, to be used when calling `expand_derivatives`.
567-
"""
568-
derivative(f, args, v) = NoDeriv()
569-
570-
# Pre-defined derivatives
571-
import DiffRules
572-
for (modu, fun, arity) DiffRules.diffrules(; filter_modules=(:Base, :SpecialFunctions, :NaNMath))
573-
fun in [:*, :+, :abs, :mod, :rem, :max, :min] && continue # special
574-
for i 1:arity
575-
576-
expr = if arity == 1
577-
DiffRules.diffrule(modu, fun, :(args[1]))
578-
else
579-
DiffRules.diffrule(modu, fun, ntuple(k->:(args[$k]), arity)...)[i]
580-
end
581-
@eval derivative(::typeof($modu.$fun), args::NTuple{$arity,Any}, ::Val{$i}) = $expr
582-
end
583-
end
584-
585-
derivative(::typeof(+), args::NTuple{N,Any}, ::Val) where {N} = 1
586-
derivative(::typeof(*), args::NTuple{N,Any}, ::Val{i}) where {N,i} = *(deleteat!(collect(args), i)...)
587-
derivative(::typeof(one), args::Tuple{<:Any}, ::Val) = 0
588-
589-
derivative(f::Function, x::Union{Num, <:BasicSymbolic}) = derivative(f(x), x)
590-
derivative(::Function, x::Any) = TypeError(:derivative, "2nd argument", Union{Num, <:BasicSymbolic}, x) |> throw
591-
592509
function count_order(x)
593510
@assert !(x isa Symbol) "The variable $x must have an order of differentiation that is greater or equal to 1!"
594511
n = 1
@@ -666,6 +583,8 @@ function derivative(O, var; simplify=false, kwargs...)
666583
Num(expand_derivatives(Differential(var)(unwrap(O)), simplify; kwargs...))
667584
end
668585
end
586+
derivative(f::Function, var::Union{SymbolicT, Num}) = derivative(f(var), var)
587+
derivative(::Function, x::Any) = throw(TypeError(:derivative, "2nd argument", Union{Num, SymbolicT}, x))
669588

670589
"""
671590
$(SIGNATURES)

src/extra_functions.jl

Lines changed: 18 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,31 +7,16 @@ for (T1, T2) in Iterators.product([Number, BasicSymbolic{VartypeT}, Num], [Integ
77
end
88
end
99

10-
derivative(::typeof(sign), args::NTuple{1,Any}, ::Val{1}) = 0
11-
12-
derivative(::typeof(signbit), args::NTuple{1,Any}, ::Val{1}) = 0
13-
derivative(::typeof(abs), args::NTuple{1,Any}, ::Val{1}) = ifelse(signbit(args[1]),-one(args[1]),one(args[1]))
14-
15-
function derivative(::typeof(min), args::NTuple{2,Any}, ::Val{1})
16-
x, y = args
17-
ifelse(x < y, one(x), zero(x))
18-
end
19-
function derivative(::typeof(min), args::NTuple{2,Any}, ::Val{2})
20-
x, y = args
21-
ifelse(x < y, zero(y), one(y))
22-
end
23-
function derivative(::typeof(max), args::NTuple{2,Any}, ::Val{1})
24-
x, y = args
25-
ifelse(x > y, one(x), zero(x))
26-
end
27-
function derivative(::typeof(max), args::NTuple{2,Any}, ::Val{2})
28-
x, y = args
29-
ifelse(x > y, zero(y), one(y))
30-
end
31-
32-
function derivative(::Union{typeof(ceil),typeof(floor),typeof(factorial)}, args::NTuple{1,Any}, ::Val{1})
33-
zero(args[1])
34-
end
10+
@register_derivative sign(x) 1 COMMON_ZERO
11+
@register_derivative signbit(x) 1 COMMON_ZERO
12+
@register_derivative abs(x) 1 ifelse(signbit(x),-one(x),one(x))
13+
@register_derivative min(x, y) 1 ifelse(x < y, one(x), zero(x))
14+
@register_derivative min(x, y) 2 ifelse(x < y, zero(y), one(y))
15+
@register_derivative max(x, y) 1 ifelse(x > y, one(x), zero(x))
16+
@register_derivative max(x, y) 2 ifelse(x > y, zero(y), one(y))
17+
@register_derivative ceil(x) 1 COMMON_ZERO
18+
@register_derivative floor(x) 1 COMMON_ZERO
19+
@register_derivative factorial(x) 1 COMMON_ZERO
3520

3621
@register_symbolic Base.rand(x)
3722
@register_symbolic Base.randn(x)
@@ -45,13 +30,7 @@ for (T1, T2, T3) in Iterators.product(Iterators.repeated((Num, BasicSymbolic{Var
4530
end
4631
end
4732

48-
function derivative(::typeof(Base.clamp), args::NTuple{3, Any}, ::Val{1})
49-
x, l, h = args
50-
T = promote_type(symtype(x), symtype(l), symtype(h))
51-
z = zero(T)
52-
o = one(T)
53-
ifelse(x<l, z, ifelse(x>h, z, o))
54-
end
33+
@register_derivative clamp(x, l, h) 1 ifelse(x < l, COMMON_ZERO, ifelse(x > h, COMMON_ZERO, COMMON_ONE))
5534

5635
for T1 in [Real, Num, BasicSymbolic{VartypeT}], T2 in [AbstractArray, Arr, BasicSymbolic{VartypeT}]
5736
if T1 != Num && T2 != Arr
@@ -78,11 +57,10 @@ end
7857

7958
LinearAlgebra.norm(x::Num, p::Real) = abs(x)
8059

81-
derivative(::typeof(<), ::NTuple{2, Any}, ::Val{i}) where {i} = 0
82-
derivative(::typeof(<=), ::NTuple{2, Any}, ::Val{i}) where {i} = 0
83-
derivative(::typeof(>), ::NTuple{2, Any}, ::Val{i}) where {i} = 0
84-
derivative(::typeof(>=), ::NTuple{2, Any}, ::Val{i}) where {i} = 0
85-
derivative(::typeof(==), ::NTuple{2, Any}, ::Val{i}) where {i} = 0
86-
derivative(::typeof(!=), ::NTuple{2, Any}, ::Val{i}) where {i} = 0
87-
88-
derivative(::typeof(expinti), args::NTuple{1,Any}, ::Val{1}) = exp(args[1])/args[1]
60+
@register_derivative <(x, y) I COMMON_ZERO
61+
@register_derivative <=(x, y) I COMMON_ZERO
62+
@register_derivative >(x, y) I COMMON_ZERO
63+
@register_derivative >=(x, y) I COMMON_ZERO
64+
@register_derivative ==(x, y) I COMMON_ZERO
65+
@register_derivative !=(x, y) I COMMON_ZERO
66+
@register_derivative expinti(x) 1 exp(x) / x

src/register_derivatives.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ macro register_derivative(f::Expr, I::Union{Symbol, Int}, body)
123123
return esc(Expr(:function, derhead, Expr(:block, unpack, body)))
124124
end
125125

126+
# Fallback
127+
@register_derivative (f::Any)(args...) I nothing
128+
126129
# Pre-defined derivatives
127130
import DiffRules
128131
for (modu, fun, arity) DiffRules.diffrules(; filter_modules=(:Base, :SpecialFunctions, :NaNMath))
@@ -136,7 +139,7 @@ for (modu, fun, arity) ∈ DiffRules.diffrules(; filter_modules=(:Base, :Special
136139
end
137140

138141
# Using the macro here doesn't work somehow.
139-
@eval function derivative_rule(::typeof($modu.$fun), ::Val{$arity}, args::SymbolicUtils.ArgsT{VartypeT}, ::Val{$i})
142+
@eval function derivative_rule(::typeof($modu.$fun), ::Val{$arity}, args::ROArgsT{VartypeT}, ::Val{$i})
140143
$SConst($expr)
141144
end
142145
end
@@ -285,7 +288,7 @@ sin(x)
285288
```
286289
"""
287290
@inline derivative_idx(::Any, ::Any) = COMMON_ZERO
288-
function derivative_idx(O::VartypeT, idx::Int)
291+
function derivative_idx(O::SymbolicT, idx::Int)
289292
iscall(O) || return COMMON_ZERO
290293
f = operation(O)
291294
args = arguments(O)

src/solver/solve_helpers.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ end
4141
SymbolicUtils.promote_symtype(::typeof(ssqrt), ::Type{T}) where {T} = T
4242
SymbolicUtils.promote_shape(::typeof(ssqrt), @nospecialize(sh::SymbolicUtils.ShapeT)) = sh
4343

44-
derivative(::typeof(ssqrt), args...) = substitute(derivative(sqrt, args...), sqrt => ssqrt)
44+
@register_derivative ssqrt(x) I begin
45+
substitute(@derivative_rule(sqrt(x), I), sqrt => ssqrt)
46+
end
4547

4648
function scbrt(n)
4749
n = unwrap(n)
@@ -62,7 +64,9 @@ end
6264

6365
SymbolicUtils.promote_symtype(::typeof(scbrt), ::Type{T}) where {T} = T
6466
SymbolicUtils.promote_shape(::typeof(scbrt), @nospecialize(sh::SymbolicUtils.ShapeT)) = sh
65-
derivative(::typeof(scbrt), args...) = substitute(derivative(cbrt, args...), cbrt => scbrt)
67+
@register_derivative scbrt(x) I begin
68+
substitute(@derivative_rule(cbrt(x), I), cbrt => scbrt)
69+
end
6670

6771
function slog(n)
6872
n = unwrap(n)
@@ -82,7 +86,9 @@ end
8286
SymbolicUtils.promote_symtype(::typeof(slog), ::Type{T}) where {T} = T
8387
SymbolicUtils.promote_shape(::typeof(slog), @nospecialize(sh::SymbolicUtils.ShapeT)) = sh
8488

85-
derivative(::typeof(slog), args...) = substitute(derivative(log, args...), log => slog)
89+
@register_derivative slog(x) I begin
90+
substitute(@derivative_rule(log(x), I), log => slog)
91+
end
8692

8793
const RootsOf = (SymbolicUtils.@syms roots_of(poly,var))[1]
8894

test/diff.jl

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -630,25 +630,25 @@ end
630630
test_equal.(Symbolics.unwrap.(Symbolics.jacobian([f], vp) .- Symbolics.jacobian([f], p)), 0)
631631
end
632632

633+
@register_array_symbolic foobar(x, y) begin
634+
size = (2,)
635+
eltype = Real
636+
ndims = 1
637+
end
638+
@register_array_symbolic dfoobar1(x, y) begin
639+
size = (2,)
640+
eltype = Real
641+
ndims = 1
642+
end
643+
@register_array_symbolic dfoobar2(x, y) begin
644+
size = (2,)
645+
eltype = Real
646+
ndims = 1
647+
end
648+
@register_derivative foobar(x, y) 1 dfoobar1(x, y)
649+
@register_derivative foobar(x, y) 2 dfoobar2(x, y)
650+
633651
@testset "Derivatives of indexed array expressions" begin
634-
@register_array_symbolic foobar(x, y) begin
635-
size = (2,)
636-
eltype = Real
637-
ndims = 1
638-
end
639-
@register_array_symbolic dfoobar1(x, y) begin
640-
size = (2,)
641-
eltype = Real
642-
ndims = 1
643-
end
644-
@register_array_symbolic dfoobar2(x, y) begin
645-
size = (2,)
646-
eltype = Real
647-
ndims = 1
648-
end
649-
Symbolics.derivative(::typeof(foobar), args::NTuple{2, Any}, ::Val{1}) = dfoobar1(args...)
650-
Symbolics.derivative(::typeof(foobar), args::NTuple{2, Any}, ::Val{2}) = dfoobar2(args...)
651-
652652
@variables x y
653653
ex = foobar(x + 2y, y)
654654
@test ex isa Symbolics.Arr

0 commit comments

Comments
 (0)