146146 return C
147147end
148148@inline function matmul_serial! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , α, β, MKN, :: StaticInt )
149- _matmul_serial! (C, A, B, α, β, MKN)
150- return C
149+ _matmul_serial! (C, A, B, α, β, MKN)
150+ return C
151151end
152152
153153"""
@@ -164,30 +164,32 @@ If the arrays are small and statically sized, it will dispatch to an inlined mul
164164Otherwise, based on the array's size, whether they are transposed, and whether the columns are already aligned, it decides to not pack at all, to pack only `A`, or to pack both arrays `A` and `B`.
165165"""
166166@inline function _matmul_serial! (
167- C:: AbstractMatrix{T} , A:: AbstractMatrix , B:: AbstractMatrix , α, β, MKN
167+ C:: AbstractMatrix{T} , A:: AbstractMatrix , B:: AbstractMatrix , α, β, MKN
168168) where {T}
169- M, K, N = MKN === nothing ? matmul_sizes (C, A, B) : MKN
170- if M * N == 0
171- return
172- elseif K == 0
173- matmul_only_β! (C, β)
174- return
175- end
176- pA = zstridedpointer (A); pB = zstridedpointer (B); pC = zstridedpointer (C);
177- Cb = preserve_buffer (C); Ab = preserve_buffer (A); Bb = preserve_buffer (B);
178- Mc, Kc, Nc = block_sizes (Val (T)); mᵣ, nᵣ = matmul_params (Val (T));
179- GC. @preserve Cb Ab Bb begin
180- if maybeinline (M, N, T, ArrayInterface. is_column_major (A)) # check MUST be compile-time resolvable
181- inlineloopmul! (pC, pA, pB, One (), Zero (), M, K, N)
182- return
183- elseif (nᵣ ≥ N) || dontpack (pA, M, K, Mc, Kc, T)
184- loopmul! (pC, pA, pB, α, β, M, K, N)
185- return
186- else
187- matmul_st_pack_dispatcher! (pC, pA, pB, α, β, M, K, N)
188- return
189- end
169+ ((β ≢ Zero ()) && iszero (β)) && return _matmul_serial! (C, A, B, α, Zero (), MKN)
170+ (β isa Bool) && return _matmul_serial! (C, A, B, α, One (), MKN)
171+ M, K, N = MKN === nothing ? matmul_sizes (C, A, B) : MKN
172+ if M * N == 0
173+ return
174+ elseif K == 0
175+ matmul_only_β! (C, β)
176+ return
177+ end
178+ pA = zstridedpointer (A); pB = zstridedpointer (B); pC = zstridedpointer (C);
179+ Cb = preserve_buffer (C); Ab = preserve_buffer (A); Bb = preserve_buffer (B);
180+ Mc, Kc, Nc = block_sizes (Val (T)); mᵣ, nᵣ = matmul_params (Val (T));
181+ GC. @preserve Cb Ab Bb begin
182+ if maybeinline (M, N, T, ArrayInterface. is_column_major (A)) # check MUST be compile-time resolvable
183+ inlineloopmul! (pC, pA, pB, One (), Zero (), M, K, N)
184+ return
185+ elseif (nᵣ ≥ N) || dontpack (pA, M, K, Mc, Kc, T)
186+ loopmul! (pC, pA, pB, α, β, M, K, N)
187+ return
188+ else
189+ matmul_st_pack_dispatcher! (pC, pA, pB, α, β, M, K, N)
190+ return
190191 end
192+ end
191193end # function
192194
193195function matmul_only_β! (C:: AbstractMatrix{T} , β:: StaticInt{0} ) where T
@@ -266,35 +268,37 @@ end
266268
267269# passing MKN directly would let osmeone skip the size check.
268270@inline function _matmul! (C:: AbstractMatrix{T} , A, B, α, β, nthread, MKN) where {T}
269- M, K, N = MKN === nothing ? matmul_sizes (C, A, B) : MKN
270- if M * N == 0
271- return
272- elseif K == 0
273- matmul_only_β! (C, β)
274- return
275- end
276- W = pick_vector_width (T)
277- pA = zstridedpointer (A); pB = zstridedpointer (B); pC = zstridedpointer (C);
278- Cb = preserve_buffer (C); Ab = preserve_buffer (A); Bb = preserve_buffer (B);
279- mᵣ, nᵣ = matmul_params (Val (T))
280- GC. @preserve Cb Ab Bb begin
281- if maybeinline (M, N, T, ArrayInterface. is_column_major (A)) # check MUST be compile-time resolvable
282- inlineloopmul! (pC, pA, pB, One (), Zero (), M, K, N)
283- return
284- else
285- (nᵣ ≥ N) && @goto LOOPMUL
286- if (Sys. ARCH === :x86_64 ) || (Sys. ARCH === :i686 )
287- (M* K* N < (StaticInt {4_096} () * W)) && @goto LOOPMUL
288- else
289- (M* K* N < (StaticInt {32_000} () * W)) && @goto LOOPMUL
290- end
291- __matmul! (pC, pA, pB, α, β, M, K, N, nthread)
292- return
293- @label LOOPMUL
294- loopmul! (pC, pA, pB, α, β, M, K, N)
295- return
296- end
271+ ((β ≢ Zero ()) && iszero (β)) && return _matmul! (C, A, B, α, Zero (), nthread, MKN)
272+ (β isa Bool) && return _matmul! (C, A, B, α, One (), nthread, MKN)
273+ M, K, N = MKN === nothing ? matmul_sizes (C, A, B) : MKN
274+ if M * N == 0
275+ return
276+ elseif K == 0
277+ matmul_only_β! (C, β)
278+ return
279+ end
280+ W = pick_vector_width (T)
281+ pA = zstridedpointer (A); pB = zstridedpointer (B); pC = zstridedpointer (C);
282+ Cb = preserve_buffer (C); Ab = preserve_buffer (A); Bb = preserve_buffer (B);
283+ mᵣ, nᵣ = matmul_params (Val (T))
284+ GC. @preserve Cb Ab Bb begin
285+ if maybeinline (M, N, T, ArrayInterface. is_column_major (A)) # check MUST be compile-time resolvable
286+ inlineloopmul! (pC, pA, pB, One (), Zero (), M, K, N)
287+ return
288+ else
289+ (nᵣ ≥ N) && @goto LOOPMUL
290+ if (Sys. ARCH === :x86_64 ) || (Sys. ARCH === :i686 )
291+ (M* K* N < (StaticInt {4_096} () * W)) && @goto LOOPMUL
292+ else
293+ (M* K* N < (StaticInt {32_000} () * W)) && @goto LOOPMUL
294+ end
295+ __matmul! (pC, pA, pB, α, β, M, K, N, nthread)
296+ return
297+ @label LOOPMUL
298+ loopmul! (pC, pA, pB, α, β, M, K, N)
299+ return
297300 end
301+ end
298302end
299303
300304# This funciton is sort of a `pun`. It splits aggressively (it does a lot of "splitin'"), which often means it will split-N.
0 commit comments