Skip to content

Commit 1e86463

Browse files
authored
faster inv and div for Complex{Union{Float16, Float32}} (#44111)
* faster inv and div for Complex{Union{Float16, Float32}} * fix float64 division bug
1 parent befe38f commit 1e86463

File tree

2 files changed

+40
-20
lines changed

2 files changed

+40
-20
lines changed

base/complex.jl

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -347,30 +347,37 @@ muladd(z::Complex, w::Complex, x::Real) =
347347

348348
function /(a::Complex{T}, b::Complex{T}) where T<:Real
349349
are = real(a); aim = imag(a); bre = real(b); bim = imag(b)
350-
if abs(bre) <= abs(bim)
351-
if isinf(bre) && isinf(bim)
352-
r = sign(bre)/sign(bim)
353-
else
354-
r = bre / bim
350+
if (isinf(bre) | isinf(bim))
351+
if isfinite(a)
352+
return complex(zero(T)*sign(are)*sign(bre), -zero(T)*sign(aim)*sign(bim))
355353
end
354+
return T(NaN)+T(NaN)*im
355+
end
356+
if abs(bre) <= abs(bim)
357+
r = bre / bim
356358
den = bim + r*bre
357359
Complex((are*r + aim)/den, (aim*r - are)/den)
358360
else
359-
if isinf(bre) && isinf(bim)
360-
r = sign(bim)/sign(bre)
361-
else
362-
r = bim / bre
363-
end
361+
r = bim / bre
364362
den = bre + r*bim
365363
Complex((are + aim*r)/den, (aim - are*r)/den)
366364
end
367365
end
368366

369-
inv(z::Complex{<:Union{Float16,Float32}}) =
370-
oftype(z, inv(widen(z)))
371-
372-
/(z::Complex{T}, w::Complex{T}) where {T<:Union{Float16,Float32}} =
373-
oftype(z, widen(z)*inv(widen(w)))
367+
function /(z::Complex{T}, w::Complex{T}) where {T<:Union{Float16,Float32}}
368+
c, d = reim(widen(w))
369+
a, b = reim(widen(z))
370+
if (isinf(c) | isinf(d))
371+
if isfinite(z)
372+
return complex(zero(T)*sign(real(z))*sign(real(w)), -zero(T)*sign(imag(z))*sign(imag(w)))
373+
end
374+
return T(NaN)+T(NaN)*im
375+
end
376+
mag = inv(muladd(c, c, d^2))
377+
re_part = muladd(a, c, b*d)
378+
im_part = muladd(b, c, -a*d)
379+
return oftype(z, Complex(re_part*mag, im_part*mag))
380+
end
374381

375382
# robust complex division for double precision
376383
# variables are scaled & unscaled to avoid over/underflow, if necessary
@@ -382,7 +389,12 @@ function /(z::ComplexF64, w::ComplexF64)
382389
a, b = reim(z); c, d = reim(w)
383390
absa = abs(a); absb = abs(b); ab = absa >= absb ? absa : absb # equiv. to max(abs(a),abs(b)) but without NaN-handling (faster)
384391
absc = abs(c); absd = abs(d); cd = absc >= absd ? absc : absd
385-
392+
if (isinf(c) | isinf(d))
393+
if isfinite(z)
394+
return complex(0.0*sign(a)*sign(c), -0.0*sign(b)*sign(d))
395+
end
396+
return NaN+NaN*im
397+
end
386398
halfov = 0.5*floatmax(Float64) # overflow threshold
387399
twounϵ = floatmin(Float64)*2.0/eps(Float64) # underflow threshold
388400

@@ -449,6 +461,12 @@ function robust_cdiv2(a::Float64, b::Float64, c::Float64, d::Float64, r::Float64
449461
end
450462
end
451463

464+
function inv(z::Complex{T}) where T<:Union{Float16,Float32}
465+
c, d = reim(widen(z))
466+
(isinf(c) | isinf(d)) && return complex(copysign(zero(T), c), flipsign(-zero(T), d))
467+
mag = inv(muladd(c, c, d^2))
468+
return oftype(z, Complex(c*mag, -d*mag))
469+
end
452470
function inv(w::ComplexF64)
453471
c, d = reim(w)
454472
(isinf(c) | isinf(d)) && return complex(copysign(0.0, c), flipsign(-0.0, d))

test/complex.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,7 +1039,7 @@ end
10391039
@testset "corner cases of division, issue #22983" begin
10401040
# These results abide by ISO/IEC 10967-3:2006(E) and
10411041
# mathematical definition of division of complex numbers.
1042-
for T in (Float32, Float64, BigFloat)
1042+
for T in (Float16, Float32, Float64, BigFloat)
10431043
@test isequal(one(T) / zero(Complex{T}), one(Complex{T}) / zero(Complex{T}))
10441044
@test isequal(one(T) / zero(Complex{T}), Complex{T}(NaN, NaN))
10451045
@test isequal(one(Complex{T}) / zero(T), Complex{T}(Inf, NaN))
@@ -1050,7 +1050,7 @@ end
10501050
end
10511051

10521052
@testset "division by Inf, issue#23134" begin
1053-
@testset "$T" for T in (Float32, Float64, BigFloat)
1053+
@testset "$T" for T in (Float16, Float32, Float64, BigFloat)
10541054
@test isequal(one(T) / complex(T(Inf)), complex(zero(T), -zero(T)))
10551055
@test isequal(one(T) / complex(T(Inf), one(T)), complex(zero(T), -zero(T)))
10561056
@test isequal(one(T) / complex(T(Inf), T(NaN)), complex(zero(T), -zero(T)))
@@ -1088,8 +1088,10 @@ end
10881088
@test isequal(one(T) / complex(T(-NaN), T(-Inf)), complex(-zero(T), zero(T)))
10891089

10901090
# divide complex by complex Inf
1091-
@test isequal(complex(one(T)) / complex(T(Inf), T(-Inf)), complex(zero(T), zero(T))) broken=(T==Float64)
1092-
@test isequal(complex(one(T)) / complex(T(-Inf), T(Inf)), complex(-zero(T), -zero(T))) broken=(T in (Float32, Float64))
1091+
@test isequal(complex(one(T)) / complex(T(Inf), T(-Inf)), complex(zero(T), zero(T)))
1092+
@test isequal(complex(one(T)) / complex(T(-Inf), T(Inf)), complex(-zero(T), -zero(T)))
1093+
@test isequal(complex(T(Inf)) / complex(T(Inf), T(-Inf)), complex(T(NaN), T(NaN)))
1094+
@test isequal(complex(T(NaN)) / complex(T(-Inf), T(Inf)), complex(T(NaN), T(NaN)))
10931095
end
10941096
end
10951097

0 commit comments

Comments
 (0)