Skip to content

Commit 7604288

Browse files
authored
rm rules for eachslice, cumsum (#1253)
* rm rules for eachslice, cumsum * bump * bound chainrules * bump
1 parent 1936109 commit 7604288

File tree

2 files changed

+2
-33
lines changed

2 files changed

+2
-33
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Zygote"
22
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
3-
version = "0.6.41"
3+
version = "0.6.42"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
@@ -25,7 +25,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2525

2626
[compat]
2727
AbstractFFTs = "0.5, 1.0"
28-
ChainRules = "1.35.3"
28+
ChainRules = "1.36.2"
2929
ChainRulesCore = "1.9"
3030
ChainRulesTestUtils = "1"
3131
DiffRules = "1.4"

src/lib/array.jl

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -351,37 +351,6 @@ _backvar(xs, Δ, N::Int, mean) = (convert(eltype(xs), 2/N) .* Δ .* (xs .- mean)
351351
return s, Δ -> _backvar(xs, Δ ./ (2 .* s), corrected, mean, dims)
352352
end
353353

354-
@adjoint function cumsum(xs::AbstractVector; dims::Integer = 1)
355-
dims == 1 || return copy(xs), Δ -> (Δ,)
356-
cumsum(xs), Δ -> (reverse(cumsum(reverse(Δ))),)
357-
end
358-
@adjoint function cumsum(xs::AbstractArray; dims::Integer)
359-
dims <= ndims(xs) || return copy(xs), Δ -> (Δ,)
360-
cumsum(xs; dims=dims), Δ -> begin
361-
(reverse(cumsum(reverse(Δ, dims=dims), dims=dims), dims=dims),)
362-
end
363-
end
364-
365-
@adjoint eachrow(x::AbstractVecOrMat) = collect(eachrow(x)), dys -> ∇eachslice(dys, x, 1)
366-
@adjoint eachcol(x::AbstractVecOrMat) = collect(eachcol(x)), dys -> ∇eachslice(dys, x, 2)
367-
@adjoint eachslice(x::AbstractArray; dims::Integer) =
368-
collect(eachslice(x; dims=dims)), dys -> ∇eachslice(dys, x, dims)
369-
370-
function ∇eachslice(dys, x::AbstractArray, dim::Integer) where {TX}
371-
i1 = findfirst(dy -> dy isa AbstractArray, dys)
372-
i1 === nothing && return (zero(x),) # all slices get nothing
373-
T = promote_type(eltype(dys[i1]), eltype(x))
374-
dx = similar(x, T)
375-
for i in axes(x, dim)
376-
if dys[i] isa AbstractArray
377-
copyto!(selectdim(dx,dim,i), dys[i])
378-
else
379-
selectdim(dx,dim,i) .= 0
380-
end
381-
end
382-
(dx,)
383-
end
384-
385354

386355
# LinearAlgebra
387356
# =============

0 commit comments

Comments
 (0)