Skip to content

Commit bc2b292

Browse files
authored
Merge branch 'ikawrakow:main' into main
2 parents b6cc848 + a3737f4 commit bc2b292

File tree

4 files changed

+37
-37
lines changed

4 files changed

+37
-37
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2950,6 +2950,13 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
29502950
int64_t tim1 = ggml_time_us();
29512951
#endif
29522952

2953+
if (ggml_is_noop(dst)) {
2954+
return true;
2955+
}
2956+
2957+
// In case we forget to do that in some kernel.
2958+
ggml_cuda_set_device(ctx.device);
2959+
29532960
auto next = i < cgraph->n_nodes - 1 ? cgraph->nodes[i+1] : nullptr;
29542961

29552962
auto fusion = ctx.fusion;

ggml/src/ggml-cuda/fattn-tile-f32.cu

Lines changed: 26 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
#define FATTN_KQ_STRIDE_TILE_F32 32
1313

14-
template<int Dk, int Dv, int ncols, int nwarps, int parallel_blocks, bool use_softcap> // D == head size
14+
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_softcap> // D == head size
1515
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
1616
__launch_bounds__(nwarps*WARP_SIZE, 1)
1717
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@@ -52,9 +52,9 @@ static __global__ void flash_attn_tile_ext_f32(
5252
const int ne1,
5353
const int ne2,
5454
const int ne3) {
55-
static_assert(Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512));
55+
5656
// Skip unused kernel variants for faster compilation:
57-
if (use_softcap && !(Dk == 128 || Dk == 256)) {
57+
if (use_softcap && !(D == 128 || D == 256)) {
5858
NO_DEVICE_CODE;
5959
return;
6060
}
@@ -70,22 +70,15 @@ static __global__ void flash_attn_tile_ext_f32(
7070
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
7171
const half * maskh = (const half *) mask + ne11*ic0;
7272

73-
const int stride_K2 = nb11 / sizeof(half2);
74-
const int stride_V2 = nb12 / sizeof(half2);
73+
const int stride_KV2 = nb11 / sizeof(half2);
7574

7675
const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
77-
78-
// TODO: is it Dk or Dv or both that need to be multiple of 2*WARP_SIZE ?
79-
// let's assume it is is both.
80-
static_assert(Dk % (2*WARP_SIZE) == 0, "Dk not divisible by 2*WARP_SIZE == 64.");
81-
static_assert(Dv % (2*WARP_SIZE) == 0, "Dv not divisible by 2*WARP_SIZE == 64.");
82-
83-
constexpr int Dkv = Dk < Dv ? Dv : Dk; // let's use this when we don't understand if it is Dk or Dv
76+
static_assert(D % (2 * WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
8477

8578
__shared__ float KQ[ncols*FATTN_KQ_STRIDE_TILE_F32];
8679

87-
// This is being used to store either K or V data. Hence we need max(Dk, Dv) as the dimension
88-
__shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][Dkv + 1]; // Pad D to avoid memory bank conflicts.
80+
__shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][D + 1]; // Pad D to avoid memory bank conflicts.
81+
8982
float2 * KV_tmp2 = (float2 *) KV_tmp;
9083

9184
float kqmax[ncols/nwarps];
@@ -95,16 +88,16 @@ static __global__ void flash_attn_tile_ext_f32(
9588
}
9689
float kqsum[ncols/nwarps] = {0.0f};
9790

98-
float2 VKQ[ncols/nwarps][(Dv/2)/WARP_SIZE] = {{{0.0f, 0.0f}}};
91+
float2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}};
9992

10093
// Convert Q to half2 and store in registers:
101-
__shared__ float Q_f[ncols][Dk];
94+
__shared__ float Q_f[ncols][D];
10295
#pragma unroll
10396
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
10497
const int j = j0 + threadIdx.y;
10598

10699
#pragma unroll
107-
for (int i0 = 0; i0 < Dk; i0 += 2*WARP_SIZE) {
100+
for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) {
108101
float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x] : make_float2(0.0f, 0.0f);
109102
Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale;
110103
Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale;
@@ -128,8 +121,8 @@ static __global__ void flash_attn_tile_ext_f32(
128121
const int i_KQ = i_KQ_0 + threadIdx.y;
129122

130123
#pragma unroll
131-
for (int k_KQ_0 = 0; k_KQ_0 < Dk; k_KQ_0 += 2*WARP_SIZE) {
132-
const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_K2 + k_KQ_0/2 + threadIdx.x];
124+
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) {
125+
const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
133126
KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp);
134127
KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp);
135128
}
@@ -140,7 +133,7 @@ static __global__ void flash_attn_tile_ext_f32(
140133
float sum[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE][ncols/nwarps] = {{0.0f}};
141134

142135
#pragma unroll
143-
for (int k_KQ = 0; k_KQ < Dk; ++k_KQ) {
136+
for (int k_KQ = 0; k_KQ < D; ++k_KQ) {
144137
float K_k[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE];
145138
float Q_k[ncols/nwarps];
146139

@@ -209,7 +202,7 @@ static __global__ void flash_attn_tile_ext_f32(
209202
kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add;
210203

211204
#pragma unroll
212-
for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) {
205+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
213206
VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale;
214207
VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale;
215208
}
@@ -222,26 +215,26 @@ static __global__ void flash_attn_tile_ext_f32(
222215
const int k = k0 + threadIdx.y;
223216

224217
#pragma unroll
225-
for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) {
218+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
226219
const int i = i0 + threadIdx.x;
227220

228-
KV_tmp2[k*(Dv/2) + i].x = __low2float(V_h2[(k_VKQ_0 + k)*stride_V2 + i]);
229-
KV_tmp2[k*(Dv/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_V2 + i]);
221+
KV_tmp2[k*(D/2) + i].x = __low2float(V_h2[(k_VKQ_0 + k)* stride_KV2 + i]);
222+
KV_tmp2[k*(D/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)* stride_KV2 + i]);
230223
}
231224
}
232225

233226
__syncthreads();
234227

235228
#pragma unroll
236229
for (int k = 0; k < FATTN_KQ_STRIDE_TILE_F32; ++k) {
237-
float2 V_k[(Dv/2)/WARP_SIZE];
230+
float2 V_k[(D/2)/WARP_SIZE];
238231
float KQ_k[ncols/nwarps];
239232

240233
#pragma unroll
241-
for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) {
234+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
242235
const int i = i0 + threadIdx.x;
243236

244-
V_k[i0/WARP_SIZE] = KV_tmp2[k*(Dv/2) + i];
237+
V_k[i0/WARP_SIZE] = KV_tmp2[k*(D/2) + i];
245238
}
246239
#pragma unroll
247240
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
@@ -251,7 +244,7 @@ static __global__ void flash_attn_tile_ext_f32(
251244
}
252245

253246
#pragma unroll
254-
for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) {
247+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
255248
#pragma unroll
256249
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
257250
VKQ[j0/nwarps][i0/WARP_SIZE].x += V_k[i0/WARP_SIZE].x*KQ_k[j0/nwarps];
@@ -275,7 +268,7 @@ static __global__ void flash_attn_tile_ext_f32(
275268
kqsum_j = warp_reduce_sum(kqsum_j);
276269

277270
#pragma unroll
278-
for (int i00 = 0; i00 < Dv; i00 += 2*WARP_SIZE) {
271+
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
279272
const int i0 = i00 + 2*threadIdx.x;
280273

281274
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
@@ -284,8 +277,8 @@ static __global__ void flash_attn_tile_ext_f32(
284277
dst_val.y /= kqsum_j;
285278
}
286279
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
287-
dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + i0 + 0] = dst_val.x;
288-
dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + i0 + 1] = dst_val.y;
280+
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = dst_val.x;
281+
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = dst_val.y;
289282
}
290283

291284
if (parallel_blocks != 1 && threadIdx.x == 0) {
@@ -301,13 +294,13 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
301294
case 64: {
302295
constexpr int D = 64;
303296
constexpr int nwarps = 8;
304-
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
297+
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
305298
launch_fattn<D, D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
306299
} break;
307300
case 128: {
308301
constexpr int D = 128;
309302
constexpr int nwarps = 8;
310-
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
303+
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
311304
launch_fattn<D, D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
312305
} break;
313306
default: {

ggml/src/ggml.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6156,9 +6156,9 @@ struct ggml_tensor * ggml_mul_multi_add(
61566156
#include <nmmintrin.h>
61576157
#include <immintrin.h>
61586158
#include <stdlib.h>
6159-
inline int popcount(uint32_t x) { return __popcnt(x); }
6159+
static inline int popcount(uint32_t x) { return __popcnt(x); }
61606160
#else
6161-
inline int popcount(uint32_t x) { return __builtin_popcount(x); }
6161+
static inline int popcount(uint32_t x) { return __builtin_popcount(x); }
61626162
#endif
61636163

61646164
struct ggml_tensor * ggml_hadamard(

src/llama-load-tensors.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ void create_tensors_helper::create_std_ffn(int i, const LLM_TN & tn, llama_layer
388388

389389
bool create_tensors_helper::create_llama_tensors(const LLM_TN & tn) {
390390
LOADING_PRELUDE
391-
create_embd_output(tn, n_embd, n_vocab, true, true);
391+
create_embd_output(tn, n_embd, n_vocab, true, false); //true);
392392

393393
for (int i = 0; i < n_layer; ++i) {
394394
ggml_context * ctx_layer = ctx_for_layer(i);
@@ -1843,7 +1843,7 @@ bool create_tensors_helper::create_glm4_moe_tensors(const LLM_TN & tn) {
18431843
GGML_ASSERT(hparams.n_expert > 0 && "n_expert must be > 0 for GLM4_MOE MoE layers");
18441844
GGML_ASSERT(hparams.n_expert_used > 0 && "n_expert_used must be > 0 for GLM4_MOE MoE layers");
18451845

1846-
create_embd_output(tn, n_embd, n_vocab, true, true);
1846+
create_embd_output(tn, n_embd, n_vocab, true, false); //true);
18471847

18481848
for (int i = 0; i < n_layer; ++i) {
18491849
ggml_context * ctx_layer = ctx_for_layer(i);

0 commit comments

Comments
 (0)