@@ -333,7 +333,8 @@ template <int nrc_y>
333333static void mul_mat_bf16_r16_bf16 (int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
334334 GGML_ASSERT (nrc_x%16 == 0 );
335335 const ggml_bf16_t * y[nrc_y];
336- for (int iy = 0 ; iy < nrc_y; ++iy) y[iy] = (const ggml_bf16_t *)info.src1_row (iy);
336+ static_for<nrc_y>([&](const int iy) { y[iy] = (const ggml_bf16_t *)info.src1_row (iy); });
337+
337338 for (int ix = 0 ; ix < nrc_x/32 ; ++ix) {
338339 __m512 acc[2 *nrc_y] = {};
339340 __m512bh qx[8 ];
@@ -348,7 +349,7 @@ static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataI
348349 qx[5 ] = (__m512bh)_mm512_loadu_si512 ((const __m512i *)b8_2+4 *ib+1 );
349350 qx[6 ] = (__m512bh)_mm512_loadu_si512 ((const __m512i *)b8_2+4 *ib+2 );
350351 qx[7 ] = (__m512bh)_mm512_loadu_si512 ((const __m512i *)b8_2+4 *ib+3 );
351- for ( int iy = 0 ; iy < nrc_y; ++ iy) {
352+ static_for< nrc_y>([&]( const int iy) {
352353 auto y128 = _mm_loadu_si128 ((const __m128i*)y[iy]+ib);
353354 // auto y = _mm512_broadcast_i32x4(y128);
354355 auto y256 = MM256_SET_M128I (y128, y128);
@@ -361,12 +362,12 @@ static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataI
361362 acc[2 *iy+1 ] = _mm512_dpbf16_ps (acc[2 *iy+1 ], qx[5 ], (__m512bh)_mm512_shuffle_epi32 (y, _MM_PERM_ENUM (0x55 )));
362363 acc[2 *iy+1 ] = _mm512_dpbf16_ps (acc[2 *iy+1 ], qx[6 ], (__m512bh)_mm512_shuffle_epi32 (y, _MM_PERM_ENUM (0xaa )));
363364 acc[2 *iy+1 ] = _mm512_dpbf16_ps (acc[2 *iy+1 ], qx[7 ], (__m512bh)_mm512_shuffle_epi32 (y, _MM_PERM_ENUM (0xff )));
364- }
365+ });
365366 }
366- for ( int iy = 0 ; iy < nrc_y; ++ iy) {
367+ static_for< nrc_y>([&]( const int iy) {
367368 info.store (32 *ix+ 0 , iy, acc[2 *iy+0 ]);
368369 info.store (32 *ix+16 , iy, acc[2 *iy+1 ]);
369- }
370+ });
370371 }
371372 for (int ix = 32 *(nrc_x/32 ); ix < nrc_x; ix += 16 ) {
372373 __m512 acc[nrc_y] = {};
@@ -377,19 +378,19 @@ static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataI
377378 qx[1 ] = (__m512bh)_mm512_loadu_si512 ((const __m512i *)b8+4 *ib+1 );
378379 qx[2 ] = (__m512bh)_mm512_loadu_si512 ((const __m512i *)b8+4 *ib+2 );
379380 qx[3 ] = (__m512bh)_mm512_loadu_si512 ((const __m512i *)b8+4 *ib+3 );
380- for ( int iy = 0 ; iy < nrc_y; ++ iy) {
381+ static_for< nrc_y>([&]( const int iy) {
381382 auto y128 = _mm_loadu_si128 ((const __m128i*)y[iy]+ib);
382383 auto y256 = MM256_SET_M128I (y128, y128);
383384 auto y = _mm512_inserti32x8 (_mm512_castsi256_si512 (y256), y256, 1 );
384385 acc[iy] = _mm512_dpbf16_ps (acc[iy], qx[0 ], (__m512bh)_mm512_shuffle_epi32 (y, _MM_PERM_ENUM (0x00 )));
385386 acc[iy] = _mm512_dpbf16_ps (acc[iy], qx[1 ], (__m512bh)_mm512_shuffle_epi32 (y, _MM_PERM_ENUM (0x55 )));
386387 acc[iy] = _mm512_dpbf16_ps (acc[iy], qx[2 ], (__m512bh)_mm512_shuffle_epi32 (y, _MM_PERM_ENUM (0xaa )));
387388 acc[iy] = _mm512_dpbf16_ps (acc[iy], qx[3 ], (__m512bh)_mm512_shuffle_epi32 (y, _MM_PERM_ENUM (0xff )));
388- }
389+ });
389390 }
390- for ( int iy = 0 ; iy < nrc_y; ++ iy) {
391+ static_for< nrc_y>([&]( const int iy) {
391392 info.store (ix, iy, acc[iy]);
392- }
393+ });
393394 }
394395}
395396
0 commit comments