Skip to content

Commit 4a61d2b

Browse files
authored
avoid restricting types to real (#99)
* not restricting type to real * fix ambiguity errors
1 parent 5ded095 commit 4a61d2b

File tree

4 files changed

+125
-121
lines changed

4 files changed

+125
-121
lines changed

src/block_sizes.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
matmul_params(::Val{T}) where {T <: Base.HWReal} = LoopVectorization.matmul_params()
2+
matmul_params(::Val{T}) where {T} = LoopVectorization.matmul_params()
33

44
function block_sizes(::Val{T}, _α, _β, R₁, R₂) where {T}
55
W = pick_vector_width(T)

src/complex_matmul.jl

Lines changed: 117 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -1,129 +1,133 @@
11
real_rep(a::AbstractArray{Complex{T}, N}) where {T, N} = reinterpret(reshape, T, a)
22
#PtrArray(Ptr{T}(pointer(a)), (StaticInt(2), size(a)...))
33

4-
@inline function _matmul!(_C::AbstractVecOrMat{Complex{T}}, _A::AbstractMatrix{Complex{U}}, _B::AbstractVecOrMat{Complex{V}},
5-
α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V}
6-
C, A, B = map(real_rep, (_C, _A, _B))
7-
8-
η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1))
9-
θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1))
10-
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))
11-
ηθ = η*θ
12-
13-
@tturbo for n indices((C, B), 3), m indices((C, A), 2)
14-
Cmn_re = zero(T)
15-
Cmn_im = zero(T)
16-
for k indices((A, B), (3, 2))
17-
Cmn_re += A[1, m, k] * B[1, k, n] - ηθ * A[2, m, k] * B[2, k, n]
18-
Cmn_im += θ * A[1, m, k] * B[2, k, n] + η * A[2, m, k] * B[1, k, n]
4+
for AT in [:AbstractVector, :AbstractMatrix] # to avoid ambiguity error
5+
@eval begin
6+
function _matmul!(_C::$AT{Complex{T}}, _A::AbstractMatrix{Complex{U}}, _B::$AT{Complex{V}},
7+
α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V}
8+
C, A, B = map(real_rep, (_C, _A, _B))
9+
10+
η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1))
11+
θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1))
12+
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))
13+
ηθ = η*θ
14+
15+
@tturbo for n indices((C, B), 3), m indices((C, A), 2)
16+
Cmn_re = zero(T)
17+
Cmn_im = zero(T)
18+
for k indices((A, B), (3, 2))
19+
Cmn_re += A[1, m, k] * B[1, k, n] - ηθ * A[2, m, k] * B[2, k, n]
20+
Cmn_im += θ * A[1, m, k] * B[2, k, n] + η * A[2, m, k] * B[1, k, n]
21+
end
22+
C[1,m,n] = (real(α) * Cmn_re -imag(α) * Cmn_im) + (real(β) * C[1,m,n] -imag(β) * C[2,m,n])
23+
C[2,m,n] = (imag(α) * Cmn_re +real(α) * Cmn_im) + (imag(β) * C[1,m,n] +real(β) * C[2,m,n])
24+
end
25+
_C
1926
end
20-
C[1,m,n] = (real(α) * Cmn_re -imag(α) * Cmn_im) + (real(β) * C[1,m,n] -imag(β) * C[2,m,n])
21-
C[2,m,n] = (imag(α) * Cmn_re +real(α) * Cmn_im) + (imag(β) * C[1,m,n] +real(β) * C[2,m,n])
22-
end
23-
_C
24-
end
25-
26-
@inline function _matmul!(_C::AbstractVecOrMat{Complex{T}}, A::AbstractMatrix{U}, _B::AbstractVecOrMat{Complex{V}},
27-
α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V}
28-
C, B = map(real_rep, (_C, _B))
29-
30-
θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1))
31-
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))
32-
33-
@tturbo for n indices((C, B), 3), m indices((C, A), (2, 1))
34-
Cmn_re = zero(T)
35-
Cmn_im = zero(T)
36-
for k indices((A, B), (2, 2))
37-
Cmn_re += A[m, k] * B[1, k, n]
38-
Cmn_im += θ * A[m, k] * B[2, k, n]
27+
28+
@inline function _matmul!(_C::$AT{Complex{T}}, A::AbstractMatrix{U}, _B::$AT{Complex{V}},
29+
α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V}
30+
C, B = map(real_rep, (_C, _B))
31+
32+
θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1))
33+
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))
34+
35+
@tturbo for n indices((C, B), 3), m indices((C, A), (2, 1))
36+
Cmn_re = zero(T)
37+
Cmn_im = zero(T)
38+
for k indices((A, B), (2, 2))
39+
Cmn_re += A[m, k] * B[1, k, n]
40+
Cmn_im += θ * A[m, k] * B[2, k, n]
41+
end
42+
C[1,m,n] = (real(α) * Cmn_re -imag(α) * Cmn_im) + (real(β) * C[1,m,n] -imag(β) * C[2,m,n])
43+
C[2,m,n] = (imag(α) * Cmn_re +real(α) * Cmn_im) + (imag(β) * C[1,m,n] +real(β) * C[2,m,n])
44+
end
45+
_C
3946
end
40-
C[1,m,n] = (real(α) * Cmn_re -imag(α) * Cmn_im) + (real(β) * C[1,m,n] -imag(β) * C[2,m,n])
41-
C[2,m,n] = (imag(α) * Cmn_re +real(α) * Cmn_im) + (imag(β) * C[1,m,n] +real(β) * C[2,m,n])
42-
end
43-
_C
44-
end
45-
46-
@inline function _matmul!(_C::AbstractVecOrMat{Complex{T}}, _A::AbstractMatrix{Complex{U}}, B::AbstractVecOrMat{V},
47-
α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V}
48-
C, A = map(real_rep, (_C, _A))
49-
50-
η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1))
51-
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))
52-
53-
@tturbo for n indices((C, B), (3, 2)), m indices((C, A), 2)
54-
Cmn_re = zero(T)
55-
Cmn_im = zero(T)
56-
for k indices((A, B), (3, 1))
57-
Cmn_re += A[1, m, k] * B[k, n]
58-
Cmn_im += η * A[2, m, k] * B[k, n]
47+
48+
@inline function _matmul!(_C::$AT{Complex{T}}, _A::AbstractMatrix{Complex{U}}, B::$AT{V},
49+
α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V}
50+
C, A = map(real_rep, (_C, _A))
51+
52+
η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1))
53+
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))
54+
55+
@tturbo for n indices((C, B), (3, 2)), m indices((C, A), 2)
56+
Cmn_re = zero(T)
57+
Cmn_im = zero(T)
58+
for k indices((A, B), (3, 1))
59+
Cmn_re += A[1, m, k] * B[k, n]
60+
Cmn_im += η * A[2, m, k] * B[k, n]
61+
end
62+
C[1,m,n] = (real(α) * Cmn_re -imag(α) * Cmn_im) + (real(β) * C[1,m,n] -imag(β) * C[2,m,n])
63+
C[2,m,n] = (imag(α) * Cmn_re +real(α) * Cmn_im) + (imag(β) * C[1,m,n] +real(β) * C[2,m,n])
64+
end
65+
_C
5966
end
60-
C[1,m,n] = (real(α) * Cmn_re -imag(α) * Cmn_im) + (real(β) * C[1,m,n] -imag(β) * C[2,m,n])
61-
C[2,m,n] = (imag(α) * Cmn_re +real(α) * Cmn_im) + (imag(β) * C[1,m,n] +real(β) * C[2,m,n])
62-
end
63-
_C
64-
end
6567

6668

6769

6870

6971

70-
@inline function _matmul_serial!(_C::AbstractVecOrMat{Complex{T}}, _A::AbstractMatrix{Complex{U}}, _B::AbstractVecOrMat{Complex{V}},
71-
α=One(), β=Zero(), MKN=nothing, contig_axis=nothing) where {T,U,V}
72-
C, A, B = map(real_rep, (_C, _A, _B))
72+
@inline function _matmul_serial!(_C::$AT{Complex{T}}, _A::AbstractMatrix{Complex{U}}, _B::$AT{Complex{V}},
73+
α=One(), β=Zero(), MKN=nothing, contig_axis=nothing) where {T,U,V}
74+
C, A, B = map(real_rep, (_C, _A, _B))
7375

74-
η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1))
75-
θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1))
76-
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))
77-
ηθ = η*θ
78-
@turbo for n indices((C, B), 3), m indices((C, A), 2)
79-
Cmn_re = zero(T)
80-
Cmn_im = zero(T)
81-
for k indices((A, B), (3, 2))
82-
Cmn_re += A[1, m, k] * B[1, k, n] - ηθ * A[2, m, k] * B[2, k, n]
83-
Cmn_im += θ * A[1, m, k] * B[2, k, n] + η * A[2, m, k] * B[1, k, n]
76+
η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1))
77+
θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1))
78+
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))
79+
ηθ = η*θ
80+
@turbo for n indices((C, B), 3), m indices((C, A), 2)
81+
Cmn_re = zero(T)
82+
Cmn_im = zero(T)
83+
for k indices((A, B), (3, 2))
84+
Cmn_re += A[1, m, k] * B[1, k, n] - ηθ * A[2, m, k] * B[2, k, n]
85+
Cmn_im += θ * A[1, m, k] * B[2, k, n] + η * A[2, m, k] * B[1, k, n]
86+
end
87+
C[1,m,n] = (real(α) * Cmn_re -imag(α) * Cmn_im) + (real(β) * C[1,m,n] -imag(β) * C[2,m,n])
88+
C[2,m,n] = (imag(α) * Cmn_re +real(α) * Cmn_im) + (imag(β) * C[1,m,n] +real(β) * C[2,m,n])
89+
end
90+
_C
8491
end
85-
C[1,m,n] = (real(α) * Cmn_re -imag(α) * Cmn_im) + (real(β) * C[1,m,n] -imag(β) * C[2,m,n])
86-
C[2,m,n] = (imag(α) * Cmn_re +real(α) * Cmn_im) + (imag(β) * C[1,m,n] +real(β) * C[2,m,n])
87-
end
88-
_C
89-
end
90-
91-
@inline function _matmul_serial!(_C::AbstractVecOrMat{Complex{T}}, A::AbstractMatrix{U}, _B::AbstractVecOrMat{Complex{V}},
92-
α=One(), β=Zero(), MKN=nothing, contig_axis=nothing) where {T,U,V}
93-
C, B = map(real_rep, (_C, _B))
94-
95-
θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1))
96-
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))
97-
98-
@turbo for n indices((C, B), 3), m indices((C, A), (2, 1))
99-
Cmn_re = zero(T)
100-
Cmn_im = zero(T)
101-
for k indices((A, B), (2, 2))
102-
Cmn_re += A[m, k] * B[1, k, n]
103-
Cmn_im += θ * A[m, k] * B[2, k, n]
92+
93+
@inline function _matmul_serial!(_C::$AT{Complex{T}}, A::AbstractMatrix{U}, _B::$AT{Complex{V}},
94+
α=One(), β=Zero(), MKN=nothing, contig_axis=nothing) where {T,U,V}
95+
C, B = map(real_rep, (_C, _B))
96+
97+
θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1))
98+
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))
99+
100+
@turbo for n indices((C, B), 3), m indices((C, A), (2, 1))
101+
Cmn_re = zero(T)
102+
Cmn_im = zero(T)
103+
for k indices((A, B), (2, 2))
104+
Cmn_re += A[m, k] * B[1, k, n]
105+
Cmn_im += θ * A[m, k] * B[2, k, n]
106+
end
107+
C[1,m,n] = (real(α) * Cmn_re -imag(α) * Cmn_im) + (real(β) * C[1,m,n] -imag(β) * C[2,m,n])
108+
C[2,m,n] = (imag(α) * Cmn_re +real(α) * Cmn_im) + (imag(β) * C[1,m,n] +real(β) * C[2,m,n])
109+
end
110+
_C
104111
end
105-
C[1,m,n] = (real(α) * Cmn_re -imag(α) * Cmn_im) + (real(β) * C[1,m,n] -imag(β) * C[2,m,n])
106-
C[2,m,n] = (imag(α) * Cmn_re +real(α) * Cmn_im) + (imag(β) * C[1,m,n] +real(β) * C[2,m,n])
107-
end
108-
_C
109-
end
110-
111-
@inline function _matmul_serial!(_C::AbstractVecOrMat{Complex{T}}, _A::AbstractMatrix{Complex{U}}, B::AbstractVecOrMat{V},
112-
α=One(), β=Zero(), MKN=nothing, contig_axis=nothing) where {T,U,V}
113-
C, A = map(real_rep, (_C, _A))
114-
115-
η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1))
116-
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))
117-
118-
@turbo for n indices((C, B), (3, 2)), m indices((C, A), 2)
119-
Cmn_re = zero(T)
120-
Cmn_im = zero(T)
121-
for k indices((A, B), (3, 1))
122-
Cmn_re += A[1, m, k] * B[k, n]
123-
Cmn_im += η * A[2, m, k] * B[k, n]
112+
113+
@inline function _matmul_serial!(_C::$AT{Complex{T}}, _A::AbstractMatrix{Complex{U}}, B::$AT{V},
114+
α=One(), β=Zero(), MKN=nothing, contig_axis=nothing) where {T,U,V}
115+
C, A = map(real_rep, (_C, _A))
116+
117+
η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1))
118+
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))
119+
120+
@turbo for n indices((C, B), (3, 2)), m indices((C, A), 2)
121+
Cmn_re = zero(T)
122+
Cmn_im = zero(T)
123+
for k indices((A, B), (3, 1))
124+
Cmn_re += A[1, m, k] * B[k, n]
125+
Cmn_im += η * A[2, m, k] * B[k, n]
126+
end
127+
C[1,m,n] = (real(α) * Cmn_re -imag(α) * Cmn_im) + (real(β) * C[1,m,n] -imag(β) * C[2,m,n])
128+
C[2,m,n] = (imag(α) * Cmn_re +real(α) * Cmn_im) + (imag(β) * C[1,m,n] +real(β) * C[2,m,n])
129+
end
130+
_C
124131
end
125-
C[1,m,n] = (real(α) * Cmn_re -imag(α) * Cmn_im) + (real(β) * C[1,m,n] -imag(β) * C[2,m,n])
126-
C[2,m,n] = (imag(α) * Cmn_re +real(α) * Cmn_im) + (imag(β) * C[1,m,n] +real(β) * C[2,m,n])
127132
end
128-
_C
129-
end
133+
end

src/macrokernels.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

2-
@inline incrementp(A::AbstractStridedPointer{<:Base.HWReal,3}, a::Ptr) = VectorizationBase.increment_ptr(A, a, (Zero(), Zero(), One()))
3-
@inline increment2(B::AbstractStridedPointer{<:Base.HWReal,2}, b::Ptr, ::StaticInt{nᵣ}) where {nᵣ} = VectorizationBase.increment_ptr(B, b, (Zero(), StaticInt{nᵣ}()))
4-
@inline increment1(C::AbstractStridedPointer{<:Base.HWReal,2}, c::Ptr, ::StaticInt{mᵣW}) where {mᵣW} = VectorizationBase.increment_ptr(C, c, (StaticInt{mᵣW}(), Zero()))
2+
@inline incrementp(A::AbstractStridedPointer{T,3} where T, a::Ptr) = VectorizationBase.increment_ptr(A, a, (Zero(), Zero(), One()))
3+
@inline increment2(B::AbstractStridedPointer{T,2} where T, b::Ptr, ::StaticInt{nᵣ}) where {nᵣ} = VectorizationBase.increment_ptr(B, b, (Zero(), StaticInt{nᵣ}()))
4+
@inline increment1(C::AbstractStridedPointer{T,2} where T, c::Ptr, ::StaticInt{mᵣW}) where {mᵣW} = VectorizationBase.increment_ptr(C, c, (StaticInt{mᵣW}(), Zero()))
55
macro kernel(pack::Bool, ex::Expr)
66
ex.head === :for || throw(ArgumentError("Must be a matmul for loop."))
77
mincrements = Expr[:(c = increment1(C, c, mᵣW)), :(ãₚ = incrementp(Ãₚ, ãₚ)), :(m = vsub_nsw(m, mᵣW))]

src/matmul.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ Otherwise, based on the array's size, whether they are transposed, and whether t
165165
"""
166166
@inline function _matmul_serial!(
167167
C::AbstractMatrix{T}, A::AbstractMatrix, B::AbstractMatrix, α, β, MKN
168-
) where {T<:Real}
168+
) where {T}
169169
M, K, N = MKN === nothing ? matmul_sizes(C, A, B) : MKN
170170
if M * N == 0
171171
return
@@ -263,7 +263,7 @@ end
263263
end
264264

265265
# passing MKN directly would let osmeone skip the size check.
266-
@inline function _matmul!(C::AbstractMatrix{T}, A, B, α, β, nthread, MKN) where {T<:Real}
266+
@inline function _matmul!(C::AbstractMatrix{T}, A, B, α, β, nthread, MKN) where {T}
267267
M, K, N = MKN === nothing ? matmul_sizes(C, A, B) : MKN
268268
if M * N == 0
269269
return
@@ -504,7 +504,7 @@ function sync_mul!(
504504
nothing
505505
end
506506

507-
function _matmul!(y::AbstractVector{T}, A::AbstractMatrix, x::AbstractVector, α, β, MKN, contig_axis) where {T<:Real}
507+
function _matmul!(y::AbstractVector{T}, A::AbstractMatrix, x::AbstractVector, α, β, MKN, contig_axis) where {T}
508508
@tturbo for m indices((A,y),1)
509509
yₘ = zero(T)
510510
for n indices((A,x),(2,1))
@@ -514,7 +514,7 @@ function _matmul!(y::AbstractVector{T}, A::AbstractMatrix, x::AbstractVector, α
514514
end
515515
return y
516516
end
517-
function _matmul_serial!(y::AbstractVector{T}, A::AbstractMatrix, x::AbstractVector, α, β, MKN) where {T<:Real}
517+
function _matmul_serial!(y::AbstractVector{T}, A::AbstractMatrix, x::AbstractVector, α, β, MKN) where {T}
518518
@turbo for m indices((A,y),1)
519519
yₘ = zero(T)
520520
for n indices((A,x),(2,1))

0 commit comments

Comments
 (0)