Skip to content

Commit 2034224

Browse files
committed
Merge remote-tracking branch 'origin/compilade/mamba2' into BambaAbstractMemory
This is definitely a surgical merge, so there are likely pieces still missing, especially from llama-model.cpp Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * origin/compilade/mamba2: convert : fix flake8 lint ggml : avoid multiply by D in GGML_OP_SSM_SCAN ggml : remove unused fast broadcast path in GGML_MUL metal : fix wrong number of tokens per sequence in SSM_SCAN metal : fix SSM_SCAN state head offset metal : add back n_seqs to SSM_SCAN args metal : remove unused arguments for SSM_SCAN metal : use log and exp instead of log1pf and expf in SSM_SCAN metal : fix SSM_SCAN pipeline scope metal : attempt to adapt SSM_SCAN for Mamba-2 llama : avoid redundant state copy for Mamba 1 and 2 convert_hf : prefer SentencePiece tokenizer for Mamba-2 when present llama : add missing break llama : remove unused variable llama : fix Mamba-2 conv state saving llama : support running Mamba-Codestral-7B-v0.1 ggml : SIMD ggml_ssm_scan for Mamba-2 llama : initial Mamba-2 support
2 parents e5007a5 + c9ecf62 commit 2034224

20 files changed

+719
-167
lines changed

convert_hf_to_gguf.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4171,6 +4171,91 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
41714171
return [(new_name, data_torch)]
41724172

41734173

4174+
@ModelBase.register("Mamba2ForCausalLM")
4175+
class Mamba2Model(TextModel):
4176+
model_arch = gguf.MODEL_ARCH.MAMBA2
4177+
4178+
def set_vocab(self):
4179+
vocab_size = self.hparams["vocab_size"]
4180+
# Round vocab size to next multiple of 16
4181+
pad_vocab = self.hparams.get("pad_vocab_size_multiple", 16)
4182+
# pad using ceiling division
4183+
# ref: https://stackoverflow.com/a/17511341/22827863
4184+
vocab_size = -(vocab_size // -pad_vocab) * pad_vocab
4185+
self.hparams["vocab_size"] = vocab_size
4186+
4187+
if (self.dir_model / "tokenizer.model").is_file():
4188+
self._set_vocab_sentencepiece()
4189+
elif (self.dir_model / "tokenizer.model.v3").is_file():
4190+
# mamba-codestral
4191+
raise NotImplementedError(f"Please rename {self.dir_model / 'tokenizer.model.v3'} to {self.dir_model / 'tokenizer.model'}")
4192+
elif (self.dir_model / "tokenizer.json").is_file():
4193+
self._set_vocab_gpt2()
4194+
else:
4195+
# Use the GPT-NeoX tokenizer when no tokenizer files are present
4196+
self._set_vocab_builtin("gpt-neox", vocab_size)
4197+
4198+
def set_gguf_parameters(self):
4199+
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
4200+
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
4201+
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
4202+
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128
4203+
head_dim = self.find_hparam(["head_dim"], optional=True) or 64
4204+
n_group = self.find_hparam(["n_groups"], optional=True) or 1
4205+
4206+
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
4207+
4208+
# Fail early for models which don't have a block expansion factor of 2
4209+
# TODO: does this really matter?
4210+
assert d_inner == 2 * d_model
4211+
assert d_inner % head_dim == 0
4212+
4213+
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
4214+
self.gguf_writer.add_embedding_length(d_model)
4215+
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
4216+
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
4217+
self.gguf_writer.add_block_count(self.block_count)
4218+
self.gguf_writer.add_ssm_conv_kernel(d_conv)
4219+
self.gguf_writer.add_ssm_inner_size(d_inner)
4220+
self.gguf_writer.add_ssm_state_size(d_state)
4221+
self.gguf_writer.add_ssm_time_step_rank(d_inner // head_dim)
4222+
self.gguf_writer.add_ssm_group_count(n_group)
4223+
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
4224+
self.gguf_writer.add_file_type(self.ftype)
4225+
4226+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
4227+
4228+
if name.startswith("model.backbone") or name.startswith("model.lm_head"):
4229+
# map Mamba-Codestral-7B-v0.1 tensor names to the names used by Mamba-2
4230+
name = name.removeprefix("model.")
4231+
4232+
if name.endswith(".dt_bias"):
4233+
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
4234+
4235+
new_name = self.map_tensor_name(name)
4236+
4237+
if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid):
4238+
data_torch = data_torch.squeeze()
4239+
elif any(self.match_model_tensor_name(new_name, t, bid, suffix="") for t in [
4240+
gguf.MODEL_TENSOR.SSM_A,
4241+
gguf.MODEL_TENSOR.SSM_D,
4242+
]):
4243+
# unsqueeze A to use similar shape semantics as Mamba-1
4244+
# (D is also unsqueezed, but for more straightforward broadcast internally)
4245+
data_torch = data_torch.reshape((*data_torch.shape, 1))
4246+
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid):
4247+
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
4248+
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
4249+
n_group = self.hparams.get("n_groups", 1)
4250+
data_torch = data_torch.reshape((n_group, d_inner // n_group))
4251+
4252+
if name.endswith(".A_log"):
4253+
logger.debug("A_log --> A ==> " + new_name)
4254+
data_torch = -torch.exp(data_torch)
4255+
4256+
yield (new_name, data_torch)
4257+
4258+
41744259
@ModelBase.register("CohereForCausalLM")
41754260
class CommandR2Model(TextModel):
41764261
model_arch = gguf.MODEL_ARCH.COMMAND_R

ggml/include/ggml.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1854,7 +1854,8 @@ extern "C" {
18541854
struct ggml_tensor * dt,
18551855
struct ggml_tensor * A,
18561856
struct ggml_tensor * B,
1857-
struct ggml_tensor * C);
1857+
struct ggml_tensor * C,
1858+
struct ggml_tensor * ids);
18581859

18591860
// partition into non-overlapping windows with padding if needed
18601861
// example:

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -463,26 +463,25 @@ typedef struct {
463463
typedef struct {
464464
int64_t d_state;
465465
int64_t d_inner;
466+
int64_t n_head;
467+
int64_t n_group;
466468
int64_t n_seq_tokens;
467469
int64_t n_seqs;
468-
uint64_t nb00;
469470
uint64_t nb01;
470471
uint64_t nb02;
471-
uint64_t nb10;
472+
uint64_t nb03;
472473
uint64_t nb11;
473474
uint64_t nb12;
474475
uint64_t nb13;
475-
uint64_t nb20;
476476
uint64_t nb21;
477477
uint64_t nb22;
478-
uint64_t nb30;
479478
uint64_t nb31;
480-
uint64_t nb40;
481479
uint64_t nb41;
482480
uint64_t nb42;
483-
uint64_t nb50;
481+
uint64_t nb43;
484482
uint64_t nb51;
485483
uint64_t nb52;
484+
uint64_t nb53;
486485
} ggml_metal_kargs_ssm_scan;
487486

488487
typedef struct {

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
189189
GGML_METAL_KERNEL_TYPE_NORM,
190190
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
191191
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
192+
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP,
192193
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
193194
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
194195
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
@@ -1125,6 +1126,7 @@ @implementation GGMLMetalClass
11251126
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
11261127
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
11271128
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
1129+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true);
11281130
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
11291131
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
11301132
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
@@ -2586,71 +2588,91 @@ static bool ggml_metal_encode_node(
25862588
struct ggml_tensor * src3 = node->src[3];
25872589
struct ggml_tensor * src4 = node->src[4];
25882590
struct ggml_tensor * src5 = node->src[5];
2591+
struct ggml_tensor * src6 = node->src[6];
25892592

25902593
GGML_ASSERT(src3);
25912594
GGML_ASSERT(src4);
25922595
GGML_ASSERT(src5);
2596+
GGML_ASSERT(src6);
25932597

25942598
size_t offs_src3 = 0;
25952599
size_t offs_src4 = 0;
25962600
size_t offs_src5 = 0;
2601+
size_t offs_src6 = 0;
25972602

25982603
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
25992604
id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
26002605
id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
2606+
id<MTLBuffer> id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil;
26012607

2602-
const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30);
2608+
const int64_t ne30 = src3->ne[0];
26032609
const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
26042610

2605-
const uint64_t nb30 = src3->nb[0];
2611+
const uint64_t nb30 = src3->nb[0]; GGML_UNUSED(nb30);
26062612
const uint64_t nb31 = src3->nb[1];
26072613

26082614
const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
2609-
const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41);
2615+
const int64_t ne41 = src4->ne[1];
26102616
const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
2617+
const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43);
26112618

2612-
const uint64_t nb40 = src4->nb[0];
2619+
const uint64_t nb40 = src4->nb[0]; GGML_UNUSED(nb40);
26132620
const uint64_t nb41 = src4->nb[1];
26142621
const uint64_t nb42 = src4->nb[2];
2622+
const uint64_t nb43 = src4->nb[3];
26152623

26162624
const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50);
26172625
const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51);
26182626
const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
2627+
const int64_t ne53 = src5->ne[2]; GGML_UNUSED(ne53);
26192628

2620-
const uint64_t nb50 = src5->nb[0];
2629+
const uint64_t nb50 = src5->nb[0]; GGML_UNUSED(nb50);
26212630
const uint64_t nb51 = src5->nb[1];
26222631
const uint64_t nb52 = src5->nb[2];
2632+
const uint64_t nb53 = src5->nb[3];
2633+
2634+
const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60);
2635+
2636+
const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60);
26232637

26242638
const int64_t d_state = ne00;
26252639
const int64_t d_inner = ne01;
2626-
const int64_t n_seq_tokens = ne11;
2627-
const int64_t n_seqs = ne02;
2640+
const int64_t n_head = ne02;
2641+
const int64_t n_group = ne41;
2642+
const int64_t n_seq_tokens = ne12;
2643+
const int64_t n_seqs = ne13;
2644+
2645+
id<MTLComputePipelineState> pipeline = nil;
26282646

2629-
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
2647+
if (ne30 == 1) {
2648+
// Mamba-2
2649+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline;
2650+
} else {
2651+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
2652+
}
26302653

26312654
ggml_metal_kargs_ssm_scan args = {
26322655
/*.d_state =*/ d_state,
26332656
/*.d_inner =*/ d_inner,
2657+
/*.n_head =*/ n_head,
2658+
/*.n_group =*/ n_group,
26342659
/*.n_seq_tokens =*/ n_seq_tokens,
26352660
/*.n_seqs =*/ n_seqs,
2636-
/*.nb00 =*/ nb00,
26372661
/*.nb01 =*/ nb01,
26382662
/*.nb02 =*/ nb02,
2639-
/*.nb10 =*/ nb10,
2663+
/*.nb03 =*/ nb03,
26402664
/*.nb11 =*/ nb11,
26412665
/*.nb12 =*/ nb12,
26422666
/*.nb13 =*/ nb13,
2643-
/*.nb20 =*/ nb20,
26442667
/*.nb21 =*/ nb21,
26452668
/*.nb22 =*/ nb22,
2646-
/*.nb30 =*/ nb30,
26472669
/*.nb31 =*/ nb31,
2648-
/*.nb40 =*/ nb40,
26492670
/*.nb41 =*/ nb41,
26502671
/*.nb42 =*/ nb42,
2651-
/*.nb50 =*/ nb50,
2672+
/*.nb43 =*/ nb43,
26522673
/*.nb51 =*/ nb51,
26532674
/*.nb52 =*/ nb52,
2675+
/*.nb53 =*/ nb53,
26542676
};
26552677

26562678
[encoder setComputePipelineState:pipeline];
@@ -2660,10 +2682,18 @@ static bool ggml_metal_encode_node(
26602682
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
26612683
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
26622684
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
2663-
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
2664-
[encoder setBytes:&args length:sizeof(args) atIndex:7];
2685+
[encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
2686+
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
2687+
[encoder setBytes:&args length:sizeof(args) atIndex:8];
2688+
// NOTE: max index is 31
26652689

2666-
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2690+
if (ne30 == 1) {
2691+
// Mamba-2
2692+
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2693+
} else {
2694+
GGML_ASSERT(d_inner == 1);
2695+
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2696+
}
26672697
} break;
26682698
case GGML_OP_RWKV_WKV6:
26692699
{

0 commit comments

Comments
 (0)