Skip to content

Commit b90ee24

Browse files
authored
Merge pull request #196 from JuliaDiff/ox/tweakderiving
Small tweaks to docs on deriving rules
2 parents faf7e6c + 0ea406b commit b90ee24

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

docs/src/arrays.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ Our approach for deriving the adjoints $\overline{X}_m$ is then:
258258

259259
Note that the final expressions for the adjoints will not contain any $\dot{X}_m$ terms.
260260

261-
!!! info
261+
!!! note
262262
Why do we conjugate, and why do we only use the real part of the dot product in \eqref{pbident}?
263263
Recall from [Complex Numbers](complex.md) that we treat a complex number as a pair of real numbers.
264264
These identities are a direct consequence of this convention.
@@ -656,8 +656,8 @@ function frule((_, ΔA), ::typeof(logabsdet), A::Matrix{<:RealOrComplex})
656656
# The primal function uses the lu decomposition to compute logabsdet
657657
# we reuse this decomposition to compute inv(A) * ΔA
658658
F = lu(A, check = false)
659-
Ω = logabsdet(F)
660-
b = tr(F \ ΔA) # tr(inv(A) * ΔA)
659+
Ω = logabsdet(F) # == logabsdet(A)
660+
b = tr(F \ ΔA) # == tr(inv(A) * ΔA)
661661
s = last(Ω)
662662
∂l = real(b)
663663
# for real A, ∂s will always be zero (because imag(b) = 0)
@@ -752,14 +752,14 @@ function rrule(::typeof(logabsdet), A::Matrix{<:RealOrComplex})
752752
# The primal function uses the lu decomposition to compute logabsdet
753753
# we reuse this decomposition to compute inv(A)
754754
F = lu(A, check = false)
755-
Ω = logabsdet(F)
755+
Ω = logabsdet(F) # == logabsdet(A)
756756
s = last(Ω)
757757
function logabsdet_pullback(ΔΩ)
758758
(Δl, Δs) = ΔΩ
759759
f = conj(s) * Δs
760-
imagf = f - real(f) # 0 for real A and Δs, im * imag(f) for complex A and/or Δs
760+
imagf = f - real(f) # 0 for real A and Δs, im * imag(f) for complex A and/or Δs
761761
g = real(Δl) + imagf
762-
∂A = g * inv(F)' # g * inv(A)'
762+
∂A = g * inv(F)' # == g * inv(A)'
763763
return (NO_FIELDS, ∂A)
764764
end
765765
return (Ω, logabsdet_pullback)

0 commit comments

Comments
 (0)