|
333 | 333 | @adjoint conj(x::AbstractArray) = conj(x), r̄ -> (conj(r̄),) |
334 | 334 | @adjoint imag(x::AbstractArray) = imag(x), ī -> (complex.(0, real.(ī)),) |
335 | 335 |
|
336 | | -@adjoint function mean(xs::AbstractArray; dims = :) |
337 | | - return mean(xs, dims=dims), Δ -> (_backmean(xs,Δ,dims),) |
338 | | -end |
339 | | -_backmean(xs, Δ, ::Colon) = zero(xs) .+ Δ ./ length(xs) |
340 | | -_backmean(xs, Δ, dims) = zero(xs) .+ Δ ./ mapreduce(i -> size(xs,i),*,dims) |
341 | | - |
342 | | -@adjoint function Statistics.var(xs::AbstractArray; corrected::Bool=true, dims=:, mean=mean(xs, dims=dims)) |
343 | | - return Statistics.var(xs; corrected=corrected, mean=mean, dims=dims), Δ -> _backvar(xs, Δ, corrected, mean, dims) |
344 | | -end |
345 | | -_backvar(xs, Δ, corrected::Bool, mean, dims) = _backvar(xs, Δ, mapreduce(i -> size(xs,i),*,dims) - corrected, mean) |
346 | | -_backvar(xs, Δ, corrected::Bool, mean, ::Colon) = _backvar(xs, Δ, length(xs) - corrected, mean) |
347 | | -_backvar(xs, Δ, N::Int, mean) = (convert(eltype(xs), 2/N) .* Δ .* (xs .- mean),) |
348 | | - |
349 | | -@adjoint function Statistics.std(xs::AbstractArray; corrected::Bool=true, dims=:, mean=mean(xs, dims=dims)) |
350 | | - s = Statistics.std(xs; corrected=corrected, mean=mean, dims=dims) |
351 | | - return s, Δ -> _backvar(xs, Δ ./ (2 .* s), corrected, mean, dims) |
352 | | -end |
353 | | - |
354 | 336 |
|
355 | 337 | # LinearAlgebra |
356 | 338 | # ============= |
|
0 commit comments