11#version 450
22
3+ #extension GL_EXT_control_flow_attributes : enable
4+
35#ifdef USE_COLLECTIVES
46# extension GL_KHR_shader_subgroup_shuffle : enable
57#endif
68
79#include "types.comp"
810
9- // Make spec constant
10- #define SHMEM_PAD 0
11-
1211// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
1312layout(binding = 0) readonly buffer A {
1413 A_TYPE knl_data[];
@@ -56,6 +55,12 @@ layout(push_constant) uniform parameter {
5655 uint32_t nb1;
5756 uint32_t nb2;
5857 uint32_t nb3;
58+
59+ // fastdiv helper values
60+ uint32_t KWmp; uint32_t KWL;
61+ uint32_t KWKHmp; uint32_t KWKHL;
62+ uint32_t OWmp; uint32_t OWL;
63+ uint32_t OWOHmp; uint32_t OWOHL;
5964}
6065
6166p;
@@ -68,6 +73,7 @@ layout(constant_id = 3) const uint BS_NPQ = 128;
6873// Thread-tile sizes
6974layout(constant_id = 4) const uint TS_K = 8;
7075layout(constant_id = 5) const uint use_collectives = 1;
76+ layout(constant_id = 6) const uint SHMEM_PAD = 4;
7177
7278uint32_t tid = gl_LocalInvocationID.x;
7379const uint32_t WG_SIZE = gl_WorkGroupSize.x;
@@ -131,6 +137,14 @@ uint32_t Br = tid / BS_NPQ;
131137uint32_t Bc = tid % BS_NPQ;
132138const uint32_t BrpWg = WG_SIZE / BS_NPQ;
133139
140+ // see init_fastdiv_values in ggml-vulkan.cpp
141+ uint fastdiv(uint n, uint mp, uint L) {
142+ uint msbs, lsbs;
143+ // msbs = mulhi(n, mp)
144+ umulExtended(n, mp, msbs, lsbs);
145+ return (msbs + n) >> L;
146+ }
147+
134148void main() {
135149 for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
136150 for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
@@ -151,9 +165,9 @@ void main() {
151165 uint32_t cached_KW_idx;
152166 if (use_collectives == 1) {
153167 cached_CRS_idx = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID;
154- cached_Cin_idx = cached_CRS_idx / (p.KW * p.KH);
168+ cached_Cin_idx = fastdiv( cached_CRS_idx, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
155169 uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH);
156- cached_KH_idx = cached_CRS_remainder / p.KW;
170+ cached_KH_idx = fastdiv( cached_CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
157171 cached_KW_idx = cached_CRS_remainder - cached_KH_idx * p.KW;
158172
159173 CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac);
@@ -162,16 +176,16 @@ void main() {
162176 KW_idx_a = subgroupShuffle(cached_KW_idx, Ac);
163177 } else {
164178 CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
165- Cin_idx_a = CRS_idx_a / (p.KW * p.KH);
179+ Cin_idx_a = fastdiv( CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
166180 uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
167- KH_idx_a = CRS_remainder / p.KW;
181+ KH_idx_a = fastdiv( CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
168182 KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
169183 }
170184#else
171185 CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
172- Cin_idx_a = CRS_idx_a / (p.KW * p.KH);
186+ Cin_idx_a = fastdiv( CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); / (p.KW * p.KH);
173187 CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
174- KH_idx_a = CRS_remainder / p.KW;
188+ KH_idx_a = fastdiv( CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
175189 KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
176190#endif
177191
@@ -188,13 +202,13 @@ void main() {
188202 Ash[B_ly * Ash_stride + B_lx] = val;
189203 }
190204 /* Load input to B_block: (BS_CRS x BS_NPQ) */
191- for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
205+ UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
192206 uint32_t B_ly = r_offset + Br; /* Row index of B block */
193207 uint32_t B_lx = Bc;
194208 uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
195- uint32_t N_idx = NPQ_idx / ( p.OH * p.OW) ;
209+ uint32_t N_idx = fastdiv( NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;
196210 uint32_t NPQ_remainder = NPQ_idx - N_idx * p.OH * p.OW;
197- uint32_t OH_idx = NPQ_remainder / p.OW;
211+ uint32_t OH_idx = fastdiv( NPQ_remainder, p.OWmp, p.OWL); // divide by p.OW;
198212 uint32_t OW_idx = NPQ_remainder - OH_idx * p.OW;
199213
200214 uint32_t CRS_idx_b;
@@ -209,16 +223,16 @@ void main() {
209223 KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br);
210224 } else {
211225 CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
212- Cin_idx_b = CRS_idx_b / (p.KW * p.KH);
226+ Cin_idx_b = fastdiv( CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
213227 uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
214- KH_idx_b = CRS_remainder / p.KW;
228+ KH_idx_b = fastdiv( CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
215229 KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
216230 }
217231#else
218232 CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
219- Cin_idx_b = CRS_idx_b / (p.KW * p.KH);
233+ Cin_idx_b = fastdiv( CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
220234 uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
221- KH_idx_b = CRS_remainder / p.KW;
235+ KH_idx_b = fastdiv( CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
222236 KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
223237#endif
224238
@@ -233,32 +247,36 @@ void main() {
233247 Bsh[B_ly * Bsh_stride + B_lx] = val;
234248 }
235249 barrier();
236- for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
237- for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
238- regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
239- }
240- for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
241- regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
242- }
243- for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
250+ if (T_y * TS_K < K) {
251+ UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
252+ for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
253+ regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
254+ }
244255 for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
245- regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
256+ regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
257+ }
258+ for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
259+ for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
260+ regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
261+ }
246262 }
247263 }
248264 }
249265 barrier();
250266 }
251267 /* Save C* */
252- for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
253- for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
254- uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
255- uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
256- uint32_t N_idx = NPQ_idx / (p.OH * p.OW);
257- uint32_t OH_idx = (NPQ_idx - N_idx * p.OH * p.OW) / p.OW;
258- uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
259- uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
260- if (K_idx < K && NPQ_idx < NPQ) {
261- dst_data[dst_idx] = regC[T_ly][T_lx];
268+ if (T_y * TS_K < K) {
269+ for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
270+ for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
271+ uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
272+ uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
273+ uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;
274+ uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW;
275+ uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
276+ uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
277+ if (K_idx < K && NPQ_idx < NPQ) {
278+ dst_data[dst_idx] = regC[T_ly][T_lx];
279+ }
262280 }
263281 }
264282 }
0 commit comments