Skip to content

Commit bd1b8f4

Browse files
authored
Merge pull request #1024 from mcabbott/paramcast
Make broadcasting over `Params` in the gradient an error
2 parents 9ff7624 + d7cd2ec commit bd1b8f4

File tree

3 files changed

+11
-4
lines changed

3 files changed

+11
-4
lines changed

src/compiler/interface.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,10 @@ end
201201

202202
Base.Broadcast.broadcasted(f, ps::Params) = broadcasted(f, ps.order)
203203

204+
@adjoint function Broadcast.broadcasted(f::Function, ps::Params)
205+
f.(ps), _ -> throw(ArgumentError("Zygote.Params does not support broadcasting within gradients, try iteration `for p in ps`"))
206+
end
207+
204208
Base.:(==)(x::Params, y::Params) = x.order.data == y.order.data
205209

206210
function Base.show(io::IO, ps::Params)

src/lib/array.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,14 @@ _droplike(dy::Union{LinearAlgebra.Adjoint, LinearAlgebra.Transpose}, dxv::Abstra
7373
@adjoint getindex(::Type{T}, xs...) where {T} = T[xs...], dy -> (nothing, dy...)
7474

7575
@adjoint! setindex!(xs::AbstractArray, x...) = setindex!(xs, x...),
76-
_ -> error("Mutating arrays is not supported")
76+
_ -> error("Mutating arrays is not supported -- called setindex!(::$(typeof(xs)), _...)")
7777

7878
@adjoint! copyto!(args...) = copyto!(args...),
79-
_ -> error("Mutating arrays is not supported")
79+
_ -> error("Mutating arrays is not supported -- called copyto!(::$(typeof(xs)), _...)")
8080

8181
for f in [push!, pop!, pushfirst!, popfirst!]
82-
@eval @adjoint! $f(xs, x...) =
83-
push!(xs, x...), _ -> error("Mutating arrays is not supported")
82+
@eval @adjoint! $f(xs, x...) = $f(xs, x...),
83+
_ -> error("Mutating arrays is not supported -- called $($f)(::$(typeof(xs)), _...)")
8484
end
8585

8686
# This is kind of bad, but at least we don't materialize the whole

test/interface.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@ using Zygote: Grads
3838
x, y = [1,2], [1]
3939
ps = Params([x, y])
4040
@test length.(ps) == length.([x, y]) # 617
41+
@test size.(ps, 1) == [2, 1]
4142
@test all(Params([[1,1]]) .== Params([[1,1]]))
43+
44+
@test_throws ArgumentError gradient(() -> sum(sum.(ps)), ps)
4245
end
4346

4447
@testset "indexing" begin

0 commit comments

Comments
 (0)