Skip to content

Commit ef1d88b

Browse files
authored
Merge branch 'ikawrakow:main' into main
2 parents 6f82dca + 808ce49 commit ef1d88b

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

ggml/src/iqk/iqk_common.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -923,6 +923,15 @@ static IQK_ALWAYS_INLINE void prepare_iq4_nl_quants_r8(const int8x16_t& values,
923923

924924
#endif
925925

926+
// static unrool for:
927+
template<int N, typename T>
928+
inline void static_for(T&&f) {
929+
if constexpr(N>0) {
930+
static_for<N-1>(f);
931+
f(N-1);
932+
}
933+
}
934+
926935
#if defined(_MSC_VER)
927936
#pragma warning(disable: 4244 4267) // possible loss of data
928937
#include <intrin.h>

ggml/src/iqk/iqk_gemm_floats.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,8 @@ template <int nrc_y>
333333
static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
334334
GGML_ASSERT(nrc_x%16 == 0);
335335
const ggml_bf16_t * y[nrc_y];
336-
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const ggml_bf16_t *)info.src1_row(iy);
336+
static_for<nrc_y>([&](const int iy) { y[iy] = (const ggml_bf16_t *)info.src1_row(iy); });
337+
337338
for (int ix = 0; ix < nrc_x/32; ++ix) {
338339
__m512 acc[2*nrc_y] = {};
339340
__m512bh qx[8];
@@ -348,7 +349,7 @@ static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataI
348349
qx[5] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+1);
349350
qx[6] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+2);
350351
qx[7] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+3);
351-
for (int iy = 0; iy < nrc_y; ++iy) {
352+
static_for<nrc_y>([&](const int iy) {
352353
auto y128 = _mm_loadu_si128((const __m128i*)y[iy]+ib);
353354
//auto y = _mm512_broadcast_i32x4(y128);
354355
auto y256 = MM256_SET_M128I(y128, y128);
@@ -361,12 +362,12 @@ static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataI
361362
acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[5], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
362363
acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[6], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
363364
acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[7], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
364-
}
365+
});
365366
}
366-
for (int iy = 0; iy < nrc_y; ++iy) {
367+
static_for<nrc_y>([&](const int iy) {
367368
info.store(32*ix+ 0, iy, acc[2*iy+0]);
368369
info.store(32*ix+16, iy, acc[2*iy+1]);
369-
}
370+
});
370371
}
371372
for (int ix = 32*(nrc_x/32); ix < nrc_x; ix += 16) {
372373
__m512 acc[nrc_y] = {};
@@ -377,19 +378,19 @@ static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataI
377378
qx[1] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+1);
378379
qx[2] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+2);
379380
qx[3] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+3);
380-
for (int iy = 0; iy < nrc_y; ++iy) {
381+
static_for<nrc_y>([&](const int iy) {
381382
auto y128 = _mm_loadu_si128((const __m128i*)y[iy]+ib);
382383
auto y256 = MM256_SET_M128I(y128, y128);
383384
auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1);
384385
acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[0], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
385386
acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[1], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
386387
acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[2], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
387388
acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[3], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
388-
}
389+
});
389390
}
390-
for (int iy = 0; iy < nrc_y; ++iy) {
391+
static_for<nrc_y>([&](const int iy) {
391392
info.store(ix, iy, acc[iy]);
392-
}
393+
});
393394
}
394395
}
395396

0 commit comments

Comments
 (0)