11
22real_rep (a:: AbstractArray{DualT} ) where {TAG, T, DualT<: ForwardDiff.Dual{TAG, T} } = reinterpret (reshape, T, a)
3+ _view1 (B:: AbstractMatrix ) = @view (B[1 ,:])
4+ _view1 (B:: AbstractArray{<:Any,3} ) = @view (B[1 ,:,:])
5+
36
47# multiplication of dual vector/matrix by standard matrix from the left
58function _matmul! (_C:: AbstractVecOrMat{DualT} , A:: AbstractMatrix , _B:: AbstractVecOrMat{DualT} ,
6- α = One () , β= Zero (), nthread:: Nothing = nothing , MKN= nothing , contig_axis= nothing ) where {DualT <: ForwardDiff.Dual }
7- B = real_rep (_B)
8- C = real_rep (_C)
9+ α , β= Zero (), nthread:: Nothing = nothing , MKN= nothing , contig_axis= nothing ) where {DualT <: ForwardDiff.Dual }
10+ B = real_rep (_B)
11+ C = real_rep (_C)
912
10- @tturbo for n ∈ indices ((C, B), 3 ), m ∈ indices ((C, A), (2 , 1 )), l in indices ((C, B), 1 )
11- Cₗₘₙ = zero (eltype (C))
12- for k ∈ indices ((A, B), 2 )
13- Cₗₘₙ += A[m, k] * B[l, k, n]
14- end
15- C[l, m, n] = α * Cₗₘₙ + β * C[l, m, n]
13+ @tturbo for n ∈ indices ((C, B), 3 ), m ∈ indices ((C, A), (2 , 1 )), l in indices ((C, B), 1 )
14+ Cₗₘₙ = zero (eltype (C))
15+ for k ∈ indices ((A, B), 2 )
16+ Cₗₘₙ += A[m, k] * B[l, k, n]
1617 end
18+ C[l, m, n] = α * Cₗₘₙ + β * C[l, m, n]
19+ end
1720
18- _C
21+ _C
1922end
2023
2124# multiplication of dual matrix by standard vector/matrix from the right
2225@inline function _matmul! (_C:: AbstractVecOrMat{DualT} , _A:: AbstractMatrix{DualT} , B:: AbstractVecOrMat ,
2326 α= One (), β= Zero (), nthread:: Nothing = nothing , MKN= nothing ) where {TAG, T, DualT <: ForwardDiff.Dual{TAG, T} }
24- if all ((ArrayInterface. is_dense (_C), ArrayInterface. is_column_major (_C),
25- ArrayInterface. is_dense (_A), ArrayInterface. is_column_major (_A)))
26- # we can avoid the reshape and call the standard method
27- A = reinterpret (T, _A)
28- C = reinterpret (T, _C)
29- _matmul! (C, A, B, α, β, nthread, nothing )
30- else
31- # we cannot use the standard method directly
32- A = real_rep (_A)
33- C = real_rep (_C)
34-
35- @tturbo for n ∈ indices ((C, B), (3 , 2 )), m ∈ indices ((C, A), 2 ), l in indices ((C, A), 1 )
36- Cₗₘₙ = zero (eltype (C))
37- for k ∈ indices ((A, B), (3 , 1 ))
38- Cₗₘₙ += A[l, m, k] * B[k, n]
39- end
40- C[l, m, n] = α * Cₗₘₙ + β * C[l, m, n]
41- end
27+ if Bool (ArrayInterface. is_dense (_C)) && Bool (ArrayInterface. is_column_major (_C)) &&
28+ Bool (ArrayInterface. is_dense (_A)) && Bool (ArrayInterface. is_column_major (_A))
29+ # we can avoid the reshape and call the standard method
30+ A = reinterpret (T, _A)
31+ C = reinterpret (T, _C)
32+ _matmul! (C, A, B, α, β, nthread, nothing )
33+ else
34+ # we cannot use the standard method directly
35+ A = real_rep (_A)
36+ C = real_rep (_C)
37+
38+ @tturbo for n ∈ indices ((C, B), (3 , 2 )), m ∈ indices ((C, A), 2 ), l in indices ((C, A), 1 )
39+ Cₗₘₙ = zero (eltype (C))
40+ for k ∈ indices ((A, B), (3 , 1 ))
41+ Cₗₘₙ += A[l, m, k] * B[k, n]
42+ end
43+ C[l, m, n] = α * Cₗₘₙ + β * C[l, m, n]
4244 end
45+ end
4346
44- _C
47+ _C
4548end
4649
47- _view1 (B:: AbstractMatrix ) = @view (B[1 ,:])
48- _view1 (B:: AbstractArray{<:Any,3} ) = @view (B[1 ,:,:])
4950@inline function _matmul! (_C:: AbstractVecOrMat{DualT} , _A:: AbstractMatrix{DualT} , _B:: AbstractVecOrMat{DualT} ,
50- α= One (), β= Zero (), nthread:: Nothing = nothing , MKN= nothing ) where {TAG, T, P, DualT <: ForwardDiff.Dual{TAG, T, P} }
51+ α= One (), β= Zero (), nthread:: Nothing = nothing , MKN= nothing , contig = nothing ) where {TAG, T, P, DualT <: ForwardDiff.Dual{TAG, T, P} }
5152 A = real_rep (_A)
5253 C = real_rep (_C)
5354 B = real_rep (_B)
54- if all (( ArrayInterface. is_dense (_C), ArrayInterface. is_column_major (_C),
55- ArrayInterface. is_dense (_A), ArrayInterface. is_column_major (_A) ))
55+ if Bool ( ArrayInterface. is_dense (_C)) && Bool ( ArrayInterface. is_column_major (_C)) &&
56+ Bool ( ArrayInterface. is_dense (_A)) && Bool ( ArrayInterface. is_column_major (_A))
5657 # we can avoid the reshape and call the standard method
5758 Ar = reinterpret (T, _A)
5859 Cr = reinterpret (T, _C)
@@ -77,3 +78,80 @@ _view1(B::AbstractArray{<:Any,3}) = @view(B[1,:,:])
7778 end
7879 _C
7980end
81+
82+
83+ # multiplication of dual vector/matrix by standard matrix from the left
84+ function _matmul_serial! (_C:: AbstractVecOrMat{DualT} , A:: AbstractMatrix , _B:: AbstractVecOrMat{DualT} ,
85+ α, β, MKN) where {DualT <: ForwardDiff.Dual }
86+ B = real_rep (_B)
87+ C = real_rep (_C)
88+
89+ @turbo for n ∈ indices ((C, B), 3 ), m ∈ indices ((C, A), (2 , 1 )), l in indices ((C, B), 1 )
90+ Cₗₘₙ = zero (eltype (C))
91+ for k ∈ indices ((A, B), 2 )
92+ Cₗₘₙ += A[m, k] * B[l, k, n]
93+ end
94+ C[l, m, n] = α * Cₗₘₙ + β * C[l, m, n]
95+ end
96+
97+ _C
98+ end
99+
100+ # multiplication of dual matrix by standard vector/matrix from the right
101+ @inline function _matmul_serial! (_C:: AbstractVecOrMat{DualT} , _A:: AbstractMatrix{DualT} , B:: AbstractVecOrMat ,
102+ α, β, MKN) where {TAG, T, DualT <: ForwardDiff.Dual{TAG, T} }
103+ if Bool (ArrayInterface. is_dense (_C)) && Bool (ArrayInterface. is_column_major (_C)) &&
104+ Bool (ArrayInterface. is_dense (_A)) && Bool (ArrayInterface. is_column_major (_A))
105+ # we can avoid the reshape and call the standard method
106+ A = reinterpret (T, _A)
107+ C = reinterpret (T, _C)
108+ _matmul_serial! (C, A, B, α, β, nothing )
109+ else
110+ # we cannot use the standard method directly
111+ A = real_rep (_A)
112+ C = real_rep (_C)
113+
114+ @turbo for n ∈ indices ((C, B), (3 , 2 )), m ∈ indices ((C, A), 2 ), l in indices ((C, A), 1 )
115+ Cₗₘₙ = zero (eltype (C))
116+ for k ∈ indices ((A, B), (3 , 1 ))
117+ Cₗₘₙ += A[l, m, k] * B[k, n]
118+ end
119+ C[l, m, n] = α * Cₗₘₙ + β * C[l, m, n]
120+ end
121+ end
122+
123+ _C
124+ end
125+
126+ @inline function _matmul_serial! (_C:: AbstractVecOrMat{DualT} , _A:: AbstractMatrix{DualT} , _B:: AbstractVecOrMat{DualT} ,
127+ α, β, MKN) where {TAG, T, P, DualT <: ForwardDiff.Dual{TAG, T, P} }
128+ A = real_rep (_A)
129+ C = real_rep (_C)
130+ B = real_rep (_B)
131+ if Bool (ArrayInterface. is_dense (_C)) && Bool (ArrayInterface. is_column_major (_C)) &&
132+ Bool (ArrayInterface. is_dense (_A)) && Bool (ArrayInterface. is_column_major (_A))
133+ # we can avoid the reshape and call the standard method
134+ Ar = reinterpret (T, _A)
135+ Cr = reinterpret (T, _C)
136+ _matmul_serial! (Cr, Ar, _view1 (B), α, β, nothing )
137+ else
138+ # we cannot use the standard method directly
139+ @turbo for n ∈ indices ((C, B), 3 ), m ∈ indices ((C, A), 2 ), l in indices ((C, A), 1 )
140+ Cₗₘₙ = zero (eltype (C))
141+ for k ∈ indices ((A, B), (3 , 2 ))
142+ Cₗₘₙ += A[l, m, k] * B[1 , k, n]
143+ end
144+ C[l, m, n] = α * Cₗₘₙ + β * C[l, m, n]
145+ end
146+ end
147+ Pstatic = static (P)
148+ @turbo for n ∈ indices ((B,C),3 ), m ∈ indices ((A,C),2 ), p ∈ 1 : Pstatic
149+ Cₚₘₙ = zero (eltype (C))
150+ for k ∈ indices ((A,B),(3 ,2 ))
151+ Cₚₘₙ += A[1 ,m,k] * B[p+ 1 ,k,n]
152+ end
153+ C[p+ 1 ,m,n] = C[p+ 1 ,m,n] + α* Cₚₘₙ
154+ end
155+ _C
156+ end
157+
0 commit comments