@@ -86,6 +86,15 @@ function Base.copyto!(dst::AbstractArray{T,N}, src::OneHotArray{<:Any,<:Any,N,<:
8686end
8787function Base. copyto! (dst:: Array{T,N} , src:: OneHotArray{<:Any,<:Any,N,<:AnyGPUArray} ) where {T,N}
8888 copyto! (dst, adapt (Array, src))
89+
90+ @inline function Base. setindex! (x:: OneHotArray{<:Any, N} , v, i:: Integer , I:: Vararg{Integer, N} ) where N
91+ @boundscheck checkbounds (x, i, I... )
92+ if Bool (v)
93+ @inbounds x. indices[I... ] = i
94+ elseif x. indices[I... ] == i
95+ @inbounds x. indices[I... ] = 0
96+ end
97+ x
8998end
9099
91100function Base. showarg (io:: IO , x:: OneHotArray , toplevel)
104113# copy CuArray versions back before trying to print them:
105114for fun in (:show , :print_array ) # print_array is used by 3-arg show
106115 @eval begin
107- Base.$ fun (io:: IO , X:: OneHotLike{T, N, var"N+1", <:AbstractGPUArray} ) where {T, N, var"N+1" } =
116+ Base.$ fun (io:: IO , X:: OneHotLike{T, N, var"N+1", <:AbstractGPUArray} ) where {T, N, var"N+1" } =
108117 Base.$ fun (io, adapt (Array, X))
109- Base.$ fun (io:: IO , X:: LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, N, <:Any, <:AbstractGPUArray}} ) where {T, N} =
118+ Base.$ fun (io:: IO , X:: LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, N, <:Any, <:AbstractGPUArray}} ) where {T, N} =
110119 Base.$ fun (io, adapt (Array, X))
111120 end
112121end
0 commit comments