@@ -11,19 +11,17 @@ function block_sizes(::Val{T}, _α, _β, R₁, R₂) where {T}
1111 block_sizes (Val (T), W, α, β, L₁ₑ, L₂ₑ)
1212end
1313function block_sizes (:: Val{T} , W, α, β, L₁ₑ, L₂ₑ) where {T}
14- mᵣnᵣ = matmul_params (Val (T))
15- mᵣ = getfield (mᵣnᵣ, 1 )
16- nᵣ = getfield (mᵣnᵣ, 2 )
14+ mᵣ, nᵣ = matmul_params (Val (T))
1715 MᵣW = mᵣ * W
18-
19- Mc = floortostaticint (√ (L₁ₑ)* √ (L₁ₑ* β + L₂ₑ* α) / √ (L₂ₑ) / StaticFloat64 (MᵣW)) * MᵣW
20- Kc = roundtostaticint (√ (L₁ₑ)* √ (L₂ₑ)/ √ (L₁ₑ* β + L₂ₑ* α))
21- Nc = floortostaticint (√ (L₂ₑ)* √ (L₁ₑ* β + L₂ₑ* α) / √ (L₁ₑ) / StaticFloat64 (nᵣ)) * nᵣ
22-
16+
17+ Mc = floortostaticint (√ (L₁ₑ) * √ (L₁ₑ * β + L₂ₑ * α) / √ (L₂ₑ) / StaticFloat64 (MᵣW)) * MᵣW
18+ Kc = roundtostaticint (√ (L₁ₑ) * √ (L₂ₑ) / √ (L₁ₑ * β + L₂ₑ * α))
19+ Nc = floortostaticint (√ (L₂ₑ) * √ (L₁ₑ * β + L₂ₑ * α) / √ (L₁ₑ) / StaticFloat64 (nᵣ)) * nᵣ
20+
2321 Mc, Kc, Nc
2422end
2523function block_sizes (:: Val{T} ) where {T}
26- block_sizes (Val (T), W₁Default (), W₂Default (), R₁Default (), R₂Default ())
24+ block_sizes (Val (T), W₁Default (), W₂Default (), R₁Default (), R₂Default ())
2725end
2826
2927"""
@@ -48,12 +46,12 @@ This is meant to specify roughly the requested amount of blocks, and return rela
4846This method is used fairly generally.
4947"""
5048@inline function split_m (M, _Mblocks, W)
51- Miters = cld_fast (M, W)
52- Mblocks = min (_Mblocks, Miters)
53- Miter_per_block, Mrem = divrem_fast (Miters, Mblocks)
54- Mbsize = Miter_per_block * W
55- Mremfinal = M - Mbsize* (Mblocks- 1 ) - Mrem * W
56- Mbsize, Mrem, Mremfinal, Mblocks
49+ Miters = cld_fast (M, W)
50+ Mblocks = min (_Mblocks, Miters)
51+ Miter_per_block, Mrem = divrem_fast (Miters, Mblocks)
52+ Mbsize = Miter_per_block * W
53+ Mremfinal = M - Mbsize * (Mblocks - 1 ) - Mrem * W
54+ Mbsize, Mrem, Mremfinal, Mblocks
5755end
5856
5957"""
@@ -162,33 +160,36 @@ Note that for synchronization on `B`, all threads must have the same values for
162160independently of `M`, this algorithm guarantees all threads are on the same page.
163161"""
164162@inline function solve_block_sizes (:: Val{T} , M, K, N, _α, _β, R₂, R₃, Wfactor) where {T}
165- W = pick_vector_width (T)
166- α = _α * W
167- β = _β * W
168- L₁ₑ = first_cache_size (Val (T)) * R₂
169- L₂ₑ = second_cache_size (Val (T)) * R₃
163+ W = pick_vector_width (T)
164+ α = _α * W
165+ β = _β * W
166+ L₁ₑ = first_cache_size (Val (T)) * R₂
167+ L₂ₑ = second_cache_size (Val (T)) * R₃
170168
171- # Nc_init = round(Int, √(L₂ₑ)*√(α * L₂ₑ + β * L₁ₑ)/√(L₁ₑ))
172- Nc_init⁻¹ = √ (L₁ₑ) / (√ (L₂ₑ)*√ (α * L₂ₑ + β * L₁ₑ))
173-
174- Niter = cldapproxi (N, Nc_init⁻¹) # approximate `ceil`
175- Nblock, Nrem = divrem_fast (N, Niter)
176- Nblock_Nrem = Nblock + One ()# (Nrem > 0)
169+ # Nc_init = round(Int, √(L₂ₑ)*√(α * L₂ₑ + β * L₁ₑ)/√(L₁ₑ))
170+ Nc_init⁻¹ = √ (L₁ₑ) / (√ (L₂ₑ) * √ (α * L₂ₑ + β * L₁ₑ))
177171
178- ((Mblock, Mblock_Mrem, Mremfinal, Mrem, Miter), (Kblock, Kblock_Krem, Krem, Kiter)) = solve_McKc (Val (T), M, K, Nblock_Nrem, _α, _β, R₂, R₃, Wfactor)
179-
180- (Mblock, Mblock_Mrem, Mremfinal, Mrem, Miter), (Kblock, Kblock_Krem, Krem, Kiter), promote (Nblock, Nblock_Nrem, Nrem, Niter)
172+ Niter = cldapproxi (N, Nc_init⁻¹) # approximate `ceil`
173+ Nblock, Nrem = divrem_fast (N, Niter)
174+ Nblock_Nrem = Nblock + One ()# (Nrem > 0)
175+
176+ ((Mblock, Mblock_Mrem, Mremfinal, Mrem, Miter), (Kblock, Kblock_Krem, Krem, Kiter)) =
177+ solve_McKc (Val (T), M, K, Nblock_Nrem, _α, _β, R₂, R₃, Wfactor)
178+
179+ (Mblock, Mblock_Mrem, Mremfinal, Mrem, Miter),
180+ (Kblock, Kblock_Krem, Krem, Kiter),
181+ promote (Nblock, Nblock_Nrem, Nrem, Niter)
181182end
182183# Takes Nc, calcs Mc and Kc
183184@inline function solve_McKc (:: Val{T} , M, K, Nc, _α, _β, R₂, R₃, Wfactor) where {T}
184185 W = pick_vector_width (T)
185186 Wfloat = StaticFloat64 (W)
186187 α = _α * Wfloat
187- β = _β * Wfloat
188- L₁ₑ = first_cache_size (Val (T)) * R₂
188+ # β = _β * Wfloat
189+ L₁ₑ = first_cache_size (Val (T)) * R₂
189190 L₂ₑ = second_cache_size (Val (T)) * R₃
190191
191- Kc_init⁻¹ = Base. FastMath. max_fast (√ (α/ L₁ₑ), Nc* inv (L₂ₑ))
192+ Kc_init⁻¹ = Base. FastMath. max_fast (√ (α / L₁ₑ), Nc * inv (L₂ₑ))
192193 Kiter = cldapproxi (K, Kc_init⁻¹) # approximate `ceil`
193194 Kblock, Krem = divrem_fast (K, Kiter)
194195 Kblock_Krem = Kblock + One ()
202203 Mblocks, Mblocks_rem = divrem_fast (M, Mᵣ)
203204 Miter, Mrem = divrem_fast (Mblocks, Mc_init_base)
204205 if Miter == 0
205- return (0 , 0 , Int (M):: Int , 0 , 1 ), Kblock_summary
206+ return (0 , 0 , Int (M):: Int , 0 , 1 ), Kblock_summary
206207 elseif Miter > Mrem
207208 Mblock_Mrem = Mbsize + Mᵣ
208209 Mremfinal = Mbsize + Mblocks_rem
221222 end
222223end
223224
224- @inline cldapproxi (n, d⁻¹) = Base. fptosi (Int, Base. FastMath. add_fast (Base. FastMath. mul_fast (n, d⁻¹), 0.9999999999999432 )) # approximate `ceil`
225+ @inline cldapproxi (n, d⁻¹) = Base. fptosi (
226+ Int,
227+ Base. FastMath. add_fast (Base. FastMath. mul_fast (n, d⁻¹), 0.9999999999999432 ),
228+ ) # approximate `ceil`
225229# @inline divapproxi(n, d⁻¹) = Base.fptosi(Int, Base.FastMath.mul_fast(n, d⁻¹)) # approximate `div`
226230
227231"""
@@ -231,14 +235,14 @@ Finds first combination of `Miter` and `Niter` that doesn't make `M` too small w
231235This would be awkard if there are computers with prime numbers of cores. I should probably consider that possibility at some point.
232236"""
233237@inline function find_first_acceptable (:: Val{T} , M, W) where {T}
234- Mᵣ, Nᵣ = matmul_params (Val (T))
235- factors = calc_factors ()
236- for (miter, niter) ∈ factors
237- if miter * (StaticInt (2 ) * Mᵣ * W) ≤ M + (W + W)
238- return miter, niter
239- end
238+ Mᵣ, _ = matmul_params (Val (T))
239+ factors = calc_factors ()
240+ for (miter, niter) ∈ factors
241+ if miter * (StaticInt (2 ) * Mᵣ * W) ≤ M + (W + W)
242+ return miter, niter
240243 end
241- last (factors)
244+ end
245+ last (factors)
242246end
243247"""
244248 divide_blocks(M, Ntotal, _nspawn, W)
@@ -247,8 +251,8 @@ Splits both `M` and `N` into blocks when trying to spawn a large number of threa
247251"""
248252@inline function divide_blocks (:: Val{T} , M, Ntotal, _nspawn, W) where {T}
249253 _nspawn == num_cores () && return find_first_acceptable (Val (T), M, W)
250- mᵣ, nᵣ = matmul_params (Val (T))
251- Miter = clamp (div_fast (M, W* mᵣ * MᵣW_mul_factor ()), 1 , _nspawn)
254+ mᵣ, _ = matmul_params (Val (T))
255+ Miter = clamp (div_fast (M, W * mᵣ * MᵣW_mul_factor ()), 1 , _nspawn)
252256 nspawn = div_fast (_nspawn, Miter)
253257 if (nspawn ≤ 1 ) & (Miter < _nspawn)
254258 # rebalance Miter
0 commit comments