|
| 1 | +""" |
| 2 | + derivative_rule(::typeof(f), ::Val{NArgs}, args::ArgsT{VartypeT}, ::Val{I}) |
| 3 | +
|
| 4 | +Define the derivative rule for `f` with `Nargs` arguments `args` with respect to the `I`th |
| 5 | +argument. Do not define this function directly. Prefer using |
| 6 | +[`@register_derivative`](@ref). Instead of calling this function directly, prefer |
| 7 | +[`@derivative_rule`](@ref). |
| 8 | +""" |
| 9 | +function derivative_rule end |
| 10 | + |
| 11 | +""" |
| 12 | + @register_derivative fn(args...) Ith_arg derivative |
| 13 | +
|
| 14 | +Register a symbolic derivative for a function. This typically accompanies a call to |
| 15 | +[`@register_symbolic`](@ref) or [`@register_array_symbolic`](@ref) and defines how |
| 16 | +[`expand_derivatives`](@ref) will behave when it tries to differentiate the registered |
| 17 | +function. |
| 18 | +
|
| 19 | +The first argument to the macro is a call to the function whose derivative is being |
| 20 | +defined. The call cannot have keyword arguments or default arguments. The call must have |
| 21 | +either an exact number of arguments or a single variadic argument. For example, `f(a)`, |
| 22 | +`f(a, b)`, `f(a, b, c)` and `f(args...)` are valid signatures. `f(a, b, args...)` is |
| 23 | +invalid. If an exact number of arguments is provided, the defined derivative is specific |
| 24 | +to that number of arguments. If the variadic signature is used, the defined derivative |
| 25 | +is valid for all numbers of arguments. In case multiple derivatives are registered for |
| 26 | +the same function, they must have different numbers of arguments. A derivative for an |
| 27 | +exact number of arguments is more specific than a variadic definition. For example, |
| 28 | +`@register_derivatives f(a, b) #...` is more specific than |
| 29 | +`@register_derivatives f(args...) #...` for a 2-argument call to `f`. The arguments |
| 30 | +can be referred to with their declared names inside the derivative definition. |
| 31 | +
|
| 32 | +The second argument to the macro is the argument with respect to which the derivative |
| 33 | +rule is defined. For example, `@register_derivative f(a, b) 2 #...` is a derivative rule |
| 34 | +with respect to the second argument of `f`. Mathematically, it represents |
| 35 | +``\\frac{ \\partial f(a, b) }{ \\partial b }``. To define a generic derivative, this |
| 36 | +argument can be an identifier. For example, `@register_derivative f(a, b) I #...` makes |
| 37 | +`I` available in the derivative definition as the index of the argument with respect to |
| 38 | +which the derivative is being taken. |
| 39 | +
|
| 40 | +The third argument to the macro is the derivative expression. This should be a symbolically |
| 41 | +traceable expression returning the derivative of the specified function with respect to |
| 42 | +the specified argument. In case of a variadic definition, the identifier `Nargs` is available |
| 43 | +to denote the number of arguments provided to the function. In case the variadic form is |
| 44 | +used, the arguments are available as a read-only array (mutation will error). Mutating |
| 45 | +the array is unsafe and undefined behavior. |
| 46 | +
|
| 47 | +!!! note |
| 48 | + For functions that return arrays (such as those registered via `@register_array_symbolic`) |
| 49 | + the returned expression must be the Jacobian. Currently, support for differentiating array |
| 50 | + functions is considered experimental. |
| 51 | +
|
| 52 | +!!! warning |
| 53 | + The derivative expression MUST return a symbolic value, or `nothing` if the derivative is |
| 54 | + not defined. In case the result is a non-symbolic value, such as a constant derivative or |
| 55 | + Jacobian of array functions, the result MUST be wrapped in `Symbolics.SConst(..)`. |
| 56 | +
|
| 57 | +Following are example definitions of derivatives: |
| 58 | +
|
| 59 | +```julia |
| 60 | +@register_derivative sin(x) 1 cos(x) |
| 61 | +@register_derivative max(x, y) 2 ifelse(x >= y, 0, 1) |
| 62 | +@register_derivative min(args...) I begin |
| 63 | + error("The rule for the derivative of `min` with \$Nargs arguments w.r.t the \$I-th argument is undefined.") |
| 64 | +end |
| 65 | +@register_derivative (foo::MyCallableStruct)(args...) I begin |
| 66 | + error("Oops! Didn't implement the derivative for \$foo") |
| 67 | +end |
| 68 | +``` |
| 69 | +""" |
| 70 | +macro register_derivative(f::Expr, I::Union{Symbol, Int}, body) |
| 71 | + @assert Meta.isexpr(f, :call) """ |
| 72 | + Incorrect `@register_derivative` syntax. The function must be provided as a call \ |
| 73 | + signature. Got `$f` which is not a call signature. |
| 74 | + """ |
| 75 | + fnhead = f.args[1] |
| 76 | + fncallargs = @view f.args[2:end] |
| 77 | + is_struct_der = Meta.isexpr(fnhead, :(::)) |
| 78 | + if is_struct_der |
| 79 | + @assert length(fnhead.args) == 2 """ |
| 80 | + Incorrect `@register_derivative` syntax. Registering derivatives of callable \ |
| 81 | + structs requires providing a name for the struct. For example, instead of |
| 82 | + `@register_derivative (::MyStruct) # ...` use \ |
| 83 | + `@register_derivative (x::MyStruct) # ...`. |
| 84 | + """ |
| 85 | + end |
| 86 | + @assert !any(Base.Fix2(Meta.isexpr, :kw), fncallargs) """ |
| 87 | + Incorrect `@register_derivative` syntax. The function cannot have default arguments. |
| 88 | + """ |
| 89 | + @assert !Meta.isexpr(fncallargs[1], :parameters) """ |
| 90 | + Incorrect `@register_derivative` syntax. The function cannot have keyword arguments. |
| 91 | + """ |
| 92 | + |
| 93 | + is_varargs = Meta.isexpr(fncallargs[end], :...) |
| 94 | + if is_varargs |
| 95 | + @assert length(fncallargs) == 1 """ |
| 96 | + Incorrect `@register_derivative` syntax. The function call signature must either \ |
| 97 | + be a single variadic argument `@register_derivative foo(args...) #...` or a \ |
| 98 | + concrete number of arguments `@register_derivative foo(arg1, arg2, arg3) # ...`. |
| 99 | + """ |
| 100 | + end |
| 101 | + |
| 102 | + derhead = Expr(:call, :($Symbolics.derivative_rule), is_struct_der ? fnhead : :(::($typeof($fnhead)))) |
| 103 | + Nargs = is_varargs ? :Nargs : length(fncallargs) |
| 104 | + push!(derhead.args, :(::Val{$Nargs})) |
| 105 | + args_name = gensym(:args) |
| 106 | + push!(derhead.args, :($args_name::$SymbolicUtils.ROArgsT{$VartypeT})) |
| 107 | + push!(derhead.args, :(::Val{$I})) |
| 108 | + |
| 109 | + if is_varargs || I isa Symbol |
| 110 | + derhead = Expr(:where, derhead) |
| 111 | + is_varargs && push!(derhead.args, Nargs) |
| 112 | + I isa Symbol && push!(derhead.args, I) |
| 113 | + end |
| 114 | + |
| 115 | + if is_varargs |
| 116 | + unpack = Expr(:(=), fncallargs[1].args[1], args_name) |
| 117 | + else |
| 118 | + unpack = Expr(:tuple) |
| 119 | + append!(unpack.args, fncallargs) |
| 120 | + unpack = Expr(:(=), unpack, args_name) |
| 121 | + end |
| 122 | + |
| 123 | + return esc(Expr(:function, derhead, Expr(:block, unpack, body))) |
| 124 | +end |
| 125 | + |
| 126 | +# Pre-defined derivatives |
| 127 | +import DiffRules |
| 128 | +for (modu, fun, arity) ∈ DiffRules.diffrules(; filter_modules=(:Base, :SpecialFunctions, :NaNMath)) |
| 129 | + fun in [:*, :+, :abs, :mod, :rem, :max, :min] && continue # special |
| 130 | + for i ∈ 1:arity |
| 131 | + |
| 132 | + expr = if arity == 1 |
| 133 | + DiffRules.diffrule(modu, fun, :(args[1])) |
| 134 | + else |
| 135 | + DiffRules.diffrule(modu, fun, ntuple(k->:(args[$k]), arity)...)[i] |
| 136 | + end |
| 137 | + |
| 138 | + # Using the macro here doesn't work somehow. |
| 139 | + @eval function derivative_rule(::typeof($modu.$fun), ::Val{$arity}, args::SymbolicUtils.ArgsT{VartypeT}, ::Val{$i}) |
| 140 | + $SConst($expr) |
| 141 | + end |
| 142 | + end |
| 143 | +end |
| 144 | + |
| 145 | +Base.@propagate_inbounds function _derivative_rule_proxy(f, args::NTuple{N, SymbolicT}, ::Val{I}) where {N, I} |
| 146 | + _derivative_rule_proxy(f, Val{N}(), args, Val{I}()) |
| 147 | +end |
| 148 | +Base.@propagate_inbounds function _derivative_rule_proxy(f, ::Val{N}, args::NTuple{N, SymbolicT}, ::Val{I}) where {N, I} |
| 149 | + _derivative_rule_proxy(f, Val{N}(), SymbolicUtils.ArgsT{VartypeT}(args), Val{I}()) |
| 150 | +end |
| 151 | +Base.@propagate_inbounds function _derivative_rule_proxy(f, args::Tuple, ::Val{I}) where {I} |
| 152 | + _derivative_rule_proxy(f, Val{length(args)}(), args, Val{I}()) |
| 153 | +end |
| 154 | +Base.@propagate_inbounds function _derivative_rule_proxy(f, ::Val{N}, args::Tuple{Vararg{Any, N}}, ::Val{I}) where {N, I} |
| 155 | + args = ntuple(BSImpl.Const{VartypeT} ∘ Base.Fix1(getindex, args), Val{N}()) |
| 156 | + _derivative_rule_proxy(f, Val{N}(), args, Val{I}()) |
| 157 | +end |
| 158 | +Base.@propagate_inbounds function _derivative_rule_proxy(f, args::ROArgsT{VartypeT}, ::Val{I}) where {I} |
| 159 | + @inbounds _derivative_rule_proxy(f, Val{length(args)}(), args, Val{I}()) |
| 160 | +end |
| 161 | +Base.@propagate_inbounds function _derivative_rule_proxy(f, ::Val{N}, args::ROArgsT{VartypeT}, ::Val{I}) where {N, I} |
| 162 | + @boundscheck checkbounds(args, N) |
| 163 | + derivative_rule(f, Val{N}(), args, Val{I}()) |
| 164 | +end |
| 165 | +Base.@propagate_inbounds function _derivative_rule_proxy(f, args::ArgsT{VartypeT}, ::Val{I}) where {I} |
| 166 | + @inbounds _derivative_rule_proxy(f, Val{length(args)}(), args, Val{I}()) |
| 167 | +end |
| 168 | +Base.@propagate_inbounds function _derivative_rule_proxy(f, ::Val{N}, args::ArgsT{VartypeT}, ::Val{I}) where {N, I} |
| 169 | + @boundscheck checkbounds(args, N) |
| 170 | + _derivative_rule_proxy(f, Val{N}(), ROArgsT{VartypeT}(args), Val{I}()) |
| 171 | +end |
| 172 | +Base.@propagate_inbounds function _derivative_rule_proxy(f, args::AbstractArray{SymbolicT}, ::Val{I}) where {I} |
| 173 | + @inbounds _derivative_rule_proxy(f, Val{length(args)}(), args, Val{I}()) |
| 174 | +end |
| 175 | +Base.@propagate_inbounds function _derivative_rule_proxy(f, ::Val{N}, args::AbstractArray{SymbolicT}, ::Val{I}) where {N, I} |
| 176 | + @boundscheck checkbounds(args, N) |
| 177 | + _derivative_rule_proxy(f, Val{N}(), ArgsT{VartypeT}(args), Val{I}()) |
| 178 | +end |
| 179 | +Base.@propagate_inbounds function _derivative_rule_proxy(f, args::AbstractArray, ::Val{I}) where {I} |
| 180 | + @inbounds _derivative_rule_proxy(f, Val{length(args)}(), args, Val{I}()) |
| 181 | +end |
| 182 | +Base.@propagate_inbounds function _derivative_rule_proxy(f, ::Val{N}, args::AbstractArray, ::Val{I}) where {N, I} |
| 183 | + @boundscheck checkbounds(args, N) |
| 184 | + _args = ArgsT{VartypeT}() |
| 185 | + sizehint!(_args, N) |
| 186 | + for a in args |
| 187 | + push!(_args, BSImpl.Const{VartypeT}(a)) |
| 188 | + end |
| 189 | + _derivative_rule_proxy(f, Val{N}(), _args, Val{I}()) |
| 190 | +end |
| 191 | + |
| 192 | +""" |
| 193 | + @derivative_rule f(args...) I |
| 194 | +
|
| 195 | +Query Symbolics.jl's derivative rule system for the derivative of `f(args...)` with respect to |
| 196 | +`args[I]`. Returns a symbolic result representing the derivative. In case the derivative rule is |
| 197 | +not defined, evaluates to `nothing`. |
| 198 | +
|
| 199 | +The first argument to the macro must be a valid function call syntax. Splatting of arguments is |
| 200 | +permitted. The second argument must be an expression or literal evaluating to the index of the |
| 201 | +argument with respect to which the derivative is required. |
| 202 | +
|
| 203 | +The derivative rule can dispatch statically if `f`, the number of arguments and `I` are known |
| 204 | +at compile time. Example invocations are: |
| 205 | +
|
| 206 | +```julia |
| 207 | +# static dispatch |
| 208 | +@derivative_rule sin(x) 1 |
| 209 | +# static dispatch if `xs` is a tuple |
| 210 | +@derivative_rule max(xs...) 2 |
| 211 | +# static dispatch if `y` and `w` are tuples, and `N + 2K` is a compile-time constant |
| 212 | +@derivative_rule foo(x, y..., z, w...) (N + 2K) |
| 213 | +``` |
| 214 | +""" |
| 215 | +macro derivative_rule(f, I) |
| 216 | + @assert Meta.isexpr(f, :call) """ |
| 217 | + Incorrect `@derivative_rule` syntax. The function must be provided as a call \ |
| 218 | + signature. Got `$f` which is not a call signature. |
| 219 | + """ |
| 220 | + fnhead = f.args[1] |
| 221 | + fncallargs = @view f.args[2:end] |
| 222 | + result = Expr(:call, _derivative_rule_proxy, fnhead) |
| 223 | + if length(fncallargs) == 1 && Meta.isexpr(fncallargs[1], :...) |
| 224 | + push!(result.args, fncallargs[1].args[1]) |
| 225 | + elseif any(Base.Fix2(Meta.isexpr, :...), fncallargs) |
| 226 | + args = Expr(:tuple) |
| 227 | + append!(args.args, fncallargs) |
| 228 | + push!(result.args, args) |
| 229 | + else |
| 230 | + push!(result.args, :(Val{$(length(fncallargs))}())) |
| 231 | + args = Expr(:tuple) |
| 232 | + append!(args.args, fncallargs) |
| 233 | + push!(result.args, args) |
| 234 | + end |
| 235 | + push!(result.args, :(Val{$I}())) |
| 236 | + return esc(result) |
| 237 | +end |
| 238 | + |
| 239 | +@register_derivative +(args...) I COMMON_ONE |
| 240 | +@register_derivative *(args...) I begin |
| 241 | + if I == 1 |
| 242 | + SymbolicUtils.mul_worker(VartypeT, view(args, 2:Nargs)) |
| 243 | + elseif I == Nargs |
| 244 | + SymbolicUtils.mul_worker(VartypeT, view(args, 1:(Nargs-1))) |
| 245 | + else |
| 246 | + t1 = SymbolicUtils.mul_worker(VartypeT, view(args, 1:(Nargs-1))) |
| 247 | + t2 = SymbolicUtils.mul_worker(VartypeT, view(args, 2:Nargs)) |
| 248 | + t1 * t2 |
| 249 | + end |
| 250 | +end |
| 251 | +@register_derivative one(x) 1 COMMON_ZERO |
| 252 | + |
| 253 | +""" |
| 254 | +$(SIGNATURES) |
| 255 | +
|
| 256 | +Calculate the derivative of the op `O` with respect to its argument with index |
| 257 | +`idx`. |
| 258 | +
|
| 259 | +# Examples |
| 260 | +
|
| 261 | +```jldoctest label1 |
| 262 | +julia> using Symbolics |
| 263 | +
|
| 264 | +julia> @variables x y; |
| 265 | +
|
| 266 | +julia> Symbolics.derivative_idx(Symbolics.value(sin(x)), 1) |
| 267 | +cos(x) |
| 268 | +``` |
| 269 | +
|
| 270 | +Note that the function does not recurse into the operation's arguments, i.e., the |
| 271 | +chain rule is not applied: |
| 272 | +
|
| 273 | +```jldoctest label1 |
| 274 | +julia> myop = Symbolics.value(sin(x) * y^2) |
| 275 | +sin(x)*(y^2) |
| 276 | +
|
| 277 | +julia> typeof(Symbolics.operation(myop)) # Op is multiplication function |
| 278 | +typeof(*) |
| 279 | +
|
| 280 | +julia> Symbolics.derivative_idx(myop, 1) # wrt. sin(x) |
| 281 | +y^2 |
| 282 | +
|
| 283 | +julia> Symbolics.derivative_idx(myop, 2) # wrt. y^2 |
| 284 | +sin(x) |
| 285 | +``` |
| 286 | +""" |
| 287 | +@inline derivative_idx(::Any, ::Any) = COMMON_ZERO |
| 288 | +function derivative_idx(O::VartypeT, idx::Int) |
| 289 | + iscall(O) || return COMMON_ZERO |
| 290 | + f = operation(O) |
| 291 | + args = arguments(O) |
| 292 | + return @derivative_rule f(args...) idx |
| 293 | +end |
| 294 | + |
0 commit comments