You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
pullback_at(::typeof(σ), x, y, ȳ) =ȳ* y *σ(-x) # = ȳ * σ(x) * σ(-x)
51
+
pullback_at(::typeof(σ), x, y, ȳ) =ȳ* y *σ(-x) # = ȳ * σ(x) * σ(-x)
52
52
```
53
53
Notice that in `pullback_at` we are not only using input `x` but also using the primal output `y` .
54
54
This is a nice bit of symmetry that shows up around `exp`.
@@ -130,7 +130,7 @@ So we are talking about a 30-40% speed-up from these optimizations.[^4]
130
130
It is faster to compute `sin` and `cos` at the same time via `sincos` than it is to compute them one after the other.
131
131
And it is faster to reuse the `exp(x)` in computing `σ(x)` and `σ(-x)`.
132
132
How can we incorporate this insight into our system?
133
-
We know we can compute both of these in the primal — because they only depend on `x` and not on `ȳ` — but there is nowhere to put them that is accessible both to the primal pass and the gradient pass code.
133
+
We know we can compute both of these in the primal — because they only depend on `x` and not on `ȳ` — but there is nowhere to put them that is accessible both to the primal pass and the gradient pass code.
134
134
135
135
What if we introduced some variable called `intermediates` that is also recorded onto the tape during the primal pass?
136
136
We would need to be able to modify the primal pass to do this, so that we can actually put the data into the `intermediates`.
@@ -140,7 +140,7 @@ So that would look like:
140
140
```julia
141
141
y =f(x) # primal program
142
142
y, intermediates =augmented_primal(f, x)
143
-
x̄ =pullback_at(f, x, y, ȳ, intermediates)
143
+
x̄ =pullback_at(f, x, y, ȳ, intermediates)
144
144
```
145
145
146
146
```@raw html
@@ -152,7 +152,7 @@ function augmented_primal(::typeof(sin), x)
152
152
return y, (; cx=cx) # use a NamedTuple for the intermediates
153
153
end
154
154
155
-
pullback_at(::typeof(sin), x, y, ȳ, intermediates) =ȳ* intermediates.cx
155
+
pullback_at(::typeof(sin), x, y, ȳ, intermediates) =ȳ* intermediates.cx
156
156
```
157
157
```@raw html
158
158
</details>
@@ -168,7 +168,7 @@ function augmented_primal(::typeof(σ), x)
168
168
return y, (; ex=ex) # use a NamedTuple for the intermediates
169
169
end
170
170
171
-
pullback_at(::typeof(σ), x, y, ȳ, intermediates) =ȳ* y / (1+ intermediates.ex)
171
+
pullback_at(::typeof(σ), x, y, ȳ, intermediates) =ȳ* y / (1+ intermediates.ex)
172
172
```
173
173
```@raw html
174
174
</details>
@@ -202,7 +202,7 @@ So changing our API we have:
202
202
```julia
203
203
y =f(x) # primal program
204
204
y, pb =augmented_primal(f, x)
205
-
x̄ =pullback_at(pb, ȳ)
205
+
x̄ =pullback_at(pb, ȳ)
206
206
```
207
207
which is much cleaner.
208
208
@@ -215,7 +215,7 @@ function augmented_primal(::typeof(sin), x)
We now have an object `pb` that acts on the cotangent of the output of the primal `ȳ` to give us the cotangent of the input of the primal function `x̄`.
286
+
We now have an object `pb` that acts on the cotangent of the output of the primal `ȳ` to give us the cotangent of the input of the primal function `x̄`.
287
287
_`pb` is not just the **memory** of state required for the `pullback`, it **is** the pullback._
288
288
289
289
We have one final thing to do, which is to think about how we make the code easy to modify.
@@ -298,15 +298,15 @@ function augmented_primal(::typeof(sin), x)
298
298
y =sin(x)
299
299
return y, PullbackMemory(sin; x=x)
300
300
end
301
-
(pb::PullbackMemory)(ȳ) =ȳ*cos(pb.x)
301
+
(pb::PullbackMemory)(ȳ) =ȳ*cos(pb.x)
302
302
```
303
303
To go from that to:
304
304
```julia
305
305
functionaugmented_primal(::typeof(sin), x)
306
306
y, cx =sincos(x)
307
307
return y, PullbackMemory(sin; cx=cx)
308
308
end
309
-
(pb::PullbackMemory)(ȳ) =ȳ* pb.cx
309
+
(pb::PullbackMemory)(ȳ) =ȳ* pb.cx
310
310
```
311
311
```@raw html
312
312
</details>
@@ -320,7 +320,7 @@ function augmented_primal(::typeof(σ), x)
@@ -344,7 +344,7 @@ We need to make a series of changes:
344
344
It's important these parts all stay in sync.
345
345
It's not too bad for this simple example with just one or two things to remember.
346
346
For more complicated multi-argument functions, which we will show below, you often end up needing to remember half a dozen things, like sizes and indices relating to each input/output, so it gets a little more fiddly to make sure you remember all the things you need to and give them the same name in both places.
347
-
_Is there a way we can automatically just have all the things we use remembered for us?_
347
+
_Is there a way we can automatically just have all the things we use remembered for us?_
348
348
Surprisingly for such a specific request, there actually is: a closure.
349
349
350
350
A closure in Julia is a callable structure that automatically contains a field for every object from its parent scope that is used in its body.
@@ -357,7 +357,7 @@ Replacing `PullbackMemory` with a closure that works the same way lets us avoid
357
357
```julia
358
358
functionaugmented_primal(::typeof(sin), x)
359
359
y, cx =sincos(x)
360
-
pb =ȳ-> cx *ȳ# pullback closure. closes over `cx`
360
+
pb =ȳ-> cx *ȳ# pullback closure. closes over `cx`
361
361
return y, pb
362
362
end
363
363
```
@@ -372,7 +372,7 @@ end
372
372
functionaugmented_primal(::typeof(σ), x)
373
373
ex =exp(x)
374
374
y = ex / (1+ ex)
375
-
pb =ȳ->ȳ* y / (1+ ex) # pullback closure. closes over `y` and `ex`
375
+
pb =ȳ->ȳ* y / (1+ ex) # pullback closure. closes over `y` and `ex`
0 commit comments