Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
429 commits
Select commit Hold shift + click to select a range
782c946
Add cb logging for indexer_k_cache_head
createthis Oct 24, 2025
1d5a878
More cb logs for indexer
createthis Oct 24, 2025
c761a8f
arguing with sentient rocks
createthis Oct 24, 2025
cc8b7ef
- Ignored the micro-window n_kv argument for the indexer path and
createthis Oct 24, 2025
e53b2ed
Fix crash (hopefully)
createthis Oct 24, 2025
5ef5a53
Another fix
createthis Oct 24, 2025
9bbd46e
Add clamp and get_k_full/get_v_full (for future use, apparently)
createthis Oct 24, 2025
d38859a
Fix warning
createthis Oct 25, 2025
3d862a2
Hide sparse attention logging behind env var LLAMA_SPARSE_DEBUG for
createthis Oct 25, 2025
5194bfa
Don't modify get_k_indexer. Create get_k_indexer_full instead.
createthis Oct 25, 2025
73b35db
Fix crash (hopefully)
createthis Oct 25, 2025
472a41d
Switch to using get_k_indexer_full
createthis Oct 25, 2025
be3ef9e
Fix crash
createthis Oct 25, 2025
04bb17f
Fix crash
createthis Oct 25, 2025
0500521
Build and use full kq_mask for sparse attention.
createthis Oct 25, 2025
aec054d
Merge branch 'deepseek_v3_2_exp' of github.com:createthis/llama.cpp i…
createthis Oct 25, 2025
ba4780c
Trying to track down why we aren't always getting the full width
createthis Oct 25, 2025
16a9d43
Fix compile error
createthis Oct 25, 2025
568c5e3
Attempt to fix crash
createthis Oct 25, 2025
6097468
Remove rotate_activation.
createthis Oct 25, 2025
ef9b177
Keep validity masking (causal window) in the indexer/top-k path, but
createthis Oct 26, 2025
48ba041
Bump LLAMA_SPARSE_TOPK default to 2048 to be inline with vllm and
createthis Oct 26, 2025
d2ac39e
set_input_kq_mask_full_2d: restore ALiBi in the full-width mask
createthis Oct 26, 2025
a37bf05
Q-scale proxy is active: idx_weights = W_proj(cur) × 1/sqrt(H) ×
createthis Oct 26, 2025
76f576d
added a safe K-scale proxy multiply in the top-k selector, after
createthis Oct 26, 2025
84ae1f1
K-scale was accidentally placed in a dbg block. Move out.
createthis Oct 26, 2025
1ebd7e7
- I replaced the fragile mask slicing in the top-k selector.
createthis Oct 26, 2025
d7870c3
clamp scores_tc immediately after adding the mask in the top-k selector
createthis Oct 26, 2025
f4cd950
top_k is now limited by the currently available KV, not just cache
createthis Oct 27, 2025
a85fad5
- In src/llama-sparse-mla-fwd.cpp, I removed the conditional “VIEW when
createthis Oct 27, 2025
0555126
top_k clamp changes
createthis Oct 27, 2025
b284bd1
Fix top-k clamp. Sparse attention generation is working!
createthis Oct 27, 2025
6fb54c1
comment out or hide all debug prints behind LLAMA_SPARSE_DEBUG
createthis Oct 27, 2025
edc23f9
Streaming per-head accumulation to avoid [N_kv, H, Tc] temporaries
createthis Oct 27, 2025
9e9a84a
Revert last change as it was objectively worse.
createthis Oct 27, 2025
06bd370
- Keep Top-K indices on device:
createthis Oct 27, 2025
7866fd5
kept the sparse attention output tensor “cur” on device in the sparse
createthis Oct 27, 2025
b96f5fb
WIP radix top-k
createthis Oct 27, 2025
d3e4a6a
Ported radix top-k selection with thresholding and tail refinement
createthis Oct 28, 2025
100535b
Integrate radix top-k
createthis Oct 28, 2025
7780061
Guard printf's with dbg
createthis Oct 28, 2025
90f0e17
Add a repro test for the context > 50k issue. Also attempt to fix it.
createthis Oct 28, 2025
8dd82d8
Add another potention repro test
createthis Oct 28, 2025
2d3cb41
Add some logging
createthis Oct 28, 2025
3d547a6
fix compile errors
createthis Oct 28, 2025
ca658a8
more logging
createthis Oct 28, 2025
2d810d5
Add an include to fix compile issue
createthis Oct 28, 2025
1ab31d4
more logging
createthis Oct 28, 2025
4de2e22
Add logging to ggml to track this down.
createthis Oct 28, 2025
e1634c7
Helping sentient rocks put their changes where they intended.
createthis Oct 28, 2025
6c048aa
Add fflush
createthis Oct 28, 2025
42024cd
Try to get logging working
createthis Oct 28, 2025
60f71ac
More logging
createthis Oct 28, 2025
07f84a3
label cur
createthis Oct 28, 2025
5983712
Add another tensor name as I argue with sentient rocks
createthis Oct 28, 2025
f2cef1a
Attempt to fix the problem, remove unused tests
createthis Oct 29, 2025
f2896cf
Oops, fix makefile
createthis Oct 29, 2025
86f7d87
Remove unnecessary prints now that problem is solved.
createthis Oct 29, 2025
8670597
Initial CUDA topk. It's only bitonic though.
createthis Oct 31, 2025
d064c27
Silence a warning
createthis Oct 31, 2025
bfc4387
Add histogram test
createthis Oct 31, 2025
785d299
Silence some warnings
createthis Oct 31, 2025
17834b8
Silence warning
createthis Oct 31, 2025
da81a3a
Add unit test for k_select_topk_bins
createthis Oct 31, 2025
327c926
Add debug printing so we can see what is going on inside the cuda
createthis Nov 1, 2025
c28506e
Fix code so that test-sparse-topk-select-cuda test passes.
createthis Nov 1, 2025
9b48325
Remove unused code.
createthis Nov 1, 2025
a71e8b6
Add streaming fallback for constrained memory
createthis Nov 1, 2025
c2ccf43
Make streaming the default for now.
createthis Nov 1, 2025
1424740
Add some large-N stress tests
createthis Nov 1, 2025
b482754
Turn off device prints
createthis Nov 1, 2025
28cf484
GGML Radix Top-k Op - currently broken during inference startup
createthis Nov 1, 2025
f2cb601
Inferring again, but slower than CPU radix top-k
createthis Nov 1, 2025
df2937a
7.07 tok/s non-streaming cuda top-k, 0.84 tok/s streaming cuda top-k.
createthis Nov 1, 2025
a60ff2a
Remove the streaming path. It was distracting.
createthis Nov 1, 2025
4694b6b
Fix tests, but also way worse performance.
createthis Nov 1, 2025
1b7c139
Revert "Fix tests, but also way worse performance."
createthis Nov 2, 2025
25ac258
Remove "stream" verbiage
createthis Nov 2, 2025
e678cbc
Add eq_capacity
createthis Nov 2, 2025
b20a910
Passing tests. Manage memory correctly.
createthis Nov 2, 2025
fdaa234
7.15 tok/s - Replace the O(K·eq_count) tail with block-parallel top-K…
createthis Nov 2, 2025
5c2d35e
Remove unused s_sel_count.
createthis Nov 2, 2025
e0deaa5
Turn exiting on nans back on.
createthis Nov 2, 2025
d5aacb8
Remove outdated comment
createthis Nov 2, 2025
256da56
- select_topk_tokens_indexer_kvaware (src/llama-sparse-topk.cpp)
createthis Nov 2, 2025
ce47463
- Head-chunked indexer GEMM
createthis Nov 2, 2025
3077737
Radix top-k kernel tuning: reduced shared memory footprint
createthis Nov 2, 2025
81fa0ef
Started staged refinement for radix selector (second and third bytes)
createthis Nov 2, 2025
0a9c899
Revert last change
createthis Nov 2, 2025
1d7ced5
- Ensure contiguous before 3D reshapes in the chunk path:
createthis Nov 2, 2025
e5a4cab
Fix for assertion
createthis Nov 2, 2025
1557621
- Reintroduced staged refinement, safely and without an env flag:
createthis Nov 2, 2025
b1f4ec3
- Replaced the chunk-wide broadcast-mul + permute + sum path with a p…
createthis Nov 2, 2025
b4dcfee
Change 1: Fuse per-head weighted reduction in the indexer
createthis Nov 2, 2025
730f111
Fix test-sparse-topk-radix-stress-cuda
createthis Nov 3, 2025
14060e0
- Added env-controlled FP16 for sparse MLA QK in apply_sparse_attenti…
createthis Nov 3, 2025
a55936b
- Pretransposed V for reuse across tokens:
createthis Nov 3, 2025
f07dbad
4.4 tok/s. Going in the wrong direction it seems.
createthis Nov 3, 2025
2fa5874
Another env var. I'll probably remove this later.
createthis Nov 3, 2025
50ca733
- Top‑k can now accept FP16 scores on the API side. CUDA forward prom…
createthis Nov 3, 2025
d88c649
9.13 tok/s - - Indexer tile cropping (causal prefill): we now slice t…
createthis Nov 3, 2025
210689e
Default FP16 paths ON for indexer and sparse MLA GEMMs
createthis Nov 3, 2025
02aa193
A) Per-token KV windows (host-side derivation)
createthis Nov 3, 2025
1f950ec
1) Per-token KV windows (host-side derivation)
createthis Nov 3, 2025
e1be324
- Scaffolded fused-kernel source and build plumbing:
createthis Nov 3, 2025
2ef96fd
- Unit test for the fused kernel:
createthis Nov 3, 2025
fe37e57
- Integrated fused path into select_topk_tokens_indexer_kvaware() wit…
createthis Nov 3, 2025
866620f
Fix warning, reindent
createthis Nov 3, 2025
0af730a
What I implemented (Phase 1: device-resident fused kernel)
createthis Nov 3, 2025
b8380d3
More debugging prints
createthis Nov 3, 2025
3b4891c
arguing with sentient rocks and printf debugging
createthis Nov 4, 2025
14f0a1b
more logging
createthis Nov 4, 2025
7c7e3a7
degenerate generation fixed - new test is broken though.
createthis Nov 4, 2025
1f533fb
All tests passing.
createthis Nov 4, 2025
0ce252d
- New tiled shared-memory fused kernel (Phase 1):
createthis Nov 4, 2025
d263802
Fix test
createthis Nov 4, 2025
f038b6a
WMMA kernel. test-indexer-fused-op-cuda fails with this, so I'm putting
createthis Nov 4, 2025
21bccaa
- Implemented a BF16 WMMA kernel: k_indexer_logits_wmma16_bf16(Q,K,W,…
createthis Nov 4, 2025
d96e906
- Fixed dispatch so WMMA isn’t launched when LLAMA_INDEXER_USE_WMMA=0:
createthis Nov 4, 2025
a438696
Fix crash when LLAMA_SPARSE_INDEXER_FUSED_DEVICE=1
createthis Nov 4, 2025
1c11e40
Gated prints behind env var
createthis Nov 4, 2025
e722f07
Some fixes
createthis Nov 4, 2025
c3bea92
- Upgraded the “tiled” fused indexer kernel to compute multiple heads…
createthis Nov 4, 2025
bc2784e
- Added double-buffered shared-memory tiling inside k_indexer_logits_…
createthis Nov 4, 2025
562158e
- Double-buffered shared memory layout:
createthis Nov 4, 2025
16de1f0
- cp.async for Q tile:
createthis Nov 4, 2025
46dc6bd
- Tile-level KV windowing in select_topk_tokens_indexer_kvaware:
createthis Nov 4, 2025
13e4487
- New CUDA helper API in ggml-cuda-indexer.h:
createthis Nov 4, 2025
c20495f
- CUDA helper and wrapper at file scope:
createthis Nov 4, 2025
12a91e4
I ensured ggml-cuda-indexer.h is included once (under GGML_USE_CUDA) …
createthis Nov 4, 2025
0e4cd5c
Profiling code so we can see whether the lightning selector or cuda
createthis Nov 5, 2025
9941319
- Optimized CUDA top-k radix selection to avoid repeated full-column …
createthis Nov 5, 2025
0bf0ff1
1) New GGML op and API
createthis Nov 5, 2025
9af0e7a
- Updated the fused kernel (sparse-mla-decode.cu) to support Hq != Hkv:
createthis Nov 5, 2025
7567255
Fused MLA kernel, but inference is broken.
createthis Nov 5, 2025
c625488
Add test-sparse-attn-mqa-cuda. It isn't repro'ing what I want yet
createthis Nov 5, 2025
7f13f21
Successfully repro the assertion we see during inference.
createthis Nov 5, 2025
9cc1334
Working fused mla!
createthis Nov 5, 2025
26ad022
Average (aggregate) profiling logs
createthis Nov 5, 2025
ada2399
Profile fused MLA kernel.
createthis Nov 5, 2025
cb3a281
Profile the indexer kernel runtime, not just the pre-kernel runtime.
createthis Nov 5, 2025
1247270
set some defaults
createthis Nov 5, 2025
de95e56
Allow profiling in test-indexer-fused-op-cuda
createthis Nov 6, 2025
70a0634
Remove comment
createthis Nov 6, 2025
52248db
Header-grouped WMMA kernel for the indexer. Things just keep getting
createthis Nov 6, 2025
b210d7b
Wire in the head-grouped WMMA kernel. 5.44 tok/s!
createthis Nov 6, 2025
b130401
Set head-grouped WMMA as the default.
createthis Nov 6, 2025
28a4ddf
Clean up warnings
createthis Nov 6, 2025
f0943af
Trying to figure out why the top-k OP test fails when profiling.
createthis Nov 6, 2025
a8567ff
Update top-k OP test. I figured out the profiling issue. I just have to
createthis Nov 6, 2025
acee2b4
5.5 - 6 tok/s. SPARSE_TOPK_RADIX avg_ms=0.216
createthis Nov 6, 2025
c9f8e0e
Remove unused prints
createthis Nov 7, 2025
0bd6441
Restructured and enhanced topk-radix.cu toward TileLang’s scheme
createthis Nov 7, 2025
6091692
a full 4-round histogram refinement on candidate buffers (no pool/arg…
createthis Nov 7, 2025
3237334
k_select_topk_bins now uses two shared candidate buffers in dynamic s…
createthis Nov 7, 2025
5c9ab96
Avoid full-column fallbacks via ping-pong candidate buffers
createthis Nov 7, 2025
98aba6c
CUDA sparse top-k: fix selection correctness and clean up debug; alig…
createthis Nov 7, 2025
89657da
Replace global sel_sofar atomic with per-bin prefix allocation (Tilel…
createthis Nov 8, 2025
50b80cd
Port of tilelang top-k kernel line-by-line.
createthis Nov 8, 2025
1b8947a
Profile tilelang top-k kernel port separately from glue code.
createthis Nov 8, 2025
10f0d56
Fix porting bug in cumsum. Unit tests to 40k. Only infers at 4k though.
createthis Nov 8, 2025
ba41773
Add warmup to top-k OP test
createthis Nov 8, 2025
f977ed1
Don't transpose.
createthis Nov 9, 2025
439ffc6
Reduce test size for now since the tilelang kernel can't handle larger.
createthis Nov 9, 2025
f367245
Wire up d_ends for the tilelang kernel
createthis Nov 9, 2025
86bbe45
Remove unused functions
createthis Nov 9, 2025
65e3ba8
Restore l_val. Arguing with sentient rocks again.
createthis Nov 9, 2025
bffea59
Fix a straggling spot where the "home-grown" kernel was running instead
createthis Nov 9, 2025
cc46776
Fix bug/parity with tilelang by implementing
createthis Nov 9, 2025
4c90531
Reproduce issue with tilelang kernel where end = 1
createthis Nov 9, 2025
d9c0c07
Add test to repro very slow tilelang kernel glue code at full context
createthis Nov 9, 2025
d3730fb
Label arguments
createthis Nov 9, 2025
2cff7dd
We want starts/ends (per-column KV windows) to be device-resident
createthis Nov 9, 2025
6813b14
- Plumbed real KV windows (starts/ends) into the CUDA top‑k op:
createthis Nov 9, 2025
a10a35a
This change makes ggml_sparse_topk_radix_ex active.
createthis Nov 9, 2025
ca34f21
Fix warnings
createthis Nov 9, 2025
95e3405
Fix failing test
createthis Nov 10, 2025
864b9c1
Add unit test to replicate the problem.
createthis Nov 10, 2025
44a0c3f
Revert "Add unit test to replicate the problem."
createthis Nov 10, 2025
756ef3e
Fix runtest.sh. It was broken by commit 95e340.
createthis Nov 10, 2025
59c0c79
Cleanup formatting.
createthis Nov 10, 2025
a22111c
Remove asm line. Unneeded.
createthis Nov 10, 2025
7798e4a
Line-by-line port of the tilelang lightning indexer kernel
createthis Nov 10, 2025
4028d86
Wire up k_tl_mqa_attn_return_logits_port behind LLAMA_INDEXER_TL_PORT=1
createthis Nov 10, 2025
d194f56
Replicate behavior of T.Parallel faithfully
createthis Nov 11, 2025
d702eac
Use LLAMA_SPARSE_PROF_EACH for consistency.
createthis Nov 11, 2025
8a8646a
Attempt to faithfully reproduce T.Pipelined and T.Copy. I don't think it
createthis Nov 11, 2025
e452213
Remove SEL_DEBUG. I keep forgetting that this file handles it
createthis Nov 12, 2025
4c25f75
Add another iteration of the tilelang ported kernel. It's a little
createthis Nov 12, 2025
951f0d8
More logging
createthis Nov 12, 2025
16e7e84
Distinguish between the two tiled launch points.
createthis Nov 14, 2025
491ec34
1) GGML API
createthis Nov 15, 2025
2eca6f9
Where we previously called ggml_indexer_logits_fused(ctx, q_tile2d, k…
createthis Nov 15, 2025
622ce52
- k_indexer_logits_tiled_f32 (ggml/src/ggml-cuda/indexer-fused.cu):
createthis Nov 15, 2025
fe82367
Plumn indexer start/end through the unit test.
createthis Nov 15, 2025
abe33ae
1) k_indexer_logits_wmma16_f32 now accepts per-token windows and uses…
createthis Nov 15, 2025
9dd9311
At the end of the WMMA16 F32 kernel, each kv_idx/tok write now checks…
createthis Nov 15, 2025
aa68f15
- Kernel signature:
createthis Nov 15, 2025
65f3397
Use FP16 CPU comparison for WMMA and TL indexer kernels.
createthis Nov 15, 2025
ac45b4f
Fix tiled indexer tests
createthis Nov 15, 2025
344cf4e
Update k_indexer_logits_wmma16_bf16 with start/end and make it the
createthis Nov 15, 2025
29f3654
Remove k_indexer_logits_warp_row and k_indexer_logits_fused as they are
createthis Nov 15, 2025
802a918
DRY up profiling code with LAUNCH_PROFILE_KERNEL macro.
createthis Nov 15, 2025
5927dd0
- Rewrote the inner post-processing in k_tl_mqa_attn_return_logits_po…
createthis Nov 16, 2025
0e4e152
tilelang indexer port: TMA + FP8 work in progress
createthis Nov 16, 2025
954d2ca
Phase A – FP16-global correctness baseline
createthis Nov 16, 2025
39c5c0f
Phase B.
createthis Nov 16, 2025
04decee
Phase C (step 1: K-only TMA skeleton, FP16 in shared)
createthis Nov 16, 2025
d82e5c9
In the TL wrapper path (the one the unit test uses), when LLAMA_TL_TM…
createthis Nov 16, 2025
0bfc80d
Phase C, minor change
createthis Nov 16, 2025
0f9f650
- In the TMA path, I forced FP16-in-shared even when LLAMA_TL_TMA_F…
createthis Nov 16, 2025
72eda10
- Scoped SM90 TMA helper stubs at file scope:
createthis Nov 16, 2025
a4394f9
Finish Phase C - SM90 mbarrier + cp.async.bulk.tensor for K and Q.
createthis Nov 16, 2025
14958fe
- Implemented SM100+ cp.async path for K tiles in k_tl_mqa_attn_retur…
createthis Nov 16, 2025
7246530
Replace k_tl_mqa_attn_return_logits_tma_f16_konly with k_tl_mqa_attn_…
createthis Nov 16, 2025
5d71a1d
Remove unused k_tl_mqa_attn_return_logits_port_f16global
createthis Nov 16, 2025
1669ede
Only build compute_smem when needed.
createthis Nov 16, 2025
55f480c
Fix PROFILE_TL_ONLY correctness.
createthis Nov 16, 2025
0d99043
Revert from pool alloc to cudaMalloc in an attempt to pinpoint this
createthis Nov 16, 2025
857fc6c
Revert "Revert from pool alloc to cudaMalloc in an attempt to pinpoin…
createthis Nov 16, 2025
f1964b8
Restore performance for PROFILE_TL_ONLY and PROFILE_TL_TMA_FP8_KONLY
createthis Nov 16, 2025
7d6fa9c
Remove unused functions.
createthis Nov 17, 2025
f1a567f
Vendor the tilelang fp8 indexer kernel. Not a port. The kernel
createthis Nov 17, 2025
53635b0
Rename env var
createthis Nov 17, 2025
e0ff9d6
Change profile name
createthis Nov 18, 2025
576871d
Wire up the tilelang lightning indexer. Something is wrong though as the
createthis Nov 18, 2025
1042e34
Remove unused k_tl_mqa_attn_return_logits_tma_fp8_full
createthis Nov 18, 2025
c46ba47
Attempt to fix test by making fp8like cpu reference and correctly doing
createthis Nov 18, 2025
11b757a
Test passes, but only with these specific settings.
createthis Nov 18, 2025
2574f0b
Move idx_compute_scores_tile() out of src/llama-sparse-topk.cpp into
createthis Nov 18, 2025
c404075
Add test for CPU indexer. Currently failing.
createthis Nov 18, 2025
234a185
Add a test that compares the CPU lightning indexer (
createthis Nov 19, 2025
f65ee12
Add test/test-fp8-e4m3-cutlass-vs-cpu.cpp so we can have confidence
createthis Nov 21, 2025
f72c04e
Yoink FP8 code from https://github.com/ggml-org/llama.cpp/pull/10055
createthis Nov 23, 2025
71d3b73
Wire GGML fp8 into our unit test. Currently failing.
createthis Nov 24, 2025
5dde911
Fix the following NaN issue, so that we have a faithful implementation:
createthis Nov 24, 2025
481e4bd
Fix warnings
createthis Nov 24, 2025
ef7e0e4
LLAMA_SPARSE_INDEXER_FUSED=1 \
createthis Nov 24, 2025
7c63a89
Random test changes
createthis Nov 24, 2025
729e044
Make the FP8 behavior the only behavior.
createthis Nov 24, 2025
d25442a
WMMA HGRP kernel changed to use FP8 internally. Output now matches
createthis Nov 24, 2025
82ed1e6
Restore 6.23 tok/s performance in WMMA HGRP kernel while retaining FP8
createthis Nov 24, 2025
3350c3c
Add LLAMA_SPARSE_PROF profiling to the idx_compute_scores_tile CPU path.
createthis Nov 24, 2025
d495b26
Remove old unused CPU graph building path from idx_compute_scores_tile.
createthis Nov 24, 2025
8715063
I don't have a benchmark from before we added CPU FP8, but this change
createthis Nov 24, 2025
8431a4f
k_indexer_logits_wmma16_bf16 is now doing real NVIDIA FP8 E4M3 math i…
createthis Nov 24, 2025
b183e8e
The tiled CUDA Lightning Indexer kernel (`k_indexer_logits_tiled_f32`…
createthis Nov 24, 2025
a5ae544
Remove tests/fp8-e4m3-cpu.h and all references to it.
createthis Nov 24, 2025
e68b55b
Remove unused use_fp16 argument to idx_compute_scores_tile
createthis Nov 24, 2025
7289478
Add warmup to tests/test-indexer-fused-op-cuda.cpp so we don't pollute
createthis Nov 25, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "ggml/src/ggml-cuda/vendors/cutlass"]
path = ggml/src/ggml-cuda/vendors/cutlass
url = https://github.com/NVIDIA/cutlass
190 changes: 190 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "b3f499bb4255f8ca19fccd664443283318f2fd2414d5e0b040fbdd0cc195d6c5":
# ref: https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
res = "deepseek-r1-qwen"
if chkhsh == "877081d19cf6996e2c4ff0e1236341e9b7bde288f5311a56a937f0afbbb3aeb5":
# ref: https://huggingface.co/deepseek-ai/DeepSeek-V3.2-Exp
res = "deepseek-v3.2"
if chkhsh == "ccc2ef013c104be7bae2965776d611e1d7a8a2a9c547dd93a682c9a9fc80352e":
# ref: https://huggingface.co/Xenova/gpt-4o
res = "gpt-4o"
Expand Down Expand Up @@ -6503,6 +6506,193 @@ def prepare_tensors(self):
raise ValueError(f"Unprocessed experts: {experts}")


@ModelBase.register(
"DeepseekV32ForCausalLM",
)
class DeepseekV3_2Model(TextModel):
model_arch = gguf.MODEL_ARCH.DEEPSEEK3_2

def set_vocab(self):
try:
self._set_vocab_gpt2()
return
except Exception:
pass

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
tokpre = self.get_vocab_base_pre(tokenizer)

if tokpre == "kimi-k2":
# Build merges list using the approach similar to HunYuanMoE
merges = []
vocab = {}
mergeable_ranks = tokenizer.model._mergeable_ranks
for token, rank in mergeable_ranks.items():
vocab[QwenModel.token_bytes_to_string(token)] = rank
if len(token) == 1:
continue
merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
if len(merged) == 2:
merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged)))

# Build token list
vocab_size = self.hparams["vocab_size"]
special_tokens = tokenizer.special_tokens
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()}
tokens: list[str] = []
toktypes: list[int] = []

for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.UNUSED)
else:
token = reverse_vocab[i]
tokens.append(token)
if i in special_tokens.values():
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.NORMAL)

self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
self.gguf_writer.add_token_merges(merges)

special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
special_vocab.add_to_gguf(self.gguf_writer)
else:
raise NotImplementedError(f"Deepseek pre-tokenizer {tokpre!r} is not supported yet!")

def set_gguf_parameters(self):

# note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group)
self.hparams["num_key_value_heads"] = 1

super().set_gguf_parameters()
hparams = self.hparams

self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None:
self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"])

# note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
self.gguf_writer.add_key_length(hparams["kv_lora_rank"] + hparams["qk_rope_head_dim"])
self.gguf_writer.add_value_length(hparams["kv_lora_rank"])
self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
self.gguf_writer.add_value_length_mla(hparams["v_head_dim"])

self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])

if hparams["scoring_func"] == "sigmoid":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
elif hparams["scoring_func"] == "softmax":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
else:
raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}")

self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])

rope_scaling = self.hparams.get("rope_scaling") or {}
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_scaling["mscale_all_dim"])

_experts: list[dict[str, Tensor]] | None = None

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# skip vision tensors and remove "language_model." for Kimi-VL
if "vision_tower" in name or "multi_modal_projector" in name:
return []

if name.startswith("language_model."):
name = name.replace("language_model.", "")

# rename e_score_correction_bias tensors
if name.endswith("e_score_correction_bias"):
name = name.replace("e_score_correction_bias", "e_score_correction.bias")

# skip Multi-Token Prediction (MTP) layers
block_count = self.hparams["num_hidden_layers"]
match = re.match(r"model.layers.(\d+)", name)
if match and int(match.group(1)) >= block_count:
return []

# process the experts separately
if name.find("mlp.experts") != -1:
n_experts = self.hparams["n_routed_experts"]
assert bid is not None

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []

# merge the experts into a single 3d tensor
for w_name in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)

merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"

new_name = self.map_tensor_name(merged_name)

tensors.append((new_name, data_torch))
return tensors
else:
return []

# note: MLA with the absorption optimization, needs these two split and k_b_proj transposed
if name.endswith("kv_b_proj.weight"):
name_kb = name.replace("kv_b_proj", "k_b_proj")
name_vb = name.replace("kv_b_proj", "v_b_proj")

n_head_kv = self.hparams["num_key_value_heads"]
v_head_dim = self.hparams["v_head_dim"]
qk_nope_head_dim = self.hparams["qk_nope_head_dim"]

assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim)

kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1])
k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1)
k_b = k_b.transpose(1, 2)

return [
(self.map_tensor_name(name_kb), k_b),
(self.map_tensor_name(name_vb), v_b)
]

return [(self.map_tensor_name(name), data_torch)]

def prepare_tensors(self):
super().prepare_tensors()

if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")


@ModelBase.register(
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
Expand Down
1 change: 1 addition & 0 deletions convert_hf_to_gguf_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ class TOKENIZER_TYPE(IntEnum):
{"name": "megrez", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Infinigence/Megrez-3B-Instruct"},
{"name": "deepseek-v3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-V3"},
{"name": "deepseek-r1-qwen", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"},
{"name": "deepseek-v3.2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-V3.2-Exp"},
{"name": "gpt-4o", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Xenova/gpt-4o", },
{"name": "superbpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/UW/OLMo2-8B-SuperBPE-t180k", },
{"name": "trillion", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/trillionlabs/Trillion-7B-preview", },
Expand Down
5 changes: 5 additions & 0 deletions ggml/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ if (NOT GGML_LLAMAFILE_DEFAULT)
set(GGML_LLAMAFILE_DEFAULT OFF)
endif()

if (NOT GGML_OPENMP_SIMD_DEFAULT)
set(GGML_OPENMP_SIMD_DEFAULT OFF)
endif()

if (NOT GGML_CUDA_GRAPHS_DEFAULT)
set(GGML_CUDA_GRAPHS_DEFAULT OFF)
endif()
Expand Down Expand Up @@ -169,6 +173,7 @@ option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON)
option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON)
option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF)
option(GGML_VXE "ggml: enable vxe" ON)
option(GGML_OPENMP_SIMD "ggml: enable OPENMP_SIMD" ${GGML_OPENMP_SIMD_DEFAULT})

option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF)
set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")
Expand Down
36 changes: 36 additions & 0 deletions ggml/include/ggml-cuda-indexer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#pragma once
#include "ggml-cuda.h"
#ifdef __cplusplus
extern "C" {
#endif

// Forward-declare the CUDA context type; definition is in common.cuh
struct ggml_backend_cuda_context;

// Derive per-token KV window ends from device-resident mask [N_kv, T]
// mask values <= -1e29 are treated as masked; ends[t] = last i where mask[i,t] > -1e29, or 0 if none
void ggml_cuda_mask_window_ends_device(struct ggml_backend_cuda_context & ctx,
const float * dMask, int N_kv, int T,
int * dEnds);

// Device-resident entry: takes device pointers and current CUDA context
void ggml_cuda_indexer_logits_fused_device(struct ggml_backend_cuda_context & ctx,
const float * dQ,
const float * dK,
const float * dW,
const float * dKS,
const int * dStarts, const int * dEnds,
int D, int H, int Tc, int kv_end,
float * dOut);

// Derive per-token KV window ends from device-resident mask and copy to host buffer
void ggml_cuda_mask_window_ends_device_to_host(struct ggml_backend_cuda_context & ctx,
const float * dMask, int N_kv, int T, int * hEnds);

// Simple convenience wrappers using current device and default stream
void ggml_cuda_mask_window_ends_device_to_host_simple(const float * dMask, int N_kv, int T, int * hEnds);
void ggml_cuda_mask_window_starts_device_to_host_simple(const float * dMask, int N_kv, int T, int * hStarts);

#ifdef __cplusplus
}
#endif
28 changes: 28 additions & 0 deletions ggml/include/ggml-cuda-radix.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#pragma once

#ifdef __cplusplus
extern "C" {
#endif

// Compute top-k indices per column using a CUDA radix-style selection.
// scores is a row-major 2D array with shape [N, T]: element(i,t) at scores[i + N*t].
// Writes indices into idx (shape [k, T], same storage rule: idx[i + k*t]).
void ggml_cuda_topk_radix_indices_host(const float * scores, int N, int T, int k, int * idx);

// Build per-column histogram on the top byte of float->key mapping.
// scores: [N, T] row-major. Outputs:
// - gt_counts: size 256*T, gt_counts[b + 256*t] = sum_{bb>b} counts[bb]
// - thr_bins: size T (currently placeholder; can be 0)
void ggml_cuda_topk_histogram_host(const float * scores, int N, int T,
unsigned int * gt_counts, unsigned int * thr_bins);

// Launch equal-bin selection kernel only, given precomputed histogram greater-counts per column
// scores: [N, T] row-major
// gt_counts: [256, T] greater-counts per bin
// idx: [k, T] output indices (row-major leading dimension k)
void ggml_cuda_topk_select_host(const float * scores, int N, int T, int k,
const unsigned int * gt_counts, int * idx);

#ifdef __cplusplus
}
#endif
59 changes: 58 additions & 1 deletion ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,11 @@ extern "C" {
// GGML_TYPE_IQ4_NL_4_8 = 37,
// GGML_TYPE_IQ4_NL_8_8 = 38,
GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block)
GGML_TYPE_COUNT = 40,
GGML_TYPE_E5M2 = 40,
GGML_TYPE_E4M3 = 41,
GGML_TYPE_E4M3_Q = 42,
GGML_TYPE_E3M4_Q = 43,
GGML_TYPE_COUNT = 44,
};

// precision
Expand Down Expand Up @@ -453,6 +457,10 @@ extern "C" {
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors
GGML_FTYPE_MOSTLY_E5M2 = 26, // except 1d tensors
GGML_FTYPE_MOSTLY_E4M3 = 27, // except 1d tensors
GGML_FTYPE_MOSTLY_E4M3_Q = 28, // except 1d tensors
GGML_FTYPE_MOSTLY_E3M4_Q = 29, // except 1d tensors
};

// available tensor operations:
Expand Down Expand Up @@ -555,6 +563,9 @@ extern "C" {
GGML_OP_OPT_STEP_ADAMW,
GGML_OP_OPT_STEP_SGD,

GGML_OP_SPARSE_TOPK_RADIX,
GGML_OP_INDEXER_FUSED,
GGML_OP_SPARSE_MLA_DECODE,
GGML_OP_GLU,

GGML_OP_COUNT,
Expand Down Expand Up @@ -725,12 +736,56 @@ extern "C" {
GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor);
GGML_API bool ggml_is_vector (const struct ggml_tensor * tensor);
GGML_API bool ggml_is_matrix (const struct ggml_tensor * tensor);

// sparse MLA decode fused (CUDA backend)
GGML_API struct ggml_tensor * ggml_sparse_mla_decode_fused(
struct ggml_context * ctx,
struct ggml_tensor * q2d,
struct ggml_tensor * k_cache,
struct ggml_tensor * v_cache,
struct ggml_tensor * idx_topk,
float kq_scale,
float attn_softcap);

GGML_API bool ggml_is_3d (const struct ggml_tensor * tensor);
GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars

// returns whether the tensor elements can be iterated over with a flattened index (no gaps, no permutation)
GGML_API bool ggml_is_contiguous (const struct ggml_tensor * tensor);
GGML_API bool ggml_is_contiguous_0(const struct ggml_tensor * tensor); // same as ggml_is_contiguous()

// radix-based sparse top-k indices per column (specialized CUDA path with CPU fallback)
GGML_API struct ggml_tensor * ggml_sparse_topk_radix(
struct ggml_context * ctx,
struct ggml_tensor * scores,
int k);


// Variant that accepts optional per-column windows [start,end)
GGML_API struct ggml_tensor * ggml_sparse_topk_radix_ex(
struct ggml_context * ctx,
struct ggml_tensor * scores,
int k,
struct ggml_tensor * starts,
struct ggml_tensor * ends);

// fused lightning-indexer logits: inputs Q[D, Tc*H], K[D, kv_end], W[H, Tc], k_scale[kv_end] => out [kv_end, Tc]
GGML_API struct ggml_tensor * ggml_indexer_logits_fused(
struct ggml_context * ctx,
struct ggml_tensor * q2d,
struct ggml_tensor * k2d,
struct ggml_tensor * w2d,
struct ggml_tensor * k_scale);

GGML_API struct ggml_tensor * ggml_indexer_logits_fused_ex(
struct ggml_context * ctx,
struct ggml_tensor * q2d,
struct ggml_tensor * k2d,
struct ggml_tensor * w2d,
struct ggml_tensor * k_scale,
struct ggml_tensor * starts,
struct ggml_tensor * ends);

GGML_API bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1
GGML_API bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2

Expand Down Expand Up @@ -2546,3 +2601,5 @@ extern "C" {
#ifdef __cplusplus
}
#endif

// optional [Tc] I32
Loading
Loading