diff --git a/lib/cublas/linalg.jl b/lib/cublas/linalg.jl index 949f8621c6..5e73716c1a 100644 --- a/lib/cublas/linalg.jl +++ b/lib/cublas/linalg.jl @@ -9,6 +9,12 @@ end # BLAS 1 # +function LinearAlgebra.rmul!(x::CuArray{<:CublasFloat}, k::Bool) + # explicitly fill x with zero to comply with julias "false = strong zero" + !k && fill!(x, zero(eltype(x))) + return x +end + LinearAlgebra.rmul!(x::StridedCuArray{<:CublasFloat}, k::Number) = scal!(length(x), k, x) @@ -267,7 +273,7 @@ function LinearAlgebra.generic_matvecmul!(Y::StridedCuVector, tA::AbstractChar, end if nA == 0 - return rmul!(Y, 0) + return rmul!(Y, beta) end T = eltype(Y) @@ -356,7 +362,7 @@ function LinearAlgebra.generic_matmatmul!(C::StridedCuVecOrMat, tA, tB, A::Strid if size(C) != (mA, nB) throw(DimensionMismatch("C has dimensions $(size(C)), should have ($mA,$nB)")) end - return LinearAlgebra.rmul!(C, 0) + return LinearAlgebra.rmul!(C, beta) end if all(in(('N', 'T', 'C')), (tA, tB)) diff --git a/test/libraries/cublas/level1.jl b/test/libraries/cublas/level1.jl index 4fd7aa1006..2bc8f7b55f 100644 --- a/test/libraries/cublas/level1.jl +++ b/test/libraries/cublas/level1.jl @@ -41,6 +41,12 @@ k = 13 @test dz ≈ z end + @testset "rmul! strong zero" begin + @test testf(rmul!, fill(T(NaN), 3), false) + @test testf(rmul!, rand(T, 3), false) + @test testf(rmul!, rand(T, 3), true) + end + @testset "rotate!" begin @test testf(rotate!, rand(T, m), rand(T, m), rand(real(T)), rand(real(T))) @test testf(rotate!, rand(T, m), rand(T, m), rand(real(T)), rand(T))