@@ -2123,11 +2123,13 @@ static bool llm_load_tensors(
21232123 layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0);
21242124
21252125 layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0);
2126- layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, 0);
2127- layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0);
2128- layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, 0);
2129- layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, 0);
2130- layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, 0);
2126+ layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
2127+ layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
2128+ layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
2129+ layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
2130+ layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
2131+ layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, llama_model_loader::TENSOR_NOT_REQUIRED);
2132+ GGML_ASSERT(!(layer.time_mix_lerp_fused == NULL && layer.time_mix_lerp_w == NULL));
21312133
21322134 layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, 0);
21332135 layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0);
@@ -2180,11 +2182,7 @@ static bool llm_load_tensors(
21802182 layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0);
21812183
21822184 layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0);
2183- layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, 0);
2184- layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0);
2185- layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, 0);
2186- layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, 0);
2187- layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, 0);
2185+ layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0);
21882186
21892187 layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
21902188 layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0);
@@ -3337,72 +3335,32 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
33373335 xxx
33383336 );
33393337
3340- struct ggml_tensor *mw, *mk, *mv, *mr, *mg;
3341- if (is_qrwkv) {
3342- // Why the f*** do they change the order here?
3343- mr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
3344- mk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
3345- mv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
3346- mw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
3347- mg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
3338+ struct ggml_tensor *xw, *xk, *xv, *xr, *xg;
3339+ if (layer->time_mix_lerp_fused) {
3340+ // fusing these weights makes some performance improvement
3341+ sx = ggml_reshape_3d(ctx, sx, n_embd, 1, n_tokens);
3342+ cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens);
3343+ xxx = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xxx, layer->time_mix_lerp_fused), sx), cur);
3344+ xw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
3345+ xk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
3346+ xv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
3347+ xr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
3348+ xg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
33483349 } else {
3349- mw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
3350- mk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
3351- mv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
3352- mr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
3353- mg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
3350+ // for backward compatibility
3351+ xw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
3352+ xk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
3353+ xv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
3354+ xr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
3355+ xg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
3356+
3357+ xw = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xw, layer->time_mix_lerp_w), sx), cur);
3358+ xk = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xk, layer->time_mix_lerp_k), sx), cur);
3359+ xv = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xv, layer->time_mix_lerp_v), sx), cur);
3360+ xr = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xr, layer->time_mix_lerp_r), sx), cur);
3361+ xg = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xg, layer->time_mix_lerp_g), sx), cur);
33543362 }
33553363
3356- struct ggml_tensor * xw = ggml_add(
3357- ctx,
3358- ggml_mul(
3359- ctx,
3360- ggml_add(ctx, mw, layer->time_mix_lerp_w),
3361- sx
3362- ),
3363- cur
3364- );
3365-
3366- struct ggml_tensor * xk = ggml_add(
3367- ctx,
3368- ggml_mul(
3369- ctx,
3370- ggml_add(ctx, mk, layer->time_mix_lerp_k),
3371- sx
3372- ),
3373- cur
3374- );
3375-
3376- struct ggml_tensor * xv = ggml_add(
3377- ctx,
3378- ggml_mul(
3379- ctx,
3380- ggml_add(ctx, mv, layer->time_mix_lerp_v),
3381- sx
3382- ),
3383- cur
3384- );
3385-
3386- struct ggml_tensor * xr = ggml_add(
3387- ctx,
3388- ggml_mul(
3389- ctx,
3390- ggml_add(ctx, mr, layer->time_mix_lerp_r),
3391- sx
3392- ),
3393- cur
3394- );
3395-
3396- struct ggml_tensor * xg = ggml_add(
3397- ctx,
3398- ggml_mul(
3399- ctx,
3400- ggml_add(ctx, mg, layer->time_mix_lerp_g),
3401- sx
3402- ),
3403- cur
3404- );
3405-
34063364 struct ggml_tensor * r = llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr);
34073365 struct ggml_tensor * k = llm_build_lora_mm(lctx, ctx, layer->time_mix_key, xk);
34083366 struct ggml_tensor * v = llm_build_lora_mm(lctx, ctx, layer->time_mix_value, xv);
0 commit comments