@@ -2007,11 +2007,20 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
20072007
20082008 GGML_ASSERT (ne01 < 65536 );
20092009
2010+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id (op->src [0 ]);
2011+ ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id (op->src [1 ]);
2012+ ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id (op->src [2 ]);
2013+ ggml_metal_buffer_id bid_src3 = has_mask ? ggml_metal_get_buffer_id (op->src [3 ]) : bid_src0;
2014+ ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id (op->src [4 ]) : bid_src0;
2015+
20102016 ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id (op);
20112017
20122018 ggml_metal_buffer_id bid_pad = bid_dst;
20132019 bid_pad.offs += ggml_nbytes (op);
20142020
2021+ ggml_metal_buffer_id bid_tmp = bid_pad;
2022+ bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_pad (op);
2023+
20152024 if (!ggml_metal_op_flash_attn_ext_use_vec (op)) {
20162025 // half8x8 kernel
20172026 const int64_t nqptg = 8 ; // queries per threadgroup !! sync with kernel template arguments !!
@@ -2048,14 +2057,10 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
20482057
20492058 ggml_metal_encoder_set_pipeline (enc, pipeline0);
20502059 ggml_metal_encoder_set_bytes (enc, &args0, sizeof (args0), 0 );
2051- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [1 ]), 1 );
2052- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [2 ]), 2 );
2053- if (op->src [3 ]) {
2054- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [3 ]), 3 );
2055- } else {
2056- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 3 );
2057- }
2058- ggml_metal_encoder_set_buffer (enc, bid_pad, 4 );
2060+ ggml_metal_encoder_set_buffer (enc, bid_src1, 1 );
2061+ ggml_metal_encoder_set_buffer (enc, bid_src2, 2 );
2062+ ggml_metal_encoder_set_buffer (enc, bid_src3, 3 );
2063+ ggml_metal_encoder_set_buffer (enc, bid_pad, 4 );
20592064
20602065 assert (ne12 == ne22);
20612066 assert (ne13 == ne23);
@@ -2137,21 +2142,13 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
21372142
21382143 ggml_metal_encoder_set_pipeline (enc, pipeline);
21392144 ggml_metal_encoder_set_bytes (enc, &args, sizeof (args), 0 );
2140- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 1 );
2141- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [1 ]), 2 );
2142- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [2 ]), 3 );
2143- if (op->src [3 ]) {
2144- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [3 ]), 4 );
2145- } else {
2146- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 4 );
2147- }
2148- if (op->src [4 ]) {
2149- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [4 ]), 5 );
2150- } else {
2151- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 5 );
2152- }
2153- ggml_metal_encoder_set_buffer (enc, bid_pad, 6 );
2154- ggml_metal_encoder_set_buffer (enc, bid_dst, 7 );
2145+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1 );
2146+ ggml_metal_encoder_set_buffer (enc, bid_src1, 2 );
2147+ ggml_metal_encoder_set_buffer (enc, bid_src2, 3 );
2148+ ggml_metal_encoder_set_buffer (enc, bid_src3, 4 );
2149+ ggml_metal_encoder_set_buffer (enc, bid_src4, 5 );
2150+ ggml_metal_encoder_set_buffer (enc, bid_pad, 6 );
2151+ ggml_metal_encoder_set_buffer (enc, bid_dst, 7 );
21552152
21562153 ggml_metal_encoder_set_threadgroup_memory_size (enc, smem, 0 );
21572154
@@ -2194,14 +2191,10 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
21942191
21952192 ggml_metal_encoder_set_pipeline (enc, pipeline0);
21962193 ggml_metal_encoder_set_bytes (enc, &args0, sizeof (args0), 0 );
2197- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [1 ]), 1 );
2198- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [2 ]), 2 );
2199- if (op->src [3 ]) {
2200- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [3 ]), 3 );
2201- } else {
2202- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 3 );
2203- }
2204- ggml_metal_encoder_set_buffer (enc, bid_pad, 4 );
2194+ ggml_metal_encoder_set_buffer (enc, bid_src1, 1 );
2195+ ggml_metal_encoder_set_buffer (enc, bid_src2, 2 );
2196+ ggml_metal_encoder_set_buffer (enc, bid_src3, 3 );
2197+ ggml_metal_encoder_set_buffer (enc, bid_pad, 4 );
22052198
22062199 assert (ne12 == ne22);
22072200 assert (ne13 == ne23);
@@ -2300,26 +2293,20 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
23002293
23012294 ggml_metal_encoder_set_pipeline (enc, pipeline);
23022295 ggml_metal_encoder_set_bytes (enc, &args, sizeof (args), 0 );
2303- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 1 );
2304- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [1 ]), 2 );
2305- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [2 ]), 3 );
2306- if (op->src [3 ]) {
2307- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [3 ]), 4 );
2308- } else {
2309- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 4 );
2310- }
2311- if (op->src [4 ]) {
2312- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [4 ]), 5 );
2313- } else {
2314- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 5 );
2315- }
2296+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1 );
2297+ ggml_metal_encoder_set_buffer (enc, bid_src1, 2 );
2298+ ggml_metal_encoder_set_buffer (enc, bid_src2, 3 );
2299+ ggml_metal_encoder_set_buffer (enc, bid_src3, 4 );
2300+ ggml_metal_encoder_set_buffer (enc, bid_src4, 5 );
23162301
23172302 const size_t smem = FATTN_SMEM (nsg);
23182303
23192304 // printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, props_dev->max_theadgroup_memory_size, (int) nsg, (int) nsgmax);
23202305 GGML_ASSERT (smem <= props_dev->max_theadgroup_memory_size );
23212306
23222307 if (nwg == 1 ) {
2308+ assert (ggml_metal_op_flash_attn_ext_extra_tmp (op) == 0 );
2309+
23232310 // using 1 workgroup -> write the result directly into dst
23242311 ggml_metal_encoder_set_buffer (enc, bid_pad, 6 );
23252312 ggml_metal_encoder_set_buffer (enc, bid_dst, 7 );
@@ -2329,13 +2316,12 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
23292316 ggml_metal_encoder_dispatch_threadgroups (enc, (ne01 + nqptg - 1 )/nqptg, ne02, ne03*nwg, 32 , nsg, 1 );
23302317 } else {
23312318 // sanity checks
2319+ assert (ggml_metal_op_flash_attn_ext_extra_tmp (op) != 0 );
2320+
23322321 GGML_ASSERT (ne01*ne02*ne03 == ne1*ne2*ne3);
23332322 GGML_ASSERT ((uint64_t )ne1*ne2*ne3 <= (1u << 31 ));
23342323
23352324 // write the results from each workgroup into a temp buffer
2336- ggml_metal_buffer_id bid_tmp = bid_dst;
2337- bid_tmp.offs += ggml_nbytes (op) + ggml_metal_op_flash_attn_ext_extra_pad (op);
2338-
23392325 ggml_metal_encoder_set_buffer (enc, bid_pad, 6 );
23402326 ggml_metal_encoder_set_buffer (enc, bid_tmp, 7 );
23412327
0 commit comments