-
Notifications
You must be signed in to change notification settings - Fork 14k
Mamba2 SSD #16982
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Mamba2 SSD #16982
Conversation
|
Yeah, I had an issue with |
|
Regarding the chunking: won't this explode the graph a lot? In case of Delta Net attention, since you have to use triangular solve there, you don't want the chunk size over 64 or performance drops drastically. But that means that you're going to go up to 8 chunks for a typical ubatch size of 512. The graph for Qwen3 Next already has 9000 nodes. I'm a bit afraid of doing chunking this way (and I know @ggerganov had strong objections too). |
Yep, it sure will. I also suspect this as one of the reasons this is slower currently. I don't think SSD has the same need for chunking based on computational complexity, so I think it's mostly there for memory overhead management. |
|
I've been further experimenting with a few tweaks to get more performance out of this.
Local notes using variants of the following command: ./bin/llama-batched-bench -m ~/models/ibm-granite/granite-4.0-h-1b/granite-4.0-h-1B-BF16-exp.gguf -c 2048 -b 2048 -ub 512 -npp 128,256 -ntg 128 -npl 1,2,4 -ngl 99NOTE: Baseline SSM_SCANF16 cache w/ F32 conv
F32 cache / F32 conv
F16 cache w/ BF16 conv
With SSDF16 cache w/ F32 conv
F32 cache / F32 conv
F16 cache w/ BF16 conv
F16 cache w/ BF16 conv and SSD cast at end
F16 cache w/ BF16 conv, SSD cast at end, and no sub-ubatch batching
|
Probably you have to use large ubatch and do some chunking in order to get some benefits from the SSD. But I don't have a good estimate about what the optimal sizes would be. At the default ubatch of 512, you can do the following experiment on make -j && ./bin/llama-bench -m ../models/granite-4-h-tiny/ggml-model-q8_0.gguf -fa 1 -t 1 -p 2048 -ub 512 -n 0
Now make the ssm scan a noop and run the test again: diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index 424c400f2..7881e63e0 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -2129,6 +2129,7 @@ kernel void kernel_ssm_scan_f32(
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgptg[[simdgroups_per_threadgroup]],
uint3 tgpg[[threadgroups_per_grid]]) {
+ return;
constexpr short NW = N_SIMDWIDTH;
shared[tpitg.x] = 0.0f;
This is the upper bound that you would get at this ubatch size. I.e. any SSD implementation will not be faster than this. Increasing the ubatch size increases the gap, so it gives more room for a good SSD implementation to outperform the ssm scan. In any case, first step seems to be to reduce the amount of ops, permutations, conts in the SSD branch as much as possible. |
🤦 I feel really silly for not figuring this trick out. I've been snipping out chunks of the graph and trying to coerce the input/output tensors to the same shape to simulate this upper bound part!
This makes a lot of sense! I'll do some large-ubatch experiments to see if the current code may already be at a cross-over point where SSD can start offering better performance with larger batches. The speed advantages are very much supposed to be primarily felt at longer context which likely also means longer ubatches. |
|
It looks like the current code is not there yet and in fact starts to degrade further when |
|
I guess it's expected since it does not have the chunking logic. |
|
@gabe-l-hart you can look at the discussion in the Qwen3 Next thread, but basically, if you ever use the recurrent update logic aka |
4435600 to
d2779ae
Compare
|
Now that we've got the underlying ops merged, I've redone the core SSD changes here. It's still quite a bit slower than with SSM_SCAN, so it still needs optimization work. |
It builds but doesn't run yet Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
d2779ae to
8ba4d39
Compare
|
@gabe-l-hart regarding max-nodes you might want to also look at #17794 |
DRAFT STATUS
This PR will remain in
Draftuntil the items in the discussion section are resolved.Description
This PR is a draft implementation of the Structured Statespace Duality described in the original mamba2 paper which reframes the
SSM_SCANop as a pseudo-attention operation. The paper describes it in great detail, but the short version is that when performing a multi-token scan, the recurrent formulation ofSSM_SCANis inefficient because it cannot parallelize over the sequence dimension the way an attention calculation can. With the SSD formulation, the logical attention matrix is decomposed into chunks and the state is updated at the chunk boundaries, allowing prefill to "jump" by the size of the chunk rather than proceed with tokens one-at-a-time.Reference Links
mlx-lm: https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/ssm.pyChanges
Introduce new primitive operations in
ggml:ggml_cumsum/ggml_cumsum_0: Perform a cumulative sum along a give dimensionggml_tri_dims/ggml_tri/ggml_tri_keep: Apply a triangular mask to the given matrixggml_tri_dimsggml_softplus: Perform the unarysoftplusoperationImplement an alternate path through
llm_graph_context_mamba::build_mamba2_layerwhen a multi-token update is detectedSSM_SCANin favor of the chunked pseudo-attention formulationDiscussion
There are a number of outstanding discussion points on this work that need to be resolved before moving it forward:
SSM_SCANwhich roundly defeats the purpose of the change! I suspect that the performance issues are due to the number ofggml_permute/ggml_contops that are added to the graph, but could use assistance figuring out how to eliminate them or identifying other sources of slowness.ubatchchunking implemented. I had it mostly working before the corresponding discussion on Qwen3Next. The inter-chunk update would be needed anyway, so I didn't strip it out, but it would be fairly trivial to do so and might offer some performance improvements.repeat_interleave: Similar to the issue that came up when initially implementingNemotronHsupport, I believe thatggml_repeatbehaves differently thanmx.repeat, resulting in incorrect results for models withn_groups > 1(tested withNemotronH).Testing
I've tested this locally with various members of the Granite 4 family and with
nvidia/NVIDIA-Nemotron-Nano-9B-v2. For the Granite 4 models withn_groups == 1, I get nearly identical results to running with purelySSM_SCAN, butNemotronHstill struggles due torepeat_interleaveissues (see above). I'll flesh out more testing results once we've worked through some of the above issues.cc @compilade since I know this has been on your TODO list since the original
mamba2implementation.