Skip to content

Commit 74201e6

Browse files
authored
Merge pull request #310 from JuliaDiff/mzgubic-patch-1
fix some typos in docs
2 parents 01b956f + 6767d67 commit 74201e6

File tree

1 file changed

+23
-23
lines changed

1 file changed

+23
-23
lines changed

docs/src/design/changing_the_primal.md

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ We will call this function `pullback_at`, as it pulls back the sensitivity at a
2525
To make this concrete:
2626
```julia
2727
y = f(x) # primal program
28-
= pullback_at(f, x, y, )
28+
= pullback_at(f, x, y, ȳ)
2929
```
3030
Let's illustrate this with examples for `sin` and for the [logistic sigmoid](https://en.wikipedia.org/wiki/Logistic_function#Derivative).
3131

@@ -34,9 +34,9 @@ Let's illustrate this with examples for `sin` and for the [logistic sigmoid](htt
3434
```
3535
```julia
3636
y = sin(x)
37-
pullback_at(::typeof(sin), x, y, ) = * cos(x)
37+
pullback_at(::typeof(sin), x, y, ȳ) = ȳ * cos(x)
3838
```
39-
`pullback_at` uses the primal input `x`, and the sensitivity being pulled back ``.
39+
`pullback_at` uses the primal input `x`, and the sensitivity being pulled back `ȳ`.
4040

4141
```@raw html
4242
</details>
@@ -48,7 +48,7 @@ pullback_at(::typeof(sin), x, y, ȳ) = ȳ * cos(x)
4848
```julia
4949
σ(x) = 1/(1 + exp(-x)) # = exp(x) / (1 + exp(x))
5050
y = σ(x)
51-
pullback_at(::typeof(σ), x, y, ) = * y * σ(-x) # = * σ(x) * σ(-x)
51+
pullback_at(::typeof(σ), x, y, ȳ) = ȳ * y * σ(-x) # = ȳ * σ(x) * σ(-x)
5252
```
5353
Notice that in `pullback_at` we are not only using input `x` but also using the primal output `y` .
5454
This is a nice bit of symmetry that shows up around `exp`.
@@ -130,7 +130,7 @@ So we are talking about a 30-40% speed-up from these optimizations.[^4]
130130
It is faster to compute `sin` and `cos` at the same time via `sincos` than it is to compute them one after the other.
131131
And it is faster to reuse the `exp(x)` in computing `σ(x)` and `σ(-x)`.
132132
How can we incorporate this insight into our system?
133-
We know we can compute both of these in the primal — because they only depend on `x` and not on `` — but there is nowhere to put them that is accessible both to the primal pass and the gradient pass code.
133+
We know we can compute both of these in the primal — because they only depend on `x` and not on `ȳ` — but there is nowhere to put them that is accessible both to the primal pass and the gradient pass code.
134134

135135
What if we introduced some variable called `intermediates` that is also recorded onto the tape during the primal pass?
136136
We would need to be able to modify the primal pass to do this, so that we can actually put the data into the `intermediates`.
@@ -140,7 +140,7 @@ So that would look like:
140140
```julia
141141
y = f(x) # primal program
142142
y, intermediates = augmented_primal(f, x)
143-
= pullback_at(f, x, y, , intermediates)
143+
= pullback_at(f, x, y, ȳ, intermediates)
144144
```
145145

146146
```@raw html
@@ -152,7 +152,7 @@ function augmented_primal(::typeof(sin), x)
152152
return y, (; cx=cx) # use a NamedTuple for the intermediates
153153
end
154154

155-
pullback_at(::typeof(sin), x, y, , intermediates) = * intermediates.cx
155+
pullback_at(::typeof(sin), x, y, ȳ, intermediates) = ȳ * intermediates.cx
156156
```
157157
```@raw html
158158
</details>
@@ -168,7 +168,7 @@ function augmented_primal(::typeof(σ), x)
168168
return y, (; ex=ex) # use a NamedTuple for the intermediates
169169
end
170170

171-
pullback_at(::typeof(σ), x, y, , intermediates) = * y / (1 + intermediates.ex)
171+
pullback_at(::typeof(σ), x, y, ȳ, intermediates) = ȳ * y / (1 + intermediates.ex)
172172
```
173173
```@raw html
174174
</details>
@@ -202,7 +202,7 @@ So changing our API we have:
202202
```julia
203203
y = f(x) # primal program
204204
y, pb = augmented_primal(f, x)
205-
= pullback_at(pb, )
205+
= pullback_at(pb, ȳ)
206206
```
207207
which is much cleaner.
208208

@@ -215,7 +215,7 @@ function augmented_primal(::typeof(sin), x)
215215
return y, PullbackMemory(sin; cx=cx)
216216
end
217217

218-
pullback_at(pb::PullbackMemory{typeof(sin)}, ) = * pb.cx
218+
pullback_at(pb::PullbackMemory{typeof(sin)}, ȳ) = ȳ * pb.cx
219219
```
220220
```@raw html
221221
</details>
@@ -231,7 +231,7 @@ function augmented_primal(::typeof(σ), x)
231231
return y, PullbackMemory(σ; y=y, ex=ex)
232232
end
233233

234-
pullback_at(pb::PullbackMemory{typeof(σ)}, ) = * pb.y / (1 + pb.ex)
234+
pullback_at(pb::PullbackMemory{typeof(σ)}, ȳ) = ȳ * pb.y / (1 + pb.ex)
235235
```
236236
```@raw html
237237
</details>
@@ -242,13 +242,13 @@ That now looks much simpler; `pullback_at` only ever has 2 arguments.
242242
One way we could make it nicer to use is by making `PullbackMemory` a callable object.
243243
Conceptually the `PullbackMemory` is a fixed thing it the contents of the tape for a particular operation.
244244
It is fully determined by the end of the primal pass.
245-
The during the gradient (reverse) pass the `PullbackMemory` is used to successively compute the `` argument.
245+
The during the gradient (reverse) pass the `PullbackMemory` is used to successively compute the `ȳ` argument.
246246
So it makes sense to make `PullbackMemory` a callable object that acts on the sensitivity.
247247
We can do that via call overloading:
248248
```julia
249249
y = f(x) # primal program
250250
y, pb = augmented_primal(f, x)
251-
= pb()
251+
= pb(ȳ)
252252
```
253253

254254
```@raw html
@@ -259,7 +259,7 @@ function augmented_primal(::typeof(sin), x)
259259
y, cx = sincos(x)
260260
return y, PullbackMemory(sin; cx=cx)
261261
end
262-
(pb::PullbackMemory)(ȳ) = * pb.cx
262+
(pb::PullbackMemory{typeof(sin)})(ȳ) = ȳ * pb.cx
263263
```
264264

265265
```@raw html
@@ -276,14 +276,14 @@ function augmented_primal(::typeof(σ), x)
276276
return y, PullbackMemory(σ; y=y, ex=ex)
277277
end
278278

279-
(pb::PullbackMemory{typeof(σ)})() = * pb.y / (1 + pb.ex)
279+
(pb::PullbackMemory{typeof(σ)})(ȳ) = ȳ * pb.y / (1 + pb.ex)
280280
```
281281
```@raw html
282282
</details>
283283
```
284284

285285
Let's recap what we have done here.
286-
We now have an object `pb` that acts on the cotangent of the output of the primal `` to give us the cotangent of the input of the primal function ``.
286+
We now have an object `pb` that acts on the cotangent of the output of the primal `ȳ` to give us the cotangent of the input of the primal function ``.
287287
_`pb` is not just the **memory** of state required for the `pullback`, it **is** the pullback._
288288

289289
We have one final thing to do, which is to think about how we make the code easy to modify.
@@ -298,15 +298,15 @@ function augmented_primal(::typeof(sin), x)
298298
y = sin(x)
299299
return y, PullbackMemory(sin; x=x)
300300
end
301-
(pb::PullbackMemory)() = * cos(pb.x)
301+
(pb::PullbackMemory)(ȳ) = ȳ * cos(pb.x)
302302
```
303303
To go from that to:
304304
```julia
305305
function augmented_primal(::typeof(sin), x)
306306
y, cx = sincos(x)
307307
return y, PullbackMemory(sin; cx=cx)
308308
end
309-
(pb::PullbackMemory)() = * pb.cx
309+
(pb::PullbackMemory)(ȳ) = ȳ * pb.cx
310310
```
311311
```@raw html
312312
</details>
@@ -320,7 +320,7 @@ function augmented_primal(::typeof(σ), x)
320320
y = σ(x)
321321
return y, PullbackMemory(σ; y=y, x=x)
322322
end
323-
(pb::PullbackMemory{typeof(σ)})() = * pb.y * σ(-pb.x)
323+
(pb::PullbackMemory{typeof(σ)})(ȳ) = ȳ * pb.y * σ(-pb.x)
324324
```
325325
to get to:
326326
```julia
@@ -329,7 +329,7 @@ function augmented_primal(::typeof(σ), x)
329329
y = ex/(1 + ex)
330330
return y, PullbackMemory(σ; y=y, ex=ex)
331331
end
332-
(pb::PullbackMemory{typeof(σ)})() = * pb.y/(1 + pb.ex)
332+
(pb::PullbackMemory{typeof(σ)})(ȳ) = ȳ * pb.y/(1 + pb.ex)
333333
```
334334
```@raw html
335335
</details>
@@ -344,7 +344,7 @@ We need to make a series of changes:
344344
It's important these parts all stay in sync.
345345
It's not too bad for this simple example with just one or two things to remember.
346346
For more complicated multi-argument functions, which we will show below, you often end up needing to remember half a dozen things, like sizes and indices relating to each input/output, so it gets a little more fiddly to make sure you remember all the things you need to and give them the same name in both places.
347-
_Is there a way we can automatically just have all the things we use remembered for us?_
347+
_Is there a way we can automatically just have all the things we use remembered for us?_
348348
Surprisingly for such a specific request, there actually is: a closure.
349349

350350
A closure in Julia is a callable structure that automatically contains a field for every object from its parent scope that is used in its body.
@@ -357,7 +357,7 @@ Replacing `PullbackMemory` with a closure that works the same way lets us avoid
357357
```julia
358358
function augmented_primal(::typeof(sin), x)
359359
y, cx = sincos(x)
360-
pb = -> cx * # pullback closure. closes over `cx`
360+
pb = ȳ -> cx * ȳ # pullback closure. closes over `cx`
361361
return y, pb
362362
end
363363
```
@@ -372,7 +372,7 @@ end
372372
function augmented_primal(::typeof(σ), x)
373373
ex = exp(x)
374374
y = ex / (1 + ex)
375-
pb = -> * y / (1 + ex) # pullback closure. closes over `y` and `ex`
375+
pb = ȳ -> ȳ * y / (1 + ex) # pullback closure. closes over `y` and `ex`
376376
return y, pb
377377
end
378378
```

0 commit comments

Comments
 (0)