Skip to content

Commit aa0fe57

Browse files
authored
support Static 0.7.8, fixes #159 (#160)
1 parent de37846 commit aa0fe57

File tree

1 file changed

+19
-15
lines changed

1 file changed

+19
-15
lines changed

src/block_sizes.jl

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,25 @@
22
matmul_params(::Val{T}) where {T} = LoopVectorization.matmul_params()
33

44
function block_sizes(::Val{T}, _α, _β, R₁, R₂) where {T}
5-
W = pick_vector_width(T)
6-
α =* W
7-
β =* W
8-
L₁ₑ = first_cache_size(Val(T)) * R₁
9-
L₂ₑ = second_cache_size(Val(T)) * R₂
10-
block_sizes(Val(T), W, α, β, L₁ₑ, L₂ₑ)
5+
W = pick_vector_width(T)
6+
Wfloat = StaticFloat64(W)
7+
α =* Wfloat
8+
β =* Wfloat
9+
L₁ₑ = StaticFloat64(first_cache_size(Val(T))) * R₁
10+
L₂ₑ = StaticFloat64(second_cache_size(Val(T))) * R₂
11+
block_sizes(Val(T), W, α, β, L₁ₑ, L₂ₑ)
1112
end
1213
function block_sizes(::Val{T}, W, α, β, L₁ₑ, L₂ₑ) where {T}
13-
mᵣ, nᵣ = matmul_params(Val(T))
14-
MᵣW = mᵣ * W
14+
mᵣnᵣ = matmul_params(Val(T))
15+
mᵣ = getfield(mᵣnᵣ, 1)
16+
nᵣ = getfield(mᵣnᵣ, 2)
17+
MᵣW = mᵣ * W
1518

16-
Mc = floortostaticint((L₁ₑ)*√(L₁ₑ*β + L₂ₑ*α)/√(L₂ₑ) / MᵣW) * MᵣW
17-
Kc = roundtostaticint((L₁ₑ)*√(L₂ₑ)/√(L₁ₑ*β + L₂ₑ*α))
18-
Nc = floortostaticint((L₂ₑ)*√(L₁ₑ*β + L₂ₑ*α)/√(L₁ₑ) / nᵣ) * nᵣ
19-
20-
Mc, Kc, Nc
19+
Mc = floortostaticint((L₁ₑ)*√(L₁ₑ*β + L₂ₑ*α)/√(L₂ₑ) / StaticFloat64(MᵣW)) * MᵣW
20+
Kc = roundtostaticint((L₁ₑ)*√(L₂ₑ)/√(L₁ₑ*β + L₂ₑ*α))
21+
Nc = floortostaticint((L₂ₑ)*√(L₁ₑ*β + L₂ₑ*α)/√(L₁ₑ) / StaticFloat64(nᵣ)) * nᵣ
22+
23+
Mc, Kc, Nc
2124
end
2225
function block_sizes(::Val{T}) where {T}
2326
block_sizes(Val(T), W₁Default(), W₂Default(), R₁Default(), R₂Default())
@@ -179,8 +182,9 @@ end
179182
# Takes Nc, calcs Mc and Kc
180183
@inline function solve_McKc(::Val{T}, M, K, Nc, _α, _β, R₂, R₃, Wfactor) where {T}
181184
W = pick_vector_width(T)
182-
α =* W
183-
β =* W
185+
Wfloat = StaticFloat64(W)
186+
α =* Wfloat
187+
β =* Wfloat
184188
L₁ₑ = first_cache_size(Val(T)) * R₂
185189
L₂ₑ = second_cache_size(Val(T)) * R₃
186190

0 commit comments

Comments
 (0)