Skip to content

Commit 612e56f

Browse files
nHackenHacke
authored andcommitted
Add storage_type for adjoint, tranpose and Diagonal s.t. they work for GPU arrays
1 parent 5cb74c4 commit 612e56f

File tree

3 files changed

+20
-0
lines changed

3 files changed

+20
-0
lines changed

src/abstract.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,11 @@ storage_type(op::AbstractLinearOperator) = error("please implement storage_type
174174
storage_type(op::LinearOperator) = typeof(op.Mv5)
175175
storage_type(M::AbstractMatrix{T}) where {T} = Vector{T}
176176

177+
# Lazy wrappers
178+
storage_type(op::Adjoint) = storage_type(parent(op))
179+
storage_type(op::Transpose) = storage_type(parent(op))
180+
storage_type(op::Diagonal) = typeof(parent(op))
181+
177182
"""
178183
reset!(op)
179184

test/gpu/amdgpu.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,18 @@ using LinearOperators, AMDGPU
66
B = ROCArray(rand(Float32, 10, 10))
77
C = ROCArray(rand(Float32, 20, 20))
88
M = BlockDiagonalOperator(A, B, C)
9+
v = ROCArray(rand(5))
10+
911

1012
v = ROCArray(rand(Float32, 35))
1113
y = M * v
1214
@test y isa ROCArray{Float32}
1315

16+
@test LinearOperators.storage_type(A) == LinearOperators.storage_type(adjoint(A))
17+
@test LinearOperators.storage_type(A) == LinearOperators.storage_type(transpose(A))
18+
@test LinearOperators.storage_type(A) == LinearOperators.storage_type(adjoint(A))
19+
@test LinearOperators.storage_type(Diagonal(v)) == typeof(v)
20+
21+
1422
@testset "AMDGPU S kwarg" test_S_kwarg(arrayType = ROCArray)
1523
end

test/gpu/nvidia.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,13 @@ using LinearOperators, CUDA, CUDA.CUSPARSE, CUDA.CUSOLVER
99
B = CUDA.rand(10, 10)
1010
C = CUDA.rand(20, 20)
1111
M = BlockDiagonalOperator(A, B, C)
12+
v = CUDA.rand(5)
13+
14+
@test LinearOperators.storage_type(A) == LinearOperators.storage_type(adjoint(A))
15+
@test LinearOperators.storage_type(A) == LinearOperators.storage_type(transpose(A))
16+
@test LinearOperators.storage_type(A) == LinearOperators.storage_type(adjoint(A))
17+
@test LinearOperators.storage_type(Diagonal(v)) == typeof(v)
18+
1219

1320
v = CUDA.rand(35)
1421
y = M * v

0 commit comments

Comments
 (0)