Skip to content

Commit 7d32509

Browse files
committed
More docs on generation
1 parent d97535f commit 7d32509

File tree

3 files changed

+37
-4
lines changed

3 files changed

+37
-4
lines changed

docs/src/autodiff/operator_overloading.md

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,44 @@ A signature type tuple always has the form:
1111
first positional argument.
1212
One can dispatch on the signature type to make rules with argument types your AD does not support not call `eval`;
1313
or more simply you can just use conditions for this.
14-
The hook is automatically triggered whenever a package is loaded.
14+
For example if your AD only supports `AbstractMatrix{Float64}` and `Float64` inputs you might write:
15+
```julia
16+
const ACCEPT_TYPE = Union{Float64, AbstractMatrix{Float64}}
17+
function define_overload(sig::Type{<:Tuple{F, Vararg{ACCEPT_TYPE}}) where F
18+
@eval quote
19+
# ...
20+
end
21+
end
22+
define_overload(::Any) = nothing # don't do anything for any other signature
23+
24+
on_new_rule(frule, define_overload)
25+
```
26+
27+
or you might write:
28+
```julia
29+
const ACCEPT_TYPES = (Float64, AbstractMatrix{Float64})
30+
function define_overload(sig) where F
31+
sig = Base.unwrap_unionall(sig) # not really handling most UnionAll,
32+
opT, argTs = Iterators.peel(sig.parameters)
33+
all(any(acceptT<: argT for acceptT in ACCEPT_TYPES) for argT in argTs) || return
34+
@eval quote
35+
# ...
36+
end
37+
end
38+
39+
on_new_rule(frule, define_overload)
40+
```
1541
16-
`refresh_rules`(@ref) is used to manually trigger the hook function on any new rules.
42+
The generation of overloaded code is the responsibility of the AD implementor.
43+
Packages like [ExprTools.jl](https://github.com/invenia/ExprTools.jl) can be helpful for this.
44+
Its generally fairly simple, though can become complex if you need to handle complicated type-constraints.
45+
Examples are shown below.
46+
47+
The hook is automatically triggered whenever a package is loaded.
48+
It can also be triggers manually using `refresh_rules`(@ref).
1749
This is useful for example if new rules are define in the REPL, or if a package defining rules is modified.
1850
(Revise.jl will not automatically trigger).
51+
When the rules are refreshed (automatically or manually), the hooks are only triggered on new/modified rules; not ones that have already had the hooks triggered on.
1952
2053
`clear_new_rule_hooks!`(@ref) clears all registered hooks.
2154
It is useful to undo [`on_new_rule`] hook registration if you are iteratively developing your overload generation function.

test/demos/forwarddiffzero.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ Base.to_power_type(x::Dual) = x
2828

2929

3030
function define_dual_overload(sig)
31-
sig = Base.unwrap_unionall(sig)
31+
sig = Base.unwrap_unionall(sig) # Not really handling most UnionAlls
3232
opT, argTs = Iterators.peel(sig.parameters)
3333
fieldcount(opT) == 0 || return # not handling functors
3434
all(Float64 <: argT for argT in argTs) || return # only handling purely Float64 ops.

test/demos/reversediffzero.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ Base.to_power_type(x::Tracked) = x
5757

5858
"What to do when a new rrule is declared"
5959
function define_tracked_overload(sig)
60-
sig = Base.unwrap_unionall(sig)
60+
sig = Base.unwrap_unionall(sig) # not really handling most UnionAll
6161
opT, argTs = Iterators.peel(sig.parameters)
6262
fieldcount(opT) == 0 || return # not handling functors
6363
all(Float64 <: argT for argT in argTs) || return # only handling purely Float64 ops.

0 commit comments

Comments
 (0)