Skip to content

Commit b62b1bb

Browse files
committed
fix: Fix mamba / mamba2 creation code reuse
This also fixes the mamba2 n_head value to be correct. Branch: BambaAbstractMemory Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 2034224 commit b62b1bb

File tree

1 file changed

+12
-16
lines changed

1 file changed

+12
-16
lines changed

src/llama-model.cpp

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8729,11 +8729,10 @@ struct llm_build_starcoder2 : public llm_graph_context {
87298729
}
87308730
};
87318731

8732+
template<bool mamba2>
87328733
struct llm_build_mamba : public llm_graph_context {
87338734
const llama_model & model;
87348735

8735-
virtual ~llm_build_mamba() = default;
8736-
87378736
llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) {
87388737
ggml_tensor * cur;
87398738
ggml_tensor * inpL;
@@ -8751,7 +8750,11 @@ struct llm_build_mamba : public llm_graph_context {
87518750
cb(cur, "attn_norm", il);
87528751

87538752
//cur = build_mamba_layer(gf, cur, state_copy, state_mask, il);
8754-
cur = build_mamba_layer(gf, cur, state_copy, ubatch, il);
8753+
if (mamba2) {
8754+
cur = build_mamba2_layer(gf, cur, state_copy, ubatch, il);
8755+
} else {
8756+
cur = build_mamba_layer(gf, cur, state_copy, ubatch, il);
8757+
}
87558758

87568759
if (il == n_layer - 1) {
87578760
// skip computing output for unused tokens
@@ -8788,7 +8791,7 @@ struct llm_build_mamba : public llm_graph_context {
87888791
}
87898792

87908793
// TODO: split
8791-
virtual ggml_tensor * build_mamba_layer(
8794+
ggml_tensor * build_mamba_layer(
87928795
ggml_cgraph * gf,
87938796
ggml_tensor * cur,
87948797
ggml_tensor * state_copy,
@@ -8923,30 +8926,23 @@ struct llm_build_mamba : public llm_graph_context {
89238926

89248927
return cur;
89258928
}
8926-
};
89278929

89288930

8929-
struct llm_build_mamba2 : public llm_build_mamba {
8930-
llm_build_mamba2(
8931-
const llama_model & model,
8932-
const llm_graph_params & params,
8933-
ggml_cgraph * gf) : llm_build_mamba(model, params, gf) {}
8934-
89358931
// Override to build mamba2 layers
8936-
virtual ggml_tensor * build_mamba_layer(
8932+
ggml_tensor * build_mamba2_layer(
89378933
ggml_cgraph * gf,
89388934
ggml_tensor * cur,
89398935
ggml_tensor * state_copy,
89408936
const llama_ubatch & ubatch,
8941-
int il) const override {
8937+
int il) const {
89428938
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
89438939

89448940
const auto kv_head = kv_self->head;
89458941

89468942
const int64_t d_conv = hparams.ssm_d_conv;
89478943
const int64_t d_inner = hparams.ssm_d_inner;
89488944
const int64_t d_state = hparams.ssm_d_state;
8949-
const int64_t n_head = d_inner;
8945+
const int64_t n_head = hparams.ssm_dt_rank;
89508946
const int64_t head_dim = d_inner / n_head;
89518947
const int64_t n_group = hparams.ssm_n_group;
89528948
const int64_t n_seqs = ubatch.n_seqs;
@@ -13201,11 +13197,11 @@ llm_graph_result_ptr llama_model::build_graph(
1320113197
} break;
1320213198
case LLM_ARCH_MAMBA:
1320313199
{
13204-
llm = std::make_unique<llm_build_mamba>(*this, params, gf);
13200+
llm = std::make_unique<llm_build_mamba<false>>(*this, params, gf);
1320513201
} break;
1320613202
case LLM_ARCH_MAMBA2:
1320713203
{
13208-
llm = std::make_unique<llm_build_mamba2>(*this, params, gf);
13204+
llm = std::make_unique<llm_build_mamba<true>>(*this, params, gf);
1320913205
} break;
1321013206
case LLM_ARCH_XVERSE:
1321113207
{

0 commit comments

Comments
 (0)