Skip to content

Commit 4239d25

Browse files
ikawrakowIwan Kawrakow
andauthored
Quick hack to improve TG performance for SWA models (#692)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
1 parent fc06bc9 commit 4239d25

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

ggml/src/iqk/iqk_flash_attn.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,29 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
8080
int rk3 = neq3/nek3;
8181
int rv3 = neq3/nev3;
8282

83+
int first_k = 0, last_k = nek1;
84+
if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nek1 > 256) {
85+
// This is a quick hack for SWA models.
86+
// Given that the mask is the same for all layers, ideally we should determinbe the
87+
// cache bounds once, and reuse for the whole graph. But even with this simple hack
88+
// we get non-negligible performance gains for SWA models and long context.
89+
auto umask = (const uint16_t *)mask;
90+
for (; first_k < last_k; ++first_k) {
91+
if (umask[first_k] == 0) break;
92+
}
93+
for (; last_k > first_k; --last_k) {
94+
if (umask[last_k-1] == 0) break;
95+
}
96+
//printf("nek1 = %d, first = %d, last = %d\n", nek1, first, last);
97+
if (last_k - first_k <= 3*nek1/4 && (last_k - first_k)%32 == 0) {
98+
//printf("Reducing from %d to %d\n", nek1, last_k - first_k);
99+
k = (const void *)((const char *)k + first_k*stride_k);
100+
v = (const void *)((const char *)v + first_k*stride_v);
101+
mask = (const void *)((const uint16_t *)mask + first_k);
102+
nek1 = last_k - first_k;
103+
}
104+
}
105+
83106
int int_type_k = int_type_k_in;
84107
auto work_buffer = work_buffer_in;
85108
if (neq1 >= 8 || (rk2 >= 8 && nek2 > 1)) {

0 commit comments

Comments
 (0)