@@ -359,25 +359,24 @@ b, b_pullback = rrule(+, 0.2, a);
359359c, c_pullback = rrule(asin, b)
360360
361361#### Then the backward pass calculating gradients
362- c̄ = 1; # ∂c/∂c
363- _, b̄ = c_pullback(unthunk( c̄)) ; # ∂c/∂b
364- _, _, ā = b_pullback(unthunk( b̄)) ; # ∂c/∂a
365- _, x̄ = a_pullback(unthunk(ā)) ; # ∂c/∂x = ∂f /∂x
366- unthunk(x̄)
362+ c̄ = 1; # ∂c/∂c
363+ _, b̄ = c_pullback(c̄); # ∂c/∂b = ∂c/∂b ⋅ ∂c/∂c
364+ _, _, ā = b_pullback(b̄); # ∂c/∂a = ∂c/∂b ⋅ ∂b /∂a
365+ _, x̄ = a_pullback(ā) ; # ∂c/∂x = ∂c/∂a ⋅ ∂a /∂x
366+ x̄ # ∂c/∂x = ∂foo/∂x
367367# output
368368-1.0531613736418153
369369```
370370``` jldoctest index
371371#### Find dfoo/dx via frules
372372x = 3;
373- ẋ = 1; # ∂x/∂x
373+ ẋ = 1; # ∂x/∂x
374374nofields = Zero(); # ∂self/∂self
375375
376- a, ȧ = frule((nofields, ẋ), sin, x); # ∂a/∂x
377- b, ḃ = frule((nofields, Zero(), unthunk(ȧ)), +, 0.2, a); # ∂b/∂x = ∂b/∂a⋅∂a/∂x
378-
379- c, ċ = frule((nofields, unthunk(ḃ)), asin, b); # ∂c/∂x = ∂c/∂b⋅∂b/∂x = ∂f/∂x
380- unthunk(ċ)
376+ a, ȧ = frule((nofields, ẋ), sin, x); # ∂a/∂x = ∂a/∂x ⋅ ∂x/∂x
377+ b, ḃ = frule((nofields, Zero(), ȧ), +, 0.2, a); # ∂b/∂x = ∂b/∂a ⋅ ∂a/∂x
378+ c, ċ = frule((nofields, ḃ), asin, b); # ∂c/∂x = ∂c/∂b ⋅ ∂b/∂x
379+ ċ # ∂c/∂x = ∂foo/∂x
381380# output
382381-1.0531613736418153
383382```
0 commit comments