@@ -225,15 +225,15 @@ function Base.showerror(io::IO, err::DerivativeNotDefinedError)
225225 # and code fences
226226 err_str = Markdown. parse ("""
227227 Derivative of `$(err. expr) ` with respect to its $(err. i) -th argument is not defined.
228- Define a derivative by adding a method to `Symbolics.derivative `:
228+ Define a derivative by using `@register_derivative `:
229229
230230 ```julia
231- function Symbolics.derivative(::typeof( $op ), args::NTuple{ $nargs , Any}, ::Val{ $ (err. i)})
231+ @register_derivative $op (args...) $ (err. i) begin
232232 # ...
233233 end
234234 ```
235235
236- Refer to the documentation for `Symbolics.derivative ` and the
236+ Refer to the documentation for `@register_derivative ` and the
237237 "[Adding Analytical Derivatives](@ref)" section of the docs for further information.
238238 """ )
239239 show (io, MIME " text/plain" (), err_str)
@@ -321,12 +321,12 @@ function executediff(D::Differential, arg::BasicSymbolic{VartypeT}; simplify=fal
321321 # We know `D.x` is in `arg`, so the derivative is not identically zero.
322322 # `arg` cannot be `D.x` since, that would have also early exited.
323323 for (i, a) in enumerate (inner_args)
324- der = derivative_idx (arr, i)
324+ der = derivative_idx (arr, i):: Union{Nothing, SymbolicT}
325325 if isequal (a, D. x)
326- der isa NoDeriv && return D (arg)
326+ der === nothing && return D (arg)
327327 push! (summed_args, der[idx])
328328 continue
329- elseif der isa NoDeriv
329+ elseif der === nothing
330330 push! (summed_args, Differential (a)(arg) * executediff (D, a))
331331 else
332332 push! (summed_args, der[idx] * executediff (D, a))
@@ -381,8 +381,8 @@ function executediff(D::Differential, arg::BasicSymbolic{VartypeT}; simplify=fal
381381 for (i, iarg) in enumerate (inner_args)
382382 t2 = executediff (D, iarg; simplify, throw_no_derivative):: SymbolicT
383383 _iszero (t2) && continue
384- t = derivative_idx (arg, i):: Union{NoDeriv , SymbolicT}
385- if t isa NoDeriv
384+ t = derivative_idx (arg, i):: Union{Nothing , SymbolicT}
385+ if t === nothing
386386 throw_no_derivative && throw (DerivativeNotDefinedError (arg, i))
387387 t = D (arg)
388388 end
@@ -506,89 +506,6 @@ function expand_derivatives(n::Complex{Num}, simplify=false; kwargs...)
506506end
507507expand_derivatives (x, simplify= false ; kwargs... ) = x
508508
509- # Don't specialize on the function here
510- """
511- $(SIGNATURES)
512-
513- Calculate the derivative of the op `O` with respect to its argument with index
514- `idx`.
515-
516- # Examples
517-
518- ```jldoctest label1
519- julia> using Symbolics
520-
521- julia> @variables x y;
522-
523- julia> Symbolics.derivative_idx(Symbolics.value(sin(x)), 1)
524- cos(x)
525- ```
526-
527- Note that the function does not recurse into the operation's arguments, i.e., the
528- chain rule is not applied:
529-
530- ```jldoctest label1
531- julia> myop = Symbolics.value(sin(x) * y^2)
532- sin(x)*(y^2)
533-
534- julia> typeof(Symbolics.operation(myop)) # Op is multiplication function
535- typeof(*)
536-
537- julia> Symbolics.derivative_idx(myop, 1) # wrt. sin(x)
538- y^2
539-
540- julia> Symbolics.derivative_idx(myop, 2) # wrt. y^2
541- sin(x)
542- ```
543- """
544- derivative_idx (O:: Any , :: Any ) = COMMON_ZERO
545- function derivative_idx (O:: BasicSymbolic , idx)
546- iscall (O) || return COMMON_ZERO
547- res = derivative (operation (O), (arguments (O)... ,), Val (idx))
548- if res isa NoDeriv
549- return res
550- else
551- return Const {VartypeT} (res)
552- end
553- end
554-
555- # Indicate that no derivative is defined.
556- struct NoDeriv
557- end
558-
559- """
560- Symbolics.derivative(::typeof(f), args::NTuple{N, Any}, ::Val{i})
561-
562- Return the derivative of `f(args...)` with respect to `args[i]`. `N` should be the number
563- of arguments that `f` takes and `i` is the argument with respect to which the derivative
564- is taken. The result can be a numeric value (if the derivative is constant) or a symbolic
565- expression. This function is useful for defining derivatives of custom functions registered
566- via `@register_symbolic`, to be used when calling `expand_derivatives`.
567- """
568- derivative (f, args, v) = NoDeriv ()
569-
570- # Pre-defined derivatives
571- import DiffRules
572- for (modu, fun, arity) ∈ DiffRules. diffrules (; filter_modules= (:Base , :SpecialFunctions , :NaNMath ))
573- fun in [:* , :+ , :abs , :mod , :rem , :max , :min ] && continue # special
574- for i ∈ 1 : arity
575-
576- expr = if arity == 1
577- DiffRules. diffrule (modu, fun, :(args[1 ]))
578- else
579- DiffRules. diffrule (modu, fun, ntuple (k-> :(args[$ k]), arity)... )[i]
580- end
581- @eval derivative (:: typeof ($ modu.$ fun), args:: NTuple{$arity,Any} , :: Val{$i} ) = $ expr
582- end
583- end
584-
585- derivative (:: typeof (+ ), args:: NTuple{N,Any} , :: Val ) where {N} = 1
586- derivative (:: typeof (* ), args:: NTuple{N,Any} , :: Val{i} ) where {N,i} = * (deleteat! (collect (args), i)... )
587- derivative (:: typeof (one), args:: Tuple{<:Any} , :: Val ) = 0
588-
589- derivative (f:: Function , x:: Union{Num, <:BasicSymbolic} ) = derivative (f (x), x)
590- derivative (:: Function , x:: Any ) = TypeError (:derivative , " 2nd argument" , Union{Num, <: BasicSymbolic }, x) |> throw
591-
592509function count_order (x)
593510 @assert ! (x isa Symbol) " The variable $x must have an order of differentiation that is greater or equal to 1!"
594511 n = 1
@@ -666,6 +583,8 @@ function derivative(O, var; simplify=false, kwargs...)
666583 Num (expand_derivatives (Differential (var)(unwrap (O)), simplify; kwargs... ))
667584 end
668585end
586+ derivative (f:: Function , var:: Union{SymbolicT, Num} ) = derivative (f (var), var)
587+ derivative (:: Function , x:: Any ) = throw (TypeError (:derivative , " 2nd argument" , Union{Num, SymbolicT}, x))
669588
670589"""
671590$(SIGNATURES)
0 commit comments