22# #### `frule`/`rrule`
33# ####
44
5+ # TODO : remember to update the examples
56"""
6- frule(f, x...)
7+ frule(f, x..., ṡelf, Δx... )
78
8- Expressing `x` as the tuple `(x₁, x₂, ...)` and the output tuple of `f(x...)`
9- as `Ω`, return the tuple:
9+ Expressing `x` as the tuple `(x₁, x₂, ...)`, `Δx` as the tuple `(Δx₁, Δx₂,
10+ ...)`, and the output tuple of `f(x...)` as `Ω`, return the tuple:
1011
11- (Ω, (ṡelf, ẋ₁, ẋ₂, ...) -> Ω̇₁, Ω̇₂, ...)
12+ (Ω, (Ω̇₁, Ω̇₂, ...) )
1213
1314The second return value is the propagation rule, or the pushforward.
1415It takes in differentials corresponding to the inputs (`ẋ₁, ẋ₂, ...`)
1516and `ṡelf` the internal values of the function (for closures).
1617
1718
18- If no method matching `frule(f, xs...)` has been defined, then return `nothing`.
19+ If no method matching `frule(f, x..., ṡelf, Δx...)` has been defined, then
20+ return `nothing`.
1921
2022Examples:
2123
2224unary input, unary output scalar function:
2325
2426```
27+ julia> dself = Zero()
28+ Zero()
29+
2530julia> x = rand();
2631
27- julia> sinx, sin_pushforward = frule(sin, x);
32+ julia> sinx, sin_pushforward = frule(sin, x, dself, 1)
33+ (0.35696518021277485, 0.9341176907197836)
2834
2935julia> sinx == sin(x)
3036true
3137
32- julia> sin_pushforward(NamedTuple(), 1) == cos(x)
38+ julia> sin_pushforward == cos(x)
3339true
3440```
3541
@@ -38,12 +44,12 @@ unary input, binary output scalar function:
3844```
3945julia> x = rand();
4046
41- julia> sincosx, sincos_pushforward = frule(sincos, x);
47+ julia> sincosx, sincos_pushforward = frule(sincos, x, dself, 1 );
4248
4349julia> sincosx == sincos(x)
4450true
4551
46- julia> sincos_pushforward(NamedTuple(), 1) == (cos(x), -sin(x))
52+ julia> sincos_pushforward == (cos(x), -sin(x))
4753true
4854```
4955
0 commit comments