Skip to content

Commit 272c6ef

Browse files
authored
XsyevBatched! interface accepting 3D StridedCuArray (#2951)
Signed-off-by: Steven Hahn <hahnse@ornl.gov>
1 parent 2e983fe commit 272c6ef

File tree

2 files changed

+76
-1
lines changed

2 files changed

+76
-1
lines changed

lib/cusolver/dense_generic.jl

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,8 +505,53 @@ function Xgeev!(jobvl::Char, jobvr::Char, A::StridedCuMatrix{T}) where {T <: Bla
505505
end
506506

507507
# XsyevBatched
508+
function XsyevBatched!(jobz::Char, uplo::Char, A::StridedCuArray{T, 3}) where {T <: BlasFloat}
509+
minimum_version = v"11.7.1"
510+
CUSOLVER.version() < minimum_version && throw(ErrorException("This operation requires cuSOLVER
511+
$(minimum_version) or later. Current cuSOLVER version: $(CUSOLVER.version())."))
512+
chkuplo(uplo)
513+
n = checksquare(A)
514+
batch_size = size(A, 3)
515+
R = real(T)
516+
lda = max(1, stride(A, 2))
517+
W = CuMatrix{R}(undef, n, batch_size)
518+
params = CuSolverParameters()
519+
dh = dense_handle()
520+
resize!(dh.info, batch_size)
521+
522+
function bufferSize()
523+
out_cpu = Ref{Csize_t}(0)
524+
out_gpu = Ref{Csize_t}(0)
525+
cusolverDnXsyevBatched_bufferSize(
526+
dh, params, jobz, uplo, n,
527+
T, A, lda, R, W, T, out_gpu, out_cpu, batch_size
528+
)
529+
return out_gpu[], out_cpu[]
530+
end
531+
with_workspaces(dh.workspace_gpu, dh.workspace_cpu, bufferSize()...) do buffer_gpu, buffer_cpu
532+
cusolverDnXsyevBatched(
533+
dh, params, jobz, uplo, n, T, A,
534+
lda, R, W, T, buffer_gpu, sizeof(buffer_gpu),
535+
buffer_cpu, sizeof(buffer_cpu), dh.info, batch_size
536+
)
537+
end
538+
539+
info = @allowscalar collect(dh.info)
540+
for i in 1:batch_size
541+
chkargsok(info[i] |> BlasInt)
542+
end
543+
544+
if jobz == 'N'
545+
return W
546+
elseif jobz == 'V'
547+
return W, A
548+
end
549+
end
550+
508551
function XsyevBatched!(jobz::Char, uplo::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat}
509-
CUSOLVER.version() < v"11.7.1" && throw(ErrorException("This operation is not supported by the current CUDA version."))
552+
minimum_version = v"11.7.1"
553+
CUSOLVER.version() < minimum_version && throw(ErrorException("This operation requires cuSOLVER
554+
$(minimum_version) or later. Current cuSOLVER version: $(CUSOLVER.version())."))
510555
chkuplo(uplo)
511556
n, num_matrices = size(A)
512557
batch_size = num_matrices ÷ n

test/libraries/cusolver/dense_generic.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,36 @@ p = 5
3333
end
3434

3535
@testset "syevBatched!" begin
36+
batch_size = 5
37+
for uplo in ('L', 'U')
38+
(CUSOLVER.version() < v"11.7.2") && (uplo == 'L') && (elty == ComplexF32) && continue
39+
40+
A = rand(elty, n, n, batch_size)
41+
B = rand(elty, n, n, batch_size)
42+
for i in 1:batch_size
43+
S = rand(elty, n, n)
44+
S = S * S' + I
45+
B[:, :, i] .= S
46+
S = uplo == 'L' ? tril(S) : triu(S)
47+
A[:, :, i] .= S
48+
end
49+
d_A = CuArray(A)
50+
d_W, d_V = CUSOLVER.XsyevBatched!('V', uplo, d_A)
51+
W = collect(d_W)
52+
V = collect(d_V)
53+
for i in 1:batch_size
54+
Bᵢ = B[:, :, i]
55+
Wᵢ = Diagonal(W[:, i])
56+
Vᵢ = V[:, :, i]
57+
@test Bᵢ * Vᵢ Vᵢ * Diagonal(Wᵢ)
58+
end
59+
60+
d_A = CuArray(A)
61+
d_W = CUSOLVER.XsyevBatched!('N', uplo, d_A)
62+
end
63+
end
64+
65+
@testset "syevBatched! updated" begin
3666
batch_size = 5
3767
for uplo in ('L', 'U')
3868
(CUSOLVER.version() < v"11.7.2") && (uplo == 'L') && (elty == ComplexF32) && continue

0 commit comments

Comments
 (0)