Skip to content

Commit 3ce8f0e

Browse files
authored
Merge pull request #43 from FluxML/bc/broadcast-style
More robust BroadcastStyle handling
2 parents 06aba31 + 8a6bc82 commit 3ce8f0e

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

src/array.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,9 @@ end
143143
Adapt.adapt_structure(T, x::OneHotArray) = OneHotArray(adapt(T, _indices(x)), x.nlabels)
144144

145145
function Base.BroadcastStyle(::Type{<:OneHotArray{<:Any, <:Any, var"N+1", T}}) where {var"N+1", T <: AbstractGPUArray}
146-
# We want CuArrayStyle{N+1}(). There's an AbstractGPUArrayStyle but it doesn't do what we need.
146+
# We want CuArrayStyle{N+1}(). There's an AbstractGPUArrayStyle but it doesn't do what we need.
147147
S = Base.BroadcastStyle(T)
148-
# S has dim N not N+1. The following hack to fix it relies on the arraystyle having N as its first type parameter, which
149-
# isn't guaranteed, but there are not so many GPU broadcasting styles in the wild. (Far fewer than there are array wrappers.)
150-
(typeof(S).name.wrapper){var"N+1"}()
148+
typeof(S)(Val{var"N+1"}())
151149
end
152150

153151
Base.map(f, x::OneHotLike) = Base.broadcast(f, x)

0 commit comments

Comments
 (0)