@@ -10,6 +10,7 @@ function (::LoopMulFunc{P,TC,TA,TB,Α,Β,Md,Kd,Nd})(p::Ptr{UInt}) where {P,TC,TA
1010 offset, K = load (p, Kd, offset)
1111 offset, N = load (p, Nd, offset)
1212 _call_loopmul! (C, A, B, α, β, M, K, N, Val {P} ())
13+ _atomic_store! (p, SPIN)
1314 nothing
1415end
1516@inline _call_loopmul! (C, A, B, α, β, M, K, N, :: Val{false} ) = loopmul! (C, A, B, α, β, M, K, N)
@@ -39,6 +40,7 @@ function (::SyncMulFunc{TC,TA,TB,Α,Β,Md,Kd,Nd,BCP,ID,TT,W₁,W₂,R₁,R₂})(
3940 offset, id = load (p, ID, offset)
4041 offset, total_ids = load (p, TT, offset)
4142 sync_mul! (C, A, B, α, β, M, K, N, atomicp, bcachep, id, total_ids, StaticFloat64 {W₁} (), StaticFloat64 {W₂} (), StaticFloat64 {R₁} (), StaticFloat64 {R₂} ())
43+ _atomic_store! (p, SPIN)
4244 nothing
4345end
4446
6365 nothing
6466end
6567
66- @inline function setup_syncmul! (
68+ @inline function launch_thread_mul! (C, A, B, α, β, M, K, N, tid:: UInt32 , :: Val{P} ) where {P}
69+ launch (setup_matmul!, tid, C, A, B, α, β, M, K, N, Val {P} ())
70+ end
71+
72+ struct SyncMulLauncher{W₁, W₂, R₁, R₂} end
73+ @inline function (:: SyncMulLauncher{W₁, W₂, R₁, R₂} )(
6774 p:: Ptr{UInt} , C:: TC , A:: TA , B:: TB , α:: Α, β:: Β, M:: Md , K:: Kd , N:: Nd ,
68- ap:: Ptr{UInt32} ,bcp:: BCP ,id:: ID ,tt:: TT , :: StaticFloat64{W₁} , :: StaticFloat64{W₂} , :: StaticFloat64{R₁} , :: StaticFloat64{R₂}
75+ ap:: Ptr{UInt32} ,bcp:: BCP ,id:: ID ,tt:: TT
6976) where {TC,TA,TB,Α,Β,Md,Kd,Nd,BCP,ID,TT,W₁,W₂,R₁,R₂}
70- offset = store! (p, cfuncpointer (SyncMulFunc {TC,TA,TB,Α,Β,Md,Kd,Nd,BCP,ID,TT,W₁,W₂,R₁,R₂} ()), sizeof (UInt))
77+ fptr = cfuncpointer (SyncMulFunc {TC,TA,TB,Α,Β,Md,Kd,Nd,BCP,ID,TT,W₁,W₂,R₁,R₂} ())
78+ offset = store! (p, fptr, sizeof (UInt))
7179 offset = store! (p, C, offset)
7280 offset = store! (p, A, offset)
7381 offset = store! (p, B, offset)
8290 offset = store! (p, tt, offset)
8391 nothing
8492end
85-
86- @inline function launch_thread_mul! (C, A, B, α, β, M, K, N, tid:: UInt32 , :: Val{P} ) where {P}
87- launch (setup_matmul!, tid, C, A, B, α, β, M, K, N, Val {P} ())
88- end
8993@inline function launch_thread_mul! (
90- C, A, B, α, β, M, K, N, ap, bcp, tid, id, tt, :: StaticFloat64{W₁} ,:: StaticFloat64{W₂} ,:: StaticFloat64{R₁} ,:: StaticFloat64{R₂}
94+ C, A, B, α, β, M, K, N, ap, bcp, tid, id, tt,
95+ :: StaticFloat64{W₁} ,:: StaticFloat64{W₂} ,:: StaticFloat64{R₁} ,:: StaticFloat64{R₂}
9196) where {W₁,W₂,R₁,R₂}
92- launch (tid, C, A, B, α, β, M, K, N, ap, bcp, id, tt) do p, C, A, B, α, β, M, K, N, ap, bcp, id, tt
93- Base. @_inline_meta
94- setup_syncmul! (
95- p, C, A, B, α, β, M, K, N, ap, bcp, id, tt,
96- StaticFloat64 {W₁} (),StaticFloat64 {W₂} (),StaticFloat64 {R₁} (),StaticFloat64 {R₂} ()
97- )
98- end
97+ launch (SyncMulLauncher {W₁, W₂, R₁, R₂} (), tid, C, A, B, α, β, M, K, N, ap, bcp, id, tt)
9998end
10099
101100
0 commit comments