|
2 | 2 | matmul_params(::Val{T}) where {T} = LoopVectorization.matmul_params() |
3 | 3 |
|
4 | 4 | function block_sizes(::Val{T}, _α, _β, R₁, R₂) where {T} |
5 | | - W = pick_vector_width(T) |
6 | | - α = _α * W |
7 | | - β = _β * W |
8 | | - L₁ₑ = first_cache_size(Val(T)) * R₁ |
9 | | - L₂ₑ = second_cache_size(Val(T)) * R₂ |
10 | | - block_sizes(Val(T), W, α, β, L₁ₑ, L₂ₑ) |
| 5 | + W = pick_vector_width(T) |
| 6 | + Wfloat = StaticFloat64(W) |
| 7 | + α = _α * Wfloat |
| 8 | + β = _β * Wfloat |
| 9 | + L₁ₑ = StaticFloat64(first_cache_size(Val(T))) * R₁ |
| 10 | + L₂ₑ = StaticFloat64(second_cache_size(Val(T))) * R₂ |
| 11 | + block_sizes(Val(T), W, α, β, L₁ₑ, L₂ₑ) |
11 | 12 | end |
12 | 13 | function block_sizes(::Val{T}, W, α, β, L₁ₑ, L₂ₑ) where {T} |
13 | | - mᵣ, nᵣ = matmul_params(Val(T)) |
14 | | - MᵣW = mᵣ * W |
| 14 | + mᵣnᵣ = matmul_params(Val(T)) |
| 15 | + mᵣ = getfield(mᵣnᵣ, 1) |
| 16 | + nᵣ = getfield(mᵣnᵣ, 2) |
| 17 | + MᵣW = mᵣ * W |
15 | 18 |
|
16 | | - Mc = floortostaticint(√(L₁ₑ)*√(L₁ₑ*β + L₂ₑ*α)/√(L₂ₑ) / MᵣW) * MᵣW |
17 | | - Kc = roundtostaticint(√(L₁ₑ)*√(L₂ₑ)/√(L₁ₑ*β + L₂ₑ*α)) |
18 | | - Nc = floortostaticint(√(L₂ₑ)*√(L₁ₑ*β + L₂ₑ*α)/√(L₁ₑ) / nᵣ) * nᵣ |
19 | | - |
20 | | - Mc, Kc, Nc |
| 19 | + Mc = floortostaticint(√(L₁ₑ)*√(L₁ₑ*β + L₂ₑ*α)/√(L₂ₑ) / StaticFloat64(MᵣW)) * MᵣW |
| 20 | + Kc = roundtostaticint(√(L₁ₑ)*√(L₂ₑ)/√(L₁ₑ*β + L₂ₑ*α)) |
| 21 | + Nc = floortostaticint(√(L₂ₑ)*√(L₁ₑ*β + L₂ₑ*α)/√(L₁ₑ) / StaticFloat64(nᵣ)) * nᵣ |
| 22 | + |
| 23 | + Mc, Kc, Nc |
21 | 24 | end |
22 | 25 | function block_sizes(::Val{T}) where {T} |
23 | 26 | block_sizes(Val(T), W₁Default(), W₂Default(), R₁Default(), R₂Default()) |
|
179 | 182 | # Takes Nc, calcs Mc and Kc |
180 | 183 | @inline function solve_McKc(::Val{T}, M, K, Nc, _α, _β, R₂, R₃, Wfactor) where {T} |
181 | 184 | W = pick_vector_width(T) |
182 | | - α = _α * W |
183 | | - β = _β * W |
| 185 | + Wfloat = StaticFloat64(W) |
| 186 | + α = _α * Wfloat |
| 187 | + β = _β * Wfloat |
184 | 188 | L₁ₑ = first_cache_size(Val(T)) * R₂ |
185 | 189 | L₂ₑ = second_cache_size(Val(T)) * R₃ |
186 | 190 |
|
|
0 commit comments