@@ -80,6 +80,29 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
8080 int rk3 = neq3/nek3;
8181 int rv3 = neq3/nev3;
8282
83+ int first_k = 0 , last_k = nek1;
84+ if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nek1 > 256 ) {
85+ // This is a quick hack for SWA models.
86+ // Given that the mask is the same for all layers, ideally we should determinbe the
87+ // cache bounds once, and reuse for the whole graph. But even with this simple hack
88+ // we get non-negligible performance gains for SWA models and long context.
89+ auto umask = (const uint16_t *)mask;
90+ for (; first_k < last_k; ++first_k) {
91+ if (umask[first_k] == 0 ) break ;
92+ }
93+ for (; last_k > first_k; --last_k) {
94+ if (umask[last_k-1 ] == 0 ) break ;
95+ }
96+ // printf("nek1 = %d, first = %d, last = %d\n", nek1, first, last);
97+ if (last_k - first_k <= 3 *nek1/4 && (last_k - first_k)%32 == 0 ) {
98+ // printf("Reducing from %d to %d\n", nek1, last_k - first_k);
99+ k = (const void *)((const char *)k + first_k*stride_k);
100+ v = (const void *)((const char *)v + first_k*stride_v);
101+ mask = (const void *)((const uint16_t *)mask + first_k);
102+ nek1 = last_k - first_k;
103+ }
104+ }
105+
83106 int int_type_k = int_type_k_in;
84107 auto work_buffer = work_buffer_in;
85108 if (neq1 >= 8 || (rk2 >= 8 && nek2 > 1 )) {
0 commit comments