@@ -505,8 +505,53 @@ function Xgeev!(jobvl::Char, jobvr::Char, A::StridedCuMatrix{T}) where {T <: Bla
505505end
506506
507507# XsyevBatched
508+ function XsyevBatched! (jobz:: Char , uplo:: Char , A:: StridedCuArray{T, 3} ) where {T <: BlasFloat }
509+ minimum_version = v " 11.7.1"
510+ CUSOLVER. version () < minimum_version && throw (ErrorException (" This operation requires cuSOLVER
511+ $(minimum_version) or later. Current cuSOLVER version: $(CUSOLVER. version ()) ." ))
512+ chkuplo (uplo)
513+ n = checksquare (A)
514+ batch_size = size (A, 3 )
515+ R = real (T)
516+ lda = max (1 , stride (A, 2 ))
517+ W = CuMatrix {R} (undef, n, batch_size)
518+ params = CuSolverParameters ()
519+ dh = dense_handle ()
520+ resize! (dh. info, batch_size)
521+
522+ function bufferSize ()
523+ out_cpu = Ref {Csize_t} (0 )
524+ out_gpu = Ref {Csize_t} (0 )
525+ cusolverDnXsyevBatched_bufferSize (
526+ dh, params, jobz, uplo, n,
527+ T, A, lda, R, W, T, out_gpu, out_cpu, batch_size
528+ )
529+ return out_gpu[], out_cpu[]
530+ end
531+ with_workspaces (dh. workspace_gpu, dh. workspace_cpu, bufferSize ()... ) do buffer_gpu, buffer_cpu
532+ cusolverDnXsyevBatched (
533+ dh, params, jobz, uplo, n, T, A,
534+ lda, R, W, T, buffer_gpu, sizeof (buffer_gpu),
535+ buffer_cpu, sizeof (buffer_cpu), dh. info, batch_size
536+ )
537+ end
538+
539+ info = @allowscalar collect (dh. info)
540+ for i in 1 : batch_size
541+ chkargsok (info[i] |> BlasInt)
542+ end
543+
544+ if jobz == ' N'
545+ return W
546+ elseif jobz == ' V'
547+ return W, A
548+ end
549+ end
550+
508551function XsyevBatched! (jobz:: Char , uplo:: Char , A:: StridedCuMatrix{T} ) where {T <: BlasFloat }
509- CUSOLVER. version () < v " 11.7.1" && throw (ErrorException (" This operation is not supported by the current CUDA version." ))
552+ minimum_version = v " 11.7.1"
553+ CUSOLVER. version () < minimum_version && throw (ErrorException (" This operation requires cuSOLVER
554+ $(minimum_version) or later. Current cuSOLVER version: $(CUSOLVER. version ()) ." ))
510555 chkuplo (uplo)
511556 n, num_matrices = size (A)
512557 batch_size = num_matrices ÷ n
0 commit comments