@@ -43,6 +43,27 @@ inline void accumulate_qkv(int Dv, float& M, float& S, float Mj, float Sj, float
4343 for (int i = 0 ; i < Dv; ++i) Racc[i] += c*R[i];
4444 }
4545}
46+ inline std::pair<int , int > mask_range (int nek1, const uint16_t * umask) {
47+ int first_k = 0 , last_k = nek1;
48+ for (; first_k < last_k; ++first_k) {
49+ if (umask[first_k] == 0 ) break ;
50+ }
51+ for (; last_k > first_k; --last_k) {
52+ if (umask[last_k-1 ] == 0 ) break ;
53+ }
54+ return { first_k, last_k };
55+ }
56+ inline bool reduce_k_range (int nek1, int & first_k, int & last_k) {
57+ int nk = last_k - first_k;
58+ if (nk >= nek1) return false ;
59+ if (nk%32 ) {
60+ int nk32 = 32 *((nk + 31 )/32 );
61+ int diff = nk32 - nk;
62+ first_k = std::max (0 , first_k - diff);
63+ last_k = first_k + nk32;
64+ }
65+ return last_k - first_k < nek1;
66+ }
4667}
4768
4869// TODO: get the ggml_type enum here without polution
@@ -66,7 +87,8 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
6687 const void * k, // k matrix. Assumed to be fp16, nq x nk elements
6788 const void * v, // v matrix. Assumed to be fp16, nq x nk elements
6889 const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
69- const void * sinks, // mask. If not null, assumed to be fp16. nq x nk elements
90+ const void * sinks, // attention sinks
91+ const void * bounds, // attention mask bounds
7092 float scale, // scale applied before softmax
7193 float softcap, // if > 0, a "soft-cap" operation is applied before softmax
7294 float * qkv, // v*softmax(scale*(k*q))
@@ -80,22 +102,13 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
80102 int rk3 = neq3/nek3;
81103 int rv3 = neq3/nev3;
82104
83- int first_k = 0 , last_k = nek1;
84- if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nek1 > 256 ) {
85- // This is a quick hack for SWA models.
86- // Given that the mask is the same for all layers, ideally we should determinbe the
87- // cache bounds once, and reuse for the whole graph. But even with this simple hack
88- // we get non-negligible performance gains for SWA models and long context.
89- auto umask = (const uint16_t *)mask;
90- for (; first_k < last_k; ++first_k) {
91- if (umask[first_k] == 0 ) break ;
92- }
93- for (; last_k > first_k; --last_k) {
94- if (umask[last_k-1 ] == 0 ) break ;
95- }
96- // printf("nek1 = %d, first = %d, last = %d\n", nek1, first, last);
97- if (last_k - first_k <= 3 *nek1/4 && (last_k - first_k)%32 == 0 ) {
98- // printf("Reducing from %d to %d\n", nek1, last_k - first_k);
105+ bool range_found = false ;
106+ if (neq3 == 1 && rk2 > 1 && neq1 == 1 && bounds && nek1 > 32 ) {
107+ range_found = true ;
108+ auto b = (const int32_t *)bounds;
109+ int first_k = b[0 ];
110+ int last_k = b[1 ];
111+ if ((last_k - first_k)%32 == 0 ) { // why is this not better? : if (reduce_k_range(nek1, first_k, last_k)) {
99112 k = (const void *)((const char *)k + first_k*stride_k);
100113 v = (const void *)((const char *)v + first_k*stride_v);
101114 mask = (const void *)((const uint16_t *)mask + first_k);
@@ -105,7 +118,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
105118
106119 int int_type_k = int_type_k_in;
107120 auto work_buffer = work_buffer_in;
108- if (neq1 >= 8 || (rk2 >= 8 && nek2 > 1 )) {
121+ if (neq1 >= 8 || (false && rk2 >= 8 && nek2 > 1 )) {
109122 uint64_t row_size = 0 ;
110123 work_buffer = iqk_repack_k (int_type_k, Dk, nek1, nek2, nek3, stride_k, nbk2, nbk3, k, work_buffer_in, ith, nth, int_type_k, row_size);
111124 if (int_type_k != int_type_k_in) {
@@ -299,6 +312,25 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
299312 if (counter++ % (nth/ntg) == ith/ntg) {
300313 int iq1 = (ith%ntg)*neq1g;
301314 int this_neq1 = std::min (neq1g, neq1-iq1);
315+ if (bounds && !range_found) {
316+ auto b = (const int32_t *)bounds + 2 *iq1;
317+ int kmin = nek1, kmax = 0 ;
318+ for (int i = 0 ; i < this_neq1; ++i) {
319+ kmin = std::min (kmin, b[2 *i+0 ]);
320+ kmax = std::max (kmax, b[2 *i+1 ]);
321+ }
322+ if (reduce_k_range (nek1, kmin, kmax)) {
323+ if (!iqk_flash_attn_impl (int_type_k, int_type_v,
324+ Dk, Dv, this_neq1, kmax-kmin, stride_q, stride_k, stride_v, stride_m, ne1*nb1/sizeof (float ),
325+ (const float *)((const char *)q + iq2*nbq2 + iq3*nbq3 + iq1*stride_q),
326+ (const void *)((const char *)k + iq2/rk2*nbk2 + iq3/rk3*nbk3 + kmin*stride_k),
327+ (const void *)((const char *)v + iq2/rv2*nbv2 + iq3/rv3*nbv3 + kmin*stride_v),
328+ (const void *)((const char *)mask + iq1*stride_m + kmin*sizeof (uint16_t )), sinksf, 1 ,
329+ scale, softcap,
330+ (float *)((char *)qkv + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1), nullptr , nullptr )) return false ;
331+ continue ;
332+ }
333+ }
302334 if (!iqk_flash_attn_impl (int_type_k, int_type_v,
303335 Dk, Dv, this_neq1, nek1, stride_q, stride_k, stride_v, stride_m, ne1*nb1/sizeof (float ),
304336 (const float *)((const char *)q + iq2*nbq2 + iq3*nbq3 + iq1*stride_q),
0 commit comments