|
2 | 2 | ##### `frule`/`rrule` |
3 | 3 | ##### |
4 | 4 |
|
5 | | -#= |
6 | | -In some weird ideal sense, the fallback for e.g. `frule` should actually be "get |
7 | | -the derivative via forward-mode AD". This is necessary to enable mixed-mode |
8 | | -rules, where e.g. `frule` is used within a `rrule` definition. For example, |
9 | | -broadcasted functions may not themselves be forward-mode *primitives*, but are |
10 | | -often forward-mode *differentiable*. |
11 | | -
|
12 | | -ChainRulesCore, by design, is decoupled from any specific AD implementation. How, |
13 | | -then, do we know which AD to fall back to when there isn't a primitive defined? |
14 | | -
|
15 | | -Well, if you're a greedy AD implementation, you can just overload `frule` and/or |
16 | | -`rrule` to use your AD directly. However, this won't play nice with other AD |
17 | | -packages doing the same thing, and thus could cause load-order-dependent |
18 | | -problems for downstream users. |
19 | | -
|
20 | | -It turns out, Cassette solves this problem nicely by allowing AD authors to |
21 | | -overload the fallbacks w.r.t. their own context. Example using ForwardDiff: |
22 | | -
|
23 | | -``` |
24 | | -using ChainRulesCore, ForwardDiff, Cassette |
25 | | -
|
26 | | -Cassette.@context MyChainRuleCtx |
27 | | -
|
28 | | -# ForwardDiff, itself, can call `my_frule` instead of |
29 | | -# `frule` to utilize the ForwardDiff-injected ChainRulesCore |
30 | | -# infrastructure |
31 | | -my_frule(args...) = Cassette.overdub(MyChainRuleCtx(), frule, args...) |
32 | | -
|
33 | | -function Cassette.execute(::MyChainRuleCtx, ::typeof(frule), f, x::Number) |
34 | | - r = frule(f, x) |
35 | | - if isa(r, Nothing) |
36 | | - fx, df = (f(x), (_, Δx) -> ForwardDiff.derivative(f, x) * Δx) |
37 | | - else |
38 | | - fx, df = r |
39 | | - end |
40 | | - return fx, df |
41 | | -end |
42 | | -``` |
43 | | -=# |
44 | | - |
45 | 5 | """ |
46 | 6 | frule(f, x...) |
47 | 7 |
|
|
0 commit comments