@@ -8729,11 +8729,10 @@ struct llm_build_starcoder2 : public llm_graph_context {
87298729 }
87308730};
87318731
8732+ template<bool mamba2>
87328733struct llm_build_mamba : public llm_graph_context {
87338734 const llama_model & model;
87348735
8735- virtual ~llm_build_mamba() = default;
8736-
87378736 llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) {
87388737 ggml_tensor * cur;
87398738 ggml_tensor * inpL;
@@ -8751,7 +8750,11 @@ struct llm_build_mamba : public llm_graph_context {
87518750 cb(cur, "attn_norm", il);
87528751
87538752 //cur = build_mamba_layer(gf, cur, state_copy, state_mask, il);
8754- cur = build_mamba_layer(gf, cur, state_copy, ubatch, il);
8753+ if (mamba2) {
8754+ cur = build_mamba2_layer(gf, cur, state_copy, ubatch, il);
8755+ } else {
8756+ cur = build_mamba_layer(gf, cur, state_copy, ubatch, il);
8757+ }
87558758
87568759 if (il == n_layer - 1) {
87578760 // skip computing output for unused tokens
@@ -8788,7 +8791,7 @@ struct llm_build_mamba : public llm_graph_context {
87888791 }
87898792
87908793 // TODO: split
8791- virtual ggml_tensor * build_mamba_layer(
8794+ ggml_tensor * build_mamba_layer(
87928795 ggml_cgraph * gf,
87938796 ggml_tensor * cur,
87948797 ggml_tensor * state_copy,
@@ -8923,30 +8926,23 @@ struct llm_build_mamba : public llm_graph_context {
89238926
89248927 return cur;
89258928 }
8926- };
89278929
89288930
8929- struct llm_build_mamba2 : public llm_build_mamba {
8930- llm_build_mamba2(
8931- const llama_model & model,
8932- const llm_graph_params & params,
8933- ggml_cgraph * gf) : llm_build_mamba(model, params, gf) {}
8934-
89358931 // Override to build mamba2 layers
8936- virtual ggml_tensor * build_mamba_layer (
8932+ ggml_tensor * build_mamba2_layer (
89378933 ggml_cgraph * gf,
89388934 ggml_tensor * cur,
89398935 ggml_tensor * state_copy,
89408936 const llama_ubatch & ubatch,
8941- int il) const override {
8937+ int il) const {
89428938 const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
89438939
89448940 const auto kv_head = kv_self->head;
89458941
89468942 const int64_t d_conv = hparams.ssm_d_conv;
89478943 const int64_t d_inner = hparams.ssm_d_inner;
89488944 const int64_t d_state = hparams.ssm_d_state;
8949- const int64_t n_head = d_inner ;
8945+ const int64_t n_head = hparams.ssm_dt_rank ;
89508946 const int64_t head_dim = d_inner / n_head;
89518947 const int64_t n_group = hparams.ssm_n_group;
89528948 const int64_t n_seqs = ubatch.n_seqs;
@@ -13201,11 +13197,11 @@ llm_graph_result_ptr llama_model::build_graph(
1320113197 } break;
1320213198 case LLM_ARCH_MAMBA:
1320313199 {
13204- llm = std::make_unique<llm_build_mamba>(*this, params, gf);
13200+ llm = std::make_unique<llm_build_mamba<false> >(*this, params, gf);
1320513201 } break;
1320613202 case LLM_ARCH_MAMBA2:
1320713203 {
13208- llm = std::make_unique<llm_build_mamba2 >(*this, params, gf);
13204+ llm = std::make_unique<llm_build_mamba<true> >(*this, params, gf);
1320913205 } break;
1321013206 case LLM_ARCH_XVERSE:
1321113207 {
0 commit comments