|
| 1 | +# Operator Overloading |
| 2 | + |
| 3 | +The principal interface for using the operator overload generation method is [`on_new_rule`](@ref). |
| 4 | +This function allows one to register a hook to be run every time a new rule is defined. |
| 5 | +The hook receives a signature type-type as input, and generally will use `eval` to define |
| 6 | +an overload of an AD system's overloaded type. |
| 7 | +For example, using the signature type `Tuple{typeof(+), Real, Real}` to make |
| 8 | +`+(::DualNumber, ::DualNumber)` call the `frule` for `+`. |
| 9 | +A signature type tuple always has the form: |
| 10 | +`Tuple{typeof(operation), typeof{pos_arg1}, typeof{pos_arg2}, ...}`, where `pos_arg1` is the |
| 11 | +first positional argument. |
| 12 | +One can dispatch on the signature type to make rules with argument types your AD does not support not call `eval`; |
| 13 | +or more simply you can just use conditions for this. |
| 14 | +For example if your AD only supports `AbstractMatrix{Float64}` and `Float64` inputs you might write: |
| 15 | +```julia |
| 16 | +const ACCEPT_TYPE = Union{Float64, AbstractMatrix{Float64}} |
| 17 | +function define_overload(sig::Type{<:Tuple{F, Vararg{ACCEPT_TYPE}}) where F |
| 18 | + @eval quote |
| 19 | + # ... |
| 20 | + end |
| 21 | +end |
| 22 | +define_overload(::Any) = nothing # don't do anything for any other signature |
| 23 | + |
| 24 | +on_new_rule(frule, define_overload) |
| 25 | +``` |
| 26 | +
|
| 27 | +or you might write: |
| 28 | +```julia |
| 29 | +const ACCEPT_TYPES = (Float64, AbstractMatrix{Float64}) |
| 30 | +function define_overload(sig) where F |
| 31 | + sig = Base.unwrap_unionall(sig) # not really handling most UnionAll, |
| 32 | + opT, argTs = Iterators.peel(sig.parameters) |
| 33 | + all(any(acceptT<: argT for acceptT in ACCEPT_TYPES) for argT in argTs) || return |
| 34 | + @eval quote |
| 35 | + # ... |
| 36 | + end |
| 37 | +end |
| 38 | + |
| 39 | +on_new_rule(frule, define_overload) |
| 40 | +``` |
| 41 | +
|
| 42 | +The generation of overloaded code is the responsibility of the AD implementor. |
| 43 | +Packages like [ExprTools.jl](https://github.com/invenia/ExprTools.jl) can be helpful for this. |
| 44 | +Its generally fairly simple, though can become complex if you need to handle complicated type-constraints. |
| 45 | +Examples are shown below. |
| 46 | +
|
| 47 | +The hook is automatically triggered whenever a package is loaded. |
| 48 | +It can also be triggers manually using `refresh_rules`(@ref). |
| 49 | +This is useful for example if new rules are define in the REPL, or if a package defining rules is modified. |
| 50 | +(Revise.jl will not automatically trigger). |
| 51 | +When the rules are refreshed (automatically or manually), the hooks are only triggered on new/modified rules; not ones that have already had the hooks triggered on. |
| 52 | +
|
| 53 | +`clear_new_rule_hooks!`(@ref) clears all registered hooks. |
| 54 | +It is useful to undo [`on_new_rule`] hook registration if you are iteratively developing your overload generation function. |
| 55 | +
|
| 56 | +## Examples |
| 57 | +
|
| 58 | +### ForwardDiffZero |
| 59 | +The overload generation hook in this example is: `define_dual_overload`. |
| 60 | +
|
| 61 | +````@eval |
| 62 | +using Markdown |
| 63 | +Markdown.parse(""" |
| 64 | +```julia |
| 65 | +$(read(joinpath(@__DIR__,"../../../test/demos/forwarddiffzero.jl"), String)) |
| 66 | +``` |
| 67 | +""") |
| 68 | +```` |
| 69 | +
|
| 70 | +### ReverseDiffZero |
| 71 | +The overload generation hook in this example is: `define_tracked_overload`. |
| 72 | +
|
| 73 | +````@eval |
| 74 | +using Markdown |
| 75 | +Markdown.parse(""" |
| 76 | +```julia |
| 77 | +$(read(joinpath(@__DIR__,"../../../test/demos/reversediffzero.jl"), String)) |
| 78 | +``` |
| 79 | +""") |
| 80 | +```` |
0 commit comments