Skip to content

Commit 93a4f60

Browse files
ikawrakowIwan Kawrakow
andauthored
Better CPU prompt processing performance for SWA models (#696)
* This does the trick for PP * Compute mask bounds when creating the mask * Set mask bounds for all supported SWA models --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
1 parent 4bf5c81 commit 93a4f60

File tree

5 files changed

+140
-30
lines changed

5 files changed

+140
-30
lines changed

ggml/include/ggml.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2043,6 +2043,10 @@ extern "C" {
20432043
struct ggml_tensor * a,
20442044
struct ggml_tensor * sinks);
20452045

2046+
GGML_API void ggml_flash_attn_ext_add_bounds(
2047+
struct ggml_tensor * a,
2048+
struct ggml_tensor * bounds);
2049+
20462050
// TODO: needs to be adapted to ggml_flash_attn_ext
20472051
GGML_API struct ggml_tensor * ggml_flash_attn_back(
20482052
struct ggml_context * ctx,

ggml/src/ggml.c

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8993,6 +8993,22 @@ void ggml_flash_attn_ext_add_sinks(
89938993
a->src[4] = sinks;
89948994
}
89958995

8996+
void ggml_flash_attn_ext_add_bounds(
8997+
struct ggml_tensor * a,
8998+
struct ggml_tensor * bounds) {
8999+
if (!bounds) {
9000+
a->src[5] = NULL;
9001+
return;
9002+
}
9003+
9004+
GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
9005+
GGML_ASSERT(bounds->type == GGML_TYPE_I32);
9006+
GGML_ASSERT(bounds->ne[0] == 2);
9007+
GGML_ASSERT(bounds->ne[1] >= a->src[0]->ne[1]);
9008+
9009+
a->src[5] = bounds;
9010+
}
9011+
89969012
// ggml_flash_attn_back
89979013

89989014
struct ggml_tensor * ggml_flash_attn_back(
@@ -18661,6 +18677,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1866118677
const struct ggml_tensor * v = dst->src[2];
1866218678
const struct ggml_tensor * mask = dst->src[3];
1866318679
const struct ggml_tensor * sinks = dst->src[4];
18680+
const struct ggml_tensor * bounds= dst->src[5];
1866418681

1866518682
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
1866618683
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
@@ -18739,7 +18756,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1873918756
dst->ne[2], dst->ne[1], dst->nb[1],
1874018757
k->type, v->type,
1874118758
Dk, Dv, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1],
18742-
q->data, k->data, v->data, mask->data, sinks ? sinks->data : NULL,
18759+
q->data, k->data, v->data, mask->data,
18760+
sinks ? sinks->data : NULL,
18761+
bounds ? bounds->data : NULL,
1874318762
scale, softcap, (float *)dst->data,
1874418763
params->wdata, (barrier_t)ggml_barrier, (void *)params->shared, ith, nth)) return;
1874518764

ggml/src/iqk/iqk_flash_attn.cpp

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,27 @@ inline void accumulate_qkv(int Dv, float& M, float& S, float Mj, float Sj, float
4343
for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i];
4444
}
4545
}
46+
inline std::pair<int, int> mask_range(int nek1, const uint16_t * umask) {
47+
int first_k = 0, last_k = nek1;
48+
for (; first_k < last_k; ++first_k) {
49+
if (umask[first_k] == 0) break;
50+
}
51+
for (; last_k > first_k; --last_k) {
52+
if (umask[last_k-1] == 0) break;
53+
}
54+
return { first_k, last_k };
55+
}
56+
inline bool reduce_k_range(int nek1, int& first_k, int& last_k) {
57+
int nk = last_k - first_k;
58+
if (nk >= nek1) return false;
59+
if (nk%32) {
60+
int nk32 = 32*((nk + 31)/32);
61+
int diff = nk32 - nk;
62+
first_k = std::max(0, first_k - diff);
63+
last_k = first_k + nk32;
64+
}
65+
return last_k - first_k < nek1;
66+
}
4667
}
4768

4869
// TODO: get the ggml_type enum here without polution
@@ -66,7 +87,8 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
6687
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
6788
const void * v, // v matrix. Assumed to be fp16, nq x nk elements
6889
const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
69-
const void * sinks, // mask. If not null, assumed to be fp16. nq x nk elements
90+
const void * sinks, // attention sinks
91+
const void * bounds, // attention mask bounds
7092
float scale, // scale applied before softmax
7193
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
7294
float * qkv, // v*softmax(scale*(k*q))
@@ -80,22 +102,13 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
80102
int rk3 = neq3/nek3;
81103
int rv3 = neq3/nev3;
82104

83-
int first_k = 0, last_k = nek1;
84-
if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nek1 > 256) {
85-
// This is a quick hack for SWA models.
86-
// Given that the mask is the same for all layers, ideally we should determinbe the
87-
// cache bounds once, and reuse for the whole graph. But even with this simple hack
88-
// we get non-negligible performance gains for SWA models and long context.
89-
auto umask = (const uint16_t *)mask;
90-
for (; first_k < last_k; ++first_k) {
91-
if (umask[first_k] == 0) break;
92-
}
93-
for (; last_k > first_k; --last_k) {
94-
if (umask[last_k-1] == 0) break;
95-
}
96-
//printf("nek1 = %d, first = %d, last = %d\n", nek1, first, last);
97-
if (last_k - first_k <= 3*nek1/4 && (last_k - first_k)%32 == 0) {
98-
//printf("Reducing from %d to %d\n", nek1, last_k - first_k);
105+
bool range_found = false;
106+
if (neq3 == 1 && rk2 > 1 && neq1 == 1 && bounds && nek1 > 32) {
107+
range_found = true;
108+
auto b = (const int32_t *)bounds;
109+
int first_k = b[0];
110+
int last_k = b[1];
111+
if ((last_k - first_k)%32 == 0) { // why is this not better? : if (reduce_k_range(nek1, first_k, last_k)) {
99112
k = (const void *)((const char *)k + first_k*stride_k);
100113
v = (const void *)((const char *)v + first_k*stride_v);
101114
mask = (const void *)((const uint16_t *)mask + first_k);
@@ -105,7 +118,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
105118

106119
int int_type_k = int_type_k_in;
107120
auto work_buffer = work_buffer_in;
108-
if (neq1 >= 8 || (rk2 >= 8 && nek2 > 1)) {
121+
if (neq1 >= 8 || (false && rk2 >= 8 && nek2 > 1)) {
109122
uint64_t row_size = 0;
110123
work_buffer = iqk_repack_k(int_type_k, Dk, nek1, nek2, nek3, stride_k, nbk2, nbk3, k, work_buffer_in, ith, nth, int_type_k, row_size);
111124
if (int_type_k != int_type_k_in) {
@@ -299,6 +312,25 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
299312
if (counter++ % (nth/ntg) == ith/ntg) {
300313
int iq1 = (ith%ntg)*neq1g;
301314
int this_neq1 = std::min(neq1g, neq1-iq1);
315+
if (bounds && !range_found) {
316+
auto b = (const int32_t *)bounds + 2*iq1;
317+
int kmin = nek1, kmax = 0;
318+
for (int i = 0; i < this_neq1; ++i) {
319+
kmin = std::min(kmin, b[2*i+0]);
320+
kmax = std::max(kmax, b[2*i+1]);
321+
}
322+
if (reduce_k_range(nek1, kmin, kmax)) {
323+
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
324+
Dk, Dv, this_neq1, kmax-kmin, stride_q, stride_k, stride_v, stride_m, ne1*nb1/sizeof(float),
325+
(const float *)((const char *)q + iq2*nbq2 + iq3*nbq3 + iq1*stride_q),
326+
(const void *)((const char *)k + iq2/rk2*nbk2 + iq3/rk3*nbk3 + kmin*stride_k),
327+
(const void *)((const char *)v + iq2/rv2*nbv2 + iq3/rv3*nbv3 + kmin*stride_v),
328+
(const void *)((const char *)mask + iq1*stride_m + kmin*sizeof(uint16_t)), sinksf, 1,
329+
scale, softcap,
330+
(float *)((char *)qkv + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1), nullptr, nullptr)) return false;
331+
continue;
332+
}
333+
}
302334
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
303335
Dk, Dv, this_neq1, nek1, stride_q, stride_k, stride_v, stride_m, ne1*nb1/sizeof(float),
304336
(const float *)((const char *)q + iq2*nbq2 + iq3*nbq3 + iq1*stride_q),

ggml/src/iqk/iqk_mul_mat.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
5858
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
5959
const void * v, // v matrix. Assumed to be fp16, nq x nk elements
6060
const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
61-
const void * sinks, // mask. If not null, assumed to be fp16. nq x nk elements
61+
const void * sinks, // attention sinks
62+
const void * bounds, // attention mask bounds
6263
float scale, // scale applied before softmax
6364
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
6465
float * qkv, // v*softmax(scale*(k*q))

0 commit comments

Comments
 (0)