Skip to content

Commit bc930cd

Browse files
committed
RWKV6[QWEN2]: Concat lerp weights together to reduce cpu overhead
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
1 parent fab0aa7 commit bc930cd

File tree

8 files changed

+75
-98
lines changed

8 files changed

+75
-98
lines changed

convert_hf_to_gguf.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ def prepare_tensors(self):
326326
gguf.MODEL_TENSOR.TIME_MIX_W2,
327327
gguf.MODEL_TENSOR.TIME_MIX_DECAY_W1,
328328
gguf.MODEL_TENSOR.TIME_MIX_DECAY_W2,
329+
gguf.MODEL_TENSOR.TIME_MIX_LERP_FUSED,
329330
gguf.MODEL_TENSOR.POSNET_NORM1,
330331
gguf.MODEL_TENSOR.POSNET_NORM2,
331332
)
@@ -3256,6 +3257,7 @@ def set_gguf_parameters(self):
32563257
# required by llama.cpp, unused
32573258
self.gguf_writer.add_head_count(0)
32583259

3260+
lerp_weights: dict[int, dict[str, Tensor]] = {}
32593261
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
32603262
new_name = self.map_tensor_name(name)
32613263

@@ -3271,16 +3273,32 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
32713273
if new_name.endswith("time_mix_decay.weight") or "lerp" in new_name:
32723274
data_torch = data_torch.squeeze()
32733275

3274-
rescale_every_n_layers = self.hparams["rescale_every"]
3275-
if rescale_every_n_layers > 0:
3276-
if new_name.endswith("time_mix_output.weight") or new_name.endswith("channel_mix_value.weight"):
3277-
data_torch = data_torch.div_(2 ** int(bid // rescale_every_n_layers))
3276+
try:
3277+
rescale_every_n_layers = self.hparams["rescale_every"]
3278+
if rescale_every_n_layers > 0:
3279+
if new_name.endswith("time_mix_output.weight") or new_name.endswith("channel_mix_value.weight"):
3280+
data_torch = data_torch.div_(2 ** int(bid // rescale_every_n_layers))
3281+
except KeyError:
3282+
pass
3283+
3284+
# concat time_mix_lerp weights to reduce some cpu overhead
3285+
# also reduces the number of tensors in the model
3286+
if bid is not None and "time_mix_lerp" in new_name and not "time_mix_lerp_x" in new_name:
3287+
try:
3288+
self.lerp_weights[bid][new_name] = data_torch
3289+
except KeyError:
3290+
self.lerp_weights[bid] = {new_name: data_torch}
3291+
if all(f"blk.{bid}.time_mix_lerp_{i}.weight" in self.lerp_weights[bid].keys() for i in ["w", "k", "v", "r", "g"]):
3292+
new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
3293+
data = torch.stack([self.lerp_weights[bid][f"blk.{bid}.time_mix_lerp_{i}.weight"].unsqueeze(0) for i in ["w", "k", "v", "r", "g"]], dim=0).unsqueeze(1)
3294+
yield (new_name, data)
3295+
return
32783296

32793297
yield (new_name, data_torch)
32803298

32813299

32823300
@Model.register("RWKV6Qwen2ForCausalLM")
3283-
class RWKV6Qwen2Model(Model):
3301+
class RWKV6Qwen2Model(Rwkv6Model):
32843302
model_arch = gguf.MODEL_ARCH.RWKV6QWEN2
32853303

32863304
def set_vocab(self):
@@ -3320,21 +3338,17 @@ def set_gguf_parameters(self):
33203338
self.gguf_writer.add_head_count(0)
33213339

33223340
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3323-
new_name = self.map_tensor_name(name)
3324-
3325-
if not (new_name.endswith(".weight") or new_name.endswith(".bias")):
3326-
new_name += ".weight"
3327-
3328-
if new_name.endswith("time_mix_w1.weight") or new_name.endswith("time_mix_decay_w1.weight") or new_name.endswith("time_mix_decay_w2.weight"):
3329-
data_torch = data_torch.transpose(0, 1)
3330-
3331-
if new_name.endswith("time_mix_w2.weight"):
3332-
data_torch = data_torch.permute(0, 2, 1)
3333-
3334-
if new_name.endswith("time_mix_decay.weight") or "lerp" in new_name:
3335-
data_torch = data_torch.squeeze()
3336-
3337-
yield (new_name, data_torch)
3341+
for new_name, data in super().modify_tensors(data_torch, name, bid):
3342+
if "time_mix_w1" in new_name or "time_mix_w2" in new_name:
3343+
data = data.view(5, -1, data.shape[-1])
3344+
# rwkv6qwen2 has a different order of rkvwg instead of the original wkvrg
3345+
# permute them here to avoid code changes
3346+
data = torch.stack([data[3], data[1], data[2], data[0], data[4]], dim=0).view(-1, data.shape[-1])
3347+
if "w2" in new_name:
3348+
data = data.view(5, -1, data.shape[-1])
3349+
yield (new_name, data)
3350+
continue
3351+
yield (new_name, data)
33383352

33393353

33403354
@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")

gguf-py/gguf/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ class MODEL_TENSOR(IntEnum):
331331
TIME_MIX_LERP_V = auto()
332332
TIME_MIX_LERP_R = auto()
333333
TIME_MIX_LERP_G = auto()
334+
TIME_MIX_LERP_FUSED = auto()
334335
TIME_MIX_LERP_W = auto()
335336
TIME_MIX_FIRST = auto()
336337
TIME_MIX_DECAY = auto()
@@ -514,6 +515,7 @@ class MODEL_TENSOR(IntEnum):
514515
MODEL_TENSOR.TIME_MIX_LERP_V: "blk.{bid}.time_mix_lerp_v",
515516
MODEL_TENSOR.TIME_MIX_LERP_R: "blk.{bid}.time_mix_lerp_r",
516517
MODEL_TENSOR.TIME_MIX_LERP_G: "blk.{bid}.time_mix_lerp_g",
518+
MODEL_TENSOR.TIME_MIX_LERP_FUSED: "blk.{bid}.time_mix_lerp_fused",
517519
MODEL_TENSOR.TIME_MIX_LERP_W: "blk.{bid}.time_mix_lerp_w",
518520
MODEL_TENSOR.TIME_MIX_FIRST: "blk.{bid}.time_mix_first",
519521
MODEL_TENSOR.TIME_MIX_DECAY: "blk.{bid}.time_mix_decay",
@@ -1080,6 +1082,7 @@ class MODEL_TENSOR(IntEnum):
10801082
MODEL_TENSOR.TIME_MIX_LERP_R,
10811083
MODEL_TENSOR.TIME_MIX_LERP_G,
10821084
MODEL_TENSOR.TIME_MIX_LERP_W,
1085+
MODEL_TENSOR.TIME_MIX_LERP_FUSED,
10831086
MODEL_TENSOR.TIME_MIX_FIRST,
10841087
MODEL_TENSOR.TIME_MIX_DECAY,
10851088
MODEL_TENSOR.TIME_MIX_DECAY_W1,
@@ -1109,6 +1112,7 @@ class MODEL_TENSOR(IntEnum):
11091112
MODEL_TENSOR.TIME_MIX_LERP_R,
11101113
MODEL_TENSOR.TIME_MIX_LERP_G,
11111114
MODEL_TENSOR.TIME_MIX_LERP_W,
1115+
MODEL_TENSOR.TIME_MIX_LERP_FUSED,
11121116
MODEL_TENSOR.TIME_MIX_FIRST,
11131117
MODEL_TENSOR.TIME_MIX_DECAY,
11141118
MODEL_TENSOR.TIME_MIX_DECAY_W1,

src/llama-arch.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,11 +1154,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
11541154
{ LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
11551155
{ LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
11561156
{ LLM_TENSOR_TIME_MIX_LERP_X, "blk.%d.time_mix_lerp_x" },
1157-
{ LLM_TENSOR_TIME_MIX_LERP_W, "blk.%d.time_mix_lerp_w" },
1158-
{ LLM_TENSOR_TIME_MIX_LERP_K, "blk.%d.time_mix_lerp_k" },
1159-
{ LLM_TENSOR_TIME_MIX_LERP_V, "blk.%d.time_mix_lerp_v" },
1160-
{ LLM_TENSOR_TIME_MIX_LERP_R, "blk.%d.time_mix_lerp_r" },
1161-
{ LLM_TENSOR_TIME_MIX_LERP_G, "blk.%d.time_mix_lerp_g" },
1157+
{ LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
11621158
{ LLM_TENSOR_TIME_MIX_FIRST, "blk.%d.time_mix_first" },
11631159
{ LLM_TENSOR_TIME_MIX_DECAY, "blk.%d.time_mix_decay" },
11641160
{ LLM_TENSOR_TIME_MIX_DECAY_W1, "blk.%d.time_mix_decay_w1" },
@@ -1356,6 +1352,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
13561352
{LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
13571353
{LLM_TENSOR_TIME_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
13581354
{LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
1355+
{LLM_TENSOR_TIME_MIX_LERP_FUSED, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
13591356
{LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
13601357
{LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}},
13611358
{LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ enum llm_tensor {
250250
LLM_TENSOR_TIME_MIX_LERP_V,
251251
LLM_TENSOR_TIME_MIX_LERP_R,
252252
LLM_TENSOR_TIME_MIX_LERP_G,
253+
LLM_TENSOR_TIME_MIX_LERP_FUSED,
253254
LLM_TENSOR_TIME_MIX_FIRST,
254255
LLM_TENSOR_TIME_MIX_DECAY,
255256
LLM_TENSOR_TIME_MIX_DECAY_W1,

src/llama-model.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2164,6 +2164,7 @@ bool llama_model_is_recurrent(const struct llama_model * model) {
21642164
switch (model->arch) {
21652165
case LLM_ARCH_MAMBA: return true;
21662166
case LLM_ARCH_RWKV6: return true;
2167+
case LLM_ARCH_RWKV6QWEN2: return true;
21672168
default: return false;
21682169
}
21692170
}

src/llama-model.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ struct llama_layer {
238238
struct ggml_tensor * time_mix_lerp_v = nullptr;
239239
struct ggml_tensor * time_mix_lerp_r = nullptr;
240240
struct ggml_tensor * time_mix_lerp_g = nullptr;
241+
struct ggml_tensor * time_mix_lerp_fused = nullptr;
241242

242243
struct ggml_tensor * time_mix_first = nullptr;
243244
struct ggml_tensor * time_mix_decay = nullptr;

src/llama-quant.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
760760
quantize &= name.find("time_mix_w2.weight") == std::string::npos;
761761
quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
762762
quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
763+
quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
763764

764765
// do not quantize relative position bias (T5)
765766
quantize &= name.find("attn_rel_b.weight") == std::string::npos;

src/llama.cpp

Lines changed: 31 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -2123,11 +2123,13 @@ static bool llm_load_tensors(
21232123
layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0);
21242124

21252125
layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0);
2126-
layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, 0);
2127-
layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0);
2128-
layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, 0);
2129-
layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, 0);
2130-
layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, 0);
2126+
layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
2127+
layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
2128+
layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
2129+
layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
2130+
layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
2131+
layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, llama_model_loader::TENSOR_NOT_REQUIRED);
2132+
GGML_ASSERT(!(layer.time_mix_lerp_fused == NULL && layer.time_mix_lerp_w == NULL));
21312133

21322134
layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, 0);
21332135
layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0);
@@ -2180,11 +2182,7 @@ static bool llm_load_tensors(
21802182
layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0);
21812183

21822184
layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0);
2183-
layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, 0);
2184-
layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0);
2185-
layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, 0);
2186-
layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, 0);
2187-
layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, 0);
2185+
layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0);
21882186

21892187
layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
21902188
layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0);
@@ -3337,72 +3335,32 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
33373335
xxx
33383336
);
33393337

3340-
struct ggml_tensor *mw, *mk, *mv, *mr, *mg;
3341-
if (is_qrwkv) {
3342-
// Why the f*** do they change the order here?
3343-
mr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
3344-
mk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
3345-
mv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
3346-
mw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
3347-
mg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
3338+
struct ggml_tensor *xw, *xk, *xv, *xr, *xg;
3339+
if (layer->time_mix_lerp_fused) {
3340+
// fusing these weights makes some performance improvement
3341+
sx = ggml_reshape_3d(ctx, sx, n_embd, 1, n_tokens);
3342+
cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens);
3343+
xxx = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xxx, layer->time_mix_lerp_fused), sx), cur);
3344+
xw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
3345+
xk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
3346+
xv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
3347+
xr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
3348+
xg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
33483349
} else {
3349-
mw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
3350-
mk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
3351-
mv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
3352-
mr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
3353-
mg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
3350+
// for backward compatibility
3351+
xw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
3352+
xk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
3353+
xv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
3354+
xr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
3355+
xg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
3356+
3357+
xw = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xw, layer->time_mix_lerp_w), sx), cur);
3358+
xk = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xk, layer->time_mix_lerp_k), sx), cur);
3359+
xv = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xv, layer->time_mix_lerp_v), sx), cur);
3360+
xr = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xr, layer->time_mix_lerp_r), sx), cur);
3361+
xg = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xg, layer->time_mix_lerp_g), sx), cur);
33543362
}
33553363

3356-
struct ggml_tensor * xw = ggml_add(
3357-
ctx,
3358-
ggml_mul(
3359-
ctx,
3360-
ggml_add(ctx, mw, layer->time_mix_lerp_w),
3361-
sx
3362-
),
3363-
cur
3364-
);
3365-
3366-
struct ggml_tensor * xk = ggml_add(
3367-
ctx,
3368-
ggml_mul(
3369-
ctx,
3370-
ggml_add(ctx, mk, layer->time_mix_lerp_k),
3371-
sx
3372-
),
3373-
cur
3374-
);
3375-
3376-
struct ggml_tensor * xv = ggml_add(
3377-
ctx,
3378-
ggml_mul(
3379-
ctx,
3380-
ggml_add(ctx, mv, layer->time_mix_lerp_v),
3381-
sx
3382-
),
3383-
cur
3384-
);
3385-
3386-
struct ggml_tensor * xr = ggml_add(
3387-
ctx,
3388-
ggml_mul(
3389-
ctx,
3390-
ggml_add(ctx, mr, layer->time_mix_lerp_r),
3391-
sx
3392-
),
3393-
cur
3394-
);
3395-
3396-
struct ggml_tensor * xg = ggml_add(
3397-
ctx,
3398-
ggml_mul(
3399-
ctx,
3400-
ggml_add(ctx, mg, layer->time_mix_lerp_g),
3401-
sx
3402-
),
3403-
cur
3404-
);
3405-
34063364
struct ggml_tensor * r = llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr);
34073365
struct ggml_tensor * k = llm_build_lora_mm(lctx, ctx, layer->time_mix_key, xk);
34083366
struct ggml_tensor * v = llm_build_lora_mm(lctx, ctx, layer->time_mix_value, xv);

0 commit comments

Comments
 (0)