diff --git a/.gitmodules b/.gitmodules index e69de29bb2d..b563cab27cf 100644 --- a/.gitmodules +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "ggml/src/ggml-cuda/vendors/cutlass"] + path = ggml/src/ggml-cuda/vendors/cutlass + url = https://github.com/NVIDIA/cutlass diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 411e36f8cf4..0c22d951c31 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -852,6 +852,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "b3f499bb4255f8ca19fccd664443283318f2fd2414d5e0b040fbdd0cc195d6c5": # ref: https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B res = "deepseek-r1-qwen" + if chkhsh == "877081d19cf6996e2c4ff0e1236341e9b7bde288f5311a56a937f0afbbb3aeb5": + # ref: https://huggingface.co/deepseek-ai/DeepSeek-V3.2-Exp + res = "deepseek-v3.2" if chkhsh == "ccc2ef013c104be7bae2965776d611e1d7a8a2a9c547dd93a682c9a9fc80352e": # ref: https://huggingface.co/Xenova/gpt-4o res = "gpt-4o" @@ -6503,6 +6506,193 @@ def prepare_tensors(self): raise ValueError(f"Unprocessed experts: {experts}") +@ModelBase.register( + "DeepseekV32ForCausalLM", +) +class DeepseekV3_2Model(TextModel): + model_arch = gguf.MODEL_ARCH.DEEPSEEK3_2 + + def set_vocab(self): + try: + self._set_vocab_gpt2() + return + except Exception: + pass + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) + tokpre = self.get_vocab_base_pre(tokenizer) + + if tokpre == "kimi-k2": + # Build merges list using the approach similar to HunYuanMoE + merges = [] + vocab = {} + mergeable_ranks = tokenizer.model._mergeable_ranks + for token, rank in mergeable_ranks.items(): + vocab[QwenModel.token_bytes_to_string(token)] = rank + if len(token) == 1: + continue + merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank) + if len(merged) == 2: + merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged))) + + # Build token list + vocab_size = self.hparams["vocab_size"] + special_tokens = tokenizer.special_tokens + reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()} + tokens: list[str] = [] + toktypes: list[int] = [] + + for i in range(vocab_size): + if i not in reverse_vocab: + tokens.append(f"[PAD{i}]") + toktypes.append(gguf.TokenType.UNUSED) + else: + token = reverse_vocab[i] + tokens.append(token) + if i in special_tokens.values(): + toktypes.append(gguf.TokenType.CONTROL) + else: + toktypes.append(gguf.TokenType.NORMAL) + + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + self.gguf_writer.add_token_merges(merges) + + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False) + special_vocab.add_to_gguf(self.gguf_writer) + else: + raise NotImplementedError(f"Deepseek pre-tokenizer {tokpre!r} is not supported yet!") + + def set_gguf_parameters(self): + + # note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group) + self.hparams["num_key_value_heads"] = 1 + + super().set_gguf_parameters() + hparams = self.hparams + + self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None: + self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) + self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"]) + + # note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA + self.gguf_writer.add_key_length(hparams["kv_lora_rank"] + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length(hparams["kv_lora_rank"]) + self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length_mla(hparams["v_head_dim"]) + + self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) + self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) + self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) + self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) + self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) + + if hparams["scoring_func"] == "sigmoid": + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + elif hparams["scoring_func"] == "softmax": + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX) + else: + raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}") + + self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) + + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) + self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_scaling["mscale_all_dim"]) + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # skip vision tensors and remove "language_model." for Kimi-VL + if "vision_tower" in name or "multi_modal_projector" in name: + return [] + + if name.startswith("language_model."): + name = name.replace("language_model.", "") + + # rename e_score_correction_bias tensors + if name.endswith("e_score_correction_bias"): + name = name.replace("e_score_correction_bias", "e_score_correction.bias") + + # skip Multi-Token Prediction (MTP) layers + block_count = self.hparams["num_hidden_layers"] + match = re.match(r"model.layers.(\d+)", name) + if match and int(match.group(1)) >= block_count: + return [] + + # process the experts separately + if name.find("mlp.experts") != -1: + n_experts = self.hparams["n_routed_experts"] + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + + # merge the experts into a single 3d tensor + for w_name in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + + new_name = self.map_tensor_name(merged_name) + + tensors.append((new_name, data_torch)) + return tensors + else: + return [] + + # note: MLA with the absorption optimization, needs these two split and k_b_proj transposed + if name.endswith("kv_b_proj.weight"): + name_kb = name.replace("kv_b_proj", "k_b_proj") + name_vb = name.replace("kv_b_proj", "v_b_proj") + + n_head_kv = self.hparams["num_key_value_heads"] + v_head_dim = self.hparams["v_head_dim"] + qk_nope_head_dim = self.hparams["qk_nope_head_dim"] + + assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim) + + kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1]) + k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1) + k_b = k_b.transpose(1, 2) + + return [ + (self.map_tensor_name(name_kb), k_b), + (self.map_tensor_name(name_vb), v_b) + ] + + return [(self.map_tensor_name(name), data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + @ModelBase.register( "DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM", diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 21bb4a9f3e5..5d29026e966 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -127,6 +127,7 @@ class TOKENIZER_TYPE(IntEnum): {"name": "megrez", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Infinigence/Megrez-3B-Instruct"}, {"name": "deepseek-v3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-V3"}, {"name": "deepseek-r1-qwen", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"}, + {"name": "deepseek-v3.2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-V3.2-Exp"}, {"name": "gpt-4o", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Xenova/gpt-4o", }, {"name": "superbpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/UW/OLMo2-8B-SuperBPE-t180k", }, {"name": "trillion", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/trillionlabs/Trillion-7B-preview", }, diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 56420587a95..18193091339 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -106,6 +106,10 @@ if (NOT GGML_LLAMAFILE_DEFAULT) set(GGML_LLAMAFILE_DEFAULT OFF) endif() +if (NOT GGML_OPENMP_SIMD_DEFAULT) + set(GGML_OPENMP_SIMD_DEFAULT OFF) +endif() + if (NOT GGML_CUDA_GRAPHS_DEFAULT) set(GGML_CUDA_GRAPHS_DEFAULT OFF) endif() @@ -169,6 +173,7 @@ option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON) option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON) option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF) option(GGML_VXE "ggml: enable vxe" ON) +option(GGML_OPENMP_SIMD "ggml: enable OPENMP_SIMD" ${GGML_OPENMP_SIMD_DEFAULT}) option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF) set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM") diff --git a/ggml/include/ggml-cuda-indexer.h b/ggml/include/ggml-cuda-indexer.h new file mode 100644 index 00000000000..c5ef52235b5 --- /dev/null +++ b/ggml/include/ggml-cuda-indexer.h @@ -0,0 +1,36 @@ +#pragma once +#include "ggml-cuda.h" +#ifdef __cplusplus +extern "C" { +#endif + +// Forward-declare the CUDA context type; definition is in common.cuh +struct ggml_backend_cuda_context; + +// Derive per-token KV window ends from device-resident mask [N_kv, T] +// mask values <= -1e29 are treated as masked; ends[t] = last i where mask[i,t] > -1e29, or 0 if none +void ggml_cuda_mask_window_ends_device(struct ggml_backend_cuda_context & ctx, + const float * dMask, int N_kv, int T, + int * dEnds); + +// Device-resident entry: takes device pointers and current CUDA context +void ggml_cuda_indexer_logits_fused_device(struct ggml_backend_cuda_context & ctx, + const float * dQ, + const float * dK, + const float * dW, + const float * dKS, + const int * dStarts, const int * dEnds, + int D, int H, int Tc, int kv_end, + float * dOut); + +// Derive per-token KV window ends from device-resident mask and copy to host buffer +void ggml_cuda_mask_window_ends_device_to_host(struct ggml_backend_cuda_context & ctx, + const float * dMask, int N_kv, int T, int * hEnds); + +// Simple convenience wrappers using current device and default stream +void ggml_cuda_mask_window_ends_device_to_host_simple(const float * dMask, int N_kv, int T, int * hEnds); +void ggml_cuda_mask_window_starts_device_to_host_simple(const float * dMask, int N_kv, int T, int * hStarts); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/include/ggml-cuda-radix.h b/ggml/include/ggml-cuda-radix.h new file mode 100644 index 00000000000..f8f2a611dd0 --- /dev/null +++ b/ggml/include/ggml-cuda-radix.h @@ -0,0 +1,28 @@ +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +// Compute top-k indices per column using a CUDA radix-style selection. +// scores is a row-major 2D array with shape [N, T]: element(i,t) at scores[i + N*t]. +// Writes indices into idx (shape [k, T], same storage rule: idx[i + k*t]). +void ggml_cuda_topk_radix_indices_host(const float * scores, int N, int T, int k, int * idx); + +// Build per-column histogram on the top byte of float->key mapping. +// scores: [N, T] row-major. Outputs: +// - gt_counts: size 256*T, gt_counts[b + 256*t] = sum_{bb>b} counts[bb] +// - thr_bins: size T (currently placeholder; can be 0) +void ggml_cuda_topk_histogram_host(const float * scores, int N, int T, + unsigned int * gt_counts, unsigned int * thr_bins); + +// Launch equal-bin selection kernel only, given precomputed histogram greater-counts per column +// scores: [N, T] row-major +// gt_counts: [256, T] greater-counts per bin +// idx: [k, T] output indices (row-major leading dimension k) +void ggml_cuda_topk_select_host(const float * scores, int N, int T, int k, + const unsigned int * gt_counts, int * idx); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 5028a9cebf2..12bc84839b9 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -417,7 +417,11 @@ extern "C" { // GGML_TYPE_IQ4_NL_4_8 = 37, // GGML_TYPE_IQ4_NL_8_8 = 38, GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block) - GGML_TYPE_COUNT = 40, + GGML_TYPE_E5M2 = 40, + GGML_TYPE_E4M3 = 41, + GGML_TYPE_E4M3_Q = 42, + GGML_TYPE_E3M4_Q = 43, + GGML_TYPE_COUNT = 44, }; // precision @@ -453,6 +457,10 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors + GGML_FTYPE_MOSTLY_E5M2 = 26, // except 1d tensors + GGML_FTYPE_MOSTLY_E4M3 = 27, // except 1d tensors + GGML_FTYPE_MOSTLY_E4M3_Q = 28, // except 1d tensors + GGML_FTYPE_MOSTLY_E3M4_Q = 29, // except 1d tensors }; // available tensor operations: @@ -555,6 +563,9 @@ extern "C" { GGML_OP_OPT_STEP_ADAMW, GGML_OP_OPT_STEP_SGD, + GGML_OP_SPARSE_TOPK_RADIX, + GGML_OP_INDEXER_FUSED, + GGML_OP_SPARSE_MLA_DECODE, GGML_OP_GLU, GGML_OP_COUNT, @@ -725,12 +736,56 @@ extern "C" { GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor); GGML_API bool ggml_is_vector (const struct ggml_tensor * tensor); GGML_API bool ggml_is_matrix (const struct ggml_tensor * tensor); + + // sparse MLA decode fused (CUDA backend) + GGML_API struct ggml_tensor * ggml_sparse_mla_decode_fused( + struct ggml_context * ctx, + struct ggml_tensor * q2d, + struct ggml_tensor * k_cache, + struct ggml_tensor * v_cache, + struct ggml_tensor * idx_topk, + float kq_scale, + float attn_softcap); + GGML_API bool ggml_is_3d (const struct ggml_tensor * tensor); GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars // returns whether the tensor elements can be iterated over with a flattened index (no gaps, no permutation) GGML_API bool ggml_is_contiguous (const struct ggml_tensor * tensor); GGML_API bool ggml_is_contiguous_0(const struct ggml_tensor * tensor); // same as ggml_is_contiguous() + + // radix-based sparse top-k indices per column (specialized CUDA path with CPU fallback) + GGML_API struct ggml_tensor * ggml_sparse_topk_radix( + struct ggml_context * ctx, + struct ggml_tensor * scores, + int k); + + + // Variant that accepts optional per-column windows [start,end) + GGML_API struct ggml_tensor * ggml_sparse_topk_radix_ex( + struct ggml_context * ctx, + struct ggml_tensor * scores, + int k, + struct ggml_tensor * starts, + struct ggml_tensor * ends); + + // fused lightning-indexer logits: inputs Q[D, Tc*H], K[D, kv_end], W[H, Tc], k_scale[kv_end] => out [kv_end, Tc] + GGML_API struct ggml_tensor * ggml_indexer_logits_fused( + struct ggml_context * ctx, + struct ggml_tensor * q2d, + struct ggml_tensor * k2d, + struct ggml_tensor * w2d, + struct ggml_tensor * k_scale); + + GGML_API struct ggml_tensor * ggml_indexer_logits_fused_ex( + struct ggml_context * ctx, + struct ggml_tensor * q2d, + struct ggml_tensor * k2d, + struct ggml_tensor * w2d, + struct ggml_tensor * k_scale, + struct ggml_tensor * starts, + struct ggml_tensor * ends); + GGML_API bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1 GGML_API bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2 @@ -2546,3 +2601,5 @@ extern "C" { #ifdef __cplusplus } #endif + + // optional [Tc] I32 diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index c8f3d859642..230a585b9e9 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -206,6 +206,7 @@ add_library(ggml-base ggml-threading.h ggml-quants.c ggml-quants.h + ggml-fp8.cpp gguf.cpp) target_include_directories(ggml-base PRIVATE .) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 93ab7ea446e..1f122fc8349 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -427,6 +427,24 @@ typedef struct { } block_iq4_xs; static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); +// fp8 support +// - fp8 simple type +typedef struct { uint8_t bits; } ggml_e5m2_t; +typedef struct { uint8_t bits; } ggml_e4m3_t; + +// - fp8 with bloc delta => 8.125 bpw +typedef struct { + float d; // delta + uint8_t qs[QK_K]; +} block_e4m3_q; +static_assert(sizeof(block_e4m3_q) == sizeof(float) + QK_K, "wrong block_e4m3_q block size/padding"); + +typedef struct { + float d; // delta + uint8_t qs[QK_K]; +} block_e3m4_q; +static_assert(sizeof(block_e3m4_q) == sizeof(float) + QK_K, "wrong block_e3m4_q block size/padding"); + #endif // GGML_COMMON_DECL #endif // GGML_COMMON_DECL diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index dbc07301b29..93027df6ac6 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1671,6 +1671,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } switch (tensor->op) { + case GGML_OP_SPARSE_TOPK_RADIX: + case GGML_OP_SPARSE_MLA_DECODE: + // Not implemented on CPU yet; should be caught by backend selection. + GGML_ASSERT(false && "CPU backend: SPARSE_TOPK_RADIX / SPARSE_MLA_DECODE not supported"); + return; case GGML_OP_DUP: { ggml_compute_forward_dup(params, tensor); @@ -2133,6 +2138,13 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { } switch (node->op) { + case GGML_OP_SPARSE_TOPK_RADIX: + n_tasks = 1; + break; + case GGML_OP_INDEXER_FUSED: + n_tasks = 1; + break; + case GGML_OP_CPY: case GGML_OP_DUP: case GGML_OP_CONT: diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 14f7dcf4f41..640e9772c13 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -5398,6 +5398,12 @@ void ggml_compute_forward_clamp( { ggml_compute_forward_clamp_f16(params, dst); } break; + case GGML_TYPE_E5M2: + case GGML_TYPE_E4M3: + case GGML_TYPE_E4M3_Q: + case GGML_TYPE_E3M4_Q: + GGML_ASSERT(false && "clamp for fp8 types not implemented on CPU"); + break; case GGML_TYPE_BF16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index bdcefe7b7ed..d0f2e455dc5 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -42,8 +42,11 @@ if (CUDAToolkit_FOUND) file(GLOB GGML_HEADERS_CUDA "*.cuh") list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h") + list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda-radix.h") file(GLOB GGML_SOURCES_CUDA "*.cu") + list(APPEND GGML_SOURCES_CUDA "topk-radix.cu" indexer-fused.cu sparse-mla-decode.cu) + file(GLOB SRCS "template-instances/fattn-mma*.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) file(GLOB SRCS "template-instances/mmq*.cu") @@ -69,6 +72,28 @@ if (CUDAToolkit_FOUND) ${GGML_SOURCES_CUDA} ) + + list(FILTER GGML_SOURCES_CUDA EXCLUDE REGEX "mqa_attn_return_logits_kernel\\.cu$") + + # Build the Lightning Indexer kernel as its own OBJECT library so we can + # give it a different CUDA arch than the rest of the project. + add_library(lightning_kernels OBJECT + ${CMAKE_CURRENT_SOURCE_DIR}/vendors/tilelang/fp8_lightning_indexer/mqa_attn_return_logits_kernel.cu + ) + # Compile just this TU for Hopper/ADA style WGMMA (change as needed) + set_property(TARGET lightning_kernels PROPERTY CUDA_ARCHITECTURES "120a") + set_property(TARGET lightning_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) + # Includes required to compile this TU + target_include_directories(lightning_kernels PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/vendors/tilelang/fp8_lightning_indexer + ${CMAKE_CURRENT_SOURCE_DIR}/vendors/cutlass/include + ) + target_compile_options(lightning_kernels PRIVATE + $<$:-Xcompiler=-fPIC -std=c++17 -w -Xcudafe --diag_suppress=177 -lineinfo --use_fast_math -gencode arch=compute_120a,code=sm_120a>) + # Inject the compiled objects into ggml-cuda + target_sources(ggml-cuda PRIVATE $) + + add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE}) if (GGML_CUDA_GRAPHS) @@ -186,3 +211,10 @@ if (CUDAToolkit_FOUND) else() message(FATAL_ERROR "CUDA Toolkit not found") endif() + # (mirror includes outside the if() block) + target_include_directories(ggml-cuda PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/vendors/tilelang/fp8_lightning_indexer + ) + target_include_directories(ggml-cuda PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/vendors/cutlass/include + ) diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index 60240102741..80589b2ef17 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -1,6 +1,7 @@ #include "binbcast.cuh" #include #include +#include static __device__ __forceinline__ float op_repeat(const float a, const float b) { return b; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index b7e81b21bcb..4938017ab7b 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -46,6 +46,7 @@ #include "ggml-cuda/mean.cuh" #include "ggml-cuda/tsembd.cuh" #include "ggml-cuda/topk-moe.cuh" +#include "ggml-cuda/topk-radix.cuh" #include "ggml-cuda/unary.cuh" #include "ggml-cuda/upscale.cuh" #include "ggml-cuda/wkv.cuh" @@ -54,6 +55,10 @@ #include "ggml-cuda/pad_reflect_1d.cuh" #include "ggml.h" +extern "C" void ggml_cuda_sparse_mla_decode_device(ggml_backend_cuda_context & ctx, + const float * q, const float * k, const float * v, const int32_t * topk, + int D, int Hq, int Hkv, int Dv, int Nkv, int K, float kq_scale, float softcap, float * out); + #include #include #include @@ -73,6 +78,7 @@ #include #include #include +#include "../../include/ggml-cuda-indexer.h" static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); @@ -2518,7 +2524,252 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_OPT_STEP_SGD: ggml_cuda_opt_step_sgd(ctx, dst); break; + case GGML_OP_SPARSE_TOPK_RADIX: + { + ggml_tensor * scores = dst->src[0]; + int k = (int)ggml_get_op_params_i32(dst, 0); + GGML_ASSERT(dst->type == GGML_TYPE_I32); + int N = (int)scores->ne[0]; + int T = (int)ggml_nrows(scores); + // profiling wrapper for SPARSE_TOPK_RADIX (selection only) + // guarded by env LLAMA_SPARSE_PROF to avoid log spam + auto * __prof_env = getenv("LLAMA_SPARSE_PROF"); + auto * __prof_each_env = getenv("LLAMA_SPARSE_PROF_EACH"); + cudaStreamCaptureStatus __iscap = cudaStreamCaptureStatusNone; + CUDA_CHECK(cudaStreamIsCapturing(ctx.stream(), &__iscap)); + bool __do_prof = (__prof_env && *__prof_env && __iscap == cudaStreamCaptureStatusNone); + cudaError_t __err_kernel = cudaSuccess; + // Optional per-column KV windows (starts/ends) passed via src[1]/src[2] + const int * tl_starts = nullptr; + const int * tl_ends = nullptr; + if (dst->src[1] && dst->src[1]->type == GGML_TYPE_I32) tl_starts = (const int *)dst->src[1]->data; + if (dst->src[2] && dst->src[2]->type == GGML_TYPE_I32) tl_ends = (const int *)dst->src[2]->data; + if (scores->type == GGML_TYPE_F16) { + // Promote half->float for current kernel + const to_fp32_cuda_t to_fp32 = ggml_get_to_fp32_cuda(GGML_TYPE_F16); + ggml_cuda_pool_alloc tmp(ctx.pool(ggml_cuda_get_device()), (size_t)N*T); + to_fp32((const void *)scores->data, (float *)tmp.get(), (size_t)N*T, ctx.stream()); + if (__do_prof) { + cudaEvent_t __e0, __e1; cudaError_t __e0c = cudaEventCreate(&__e0); cudaError_t __e1c = cudaEventCreate(&__e1); + bool __ev_ok = (__e0c == cudaSuccess && __e1c == cudaSuccess); + if (__ev_ok) cudaEventRecord(__e0, ctx.stream()); + { + const char * use_tl = getenv("LLAMA_SPARSE_TOPK_TL"); + if (use_tl && *use_tl) { + ggml_cuda_topk_tilelang_port_device(ctx, (const float *)tmp.get(), N, T, k, (int *)dst->data, tl_starts, tl_ends); + } else { + ggml_cuda_topk_radix_indices_device(ctx, (const float *)tmp.get(), N, T, k, (int *)dst->data); + } + } + __err_kernel = cudaGetLastError(); + float __ms = 0.0f; + if (__ev_ok) { + cudaEventRecord(__e1, ctx.stream()); cudaEventSynchronize(__e1); + cudaEventElapsedTime(&__ms, __e0, __e1); + cudaEventDestroy(__e0); cudaEventDestroy(__e1); + } + static int __cnt = 0; static double __sum = 0.0; __sum += __ms; __cnt++; + if (__prof_each_env && *__prof_each_env) { + fprintf(stderr, "[PROFILE] SPARSE_TOPK_RADIX N=%d T=%d k=%d ms=%.3f\n", N, T, k, (float)(__ms)); + } else { + if (__cnt % 50 == 0) { fprintf(stderr, "[PROFILE] SPARSE_TOPK_RADIX N=%d T=%d k=%d avg_ms=%.3f over 50 calls\n", N, T, k, (float)(__sum/50.0)); __sum = 0.0; } + } + } else { + { + const char * use_tl = getenv("LLAMA_SPARSE_TOPK_TL"); + if (use_tl && *use_tl) { + ggml_cuda_topk_tilelang_port_device(ctx, (const float *)tmp.get(), N, T, k, (int *)dst->data, tl_starts, tl_ends); + } else { + ggml_cuda_topk_radix_indices_device(ctx, (const float *)tmp.get(), N, T, k, (int *)dst->data); + } + } + __err_kernel = cudaGetLastError(); + } + } else { + GGML_ASSERT(scores->type == GGML_TYPE_F32); + if (__do_prof) { + cudaEvent_t __e0, __e1; + cudaEventCreate(&__e0); + cudaEventCreate(&__e1); + cudaEventRecord(__e0, ctx.stream()); + { + const char * use_tl = getenv("LLAMA_SPARSE_TOPK_TL"); + if (use_tl && *use_tl) { + ggml_cuda_topk_tilelang_port_device(ctx, (const float *)scores->data, N, T, k, (int *)dst->data, tl_starts, tl_ends); + } else { + ggml_cuda_topk_radix_indices_device(ctx, (const float *)scores->data, N, T, k, (int *)dst->data); + } + } + __err_kernel = cudaGetLastError(); + cudaEventRecord(__e1, ctx.stream()); + cudaEventSynchronize(__e1); + float __ms = 0.0f; + cudaEventElapsedTime(&__ms, __e0, __e1); + static int __cnt = 0; + static double __sum = 0.0; + __sum += __ms; + __cnt++; + if (__prof_each_env && *__prof_each_env) { + fprintf(stderr, "[PROFILE] SPARSE_TOPK_RADIX N=%d T=%d k=%d ms=%.3f\n", N, T, k, (float)(__ms)); + } else { + if (__cnt % 50 == 0) { fprintf(stderr, "[PROFILE] SPARSE_TOPK_RADIX N=%d T=%d k=%d avg_ms=%.3f over 50 calls\n", N, T, k, (float)(__sum/50.0)); __sum = 0.0; } + } + cudaEventDestroy(__e0); + cudaEventDestroy(__e1); + } else { + const char * use_tl = getenv("LLAMA_SPARSE_TOPK_TL"); + if (use_tl && *use_tl) { + ggml_cuda_topk_tilelang_port_device(ctx, (const float *)scores->data, N, T, k, (int *)dst->data, tl_starts, tl_ends); + } else { + ggml_cuda_topk_radix_indices_device(ctx, (const float *)scores->data, N, T, k, (int *)dst->data); + } + __err_kernel = cudaGetLastError(); + } + } + // Prefer kernel error captured immediately after launch if set, otherwise last error + cudaError_t err_topk = (__err_kernel != cudaSuccess) ? __err_kernel : cudaGetLastError(); + if (err_topk != cudaSuccess) { + GGML_LOG_ERROR("ggml_cuda_compute_forward: SPARSE_TOPK_RADIX failed"); + CUDA_CHECK(err_topk); + } + } + break; + case GGML_OP_INDEXER_FUSED: + { + // inputs: Q[D, Tc*H], K[D, kv], W[H, Tc], k_scale[kv] + ggml_tensor * q2d = dst->src[0]; + ggml_tensor * k2d = dst->src[1]; + ggml_tensor * w2d = dst->src[2]; + ggml_tensor * ks = dst->src[3]; + int D = (int)q2d->ne[0]; + int TcH = (int)q2d->ne[1]; + int kv = (int)k2d->ne[1]; + int Tc = (int)w2d->ne[1]; + int H = (int)w2d->ne[0]; + GGML_ASSERT(TcH == Tc*H); + // Inputs must be contiguous; enforce at graph construction time + GGML_ASSERT(ggml_is_contiguous(q2d)); + GGML_ASSERT(ggml_is_contiguous(k2d)); + GGML_ASSERT(ggml_is_contiguous(w2d)); + GGML_ASSERT(ggml_is_contiguous(ks)); + ggml_tensor * q2d_c = q2d; + ggml_tensor * k2d_c = k2d; + ggml_tensor * w2d_c = w2d; + ggml_tensor * ks_c = ks; + // Optional starts/ends (per-token windows) + const int * dStarts = nullptr; + const int * dEnds = nullptr; + if (dst->src[4] && dst->src[4]->type == GGML_TYPE_I32) dStarts = (const int *)dst->src[4]->data; + if (dst->src[5] && dst->src[5]->type == GGML_TYPE_I32) dEnds = (const int *)dst->src[5]->data; +// Optional profiling for fused indexer + auto * __prof_env2 = getenv("LLAMA_SPARSE_PROF"); + auto * __prof_each_env = getenv("LLAMA_SPARSE_PROF_EACH"); + cudaEvent_t __i0, __i1; + bool __do_prof2 = false; + float __ms2 = 0.0f; + if (__prof_env2 && *__prof_env2) { + cudaEventCreate(&__i0); + cudaEventCreate(&__i1); + __do_prof2 = true; + } + + // Promote half/bf16 to float for now (Phase 1) + ggml_cuda_pool_alloc qf(ctx.pool(ggml_cuda_get_device()), (size_t)D*TcH); + ggml_cuda_pool_alloc kf(ctx.pool(ggml_cuda_get_device()), (size_t)D*kv); + const to_fp32_cuda_t to_q = ggml_get_to_fp32_cuda(q2d_c->type); + const to_fp32_cuda_t to_k = ggml_get_to_fp32_cuda(k2d_c->type); + if (to_q) { + to_q((const void *)q2d_c->data, (float *)qf.get(), (size_t)D*TcH, ctx.stream()); + } else { + // F32 fast path: device-to-device copy + CUDA_CHECK(cudaMemcpyAsync((void *)qf.get(), (const void *)q2d_c->data, + sizeof(float)*(size_t)D*TcH, cudaMemcpyDeviceToDevice, ctx.stream())); + } + if (to_k) { + to_k((const void *)k2d_c->data, (float *)kf.get(), (size_t)D*kv, ctx.stream()); + } else { + CUDA_CHECK(cudaMemcpyAsync((void *)kf.get(), (const void *)k2d_c->data, + sizeof(float)*(size_t)D*kv, cudaMemcpyDeviceToDevice, ctx.stream())); + } + // Launch naive device kernel (implemented in indexer-fused.cu) directly writing to dst + if (__do_prof2) { cudaEventRecord(__i0, ctx.stream()); } + +#ifndef NDEBUG + printf("[GGML_OP_INDEXER_FUSED] D=%d H=%d Tc=%d kv=%d TcH=%d\n", D, H, Tc, kv, TcH); + printf("[GGML_OP_INDEXER_FUSED] ptrs dQ=%p dK=%p dW=%p dKS=%p dOut=%p\n", (void*)qf.get(), (void*)kf.get(), w2d_c->data, ks_c->data, dst->data); + fflush(stdout); +#endif + ggml_cuda_indexer_logits_fused_device(ctx, (const float *)qf.get(), (const float *)kf.get(), (const float *)w2d_c->data, (const float *)ks_c->data, dStarts, dEnds, D, H, Tc, kv, (float *)dst->data); + CUDA_CHECK(cudaGetLastError()); + if (__do_prof2) { + cudaEventRecord(__i1, ctx.stream()); + cudaEventSynchronize(__i1); + cudaEventElapsedTime(&__ms2, __i0, __i1); + cudaEventDestroy(__i0); + cudaEventDestroy(__i1); + static int __cnt_idx_cuda = 0; + static double __sum_idx_cuda = 0.0; + __sum_idx_cuda += __ms2; + __cnt_idx_cuda++; + if (__prof_each_env && *__prof_each_env) { + fprintf(stderr, "[PROFILE] IDX_TILE CUDA D=%d H=%d Tc=%d kv=%d ms=%.3f\n", + D, H, Tc, kv, (float)(__ms2)); + } else { + if (__cnt_idx_cuda % 50 == 0) { + fprintf(stderr, "[PROFILE] IDX_TILE CUDA D=%d H=%d Tc=%d kv=%d avg_ms=%.3f over 50 calls\n", + D, H, Tc, kv, (float)(__sum_idx_cuda/50.0)); + __sum_idx_cuda = 0.0; + } + } + } + (void)D; (void)H; (void)Tc; (void)kv; (void)TcH; // silence warnings if asserts disabled + } + break; + + case GGML_OP_SPARSE_MLA_DECODE: + { + ggml_tensor * q2d = dst->src[0]; + ggml_tensor * kc = dst->src[1]; + ggml_tensor * vc = dst->src[2]; + ggml_tensor * idx = dst->src[3]; + int Dq = (int)q2d->ne[0]; + int Hq = (int)q2d->ne[1]; + int Dk = (int)kc->ne[0]; + int Hkv = (int)kc->ne[1]; + int Nkv = (int)kc->ne[2]; + int Dv = (int)vc->ne[0]; + GGML_ASSERT(Dq == Dk); + // Allow MQA/GQA: Hq may differ from Hkv; kernel maps h-> (h % Hkv) + int K = (int)idx->ne[0]; + float kq_scale = ggml_get_op_params_f32(dst, 0); + float softcap = ggml_get_op_params_f32(dst, 1); + + // promote to float + ggml_cuda_pool_alloc qtmp(ctx.pool(ggml_cuda_get_device()), (size_t)Dq*Hq); + const to_fp32_cuda_t to_q = ggml_get_to_fp32_cuda(q2d->type); + if (to_q) { to_q((const void *)q2d->data, (float *)qtmp.get(), (size_t)Dq*Hq, ctx.stream()); } + else { CUDA_CHECK(cudaMemcpyAsync((void*)qtmp.get(), q2d->data, sizeof(float)*(size_t)Dq*Hq, cudaMemcpyDeviceToDevice, ctx.stream())); } + ggml_cuda_pool_alloc ktmp(ctx.pool(ggml_cuda_get_device()), (size_t)Dk*Hkv*Nkv); + const to_fp32_cuda_t to_k = ggml_get_to_fp32_cuda(kc->type); + if (to_k) { to_k((const void *)kc->data, (float *)ktmp.get(), (size_t)Dk*Hkv*Nkv, ctx.stream()); } + else { CUDA_CHECK(cudaMemcpyAsync((void*)ktmp.get(), kc->data, sizeof(float)*(size_t)Dk*Hkv*Nkv, cudaMemcpyDeviceToDevice, ctx.stream())); } + ggml_cuda_pool_alloc vtmp(ctx.pool(ggml_cuda_get_device()), (size_t)Dv*Hkv*Nkv); + const to_fp32_cuda_t to_v = ggml_get_to_fp32_cuda(vc->type); + if (to_v) { to_v((const void *)vc->data, (float *)vtmp.get(), (size_t)Dv*Hkv*Nkv, ctx.stream()); } + else { CUDA_CHECK(cudaMemcpyAsync((void*)vtmp.get(), vc->data, sizeof(float)*(size_t)Dv*Hkv*Nkv, cudaMemcpyDeviceToDevice, ctx.stream())); } + auto * __prof_env3 = getenv("LLAMA_SPARSE_PROF"); + cudaEvent_t __m0, __m1; bool __do_prof3 = false; float __ms3 = 0.0f; + if (__prof_env3 && *__prof_env3) { cudaEventCreate(&__m0); cudaEventCreate(&__m1); __do_prof3 = true; cudaEventRecord(__m0, ctx.stream()); } + ggml_cuda_sparse_mla_decode_device(ctx, (const float*)qtmp.get(), (const float*)ktmp.get(), (const float*)vtmp.get(), (const int32_t*)idx->data, + Dq, Hq, Hkv, Dv, Nkv, K, kq_scale, softcap, (float*)dst->data); + CUDA_CHECK(cudaGetLastError()); + if (__do_prof3) { cudaEventRecord(__m1, ctx.stream()); cudaEventSynchronize(__m1); cudaEventElapsedTime(&__ms3, __m0, __m1); cudaEventDestroy(__m0); cudaEventDestroy(__m1); static int __cnt_mla = 0; static double __sum_mla = 0.0; __sum_mla += __ms3; __cnt_mla++; if (__cnt_mla % 50 == 0) { fprintf(stderr, "[PROFILE] SPARSE_MLA_DECODE D=%d Hq=%d Hkv=%d Dv=%d Nkv=%d K=%d avg_ms=%.3f over 50 calls\n", Dq, Hq, Hkv, Dv, Nkv, K, (float)(__sum_mla/50.0)); __sum_mla = 0.0; } } + } + break; + default: + + return false; } @@ -3379,6 +3630,31 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return false; } break; + case GGML_OP_INDEXER_FUSED: + { + const struct ggml_tensor * q = op->src[0]; + const struct ggml_tensor * k = op->src[1]; + const struct ggml_tensor * w = op->src[2]; + const struct ggml_tensor * ks = op->src[3]; + if (!q || !k || !w || !ks) return false; + if (!ggml_is_contiguous(q) || !ggml_is_contiguous(k) || !ggml_is_contiguous(w) || !ggml_is_contiguous(ks)) return false; + if (!(q->type == GGML_TYPE_F32 || q->type == GGML_TYPE_F16 || q->type == GGML_TYPE_BF16)) return false; + if (!(k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16 || k->type == GGML_TYPE_BF16)) return false; + if (w->type != GGML_TYPE_F32) return false; + if (ks->type != GGML_TYPE_F32) return false; + if (q->ne[0] != k->ne[0]) return false; + return true; + } break; + case GGML_OP_SPARSE_TOPK_RADIX: + { + const struct ggml_tensor * a = op->src[0]; + if (a == NULL) return false; + if (a->type != GGML_TYPE_F32) return false; + if (op->type != GGML_TYPE_I32) return false; + if (a->ne[2] != 1 || a->ne[3] != 1) return false; + if (op->ne[2] != 1 || op->ne[3] != 1) return false; + return ggml_is_contiguous(a); + } break; case GGML_OP_GLU: switch (ggml_get_glu_op(op)) { case GGML_GLU_OP_REGLU: @@ -3669,6 +3945,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_SGD: return true; + + case GGML_OP_SPARSE_MLA_DECODE: + return true; + default: return false; } @@ -3917,3 +4197,79 @@ ggml_backend_t ggml_backend_cuda_init(int device) { } GGML_BACKEND_DL_IMPL(ggml_backend_cuda_reg) + + +// Device-side mask window ends derivation (appended) +__global__ void k_mask_window_ends(const float * __restrict__ mask, int N, int T, int * __restrict__ ends) { + int t = blockIdx.x * blockDim.x + threadIdx.x; + if (t >= T) return; + // mask is [N, T] row-major: element(i,t) at mask[i + (size_t)N * t] + const float * col = mask + (size_t)N * (size_t)t; + int e = 0; + for (int i = N-1; i >= 0; --i) { + float v = col[i]; + if (v > -1.0e29f) { e = i+1; break; } + } + ends[t] = e; +} + +extern "C" void ggml_cuda_mask_window_ends_device(ggml_backend_cuda_context & ctx, + const float * dMask, int N_kv, int T, + int * dEnds) { + int block = 128; int grid = (T + block - 1) / block; + k_mask_window_ends<<>>(dMask, N_kv, T, dEnds); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { ggml_cuda_error("k_mask_window_ends launch", __func__, __FILE__, __LINE__, cudaGetErrorString(err)); } +} + +extern "C" void ggml_cuda_mask_window_ends_device_to_host(ggml_backend_cuda_context & ctx, + const float * dMask, int N_kv, int T, int * hEnds) { + int * dEnds = nullptr; + size_t bytes = sizeof(int) * (size_t)T; + cudaError_t err = cudaMalloc((void**)&dEnds, bytes); + if (err != cudaSuccess) { ggml_cuda_error("cudaMalloc dEnds", __func__, __FILE__, __LINE__, cudaGetErrorString(err)); } + int block = 128; int grid = (T + block - 1) / block; + k_mask_window_ends<<>>(dMask, N_kv, T, dEnds); + err = cudaGetLastError(); + if (err != cudaSuccess) { ggml_cuda_error("k_mask_window_ends launch", __func__, __FILE__, __LINE__, cudaGetErrorString(err)); } + err = cudaMemcpy(hEnds, dEnds, bytes, cudaMemcpyDeviceToHost); + if (err != cudaSuccess) { ggml_cuda_error("cudaMemcpy D2H dEnds", __func__, __FILE__, __LINE__, cudaGetErrorString(err)); } + cudaFree(dEnds); +} + +extern "C" void ggml_cuda_mask_window_ends_device_to_host_simple(const float * dMask, int N_kv, int T, int * hEnds) { + // Use current device and default stream + int * dEnds = nullptr; + size_t bytes = sizeof(int) * (size_t)T; + cudaError_t err = cudaMalloc((void**)&dEnds, bytes); + if (err != cudaSuccess) { ggml_cuda_error("cudaMalloc dEnds", __func__, __FILE__, __LINE__, cudaGetErrorString(err)); } + int block = 128; int grid = (T + block - 1) / block; + k_mask_window_ends<<>>(dMask, N_kv, T, dEnds); + err = cudaGetLastError(); + if (err != cudaSuccess) { ggml_cuda_error("k_mask_window_ends launch", __func__, __FILE__, __LINE__, cudaGetErrorString(err)); } + err = cudaMemcpy(hEnds, dEnds, bytes, cudaMemcpyDeviceToHost); + if (err != cudaSuccess) { ggml_cuda_error("cudaMemcpy D2H dEnds", __func__, __FILE__, __LINE__, cudaGetErrorString(err)); } + cudaFree(dEnds); +} + +__global__ void k_mask_window_starts(const float * __restrict__ mask, int N, int T, int * __restrict__ starts) { + int t = blockIdx.x * blockDim.x + threadIdx.x; + if (t >= T) return; + const float * col = mask + (size_t)N * (size_t)t; + int s = 0; + for (int i = 0; i < N; ++i) { float v = col[i]; if (v > -1.0e29f) { s = i; break; } } + starts[t] = s; +} + +extern "C" void ggml_cuda_mask_window_starts_device_to_host_simple(const float * dMask, int N_kv, int T, int * hStarts) { + int * dStarts = nullptr; size_t bytes = sizeof(int) * (size_t)T; + cudaError_t err = cudaMalloc((void**)&dStarts, bytes); + if (err != cudaSuccess) { ggml_cuda_error("cudaMalloc dStarts", __func__, __FILE__, __LINE__, cudaGetErrorString(err)); } + int block = 128; int grid = (T + block - 1) / block; + k_mask_window_starts<<>>(dMask, N_kv, T, dStarts); + err = cudaGetLastError(); + if (err != cudaSuccess) { ggml_cuda_error("k_mask_window_starts launch", __func__, __FILE__, __LINE__, cudaGetErrorString(err)); } + err = cudaMemcpy(hStarts, dStarts, bytes, cudaMemcpyDeviceToHost); + if (err != cudaSuccess) { ggml_cuda_error("cudaMemcpy D2H dStarts", __func__, __FILE__, __LINE__, cudaGetErrorString(err)); } + cudaFree(dStarts); +} diff --git a/ggml/src/ggml-cuda/indexer-fused.cu b/ggml/src/ggml-cuda/indexer-fused.cu new file mode 100644 index 00000000000..65a66357edd --- /dev/null +++ b/ggml/src/ggml-cuda/indexer-fused.cu @@ -0,0 +1,1509 @@ +#include "common.cuh" +#ifndef LLAMA_ENABLE_CP_ASYNC +#define LLAMA_ENABLE_CP_ASYNC 1 +#endif + +#include +using namespace nvcuda; + +#include + +#include +#include +#include +#include +#include +#include "../../include/ggml-cuda-indexer.h" + +struct fp8_e4_t; +extern "C" int call(fp8_e4_t* __restrict__ IndexQ, fp8_e4_t* __restrict__ IndexK, float* __restrict__ IndexKScale, float* __restrict__ Logits, float* __restrict__ Weights, int* __restrict__ CuSeqLenKS, int* __restrict__ CuSeqLenKE, int seq_len_kv, int seq_len, cudaStream_t stream); +extern "C" int init(); +extern "C" const char* get_last_error(); + + +#ifndef SEL_DEBUG +#endif + + +#if __CUDA_ARCH__ >= 800 +// Forward declarations for cp.async helpers used by kernels defined earlier in this file +static __device__ __forceinline__ void cp_async_16B_all(void* __restrict__ dst, const void* __restrict__ src, size_t bytes); +static __device__ __forceinline__ void cp_async_16B_issue_all(void* __restrict__ dst, const void* __restrict__ src, size_t bytes); +#endif + +#ifndef LAUNCH_PROFILE_KERNEL +#define LAUNCH_PROFILE_KERNEL(TAG_STR, TAGNAME, STREAM, LAUNCH_STMT, D_, H_, Tc_, KV_) do { \ + if (__prof_env && *__prof_env) { \ + cudaEvent_t __e0, __e1; cudaEventCreate(&__e0); cudaEventCreate(&__e1); \ + cudaEventRecord(__e0, STREAM); \ + LAUNCH_STMT; \ + cudaEventRecord(__e1, STREAM); \ + cudaEventSynchronize(__e1); \ + float __ms = 0.0f; cudaEventElapsedTime(&__ms, __e0, __e1); \ + cudaEventDestroy(__e0); cudaEventDestroy(__e1); \ + static int __cnt_##TAGNAME = 0; \ + static double __sum_##TAGNAME = 0.0; \ + __sum_##TAGNAME += __ms; \ + __cnt_##TAGNAME++; \ + if (__prof_each_env && *__prof_each_env) { \ + fprintf(stderr, "[" TAG_STR "] TILELANG_INDEXER D=%d H=%d Tc=%d kv=%d ms=%.3f\n", D_, H_, Tc_, KV_, __ms); \ + } else if (__cnt_##TAGNAME % 50 == 0) { \ + fprintf(stderr, "[" TAG_STR "] TILELANG_INDEXER D=%d H=%d Tc=%d kv=%d avg_ms=%.3f over 50 calls\n", D_, H_, Tc_, KV_, (float)(__sum_##TAGNAME/50.0)); \ + __sum_##TAGNAME = 0.0; \ + } \ + } else { \ + LAUNCH_STMT; \ + } \ +} while(0) +#endif + +#if __CUDA_ARCH__ >= 800 && defined(LLAMA_ENABLE_CP_ASYNC) +static __device__ inline void cp_async_16b(void * smem_ptr, const void * gmem_ptr) { + unsigned smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile ("cp.async.cg.shared.global [%0], [%1], 16;\n" :: "r"(smem), "l"(gmem_ptr)); +} +static __device__ inline void cp_async_commit() { + asm volatile ("cp.async.commit_group;\n" ::); +} +static __device__ inline void cp_async_wait() { + asm volatile ("cp.async.wait_group 0;\n" ::); +} +#endif + +// helper kernels +static __device__ __forceinline__ uint8_t f32_to_fp8e4m3(float); +static __device__ __forceinline__ float fp8e4m3_to_f32(uint8_t); + +__global__ void k_fill_int(int *arr, int n, int val) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) arr[i] = val; +} + +__global__ void k_colmajor_DN_to_rowmajor_ND(const float *src, int D, int N, float *dst) { + int n = blockIdx.x * blockDim.x + threadIdx.x; + int d = blockIdx.y * blockDim.y + threadIdx.y; + if (n < N && d < D) { + dst[(size_t)n * (size_t)D + (size_t)d] = src[(size_t)d + (size_t)D * (size_t)n]; + } +} +__global__ void k_transpose_TcKv_to_KvTc(const float *in, int Tc, int kv, float *out) { + int t = blockIdx.x * blockDim.x + threadIdx.x; + int k = blockIdx.y * blockDim.y + threadIdx.y; + if (t < Tc && k < kv) out[k + (size_t)kv * t] = in[(size_t)t * kv + k]; +} + +// Row-major float32 -> FP8 E4M3 packer (no per-row scaling) +__global__ void k_rowmajor_f32_to_fp8_e4m3(const float *src, int rows, int cols, unsigned char *dst) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + size_t total = (size_t)rows * (size_t)cols; + if (idx < total) { + float x = src[idx]; + __nv_fp8_storage_t v = __nv_cvt_float_to_fp8(x, __NV_SATFINITE, __NV_E4M3); + dst[idx] = (unsigned char) v; + } +} + +// Row-major float32 -> per-row absolute max +__global__ void k_rowmajor_f32_rowwise_absmax(const float *src, int rows, int cols, float *row_amax) { + int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= rows) return; + float maxv = 0.0f; + for (int c = 0; c < cols; ++c) { + float v = fabsf(src[(size_t)row * (size_t)cols + (size_t)c]); + if (v > maxv) maxv = v; + } + row_amax[row] = maxv; +} + +// Compute per-row FP8 scales (amax/448) and their reciprocals +__global__ void k_fp8_compute_row_scales(const float *row_amax, int rows, float *sf, float *inv_sf) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= rows) return; + float a = row_amax[i]; + if (a < 1e-4f) a = 1e-4f; + float s = a / 448.0f; + sf[i] = s; + inv_sf[i] = 1.0f / s; +} + +// Row-major float32 -> FP8 E4M3 with per-row scaling +__global__ void k_rowmajor_f32_to_fp8_e4m3_rowwise_scaled(const float *src, int rows, int cols, + const float *inv_sf, unsigned char *dst) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + size_t total = (size_t)rows * (size_t)cols; + if (idx >= total) return; + int row = (int)(idx / (size_t)cols); + float x = src[idx] * inv_sf[row]; + __nv_fp8_storage_t v = __nv_cvt_float_to_fp8(x, __NV_SATFINITE, __NV_E4M3); + dst[idx] = (unsigned char) v; +} + +// Elementwise product of two float vectors +__global__ void k_elemwise_mul(const float *a, const float *b, float *out, int n) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= n) return; + out[i] = a[i] * b[i]; +} + +// helpers to read env +static inline int getenv_int_(const char * name, int def) { + const char * s = getenv(name); + if (!s || !*s) return def; + int v = atoi(s); + return v > 0 ? v : def; +} + + +// Simple baseline fused kernel: compute K^T * Q -> ReLU, then per-head weighted sum, multiply k_scale. +// This is a placeholder for a fully-optimized version. It assumes row-major contiguous inputs. +static inline bool sparse_debug_on(){ const char *d=getenv("LLAMA_SPARSE_DEBUG_INDEXER"); return d && *d && atoi(d)!=0; } + + +// Tiled, shared-memory fused kernel (float inputs, float accum) +// Q: [D, Tc*H], K: [D, kv], W: [H, Tc], k_scale: [kv]; Out: [kv, Tc] +__global__ void k_indexer_logits_tiled_f32( + const float * __restrict__ Q, + const float * __restrict__ K, + const float * __restrict__ W, + const float * __restrict__ k_scale, + int D, int H, int Tc, int kv, + const int * __restrict__ starts, + const int * __restrict__ ends, + int D_TILE, int BLOCK_Q, int BLOCK_N, int exact_flag, + int HEAD_CHUNK_ARG, + int PIPE_STAGES_ARG, + float * __restrict__ Out) { + __shared__ int s_min_blk; + __shared__ int s_max_blk; + __shared__ float s_K_sf[512]; + + // Dynamic select exact vs optimized based on workload or env + bool exact = (exact_flag != 0); + /* env read on host */ + (void)Tc; (void)kv; + if (exact) { + // Exact global-load path (bit-exact with reference; slower) + int t_local = threadIdx.x; + int k_local = threadIdx.y; + int t0 = blockIdx.x * BLOCK_Q; + int k0 = blockIdx.y * BLOCK_N; + int token = t0 + t_local; + int kv_idx = k0 + k_local; + // Compute union window for this block + if (threadIdx.x == 0 && threadIdx.y == 0) { + int smin = 0; int smax = kv; + if (starts != nullptr && ends != nullptr) { + smin = kv; smax = 0; + for (int q = 0; q < BLOCK_Q; ++q) { + int tok = t0 + q; + if (tok < Tc) { + int s0 = starts[tok]; int e0 = ends[tok]; + if (s0 < 0) s0 = 0; if (s0 > kv) s0 = kv; + if (e0 < 0) e0 = 0; if (e0 > kv) e0 = kv; + if (s0 < smin) smin = s0; + if (e0 > smax) smax = e0; + } + } + if (smin > smax) smin = smax; + } + s_min_blk = smin; s_max_blk = smax; + } + __syncthreads(); + int smin_blk = s_min_blk; int smax_blk = s_max_blk; + bool in_bounds = (t_local < BLOCK_Q) && (k_local < BLOCK_N) && (token < Tc) && (kv_idx < kv); + bool in_union = (kv_idx >= smin_blk && kv_idx < smax_blk); + + float acc = 0.0f; + if (in_bounds && in_union) { + for (int h = 0; h < H; ++h) { + const float *qv = Q + (size_t)D * (token*H + h); + const float *kvp= K + (size_t)D * kv_idx; + float dot = 0.0f; + #pragma unroll 1 + for (int d = 0; d < D; ++d) dot += qv[d] * kvp[d]; + if (dot < 0.0f) dot = 0.0f; + acc += dot * W[h + (size_t)H * token]; + } + acc *= k_scale[kv_idx]; + if (starts != nullptr && ends != nullptr) { + int s0 = starts[token]; int e0 = ends[token]; + if (s0 < 0) s0 = 0; if (s0 > kv) s0 = kv; + if (e0 < 0) e0 = 0; if (e0 > kv) e0 = kv; + if (kv_idx < s0 || kv_idx >= e0) acc = 0.0f; + } + } + if (in_bounds) { + Out[kv_idx + (size_t)kv * token] = acc; + } + return; + } + // Optimized shared-memory tiled path with head-chunked reduction + int t_local = threadIdx.x; // [0..BLOCK_Q) + int k_local = threadIdx.y; // [0..BLOCK_N) + int t0 = blockIdx.x * BLOCK_Q; + int k0 = blockIdx.y * BLOCK_N; + int token = t0 + t_local; + int kv_idx = k0 + k_local; + // Compute union window for this block + if (threadIdx.x == 0 && threadIdx.y == 0) { + int smin = 0; int smax = kv; + if (starts != nullptr && ends != nullptr) { + smin = kv; smax = 0; + for (int q = 0; q < BLOCK_Q; ++q) { + int tok = t0 + q; + if (tok < Tc) { + int s0 = starts[tok]; int e0 = ends[tok]; + if (s0 < 0) s0 = 0; if (s0 > kv) s0 = kv; + if (e0 < 0) e0 = 0; if (e0 > kv) e0 = kv; + if (s0 < smin) smin = s0; + if (e0 > smax) smax = e0; + } + } + if (smin > smax) smin = smax; + } + s_min_blk = smin; s_max_blk = smax; + } + __syncthreads(); + + // Per-row K scale (amax/448) for BLOCK_N kv rows in this tile, matching CPU FP8 path + if (threadIdx.x == 0) { + int j = threadIdx.y; + if (j < BLOCK_N) { + int row = k0 + j; + float maxv = 0.0f; + if (row < kv) { + for (int d0 = 0; d0 < D; ++d0) { + float v = K[(size_t)d0 + (size_t)D * (size_t)row]; + float av = fabsf(v); + if (av > maxv) maxv = av; + } + } + if (maxv < 1e-4f) maxv = 1e-4f; + s_K_sf[j] = maxv / 448.0f; + } + } + __syncthreads(); + + // Head-chunk size + int Hc = HEAD_CHUNK_ARG > 0 ? HEAD_CHUNK_ARG : 16; + if (Hc < 1) Hc = 1; + if (Hc > H) Hc = H; + + extern __shared__ float shmem[]; + // Double-buffered layout if PIPE_STAGES_ARG >= 2: + // [K0][Q0][K1][Q1][W] else [K0][Q0][W] + int STAGES = (PIPE_STAGES_ARG >= 2 ? 2 : 1); + int sizeK = D_TILE * BLOCK_N; + int sizeQ = D_TILE * BLOCK_Q * Hc; + float * K0 = shmem; + float * Q0 = K0 + sizeK; + float * K1 = (STAGES == 2) ? (Q0 + sizeQ) : K0; + float * Q1 = (STAGES == 2) ? (K1 + sizeK) : Q0; + float * W_sh = (STAGES == 2) ? (Q1 + sizeQ) : (Q0 + sizeQ); + + // Accumulator per (kv,row) x (token,col) + float acc = 0.0f; + + for (int h0 = 0; h0 < H; h0 += Hc) { + int hc = min(Hc, H - h0); +#if SEL_DEBUG + if(threadIdx.x==0 && blockIdx.x==0){ + printf("[IDX_DBG] tiled_f32 params: D=%d H=%d Tc=%d kv=%d BLOCK_Q=%d BLOCK_N=%d Hc=%d\n", D, H, Tc, kv, BLOCK_Q, BLOCK_N, hc); + } +#endif + + // load weights W[h0:h0+hc, token-range] + // cooperative load: map 2D [hc, BLOCK_Q] + int stride2 = blockDim.x * blockDim.y; + int tid2 = threadIdx.y * blockDim.x + threadIdx.x; + for (int idx = tid2; idx < hc*BLOCK_Q; idx += stride2) { + int hi = idx / BLOCK_Q; + int q = idx % BLOCK_Q; + int tok = t0 + q; + float wv = 0.0f; + if (tok < Tc) wv = W[(h0 + hi) + (size_t)H * tok]; + W_sh[hi * BLOCK_Q + q] = wv; + } + __syncthreads(); + + // Compute S = K^T * Q over this head-chunk + float sum_hc = 0.0f; + // Accumulate per-head dot across D, then apply ReLU and weights once + const int MAX_HC = 64; + float dot_vec[MAX_HC]; + for (int i = 0; i < hc; ++i) dot_vec[i] = 0.0f; + for (int d0 = 0, stage = 0; d0 < D; d0 += D_TILE, stage ^= 1) { + int cur = min(D_TILE, D - d0); + int stride = blockDim.x * blockDim.y; + int tid = threadIdx.y * blockDim.x + threadIdx.x; + + // Select buffers + float * Kbuf = (STAGES == 2 && (stage == 1)) ? K1 : K0; + float * Qbuf = (STAGES == 2 && (stage == 1)) ? Q1 : Q0; + + // Cooperative (or cp.async) load Kbuf: [cur, BLOCK_N] and Qbuf: [cur, BLOCK_Q*hc] +#if __CUDA_ARCH__ >= 800 && defined(LLAMA_ENABLE_CP_ASYNC) + if (PIPE_STAGES_ARG >= 2) { + // K: cooperative load + for (int idx = tid; idx < cur * BLOCK_N; idx += stride) { + int di = idx / BLOCK_N; + int j = idx % BLOCK_N; + int gk = k0 + j; + float v = 0.0f; + if (gk < kv) v = K[(size_t)(d0 + di) + (size_t)D * gk]; + Kbuf[di * BLOCK_N + j] = v; + } + // Q: cp.async 16B using transposed layout [BLOCK_Q*hc, D_TILE] + int groups = (cur / 4); + for (int idx = tid; idx < groups * BLOCK_Q * hc; idx += stride) { + int di4 = (idx % groups) * 4; + int rem = idx / groups; + int q = rem % BLOCK_Q; + int hi = rem / BLOCK_Q; + int gt = t0 + q; + const float * gptr = (gt < Tc) ? &Q[(size_t)(d0 + di4) + (size_t)D * (gt*H + (h0 + hi))] : nullptr; + float * sptr = &Qbuf[(hi * BLOCK_Q + q) * D_TILE + di4]; + if (gptr) { + cp_async_16b(sptr, gptr); + } else { + // zero-fill when out of range + reinterpret_cast(sptr)[0] = make_float4(0.f,0.f,0.f,0.f); + } + } + cp_async_commit(); + cp_async_wait(); + __syncthreads(); + // Handle Q tail when cur % 4 != 0 via cooperative scalar loads into transposed storage + int tail = cur & 3; + if (tail) { + for (int idx = tid; idx < tail * BLOCK_Q * hc; idx += stride) { + int di = idx % tail; + int rem = idx / tail; + int q = rem % BLOCK_Q; + int hi = rem / BLOCK_Q; + int gt = t0 + q; + float v = 0.0f; + if (gt < Tc) v = Q[(size_t)(d0 + (cur - tail) + di) + (size_t)D * (gt*H + (h0 + hi))]; + Qbuf[(hi * BLOCK_Q + q) * D_TILE + (cur - tail) + di] = v; + } + } + } else +#endif + { + for (int idx = tid; idx < cur * BLOCK_N; idx += stride) { + int di = idx / BLOCK_N; + int j = idx % BLOCK_N; + int gk = k0 + j; + float v = 0.0f; + if (gk < kv) v = K[(size_t)(d0 + di) + (size_t)D * gk]; + Kbuf[di * BLOCK_N + j] = v; + } + for (int idx = tid; idx < cur * BLOCK_Q * hc; idx += stride) { + int di = idx / (BLOCK_Q * hc); + int rem = idx % (BLOCK_Q * hc); + int q = rem % BLOCK_Q; + int hi = rem / BLOCK_Q; + int gt = t0 + q; + float v = 0.0f; + if (gt < Tc) v = Q[(size_t)(d0 + di) + (size_t)D * (gt*H + (h0 + hi))]; + Qbuf[di * (BLOCK_Q * hc) + hi * BLOCK_Q + q] = v; + } + __syncthreads(); + } + + // Apply FP8 E4M3 quant/dequant in shared memory to match CPU Lightning Indexer + // Kbuf: [cur, BLOCK_N] row-major, with per-row scale s_K_sf[j] + for (int idx = tid; idx < cur * BLOCK_N; idx += stride) { + int di = idx / BLOCK_N; + int j = idx % BLOCK_N; + float x = Kbuf[di * BLOCK_N + j]; + float sf = s_K_sf[j]; + float scaled = x / sf; + uint8_t code = f32_to_fp8e4m3(scaled); + float dec = fp8e4m3_to_f32(code); + Kbuf[di * BLOCK_N + j] = dec; + } + // Qbuf: contiguous buffer of size cur * BLOCK_Q * hc (layout depends on cp.async) + for (int idx = tid; idx < cur * BLOCK_Q * hc; idx += stride) { + float x = Qbuf[idx]; + uint8_t code = f32_to_fp8e4m3(x); + float dec = fp8e4m3_to_f32(code); + Qbuf[idx] = dec; + } + __syncthreads(); + + // Compute partial dot across cur for this (kv_idx, token), accumulating per head + float * Kcomp = Kbuf; + for (int di = 0; di < cur; ++di) { + float kval = Kcomp[di * BLOCK_N + k_local]; + for (int hi = 0; hi < hc; ++hi) { +#if __CUDA_ARCH__ >= 800 && defined(LLAMA_ENABLE_CP_ASYNC) + // When cp.async-enabled, Qbuf is transposed [BLOCK_Q*hc, D_TILE] + float qval = Qbuf[(hi * BLOCK_Q + t_local) * D_TILE + di]; +#else + // Cooperative layout: Qbuf [cur, BLOCK_Q*hc] + float qval = Qbuf[di * (BLOCK_Q * hc) + hi * BLOCK_Q + t_local]; +#endif + dot_vec[hi] += kval * qval; + } + } + __syncthreads(); + } + // Apply ReLU and weights, then sum into this tile accumulator + for (int hi = 0; hi < hc; ++hi) { + float tmp = dot_vec[hi]; + if (tmp < 0.0f) tmp = 0.0f; + sum_hc += tmp * W_sh[hi * BLOCK_Q + t_local]; + } + acc += sum_hc; + } + + // Apply combined scale k_scale * K_sf (per-row amax scale) to match CPU FP8 path + float sf_row = s_K_sf[k_local]; + acc *= k_scale[kv_idx] * sf_row; + Out[kv_idx + (size_t)kv * token] = (starts && ends) ? ((kv_idx >= starts[token] && kv_idx < ends[token]) ? acc : 0.0f) : acc; +} + +// mqa_attn_return_logits_kernel_port.cu +// Self-contained: cp.async double-buffer + WMMA (FP16) port of TileLang kernel. +// Grid: grid.x = ceil_div(seq_len, block_Q); block.x = threads (multiple of 32). +// Build: nvcc -std=c++17 -arch=sm_80 -lineinfo -Xptxas -v -c mqa_attn_return_logits_kernel_port.cu +// Note: Inputs (IndexQ/IndexK/Weights) are float32; converted to __half in shared. + +static __device__ __forceinline__ size_t align16(size_t x){ return (x+15u)&~size_t(15u); } + +#if __CUDA_ARCH__ >= 800 +static __device__ __forceinline__ +void cp_async_16B_all(void* __restrict__ dst, const void* __restrict__ src, size_t bytes){ + const size_t n16 = bytes & ~size_t(15); + const char *s = (const char*)src; char *d = (char*)dst; + for(size_t i=(size_t)threadIdx.x*16;i= 800 +static __device__ __forceinline__ +void cp_async_16B_issue_all(void* __restrict__ dst, const void* __restrict__ src, size_t bytes){ + const size_t n16 = bytes & ~size_t(15); + const char *s = (const char*)src; char *d = (char*)dst; + for(size_t i=(size_t)threadIdx.x*16;i> 5; + + const size_t Q_rows = (size_t)block_Q * (size_t)heads; + const size_t Q_cols = (size_t)index_dim; + const size_t K_rows_max = (size_t)block_N; + const size_t K_cols = (size_t)index_dim; + + // f32 staging + off = align16(off); float* Qs_f32 = (float*)(smem + off); off += Q_rows*Q_cols*sizeof(float); + // two ping-pong K slabs (f32) + off = align16(off); float* K0_f32 = (float*)(smem + off); off += K_rows_max*K_cols*sizeof(float); + off = align16(off); float* K1_f32 = (float*)(smem + off); off += K_rows_max*K_cols*sizeof(float); + // two ping-pong k_scale vectors + off = align16(off); float* ks0 = (float*)(smem + off); off += K_rows_max*sizeof(float); + off = align16(off); float* ks1 = (float*)(smem + off); off += K_rows_max*sizeof(float); + // logits scratch for current K tile + off = align16(off); float* logits_blk = (float*)(smem + off); off += K_rows_max*(size_t)block_Q*sizeof(float); + + // f16 WMMA slabs + off = align16(off); __half* Qs_f16 = (__half*)(smem + off); off += Q_rows*Q_cols*sizeof(__half); + off = align16(off); __half* K0_f16 = (__half*)(smem + off); off += K_rows_max*K_cols*sizeof(__half); + off = align16(off); __half* K1_f16 = (__half*)(smem + off); off += K_rows_max*K_cols*sizeof(__half); + + // per-warp C tile scratch (float, WM*WN each) + off = align16(off); float* Csh = (float*)(smem + off); off += (size_t)warps*(WM*WN)*sizeof(float); + + // ---- compute cu_k_s_min / cu_k_e_max for this block ---- + __shared__ int cu_k_s_min_s, cu_k_e_max_s; + if(threadIdx.x==0){ + int smin= 2147483647, emax= -2147483648; + for(int bq=0;bqseq_len_kv) v = seq_len_kv; + if(v < smin) smin = v; + } + for(int bq=0;bqseq_len_kv) v = seq_len_kv; + if(v > emax) emax = v; + } + cu_k_s_min_s = smin; cu_k_e_max_s = emax; + } + __syncthreads(); + const int cu_k_s_min = cu_k_s_min_s; + const int cu_k_e_max = cu_k_e_max_s; + + // ---- stage Q block (seq range [seq_len_i, seq_len_i+block_Q), all heads) ---- + { + const size_t bytesQ = Q_rows*Q_cols*sizeof(float); +#if __CUDA_ARCH__ >= 800 + bool ok = ((((uintptr_t)Qs_f32)&0xF)==0) && + ((((uintptr_t)(IndexQ + (size_t)seq_len_i*(size_t)heads*(size_t)index_dim))&0xF)==0) && + (bytesQ>=16); + if(ok){ + const float* srcQ = IndexQ + (size_t)seq_len_i*(size_t)heads*(size_t)index_dim; + cp_async_16B_all(Qs_f32, srcQ, bytesQ); + } else +#endif + { + for(size_t t=threadIdx.x;t0){ + const size_t bytesK = (size_t)curN0*(size_t)index_dim*sizeof(float); + const size_t bytesKs= (size_t)curN0*sizeof(float); +#if __CUDA_ARCH__ >= 800 + bool okK = ((((uintptr_t)K0_f32)&0xF)==0) && ((((uintptr_t)(IndexK + (size_t)cu_k_s_min*(size_t)index_dim))&0xF)==0) && (bytesK>=16); + bool okKs = ((((uintptr_t)ks0 )&0xF)==0) && ((((uintptr_t)(IndexKScale + (size_t)cu_k_s_min))&0xF)==0) && (bytesKs>=16); + if(okK) cp_async_16B_all(K0_f32, IndexK + (size_t)cu_k_s_min*(size_t)index_dim, bytesK); else { +#endif + for(size_t t=threadIdx.x;t<(size_t)curN0*(size_t)index_dim;t+=blockDim.x){ + size_t r=t/(size_t)index_dim, c=t%(size_t)index_dim; + K0_f32[t]=IndexK[(size_t)(cu_k_s_min+(int)r)*(size_t)index_dim + c]; + } + __syncthreads(); +#if __CUDA_ARCH__ >= 800 + } + if(okKs) cp_async_16B_all(ks0, IndexKScale + (size_t)cu_k_s_min, bytesKs); else { +#endif + for(size_t t=threadIdx.x;t<(size_t)curN0;t+=blockDim.x) ks0[t]=IndexKScale[cu_k_s_min+(int)t]; + __syncthreads(); +#if __CUDA_ARCH__ >= 800 + } +#endif + // convert first K to half + for(size_t t=threadIdx.x;t<(size_t)curN0*(size_t)index_dim;t+=blockDim.x) K0_f16[t]=__float2half_rn(K0_f32[t]); + __syncthreads(); + } + + const int warp_id = threadIdx.x>>5; + const int lane = threadIdx.x&31; + const int warps_pb= warps; + const int Nq_all = (int)Q_rows; + const int tiles_n = (Nq_all + WN - 1)/WN; + + for(int it=0; it= 800 + bool okK = ((((uintptr_t)K1_f32)&0xF)==0) && ((((uintptr_t)(IndexK + (size_t)k_start_next*(size_t)index_dim))&0xF)==0) && (bytesK>=16); + bool okKs = ((((uintptr_t)ks1 )&0xF)==0) && ((((uintptr_t)(IndexKScale + (size_t)k_start_next))&0xF)==0) && (bytesKs>=16); + if(okK) cp_async_16B_all(K1_f32, IndexK + (size_t)k_start_next*(size_t)index_dim, bytesK); else { +#endif + for(size_t t=threadIdx.x;t<(size_t)curN1*(size_t)index_dim;t+=blockDim.x){ + size_t r=t/(size_t)index_dim, c=t%(size_t)index_dim; + K1_f32[t]=IndexK[(size_t)(k_start_next+(int)r)*(size_t)index_dim + c]; + } + __syncthreads(); +#if __CUDA_ARCH__ >= 800 + } + if(okKs) cp_async_16B_all(ks1, IndexKScale + (size_t)k_start_next, bytesKs); else { +#endif + for(size_t t=threadIdx.x;t<(size_t)curN1;t+=blockDim.x) ks1[t]=IndexKScale[k_start_next+(int)t]; + __syncthreads(); +#if __CUDA_ARCH__ >= 800 + } +#endif + } + + // Compute on current tile buffer (0) + const int k_start_cur = cu_k_s_min + it*block_N; + const int curN0_it = min(block_N, cu_k_e_max - k_start_cur); + + // If we just prefetched next, convert it while we compute? (keep simple: convert after prefetch) + if(it==0 || it>0){ /* K0_f16 already converted for first; for subsequent we swap below */ } + + // WMMA over current tile (K0_f16 vs Qs_f16) + const int tiles_m = (curN0_it + WM - 1)/WM; + + for(int tile_lin = warp_id; tile_lin < tiles_m*tiles_n; tile_lin += warps_pb){ + const int tile_m = tile_lin / tiles_n; // along bn + const int tile_n = tile_lin % tiles_n; // along (bq*heads) + + wmma::fragment c; + wmma::fill_fragment(c, 0.0f); + + for(int kk=0; kk a; + wmma::fragment b; + wmma::load_matrix_sync(a, Ap, index_dim); + wmma::load_matrix_sync(b, Bp, index_dim); + wmma::mma_sync(c, a, b, c); + } + + float* cptr = Csh + (size_t)warp_id*(WM*WN); + wmma::store_matrix_sync(cptr, c, WN, wmma::mem_row_major); + __syncwarp(); + + // Post-process (ReLU * weight * k_scale), reduce heads within warp, no atomics + const int base_bq = (tile_n*WN) / heads; + const int max_cols = min(WN, (int)Q_rows - tile_n*WN); + const int groups = max(0, (max_cols + heads - 1) / heads); + for (int mi = lane; mi < WM; mi += 32) { + int bn = tile_m*WM + mi; + if (bn >= curN0_it) continue; + float ks = ks0[bn]; + float acc_g[16]; + #pragma unroll + for (int u = 0; u < 16; ++u) acc_g[u] = 0.0f; + // accumulate contributions for all columns this tile covers + #pragma unroll + for (int cj = 0; cj < WN; ++cj) { + int ncol = tile_n*WN + cj; // = bq*heads + h + if (cj >= max_cols || ncol >= (int)Q_rows) break; + float val = cptr[mi*WN + cj]; + if (val < 0.f) val = 0.f; + int bq_abs = ncol / heads; + int h = ncol % heads; + int tok = seq_len_i + bq_abs; + float w = 0.0f; + if (tok < seq_len) w = Weights[(size_t)tok*(size_t)heads + h]; + int u = bq_abs - base_bq; + if (u >= 0 && u < 16) acc_g[u] += val * w; + } + // write partial sums to logits scratch (atomic to avoid inter-warp races across tile_n) + #pragma unroll + for (int u = 0; u < 16; ++u) { + if (u >= groups) break; + int bq_abs = base_bq + u; + int tok = seq_len_i + bq_abs; + if (bq_abs < block_Q && tok < seq_len) { + atomicAdd(&logits_blk[(size_t)bn*(size_t)block_Q + (size_t)bq_abs], acc_g[u] * ks); + } + } + } + __syncwarp(); + } + __syncthreads(); + + // Write this tile’s logits to global + for(size_t t=threadIdx.x;t<(size_t)curN0_it*(size_t)block_Q;t+=blockDim.x){ + int bn = (int)(t/(size_t)block_Q); + int bq = (int)(t%(size_t)block_Q); + int tok = seq_len_i + bq; + if(tok < seq_len){ + int kv_col = k_start_cur + bn; + if(kv_col < seq_len_kv) + Logits[(size_t)tok*(size_t)seq_len_kv + (size_t)kv_col] = + logits_blk[(size_t)bn*(size_t)block_Q + (size_t)bq]; + } + } + __syncthreads(); + + // Clear scratch for next tile + for(size_t t=threadIdx.x;t<(size_t)curN0_it*(size_t)block_Q;t+=blockDim.x) + logits_blk[t]=0.f; + __syncthreads(); + + // Swap ping-pong: convert next tile, then swap buffers + if(it+1 < iters){ + // convert next K (in K1_f32 -> K1_f16) + for(size_t t=threadIdx.x;t<(size_t)curN1*(size_t)index_dim;t+=blockDim.x) + K1_f16[t]=__float2half_rn(K1_f32[t]); + __syncthreads(); + // swap pointers + float* tmpf; __half* tmph; float* tmps; + tmpf=K0_f32; K0_f32=K1_f32; K1_f32=tmpf; + tmph=K0_f16; K0_f16=K1_f16; K1_f16=tmph; + tmps=ks0; ks0 =ks1; ks1 =tmps; + } + } +} + +// Device-side FP8 E4M3 helpers using native PTX conversions +static __device__ __forceinline__ uint8_t f32_to_fp8e4m3(float x) { + uint16_t tmp; + float zero = 0.0f; + asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(tmp) : "f"(zero), "f"(x)); + return static_cast(tmp & 0xFFu); +} + +static __device__ __forceinline__ float fp8e4m3_to_f32(uint8_t code) { + uint16_t bits = code; + uint32_t packed; + asm volatile("cvt.rn.f16x2.e4m3x2 %0, %1;" : "=r"(packed) : "h"(bits)); + return __half2float(reinterpret_cast(packed).x); +} + +// WMMA 16x16x16 BF16 (float input cast to bf16), one warp per block +__global__ void k_indexer_logits_wmma16_bf16( + const float * __restrict__ Q, // [D, Tc*H] + const float * __restrict__ K, // [D, kv] + const float * __restrict__ W, // [H, Tc] + const float * __restrict__ k_scale, // [kv] + int D, int H, int Tc, int kv, + const int * __restrict__ starts, + const int * __restrict__ ends, + float * __restrict__ Out) { +#if __CUDA_ARCH__ >= 800 + const int tokens_per_tile = max(1, 16 / H); + const int t0 = blockIdx.x * tokens_per_tile; + const int k0 = blockIdx.y * 16; + if (t0 >= Tc || k0 >= kv) return; + + int smin_blk = 0, smax_blk = kv; + if (starts != nullptr && ends != nullptr) { + smin_blk = kv; smax_blk = 0; + for (int tl = 0; tl < tokens_per_tile; ++tl) { + int tok = t0 + tl; + if (tok < Tc) { + int s0 = starts[tok]; int e0 = ends[tok]; + if (s0 < 0) s0 = 0; if (s0 > kv) s0 = kv; + if (e0 < 0) e0 = 0; if (e0 > kv) e0 = kv; + if (s0 < smin_blk) smin_blk = s0; + if (e0 > smax_blk) smax_blk = e0; + } + } + if (smin_blk > smax_blk) smin_blk = smax_blk; + } + int curN_all = kv - k0; if (curN_all < 0) curN_all = 0; int curN = curN_all < 16 ? curN_all : 16; + if (starts != nullptr && ends != nullptr) { + if (k0 >= smax_blk || (k0 + 16) <= smin_blk) { + int lane = threadIdx.x & 31; + for (int idx = lane; idx < curN * tokens_per_tile; idx += 32) { + int mi = idx / tokens_per_tile; int tl = idx % tokens_per_tile; int kv_idx = k0 + mi; int tok = t0 + tl; + if (tok < Tc && kv_idx < kv) Out[(size_t)kv_idx + (size_t)kv * tok] = 0.0f; + } + return; + } + } + + // Per-row K scale (amax/448) for 16-row tile + __shared__ float K_sf[16]; + int lane = threadIdx.x & 31; + if (lane < 16) K_sf[lane] = 1.0f; + __syncthreads(); + for (int mi = 0; mi < 16; ++mi) { + int kv_idx = k0 + mi; + float local_max = 0.0f; + if (kv_idx < kv) { + for (int d0 = lane; d0 < D; d0 += 32) { + float v = K[(size_t)d0 + (size_t)D * (size_t)kv_idx]; + float av = fabsf(v); + if (av > local_max) local_max = av; + } + } + for (int off = 16; off > 0; off >>= 1) { + float other = __shfl_down_sync(0xffffffff, local_max, off); + if (other > local_max) local_max = other; + } + if (lane == 0 && kv_idx < kv) { + float maxv = local_max; + if (maxv < 1e-4f) maxv = 1e-4f; + K_sf[mi] = maxv / 448.0f; + } + __syncwarp(); + } + __syncthreads(); + + wmma::fragment c_frag; + wmma::fill_fragment(c_frag, 0.0f); + + __shared__ __nv_bfloat16 A_sh[16*16]; // row-major + __shared__ __nv_bfloat16 B_sh[16*16]; // col-major + __shared__ float C_sh[16*16]; + + // Iterate K dimension in 16-slices + for (int d0 = 0; d0 < D; d0 += 16) { + int lane2 = threadIdx.x & 31; + // Load A_sh with FP8 quant/dequant and per-row scaling + for (int idx = lane2; idx < 16*16; idx += 32) { + int mi = idx / 16; + int di = idx % 16; + int kv_idx = k0 + mi; + __nv_bfloat16 v = __float2bfloat16(0.0f); + if (kv_idx < kv && d0 + di < D) { + float f = K[(size_t)(d0 + di) + (size_t)D * (size_t)kv_idx]; + float sf = K_sf[mi]; + float scaled = f / sf; + uint8_t code = f32_to_fp8e4m3(scaled); + float dec = fp8e4m3_to_f32(code); + v = __float2bfloat16(dec); + } + A_sh[mi * 16 + di] = v; + } + // Load B_sh (col-major) with FP8 quant/dequant + for (int idx = lane2; idx < 16*16; idx += 32) { + int di = idx / 16; // k index + int cj = idx % 16; // column index 0..15 => (tok_local,h) + int tok_local = cj / H; + int h = cj % H; + int tok = t0 + tok_local; + __nv_bfloat16 v = __float2bfloat16(0.0f); + if (tok < Tc && h < H && d0 + di < D) { + float f = Q[(size_t)(d0 + di) + (size_t)D * (size_t)(tok*H + h)]; + uint8_t code = f32_to_fp8e4m3(f); + float dec = fp8e4m3_to_f32(code); + v = __float2bfloat16(dec); + } + B_sh[cj * 16 + di] = v; + } + __syncthreads(); + + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::load_matrix_sync(a_frag, A_sh, 16); + wmma::load_matrix_sync(b_frag, B_sh, 16); + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + __syncthreads(); + } + + wmma::store_matrix_sync(C_sh, c_frag, 16, wmma::mem_row_major); + __syncthreads(); + + int lane3 = threadIdx.x & 31; + for (int idx = lane3; idx < 16 * tokens_per_tile; idx += 32) { + int mi = idx / tokens_per_tile; + int tl = idx % tokens_per_tile; + int kv_idx = k0 + mi; + int tok = t0 + tl; + if (kv_idx < kv && tok < Tc) { + float s = 0.0f; + int col_base = tl * H; + for (int h = 0; h < H; ++h) { + float v = C_sh[mi * 16 + (col_base + h)]; + if (v < 0.0f) v = 0.0f; + float w = W[h + (size_t)H * tok]; + s += v * w; + } + // Apply combined scale k_scale * K_sf, matching CPU FP8 path + s *= k_scale[kv_idx] * K_sf[mi]; + if (starts != nullptr && ends != nullptr) { + int s0 = starts[tok]; + int e0 = ends[tok]; + if (s0 < 0) s0 = 0; + if (s0 > kv) s0 = kv; + if (e0 < 0) e0 = 0; + if (e0 > kv) e0 = kv; + Out[kv_idx + (size_t)kv * tok] = (kv_idx >= s0 && kv_idx < e0) ? s : 0.0f; + } else { + Out[kv_idx + (size_t)kv * tok] = s; + } + } + } +#endif +} + +// WMMA 16x16 with head grouping: supports H multiple of 16 +__global__ void k_indexer_logits_wmma16_f32_hgrp( + const float * __restrict__ Q, // [D, Tc*H] + const float * __restrict__ K, // [D, kv] + const float * __restrict__ W, // [H, Tc] + const float * __restrict__ k_scale, // [kv] + int D, int H, int Tc, int kv, + const int * __restrict__ starts, + const int * __restrict__ ends, + float * __restrict__ Out) { +#if __CUDA_ARCH__ >= 800 + const int tokens_per_tile = 1; + const int t0 = blockIdx.x * tokens_per_tile; + const int k0 = blockIdx.y * 16; + if (t0 >= Tc || k0 >= kv) return; + + // Per-row K scale (amax/448) for 16-row tile + __shared__ float K_sf[16]; + int lane = threadIdx.x & 31; + if (lane < 16) K_sf[lane] = 1.0f; + __syncthreads(); + + for (int mi = 0; mi < 16; ++mi) { + int kv_idx = k0 + mi; + float local_max = 0.0f; + if (kv_idx < kv) { + for (int d0 = lane; d0 < D; d0 += 32) { + float v = K[(size_t)d0 + (size_t)D * (size_t)kv_idx]; + float av = fabsf(v); + if (av > local_max) local_max = av; + } + } + // warp reduce + for (int off = 16; off > 0; off >>= 1) { + float other = __shfl_down_sync(0xffffffff, local_max, off); + if (other > local_max) local_max = other; + } + if (lane == 0 && kv_idx < kv) { + float maxv = local_max; + if (maxv < 1e-4f) maxv = 1e-4f; + K_sf[mi] = maxv / 448.0f; + } + __syncwarp(); + } + __syncthreads(); + + __shared__ __half A_sh[16*16]; // row-major K tile (FP8-quantized then decoded) + __shared__ __half B_sh[16*16]; // col-major Q tile (FP8-quantized then decoded) + __shared__ float C_sh[16*16]; // accumulator dump + __shared__ float S_acc[16]; // accumulate per kv row + + if (threadIdx.x < 16) S_acc[threadIdx.x] = 0.0f; + __syncthreads(); + + for (int h0 = 0; h0 < H; h0 += 16) { + wmma::fragment c_frag; + wmma::fill_fragment(c_frag, 0.0f); + + // Iterate K dimension in 16-slices + for (int d0 = 0; d0 < D; d0 += 16) { + int lane2 = threadIdx.x & 31; + // Load A_sh: rows are kv rows, cols are k-slice, with FP8 quant/dequant and per-row scale + for (int idx = lane2; idx < 16*16; idx += 32) { + int mi = idx / 16; // row + int di = idx % 16; // col + int kv_idx = k0 + mi; + __half v = __float2half_rn(0.0f); + if (kv_idx < kv && d0 + di < D) { + float f = K[(size_t)(d0 + di) + (size_t)D * (size_t)kv_idx]; + float sf = K_sf[mi]; + float scaled = f / sf; + uint8_t code = f32_to_fp8e4m3(scaled); + float dec = fp8e4m3_to_f32(code); + v = __float2half_rn(dec); + } + A_sh[mi * 16 + di] = v; + } + // Load B_sh: columns=16 heads in group, rows=16 k-slice; col-major, FP8 quant/dequant + for (int idx = lane2; idx < 16*16; idx += 32) { + int di = idx / 16; // k index + int cj = idx % 16; // head col 0..15 + int h = h0 + cj; + int tok = t0; // one token per tile + __half v = __float2half_rn(0.0f); + if (tok < Tc && h < H && d0 + di < D) { + float f = Q[(size_t)(d0 + di) + (size_t)D * (size_t)(tok*H + h)]; + uint8_t code = f32_to_fp8e4m3(f); + float dec = fp8e4m3_to_f32(code); + v = __float2half_rn(dec); + } + B_sh[cj * 16 + di] = v; + } + __syncthreads(); + + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::load_matrix_sync(a_frag, A_sh, 16); + wmma::load_matrix_sync(b_frag, B_sh, 16); + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + __syncthreads(); + } + + // Dump accumulators to shared + wmma::store_matrix_sync(C_sh, c_frag, 16, wmma::mem_row_major); + __syncthreads(); + + // Accumulate this head-group contribution into S_acc per row + int lane3 = threadIdx.x & 31; + for (int mi = lane3; mi < 16; mi += 32) { + float srow = 0.0f; + for (int cj = 0; cj < 16; ++cj) { + float v = C_sh[mi * 16 + cj]; + if (v < 0.0f) v = 0.0f; + int h = h0 + cj; + if (h < H) { + float w = W[h + (size_t)H * (size_t)t0]; + srow += v * w; + } + } + atomicAdd(&S_acc[mi], srow); + } + __syncthreads(); + } + + // Write out with combined scale KS * K_sf and windowing + for (int mi = lane; mi < 16; mi += 32) { + int kv_idx = k0 + mi; + if (kv_idx < kv && t0 < Tc) { + float srow = S_acc[mi] * k_scale[kv_idx] * K_sf[mi]; + if (starts != nullptr && ends != nullptr) { + int s0 = starts[t0]; + int e0 = ends[t0]; + if (s0 < 0) s0 = 0; + if (s0 > kv) s0 = kv; + if (e0 < 0) e0 = 0; + if (e0 > kv) e0 = kv; + Out[kv_idx + (size_t)kv * (size_t)t0] = (kv_idx >= s0 && kv_idx < e0) ? srow : 0.0f; + } else { + Out[kv_idx + (size_t)kv * (size_t)t0] = srow; + } + } + } +#endif +} + +extern "C" void ggml_cuda_indexer_logits_fused_device(ggml_backend_cuda_context & ctx, + const float * dQ, + const float * dK, + const float * dW, + const float * dKS, + const int * dStarts, const int * dEnds, + int D, int H, int Tc, int kv_end, + float * dOut) { + cudaStream_t stream = ctx.stream(); + // Ensure starts/ends are device-resident copies (handles host or device sources) + const int * dStarts_dev = dStarts; + const int * dEnds_dev = dEnds; + int * dStarts_tmp = nullptr; + int * dEnds_tmp = nullptr; + if (dStarts) { + dStarts_dev = dStarts; + } + if (dEnds) { + dEnds_dev = dEnds; + } + + // env knobs for tile sizes with heuristics when unset + const char *env_bq = getenv("LLAMA_INDEXER_BLOCK_Q"); + int BLOCK_Q = env_bq ? max(1, atoi(env_bq)) : 2; // safe default; larger can explode memory + const char *env_bn = getenv("LLAMA_INDEXER_BLOCK_N"); + int BLOCK_N = env_bn ? max(1, atoi(env_bn)) : (kv_end >= 512 ? 256 : 128); + const char *env_dt = getenv("LLAMA_INDEXER_D_TILE"); + int D_TILE = env_dt ? max(16, atoi(env_dt)) : 32; + size_t work_elems = (size_t)Tc * (size_t)kv_end; + int exact_flag = (work_elems <= 4096) ? 1 : 0; + { + const char *e = getenv("LLAMA_INDEXER_EXACT"); + if (e && *e && atoi(e)!=0) exact_flag = 1; + } + // Select kernel based on env; default to tiled + bool use_wmma = false; + bool do_not_use_wmma = false; + { + const char *s = getenv("LLAMA_INDEXER_USE_WMMA"); + if (s && atoi(s) != 0) use_wmma = true; + if (s && atoi(s) == 0) do_not_use_wmma = true; + } + // Heuristics: + if (!use_wmma) { + size_t work = (size_t)Tc * (size_t)kv_end; + // prefer WMMA when legal: standard (H<=16) or head-grouped (H%16==0) + if (D % 16 == 0 && ((((H <= 16) && ((16 % H) == 0)) || ((H % 16) == 0))) && work >= 16384 && !do_not_use_wmma) { + use_wmma = 1; + } + } + + if (sparse_debug_on()) printf("[INDEXER_DISPATCH] use_wmma=%d D=%d H=%d Tc=%d kv=%d BLOCK_Q=%d BLOCK_N=%d D_TILE=%d\n", (int)use_wmma, D, H, Tc, kv_end, BLOCK_Q, BLOCK_N, D_TILE); + // Optional: TL port path in device wrapper + const char * __prof_env = getenv("LLAMA_SPARSE_PROF"); + auto * __prof_each_env = getenv("LLAMA_SPARSE_PROF_EACH"); + if (const char *s = getenv("LLAMA_INDEXER_TL_PORT"); s && atoi(s) != 0) { + ggml_cuda_pool & __pool = ctx.pool(ggml_cuda_get_device()); + bool use_tma_fp8 = false; + if (const char *e = getenv("LLAMA_TL_FP8"); e && atoi(e) != 0) use_tma_fp8 = true; + // Prepare starts/ends (CuSeqLenKS/KE). If provided by caller via GGML op src[4]/src[5], + // use them; otherwise synthesize [0, kv_end) per token. + ggml_cuda_pool_alloc __KS_i(__pool, (size_t)Tc); + ggml_cuda_pool_alloc __KE_i(__pool, (size_t)Tc); + int *dKS_i = __KS_i.get(), *dKE_i = __KE_i.get(); + // Default: fill 0..kv_end + int tblocks = (Tc + 255) / 256; + k_fill_int<<>>(dKS_i, Tc, 0); + k_fill_int<<>>(dKE_i, Tc, kv_end); + ggml_cuda_pool_alloc __Logits(__pool, (size_t)Tc * (size_t)kv_end); + float *dLogits = __Logits.get(); + cudaMemsetAsync(dLogits, 0, sizeof(float) * (size_t)Tc * (size_t)kv_end, stream); + int block_N = getenv_int_("LLAMA_TL_BLOCK_N", 256); + int threads = getenv_int_("LLAMA_TL_THREADS", 640); + int block_Q = getenv_int_("LLAMA_TL_BLOCK_Q", max(1, 128 / max(1, H))); + int num_stages = getenv_int_("LLAMA_TL_NUM_STAGES", 3); + auto align16 = [](size_t x) { return (x + 15u) & ~size_t(15u); }; + if (sparse_debug_on()) { + int WM = 16, WN = 16; + int Q_rows = block_Q * H; + int Nq_pad = ((Q_rows + WN - 1) / WN) * WN; + size_t Qs_f16_alloc_bytes = (size_t)block_Q * (size_t)H * (size_t)D * sizeof(__half); + size_t Qs_f16_needed_bytes = (size_t)Nq_pad * (size_t)D * sizeof(__half); + int K_rows_max = block_N; + int K_rows_pad = ((K_rows_max + WM - 1) / WM) * WM; + size_t K_f16_alloc_bytes = (size_t)K_rows_max * (size_t)D * sizeof(__half); + size_t K_f16_needed_bytes = (size_t)K_rows_pad * (size_t)D * sizeof(__half); + fprintf(stderr, + "[TL_PORT_DEBUG] Q_rows=%d Nq_pad=%d Qs_f16_alloc=%zu Qs_f16_needed=%zu (diff=%zd) | " + "K_rows_max=%d K_rows_pad=%d K_f16_alloc=%zu K_f16_needed=%zu (diff=%zd)\n", + Q_rows, Nq_pad, (size_t)Qs_f16_alloc_bytes, (size_t)Qs_f16_needed_bytes, (ssize_t)Qs_f16_needed_bytes - (ssize_t)Qs_f16_alloc_bytes, + K_rows_max, K_rows_pad, (size_t)K_f16_alloc_bytes, (size_t)K_f16_needed_bytes, (ssize_t)K_f16_needed_bytes - (ssize_t)K_f16_alloc_bytes); + } + + int maxOpt = 0; + cudaDeviceGetAttribute(&maxOpt, cudaDevAttrMaxSharedMemoryPerBlockOptin, ggml_cuda_get_device()); + size_t max_shmem = (size_t)(maxOpt > 0 ? maxOpt : 98304); + dim3 gridTL((Tc + block_Q - 1) / block_Q); + // Convert Q [D, Tc*H] to row-major [Tc*H, D]; K [D, kv] to [kv, D]; W [H, Tc] to [Tc, H] + ggml_cuda_pool_alloc __Qrm(__pool, (size_t)(Tc*H) * (size_t)D); + ggml_cuda_pool_alloc __Krm(__pool, (size_t)kv_end * (size_t)D); + ggml_cuda_pool_alloc __Wrm(__pool, (size_t)Tc * (size_t)H); + float *dQrm = __Qrm.get(); + float *dKrm = __Krm.get(); + float *dWrm = __Wrm.get(); + if (dStarts_dev != nullptr && dEnds_dev != nullptr) { + CUDA_CHECK(cudaMemcpyAsync(dKS_i, dStarts, sizeof(int)*(size_t)Tc, cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(dKE_i, dEnds, sizeof(int)*(size_t)Tc, cudaMemcpyDeviceToDevice, stream)); + } + if (use_tma_fp8) { + if (getenv("LLAMA_INDEXER_TL_FP8_DEBUG")) { + fprintf(stderr, "[TL_FP8_DBG] entering TL FP8 path: D=%d H=%d Tc=%d kv=%d\n", D, H, Tc, kv_end); + fflush(stderr); + } + static bool __tl_fp8_init_done = false; + static int __tl_fp8_init_status = 0; + if (!__tl_fp8_init_done) { + __tl_fp8_init_status = init(); + __tl_fp8_init_done = true; + if (__tl_fp8_init_status != 0) { + const char * __err = get_last_error(); + fprintf(stderr, "[TL_FP8] init() failed status=%d err=%s\n", __tl_fp8_init_status, __err ? __err : ""); + } + } + ggml_cuda_pool_alloc<__half> __Qh(__pool, (size_t)(Tc*H) * (size_t)D); + ggml_cuda_pool_alloc<__half> __Kh(__pool, (size_t)kv_end * (size_t)D); + ggml_cuda_pool_alloc __Qfp8(__pool, (size_t)(Tc*H) * (size_t)D); + ggml_cuda_pool_alloc __Kfp8(__pool, (size_t)kv_end * (size_t)D); + ggml_cuda_pool_alloc __Kamax(__pool, (size_t)kv_end); + ggml_cuda_pool_alloc __Ksf(__pool, (size_t)kv_end); + ggml_cuda_pool_alloc __KsfInv(__pool, (size_t)kv_end); + ggml_cuda_pool_alloc __IdxKScale(__pool, (size_t)kv_end); + unsigned char *dQfp8 = __Qfp8.get(), *dKfp8 = __Kfp8.get(); + float *dKamax = __Kamax.get(); + float *dKsf = __Ksf.get(); + float *dKsfInv= __KsfInv.get(); + float *dIdxKScale = __IdxKScale.get(); + dim3 tbH(256); + dim3 tbT(32, 8); + dim3 gdQ((Tc*H + tbT.x - 1)/tbT.x, (D + tbT.y - 1)/tbT.y); + dim3 gdK((kv_end + tbT.x - 1)/tbT.x, (D + tbT.y - 1)/tbT.y); + dim3 gdW((Tc + tbT.x - 1)/tbT.x, (H + tbT.y - 1)/tbT.y); + // Convert GGML column-major to row-major layouts matching TileLang kernel + k_colmajor_DN_to_rowmajor_ND<<>>(dQ, D, Tc*H, dQrm); + if (getenv("LLAMA_INDEXER_TL_FP8_DEBUG")) { + std::vector hQrm((size_t)(Tc*H)*(size_t)D); + cudaMemcpy(hQrm.data(), dQrm, hQrm.size()*sizeof(float), cudaMemcpyDeviceToHost); + int maxd = D < 8 ? D : 8; + fprintf(stderr, "[TL_FP8_DBG] Qrm[0,0..%d]:", maxd-1); + for (int d0 = 0; d0 < maxd; ++d0) fprintf(stderr, " % .6f", hQrm[d0]); + fprintf(stderr, "\n"); + } + k_colmajor_DN_to_rowmajor_ND<<>>(dK, D, kv_end, dKrm); + k_colmajor_DN_to_rowmajor_ND<<>>(dW, H, Tc, dWrm); + if (getenv("LLAMA_INDEXER_TL_FP8_DEBUG")) { + std::vector hKrm((size_t)kv_end*(size_t)D); + cudaMemcpy(hKrm.data(), dKrm, hKrm.size()*sizeof(float), cudaMemcpyDeviceToHost); + int maxd = D < 8 ? D : 8; + fprintf(stderr, "[TL_FP8_DBG] Krm[0,0..%d]:", maxd-1); + for (int d0 = 0; d0 < maxd; ++d0) fprintf(stderr, " % .6f", hKrm[d0]); + fprintf(stderr, "\n"); + std::vector hWrm((size_t)Tc*(size_t)H); + cudaMemcpy(hWrm.data(), dWrm, hWrm.size()*sizeof(float), cudaMemcpyDeviceToHost); + int maxh = H < 8 ? H : 8; + fprintf(stderr, "[TL_FP8_DBG] Wrm[0,0..%d] for token0:", maxh-1); + for (int h0 = 0; h0 < maxh; ++h0) fprintf(stderr, " % .6f", hWrm[h0]); + fprintf(stderr, "\n"); + } + // Q: direct FP8 cast (no per-row scaling, as in TileLang test) + { + size_t Qtotal = (size_t)(Tc*H) * (size_t)D; + dim3 gdQfp8((unsigned)((Qtotal + tbH.x - 1)/tbH.x)); + k_rowmajor_f32_to_fp8_e4m3<<>>(dQrm, (int)(Tc*H), D, dQfp8); + CUDA_CHECK(cudaGetLastError()); + } + // K: compute per-row amax, scales, and quantize with per-row scaling + { + int rowsK = kv_end; + int colsK = D; + int threadsAmax = 256; + int blocksAmax = (rowsK + threadsAmax - 1) / threadsAmax; + k_rowmajor_f32_rowwise_absmax<<>>(dKrm, rowsK, colsK, dKamax); + CUDA_CHECK(cudaGetLastError()); + k_fp8_compute_row_scales<<>>(dKamax, rowsK, dKsf, dKsfInv); + CUDA_CHECK(cudaGetLastError()); + size_t Ktotal = (size_t)rowsK * (size_t)colsK; + dim3 gdKfp8((unsigned)((Ktotal + tbH.x - 1)/tbH.x)); + k_rowmajor_f32_to_fp8_e4m3_rowwise_scaled<<>>(dKrm, rowsK, colsK, dKsfInv, dKfp8); + CUDA_CHECK(cudaGetLastError()); + // Combine GGML k_scale (dKS) with FP8 per-row scale into effective IndexKScale + k_elemwise_mul<<>>(dKS, dKsf, dIdxKScale, rowsK); + CUDA_CHECK(cudaGetLastError()); + if (getenv("LLAMA_INDEXER_TL_FP8_DEBUG")) { + std::vector hKS(rowsK); + std::vector hKsf(rowsK); + std::vector hIdxKScale(rowsK); + cudaMemcpy(hKS.data(), dKS, rowsK*sizeof(float), cudaMemcpyDeviceToHost); + cudaMemcpy(hKsf.data(), dKsf, rowsK*sizeof(float), cudaMemcpyDeviceToHost); + cudaMemcpy(hIdxKScale.data(), dIdxKScale, rowsK*sizeof(float), cudaMemcpyDeviceToHost); + int maxr = rowsK < 8 ? rowsK : 8; + fprintf(stderr, "[TL_FP8_DBG] KS[0..%d]:", maxr-1); + for (int i = 0; i < maxr; ++i) fprintf(stderr, " % .6f", hKS[i]); + fprintf(stderr, "\n"); + fprintf(stderr, "[TL_FP8_DBG] Ksf[0..%d]:", maxr-1); + for (int i = 0; i < maxr; ++i) fprintf(stderr, " % .6f", hKsf[i]); + fprintf(stderr, "\n"); + fprintf(stderr, "[TL_FP8_DBG] IdxKScale[0..%d]:", maxr-1); + for (int i = 0; i < maxr; ++i) fprintf(stderr, " % .6f", hIdxKScale[i]); + fprintf(stderr, "\n"); + } + } + + int __tl_call_status = 0; + LAUNCH_PROFILE_KERNEL("PROFILE_TL_FP8_KONLY", TL_ONLY, stream, ([&](){ + __tl_call_status = call(reinterpret_cast(dQfp8), reinterpret_cast(dKfp8), + dIdxKScale, + dLogits, dWrm, + dKS_i, dKE_i, + kv_end, Tc, stream); + })(), D, H, Tc, kv_end); + if (__tl_call_status != 0) { + const char * __err = get_last_error(); + fprintf(stderr, "[TL_FP8] call() failed status=%d err=%s\n", __tl_call_status, __err ? __err : ""); + } + CUDA_CHECK(cudaGetLastError()); + dim3 tblock(32, 8); + dim3 tgrid((Tc + tblock.x - 1)/tblock.x, (kv_end + tblock.y - 1)/tblock.y); + k_transpose_TcKv_to_KvTc<<>>(dLogits, Tc, kv_end, dOut); + CUDA_CHECK(cudaGetLastError()); + if (getenv("LLAMA_INDEXER_TL_FP8_DEBUG")) { + std::vector hLogits((size_t)Tc*(size_t)kv_end); + cudaMemcpy(hLogits.data(), dOut, hLogits.size()*sizeof(float), cudaMemcpyDeviceToHost); + int maxk = kv_end < 8 ? kv_end : 8; + int maxt = Tc < 2 ? Tc : 2; + fprintf(stderr, "[TL_FP8_DBG] Logits (kv x T) sample:\n"); + for (int k = 0; k < maxk; ++k) { + for (int t = 0; t < maxt; ++t) { + float v = hLogits[k + (size_t)kv_end * t]; + fprintf(stderr, " L[%d,%d]=% .6f", k, t, v); + } + fprintf(stderr, "\n"); + } + } + + } else { + dim3 tbT(32, 8); + dim3 gdQ((Tc*H + tbT.x - 1)/tbT.x, (D + tbT.y - 1)/tbT.y); + k_colmajor_DN_to_rowmajor_ND<<>>(dQ, D, Tc*H, dQrm); + if (getenv("LLAMA_INDEXER_TL_FP8_DEBUG")) { + std::vector hQrm((size_t)(Tc*H)*(size_t)D); + cudaMemcpy(hQrm.data(), dQrm, hQrm.size()*sizeof(float), cudaMemcpyDeviceToHost); + int maxd = D < 8 ? D : 8; + fprintf(stderr, "[TL_FP8_DBG] Qrm[0,0..%d]:", maxd-1); + for (int d0 = 0; d0 < maxd; ++d0) fprintf(stderr, " % .6f", hQrm[d0]); + fprintf(stderr, "\n"); + } + dim3 gdK((kv_end + tbT.x - 1)/tbT.x, (D + tbT.y - 1)/tbT.y); + k_colmajor_DN_to_rowmajor_ND<<>>(dK, D, kv_end, dKrm); + dim3 gdW((Tc + tbT.x - 1)/tbT.x, (H + tbT.y - 1)/tbT.y); + k_colmajor_DN_to_rowmajor_ND<<>>(dW, H, Tc, dWrm); + auto compute_smem = [&](int bq, int bn) -> size_t { + // Mirror k_tl_mqa_attn_return_logits_port shared layout (float staging + f16) + const int WM = 16, WN = 16; + const int warps = max(1, threads/32); + size_t off = 0; + // Qs_f32 + off = align16(off); off += (size_t)bq * (size_t)H * (size_t)D * sizeof(float); + // K0_f32, K1_f32 + off = align16(off); off += (size_t)bn * (size_t)D * sizeof(float); + off = align16(off); off += (size_t)bn * (size_t)D * sizeof(float); + // ks0, ks1 + off = align16(off); off += (size_t)bn * sizeof(float); + off = align16(off); off += (size_t)bn * sizeof(float); + // logits_blk + off = align16(off); off += (size_t)bn * (size_t)bq * sizeof(float); + // Qs_f16, K0_f16, K1_f16 + off = align16(off); off += (size_t)bq * (size_t)H * (size_t)D * sizeof(__half); + off = align16(off); off += (size_t)bn * (size_t)D * sizeof(__half); + off = align16(off); off += (size_t)bn * (size_t)D * sizeof(__half); + // Csh per warp + off = align16(off); off += (size_t)warps * (WM*WN) * sizeof(float); + return off; + }; + size_t shmem_bytes = compute_smem(block_Q, block_N); + while (shmem_bytes > max_shmem && (block_Q > 1 || block_N > 1)) { + if (block_N >= block_Q && block_N > 1) block_N = (block_N + 1) / 2; else if (block_Q > 1) block_Q = (block_Q + 1) / 2; + shmem_bytes = compute_smem(block_Q, block_N); + } + CUDA_SET_SHARED_MEMORY_LIMIT(k_tl_mqa_attn_return_logits_port, (int)shmem_bytes); + LAUNCH_PROFILE_KERNEL("PROFILE_TL_ONLY", TL_ONLY, stream, ([&](){ + k_tl_mqa_attn_return_logits_port<<>>( + dQrm, dKrm, dKS, dLogits, dWrm, dKS_i, dKE_i, + Tc, kv_end, H, D, block_N, num_stages, threads, block_Q); + })(), D, H, Tc, kv_end); + CUDA_CHECK(cudaGetLastError()); + dim3 tblock(32, 8); + dim3 tgrid((Tc + tblock.x - 1)/tblock.x, (kv_end + tblock.y - 1)/tblock.y); + k_transpose_TcKv_to_KvTc<<>>(dLogits, Tc, kv_end, dOut); + CUDA_CHECK(cudaGetLastError()); + } + cudaStreamSynchronize(stream); + if (dStarts_tmp) cudaFree(dStarts_tmp); + if (dEnds_tmp) cudaFree(dEnds_tmp); + return; + + } + if (use_wmma && D % 16 == 0 && (size_t)Tc * kv_end > 4096) { + dim3 block(32,1,1); + const int tokens_per_tile = max(1, 16 / min(H,16)); + dim3 grid((Tc + tokens_per_tile - 1) / tokens_per_tile, (kv_end + 15) / 16, 1); + if (H % 16 == 0) { + if (sparse_debug_on()) printf("[INDEXER_DISPATCH] launch=wmma_hgrp grid=(%d,%d) block=(%d,%d)\n", grid.x, grid.y, block.x, block.y); + LAUNCH_PROFILE_KERNEL("PROFILE_WMMA_HGRP_ONLY", WMMA_HGRP_ONLY, stream, ([&](){ + k_indexer_logits_wmma16_f32_hgrp<<>>(dQ, dK, dW, dKS, D, H, Tc, kv_end, dStarts_dev, dEnds_dev, dOut); + })(), D, H, Tc, kv_end); + + } else if (H <= 16 && (16 % H) == 0) { + if (sparse_debug_on()) printf("[INDEXER_DISPATCH] launch=wmma grid=(%d,%d) block=(%d,%d)\n", grid.x, grid.y, block.x, block.y); + LAUNCH_PROFILE_KERNEL("PROFILE_WMMA_ONLY", WMMA_ONLY, stream, ([&]{ + k_indexer_logits_wmma16_bf16<<>>(dQ, dK, dW, dKS, D, H, Tc, kv_end, dStarts_dev, dEnds_dev, dOut); + })(), D, H, Tc, kv_end); + + } else { + // not WMMA-friendly; fallback to tiled below + int HEAD_CHUNK = getenv_int_("LLAMA_INDEXER_HEAD_CHUNK", 32); + if (HEAD_CHUNK > 64) HEAD_CHUNK = 64; + int PIPE_STAGES = getenv_int_("LLAMA_INDEXER_PIPE_STAGES", 2); + if (PIPE_STAGES < 1) PIPE_STAGES = 1; + if (PIPE_STAGES > 2) PIPE_STAGES = 2; + int maxThreadsPerBlock = 1024; + int threadsPerBlock = BLOCK_Q * BLOCK_N; + if (threadsPerBlock > maxThreadsPerBlock) { + int new_BLOCK_N = maxThreadsPerBlock / max(1, BLOCK_Q); + if (new_BLOCK_N < 1) new_BLOCK_N = 1; + if (sparse_debug_on()) printf("[INDEXER_DISPATCH] clamp BLOCK_N %d->%d due to threadsPerBlock=%d>=%d\n", BLOCK_N, new_BLOCK_N, threadsPerBlock, maxThreadsPerBlock); + BLOCK_N = new_BLOCK_N; + } + dim3 blockT(BLOCK_Q, BLOCK_N); + dim3 gridT((Tc + BLOCK_Q - 1)/BLOCK_Q, (kv_end + BLOCK_N - 1)/BLOCK_N); + size_t shmem = (size_t)D_TILE * BLOCK_N * sizeof(float) + + (size_t)D_TILE * BLOCK_Q * HEAD_CHUNK * sizeof(float) + + (PIPE_STAGES >= 2 ? ((size_t)D_TILE * BLOCK_N * sizeof(float) + (size_t)D_TILE * BLOCK_Q * HEAD_CHUNK * sizeof(float)) : 0) + + (size_t)HEAD_CHUNK * BLOCK_Q * sizeof(float); + CUDA_SET_SHARED_MEMORY_LIMIT(k_indexer_logits_tiled_f32, (int)shmem); + if (sparse_debug_on()) printf("[INDEXER_DISPATCH] launch=tiled grid=(%d,%d) block=(%d,%d) shmem=%zu Hc=%d stages=%d\n", gridT.x, gridT.y, blockT.x, blockT.y, (size_t)shmem, HEAD_CHUNK, PIPE_STAGES); + LAUNCH_PROFILE_KERNEL("PROFILE_TILED_ONLY_2", TILED_ONLY_2, stream, ([&]{ + k_indexer_logits_tiled_f32<<>>(dQ, dK, dW, dKS, D, H, Tc, kv_end, dStarts_dev, dEnds_dev, D_TILE, BLOCK_Q, BLOCK_N, exact_flag, HEAD_CHUNK, PIPE_STAGES, dOut); + })(), D, H, Tc, kv_end); + + + cudaStreamSynchronize(stream); + if (dStarts_tmp) cudaFree(dStarts_tmp); + if (dEnds_tmp) cudaFree(dEnds_tmp); + return; + } + { + cudaError_t __err = cudaGetLastError(); + if (__err != cudaSuccess) { + if (sparse_debug_on()) printf("[INDEXER_DISPATCH] WMMA launch failed: %s, falling back to tiled.\n", cudaGetErrorString(__err)); + int HEAD_CHUNK = getenv_int_("LLAMA_INDEXER_HEAD_CHUNK", 32); + if (HEAD_CHUNK > 64) HEAD_CHUNK = 64; + int PIPE_STAGES = getenv_int_("LLAMA_INDEXER_PIPE_STAGES", 2); + if (PIPE_STAGES < 1) PIPE_STAGES = 1; + if (PIPE_STAGES > 2) PIPE_STAGES = 2; + int maxThreadsPerBlock = 1024; + int threadsPerBlock = BLOCK_Q * BLOCK_N; + if (threadsPerBlock > maxThreadsPerBlock) { + int new_BLOCK_N = maxThreadsPerBlock / max(1, BLOCK_Q); + if (new_BLOCK_N < 1) new_BLOCK_N = 1; + BLOCK_N = new_BLOCK_N; + } + dim3 blockT(BLOCK_Q, BLOCK_N); + dim3 gridT((Tc + BLOCK_Q - 1)/BLOCK_Q, (kv_end + BLOCK_N - 1)/BLOCK_N); + size_t shmem = (size_t)D_TILE * BLOCK_N * sizeof(float) + + (size_t)D_TILE * BLOCK_Q * HEAD_CHUNK * sizeof(float) + + (PIPE_STAGES >= 2 ? ((size_t)D_TILE * BLOCK_N * sizeof(float) + (size_t)D_TILE * BLOCK_Q * HEAD_CHUNK * sizeof(float)) : 0) + + (size_t)HEAD_CHUNK * BLOCK_Q * sizeof(float); + CUDA_SET_SHARED_MEMORY_LIMIT(k_indexer_logits_tiled_f32, (int)shmem); + if (sparse_debug_on()) printf("[INDEXER_DISPATCH] fallback tiled grid=(%d,%d) block=(%d,%d) shmem=%zu Hc=%d stages=%d\n", gridT.x, gridT.y, blockT.x, blockT.y, (size_t)shmem, HEAD_CHUNK, PIPE_STAGES); + k_indexer_logits_tiled_f32<<>>(dQ, dK, dW, dKS, D, H, Tc, kv_end, dStarts_dev, dEnds_dev, D_TILE, BLOCK_Q, BLOCK_N, exact_flag, HEAD_CHUNK, PIPE_STAGES, dOut); + } + } + } else { + int HEAD_CHUNK = getenv_int_("LLAMA_INDEXER_HEAD_CHUNK", 32); + if (HEAD_CHUNK > 64) HEAD_CHUNK = 64; + int PIPE_STAGES = getenv_int_("LLAMA_INDEXER_PIPE_STAGES", 2); + if (PIPE_STAGES < 1) PIPE_STAGES = 1; + if (PIPE_STAGES > 2) PIPE_STAGES = 2; + // Sanity clamp block dims to device maximum threads per block + int maxThreadsPerBlock = 1024; + int threadsPerBlock = BLOCK_Q * BLOCK_N; + if (threadsPerBlock > maxThreadsPerBlock) { + int new_BLOCK_N = maxThreadsPerBlock / max(1, BLOCK_Q); + if (new_BLOCK_N < 1) new_BLOCK_N = 1; + if (sparse_debug_on()) printf("[INDEXER_DISPATCH] clamp BLOCK_N %d->%d due to threadsPerBlock=%d>=%d\n", BLOCK_N, new_BLOCK_N, threadsPerBlock, maxThreadsPerBlock); + BLOCK_N = new_BLOCK_N; + } + dim3 blockT(BLOCK_Q, BLOCK_N); + dim3 gridT((Tc + BLOCK_Q - 1)/BLOCK_Q, (kv_end + BLOCK_N - 1)/BLOCK_N); + size_t shmem = (size_t)D_TILE * BLOCK_N * sizeof(float) + + (size_t)D_TILE * BLOCK_Q * HEAD_CHUNK * sizeof(float) + + (PIPE_STAGES >= 2 ? ((size_t)D_TILE * BLOCK_N * sizeof(float) + (size_t)D_TILE * BLOCK_Q * HEAD_CHUNK * sizeof(float)) : 0) + + (size_t)HEAD_CHUNK * BLOCK_Q * sizeof(float); + // Raise per-kernel dynamic shared memory limit to our requirement (best-effort) + CUDA_SET_SHARED_MEMORY_LIMIT(k_indexer_logits_tiled_f32, (int)shmem); + if (sparse_debug_on()) printf("[INDEXER_DISPATCH] launch=tiled grid=(%d,%d) block=(%d,%d) shmem=%zu Hc=%d stages=%d\n", gridT.x, gridT.y, blockT.x, blockT.y, (size_t)shmem, HEAD_CHUNK, PIPE_STAGES); + LAUNCH_PROFILE_KERNEL("PROFILE_TILED_ONLY_1", TILED_ONLY_1, stream, ([&]{ + k_indexer_logits_tiled_f32<<>>(dQ, dK, dW, dKS, D, H, Tc, kv_end, dStarts_dev, dEnds_dev, D_TILE, BLOCK_Q, BLOCK_N, exact_flag, HEAD_CHUNK, PIPE_STAGES, dOut); + })(), D, H, Tc, kv_end); + } + + if (dStarts_tmp) cudaFree(dStarts_tmp); + if (dEnds_tmp) cudaFree(dEnds_tmp); +} diff --git a/ggml/src/ggml-cuda/sparse-mla-decode.cu b/ggml/src/ggml-cuda/sparse-mla-decode.cu new file mode 100644 index 00000000000..a7c9d2a768f --- /dev/null +++ b/ggml/src/ggml-cuda/sparse-mla-decode.cu @@ -0,0 +1,141 @@ +#include "common.cuh" +#include +#include +#include + +// Fused sparse MLA decode kernel supporting MQA/GQA (Hq may differ from Hkv) +// Layouts: +// Q: [D, Hq] +// K: [D, Hkv, N] +// V: [Dv, Hkv, N] +// topk: [Ksel] +// Output: +// Out: [Dv, Hq] + +__global__ void k_sparse_mla_decode(const float * __restrict__ Q, + const float * __restrict__ K, + const float * __restrict__ V, + const int32_t * __restrict__ topk, + int D, int Hq, int Hkv, int Dv, int N, int Ksel, + float kq_scale, float softcap, + float * __restrict__ Out) { + int h = blockIdx.x; + if (h >= Hq) return; + + const int lane = threadIdx.x & 31; + const int warp = threadIdx.x >> 5; + const int wcount = (blockDim.x + 31) >> 5; + + extern __shared__ float smem[]; + float * scores = smem; // Ksel floats + float * s_wmax = scores + Ksel; // wcount floats + float * s_wsum = s_wmax + wcount; // wcount floats + + // compute logits and local max + float m_local = -1e30f; + const int hk = (Hkv == 1 ? 0 : (h % Hkv)); + const float * __restrict__ qh = Q + (size_t)D * h; + + for (int i = threadIdx.x; i < Ksel; i += blockDim.x) { + int idx = topk[i]; + float dot = -1e30f; + if (idx >= 0 && idx < N) { + const float * __restrict__ kh = K + (size_t)D * (hk + (size_t)Hkv * idx); + dot = 0.0f; + for (int d = 0; d < D; ++d) dot += qh[d] * kh[d]; + dot *= kq_scale; + if (softcap > 0.0f) { + dot = tanhf(dot / softcap) * softcap; + } + } + scores[i] = dot; + m_local = fmaxf(m_local, dot); + } + + // warp reduce max + for (int off = 16; off > 0; off >>= 1) { + m_local = fmaxf(m_local, __shfl_down_sync(0xffffffff, m_local, off)); + } + if (lane == 0) { + s_wmax[warp] = m_local; + } + __syncthreads(); + + // block reduce max using first warp + float maxv = -1e30f; + if (warp == 0) { + float v = (lane < wcount) ? s_wmax[lane] : -1e30f; + for (int off = 16; off > 0; off >>= 1) { + v = fmaxf(v, __shfl_down_sync(0xffffffff, v, off)); + } + if (lane == 0) s_wmax[0] = v; + } + __syncthreads(); + maxv = s_wmax[0]; + + // compute exp and local sum + float lsum = 0.0f; + for (int i = threadIdx.x; i < Ksel; i += blockDim.x) { + float e = __expf(scores[i] - maxv); + scores[i] = e; + lsum += e; + } + for (int off = 16; off > 0; off >>= 1) { + lsum += __shfl_down_sync(0xffffffff, lsum, off); + } + if (lane == 0) { + s_wsum[warp] = lsum; + } + __syncthreads(); + + // block reduce sum using first warp + float snorm = 0.0f; + if (warp == 0) { + float v = (lane < wcount) ? s_wsum[lane] : 0.0f; + for (int off = 16; off > 0; off >>= 1) { + v += __shfl_down_sync(0xffffffff, v, off); + } + if (lane == 0) s_wsum[0] = v; + } + __syncthreads(); + snorm = s_wsum[0]; + float inv = snorm > 0.0f ? 1.0f / snorm : 0.0f; + + // accumulate output for this head + for (int dv = threadIdx.x; dv < Dv; dv += blockDim.x) { + float acc = 0.0f; + for (int i = 0; i < Ksel; ++i) { + int idx = topk[i]; if (idx < 0 || idx >= N) continue; + float p = scores[i] * inv; + const float * __restrict__ vh = V + (size_t)Dv * (hk + (size_t)Hkv * idx); + acc += p * vh[dv]; + } + Out[dv + (size_t)Dv * h] = acc; + } +} + +extern "C" void ggml_cuda_sparse_mla_decode_device(ggml_backend_cuda_context & ctx, + const float * q, + const float * k, + const float * v, + const int32_t * topk, + int D, int Hq, int Hkv, int Dv, + int N, int Ksel, + float kq_scale, float softcap, + float * out); + +extern "C" void ggml_cuda_sparse_mla_decode_device(ggml_backend_cuda_context & ctx, + const float * q, + const float * k, + const float * v, + const int32_t * topk, + int D, int Hq, int Hkv, int Dv, + int N, int Ksel, + float kq_scale, float softcap, + float * out) { + dim3 grid(Hq); + dim3 block(128); + int warps = (block.x + 31) >> 5; + size_t shmem = (size_t)Ksel * sizeof(float) + 2 * (size_t)warps * sizeof(float); + k_sparse_mla_decode<<>>(q, k, v, topk, D, Hq, Hkv, Dv, N, Ksel, kq_scale, softcap, out); +} diff --git a/ggml/src/ggml-cuda/topk-radix.cu b/ggml/src/ggml-cuda/topk-radix.cu new file mode 100644 index 00000000000..3a25aee14d3 --- /dev/null +++ b/ggml/src/ggml-cuda/topk-radix.cu @@ -0,0 +1,1161 @@ +#include "topk-radix.cuh" +#include "common.cuh" + +#include +#include + +#include +#include +#include "../../include/ggml-cuda-radix.h" +#include + +#include +#include + +static inline uint16_t host_float_to_half_bits_rtne(float f) { + uint32_t x; + memcpy(&x, &f, sizeof(x)); + uint32_t sign = (x >> 16) & 0x8000u; + int32_t exp = (int32_t)((x >> 23) & 0xFFu) - 127 + 15; + uint32_t mant = x & 0x007FFFFFu; + if (exp <= 0) { + if (exp < -10) return (uint16_t)sign; + mant |= 0x00800000u; + uint32_t sub = mant >> (1 - exp); + if (sub & 0x00001000u) sub += 0x00002000u; + return (uint16_t)(sign | (sub >> 13)); + } else if (exp >= 31) { + if (mant == 0) return (uint16_t)(sign | 0x7C00u); + mant >>= 13; + return (uint16_t)(sign | 0x7C00u | mant | (mant == 0)); + } else { + if (mant & 0x00001000u) { + mant += 0x00002000u; + if (mant & 0x00800000u) { + mant = 0; + exp += 1; + if (exp >= 31) return (uint16_t)(sign | 0x7C00u); + } + } + return (uint16_t)(sign | ((uint32_t)exp << 10) | (mant >> 13)); + } +} +static inline uint8_t host_convert_to_uint16_bin(float x) { + uint16_t h = host_float_to_half_bits_rtne(x); + uint16_t bits = (x < 0.0f) ? (uint16_t)(~h & 0xFFFFu) : (uint16_t)(h | 0x8000u); + return (uint8_t)(bits >> 8); +} + +template +__device__ __forceinline__ void named_sync(int count) { + asm volatile("bar.sync %0, %1;" :: "n"(ID), "r"(count) : "memory"); +} + +#ifndef SEL_DEBUG +#define SEL_DEBUG 0 +#endif +#ifndef SEL_DEBUG_COL +#define SEL_DEBUG_COL 0 +#endif + +static inline __host__ __device__ int env_threads_or_default(const char * name, int deflt) { + int v = deflt; +#ifndef __CUDA_ARCH__ + const char * e = getenv(name); + if (e && *e) { + int t = atoi(e); + if (t > 0) v = t; + } +#endif + if (v < 128) v = 128; + if (v > 1024) v = 1024; + v = (v + 31) & ~31; + return v; +} + +// Key32-based MSB bin for descending order: transform float to lexicographic-descending key and take high byte +static __device__ __forceinline__ uint32_t key32_desc(float x) { + uint32_t u = __float_as_uint(x); + return ((int32_t)u < 0) ? ~u : (u ^ 0x80000000u); +} +static __device__ __forceinline__ uint8_t key32_msb_bin_desc(float x) { + return (uint8_t)(key32_desc(x) >> 24); +} + +// Compute K-th threshold bin for top byte of keys of a given column +// Here we implement a block-per-column approach where each block processes N elements. +// We use shared histogram of 256 bins. +static __global__ void k_histogram_topbyte(const float * __restrict__ scores, + int N, int T, int ld, + uint32_t * __restrict__ thr_bins, + uint32_t * __restrict__ gt_counts) { + int t = blockIdx.x; + if (t >= T) return; + // dynamic shared memory: [warp_count*256] per-warp histograms + [256] final hist + extern __shared__ uint32_t shmem[]; + const int warp_count = blockDim.x >> 5; + uint32_t * hist_warp = shmem; + uint32_t * hist = shmem + warp_count * 256; + for (int i = threadIdx.x; i < 256 * (warp_count + 1); i += blockDim.x) shmem[i] = 0u; + __syncthreads(); + + const float * col = scores + (size_t)ld * t; + // accumulate per-warp histograms + uint32_t * my_hist = hist_warp + ((threadIdx.x >> 5) * 256); + size_t __addr = (size_t)col; + if ( ( (__addr & 0xFu) == 0u) ) { + int i4 = threadIdx.x * 4; + for (; i4 + 3 < N; i4 += blockDim.x * 4) { + float4 v = *((const float4 *)(col + i4)); + uint8_t b0_0 = key32_msb_bin_desc(v.x); + uint8_t b0_1 = key32_msb_bin_desc(v.y); + uint8_t b0_2 = key32_msb_bin_desc(v.z); + uint8_t b0_3 = key32_msb_bin_desc(v.w); + if (b0_0==b0_1 && b0_1==b0_2 && b0_2==b0_3) { + atomicAdd(&my_hist[b0_0], 4u); + } else if (b0_0==b0_1 && b0_2==b0_3) { + atomicAdd(&my_hist[b0_0], 2u); + atomicAdd(&my_hist[b0_2], 2u); + } else { + atomicAdd(&my_hist[b0_0], 1u); + atomicAdd(&my_hist[b0_1], 1u); + atomicAdd(&my_hist[b0_2], 1u); + atomicAdd(&my_hist[b0_3], 1u); + } + } + int rem = N & 3; + int tail_start = N - rem; + int li = tail_start + threadIdx.x; + if (threadIdx.x < rem && li < N) { + uint8_t b0 = key32_msb_bin_desc(col[li]); + atomicAdd(&my_hist[b0], 1u); + } + } else { + for (int i = threadIdx.x; i < N; i += blockDim.x) { + uint8_t b0 = key32_msb_bin_desc(col[i]); + atomicAdd(&my_hist[b0], 1u); + } + } + __syncthreads(); + + // reduce to final hist + for (int b = threadIdx.x; b < 256; b += blockDim.x) { + uint32_t s = 0u; + for (int w = 0; w < warp_count; ++w) s += hist_warp[w*256 + b]; + hist[b] = s; + } + __syncthreads(); + + if (threadIdx.x == 0) { + uint32_t sum = 0; + for (int b = 255; b >= 0; --b) { + gt_counts[b + 256*t] = sum; + sum += hist[b]; + } + thr_bins[t] = 0; + } +} + +// select indices > threshold bin and collect equals for tail passes +static __global__ void k_select_topk_bins(const float * __restrict__ scores, + int N, int T, int ld, int k, int eq_capacity, + const uint32_t * __restrict__ gt_counts, // [256, T] + int * __restrict__ idx_out) +{ + int t = blockIdx.x; + if (t >= T) return; + const float * col = scores + (size_t)ld * t; + // initialize output indices to -1 to avoid mistaking zeros as valid index 0 + if (threadIdx.x == 0) { + for (int i = 0; i < k; ++i) idx_out[i + k*t] = -1; + } + __syncthreads(); + + // Round 0: determine thr0 (MSB) from gt_counts + int thr0 = 0; + for (int b = 255; b >= 0; --b) { + uint32_t sgt = gt_counts[b + 256*t]; + uint32_t prev = (b == 0 ? (uint32_t)N : gt_counts[(b - 1) + 256*t]); + uint32_t eq = prev - gt_counts[b + 256*t]; + if (sgt < (uint32_t)k && sgt + eq >= (uint32_t)k) { + thr0 = b; + break; + } + } + + + // per-warp histogram scratch (max 32 warps) + __shared__ unsigned int pw_hist[32*256]; + __shared__ unsigned int pw_final[256]; + int warp_count = (blockDim.x + 31) >> 5; + __shared__ int sel_sofar; +#if SEL_DEBUG + __shared__ int sel_before_R1; + __shared__ int sel_before_R2; +#endif + if (threadIdx.x == 0) sel_sofar = 0; + __syncthreads(); + uint32_t sgt0 = gt_counts[thr0 + 256*t]; + int take_gt0 = min(k, (int)sgt0); +#if SEL_DEBUG + if (threadIdx.x == 0 && (SEL_DEBUG_COL==0 || blockIdx.x == SEL_DEBUG_COL)) { + printf("[SEL] t=%d thr0=%d sgt0=%u k=%d\n", t, thr0, sgt0, k); + } +#endif + + extern __shared__ int s_eq[]; + int *eq0 = s_eq; + int *eq1 = s_eq + eq_capacity; + __shared__ int eq0_store; + __shared__ int eq0_total; + if (threadIdx.x == 0) { + eq0_store = 0; + eq0_total = 0; + } + __syncthreads(); + + // Select b0>thr0 using per-bin prefix allocation; collect b0==thr0 into eq0 (store up to capacity) + __shared__ int s_written[256]; + // initialize per-bin write cursors to the prefix count from gt_counts + for (int b = threadIdx.x; b < 256; b += blockDim.x) { + s_written[b] = (int)gt_counts[b + 256*t]; + } + __syncthreads(); + + for (int i = threadIdx.x; i < N; i += blockDim.x) { + uint32_t raw = __float_as_uint(col[i]); + int b0 = (int)key32_msb_bin_desc(__uint_as_float(raw)); + if (b0 > thr0) { + int pos = atomicAdd(&s_written[b0], 1); + if (pos < take_gt0) idx_out[pos + k*t] = i; + } else if (b0 == thr0) { + atomicAdd(&eq0_total, 1); + int p = atomicAdd(&eq0_store, 1); + if (p < eq_capacity) eq0[p] = i; + } + } + __syncthreads(); +#if SEL_DEBUG + if (threadIdx.x == 0 && (SEL_DEBUG_COL==0 || blockIdx.x == SEL_DEBUG_COL)) { + printf("[SEL] after R0: eq0_total=%d eq0_store=%d sel_sofar=%d take_gt0=%d\n", eq0_total, eq0_store, sel_sofar, take_gt0); + } +#endif + if (threadIdx.x == 0) sel_sofar = take_gt0; + __syncthreads(); + int remaining = k - sel_sofar; +#if SEL_DEBUG + if (threadIdx.x == 0) sel_before_R1 = sel_sofar; +#endif + __syncthreads(); + + // Round 1: build h1 over b1 on eq0 if fully stored, else scan full column with b0==thr0 + __shared__ unsigned int h1[256]; + for (int i = threadIdx.x; i < 256; i += blockDim.x) h1[i] = 0u; + __syncthreads(); + if (eq0_total <= eq_capacity) { + int lim0 = min(eq0_store, eq_capacity); + for (int j = threadIdx.x; j < lim0; j += blockDim.x) { + int idx = eq0[j]; + uint32_t raw = __float_as_uint(col[idx]); + uint32_t key = ((int32_t)raw < 0)?~raw:(raw^0x80000000u); + int b1 = (key>>16)&0xFF; + atomicAdd(&h1[b1],1u); + } + } else { + for (int i = threadIdx.x; i < N; i += blockDim.x) { + uint32_t raw = __float_as_uint(col[i]); + uint32_t key = ((int32_t)raw < 0)?~raw:(raw^0x80000000u); + int b0c = (int)key32_msb_bin_desc(__uint_as_float(raw)); + if (b0c!=thr0) continue; + int b1=(key>>16)&0xFF; + atomicAdd(&h1[b1],1u); + } + } + __syncthreads(); + __shared__ int thr1; + if (threadIdx.x==0){ + unsigned int sum=0,need=remaining; + thr1=255; + for(int b=255;b>=0;--b){ + unsigned int sgt=sum; + unsigned int eqb=h1[b]; + if(sgt=need){ + thr1=b; + break; + } + sum+=eqb; + } + } + __syncthreads(); +#if SEL_DEBUG + if (threadIdx.x == 0 && (SEL_DEBUG_COL==0 || blockIdx.x == SEL_DEBUG_COL)) { + unsigned int sum_h1 = 0; + for (int b = 0; b < 256; ++b) sum_h1 += h1[b]; + printf("[SEL] R1: thr1=%d remaining=%d path=%s sum_h1=%u eq0_total=%d eq0_store=%d\n", + thr1, remaining, (eq0_total <= eq_capacity) ? "buf" : "fallback", sum_h1, eq0_total, eq0_store); + } +#endif +#if SEL_DEBUG + if (threadIdx.x == 0 && (SEL_DEBUG_COL==0 || blockIdx.x == SEL_DEBUG_COL)) { + printf("[TL_KERNEL] R1 thr1=%d\n", thr1); + } +#endif + + + // Select b1>thr1; collect b1==thr1 into eq1 + __shared__ int eq1_store; + __shared__ int eq1_total; + if (threadIdx.x == 0) { + eq1_store = 0; + eq1_total = 0; + } + __syncthreads(); + if (eq0_total <= eq_capacity) { + int lim0 = min(eq0_store, eq_capacity); + for (int j = threadIdx.x; j < lim0; j += blockDim.x) { + int idx = eq0[j]; + uint32_t raw = __float_as_uint(col[idx]); + uint32_t key=((int32_t)raw<0)?~raw:(raw^0x80000000u); + int b1=(key>>16)&0xFF; + if(b1>thr1){ + int pos=atomicAdd(&sel_sofar,1); + if(pos>16)&0xFF; + if(b1>thr1){ + int pos=atomicAdd(&sel_sofar,1); + if(pos>8)&0xFF; + atomicAdd(&pw_hist[((threadIdx.x>>5)*256) + b2], 1u); + } + } else { + for (int i = threadIdx.x; i < N; i += blockDim.x) { + uint32_t raw=__float_as_uint(col[i]); + uint32_t key=((int32_t)raw<0)?~raw:(raw^0x80000000u); + int b0c=(int)key32_msb_bin_desc(__uint_as_float(raw)); + if(b0c!=thr0) continue; + int b1=(key>>16)&0xFF; + if(b1!=thr1) continue; + int b2=(key>>8)&0xFF; + atomicAdd(&pw_hist[((threadIdx.x>>5)*256) + b2], 1u); + } + } + __syncthreads(); + for (int b = threadIdx.x; b < 256; b += blockDim.x) { + unsigned int s=0; + for (int w=0; w=0;--b){ + unsigned int sgt=sum; + unsigned int eqb=pw_final[b]; + if(sgt < need && sgt + eqb >= need){ + thr2=b; + break; + } + sum+=eqb; + } + } + __syncthreads(); +#if SEL_DEBUG + if (threadIdx.x == 0 && (SEL_DEBUG_COL==0 || blockIdx.x == SEL_DEBUG_COL)) { + printf("[SEL] thr2=%d remaining=%d (R2 path=%s)\n", thr2, remaining, (eq1_store==eq1_total && eq1_total<=eq_capacity)?"buf":"fallback"); + } +#endif + + // Select b2>thr2; collect b2==thr2 back into eq0 (ping-pong) + __shared__ int eq2_total; + if (threadIdx.x == 0) { + eq0_store = 0; + eq2_total = 0; + } + __syncthreads(); + if (eq1_total <= eq_capacity) { + int lim1 = min(eq1_store, eq_capacity); + for (int j = threadIdx.x; j < lim1; j += blockDim.x) { + int idx=eq1[j]; + uint32_t raw=__float_as_uint(col[idx]); + uint32_t key=((int32_t)raw<0)?~raw:(raw^0x80000000u); + int b2=(key>>8)&0xFF; + if(b2>thr2){ + int pos=atomicAdd(&sel_sofar,1); + if(pos>16)&0xFF; + if(b1!=thr1) continue; + int b2=(key>>8)&0xFF; + if(b2>thr2){ + int pos=atomicAdd(&sel_sofar,1); + if(pos>5)*256) + b3], 1u); + } + } else { + for (int i = threadIdx.x; i < N; i += blockDim.x) { + uint32_t raw=__float_as_uint(col[i]); + uint32_t key=((int32_t)raw<0)?~raw:(raw^0x80000000u); + int b0c=(int)key32_msb_bin_desc(__uint_as_float(raw)); + if(b0c!=thr0) continue; + int b1=(key>>16)&0xFF; + if(b1!=thr1) continue; + int b2=(key>>8)&0xFF; + if(b2!=thr2) continue; + int b3= key & 0xFF; + atomicAdd(&pw_hist[((threadIdx.x>>5)*256) + b3], 1u); + } + } + __syncthreads(); + for (int b = threadIdx.x; b < 256; b += blockDim.x) { + unsigned int s=0; + for (int w=0; w=0; --b){ + unsigned int sgt=sum; + unsigned int eqb=pw_final[b]; + if(sgt < need && sgt + eqb >= need){ + thr3=b; + break; + } + sum+=eqb; + } + } + __syncthreads(); +#if SEL_DEBUG + if (threadIdx.x == 0 && (SEL_DEBUG_COL==0 || blockIdx.x == SEL_DEBUG_COL)) { + printf("[SEL] thr3=%d remaining=%d (R3 path=%s)\n", thr3, remaining, (eq0_store==eq2_total && eq2_total<=eq_capacity)?"buf":"fallback"); + } +#endif + + // Select b3>thr3; collect b3==thr3 into eq1 + __shared__ int eq3_total; + if (threadIdx.x == 0) { + eq1_store = 0; + eq3_total = 0; + } + __syncthreads(); + if (eq2_total <= eq_capacity) { + int lim0 = min(eq0_store, eq_capacity); + for (int j = threadIdx.x; j < lim0; j += blockDim.x) { + int idx=eq0[j]; + uint32_t raw=__float_as_uint(col[idx]); + uint32_t key=((int32_t)raw<0)?~raw:(raw^0x80000000u); + int b3= key & 0xFF; + if(b3>thr3){ + int pos=atomicAdd(&sel_sofar,1); + if(pos= 0 && idx < N) { + float v = col[idx]; + unsigned int raw = __float_as_uint(v); + unsigned int key = ((int)raw < 0) ? ~raw : (raw ^ 0x80000000u); + int b0 = (int)key32_msb_bin_desc(v); + int b1 = (key >> 16) & 0xFF; + int b2 = (key >> 8) & 0xFF; + int b3 = key & 0xFF; + printf("(%d:%.5f b0=%d b1=%d b2=%d b3=%d)\n", idx, v, b0, b1, b2, b3); + } else { + printf("(%d:invalid)\n", idx); + } + } + printf("\n"); + } +#endif + } + } + } else { + for (int i = threadIdx.x; i < N; i += blockDim.x) { + uint32_t raw=__float_as_uint(col[i]); + uint32_t key=((int32_t)raw<0)?~raw:(raw^0x80000000u); + int b0c=(int)key32_msb_bin_desc(__uint_as_float(raw)); + if(b0c!=thr0) continue; + int b1=(key>>16)&0xFF; + if(b1!=thr1) continue; + int b2=(key>>8)&0xFF; + if(b2!=thr2) continue; + int b3= key & 0xFF; + if(b3>thr3){ + int pos=atomicAdd(&sel_sofar,1); + if(pos>16)&0xFF; + if(b1!=thr1) continue; + int b2=(key>>8)&0xFF; + if(b2!=thr2) continue; + int b3= key & 0xFF; +#if SEL_DEBUG + if (threadIdx.x == 0 && (SEL_DEBUG_COL==0 || blockIdx.x == SEL_DEBUG_COL)) { + printf("[SEL] final indices t=%d: ", t); + for (int i = 0; i < k; ++i) { + int idx = idx_out[i + k*t]; + if (idx >= 0 && idx < N) { + float v = col[idx]; + unsigned int raw = __float_as_uint(v); + unsigned int key = ((int)raw < 0) ? ~raw : (raw ^ 0x80000000u); + int b0 = (int)key32_msb_bin_desc(v); + int b1 = (key >> 16) & 0xFF; + int b2 = (key >> 8) & 0xFF; + int b3 = key & 0xFF; + printf("(%d:%.5f b0=%d b1=%d b2=%d b3=%d) ", idx, v, b0, b1, b2, b3); + } else { + printf("(%d:invalid) ", idx); + } + } + printf("\n"); + } +#endif + + if(b3!=thr3) continue; + int pos=atomicAdd(&sel_sofar,1); + if (pos < k) idx_out[pos + k*t] = i; + } + } + __syncthreads(); +#if SEL_DEBUG + if (threadIdx.x == 0 && (SEL_DEBUG_COL==0 || blockIdx.x == SEL_DEBUG_COL)) { + printf("[SEL] final sel_sofar=%d\n", sel_sofar); + } +#endif +} + +void ggml_cuda_topk_radix_indices_device(ggml_backend_cuda_context & ctx, + const float * scores_d, int N, int T, int k, + int * idx_d) { + cudaStream_t stream = ctx.stream(); + // Radix-like path: histogram top byte + select with tie refinement + uint32_t * gt_counts_d = nullptr; + uint32_t * thr_bins_d = nullptr; + cudaMalloc(>_counts_d, sizeof(uint32_t) * 256 * (size_t)T); + cudaMalloc(&thr_bins_d, sizeof(uint32_t) * (size_t)T); + + const int hist_threads = env_threads_or_default("LLAMA_SPARSE_TOPK_THREADS", 1024); + const size_t hist_shmem = (size_t)(((hist_threads/32) + 1) * 256) * sizeof(uint32_t); + k_histogram_topbyte<<>>(scores_d, N, T, /*ld=*/N, thr_bins_d, gt_counts_d); + + // Equal-bin selection kernel; bound dynamic shared memory to device limit + const int sel_threads = hist_threads; + // Conservative eq buffer capacity to avoid exceeding per-block shared mem + int cap_env = 0; + const char *env_cap = getenv("LLAMA_SPARSE_TOPK_EQ_CAP"); + if (env_cap) { + cap_env = atoi(env_cap); + if (cap_env < 0) cap_env = 0; + } + int cap_default = 4096; + const int eq_cap = max(k, min(N, cap_env ? cap_env : cap_default)); + size_t sel_shmem = (size_t) (2*eq_cap) * sizeof(int); + CUDA_SET_SHARED_MEMORY_LIMIT(k_select_topk_bins, (int)sel_shmem); + k_select_topk_bins<<>>(scores_d, N, T, /*ld=*/N, k, eq_cap, gt_counts_d, idx_d); + + cudaFree(gt_counts_d); + cudaFree(thr_bins_d); +} + +extern "C" void ggml_cuda_topk_radix_indices_host(const float * scores_h, int N, int T, int k, int * idx_h) { + ggml_backend_cuda_context ctx(0); + cudaStream_t stream = ctx.stream(); + float * scores_d = nullptr; + int * idx_d = nullptr; + cudaMalloc(&scores_d, sizeof(float) * (size_t)N * T); + cudaMalloc(&idx_d, sizeof(int) * (size_t)k * T); + cudaMemcpyAsync(scores_d, scores_h, sizeof(float) * (size_t)N * T, cudaMemcpyHostToDevice, stream); + ggml_cuda_topk_radix_indices_device(ctx, scores_d, N, T, k, idx_d); + cudaMemcpyAsync(idx_h, idx_d, sizeof(int) * (size_t)k * T, cudaMemcpyDeviceToHost, stream); + cudaStreamSynchronize(stream); + cudaFree(scores_d); + cudaFree(idx_d); +} + +extern "C" void ggml_cuda_topk_histogram_host(const float * scores_h, int N, int T, + unsigned int * gt_counts_h, unsigned int * thr_bins_h) { + ggml_backend_cuda_context ctx(0); + cudaStream_t stream = ctx.stream(); + float * scores_d = nullptr; + uint32_t * gt_counts_d = nullptr; + uint32_t * thr_bins_d = nullptr; + cudaMalloc(&scores_d, sizeof(float) * (size_t)N * T); + cudaMalloc(>_counts_d, sizeof(uint32_t) * 256 * (size_t)T); + cudaMalloc(&thr_bins_d, sizeof(uint32_t) * (size_t)T); + cudaMemcpyAsync(scores_d, scores_h, sizeof(float) * (size_t)N * T, cudaMemcpyHostToDevice, stream); + const int hist_threads = env_threads_or_default("LLAMA_SPARSE_TOPK_THREADS", 1024); + const size_t hist_shmem = (size_t)(((hist_threads/32) + 1) * 256) * sizeof(uint32_t); + k_histogram_topbyte<<>>(scores_d, N, T, /*ld=*/N, thr_bins_d, gt_counts_d); + cudaMemcpyAsync(gt_counts_h, gt_counts_d, sizeof(uint32_t) * 256 * (size_t)T, cudaMemcpyDeviceToHost, stream); + cudaMemcpyAsync(thr_bins_h, thr_bins_d, sizeof(uint32_t) * (size_t)T, cudaMemcpyDeviceToHost, stream); + cudaStreamSynchronize(stream); + cudaFree(scores_d); + cudaFree(gt_counts_d); + cudaFree(thr_bins_d); +} + +extern "C" void ggml_cuda_topk_select_host(const float * scores_h, int N, int T, int k, + const unsigned int * gt_counts_h, int * idx_h) { + ggml_backend_cuda_context ctx(0); + cudaStream_t stream = ctx.stream(); + float * scores_d = nullptr; + uint32_t * gt_counts_d = nullptr; + int * idx_d = nullptr; + cudaMalloc(&scores_d, sizeof(float) * (size_t)N * T); + cudaMalloc(>_counts_d, sizeof(uint32_t) * 256 * (size_t)T); + cudaMalloc(&idx_d, sizeof(int) * (size_t)k * T); + cudaMemcpyAsync(scores_d, scores_h, sizeof(float) * (size_t)N * T, cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(gt_counts_d, gt_counts_h, sizeof(uint32_t) * 256 * (size_t)T, cudaMemcpyHostToDevice, stream); + const int sel_threads = env_threads_or_default("LLAMA_SPARSE_TOPK_THREADS", 1024); + int cap_env = 0; + const char *env_cap = getenv("LLAMA_SPARSE_TOPK_EQ_CAP"); + if (env_cap) { + cap_env = atoi(env_cap); + if (cap_env < 0) cap_env = 0; + } + int cap_default = 4096; + const int eq_cap_host = max(k, min(N, cap_env ? cap_env : cap_default)); + const size_t sel_shmem = (size_t) (2*eq_cap_host) * sizeof(int); + CUDA_SET_SHARED_MEMORY_LIMIT(k_select_topk_bins, (int)sel_shmem); + k_select_topk_bins<<>>(scores_d, N, T, /*ld=*/N, k, eq_cap_host, gt_counts_d, idx_d); + cudaMemcpyAsync(idx_h, idx_d, sizeof(int) * (size_t)k * T, cudaMemcpyDeviceToHost, stream); + cudaStreamSynchronize(stream); + cudaFree(scores_d); + cudaFree(gt_counts_d); + cudaFree(idx_d); +} + +// ----------------------------------------------------------------------------- +// TileLang DeepSeek V3.2 top-k selector (ported line-by-line to CUDA) +// This kernel mirrors the control flow and comments of +// /workspace/tilelang/examples/deepseek_v32/topk_selector.py +// Inputs: +// input : [batch, seq_len] float32 scores +// index : [batch, topk] int32 output indices +// starts : [batch] int32 per-batch start index (inclusive) +// ends : [batch] int32 per-batch end index (exclusive) +// Notes: +// - BLOCK_SIZE is 1024 threads per block; one block per batch element +// - RADIX = 256; SMEM_INPUT_SIZE = 4096 (tie buffer per round) +// - convert_to_uint16 / convert_to_uint32 match the TileLang mapping +// Simple glue kernels for wiring the TileLang-ported selector +// ----------------------------------------------------------------------------- + +// Cast to float16, reinterpret bits, then map sign for descending order +static __device__ __forceinline__ uint16_t tl_convert_to_uint16(float x) { + __half h = __float2half(x); + unsigned short bits_uint = __half_as_ushort(h); + bits_uint = (x < 0.0f) ? (unsigned short)(~bits_uint & 0xFFFFu) + : (unsigned short)(bits_uint | 0x8000u); + return (uint16_t)(bits_uint >> 8); +} + +static __device__ __forceinline__ uint32_t tl_convert_to_uint32(float x) { + uint32_t bits_uint = __float_as_uint(x); + return ((int32_t)bits_uint < 0) ? ~bits_uint : (bits_uint | 0x80000000u); +} + +// Derive per-column end (exclusive) from scores by scanning for last value > threshold +// Cooperative within a block: one block per column, threads stride from the end +static __global__ void k_derive_ends_from_scores(const float * __restrict__ scores, + int N, int T, int ld, float masked_thresh, + int * __restrict__ ends) { + int t = blockIdx.x; + if (t >= T) return; + const float * col = scores + (size_t)ld * t; + int e_local = 0; + // Stride from the end: each thread scans its own tail segment + for (int i = N - 1 - threadIdx.x; i >= 0; i -= blockDim.x) { + float v = col[i]; + if (v > masked_thresh) { + e_local = i + 1; + break; + } + } + // Reduce max e_local across threads + extern __shared__ int smax[]; // sized by blockDim.x when launching with dynamic smem=0 -> use static array instead + // Fallback to static shared memory sized for <= 1024 threads + __shared__ int smax_static[1024]; + int *s = smax_static; + s[threadIdx.x] = e_local; + __syncthreads(); + for (int offset = blockDim.x >> 1; offset > 0; offset >>= 1) { + if (threadIdx.x < offset) s[threadIdx.x] = max(s[threadIdx.x], s[threadIdx.x + offset]); + __syncthreads(); + } + if (threadIdx.x == 0) { + int e = s[0]; + ends[t] = (e <= 256) ? 257 : e; + } +} + +// Fixed configuration to match TileLang example +#ifndef TL_TOPK_RADIX +#define TL_TOPK_RADIX 256 +#endif +#ifndef TL_TOPK_BLOCK_SIZE +#define TL_TOPK_BLOCK_SIZE 1024 +#endif +#ifndef TL_TOPK_SMEM_INPUT_SIZE +#define TL_TOPK_SMEM_INPUT_SIZE 4096 +#endif + +// Port of tl_topk_impl kernel +static __global__ void k_tl_topk_port( + const float * __restrict__ input, // [batch, seq_len] + int batch, + int seq_len, + int topk, + int * __restrict__ index, // [batch, topk] + const int * __restrict__ starts, // [batch] + const int * __restrict__ ends) { // [batch] + // with T.Kernel(batch, threads=BLOCK_SIZE) as (bx): + int bx = blockIdx.x; + if (bx >= batch) return; + int tx = threadIdx.x; // T.get_thread_binding() + + // Shared allocations (names match TileLang code) + __shared__ int s_threshold_bin_id[1]; + __shared__ int s_histogram[TL_TOPK_RADIX + 1]; + __shared__ int s_num_input[2]; + __shared__ int s_input_idx[2][TL_TOPK_SMEM_INPUT_SIZE]; + + // Local vars (l_* prefix to mirror TileLang code) + int l_threshold_bin_id = 0; + int l_new_topk = topk; + int l_num_input = 0; + int l_bin_id32 = 0; + int l_val = 0; + int l_start_pos = 0; + int l_start_idx = starts[bx]; + int l_end_idx = ends[bx]; + + if (SEL_DEBUG && bx == 0 && tx == 0) { + printf("[TL_KERNEL] l_new_topk=%d l_start_idx=%d l_end_idx=%d\n", l_new_topk, l_start_idx, l_end_idx); + } + + // stage 1: use 8bit to do quick topk + // T.fill(s_histogram, 0) + for (int i = tx; i < TL_TOPK_RADIX + 1; i += blockDim.x) s_histogram[i] = 0; + if (tx == 0) s_num_input[0] = 0; // T.fill(s_num_input[0], 0) + if (tx == 0) s_threshold_bin_id[0] = -1; + + __syncthreads(); + // for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)): + int iters = (seq_len + TL_TOPK_BLOCK_SIZE - 1) / TL_TOPK_BLOCK_SIZE; + for (int s = 0; s < iters; ++s) { + int input_idx = s * TL_TOPK_BLOCK_SIZE + tx; + if (input_idx < l_end_idx && input_idx >= l_start_idx && input_idx < seq_len) { + float v = input[(size_t)bx * seq_len + input_idx]; + uint16_t inval_int16 = tl_convert_to_uint16(v); + atomicAdd(&s_histogram[inval_int16], 1); + } + } + __syncthreads(); + + // cumsum over RADIX bins (suffix-style), TileLang parity + if (tx < TL_TOPK_RADIX) { + for (int i = 0; i < 8; ++i) { + int offset = 1 << i; + named_sync<3>(TL_TOPK_RADIX); + if (tx < TL_TOPK_RADIX - offset) { + l_val = s_histogram[tx] + s_histogram[tx + offset]; + } + named_sync<3>(TL_TOPK_RADIX); + if (tx < TL_TOPK_RADIX - offset) { + s_histogram[tx] = l_val; + } + } + // find threshold bin id + named_sync<3>(TL_TOPK_RADIX); + if (s_histogram[tx] > l_new_topk && s_histogram[tx + 1] <= l_new_topk) { + s_threshold_bin_id[0] = tx; + if (SEL_DEBUG) { + printf("[TL_KERNEL] thr0=%d tx=%d\n", l_threshold_bin_id, tx); + } + } + } + __syncthreads(); + + l_threshold_bin_id = s_threshold_bin_id[0]; + l_new_topk = l_new_topk - s_histogram[l_threshold_bin_id + 1]; + __syncthreads(); + if (SEL_DEBUG && (bx == 0 && tx == 0)) { + int sgt0_dbg = s_histogram[l_threshold_bin_id + 1]; + printf("[TL_KERNEL] thr0=%d sgt0=%d new_topk=%d\n", l_threshold_bin_id, sgt0_dbg, l_new_topk); + } + + // collect all elements with exponent  threshold + for (int s = 0; s < iters; ++s) { + __syncthreads(); + int input_idx = s * TL_TOPK_BLOCK_SIZE + tx; + if (input_idx < l_end_idx && input_idx >= l_start_idx && input_idx < seq_len) { + float v = input[(size_t)bx * seq_len + input_idx]; + int bin_id = (int)tl_convert_to_uint16(v); + l_bin_id32 = bin_id; + if (l_bin_id32 > l_threshold_bin_id) { + // pos = atomic_add(s_histogram[l_bin_id32 + 1], 1) + int pos = atomicAdd(&s_histogram[l_bin_id32 + 1], 1); + // index[bx, pos] = input_idx + if (pos < topk) index[bx * topk + pos] = input_idx; + } else if (l_bin_id32 == l_threshold_bin_id && l_new_topk > 0) { + int pos = atomicAdd(&s_num_input[0], 1); + if (pos < TL_TOPK_SMEM_INPUT_SIZE) { + s_input_idx[0][pos] = input_idx; + } + } + } + } + + // stage 2: tail pass + for (int round = 0; round < 4; ++round) { + if (l_new_topk <= 0) break; // T.loop_break() + + int r_idx = round & 1; + l_start_pos = topk - l_new_topk; + + __syncthreads(); + // T.fill(s_histogram, 0) + for (int i = tx; i < TL_TOPK_RADIX + 1; i += blockDim.x) s_histogram[i] = 0; + if (tx == 0) s_num_input[r_idx ^ 1] = 0; + __syncthreads(); + + l_num_input = s_num_input[r_idx]; + if (SEL_DEBUG && bx == 0 && tx == 0) { + printf("[TL_KERNEL] R%d start: l_new_topk=%d l_num_input=%d l_start_pos=%d\n", round, l_new_topk, l_num_input, l_start_pos); + } + int it2 = (l_num_input + TL_TOPK_BLOCK_SIZE - 1) / TL_TOPK_BLOCK_SIZE; + for (int s = 0; s < it2; ++s) { + int idx = s * TL_TOPK_BLOCK_SIZE + tx; + if (idx < l_num_input) { + int in_idx = s_input_idx[r_idx][idx]; + float v = input[(size_t)bx * seq_len + in_idx]; + l_bin_id32 = (int)((tl_convert_to_uint32(v) >> (24 - round * 8)) & 0xFFu); + atomicAdd(&s_histogram[l_bin_id32], 1); + } + } + __syncthreads(); + + // cumsum over RADIX bins (suffix-style), TileLang parity + if (tx < TL_TOPK_RADIX) { + for (int i = 0; i < 8; ++i) { + int offset = 1 << i; + named_sync<3>(TL_TOPK_RADIX); + if (tx < TL_TOPK_RADIX - offset) { + l_val = s_histogram[tx] + s_histogram[tx + offset]; + } + named_sync<3>(TL_TOPK_RADIX); + if (tx < TL_TOPK_RADIX - offset) { + s_histogram[tx] = l_val; + } + } + // find threshold bin id + named_sync<3>(TL_TOPK_RADIX); + if (s_histogram[tx] > l_new_topk && s_histogram[tx + 1] <= l_new_topk) { + s_threshold_bin_id[0] = tx; + } + } + __syncthreads(); + + l_threshold_bin_id = s_threshold_bin_id[0]; + l_new_topk = l_new_topk - s_histogram[l_threshold_bin_id + 1]; + __syncthreads(); + + for (int s = 0; s < it2; ++s) { + __syncthreads(); + int idx = s * TL_TOPK_BLOCK_SIZE + tx; + if (idx < l_num_input) { + int in_idx = s_input_idx[r_idx][idx]; + float v = input[(size_t)bx * seq_len + in_idx]; + l_bin_id32 = (int)((tl_convert_to_uint32(v) >> (24 - round * 8)) & 0xFFu); + if (l_bin_id32 > l_threshold_bin_id) { + int pos = atomicAdd(&s_histogram[l_bin_id32 + 1], 1) + l_start_pos; + if (pos < topk) index[bx * topk + pos] = in_idx; + } else if (l_bin_id32 == l_threshold_bin_id && l_new_topk > 0) { + if (round == 3) { + int l_out_pos = atomicAdd(&s_histogram[l_bin_id32 + 1], 1) + l_start_pos; + if (l_out_pos < topk) index[bx * topk + l_out_pos] = in_idx; + } else { + int pos = atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (pos < TL_TOPK_SMEM_INPUT_SIZE) { + s_input_idx[r_idx ^ 1][pos] = in_idx; + } + } + } + } + } + } +} + + +void ggml_cuda_topk_tilelang_port_device(ggml_backend_cuda_context & ctx, + const float * scores_d, int N, int T, int k, + int * idx_d, + const int * starts_d, + const int * ends_d) { + cudaStream_t stream = ctx.stream(); + // Prepare starts/ends: if provided, use them; else synthesize [0,N) by deriving ends from scores + // Treat starts_d/ends_d as potentially const; when we need to allocate, we create writable buffers + int * d_starts = (int *)(uintptr_t)starts_d; + int * d_ends = (int *)(uintptr_t)ends_d; + int * tmp_alloc_starts = nullptr; + int * tmp_alloc_ends = nullptr; + // Optionally fill device windows from scores (default: enabled) + bool fill_win = true; + if (const char *e = getenv("LLAMA_SPARSE_TOPK_WINDOWS_DEVICE")) fill_win = atoi(e) != 0; + if (d_starts == nullptr) { + CUDA_CHECK(cudaMalloc(&tmp_alloc_starts, sizeof(int) * (size_t)T)); + d_starts = tmp_alloc_starts; + } + if (d_ends == nullptr) { + CUDA_CHECK(cudaMalloc(&tmp_alloc_ends, sizeof(int) * (size_t)T)); + d_ends = tmp_alloc_ends; + } + if (fill_win) { + // starts := 0, ends := last index+1 where score > -1e29 + CUDA_CHECK(cudaMemsetAsync(d_starts, 0, sizeof(int) * (size_t)T, stream)); + const int threads = TL_TOPK_BLOCK_SIZE; + const int blocks = T; + k_derive_ends_from_scores<<>>(scores_d, N, T, /*ld=*/N, -1.0e29f, d_ends); + } + + // Directly launch the ported kernel on [N, T] row-major (batch=T, seq_len=N) + dim3 grid(T); + dim3 block(TL_TOPK_BLOCK_SIZE); + CUDA_CHECK(cudaMemsetAsync(idx_d, 0xFF, sizeof(int) * (size_t)k * T, stream)); + const char * __prof_env = getenv("LLAMA_SPARSE_PROF"); + + auto * __prof_each_env = getenv("LLAMA_SPARSE_PROF_EACH"); + if (__prof_env && *__prof_env) { + cudaEvent_t __e0, __e1; + cudaEventCreate(&__e0); + cudaEventCreate(&__e1); + cudaEventRecord(__e0, stream); + k_tl_topk_port<<>>(scores_d, T, N, k, idx_d, d_starts, d_ends); + cudaEventRecord(__e1, stream); + cudaEventSynchronize(__e1); + float __ms = 0.0f; + cudaEventElapsedTime(&__ms, __e0, __e1); + static int __cnt_idx_cuda = 0; + static double __sum_idx_cuda = 0.0; + __sum_idx_cuda += __ms; + __cnt_idx_cuda++; + if (__prof_each_env && *__prof_each_env) { + fprintf(stderr, "[PROFILE_TL_ONLY] TILELANG_TOPK N=%d T=%d k=%d ms=%.3f\n", N, T, k, __ms); + } else { + if (__cnt_idx_cuda % 50 == 0) { + fprintf(stderr, "[PROFILE_TL_ONLY] TILELANG_TOPK N=%d T=%d k=%d avg_ms=%.3f over 50 calls\n", + N, T, k, (float)(__sum_idx_cuda/50.0)); + __sum_idx_cuda = 0.0; + } + } + cudaEventDestroy(__e0); + cudaEventDestroy(__e1); + } else { + k_tl_topk_port<<>>(scores_d, T, N, k, idx_d, d_starts, d_ends); + } + + // Debug: validate first column indices against threshold when profiling is enabled (SEL_DEBUG only) + if (SEL_DEBUG) { + const char * __prof_env2 = getenv("LLAMA_SPARSE_PROF"); + if (__prof_env2 && *__prof_env2) { + // Copy first column t=0 of input and first k indices + std::vector idx0(k, -1); + CUDA_CHECK(cudaMemcpyAsync(idx0.data(), idx_d, sizeof(int) * (size_t)k, cudaMemcpyDeviceToHost, stream)); + std::vector col0(N); + CUDA_CHECK(cudaMemcpyAsync(col0.data(), scores_d, sizeof(float) * (size_t)N, cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + // Compute Kth threshold + // Host compute sgt0 for t=0 for debugging + { + unsigned int hist[256] = {0}; + for (int i = 0; i < N; ++i) { + uint8_t b = host_convert_to_uint16_bin(col0[i]); + hist[b]++; + } + unsigned int S[257]; + S[256] = 0; + for (int b = 255; b >= 0; --b) S[b] = S[b+1] + hist[b]; + int thr0 = 0; + for (int b = 255; b >= 0; --b) { + if (S[b] > (unsigned)k && S[b+1] <= (unsigned)k) { + thr0 = b; + break; + } + } + unsigned int sgt0 = S[thr0+1]; + if (SEL_DEBUG) fprintf(stderr, "[TL_HOST_DBG] thr0=%d sgt0=%u\n", thr0, sgt0); + } + + std::vector sorted = col0; + if (k > 0 && k <= N) { + std::nth_element(sorted.begin(), sorted.begin() + (k-1), sorted.end(), std::greater()); + float thresh = sorted[k-1]; + // compute MSB threshold bin for host check + int thr_bin = host_convert_to_uint16_bin(thresh); + int below = 0; + int bad_idx = -1; + float bad_v = 0.0f; + int bad_bin = -1; + std::vector seen(N, 0); + for (int i = 0; i < k; ++i) { + int ix = idx0[i]; + if (ix < 0 || ix >= N) { + below++; + bad_idx = ix; + bad_v = NAN; + bad_bin = -1; + break; + } + if (!seen[ix]) seen[ix] = 1; + float v = col0[ix]; + int vb = host_convert_to_uint16_bin(v); + if (!(v >= thresh)) { + below++; + bad_idx = ix; + bad_v = v; + bad_bin = vb; + break; + } + } + if (SEL_DEBUG) fprintf(stderr, "[TL_DEBUG] t=0 thresh=%.6f (bin=%d) below=%d bad_idx=%d bad_v=%g bad_bin=%d\n", thresh, thr_bin, below, bad_idx, bad_v, bad_bin); + } + if (SEL_DEBUG) { + fprintf(stderr, "[TL_DEBUG_IDX] idx0: "); + for (int i = 0; i < k; ++i) fprintf(stderr, "%d ", idx0[i]); + fprintf(stderr, "\n"); + } + } + + } + + if (tmp_alloc_starts) cudaFree(tmp_alloc_starts); + if (tmp_alloc_ends) cudaFree(tmp_alloc_ends); +} diff --git a/ggml/src/ggml-cuda/topk-radix.cuh b/ggml/src/ggml-cuda/topk-radix.cuh new file mode 100644 index 00000000000..b487b0941f7 --- /dev/null +++ b/ggml/src/ggml-cuda/topk-radix.cuh @@ -0,0 +1,16 @@ +#pragma once +#include "common.cuh" + +// Launch device kernel(s) to compute per-column top-k indices using a radix selection approach. +// scores_d: device pointer to [N, T] row-major (scores[i + N*t]) +// idx_d: device pointer to [k, T] row-major (idx[i + k*t]) +void ggml_cuda_topk_radix_indices_device(ggml_backend_cuda_context & ctx, + const float * scores_d, int N, int T, int k, + int * idx_d); + +// Launch TileLang-ported top-k kernel with starts/ends synthesized as [0, N) +void ggml_cuda_topk_tilelang_port_device(ggml_backend_cuda_context & ctx, + const float * scores_d, int N, int T, int k, + int * idx_d, + const int * starts_d, + const int * ends_d); diff --git a/ggml/src/ggml-cuda/vendors/cutlass b/ggml/src/ggml-cuda/vendors/cutlass new file mode 160000 index 00000000000..b2dd65dc864 --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/cutlass @@ -0,0 +1 @@ +Subproject commit b2dd65dc864e09688245b316ac46c4a6cd07e15c diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/mqa_attn_return_logits_kernel.cu b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/mqa_attn_return_logits_kernel.cu new file mode 100644 index 00000000000..8150c5c669c --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/mqa_attn_return_logits_kernel.cu @@ -0,0 +1,222 @@ +#include "vendor_config.h" +#include "tl_templates/cuda/cuda_fp8.h" +#include "tl_templates/cuda/gemm.h" +#include "tl_templates/cuda/copy.h" +#include "tl_templates/cuda/reduce.h" +#include "tl_templates/cuda/ldsm.h" +#include "tl_templates/cuda/threadblock_swizzle.h" +#include "tl_templates/cuda/debug.h" +#ifdef ENABLE_BF16 +#include "tl_templates/cuda/cuda_bf16_fallbacks.cuh" +#endif + +extern "C" __global__ void mqa_attn_return_logits_kernel_kernel(int* __restrict__ CuSeqLenKE, int* __restrict__ CuSeqLenKS, float* __restrict__ IndexKScale, __grid_constant__ const CUtensorMap IndexK_desc, __grid_constant__ const CUtensorMap IndexQ_desc, float* __restrict__ Logits, float* __restrict__ Weights, int seq_len, int seq_len_kv); +extern "C" __global__ void __launch_bounds__(640, 1) mqa_attn_return_logits_kernel_kernel(int* __restrict__ CuSeqLenKE, int* __restrict__ CuSeqLenKS, float* __restrict__ IndexKScale, __grid_constant__ const CUtensorMap IndexK_desc, __grid_constant__ const CUtensorMap IndexQ_desc, float* __restrict__ Logits, float* __restrict__ Weights, int seq_len, int seq_len_kv) { + extern __shared__ __align__(1024) uchar buf_dyn_shmem[]; + int cu_k_s_min[1]; + int cu_k_e_max[1]; + float weights[2]; + float index_k_scale_fragment[32]; + float s[64]; + float s_reshaped[64]; + float logits[32]; + __shared__ uint64_t mbarrier_mem[7]; + auto mbarrier = reinterpret_cast(mbarrier_mem); + if (tl::tl_shuffle_elect<0>()) { + tl::prefetch_tma_descriptor(IndexQ_desc); + tl::prefetch_tma_descriptor(IndexK_desc); + mbarrier[0].init(128); + mbarrier[1].init(128); + mbarrier[2].init(128); + mbarrier[3].init(512); + mbarrier[4].init(512); + mbarrier[5].init(512); + mbarrier[6].init(128); + } + tl::fence_barrier_init(); + __syncthreads(); + if (512 <= ((int)threadIdx.x)) { + cu_k_s_min[0] = 2147483647; + cu_k_e_max[0] = -2147483648; + for (int bq_i = 0; bq_i < 32; ++bq_i) { + if (((((int)blockIdx.x) * 32) + bq_i) < seq_len) { + cu_k_s_min[0] = min(cu_k_s_min[0], min(CuSeqLenKS[((((int64_t)((int)blockIdx.x)) * (int64_t)32) + ((int64_t)bq_i))], seq_len_kv)); + } else { + cu_k_s_min[0] = min(cu_k_s_min[0], 0); + } + } + for (int bq_i_1 = 0; bq_i_1 < 32; ++bq_i_1) { + if (((((int)blockIdx.x) * 32) + bq_i_1) < seq_len) { + cu_k_e_max[0] = max(cu_k_e_max[0], min(CuSeqLenKE[((((int64_t)((int)blockIdx.x)) * (int64_t)32) + ((int64_t)bq_i_1))], seq_len_kv)); + } else { + cu_k_e_max[0] = max(cu_k_e_max[0], 0); + } + } + if (tl::tl_shuffle_elect<128>()) { + mbarrier[6].expect_transaction(8192); + tl::fence_proxy_async(); + tl::tma_load(IndexQ_desc, mbarrier[6], (&(((fp8_e4_t*)buf_dyn_shmem)[0])), 0, (((int)blockIdx.x) * 128)); + } + mbarrier[6].arrive(); + for (int nbn_i = 0; nbn_i < (((cu_k_e_max[0] + 255) - cu_k_s_min[0]) >> 8); ++nbn_i) { + mbarrier[((nbn_i % 3) + 3)].wait((((nbn_i % 6) / 3) ^ 1)); + if (tl::tl_shuffle_elect<128>()) { + mbarrier[(nbn_i % 3)].expect_transaction(16384); + tl::fence_proxy_async(); + tl::tma_load(IndexK_desc, mbarrier[(nbn_i % 3)], (&(((fp8_e4_t*)buf_dyn_shmem)[(((nbn_i % 3) * 16384) + 8192)])), 0, ((nbn_i * 256) + cu_k_s_min[0])); + } + tl::mbarrier_cp_async_arrive(mbarrier[(nbn_i % 3)]); + mbarrier[(nbn_i % 3)].arrive(); + } + } else { + cu_k_s_min[0] = 2147483647; + cu_k_e_max[0] = -2147483648; + for (int bq_i_2 = 0; bq_i_2 < 32; ++bq_i_2) { + if (((((int)blockIdx.x) * 32) + bq_i_2) < seq_len) { + cu_k_s_min[0] = min(cu_k_s_min[0], min(CuSeqLenKS[((((int64_t)((int)blockIdx.x)) * (int64_t)32) + ((int64_t)bq_i_2))], seq_len_kv)); + } else { + cu_k_s_min[0] = min(cu_k_s_min[0], 0); + } + } + for (int bq_i_3 = 0; bq_i_3 < 32; ++bq_i_3) { + if (((((int)blockIdx.x) * 32) + bq_i_3) < seq_len) { + cu_k_e_max[0] = max(cu_k_e_max[0], min(CuSeqLenKE[((((int64_t)((int)blockIdx.x)) * (int64_t)32) + ((int64_t)bq_i_3))], seq_len_kv)); + } else { + cu_k_e_max[0] = max(cu_k_e_max[0], 0); + } + } + float2 condval; + if (((((((int)blockIdx.x) * 32) + ((((int)threadIdx.x) >> 5) * 2)) + ((((int)threadIdx.x) & 3) >> 1)) < seq_len)) { + condval = *(float2*)(Weights + (((((int64_t)((int)blockIdx.x)) * (int64_t)128) + ((((int64_t)((int)threadIdx.x)) >> (int64_t)5) * (int64_t)8)) + ((((int64_t)((int)threadIdx.x)) & (int64_t)3) * (int64_t)2))); + } else { + condval = make_float2(0x0p+0f/*0.000000e+00*/, 0x0p+0f/*0.000000e+00*/); + } + *(float2*)(weights + 0) = condval; + mbarrier[6].wait(0); + for (int nbn_i_1 = 0; nbn_i_1 < (((cu_k_e_max[0] + 255) - cu_k_s_min[0]) >> 8); ++nbn_i_1) { + #pragma unroll + for (int i = 0; i < 32; ++i) { + if (((((((nbn_i_1 * 256) + (i * 8)) + ((((int)threadIdx.x) & 31) >> 2)) + cu_k_s_min[0]) < seq_len_kv) && (0 <= ((((nbn_i_1 * 256) + (i * 8)) + ((((int)threadIdx.x) & 31) >> 2)) + cu_k_s_min[0]))) && (((((nbn_i_1 * 256) + (i * 8)) + ((((int)threadIdx.x) & 31) >> 2)) + cu_k_s_min[0]) < seq_len_kv)) { + index_k_scale_fragment[i] = IndexKScale[((((((int64_t)nbn_i_1) * (int64_t)256) + (((int64_t)i) * (int64_t)8)) + ((((int64_t)((int)threadIdx.x)) & (int64_t)31) >> (int64_t)2)) + ((int64_t)cu_k_s_min[(int64_t)0]))]; + } else { + float condval_1; + if ((((((((nbn_i_1 * 256) + (i * 8)) + ((((int)threadIdx.x) & 31) >> 2)) + cu_k_s_min[0]) < seq_len_kv) && (0 <= ((((nbn_i_1 * 256) + (i * 8)) + ((((int)threadIdx.x) & 31) >> 2)) + cu_k_s_min[0]))) && (((((nbn_i_1 * 256) + (i * 8)) + ((((int)threadIdx.x) & 31) >> 2)) + cu_k_s_min[0]) < seq_len_kv))) { + condval_1 = IndexKScale[((((((int64_t)nbn_i_1) * (int64_t)256) + (((int64_t)i) * (int64_t)8)) + ((((int64_t)((int)threadIdx.x)) & (int64_t)31) >> (int64_t)2)) + ((int64_t)cu_k_s_min[(int64_t)0]))]; + } else { + condval_1 = 0x0p+0f/*0.000000e+00*/; + } + index_k_scale_fragment[i] = condval_1; + } + } + mbarrier[(nbn_i_1 % 3)].wait(((nbn_i_1 % 6) / 3)); + tl::fence_proxy_async(); + tl::gemm_ss<256, 128, 64, 1, 16, 0, 1, 1, 64, 64, 0, 0>((&(((fp8_e4_t*)buf_dyn_shmem)[(((nbn_i_1 % 3) * 16384) + 8192)])), (&(((fp8_e4_t*)buf_dyn_shmem)[0])), (&(s[0]))); + mbarrier[((nbn_i_1 % 3) + 3)].arrive(); + #pragma unroll + for (int i_1 = 0; i_1 < 64; ++i_1) { + s_reshaped[i_1] = ((max(s[i_1], 0x0p+0f/*0.000000e+00*/) * weights[(i_1 & 1)]) * index_k_scale_fragment[(i_1 >> 1)]); + } + #pragma unroll + for (int i_2 = 0; i_2 < 32; ++i_2) { + logits[i_2] = 0x0p+0f/*0.000000e+00*/; + #pragma unroll + for (int rv = 0; rv < 2; ++rv) { + logits[i_2] = (logits[i_2] + s_reshaped[((i_2 * 2) + rv)]); + } + logits[i_2] = tl::AllReduce::run(logits[i_2]); + } + if ((((int)threadIdx.x) % 2) == 0) { + #pragma unroll + for (int i_3 = 0; i_3 < 32; ++i_3) { + if (((0 <= ((((nbn_i_1 * 256) + (i_3 * 8)) + ((((int)threadIdx.x) & 31) >> 2)) + cu_k_s_min[0])) && (((((nbn_i_1 * 256) + (i_3 * 8)) + ((((int)threadIdx.x) & 31) >> 2)) + cu_k_s_min[0]) < seq_len_kv)) && ((((((int)blockIdx.x) * 32) + ((((int)threadIdx.x) >> 5) * 2)) + ((((int)threadIdx.x) & 3) >> 1)) < seq_len)) { + if (((((nbn_i_1 * 256) + (i_3 * 8)) + ((((int)threadIdx.x) & 31) >> 2)) + cu_k_s_min[0]) < seq_len_kv) { + Logits[(((((((int64_t)nbn_i_1) * (int64_t)256) + (((int64_t)i_3) * (int64_t)8)) + ((((int64_t)((int)threadIdx.x)) & (int64_t)31) >> (int64_t)2)) + ((((((int64_t)((int)blockIdx.x)) * (int64_t)32) + ((((int64_t)((int)threadIdx.x)) >> (int64_t)5) * (int64_t)2)) + ((((int64_t)((int)threadIdx.x)) & (int64_t)3) >> (int64_t)1)) * ((int64_t)seq_len_kv))) + ((int64_t)cu_k_s_min[(int64_t)0]))] = logits[i_3]; + } + } else { + if (0 <= ((((nbn_i_1 * 256) + (i_3 * 8)) + ((((int)threadIdx.x) & 31) >> 2)) + cu_k_s_min[0])) { + if (((((nbn_i_1 * 256) + (i_3 * 8)) + ((((int)threadIdx.x) & 31) >> 2)) + cu_k_s_min[0]) < seq_len_kv) { + if ((((((int)blockIdx.x) * 32) + ((((int)threadIdx.x) >> 5) * 2)) + ((((int)threadIdx.x) & 3) >> 1)) < seq_len) { + Logits[(((((((int64_t)nbn_i_1) * (int64_t)256) + (((int64_t)i_3) * (int64_t)8)) + ((((int64_t)((int)threadIdx.x)) & (int64_t)31) >> (int64_t)2)) + ((((((int64_t)((int)blockIdx.x)) * (int64_t)32) + ((((int64_t)((int)threadIdx.x)) >> (int64_t)5) * (int64_t)2)) + ((((int64_t)((int)threadIdx.x)) & (int64_t)3) >> (int64_t)1)) * ((int64_t)seq_len_kv))) + ((int64_t)cu_k_s_min[(int64_t)0]))] = logits[i_3]; + } + } + } + } + } + } + } + } +} + + +#define ERROR_BUF_SIZE 1024 +static char error_buf[ERROR_BUF_SIZE]; + +extern "C" const char* get_last_error() { + return error_buf; +} + +extern "C" int init() { + error_buf[0] = '\0'; + + cudaError_t result_mqa_attn_return_logits_kernel_kernel = cudaFuncSetAttribute(mqa_attn_return_logits_kernel_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 57344); + if (result_mqa_attn_return_logits_kernel_kernel != cudaSuccess) { + snprintf(error_buf, ERROR_BUF_SIZE, "Failed to set the allowed dynamic shared memory size to %d with error: %s", 57344, cudaGetErrorString(result_mqa_attn_return_logits_kernel_kernel)); + return -1; + } + + return 0; +} + +extern "C" int call(fp8_e4_t* __restrict__ IndexQ, fp8_e4_t* __restrict__ IndexK, float* __restrict__ IndexKScale, float* __restrict__ Logits, float* __restrict__ Weights, int* __restrict__ CuSeqLenKS, int* __restrict__ CuSeqLenKE, int seq_len_kv, int seq_len, cudaStream_t stream=cudaStreamDefault) { + + CUtensorMap IndexK_desc; + CUtensorMapDataType IndexK_desc_type= (CUtensorMapDataType)0; + cuuint32_t IndexK_desc_tensorRank= 2; + void *IndexK_desc_globalAddress= IndexK; + cuuint64_t IndexK_desc_globalDim[2]= {64,seq_len_kv}; + cuuint64_t IndexK_desc_globalStride[2]= {1,64}; + cuuint32_t IndexK_desc_boxDim[2]= {64,256}; + cuuint32_t IndexK_desc_elementStrides[2]= {1,1}; + CUtensorMapInterleave IndexK_desc_interleave= (CUtensorMapInterleave)0; + CUtensorMapSwizzle IndexK_desc_swizzle= (CUtensorMapSwizzle)2; + CUtensorMapL2promotion IndexK_desc_l2Promotion= (CUtensorMapL2promotion)2; + CUtensorMapFloatOOBfill IndexK_desc_oobFill= (CUtensorMapFloatOOBfill)0; + + CUresult IndexK_desc_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)( + &IndexK_desc, IndexK_desc_type, IndexK_desc_tensorRank, IndexK_desc_globalAddress, IndexK_desc_globalDim, IndexK_desc_globalStride + 1, IndexK_desc_boxDim, IndexK_desc_elementStrides, IndexK_desc_interleave, IndexK_desc_swizzle, IndexK_desc_l2Promotion, IndexK_desc_oobFill); + + if (IndexK_desc_result != CUDA_SUCCESS) { + std::stringstream ss; + ss << "Error: Failed to initialize the TMA descriptor IndexK_desc"; + snprintf(error_buf, ERROR_BUF_SIZE, "%s", ss.str().c_str()); + return -1; + } + + CUtensorMap IndexQ_desc; + CUtensorMapDataType IndexQ_desc_type= (CUtensorMapDataType)0; + cuuint32_t IndexQ_desc_tensorRank= 2; + void *IndexQ_desc_globalAddress= IndexQ; + cuuint64_t IndexQ_desc_globalDim[2]= {64,seq_len * 4}; + cuuint64_t IndexQ_desc_globalStride[2]= {1,64}; + cuuint32_t IndexQ_desc_boxDim[2]= {64,128}; + cuuint32_t IndexQ_desc_elementStrides[2]= {1,1}; + CUtensorMapInterleave IndexQ_desc_interleave= (CUtensorMapInterleave)0; + CUtensorMapSwizzle IndexQ_desc_swizzle= (CUtensorMapSwizzle)2; + CUtensorMapL2promotion IndexQ_desc_l2Promotion= (CUtensorMapL2promotion)2; + CUtensorMapFloatOOBfill IndexQ_desc_oobFill= (CUtensorMapFloatOOBfill)0; + + CUresult IndexQ_desc_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)( + &IndexQ_desc, IndexQ_desc_type, IndexQ_desc_tensorRank, IndexQ_desc_globalAddress, IndexQ_desc_globalDim, IndexQ_desc_globalStride + 1, IndexQ_desc_boxDim, IndexQ_desc_elementStrides, IndexQ_desc_interleave, IndexQ_desc_swizzle, IndexQ_desc_l2Promotion, IndexQ_desc_oobFill); + + if (IndexQ_desc_result != CUDA_SUCCESS) { + std::stringstream ss; + ss << "Error: Failed to initialize the TMA descriptor IndexQ_desc"; + snprintf(error_buf, ERROR_BUF_SIZE, "%s", ss.str().c_str()); + return -1; + } + mqa_attn_return_logits_kernel_kernel<<>>(CuSeqLenKE, CuSeqLenKS, IndexKScale, IndexK_desc, IndexQ_desc, Logits, Weights, seq_len, seq_len_kv); + TILELANG_CHECK_LAST_ERROR("mqa_attn_return_logits_kernel_kernel"); + + return 0; +} + diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/atomic.h b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/atomic.h new file mode 100644 index 00000000000..82eeccfda5a --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/atomic.h @@ -0,0 +1,471 @@ +#pragma once + +#ifndef __CUDACC_RTC__ +#include +#endif + +#include +#include +#include + +using cutlass::bfloat16_t; +using cutlass::half_t; + +#define TL_DEVICE __forceinline__ __device__ + +template struct normalize_atomic_type { + using type = T; +}; + +template <> struct normalize_atomic_type { + using type = half; +}; + +#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750)) +template <> struct normalize_atomic_type { + using type = __nv_bfloat16; +}; +#endif + +template TL_DEVICE T1 cuda_cast(T2 val) { + return T1(val); +} + +template <> TL_DEVICE half cuda_cast(float val) { + return __float2half(val); +} + +#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750)) +template <> TL_DEVICE __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val) { + return __float2bfloat16(val); +} +#endif + +template +TL_DEVICE void AtomicMax(T1 &ref, T2 val, + int memory_order = int(cuda::memory_order_relaxed)) { + using NT1 = typename normalize_atomic_type::type; + T1 *address = &ref; + if constexpr ((std::is_same_v || + std::is_same_v) && + memory_order == int(cuda::memory_order_relaxed)) { + atomicMax(reinterpret_cast(address), static_cast(val)); + } else { + cuda::atomic_ref aref(*address); + aref.fetch_max(cuda_cast(val), cuda::memory_order(memory_order)); + } +} + +template +TL_DEVICE T1 AtomicMaxRet(T1 &ref, T2 val, + int memory_order = int(cuda::memory_order_relaxed)) { + using NT1 = typename normalize_atomic_type::type; + T1 *address = &ref; + if constexpr ((std::is_same_v || + std::is_same_v) && + memory_order == int(cuda::memory_order_relaxed)) { + return static_cast( + atomicMax(reinterpret_cast(address), static_cast(val))); + } else { + cuda::atomic_ref aref(*address); + return static_cast( + aref.fetch_max(cuda_cast(val), cuda::memory_order(memory_order))); + } +} + +template +TL_DEVICE void AtomicMin(T1 &ref, T2 val, + int memory_order = int(cuda::memory_order_relaxed)) { + using NT1 = typename normalize_atomic_type::type; + T1 *address = &ref; + if constexpr ((std::is_same_v || + std::is_same_v) && + memory_order == int(cuda::memory_order_relaxed)) { + atomicMin(reinterpret_cast(address), static_cast(val)); + } else { + cuda::atomic_ref aref(*address); + aref.fetch_min(cuda_cast(val), cuda::memory_order(memory_order)); + } +} + +template +TL_DEVICE T1 AtomicMinRet(T1 &ref, T2 val, + int memory_order = int(cuda::memory_order_relaxed)) { + using NT1 = typename normalize_atomic_type::type; + T1 *address = &ref; + if constexpr ((std::is_same_v || + std::is_same_v) && + memory_order == int(cuda::memory_order_relaxed)) { + return static_cast( + atomicMin(reinterpret_cast(address), static_cast(val))); + } else { + cuda::atomic_ref aref(*address); + return static_cast( + aref.fetch_min(cuda_cast(val), cuda::memory_order(memory_order))); + } +} + +template +TL_DEVICE void AtomicAdd(T1 &ref, T2 val, + int memory_order = int(cuda::memory_order_relaxed)) { + using NT1 = typename normalize_atomic_type::type; + T1 *address = &ref; + if constexpr ((std::is_same_v || + std::is_same_v) && + memory_order == int(cuda::memory_order_relaxed)) { + atomicAdd(reinterpret_cast(address), static_cast(val)); + } else { + cuda::atomic_ref aref(*address); + aref.fetch_add(cuda_cast(val), cuda::memory_order(memory_order)); + } +} + +template +TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val, + int memory_order = int(cuda::memory_order_relaxed)) { + using NT1 = typename normalize_atomic_type::type; + T1 *address = &ref; + if constexpr ((std::is_same_v || + std::is_same_v) && + memory_order == int(cuda::memory_order_relaxed)) { + return static_cast( + atomicAdd(reinterpret_cast(address), static_cast(val))); + } else { + cuda::atomic_ref aref(*address); + return static_cast( + aref.fetch_add(cuda_cast(val), cuda::memory_order(memory_order))); + } +} + +// TODO add memory_order for vectorized atomic add +TL_DEVICE void AtomicAddx2(half_t *ref, half_t *val, + int memory_order = int(cuda::memory_order_relaxed)) { + if (memory_order == int(cuda::memory_order_relaxed)) { + atomicAdd(reinterpret_cast(ref), + static_cast(*reinterpret_cast(val))); + } else { + // Since atomicAdd does not support memory order, atomic_ref does not + // support vectorized atomic operation we can only inline ptx code here + // Note: Vectorized atomic operations only support global space + // Note: for 16-bit value, we need to reinterpret_cast the value to unsigned + // short and use "h" register in assembly + __half2 add_val = *reinterpret_cast<__half2 *>(val); + unsigned short add_val_x_cast = + *reinterpret_cast(&add_val.x); + unsigned short add_val_y_cast = + *reinterpret_cast(&add_val.y); + unsigned long long ref_addr = reinterpret_cast(ref); + __half ret_val_x, ret_val_y; + unsigned short ret_val_x_cast = + *reinterpret_cast(&ret_val_x); + unsigned short ret_val_y_cast = + *reinterpret_cast(&ret_val_y); + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { + asm volatile( + "atom.release.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile( + "atom.acquire.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile( + "atom.acq_rel.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } + } +} + +TL_DEVICE half2 +AtomicAddx2Ret(half_t *ref, half_t *val, + int memory_order = int(cuda::memory_order_relaxed)) { + if (memory_order == int(cuda::memory_order_relaxed)) { + return atomicAdd(reinterpret_cast(ref), + static_cast(*reinterpret_cast(val))); + } else { + __half2 add_val = *reinterpret_cast<__half2 *>(val); + unsigned short add_val_x_cast = + *reinterpret_cast(&add_val.x); + unsigned short add_val_y_cast = + *reinterpret_cast(&add_val.y); + unsigned long long ref_addr = reinterpret_cast(ref); + __half ret_val_x, ret_val_y; + unsigned short ret_val_x_cast = + *reinterpret_cast(&ret_val_x); + unsigned short ret_val_y_cast = + *reinterpret_cast(&ret_val_y); + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { + asm volatile( + "atom.release.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile( + "atom.acquire.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile( + "atom.acq_rel.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } + return half2(*reinterpret_cast<__half *>(&ret_val_x_cast), + *reinterpret_cast<__half *>(&ret_val_y_cast)); + } +} + +#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750)) +TL_DEVICE void AtomicAddx2(bfloat16_t *ref, bfloat16_t *val, + int memory_order = int(cuda::memory_order_relaxed)) { + if (memory_order == int(cuda::memory_order_relaxed)) { + atomicAdd( + reinterpret_cast<__nv_bfloat162 *>(ref), + static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val))); + } else { + __nv_bfloat162 add_val = *reinterpret_cast<__nv_bfloat162 *>(val); + unsigned short add_val_x_cast = + *reinterpret_cast(&add_val.x); + unsigned short add_val_y_cast = + *reinterpret_cast(&add_val.y); + unsigned long long ref_addr = reinterpret_cast(ref); + __nv_bfloat162 ret_val; + unsigned short ret_val_x_cast = + *reinterpret_cast(&ret_val.x); + unsigned short ret_val_y_cast = + *reinterpret_cast(&ret_val.y); + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { + asm volatile("atom.release.gpu.global.add.v2.bf16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile("atom.acquire.gpu.global.add.v2.bf16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile("atom.acq_rel.gpu.global.add.v2.bf16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } + } +} + +TL_DEVICE __nv_bfloat162 +AtomicAddx2Ret(bfloat16_t *ref, bfloat16_t *val, + int memory_order = int(cuda::memory_order_relaxed)) { + if (memory_order == int(cuda::memory_order_relaxed)) { + return atomicAdd( + reinterpret_cast<__nv_bfloat162 *>(ref), + static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val))); + } else { + __nv_bfloat162 add_val = *reinterpret_cast<__nv_bfloat162 *>(val); + unsigned short add_val_x_cast = + *reinterpret_cast(&add_val.x); + unsigned short add_val_y_cast = + *reinterpret_cast(&add_val.y); + unsigned long long ref_addr = reinterpret_cast(ref); + __nv_bfloat162 ret_val; + unsigned short ret_val_x_cast = + *reinterpret_cast(&ret_val.x); + unsigned short ret_val_y_cast = + *reinterpret_cast(&ret_val.y); + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { + asm volatile("atom.release.gpu.global.add.v2.bf16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile("atom.acquire.gpu.global.add.v2.bf16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile("atom.acq_rel.gpu.global.add.v2.bf16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } + return __nv_bfloat162(*reinterpret_cast<__nv_bfloat16 *>(&ret_val_x_cast), + *reinterpret_cast<__nv_bfloat16 *>(&ret_val_y_cast)); + } +} +#endif + +#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900)) +TL_DEVICE void AtomicAddx2(float *ref, float *val, + int memory_order = int(cuda::memory_order_relaxed)) { + if (memory_order == int(cuda::memory_order_relaxed)) { + atomicAdd(reinterpret_cast(ref), + static_cast(*reinterpret_cast(val))); + } else { + float2 add_val = *reinterpret_cast(val); + unsigned long long ref_addr = reinterpret_cast(ref); + float2 ret_val; + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { + asm volatile("atom.release.gpu.global.add.v2.f32 {%0,%1}, [%2], {%3,%4};" + : "=f"(ret_val.x), "=f"(ret_val.y) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile("atom.acquire.gpu.global.add.v2.f32 {%0,%1}, [%2], {%3,%4};" + : "=f"(ret_val.x), "=f"(ret_val.y) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile("atom.acq_rel.gpu.global.add.v2.f32 {%0,%1}, [%2], {%3,%4};" + : "=f"(ret_val.x), "=f"(ret_val.y) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y) + : "memory"); + } + } +} + +TL_DEVICE float2 +AtomicAddx2Ret(float *ref, float *val, + int memory_order = int(cuda::memory_order_relaxed)) { + if (memory_order == int(cuda::memory_order_relaxed)) { + return atomicAdd(reinterpret_cast(ref), + static_cast(*reinterpret_cast(val))); + } else { + float2 add_val = *reinterpret_cast(val); + unsigned long long ref_addr = reinterpret_cast(ref); + float2 ret_val; + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { + asm volatile("atom.release.gpu.global.add.v2.f32 {%0,%1}, [%2], {%3,%4};" + : "=f"(ret_val.x), "=f"(ret_val.y) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile("atom.acquire.gpu.global.add.v2.f32 {%0,%1}, [%2], {%3,%4};" + : "=f"(ret_val.x), "=f"(ret_val.y) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile("atom.acq_rel.gpu.global.add.v2.f32 {%0,%1}, [%2], {%3,%4};" + : "=f"(ret_val.x), "=f"(ret_val.y) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y) + : "memory"); + } + return ret_val; + } +} + +TL_DEVICE void AtomicAddx4(float *ref, float *val, + int memory_order = int(cuda::memory_order_relaxed)) { + if (memory_order == int(cuda::memory_order_relaxed)) { + atomicAdd(reinterpret_cast(ref), + static_cast(*reinterpret_cast(val))); + } else { + // Since atomicAdd does not support memory order, atomic_ref does not + // support vectorized atomic operation we can only inline ptx code here + // Note: Vectorized atomic operations only support global space + float4 add_val = *reinterpret_cast(val); + unsigned long long ref_addr = reinterpret_cast(ref); + float4 ret_val; + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { + asm volatile("atom.release.gpu.global.add.v4.f32 {%0,%1,%2,%3}, [%4], " + "{%5,%6,%7,%8};" + : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), + "=f"(ret_val.w) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), + "f"(add_val.z), "f"(add_val.w) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile("atom.acquire.gpu.global.add.v4.f32 {%0,%1,%2,%3}, [%4], " + "{%5,%6,%7,%8};" + : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), + "=f"(ret_val.w) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), + "f"(add_val.z), "f"(add_val.w) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile("atom.acq_rel.gpu.global.add.v4.f32 {%0,%1,%2,%3}, [%4], " + "{%5,%6,%7,%8};" + : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), + "=f"(ret_val.w) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), + "f"(add_val.z), "f"(add_val.w) + : "memory"); + } + } +} + +TL_DEVICE float4 +AtomicAddx4Ret(float *ref, float *val, + int memory_order = int(cuda::memory_order_relaxed)) { + if (memory_order == int(cuda::memory_order_relaxed)) { + return atomicAdd(reinterpret_cast(ref), + static_cast(*reinterpret_cast(val))); + } else { + float4 add_val = *reinterpret_cast(val); + unsigned long long ref_addr = reinterpret_cast(ref); + float4 ret_val; + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { + asm volatile("atom.global.gpu.release.add.v4.f32 {%0,%1,%2,%3}, [%4], " + "{%5,%6,%7,%8};" + : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), + "=f"(ret_val.w) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), + "f"(add_val.z), "f"(add_val.w) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile("atom.global.gpu.acquire.add.v4.f32 {%0,%1,%2,%3}, [%4], " + "{%5,%6,%7,%8};" + : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), + "=f"(ret_val.w) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), + "f"(add_val.z), "f"(add_val.w) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile("atom.global.gpu.acq_rel.add.v4.f32 {%0,%1,%2,%3}, [%4], " + "{%5,%6,%7,%8};" + : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), + "=f"(ret_val.w) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), + "f"(add_val.z), "f"(add_val.w) + : "memory"); + } + return ret_val; + } +} +#endif + +template TL_DEVICE T AtomicLoad(T &ref, int memory_order) { + cuda::atomic_ref aref(ref); + return aref.load(cuda::memory_order(memory_order)); +} + +template +TL_DEVICE void AtomicStore(T1 &ref, T2 value, int memory_order) { + using NT1 = typename normalize_atomic_type::type; + cuda::atomic_ref aref(ref); + aref.store(cuda_cast(value), cuda::memory_order(memory_order)); +} diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/barrier.h b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/barrier.h new file mode 100644 index 00000000000..79a57f7df1b --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/barrier.h @@ -0,0 +1,162 @@ +#pragma once + +#include "common.h" +#include + +// Reuse cutlass advanced barrier abstraction +using Barrier = cutlass::arch::ClusterTransactionBarrier; + +namespace tl { + +TL_DEVICE void mbarrier_init(uint64_t &smem_barrier, uint32_t arrive_count) { + uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); + asm volatile("mbarrier.init.shared.b64 [%1], %0;" + : + : "r"(arrive_count), "r"(smem_int_ptr)); +} + +TL_DEVICE uint32_t mbarrier_try_wait(uint64_t &smem_barrier, int phase_bit) { + + uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); + uint32_t waitComplete; + + asm volatile("{\n\t" + ".reg .pred P1; \n\t" + "mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P1; \n\t" + "}" + : "=r"(waitComplete) + : "r"(smem_int_ptr), "r"(phase_bit)); + + return waitComplete; +} + +TL_DEVICE void mbarrier_wait(uint64_t &smem_barrier, int phase_bit) { + if (mbarrier_try_wait(smem_barrier, phase_bit) == 0) { + uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); + // Arbitrarily large timer value after which try-wait expires and re-tries. + uint32_t ticks = 0x989680; + asm volatile("{\n\t" + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t" + "@P1 bra DONE; \n\t" + "bra LAB_WAIT; \n\t" + "DONE: \n\t" + "}" + : + : "r"(smem_int_ptr), "r"(phase_bit), "r"(ticks)); + } +} + +TL_DEVICE void mbarrier_test_wait(uint64_t &smem_barrier, int phase_bit) { + uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); + asm volatile( + "{\n" + ".reg .pred P1;\n" + "LAB_WAIT:\n" + "mbarrier.test_wait.parity.shared::cta.b64 P1, [%0], %1;\n" + "@P1 bra.uni DONE;\n" + "nanosleep.u32 5;\n" // wait a few nanoseconds on pre-Hopper architectures + // to save instruction issue slots + "bra.uni LAB_WAIT;\n" + "DONE:\n" + "}\n" ::"r"(smem_int_ptr), + "r"(phase_bit)); +} + +TL_DEVICE void mbarrier_arrive(uint64_t &smem_barrier) { + uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); + asm volatile("mbarrier.arrive.shared.b64 _, [%0];" : : "r"(smem_int_ptr)); +} + +TL_DEVICE void mbarrier_arrive(uint64_t &smem_barrier, int cta_id, + uint32_t pred) { + uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); + if (pred) { + asm volatile("{\n\t" + ".reg .b32 remAddr32;\n\t" + "mapa.shared::cluster.u32 remAddr32, %0, %1;\n\t" + "mbarrier.arrive.shared::cluster.b64 _, [remAddr32];\n\t" + "}" + : + : "r"(smem_int_ptr), "r"(cta_id)); + } +} + +TL_DEVICE void mbarrier_expect_tx(uint64_t &smem_barrier, + uint32_t transaction_bytes) { + uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); + asm volatile("mbarrier.expect_tx.shared.b64 [%1], %0;" + : + : "r"(transaction_bytes), "r"(smem_int_ptr)); +} + +TL_DEVICE void mbarrier_arrive_expect_tx(uint64_t &smem_barrier, + uint32_t transaction_bytes) { + uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); + asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%1], %0;" + : + : "r"(transaction_bytes), "r"(smem_int_ptr)); +} + +template +TL_DEVICE void mbarrier_cp_async_arrive(BarrierType &smem_mbar) { + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + asm volatile("cp.async.mbarrier.arrive.shared.b64 [%0];" + : + : "r"(smem_int_mbar)); +} + +template +TL_DEVICE void mbarrier_cp_async_arrive_noinc(BarrierType &smem_mbar) { + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + asm volatile("{\n\t" + "cp.async.mbarrier.arrive.noinc.shared::cta.b64 [%0];\n\t" + "}" + : + : "r"(smem_int_mbar)); + cutlass::arch::synclog_emit_cpasync_barrier_arrive(__LINE__, smem_int_mbar); +} + +TL_DEVICE void fence_proxy_async() { + asm volatile("fence.proxy.async.shared::cta;" : :); +} + +TL_DEVICE void fence_barrier_init() { + asm volatile("fence.mbarrier_init.release.cluster;" : :); +} + +// Indicate arrival of warp issuing TMA_STORE +TL_DEVICE void tma_store_arrive() { + asm volatile("cp.async.bulk.commit_group;"); +} + +template TL_DEVICE void tma_store_wait() { + asm volatile("cp.async.bulk.wait_group.read %0;" : : "n"(Count) : "memory"); +} + +TL_DEVICE void syncthreads_partial(uint64_t &smem_barrier) { + uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); + uint64_t state = 0; + asm volatile("{\n" + ".reg .pred P1;\n" + "mbarrier.arrive.shared.b64 %1, [%0];\n" + "LAB_WAIT:\n" + "mbarrier.try_wait.shared.b64 P1, [%0], %1;\n" + "@!P1 bra.uni LAB_WAIT;\n" + "}\n" + : + : "r"(smem_int_ptr), "l"(state)); +} +} // namespace tl diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/common.h b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/common.h new file mode 100644 index 00000000000..a42aa1bd0fb --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/common.h @@ -0,0 +1,347 @@ +#pragma once + +#ifndef __CUDACC_RTC__ +#include +#endif + +#include "atomic.h" +#include +#include +#include +#include + +using cutlass::bfloat16_t; +using cutlass::half_t; +using cutlass::tfloat32_t; + +using cute::cast_smem_ptr_to_uint; + +using int4_t = int4; + +#define hexp cutlass::fast_exp +#define hlog cutlass::fast_log +#define hsqrt cutlass::fast_sqrt +#define hsin cutlass::fast_sin +#define hcos cutlass::fast_cos +#define htanh cutlass::fast_tanh +#define hpow powf + +#define uint unsigned int +#define uchar unsigned char +#define ushort unsigned short + +#define TL_DEVICE __forceinline__ __device__ +#define TL_DEVICE_NOINLINE __noinline__ __device__ +#define TL_PATCH + +#define TILELANG_CHECK(stmt) \ + do { \ + cudaError_t __err = (stmt); \ + if (__err != cudaSuccess) { \ + snprintf(error_buf, ERROR_BUF_SIZE, "%s:%d: %s - %s", __FILE__, \ + __LINE__, cudaGetErrorName(__err), cudaGetErrorString(__err)); \ + return -1; \ + } \ + } while (0) + +#define TILELANG_CHECK_LAST_ERROR(kernel_name) \ + do { \ + cudaError_t __err = cudaGetLastError(); \ + if (__err != cudaSuccess) { \ + snprintf(error_buf, ERROR_BUF_SIZE, kernel_name ": %s - %s", \ + cudaGetErrorName(__err), cudaGetErrorString(__err)); \ + return -1; \ + } \ + } while (0) + +// using cutlass abs function for half_t +TL_PATCH TL_DEVICE half_t __habs(const half_t x) { return abs(x); } + +// using cutlass abs function for bfloat_t +TL_PATCH TL_DEVICE bfloat16_t __habs(const bfloat16_t x) { return abs(x); } + +// hrsqrt function for half_t +TL_PATCH TL_DEVICE half_t hrsqrt(const half_t x) { + return half_t(hrsqrt(x.to_half())); +} + +// Pack two half values. +TL_DEVICE unsigned __pack_half2(const half x, const half y) { + unsigned v0 = *((unsigned short *)&x); + unsigned v1 = *((unsigned short *)&y); + return (v1 << 16) | v0; +} + +// Pack two half_t values. +TL_DEVICE unsigned __pack_half2(const half_t x, const half_t y) { + unsigned v0 = *((unsigned short *)&x); + unsigned v1 = *((unsigned short *)&y); + return (v1 << 16) | v0; +} + +// Pack two bfloat16_t values. +TL_DEVICE unsigned __pack_half2(const bfloat16_t x, const bfloat16_t y) { + unsigned v0 = *((unsigned short *)&x); + unsigned v1 = *((unsigned short *)&y); + return (v1 << 16) | v0; +} + +// Pack two bfloat16_t values. +TL_DEVICE unsigned __pack_nv_bfloat162(const bfloat16_t x, const bfloat16_t y) { + unsigned v0 = *((unsigned short *)&x); + unsigned v1 = *((unsigned short *)&y); + return (v1 << 16) | v0; +} + +// Pack four char values. +TL_DEVICE int make_int(signed char x0, signed char x1, signed char x2, + signed char x3) { + return (x3 << 24) | (x2 << 16) | (x1 << 8) | x0; +} + +// Pack eight char values. +TL_DEVICE int2 make_int2(signed char x0, signed char x1, signed char x2, + signed char x3, signed char y0, signed char y1, + signed char y2, signed char y3) { + int2 result; + result.x = make_int(x0, x1, x2, x3); + result.y = make_int(y0, y1, y2, y3); + return result; +} + +// Pack sixteen char values. +TL_DEVICE int4_t make_int4(signed char x0, signed char x1, signed char x2, + signed char x3, signed char y0, signed char y1, + signed char y2, signed char y3, signed char z0, + signed char z1, signed char z2, signed char z3, + signed char w0, signed char w1, signed char w2, + signed char w3) { + int4_t result; + result.x = make_int(x0, x1, x2, x3); + result.y = make_int(y0, y1, y2, y3); + result.z = make_int(z0, z1, z2, z3); + result.w = make_int(w0, w1, w2, w3); + return result; +} + +// Pack eight int values. +TL_DEVICE longlong4 make_longlong4(int x0, int x1, int y0, int y1, int z0, + int z1, int w0, int w1) { + longlong4 result; + *((int2 *)&result.x) = make_int2(x0, x1); + *((int2 *)&result.y) = make_int2(y0, y1); + *((int2 *)&result.z) = make_int2(z0, z1); + *((int2 *)&result.w) = make_int2(w0, w1); + return result; +} + +// Helper to cast SMEM pointer to unsigned +TL_DEVICE uint32_t smem_ptr_to_uint(void const *const ptr) { + return static_cast(__cvta_generic_to_shared(ptr)); +} + +/** + * Convert a shared-memory pointer to a 32-bit unsigned integer address. + * + * Casts the given pointer (expected to reference shared memory) into a 32-bit + * unsigned integer using the device address-space conversion required for + * shared-memory pointers. + * + * @param smem_ptr Pointer into shared memory. + * @return 32-bit unsigned integer representation of the shared-memory address. + * + * @note The pointer must refer to shared memory; behavior is undefined for + * pointers in other address spaces. + */ +TL_DEVICE unsigned int cast_smem_ptr_to_int(const void *const smem_ptr) { + unsigned int smem_int; + asm volatile("{ .reg .u64 smem_int; cvta.to.shared.u64 smem_int, %1; " + "cvt.u32.u64 %0, smem_int; }" + : "=r"(smem_int) + : "l"(smem_ptr)); + return smem_int; +} + +// DP4A +template +TL_DEVICE /** + * Compute a 4×8-bit dot-product-accumulate using the CUDA DP4A + * intrinsic. + * + * Reads 32-bit packed values from `a` and `b` (each containing four + * signed 8-bit lanes), applies the __dp4a operation (dot product of + * the four lane pairs added to an accumulator), and stores the 32-bit + * integer result through `c`. + * + * @param a Pointer to a 32-bit packed input containing four signed + * 8-bit elements. + * @param b Pointer to a 32-bit packed input containing four signed + * 8-bit elements. + * @param c Pointer to a 32-bit accumulator; its current value is used + * as the initial accumulator and overwritten with the resulting int32 + * sum. + */ + void + DP4A(InDatatype *a, InDatatype *b, OutDatatype *c) { + const int a_int = *((int *)a); + const int b_int = *((int *)b); + const int c_int = *((int *)c); + *c = __dp4a(a_int, b_int, c_int); +} + +namespace tl { +/*! + * \brief PTX data type. + * \note + * PTX fundamental data types: + * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#fundamental-types + * PTX matrix data types: + * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types + */ +enum class DataType : int { + kInt4 = 0, + kUInt4 = 1, + kInt8 = 2, + kUInt8 = 3, + kInt16 = 4, + kUInt16 = 5, + kInt32 = 6, + kUInt32 = 7, + kInt64 = 8, + kUInt64 = 9, + kFloat8_e4m3 = 10, + kFloat8_e5m2 = 11, + kFloat16 = 12, + kBFloat16 = 13, + kFloat16x2 = 14, + kFloat32 = 15, + kTensorFloat32 = 16, + kFloat64 = 17, + kBit1 = 18, + kBit8 = 19, + kBit16 = 20, + kBit32 = 21, + kBit64 = 22 +}; + +union GmmaDescriptor { + CUTE_HOST_DEVICE constexpr GmmaDescriptor() noexcept : desc_(0) {} + CUTE_HOST_DEVICE constexpr GmmaDescriptor(uint64_t desc) noexcept + : desc_(desc) {} + CUTE_HOST_DEVICE constexpr GmmaDescriptor(GmmaDescriptor const &t) noexcept + : desc_(t.desc_) {} + CUTE_HOST_DEVICE constexpr GmmaDescriptor(GmmaDescriptor &&t) noexcept + : desc_(t.desc_) {} + + CUTE_HOST_DEVICE constexpr GmmaDescriptor & + operator=(GmmaDescriptor const &t) noexcept { + desc_ = t.desc_; + return *this; + } + + CUTE_HOST_DEVICE constexpr GmmaDescriptor & + operator=(GmmaDescriptor &&t) noexcept { + desc_ = t.desc_; + return *this; + } + + uint64_t desc_; + uint32_t reg32_[2]; + uint16_t reg16_[4]; + + // Bitfield implementation avoids the need for shifts in assignment + struct { + // start_address, bit [0,14), 4LSB not included + uint16_t start_address_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // leading dimension byte offset, bit [16,30), 4LSB not included + // For N: This is the stride from the first col to the second col of the 8x2 + // brick in INTERLEAVED + // Unused for all SWIZZLE_* layouts (and assumed to be 1) + // For T: This is the stride from the first 8 rows to the next 8 rows. + uint16_t leading_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // stride dimension byte offset, bit [32,46), 4LSB not included + // For N: This is the stride from the first 8 rows to the next 8 rows. + // For T: This is the stride fro mthe first 8 cols to the next 8 cols. + uint16_t stride_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // base_offset, bit [49,52) + // Valid only for SWIZZLE_128B and SWIZZLE_64B + uint8_t : 1, base_offset_ : 3, + : 4; // 1 bit unused, 3 bits [1,4), 4 bits unused + // layout type, bit [62,64) + // SWIZZLE_NONE = 0, SWIZZLE_32B = 3, SWIZZLE_64B = 2, SWIZZLE_128B = 1 + uint8_t : 6, layout_type_ : 2; // 6 bits unused, 2 bits [6,8) + } bitfield; + + // Decay to a uint64_t + CUTE_HOST_DEVICE constexpr operator uint64_t() const noexcept { + return desc_; + } + template + CUTE_HOST_DEVICE constexpr GmmaDescriptor operator+(const T &offset) const { + GmmaDescriptor ret; + ret.reg32_[0] = reg32_[0] + uint32_t(offset); + ret.reg32_[1] = reg32_[1]; + return ret; + } +}; + +// Any +template TL_DEVICE bool Any(T *a, int size) { + for (int i = 0; i < size; i++) { + if (a[i]) { + return true; + } + } + return false; +} + +// All +template TL_DEVICE bool All(T *a, int size) { + for (int i = 0; i < size; i++) { + if (!a[i]) { + return false; + } + } + return true; +} + +// Pow of int +template TL_DEVICE T pow_of_int(T x) { + T result = x; + for (int i = 1; i < y; i++) { + result *= x; + } + return result; +} + +// Thread partial barrier synchronization +// https://docs.nvidia.com/cuda/parallel-thread-execution/#memory-consistency-model +template +TL_DEVICE void __sync_thread_partial() { + asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(thread_count)); +} + +template +TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor, + T *start_address) { + descriptor.bitfield.start_address_ = + cute::cast_smem_ptr_to_uint(start_address) >> 4; + descriptor.bitfield.layout_type_ = layout_type; + descriptor.bitfield.base_offset_ = 0; + descriptor.bitfield.leading_byte_offset_ = leading_byte_offset; + descriptor.bitfield.stride_byte_offset_ = stride_byte_offset; +} + +template +TL_DEVICE void increase_descriptor_offset(GmmaDescriptor &descriptor, + T offset) { + descriptor.reg32_[0] += (offset >> 4); +} + +} // namespace tl + +namespace cutlass { +TL_DEVICE +bfloat16_t fast_exp(bfloat16_t x) { return ::hexp(x); } +} // namespace cutlass diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/compress_sm90.cu b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/compress_sm90.cu new file mode 100644 index 00000000000..8bb236dd837 --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/compress_sm90.cu @@ -0,0 +1,167 @@ +#include + +#include + +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/transform/device/transform_universal_adapter.hpp" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" + +using namespace cute; + +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) \ + << " at: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +#define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ + << " at line: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } +template +std::tuple compress_impl(torch::Tensor A) { + using ElementA = T; + using ElementE = uint8_t; + using LayoutTagA = conditional_t; + using ProblemShape = cute::Shape; + + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideE = StrideA; + + // NOTE: this is derived from sparse sm90 mma atoms + // Ref: https://github.com/NVIDIA/cutlass/blob/dc4817921edda44a549197ff3a9dcf5df0636e7b/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp + using SparseE = conditional_t<(sizeof_bits_v == 32), cute::sparse_elem<4, ElementE>, cute::sparse_elem<8, ElementE>>; + static constexpr GMMA::Major GmmaMajorA = transposed ? cute::SM90::GMMA::Major::MN : cute::SM90::GMMA::Major::K; + using SparseConfig = cutlass::Sm90GemmSparseConfig< + cute::sparse_elem<2, ElementA>, GmmaMajorA, + SparseE, cute::C>; + + using CompressorUtility = + cutlass::transform::kernel::StructuredSparseCompressorUtility< + ProblemShape, ElementA, LayoutTagA, SparseConfig>; + + using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + ProblemShape, ElementA, LayoutTagA, SparseConfig, cutlass::arch::Sm90>; + + using Compressor = cutlass::transform::device::TransformUniversalAdapter; + + TORCH_CHECK(A.is_contiguous(), "A need to be contiguous"); + TORCH_CHECK(A.dim() == 2, "Might support batch dim in the future "); + + int M = -1; + int K = -1; + int N = -1; // not used, but required for config + int L = 1; + if constexpr(transposed) { + M = A.size(1); + K = A.size(0); + } else { + M = A.size(0); + K = A.size(1); + } + + ProblemShape problem_shape = make_tuple(M, N, K, L); + StrideA stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + + CompressorUtility compressor_utility(problem_shape, stride_A); + int ME = compressor_utility.get_metadata_m_physical(); + int KE = compressor_utility.get_metadata_k_physical(); + int KC = compressor_utility.get_tensorA_k_physical(); + + StrideE stride_E = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(ME, KE, L)); + auto dtype = A.dtype().toScalarType(); + torch::Tensor A_compressed = torch::zeros(KC * M, + torch::TensorOptions().dtype(dtype).device(A.device())); + torch::Tensor E = torch::zeros({ME, KE}, + torch::TensorOptions().dtype(torch::kUInt8).device(A.device())); + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = A.device().index(); + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + typename Compressor::Arguments arguments{problem_shape, + { + A.data_ptr(), + stride_A, + A_compressed.data_ptr(), + E.data_ptr(), + }, + {hw_info}}; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(compressor_op.can_implement(arguments)); + CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get())); + CUTLASS_CHECK(compressor_op.run()); + CUDA_CHECK(cudaDeviceSynchronize()); + + if constexpr (transposed) { + return std::make_tuple(A_compressed.view({KC, M}), E); + } else { + return std::make_tuple(A_compressed.view({M, KC}), E); + } +} + +// block <= 128 +// Ref https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146 +#define DISPATCH_BLOCK_K(TYPE, BLOCK_K, FACTOR, TENSOR, TRANSPOSED) \ + [&]() -> std::tuple { \ + switch (BLOCK_K) { \ + case int(32 * FACTOR): return compress_impl(TENSOR); \ + case int(64 * FACTOR): return compress_impl(TENSOR); \ + case int(128 * FACTOR): return compress_impl(TENSOR); \ + default: \ + TORCH_CHECK(false, "Unsupported block_k: ", BLOCK_K); \ + } \ + }() + +#define DISPATCH_CONTIGUOUS(TRANSPOSED) \ + [&]() -> std::tuple { \ + switch (dtype) { \ + case torch::kFloat32: \ + return DISPATCH_BLOCK_K(float, block_k, 0.5, A, TRANSPOSED); \ + case torch::kFloat16: \ + case torch::kBFloat16: \ + return DISPATCH_BLOCK_K(cute::half_t, block_k, 1, A, TRANSPOSED); \ + case torch::kFloat8_e4m3fn: \ + return DISPATCH_BLOCK_K(cute::float_e4m3_t, block_k, 2, A, TRANSPOSED); \ + case torch::kFloat8_e5m2: \ + return DISPATCH_BLOCK_K(cute::float_e5m2_t, block_k, 2, A, TRANSPOSED); \ + case torch::kChar: \ + return DISPATCH_BLOCK_K(int8_t, block_k, 2, A, TRANSPOSED); \ + case torch::kByte: \ + return DISPATCH_BLOCK_K(uint8_t, block_k, 2, A, TRANSPOSED); \ + default: \ + TORCH_CHECK(false, "Unsupported dtype"); \ + } \ + }() + +std::tuple compress_sm90(torch::Tensor A, int64_t block_k, bool transposed) { + auto dtype = A.dtype().toScalarType(); + return transposed ? DISPATCH_CONTIGUOUS(true) : DISPATCH_CONTIGUOUS(false); +} + +#undef DISPATCH_BLOCK_K +#undef DISPATCH_CONTIGUOUS + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("compress_sm90", torch::wrap_pybind_function(compress_sm90), + "compress_sm90"); +} diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/copy.h b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/copy.h new file mode 100644 index 00000000000..1dd53843431 --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/copy.h @@ -0,0 +1,80 @@ +#pragma once + +#include "common.h" + +#ifdef __CUDA_ARCH_LIST__ +#if __CUDA_ARCH_LIST__ >= 900 +#include "copy_sm90.h" +#endif +#if __CUDA_ARCH_LIST__ >= 1000 +#include "copy_sm100.h" +#endif +#endif + +namespace tl { + +TL_DEVICE void cp_async_commit() { + asm volatile("cp.async.commit_group;\n" ::); +} + +template TL_DEVICE void cp_async_wait() { + if constexpr (N == 0) { + asm volatile("cp.async.wait_all;\n" ::); + } else { + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); + } +} + +template +TL_DEVICE void cp_async_gs(void const *const smem_addr, void *global_ptr) { + static_assert(N == 16 || N == 8 || N == 4); + unsigned int addr = smem_ptr_to_uint(smem_addr); + if constexpr (N == 16) { + asm volatile( +#if TL_ENABLE_L2_PREFETCH + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2;" +#else + "cp.async.cg.shared.global [%0], [%1], %2;" +#endif + ::"r"(addr), + "l"((void *)(global_ptr)), "n"(N)); + } else { + asm volatile( +#if TL_ENABLE_L2_PREFETCH + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;" +#else + "cp.async.ca.shared.global [%0], [%1], %2;" +#endif + ::"r"(addr), + "l"((void *)(global_ptr)), "n"(N)); + } +} + +template +TL_DEVICE void cp_async_gs_conditional(void const *const smem_addr, + void *global_ptr, bool cond) { + static_assert(N == 16 || N == 8 || N == 4); + int bytes = cond ? N : 0; + unsigned int addr = smem_ptr_to_uint(smem_addr); + if constexpr (N == 16) { + asm volatile( +#if TL_ENABLE_L2_PREFETCH + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;" +#else + "cp.async.cg.shared.global [%0], [%1], %2, %3;" +#endif + ::"r"(addr), + "l"((void *)(global_ptr)), "n"(N), "r"(bytes)); + } else { + asm volatile( +#if TL_ENABLE_L2_PREFETCH + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2, %3;" +#else + "cp.async.ca.shared.global [%0], [%1], %2, %3;" +#endif + ::"r"(addr), + "l"((void *)(global_ptr)), "n"(N), "r"(bytes)); + } +} + +} // namespace tl diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/copy_sm100.h b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/copy_sm100.h new file mode 100644 index 00000000000..c4047c34974 --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/copy_sm100.h @@ -0,0 +1,134 @@ +#pragma once +#include "cuda_fp8.h" +#include "tcgen_05.h" +#include "tcgen_05_ld.h" + +namespace tl { + +__device__ __forceinline__ longlong4 ld_global_256(const longlong4 *ptr) { + longlong4 ret; + asm volatile("ld.global.v4.s64 {%0, %1, %2, %3}, [%4];" + : "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w) + : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ void st_global_256(longlong4 *ptr, longlong4 &val) { + asm volatile("st.global.v4.s64 [%0], {%1, %2, %3, %4};" + : + : "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w)); +} + +__device__ __forceinline__ ulonglong4 ld_global_256(const ulonglong4 *ptr) { + ulonglong4 ret; + asm volatile("ld.global.v4.u64 {%0, %1, %2, %3}, [%4];" + : "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w) + : "l"(ptr)); + return ret; +} + +// must be const &val, otherwise the compiler will generate a temporary variable +// and compilation will fail if we have st_global_256(ptr, ld_global_256(ptr)) +__device__ __forceinline__ void st_global_256(ulonglong4 *ptr, + const ulonglong4 &val) { + asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};" + : + : "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w)); +} + +__device__ __forceinline__ ulonglong4 ld_global_256(const fp8_e4_32_t *ptr) { + ulonglong4 ret; + asm volatile("ld.global.v4.u64 {%0, %1, %2, %3}, [%4];" + : "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w) + : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ void st_global_256(fp8_e4_32_t *ptr, + fp8_e4_32_t &val8) { + ulonglong4 &val = *((ulonglong4 *)&val8); + asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};" + : + : "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w)); +} + +__device__ __forceinline__ unsigned long long +pack_bfloat16x4(const bfloat16_t x, const bfloat16_t y, const bfloat16_t z, + const bfloat16_t w) { + unsigned long long v0 = *((unsigned short *)&x); + unsigned long long v1 = *((unsigned short *)&y); + unsigned long long v2 = *((unsigned short *)&z); + unsigned long long v3 = *((unsigned short *)&w); + return (v0 | (v1 << 16) | (v2 << 32) | (v3 << 48)); +} + +__device__ __forceinline__ unsigned long long +pack_float16x4(const half x, const half y, const half z, const half w) { + unsigned long long v0 = *((unsigned short *)&x); + unsigned long long v1 = *((unsigned short *)&y); + unsigned long long v2 = *((unsigned short *)&z); + unsigned long long v3 = *((unsigned short *)&w); + return (v0 | (v1 << 16) | (v2 << 32) | (v3 << 48)); +} + +// Helper function to find the largest K that 2**K <= N +// Requires N > 0 +template +__device__ __forceinline__ constexpr int get_floor_log2() { + static_assert(N > 0); + if constexpr ((1 << (K + 1)) > N) + return K; + else + return get_floor_log2(); +} + +template +__device__ __forceinline__ void tcgen05_ld_core(uint32_t const &tmem_start_col, + dst_t *dst_ptr) { + static_assert(N > 0); + constexpr int LOG_N = get_floor_log2(); + constexpr int CUR_SEGMENT_LEN = 1 << (LOG_N > MAX_LOGN ? MAX_LOGN : LOG_N); + target_call_cls::copy(tmem_start_col, (uint32_t *)dst_ptr); + if constexpr (N - CUR_SEGMENT_LEN > 0) { + tcgen05_ld_core( + tmem_start_col + CUR_SEGMENT_LEN, dst_ptr + CUR_SEGMENT_LEN); + } +} + +template +__device__ __forceinline__ void +tcgen05_ld_32dp32bNx(uint32_t const &tmem_start_col, + uint32_t const &tmem_col_offset, dst_t *dst_ptr) { + tcgen05_ld_core(tmem_start_col + tmem_col_offset, + dst_ptr); + tl::fence_view_async_tmem_load(); +} + +template +__device__ __forceinline__ void +tcgen05_ld_32dp64bNx(uint32_t const &tmem_start_col, + uint32_t const &tmem_col_offset, dst_t *dst_ptr) { + tcgen05_ld_core(tmem_start_col + tmem_col_offset, + dst_ptr); + tl::fence_view_async_tmem_load(); +} + +template +__device__ __forceinline__ void +tcgen05_ld_32dp128bNx(uint32_t const &tmem_start_col, + uint32_t const &tmem_col_offset, dst_t *dst_ptr) { + tcgen05_ld_core( + tmem_start_col + tmem_col_offset, dst_ptr); + tl::fence_view_async_tmem_load(); +} + +template +__device__ __forceinline__ void +tcgen05_ld_32dp256bNx(uint32_t const &tmem_start_col, + uint32_t const &tmem_col_offset, dst_t *dst_ptr) { + tcgen05_ld_core( + tmem_start_col + tmem_col_offset, dst_ptr); + tl::fence_view_async_tmem_load(); +} + +} // namespace tl diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/copy_sm90.h b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/copy_sm90.h new file mode 100644 index 00000000000..b8b174dc4c2 --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/copy_sm90.h @@ -0,0 +1,270 @@ +#pragma once + +#ifndef __CUDACC_RTC__ +#include +#endif + +#include "barrier.h" +#include "common.h" + +namespace tl { +enum class CacheHintSm90 : uint64_t { + EVICT_NORMAL = 0x1000000000000000, + EVICT_FIRST = 0x12F0000000000000, + EVICT_LAST = 0x14F0000000000000, +}; + +template +TL_DEVICE void tma_load(void *smem_ptr, void *gmem_ptr, BarrierType &smem_mbar, + uint32_t size) { + uint32_t smem_int_mbar = + smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::" + "bytes [%0], [%1], %2, [%3]; \n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), "r"(size), "r"(smem_int_mbar) + :); +} + +TL_DEVICE void tma_load_multicast(void *smem_ptr, void *gmem_ptr, + uint64_t &smem_mbar, uint32_t size, + uint16_t mask) { + uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar); + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes." + "multicast::cluster [%0], [%1], %2, [%3], %4; \n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), "r"(size), "r"(smem_int_mbar), "h"(mask) + :); +} + +template +TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar, + void const *const smem_ptr, int32_t const &crd0) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::" + "complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3}], [%2], %4;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "l"(cache_hint) + : "memory"); +} + +template +TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar, + void const *const smem_ptr, int32_t const &crd0, + int32_t const &crd1) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::" + "complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4}], [%2], %5;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "l"(cache_hint) + : "memory"); +} + +template +TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar, + void const *const smem_ptr, int32_t const &crd0, + int32_t const &crd1, int32_t const &crd2) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::" + "complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5}], [%2], %6;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint) + : "memory"); +} +template +TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar, + void const *const smem_ptr, int32_t const &crd0, + int32_t const &crd1, int32_t const &crd2, + int32_t const &crd3) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::" + "complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6}], [%2], %7;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint) + : "memory"); +} + +template +TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar, + void const *const smem_ptr, int32_t const &crd0, + int32_t const &crd1, int32_t const &crd2, + int32_t const &crd3, int32_t const &crd4) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::" + "complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), + "l"(cache_hint) + : "memory"); +} + +template +TL_DEVICE void +tma_load_im2col(const CUtensorMap &descriptor, BarrierType &smem_mbar, + void const *const smem_ptr, int32_t const &coord_c, + int32_t const &coord_w, int32_t const &coord_h, + int32_t const &coord_n, uint16_t const &offset_w, + uint16_t const &offset_h) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_mbar = + smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier:" + ":complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8}, %9;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n), + "h"(offset_w), "h"(offset_h), "l"(cache_hint) + : "memory"); +} + +template +TL_DEVICE void tma_store(void *gmem_ptr, void *smem_ptr, uint32_t size) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.global.shared::cta.bulk_group" + ".L2::cache_hint [%0], [%1], %2, %3;" + : + : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(size), "l"(cache_hint) + :); +} + +template +TL_DEVICE void tma_store(const CUtensorMap &descriptor, + void const *const smem_ptr, int32_t const &crd0) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.tensor.1d.global.shared::cta.bulk_group " + ".L2::cache_hint [%0, {%2}], [%1], %3;" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), + "l"(cache_hint) + : "memory"); +} + +template +TL_DEVICE void tma_store(const CUtensorMap &descriptor, + void const *const smem_ptr, int32_t const &crd0, + int32_t const &crd1) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group " + ".L2::cache_hint [%0, {%2, %3}], [%1], %4;" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1), + "l"(cache_hint) + : "memory"); +} + +template +TL_DEVICE void tma_store(const CUtensorMap &descriptor, + void const *const smem_ptr, int32_t const &crd0, + int32_t const &crd1, int32_t const &crd2) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.tensor.3d.global.shared::cta.bulk_group " + ".L2::cache_hint [%0, {%2, %3, %4}], [%1], %5;" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1), + "r"(crd2), "l"(cache_hint) + : "memory"); +} + +template +TL_DEVICE void tma_store(const CUtensorMap &descriptor, + void const *const smem_ptr, int32_t const &crd0, + int32_t const &crd1, int32_t const &crd2, + int32_t const &crd3) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.tensor.4d.global.shared::cta.bulk_group " + ".L2::cache_hint [%0, {%2, %3, %4, %5}], [%1], %6;" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1), + "r"(crd2), "r"(crd3), "l"(cache_hint) + : "memory"); +} + +template +TL_DEVICE void tma_store(const CUtensorMap &descriptor, + void const *const smem_ptr, int32_t const &crd0, + int32_t const &crd1, int32_t const &crd2, + int32_t const &crd3, int32_t const &crd4) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.tensor.5d.global.shared::cta.bulk_group " + ".L2::cache_hint [%0, {%2, %3, %4, %5, %6}], [%1], %7;" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1), + "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint) + : "memory"); +} + +TL_DEVICE void tma_store_add(float *const smem_ptr, float *gmem_ptr, + int32_t const &store_bytes) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 " + "[%0], [%1], %2;\n" + : + : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes) + : "memory"); +} + +TL_DEVICE void prefetch_tma_descriptor(const CUtensorMap &descriptor) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + asm volatile("prefetch.tensormap [%0];" : : "l"(gmem_int_desc) : "memory"); +} + +} // namespace tl diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/cuda_bf16_fallbacks.cuh b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/cuda_bf16_fallbacks.cuh new file mode 100644 index 00000000000..f5641f61609 --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/cuda_bf16_fallbacks.cuh @@ -0,0 +1,257 @@ +// Downloaded from from FasterTransformer v5.2.1 +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_fallbacks.cuh +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cuda_bf16_wrapper.h" +#include + +namespace fastertransformer { + +#ifdef ENABLE_BF16 +inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float2 f_val; + f_val.x = __low2float(val); + f_val.y = __high2float(val); + return f_val; +#else + return __bfloat1622float2(val); +#endif +} + +inline __device__ int16_t bf1622int16(__nv_bfloat162 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float2 f_val; + f_val.x = max(min(__low2float(val), 127.f), -128.f); + f_val.y = max(min(__high2float(val), 127.f), -128.f); + union { int8_t int8[2]; int16_t int16; }; + int8[0] = static_cast(static_cast(f_val.x)); + int8[1] = static_cast(static_cast(f_val.y)); + return int16; +#else + val = __hmin2(val, make_bfloat162(127., 127.)); + val = __hmax2(val, make_bfloat162(-128., -128.)); + union { int8_t int8[2]; int16_t int16; }; + int8[0] = static_cast(static_cast(val.x)); + int8[1] = static_cast(static_cast(val.y)); + return int16; +#endif +} + +inline __device__ __nv_bfloat162 float22bf162(const float2 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __floats2bfloat162_rn(val.x, val.y); +#else + return __float22bfloat162_rn(val); +#endif +} + +inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + __nv_bfloat162 val2; + val2.x = val; + val2.y = val; + return val2; +#else + return __bfloat162bfloat162(val); +#endif +} + +inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl + fyl, fxh + fyh); +#else + return __hadd2(x, y); +#endif +} + +inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16( __bfloat162float(x) + __bfloat162float(y) ); +#else + return __hadd(x, y); +#endif +} + +inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl - fyl, fxh - fyh); +#else + return __hsub2(x, y); +#endif +} + +inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16( __bfloat162float(x) - __bfloat162float(y) ); +#else + return __hsub(x, y); +#endif +} + +inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl * fyl, fxh * fyh); +#else + return __hmul2(x, y); +#endif +} + +inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) ); +#else + return __hmul(x, y); +#endif +} + +inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh, fzl, fzh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + fzl = __low2float(z); + fzh = __high2float(z); + return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh); +#else + return __hfma2(x, y, z); +#endif +} + +inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z)); +#else + return __hfma(x, y, z); +#endif +} + +inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh; + fxl = __low2float(x); + fxh = __high2float(x);; + return __floats2bfloat162_rn(expf(fxl), expf(fxh)); +#else + return h2exp(x); +#endif +} + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) +inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hmul2(x, y); }; +inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hadd2(x, y); }; + +inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y) +{ + __nv_bfloat162 t; t.x = x; t.y = y; return t; +} + +#endif + +inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c)); +#else + return a + b + c; +#endif +} + +inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d)); +#else + return (__nv_bfloat16)((float)a + (float)b + (float)c + (float)d); +#endif +} + +inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fal, fah, fbl, fbh, fcl, fch; + fal = __low2float(a); + fah = __high2float(a); + fbl = __low2float(b); + fbh = __high2float(b); + fcl = __low2float(c); + fch = __high2float(c); + return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch); +#else + return a + b + c; +#endif +} + +inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c)); +#else + return a * b * c; +#endif +} + +inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fal, fah, fbl, fbh, fcl, fch; + fal = __low2float(a); + fah = __high2float(a); + fbl = __low2float(b); + fbh = __high2float(b); + fcl = __low2float(c); + fch = __high2float(c); + return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch); +#else + return a * b * c; +#endif +} + +inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fal, fah, fbl, fbh, fcl, fch, fdl, fdh; + fal = __low2float(a); + fah = __high2float(a); + fbl = __low2float(b); + fbh = __high2float(b); + fcl = __low2float(c); + fch = __high2float(c); + fdl = __low2float(d); + fdh = __high2float(d); + return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh); +#else + return a * b * c + d; +#endif +} + +#endif // ENABLE_BF16 + +} // namespace fastertransformer diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/cuda_bf16_wrapper.h b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/cuda_bf16_wrapper.h new file mode 100644 index 00000000000..efb6e798730 --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/cuda_bf16_wrapper.h @@ -0,0 +1,23 @@ +// Downloaded from from FasterTransformer v5.2.1 +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_wrapper.h +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#ifdef ENABLE_BF16 +#include +#endif diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/cuda_fp8.h b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/cuda_fp8.h new file mode 100644 index 00000000000..8d2165822cf --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/cuda_fp8.h @@ -0,0 +1,199 @@ +#pragma once + +#include +#include + +using fp8_e4_t = cute::float_e4m3_t; +using fp8_e5_t = cute::float_e5m2_t; + +struct __CUDA_ALIGN__(2) fp8_e4_2_t { + fp8_e4_t x; + fp8_e4_t y; +}; + +struct __CUDA_ALIGN__(4) fp8_e4_4_t { + fp8_e4_t x; + fp8_e4_t y; + fp8_e4_t z; + fp8_e4_t w; +}; + +struct __CUDA_ALIGN__(8) fp8_e4_8_t { + fp8_e4_4_t x; + fp8_e4_4_t y; +}; + +struct __CUDA_ALIGN__(16) fp8_e4_16_t { + fp8_e4_8_t x; + fp8_e4_8_t y; +}; + +struct __CUDA_ALIGN__(32) fp8_e4_32_t { + fp8_e4_16_t x; + fp8_e4_16_t y; + + __device__ __forceinline__ fp8_e4_32_t &operator=(const ulonglong4 &rhs) { + x.x = *(fp8_e4_8_t *)&rhs.x; + x.y = *(fp8_e4_8_t *)&rhs.y; + y.x = *(fp8_e4_8_t *)&rhs.z; + y.y = *(fp8_e4_8_t *)&rhs.w; + return *this; + } +}; + +struct __CUDA_ALIGN__(2) fp8_e5_2_t { + fp8_e5_t x; + fp8_e5_t y; +}; + +struct __CUDA_ALIGN__(4) fp8_e5_4_t { + fp8_e5_t x; + fp8_e5_t y; + fp8_e5_t z; + fp8_e5_t w; +}; + +struct __CUDA_ALIGN__(8) fp8_e5_8_t { + fp8_e5_4_t x; + fp8_e5_4_t y; +}; + +struct __CUDA_ALIGN__(16) fp8_e5_16_t { + fp8_e5_8_t x; + fp8_e5_8_t y; +}; + +struct __CUDA_ALIGN__(32) fp8_e5_32_t { + fp8_e5_16_t x; + fp8_e5_16_t y; + + __device__ __forceinline__ fp8_e5_32_t &operator=(const ulonglong4 &rhs) { + x.x = *(fp8_e5_8_t *)&rhs.x; + x.y = *(fp8_e5_8_t *)&rhs.y; + y.x = *(fp8_e5_8_t *)&rhs.z; + y.y = *(fp8_e5_8_t *)&rhs.w; + return *this; + } +}; + +// Pack two fp8_e4_t values. +__forceinline__ __device__ fp8_e4_2_t make_fp8_e4_2_t(fp8_e4_t x, fp8_e4_t y) { + fp8_e4_2_t result; + result.x = x; + result.y = y; + return result; +} + +// Pack four fp8_e4_t values. +__forceinline__ __device__ fp8_e4_4_t make_fp8_e4_4_t(fp8_e4_t x0, fp8_e4_t x1, + fp8_e4_t x2, + fp8_e4_t x3) { + fp8_e4_4_t result; + result.x = x0; + result.y = x1; + result.z = x2; + result.w = x3; + return result; +} + +// Pack eight fp8_e4_t values. +__forceinline__ __device__ fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x0, fp8_e4_t x1, + fp8_e4_t x2, fp8_e4_t x3, + fp8_e4_t x4, fp8_e4_t x5, + fp8_e4_t x6, + fp8_e4_t x7) { + fp8_e4_8_t result; + result.x = make_fp8_e4_4_t(x0, x1, x2, x3); + result.y = make_fp8_e4_4_t(x4, x5, x6, x7); + return result; +} + +// Pack sixteen fp8_e4_t values. +__forceinline__ __device__ fp8_e4_16_t +make_fp8_e4_16_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2, fp8_e4_t x3, + fp8_e4_t x4, fp8_e4_t x5, fp8_e4_t x6, fp8_e4_t x7, + fp8_e4_t y0, fp8_e4_t y1, fp8_e4_t y2, fp8_e4_t y3, + fp8_e4_t y4, fp8_e4_t y5, fp8_e4_t y6, fp8_e4_t y7) { + fp8_e4_16_t result; + result.x = make_fp8_e4_8_t(x0, x1, x2, x3, x4, x5, x6, x7); + result.y = make_fp8_e4_8_t(y0, y1, y2, y3, y4, y5, y6, y7); + return result; +} + +// Pack thirty-two fp8_e4_t values. +__forceinline__ __device__ fp8_e4_32_t make_fp8_e4_32_t( + fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2, fp8_e4_t x3, fp8_e4_t x4, + fp8_e4_t x5, fp8_e4_t x6, fp8_e4_t x7, fp8_e4_t x8, fp8_e4_t x9, + fp8_e4_t x10, fp8_e4_t x11, fp8_e4_t x12, fp8_e4_t x13, fp8_e4_t x14, + fp8_e4_t x15, fp8_e4_t y0, fp8_e4_t y1, fp8_e4_t y2, fp8_e4_t y3, + fp8_e4_t y4, fp8_e4_t y5, fp8_e4_t y6, fp8_e4_t y7, fp8_e4_t y8, + fp8_e4_t y9, fp8_e4_t y10, fp8_e4_t y11, fp8_e4_t y12, fp8_e4_t y13, + fp8_e4_t y14, fp8_e4_t y15) { + fp8_e4_32_t result; + result.x = make_fp8_e4_16_t(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, + x12, x13, x14, x15); + result.y = make_fp8_e4_16_t(y0, y1, y2, y3, y4, y5, y6, y7, y8, y9, y10, y11, + y12, y13, y14, y15); + return result; +} + +// Pack two fp8_e5_t values. +__forceinline__ __device__ fp8_e5_2_t make_fp8_e5_2_t(fp8_e5_t x, fp8_e5_t y) { + fp8_e5_2_t result; + result.x = x; + result.y = y; + return result; +} + +// Pack four fp8_e5_t values. +__forceinline__ __device__ fp8_e5_4_t make_fp8_e5_4_t(fp8_e5_t x0, fp8_e5_t x1, + fp8_e5_t x2, + fp8_e5_t x3) { + fp8_e5_4_t result; + result.x = x0; + result.y = x1; + result.z = x2; + result.w = x3; + return result; +} + +// Pack eight fp8_e5_t values. +__forceinline__ __device__ fp8_e5_8_t make_fp8_e5_8_t(fp8_e5_t x0, fp8_e5_t x1, + fp8_e5_t x2, fp8_e5_t x3, + fp8_e5_t x4, fp8_e5_t x5, + fp8_e5_t x6, + fp8_e5_t x7) { + fp8_e5_8_t result; + result.x = make_fp8_e5_4_t(x0, x1, x2, x3); + result.y = make_fp8_e5_4_t(x4, x5, x6, x7); + return result; +} + +// Pack sixteen fp8_e5_t values. +__forceinline__ __device__ fp8_e5_16_t +make_fp8_e5_16_t(fp8_e5_t x0, fp8_e5_t x1, fp8_e5_t x2, fp8_e5_t x3, + fp8_e5_t x4, fp8_e5_t x5, fp8_e5_t x6, fp8_e5_t x7, + fp8_e5_t y0, fp8_e5_t y1, fp8_e5_t y2, fp8_e5_t y3, + fp8_e5_t y4, fp8_e5_t y5, fp8_e5_t y6, fp8_e5_t y7) { + fp8_e5_16_t result; + result.x = make_fp8_e5_8_t(x0, x1, x2, x3, x4, x5, x6, x7); + result.y = make_fp8_e5_8_t(y0, y1, y2, y3, y4, y5, y6, y7); + return result; +} + +// Pack thirty-two fp8_e5_t values. +__forceinline__ __device__ fp8_e5_32_t make_fp8_e5_32_t( + fp8_e5_t x0, fp8_e5_t x1, fp8_e5_t x2, fp8_e5_t x3, fp8_e5_t x4, + fp8_e5_t x5, fp8_e5_t x6, fp8_e5_t x7, fp8_e5_t x8, fp8_e5_t x9, + fp8_e5_t x10, fp8_e5_t x11, fp8_e5_t x12, fp8_e5_t x13, fp8_e5_t x14, + fp8_e5_t x15, fp8_e5_t y0, fp8_e5_t y1, fp8_e5_t y2, fp8_e5_t y3, + fp8_e5_t y4, fp8_e5_t y5, fp8_e5_t y6, fp8_e5_t y7, fp8_e5_t y8, + fp8_e5_t y9, fp8_e5_t y10, fp8_e5_t y11, fp8_e5_t y12, fp8_e5_t y13, + fp8_e5_t y14, fp8_e5_t y15) { + fp8_e5_32_t result; + result.x = make_fp8_e5_16_t(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, + x12, x13, x14, x15); + result.y = make_fp8_e5_16_t(y0, y1, y2, y3, y4, y5, y6, y7, y8, y9, y10, y11, + y12, y13, y14, y15); + return result; +} diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/debug.h b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/debug.h new file mode 100644 index 00000000000..7dbb31ea385 --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/debug.h @@ -0,0 +1,268 @@ +#pragma once + +#include "./cuda_fp8.h" +#include "common.h" + +#ifndef __CUDACC_RTC__ +#include +#endif + +// Template declaration for device-side debug printing (variable only) +template __device__ void debug_print_var(const char *msg, T var); + +// Overload for pointer type (supports any cv-qualified T*) +template __device__ void debug_print_var(const char *msg, T *var) { + printf( + "msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=pointer " + "value=%p\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, var); +} + +// Specialization for signed char type +template <> +__device__ void debug_print_var(const char *msg, signed char var) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=signed " + "char " + "value=%d\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, var); +} + +// Specialization for unsigned char type +template <> +__device__ void debug_print_var(const char *msg, + unsigned char var) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): " + "dtype=unsigned char " + "value=%d\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, var); +} + +// Specialization for integer type +template <> __device__ void debug_print_var(const char *msg, int var) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=int " + "value=%d\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, var); +} + +// Specialization for unsigned integer type +template <> +__device__ void debug_print_var(const char *msg, + unsigned int var) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=int " + "value=%u\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, var); +} + +// Specialization for float type +template <> __device__ void debug_print_var(const char *msg, float var) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=float " + "value=%f\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, var); +} + +// Specialization for half type +template <> __device__ void debug_print_var(const char *msg, half var) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=half " + "value=%f\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, (float)var); +} + +// Specialization for half_t type +template <> +__device__ void debug_print_var(const char *msg, half_t var) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=half_t " + "value=%f\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, (float)var); +} + +// Specialization for bfloat16_t type +template <> +__device__ void debug_print_var(const char *msg, bfloat16_t var) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): " + "dtype=bfloat16_t value=%f\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, (float)var); +} + +// Specialization for double type +template <> +__device__ void debug_print_var(const char *msg, double var) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=double " + "value=%lf\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, var); +} + +// Specialization for fp8_e4_t type +template <> +__device__ void debug_print_var(const char *msg, fp8_e4_t var) { + printf( + "msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=fp8_e4_t " + "value=%f\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, (float)var); +} + +// Specialization for fp8_e5_t type +template <> +__device__ void debug_print_var(const char *msg, fp8_e5_t var) { + printf( + "msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=fp8_e5_t " + "value=%f\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, (float)var); +} + +// Template declaration for device-side debug printing (buffer only) +template +__device__ void debug_print_buffer_value(const char *msg, const char *buf_name, + int index, T var); + +// Specialization for signed char type +template <> +__device__ void +debug_print_buffer_value(const char *msg, const char *buf_name, + int index, signed char var) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=signed char value=%d\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, buf_name, index, var); +} + +// Specialization for unsigned char type +template <> +__device__ void +debug_print_buffer_value(const char *msg, const char *buf_name, + int index, unsigned char var) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=char value=%d\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, buf_name, index, var); +} + +// Specialization for integer type +template <> +__device__ void debug_print_buffer_value(const char *msg, + const char *buf_name, int index, + int var) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=int value=%d\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, buf_name, index, var); +} + +// Specialization for unsigned integer type +template <> +__device__ void +debug_print_buffer_value(const char *msg, const char *buf_name, + int index, unsigned int var) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=int value=%u\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, buf_name, index, var); +} + +// Specialization for float type +template <> +__device__ void debug_print_buffer_value(const char *msg, + const char *buf_name, int index, + float var) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=float value=%f\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, buf_name, index, var); +} + +// Specialization for half type +template <> +__device__ void debug_print_buffer_value(const char *msg, + const char *buf_name, int index, + half var) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=half value=%f\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, buf_name, index, (float)var); +} + +// Specialization for half_t type +template <> +__device__ void debug_print_buffer_value(const char *msg, + const char *buf_name, + int index, half_t var) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=half_t value=%f\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, buf_name, index, (float)var); +} + +// Specialization for bfloat16_t type +template <> +__device__ void +debug_print_buffer_value(const char *msg, const char *buf_name, + int index, bfloat16_t var) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=bfloat16_t value=%f\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, buf_name, index, (float)var); +} + +// Specialization for double type +template <> +__device__ void debug_print_buffer_value(const char *msg, + const char *buf_name, + int index, double var) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=double value=%lf\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, buf_name, index, var); +} + +// Specialization for fp8_e4_t type +template <> +__device__ void debug_print_buffer_value(const char *msg, + const char *buf_name, + int index, fp8_e4_t var) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=fp8_e4_t value=%f\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, buf_name, index, (float)var); +} + +// Specialization for fp8_e5_t type +template <> +__device__ void debug_print_buffer_value(const char *msg, + const char *buf_name, + int index, fp8_e5_t var) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=fp8_e5_t value=%f\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, buf_name, index, (float)var); +} + +// Specialization for int16 type +template <> +__device__ void debug_print_buffer_value(const char *msg, + const char *buf_name, + int index, int16_t var) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=int16_t value=%d\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, buf_name, index, (int32_t)var); +} + +TL_DEVICE void device_assert(bool cond) { assert(cond); } + +TL_DEVICE void device_assert_with_msg(bool cond, const char *msg) { + if (!cond) { + printf("Device assert failed: %s\n", msg); + assert(0); + } +} diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm.h b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm.h new file mode 100644 index 00000000000..b0b2a1b42e0 --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm.h @@ -0,0 +1,18 @@ +#pragma once + +#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 1200)) +#include "gemm_sm120.h" +#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 1000)) +#include "gemm_sm100.h" +#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900)) +#include "./instruction/wgmma.h" +#include "gemm_sm90.h" +#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 890)) +#include "gemm_sm89.h" +#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750)) +#include "gemm_sm80.h" +#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 700)) +#include "gemm_sm70.h" +#else +// No matching architecture found +#endif diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm_mma.h b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm_mma.h new file mode 100644 index 00000000000..9462514f8c9 --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm_mma.h @@ -0,0 +1,482 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "common.h" +#include "cuda_fp8.h" +#include "intrin.h" + +namespace cute::tl_mma { + +template +struct DispatchInstruction; + +using _X = Underscore; + +} // namespace cute::tl_mma + +#define TL_DISPATCH_MMA(A_type, B_type, C_type, MMA_instr) \ + namespace cute::tl_mma { \ + template \ + struct DispatchInstruction { \ + using MMA = MMA_Atom; \ + using MMA_Group = Tile<_X, Int, _X>; \ + }; \ + } +#define TL_DISPATCH_MMA_TEMPLATE(A_type, B_type, C_type, MMA_instr) \ + namespace cute::tl_mma { \ + template \ + struct DispatchInstruction { \ + using MMA = MMA_Atom>; \ + using MMA_Group = Tile<_X, Int, _X>; \ + }; \ + } + +#ifdef __CUDA_ARCH_LIST__ +#if __CUDA_ARCH_LIST__ >= 1200 +#include "cuda_fp8.h" +#include +#include +TL_DISPATCH_MMA_TEMPLATE(fp8_e4_t, fp8_e4_t, float, SM120_16x8x32_TN) +TL_DISPATCH_MMA_TEMPLATE(fp8_e5_t, fp8_e5_t, float, SM120_16x8x32_TN) +TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN) +TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN) +TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN) +TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN) +TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN) +TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN) +#elif __CUDA_ARCH_LIST__ >= 1000 +#include "cuda_fp8.h" +#include +#include +#include +TL_DISPATCH_MMA(fp8_e4_t, fp8_e4_t, float, SM89_16x8x32_F32E4M3E4M3F32_TN) +TL_DISPATCH_MMA(fp8_e5_t, fp8_e5_t, float, SM89_16x8x32_F32E5M2E5M2F32_TN) +TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN) +TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN) +TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN) +TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN) +TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN) +TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN) +#elif __CUDA_ARCH_LIST__ >= 900 +#include "cuda_fp8.h" +#include +#include +TL_DISPATCH_MMA(fp8_e4_t, fp8_e4_t, float, SM89_16x8x32_F32E4M3E4M3F32_TN) +TL_DISPATCH_MMA(fp8_e5_t, fp8_e5_t, float, SM89_16x8x32_F32E5M2E5M2F32_TN) +TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN) +TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN) +TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN) +TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN) +TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN) +TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN) +#elif __CUDA_ARCH_LIST__ >= 890 +#include "cuda_fp8.h" +#include +#include +TL_DISPATCH_MMA(fp8_e4_t, fp8_e4_t, float, SM89_16x8x32_F32E4M3E4M3F32_TN) +TL_DISPATCH_MMA(fp8_e5_t, fp8_e5_t, float, SM89_16x8x32_F32E5M2E5M2F32_TN) +TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN) +TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN) +TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN) +TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN) +TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN) +TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN) +#elif __CUDA_ARCH_LIST__ >= 800 +#include +TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN) +TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN) +TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN) +TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN) +TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN) +TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN) +#elif __CUDA_ARCH_LIST__ >= 750 +TL_DISPATCH_MMA(half_t, half_t, float, SM75_16x8x8_F32F16F16F32_TN) +#endif +#endif +#undef TL_DISPATCH_MMA +#undef TL_DISPATCH_MMA_TEMPLATE + +namespace cute::tl_mma { + +template struct SelectCopy { + static constexpr int remainder = (N / num_warp_n) % 16; + using type = std::conditional_t< + remainder == 4 || remainder == 8 || remainder == 0, + std::conditional_t< + transpose, + std::conditional_t< + remainder == 4, SM75_U32x1_LDSM_N, + std::conditional_t>, + std::conditional_t< + remainder == 4, SM75_U16x2_LDSM_T, + std::conditional_t>>, + DefaultCopy>; +}; + +template +struct OperandTraits { + // Primary template, use padded layout and default copy + static constexpr int stride = leading_dim; + static constexpr int padded = + stride % (256 / Bits) == 0 ? stride + 128 / Bits : stride; + using Layout = typename std::conditional< + K_inner, Layout, Int>, Shape, _1>>, + Layout, Int>, Shape<_1, Int>>>::type; + using Copy = DefaultCopy; +}; + +template +struct OperandTraits<16, N, K, true, num_warp_n, leading_dim, + typename std::enable_if::type> { + using LayoutAtom = decltype(composition( + Swizzle<2, 3, 3>{}, Layout, Stride<_32, _1>>{})); + using Layout = + decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); + using Copy = typename SelectCopy::type; +}; + +template +struct OperandTraits<16, N, K, true, num_warp_n, leading_dim, + typename std::enable_if::type> { + using LayoutAtom = decltype(composition( + Swizzle<3, 3, 3>{}, Layout, Stride<_64, _1>>{})); + using Layout = + decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); + using Copy = typename SelectCopy::type; +}; + +template +struct OperandTraits<16, N, K, false, num_warp_n, leading_dim, + typename std::enable_if::type> { + using LayoutAtom = decltype(composition( + Swizzle<2, 3, 3>{}, Layout, Stride<_1, _32>>{})); + using Layout = decltype(tile_to_shape( + LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); + using Copy = typename SelectCopy::type; +}; + +template +struct OperandTraits<16, N, K, false, num_warp_n, leading_dim, + typename std::enable_if::type> { + using LayoutAtom = decltype(composition( + Swizzle<3, 3, 3>{}, Layout, Stride<_1, _64>>{})); + using Layout = decltype(tile_to_shape( + LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); + using Copy = typename SelectCopy::type; +}; + +template +struct OperandTraits<32, N, K, true, num_warp_n, leading_dim, + typename std::enable_if::type> { + using LayoutAtom = decltype(composition( + Swizzle<3, 2, 3>{}, Layout, Stride<_32, _1>>{})); + using Layout = + decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); + using Copy = typename SelectCopy::type; +}; + +template +struct OperandTraits<32, N, K, true, num_warp_n, leading_dim, + typename std::enable_if::type> { + using LayoutAtom = decltype(composition( + Swizzle<2, 2, 3>{}, Layout, Stride<_16, _1>>{})); + using Layout = + decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); + using Copy = typename SelectCopy::type; +}; + +template +struct OperandTraits<32, N, K, false, num_warp_n, leading_dim, + typename std::enable_if::type> { + using LayoutAtom = decltype(composition( + Swizzle<3, 2, 3>{}, Layout, Stride<_1, _32>>{})); + using Layout = decltype(tile_to_shape( + LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); + using Copy = UniversalCopy; +}; + +template +struct OperandTraits<32, N, K, false, num_warp_n, leading_dim, + typename std::enable_if::type> { + using LayoutAtom = decltype(composition( + Swizzle<2, 2, 3>{}, Layout, Stride<_1, _16>>{})); + using Layout = decltype(tile_to_shape( + LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); + using Copy = UniversalCopy; +}; + +template +struct OperandTraits<8, N, K, true, num_warp_n, leading_dim, + typename std::enable_if::type> { + using LayoutAtom = decltype(composition( + Swizzle<2, 4, 3>{}, Layout, Stride<_64, _1>>{})); + using Layout = + decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); + using Copy = typename SelectCopy::type; +}; + +template +struct OperandTraits<8, N, K, true, num_warp_n, leading_dim, + typename std::enable_if::type> { + using LayoutAtom = decltype(composition( + Swizzle<3, 4, 3>{}, Layout, Stride<_128, _1>>{})); + using Layout = + decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); + using Copy = typename SelectCopy::type; +}; + +template +struct OperandTraits<64, N, K, true, num_warp_n, leading_dim, + typename std::enable_if::type> { + using LayoutAtom = decltype(composition( + Swizzle<2, 0, 4>{}, Layout, Stride<_16, _1>>{})); + using Layout = + decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); + using Copy = DefaultCopy; +}; + +template +struct OperandTraits<64, N, K, false, num_warp_n, leading_dim, + typename std::enable_if::type> { + using LayoutAtom = decltype(composition( + Swizzle<2, 2, 2>{}, Layout, Stride<_1, _16>>{})); + using Layout = decltype(tile_to_shape( + LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); + using Copy = DefaultCopy; +}; + +template +class GemmTensorOp { +public: + using A_type = + typename std::conditional::value, + tfloat32_t, A_type_raw>::type; + using B_type = + typename std::conditional::value, + tfloat32_t, A_type_raw>::type; + using C_type = C_type_raw; + + using Instruction = + DispatchInstruction; + + using OperandATraits = OperandTraits::value, M, K, + !trans_A, num_warp_m, lda>; + using OperandBTraits = + OperandTraits::value, N, K, trans_B, num_warp_n, ldb>; + + using SmemLayoutA = typename OperandATraits::Layout; + using SmemLayoutB = typename OperandBTraits::Layout; + using SmemCopyA = Copy_Atom; + using SmemCopyB = Copy_Atom; + + using TileMma = TiledMMA, Int, _1>>, + typename Instruction::MMA_Group>; + + template + static CUTE_DEVICE auto remove_swizzle(Layout const &layout) { + return layout; + } + // In fp16, when layout is KxN and n_warp is 1 and N % 64 == 0 + // the original layout fail to compile, currently using this as a workaround + template + static CUTE_DEVICE auto + remove_swizzle(ComposedLayout const &layout) { + if constexpr (sizeof(A_type) == 2) + return layout.layout_b(); + else + return layout; + } + + template + static CUTE_DEVICE auto get_region_tensor(Tensor &sa) { + if constexpr (offset == 0) { + return composition( + sa, + Layout, Int>, + Stride<_1, typename std::conditional, + Int>::type>>{}); + } else { + if constexpr (trans) { + static_assert(offset % KK == 0, "Offset must be a multiple of K"); + constexpr int offset_n = offset / KK; + return flat_divide(sa, Shape, Int>{})(_, _, _0{}, + Int{}); + } else { + static_assert(offset % NN == 0, "Offset must be a multiple of N"); + constexpr int offset_n = offset / NN; + return flat_divide(sa, Shape, Int>{})(_, _, Int{}, + _0{}); + } + } + } + + static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) { + const int tid = threadIdx.x; + Tensor sA_all = make_tensor(make_smem_ptr(reinterpret_cast(pA)), + SmemLayoutA{}); + Tensor sB_all = make_tensor(make_smem_ptr(reinterpret_cast(pB)), + SmemLayoutB{}); + Tensor sA = get_region_tensor(sA_all); + Tensor sB = get_region_tensor(sB_all); + TileMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tid); + auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma); + auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma); + auto thr_copy_A = tiled_copy_A.get_thread_slice(tid); + auto thr_copy_B = tiled_copy_B.get_thread_slice(tid); + + Tensor tCrA = thr_mma.partition_fragment_A(sA); + Tensor tCrB = thr_mma.partition_fragment_B(sB); + Tensor tCsA = thr_copy_A.partition_S(sA); + Tensor tCsB = thr_copy_B.partition_S(sB); + + Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); + Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); + + Tensor acc = + make_tensor(make_rmem_ptr(reinterpret_cast(pC)), + partition_shape_C(tiled_mma, Shape, Int>{})); + + // when layout is KxN and n_warp is 1, there seem to be a bug, use this as a + // workaround + auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout())); + auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout())); + if constexpr (clear_accum) { + clear(acc); + } + CUTE_UNROLL + for (int k = 0; k < size<2>(tCrA); ++k) { + copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k)); + copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k)); + gemm(tiled_mma, tCrA_view(_, _, k), tCrB_view(_, _, k), acc); + } + } + + static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB, + C_type_raw *pC) { + const int tid = threadIdx.x; + Tensor sB_all = make_tensor(make_smem_ptr(reinterpret_cast(pB)), + SmemLayoutB{}); + Tensor sB = get_region_tensor(sB_all); + TileMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tid); + auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma); + auto thr_copy_B = tiled_copy_B.get_thread_slice(tid); + + Tensor tCrB = thr_mma.partition_fragment_B(sB); + Tensor tCsB = thr_copy_B.partition_S(sB); + + Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); + + Tensor acc = + make_tensor(make_rmem_ptr(reinterpret_cast(pC)), + partition_shape_C(tiled_mma, Shape, Int>{})); + Tensor tCrA = + make_tensor(make_rmem_ptr(reinterpret_cast(pA)), + partition_shape_A(tiled_mma, Shape, Int>{})); + auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout())); + if constexpr (clear_accum) { + clear(acc); + } + copy(tiled_copy_B, tCsB(_, _, 0), tCrB_copy_view(_, _, 0)); + CUTE_UNROLL + for (int k = 0; k < size<2>(tCrA); ++k) { + if (k < size<2>(tCrA) - 1) { + copy(tiled_copy_B, tCsB(_, _, k + 1), tCrB_copy_view(_, _, k + 1)); + } + gemm(tiled_mma, tCrA(_, _, k), tCrB_view(_, _, k), acc); + } + } + + static CUTE_DEVICE void body_sr(A_type_raw *pA, B_type_raw *pB, + C_type_raw *pC) { + const int tid = threadIdx.x; + Tensor sA_all = make_tensor(make_smem_ptr(reinterpret_cast(pA)), + SmemLayoutA{}); + Tensor sA = get_region_tensor(sA_all); + TileMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tid); + auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma); + auto thr_copy_A = tiled_copy_A.get_thread_slice(tid); + + Tensor tCrA = thr_mma.partition_fragment_A(sA); + Tensor tCsA = thr_copy_A.partition_S(sA); + + Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); + + Tensor acc = + make_tensor(make_rmem_ptr(reinterpret_cast(pC)), + partition_shape_C(tiled_mma, Shape, Int>{})); + Tensor tCrB = + make_tensor(make_rmem_ptr(reinterpret_cast(pB)), + partition_shape_B(tiled_mma, Shape, Int>{})); + auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout())); + if constexpr (clear_accum) { + clear(acc); + } + copy(tiled_copy_A, tCsA(_, _, 0), tCrA_copy_view(_, _, 0)); + CUTE_UNROLL + for (int k = 0; k < size<2>(tCrA); ++k) { + if (k < size<2>(tCrA) - 1) { + copy(tiled_copy_A, tCsA(_, _, k + 1), tCrA_copy_view(_, _, k + 1)); + } + gemm(tiled_mma, tCrA_view(_, _, k), tCrB(_, _, k), acc); + } + } +}; + +} // namespace cute::tl_mma + +namespace tl::tl_mma { + +template +CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) { + using MMA = + cute::tl_mma::GemmTensorOp; + MMA::body(pA, pB, accum); +} + +template +CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) { + using MMA = + cute::tl_mma::GemmTensorOp; + MMA::body_rs(pA, pB, accum); +} + +template +CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) { + using MMA = + cute::tl_mma::GemmTensorOp; + MMA::body_sr(pA, pB, accum); +} + +} // namespace tl::tl_mma diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm_sm100.h b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm_sm100.h new file mode 100644 index 00000000000..5b50fe72a9f --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm_sm100.h @@ -0,0 +1,384 @@ +// Licensed under the MIT License. +#pragma once + +#include "common.h" +#include "gemm_mma.h" +#include "intrin.h" + +#include +#include +#include + +namespace cute { + +// Extensions to CuTe +// CuTe don't support TCGEN5MMA with .ws, so we add it here +// About why we need .ws, plz refer to comments in tl_tcgen5mma::GemmTensorOp + +template +struct SM100_MMA_F16BF16_WS_SS { + static_assert(M == 32 || M == 64 || M == 128, + "SM100_MMA_F16BF16 (with .ws) M-mode size should be 32, 64 or " + "128 for 1 CTA cluster MMA."); + static_assert( + N == 64 || N == 128 || N == 256, + "SM100_MMA_F16BF16 (with .ws) N-mode size should be 32, 64 or 128"); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scaleC, uint64_t const &idescE) { + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p, 0; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE >> 32)), + "r"(scaleC)); + } + } +}; + +template +struct MMA_Traits> { + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && + cute::sizeof_bits_v == 16, + "SM100_MMA_F16BF16_WS_SS supports 16bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_ws_1sm; + + // Logical shape-K is always 256bits, transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape, Int, Int>; + using ThrID = Layout<_1>; + using ALayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; + using BLayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; + using CLayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; + + UMMA::InstrDescriptor idesc_ = + UMMA::make_instr_desc(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend void + mma_unpack(MMA_Traits const &traits, Tensor &D, + Tensor const &A, Tensor const &B, + Tensor const &C) { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, + "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, + "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_WS_SS::fma(desc_a, desc_b, tmem_c, + uint32_t(traits.accumulate_), + idesc); + } +}; + +struct SM100_MMA_F8F6F4_WS_SS { + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scaleC, uint64_t const &idescE) { + if (cute::elect_one_sync()) { + asm volatile("{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.ws.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, " + "p, 0; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), + "r"(uint32_t(idescE >> 32)), "r"(scaleC)); + } + } +}; + +template +struct MMA_Traits, + cute::C, cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant> { + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v <= 8 && + cute::sizeof_bits_v <= 8, + "SM100_MMA_F8F6F4_WS_SS supports types with leq 8bit types"); + static_assert(M == 32 || M == 64 || M == 128, + "SM100_MMA_F8F6F4_WS_SS M-mode size should be 32, 64 or 128 " + "for 1 CTA cluster MMA."); + static_assert( + N == 64 || N == 128 || N == 256, + "SM100_MMA_F8F6F4_WS_SS (with .ws) N-mode size should be 32, 64 or 128"); + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_ws_1sm; + + static_assert(sizeof_bits_v <= sizeof_bits_v && + sizeof_bits_v <= sizeof_bits_v); + + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 32; + + using Shape_MNK = Shape, Int, Int>; + using ThrID = Layout<_1>; + using ALayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; + using BLayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; + using CLayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; + + UMMA::InstrDescriptor idesc_ = + UMMA::make_instr_desc(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend void + mma_unpack(MMA_Traits const &traits, Tensor &D, + Tensor const &A, Tensor const &B, + Tensor const &C) { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, + "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, + "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F8F6F4_WS_SS::fma(desc_a, desc_b, tmem_c, + uint32_t(traits.accumulate_), idesc); + } +}; + +namespace tl_tcgen5mma { + +using cutlass::gemm::collective::detail::sm100_smem_selector; + +template +struct DispatchInstruction; + +template +struct DispatchInstruction> { + using MMA = SM100_MMA_F16BF16_SS; +}; + +template +struct DispatchInstruction> { + using MMA = SM100_MMA_F16BF16_WS_SS; +}; + +template +struct DispatchInstruction> { + using MMA = + SM100_MMA_F16BF16_SS; +}; + +template +struct DispatchInstruction> { + using MMA = + SM100_MMA_F16BF16_WS_SS; +}; + +template +struct DispatchInstruction> { + using MMA = MMA_Traits, + Int, integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; + +template +struct DispatchInstruction> { + using MMA = + MMA_Traits, + Int, integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; + +template +struct DispatchInstruction> { + using MMA = MMA_Traits, + Int, integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; + +template +struct DispatchInstruction> { + using MMA = + MMA_Traits, + Int, integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; + +template +class GemmTensorOp { +public: + using A_type = + typename std::conditional::value, + tfloat32_t, A_type_raw>::type; + using B_type = + typename std::conditional::value, + tfloat32_t, B_type_raw>::type; + using C_type = C_type_raw; + + static_assert(AtomM == 128 || AtomM == 64 || AtomM == 32); + + static constexpr UMMA::Major UmmaMajorA = + trans_A ? UMMA::Major::MN : UMMA::Major::K; + static constexpr UMMA::Major UmmaMajorB = + trans_B ? UMMA::Major::K : UMMA::Major::MN; + + using SmemLayoutAtomA = + decltype(sm100_smem_selector, Int>()); + using SmemLayoutAtomB = + decltype(sm100_smem_selector, Int>()); + + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, Shape, Int>{}, + conditional_t, Step<_1, _2>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, Shape, Int>{}, + conditional_t, Step<_2, _1>>{})); + + static CUTE_DEVICE void body_ss(A_type_raw *pA, B_type_raw *pB, uint32_t pC, + uint64_t *umma_bar_ptr, bool clear_accum) { + Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast(pA)), + SmemLayoutA{}); + Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast(pB)), + SmemLayoutB{}); + + // TODO (lei): Normal TCGEN5MMA (the one w/o ws) don't saturate all 128 + // lanes when M == 64 + // (see layout F in + // https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-f) + // So we use the .ws variant here + using MmaAtom = + typename DispatchInstruction::MMA; + auto tiled_mma = make_tiled_mma(MmaAtom{}, Layout>{}, + Tile, Int, Int>{}); + auto thr_mma = tiled_mma.get_slice(_0{}); + tiled_mma.accumulate_ = + clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One; + Tensor acc = partition_fragment_C(tiled_mma, Shape, Int>{}); + acc.data() = pC; + + Tensor sA_frag = thr_mma.partition_fragment_A(sA); + Tensor sB_frag = thr_mma.partition_fragment_B(sB); + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(sA_frag); ++k_block) { + cute::gemm(tiled_mma, sA_frag(_, _, k_block), sB_frag(_, _, k_block), + acc); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + cutlass::arch::umma_arrive(umma_bar_ptr); + } +}; + +} // namespace tl_tcgen5mma + +} // namespace cute + +namespace tl { + +using tl_mma::gemm_rs; +using tl_mma::gemm_sr; +using tl_mma::gemm_ss; + +// TODO (lei): Implement gemm_ts +// template +// TL_DEVICE void gemm_ts(A_type *pA, B_type *pB, C_type *accum, uint64_t +// *umma_bar_ptr) { +// } + +template +TL_DEVICE void tcgen5mma_gemm_ss(A_type *pA, B_type *pB, uint32_t accum, + Barrier_type *umma_bar_ptr, bool clear_accum) { + using MMA = + cute::tl_tcgen5mma::GemmTensorOp; + MMA::body_ss(pA, pB, accum, reinterpret_cast(umma_bar_ptr), + clear_accum); +} + +} // namespace tl diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm_sm120.h b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm_sm120.h new file mode 100644 index 00000000000..122f56642af --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm_sm120.h @@ -0,0 +1,9 @@ +#pragma once + +#include "gemm_mma.h" + +namespace tl { +using tl_mma::gemm_rs; +using tl_mma::gemm_sr; +using tl_mma::gemm_ss; +} // namespace tl diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm_sm70.h b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm_sm70.h new file mode 100644 index 00000000000..75127727935 --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm_sm70.h @@ -0,0 +1,188 @@ +#pragma once + +#include +#include + +#include "common.h" + +using cutlass::gemm::GemmShape; + +// Primary template +// Add 128 bits padding when the last dim is a multiple of 256 bits +template +struct DispatchSharedMemoryLayoutA { + using Layout = + typename std::conditional::type; + static int constexpr Dim = transpose ? M : K; + static int constexpr Stride = + (Dim * sizeof(T) % 32 == 0) ? Dim + 16 / sizeof(T) : Dim; +}; +template +struct DispatchSharedMemoryLayoutB { + using Layout = + typename std::conditional::type; + static int constexpr Dim = transpose ? K : N; + static int constexpr Stride = + (Dim * sizeof(T) % 32 == 0) ? Dim + 16 / sizeof(T) : Dim; +}; + +// Partial specialization for half_t +template +struct DispatchSharedMemoryLayoutA::type> { + using Layout = + cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous<16>; + static int constexpr Stride = M; +}; + +template +struct DispatchSharedMemoryLayoutA { + using Layout = + cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, K>; + static int constexpr Stride = M; +}; + +template struct DispatchSharedMemoryLayoutB { + using Layout = + cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise<16, K>; + static int constexpr Stride = N; +}; + +template +struct DispatchSharedMemoryLayoutB::type> { + using Layout = + cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous<16>; + static int constexpr Stride = N; +}; + +template +class GemmTensorOp { +public: + using A_type = A_type_raw; + using B_type = B_type_raw; + using C_type = C_type_raw; + using InstructionShape = GemmShape<16, 16, 4>; + using SMemLayoutA = + typename DispatchSharedMemoryLayoutA::Layout; + using SMemLayoutB = + typename DispatchSharedMemoryLayoutB::Layout; + static constexpr int stride_A = + DispatchSharedMemoryLayoutA::Stride; + static constexpr int stride_B = + DispatchSharedMemoryLayoutB::Stride; + + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma< + InstructionShape, 32, A_type, + typename std::conditional::type, + B_type, + typename std::conditional::type, + C_type, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>, + cutlass::MatrixShape<1, 1>>; + + static_assert(Shape::kM % num_warp_m == 0); + static_assert(Shape::kN % num_warp_n == 0); + + using MmaWarp = typename cutlass::gemm::warp::MmaVoltaTensorOp< + GemmShape, + A_type, SMemLayoutA, B_type, SMemLayoutB, C_type, + cutlass::layout::RowMajor, Policy>; + + using TensorRefA = typename MmaWarp::IteratorA::TensorRef; + using TensorRefB = typename MmaWarp::IteratorB::TensorRef; + using FragmentA = typename MmaWarp::FragmentA; + using FragmentB = typename MmaWarp::FragmentB; + using FragmentC = typename MmaWarp::FragmentC; + using IteratorA = typename MmaWarp::IteratorA; + using IteratorB = typename MmaWarp::IteratorB; + + static_assert(Shape::kK % InstructionShape::kK == 0); + static int constexpr kKgroups = Shape::kK / InstructionShape::kK; + + static CUTLASS_DEVICE void body(A_type_raw *pA, B_type_raw *pB, + FragmentC &accum, const int warp_idx_m, + const int warp_idx_n, const int lane_id) { + MmaWarp mma_op; + FragmentA frag_A; + FragmentB frag_B; + const TensorRefA ref_A((A_type *)pA, stride_A); + const TensorRefB ref_B((B_type *)pB, stride_B); + IteratorA iter_A(ref_A, lane_id); + IteratorB iter_B(ref_B, lane_id); + iter_A.add_tile_offset({warp_idx_m, 0}); + iter_B.add_tile_offset({0, warp_idx_n}); + if constexpr (clear_accum) { + accum.clear(); + } + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < kKgroups; ++k) { + iter_A.load(frag_A); + iter_B.load(frag_B); + ++iter_A; + ++iter_B; + mma_op(accum, frag_A, frag_B, accum); + } + } + + static CUTLASS_DEVICE void body_rs(const FragmentA *frag_A, B_type_raw *pB, + FragmentC &accum, const int warp_idx_n, + const int lane_id) { + MmaWarp mma_op; + FragmentB frag_B; + const TensorRefB ref_B((B_type *)pB, stride_B); + IteratorB iter_B(ref_B, lane_id); + iter_B.add_tile_offset({0, warp_idx_n}); + if constexpr (clear_accum) { + accum.clear(); + } + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < kKgroups; ++k) { + iter_B.load(frag_B); + ++iter_B; + mma_op(accum, frag_A[k], frag_B, accum); + } + } +}; + +namespace tl { + +template +CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) { + using MMA = GemmTensorOp, num_warp_m, num_warp_n, trans_A, + trans_B, clear_accum, A_type, B_type, C_type>; + using FragmentC = typename MMA::FragmentC; + int warp_id = threadIdx.x / 32; + int lane_id = threadIdx.x % 32; + MMA::body(pA, pB, *(FragmentC *)(accum), warp_id / num_warp_n, + warp_id % num_warp_n, lane_id); +} + +template +CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) { + using MMA = GemmTensorOp, num_warp_m, num_warp_n, trans_A, + trans_B, clear_accum, A_type, B_type, C_type>; + using FragmentA = typename MMA::FragmentA; + using FragmentC = typename MMA::FragmentC; + int warp_id = threadIdx.x / 32; + int lane_id = threadIdx.x % 32; + MMA::body_rs((const FragmentA *)(pA), pB, *(FragmentC *)(accum), + warp_id % num_warp_n, lane_id); +} + +}; // namespace tl diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm_sm80.h b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm_sm80.h new file mode 100644 index 00000000000..122f56642af --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm_sm80.h @@ -0,0 +1,9 @@ +#pragma once + +#include "gemm_mma.h" + +namespace tl { +using tl_mma::gemm_rs; +using tl_mma::gemm_sr; +using tl_mma::gemm_ss; +} // namespace tl diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm_sm89.h b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm_sm89.h new file mode 100644 index 00000000000..d64ae9e2e69 --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm_sm89.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +#include "cuda_fp8.h" + +#include "gemm_mma.h" + +namespace tl { +using tl_mma::gemm_rs; +using tl_mma::gemm_sr; +using tl_mma::gemm_ss; +} // namespace tl diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm_sm90.h b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm_sm90.h new file mode 100644 index 00000000000..1aa3ecff9e2 --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm_sm90.h @@ -0,0 +1,385 @@ +#pragma once + +#include "common.h" +#include "gemm_mma.h" +#include "intrin.h" + +#include +#include +#include + +namespace cute { + +using namespace SM90; + +namespace tl_wgmma { + +using namespace cutlass::gemm::collective::detail; // ss_smem_selector + +template +class GemmTensorOp { +public: + using A_type = conditional_t::value, + tfloat32_t, A_type_raw>; + using B_type = conditional_t::value, + tfloat32_t, B_type_raw>; + using C_type = C_type_raw; + + static constexpr GMMA::Major GmmaMajorA = + trans_A ? GMMA::Major::MN : GMMA::Major::K; + static constexpr GMMA::Major GmmaMajorB = + trans_B ? GMMA::Major::K : GMMA::Major::MN; + + using SmemLayoutAtomA = + decltype(ss_smem_selector, Int>()); + using SmemLayoutAtomB = + decltype(ss_smem_selector, Int>()); + + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, Shape, Int>{}, + conditional_t, Step<_1, _2>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, Shape, Int>{}, + conditional_t, Step<_2, _1>>{})); + + static_assert(num_warp_m % 4 == 0, + "num_warp_m must be a multiple of 4 for hopper wgmma"); + + template + static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) { + const int tid = threadIdx.x; + Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast(pA)), + SmemLayoutA{}); + Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast(pB)), + SmemLayoutB{}); + auto tiled_mma = make_tiled_mma( + GMMA::ss_op_selector< + A_type, B_type, C_type, + Shape, Int, Int>, + GmmaMajorA, GmmaMajorB>(), + Layout, Int, _1>>{}); + auto thr_mma = tiled_mma.get_thread_slice(tid); + + // Allocate registers for pipelining + Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE) + + Tensor acc = + make_tensor(make_rmem_ptr(reinterpret_cast(pC)), + partition_shape_C(tiled_mma, Shape, Int>{})); + + warpgroup_fence_operand(acc); + warpgroup_arrive(); + if constexpr (clear_accum) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), acc); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + if constexpr (wg_wait >= 0) { + warpgroup_wait(); + } + warpgroup_fence_operand(acc); + } + + template + static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB, + C_type_raw *pC) { + // TODO: Move bar.sync out of body_rs + // asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(num_warp_m * num_warp_n * + // 32)); + const int tid = threadIdx.x; + Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast(pB)), + SmemLayoutB{}); + auto tiled_mma = make_tiled_mma( + GMMA::rs_op_selector< + A_type, B_type, C_type, + Shape, Int, Int>, + GmmaMajorA, GmmaMajorB>(), + Layout, Int, _1>>{}); + auto thr_mma = tiled_mma.get_thread_slice(tid); + + // Allocate registers for pipelining + Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE) + Tensor tCrA = + make_tensor(make_rmem_ptr(reinterpret_cast(pA)), + partition_shape_A(tiled_mma, Shape, Int>{})); + Tensor acc = + make_tensor(make_rmem_ptr(reinterpret_cast(pC)), + partition_shape_C(tiled_mma, Shape, Int>{})); + + warpgroup_fence_operand(tCrA); + warpgroup_fence_operand(acc); + warpgroup_arrive(); + if constexpr (clear_accum) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), acc); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + if constexpr (wg_wait >= 0) { + warpgroup_wait(); + } + warpgroup_fence_operand(acc); + warpgroup_fence_operand(tCrA); + } +}; + +} // namespace tl_wgmma + +} // namespace cute +/** + * Execute a tiled GEMM where A is read from global memory and B is staged in + * shared memory. + * + * Dispatches to tl_mma::GemmTensorOp::body_rs to perform the + * computation. + * + * @param pA Pointer to the A tile region (device memory). + * @param pB Pointer to the B tile region (device memory). + * @param accum Pointer to the accumulator/output tile region (device memory). + */ +/** + * Execute a tiled GEMM where A is staged in shared memory and B is read from + * global memory. + * + * Dispatches to tl_mma::GemmTensorOp::body_sr to perform the + * computation. + * + * @param pA Pointer to the A tile region (device memory). + * @param pB Pointer to the B tile region (device memory). + * @param accum Pointer to the accumulator/output tile region (device memory). + */ +/** + * Perform a tiled GEMM (both operands in shared memory or selected backend) and + * write to accum. + * + * If use_wgmma is true, validates wgmma constraints (strides and offsets) and + * dispatches to the Hopper wgmma implementation; otherwise dispatches to the + * tl_mma implementation. + * + * @param pA Pointer to the A tile region (device memory). + * @param pB Pointer to the B tile region (device memory). + * @param accum Pointer to the accumulator/output tile region (device memory). + */ +/** + * Perform a tiled GEMM with A in global memory and B in shared memory (or + * selected backend). + * + * If use_wgmma is true, validates wgmma constraints (strides and offsets) and + * dispatches to the Hopper wgmma read-share implementation; otherwise + * dispatches to the tl_mma read-share. + * + * @param pA Pointer to the A tile region (device memory). + * @param pB Pointer to the B tile region (device memory). + * @param accum Pointer to the accumulator/output tile region (device memory). + */ +/** + * Perform a tiled GEMM with A staged in shared memory and B in global memory + * (tl_mma only). + * + * wgmma does not support this variant; caller must set use_wgmma == false. + * Dispatches to tl_mma::GemmTensorOp::body_sr. + * + * @param pA Pointer to the A tile region (device memory). + * @param pB Pointer to the B tile region (device memory). + * @param accum Pointer to the accumulator/output tile region (device memory). + */ +/** + * Wait for a warp-group of WMMA/MMA warps to complete. + * + * Wrapper around cute::warpgroup_wait for the specified number of MMA warps. + */ +/** + * Synchronize a named barrier across NumMmaThreads MMA threads. + * + * Calls cutlass::arch::NamedBarrier::sync with the canonical warp-group id. + */ +/** + * Arrive at a named barrier for NumMmaThreads MMA threads using + * architecture-aware mapping. + * + * Supported NumMmaThreads values: 256 or 384. The function issues one or two + * barrier arrives depending on the thread-group topology to ensure proper + * rendezvous ordering. + */ +/** + * Initialize named-barrier state for multi-warp MMA execution. + * + * For NumMmaThreads == 256 or 384, performs the required initial barrier + * arrivals for non-zero canonical warp-group indices to set up subsequent + * barrier synchronization. + */ + +namespace tl { + +template +TL_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) { + if constexpr (use_wgmma) { + static_assert((trans_A && lda == M) || (!trans_A && lda == K), + "Hopper wgmma doesn't support custom stride for A"); + static_assert((trans_B && ldb == K) || (!trans_B && ldb == N), + "Hopper wgmma doesn't support custom stride for B"); + static_assert(offset_a == 0 && offset_b == 0, + "offset_a and offset_b must be zero for wgmma"); + using MMA = cute::tl_wgmma::GemmTensorOp; + MMA::body(pA, pB, accum); + } else { + using MMA = + cute::tl_mma::GemmTensorOp; + MMA::body(pA, pB, accum); + } +} + +template +TL_DEVICE /** + * Perform a read-share (B in shared memory, A in global) tiled GEMM + * and accumulate into `accum`. + * + * Dispatches at compile time to either the Hopper wgmma + * implementation or the fallback MMA implementation depending on + * `use_wgmma`. The selected GemmTensorOp::body_rs performs the + * region-tiled GEMM loop and updates the accumulator in-place. + * + * When `use_wgmma == true`, this function enforces wgmma constraints + * at compile time: + * - A's leading dimension must equal (trans_A ? M : K) + * - B's leading dimension must equal (trans_B ? K : N) + * - offset_a and offset_b must be zero + * + * @param pA Pointer to operand A (global memory). Layout/stride + * expectations depend on template parameters. + * @param pB Pointer to operand B (base for shared-memory staging). + * Layout/stride expectations depend on template parameters. + * @param accum Pointer to the accumulator/output C buffer updated + * in-place. + */ + void + gemm_rs(A_type *pA, B_type *pB, C_type *accum) { + if constexpr (use_wgmma) { + static_assert((trans_A && lda == M) || (!trans_A && lda == K), + "Hopper wgmma doesn't support custom stride for A"); + static_assert((trans_B && ldb == K) || (!trans_B && ldb == N), + "Hopper wgmma doesn't support custom stride for B"); + static_assert(offset_a == 0 && offset_b == 0, + "offset_a and offset_b must be zero for wgmma"); + using MMA = cute::tl_wgmma::GemmTensorOp; + MMA::body_rs(pA, pB, accum); + } else { + using MMA = + cute::tl_mma::GemmTensorOp; + MMA::body_rs(pA, pB, accum); + } +} + +template +TL_DEVICE /** + * Perform a non-wgmma tiled GEMM where A regions are staged into + * shared memory and B is read directly from global memory, + * accumulating into `accum`. + * + * This overload dispatches to the tl_mma::GemmTensorOp::body_sr + * implementation. Must be instantiated with `use_wgmma = false` + * (enforced via static_assert). + * + * @param pA Pointer to the A operand in global memory (source that + * will be staged to shared memory). + * @param pB Pointer to the B operand in global memory (read + * directly). + * @param accum Pointer to the output accumulator matrix in global + * memory. + */ + void + gemm_sr(A_type *pA, B_type *pB, C_type *accum) { + static_assert(!use_wgmma, "wgmma doesn't support gemm_sr"); + using MMA = + cute::tl_mma::GemmTensorOp; + MMA::body_sr(pA, pB, accum); +} + +template +TL_DEVICE /** + * Wait for all WMMA/MMA warps in the current warp-group to + * synchronize. + * + * Blocks until the warp-group-wide rendezvous for `num_mma` MMA lanes + * completes, ensuring all participating warps have arrived before + * proceeding. + */ + void + wait_wgmma() { + cute::warpgroup_wait(); +} + +template TL_DEVICE void warp_scheduler_barrier_sync() { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, + cutlass::canonical_warp_group_idx() /*id*/); +} + +template TL_DEVICE void warp_scheduler_barrier_arrive() { + static_assert(NumMmaThreads == 256 || NumMmaThreads == 384); + if constexpr (NumMmaThreads == 256) { + cutlass::arch::NamedBarrier::arrive( + NumMmaThreads, (1 - cutlass::canonical_warp_group_idx()) /*id*/); + } else { + cutlass::arch::NamedBarrier::arrive( + NumMmaThreads, + (cutlass::canonical_warp_group_idx() <= 1 + ? cutlass::canonical_warp_group_idx() + 1 + : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/); + cutlass::arch::NamedBarrier::arrive( + NumMmaThreads, + (cutlass::canonical_warp_group_idx() <= 0 + ? cutlass::canonical_warp_group_idx() + 2 + : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/); + } +} + +template TL_DEVICE void mma_init() { + static_assert(NumMmaThreads == 256 || NumMmaThreads == 384); + if (cutlass::canonical_warp_group_idx() > 0) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 0); + } + if constexpr (NumMmaThreads == 384) { + if (cutlass::canonical_warp_group_idx() > 1) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 1 /*id*/); + } + } +} +} // namespace tl diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm_sp.h b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm_sp.h new file mode 100644 index 00000000000..f40a7bd0f8e --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm_sp.h @@ -0,0 +1,6 @@ +#pragma once +#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900)) +#include "gemm_sp_sm90.h" +#else(defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800)) +#include "gemm_sp_sm80.h" +#endif diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm_sp_sm80.h b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm_sp_sm80.h new file mode 100644 index 00000000000..f1fc860092e --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm_sp_sm80.h @@ -0,0 +1,270 @@ +#include +#include + +namespace tl { + +static int const kSparse = 2; +template struct ShapeCheck { + static constexpr bool value = false; +}; + +template struct ShapeCheck { + static constexpr bool value = + (Shape::kM % 32 == 0) && (Shape::kN % 32 == 0) && (Shape::kK % 32 == 0); +}; + +template struct ShapeCheck { + static constexpr bool value = + ShapeCheck::value; // Same as half +}; + +template struct ShapeCheck { + static constexpr bool value = + (Shape::kM % 16 == 0) && (Shape::kN % 16 == 0) && (Shape::kK % 64 == 0); +}; + +template struct ShapeCheck { + static constexpr bool value = + (Shape::kM % 16 == 0) && (Shape::kN % 16 == 0) && (Shape::kK % 64 == 0); +}; + +// ref: +// https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h +template struct DispatchInstructionShape { + static_assert(!std::is_same_v, + "Unsupported type for DispatchInstructionShape"); +}; + +template <> struct DispatchInstructionShape { + using Shape = cutlass::gemm::GemmShape<16, 8, 32>; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template <> struct DispatchInstructionShape { + using Shape = cutlass::gemm::GemmShape<16, 8, 32>; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// TODO: Not supported for now +// template<> +// struct DispatchInstructionShape { +// using Shape = cutlass::gemm::GemmShape<16, 8, 16>; +// using Operator = cutlass::arch::OpMultiplyAdd; +// }; + +template <> struct DispatchInstructionShape { + using Shape = cutlass::gemm::GemmShape<16, 8, 64>; + using Operator = cutlass::arch::OpMultiplyAddSaturate; +}; + +template <> struct DispatchInstructionShape { + using Shape = cutlass::gemm::GemmShape<16, 8, 64>; + using Operator = cutlass::arch::OpMultiplyAddSaturate; +}; + +// TODO: Not supported for now +// template<> +// struct DispatchInstructionShape { +// using Shape = cutlass::gemm::GemmShape<16, 8, 128>; +// using Operator = cutlass::arch::OpMultiplyAddSaturate; +// }; + +template +struct DispatchSharedMemoryLayoutA; + +template +struct DispatchSharedMemoryLayoutA { + using SmemLayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, K / kSparse>; +}; + +template +struct DispatchSharedMemoryLayoutA { + static int const Crosswise_A = + cutlass::platform::min(int(128 / sizeof(T)), M); + using SmemLayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, Crosswise_A>; +}; + +template +struct DispatchSharedMemoryLayoutB; + +template +struct DispatchSharedMemoryLayoutB { + static_assert( + cutlass::sizeof_bits::value != 8, + "int8, uint8, float8 only support column major layout for matrix B"); + static int const Crosswise_B = + cutlass::platform::min(int(128 / sizeof(T)), N); + using SmemLayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, Crosswise_B>; +}; + +template +struct DispatchSharedMemoryLayoutB { + static int const kCrosswiseB = (K > (1024 / cutlass::sizeof_bits::value)) + ? (1024 / cutlass::sizeof_bits::value) + : K; + using SmemLayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, kCrosswiseB>; +}; + +template struct DispatchType { + static_assert(std::is_same::value, "Unsupported dtype"); +}; + +template <> struct DispatchType { + using Type = cutlass::half_t; +}; + +template <> struct DispatchType { + using Type = cutlass::bfloat16_t; +}; + +template <> struct DispatchType { + using Type = uint8_t; +}; + +template <> struct DispatchType { + using Type = int8_t; +}; + +template +class GemmTensorOp { +public: + static_assert(Shape::kM % num_warp_m == 0); + static_assert(Shape::kN % num_warp_n == 0); + using ElementA = typename DispatchType::Type; + using ElementB = typename DispatchType::Type; + using ElementC = C_type_raw; + + static_assert(std::is_same_v, + "A and B are not the same type"); + static_assert(ShapeCheck::value, + "Invalid shape for ElementA"); + + using LayoutA = + typename std::conditional_t; + using LayoutB = + typename std::conditional_t; + using LayoutC = cutlass::layout::RowMajor; + using ThreadblockShape = Shape; + using SmemLayoutA = + typename DispatchSharedMemoryLayoutA::SmemLayoutA; + using SmemLayoutB = + typename DispatchSharedMemoryLayoutB::SmemLayoutB; + + using WarpShape = cutlass::gemm::GemmShape; + using InstructionShape = typename DispatchInstructionShape::Shape; + using Operator = typename DispatchInstructionShape::Operator; + static_assert(WarpShape::kK % InstructionShape::kK == 0, + "K dimension must be divisible by instruction shape K."); + + // instruction/warp config + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::SparseMma, + cutlass::MatrixShape<1, 1>>; + using MmaWarp = + cutlass::gemm::warp::SparseMmaTensorOp; + static_assert(kSparse == MmaWarp::kSparse, "not 2:4 structured sparse"); + + using SmemLayoutE = typename MmaWarp::LayoutE; + static_assert(std::is_same_v, + "Meta data layout must be ColumnMajor for sparse mma."); + + // other traits + using FragmentA = typename MmaWarp::FragmentA; + using FragmentB = typename MmaWarp::FragmentB; + using FragmentC = typename MmaWarp::FragmentC; + using FragmentE = typename MmaWarp::FragmentE; + + using IteratorA = typename MmaWarp::IteratorA; + using IteratorB = typename MmaWarp::IteratorB; + using IteratorE = typename MmaWarp::IteratorE; + + using TensorRefA = typename IteratorA::TensorRef; + using TensorRefB = typename IteratorB::TensorRef; + using TensorRefE = typename IteratorE::TensorRef; + using ElementE = typename TensorRefE::Element; + + static int const kElementsPerElementE = MmaWarp::kElementsPerElementE; + static_assert(kSparse == MmaWarp::kSparse, "not 2:4 structured sparse"); + + using ShapeA = cutlass::MatrixShape; + using ShapeB = cutlass::MatrixShape; + using ShapeE = + cutlass::MatrixShape; + + static int constexpr kKgroups = WarpShape::kK / InstructionShape::kK; + + template + static CUTLASS_DEVICE void + body(A_type_raw *pA, E_type_raw *pE, B_type_raw *pB, FragmentC &accum, + const int warp_idx_m, const int warp_idx_n, const int lane_id) { + MmaWarp mma_op; + FragmentA frag_a; + FragmentB frag_b; + FragmentE frag_e; + const TensorRefA ref_A( + (ElementA *)pA, + MmaWarp::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn})); + const TensorRefE ref_E( + (ElementE *)pE, + MmaWarp::LayoutE::packed({ShapeE::kRow, ShapeE::kColumn})); + const TensorRefB ref_B( + (ElementB *)pB, + MmaWarp::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn})); + IteratorA iter_A(ref_A, lane_id); + IteratorE iter_E(ref_E, lane_id); + IteratorB iter_B(ref_B, lane_id); + iter_A.add_tile_offset({warp_idx_m, 0}); + iter_E.add_tile_offset({warp_idx_m, 0}); + iter_B.add_tile_offset({0, warp_idx_n}); + if constexpr (clear_accum) { + accum.clear(); + } + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < kKgroups; ++k) { + iter_A.load(frag_a); + iter_E.load(frag_e); + iter_B.load(frag_b); + ++iter_A; + ++iter_E; + ++iter_B; + mma_op(accum, frag_a, frag_b, accum, frag_e); + } + } +}; + +template +TL_DEVICE void gemm_sp_ss(A_type *pA, B_type *pB, C_type *accum, E_type *pE) { + using MMA = + GemmTensorOp, num_warp_m, num_warp_n, + trans_A, trans_B, clear_accum, A_type, B_type, C_type>; + using FragmentC = typename MMA::FragmentC; + + int warp_id = threadIdx.x / 32; + int lane_id = threadIdx.x % 32; + MMA::body(pA, pE, pB, *(FragmentC *)(accum), warp_id % num_warp_m, + warp_id / num_warp_m, lane_id); +} + +} // namespace tl diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm_sp_sm90.h b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm_sp_sm90.h new file mode 100644 index 00000000000..db55a21ecf5 --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/gemm_sp_sm90.h @@ -0,0 +1,232 @@ +#pragma once + +#include +#include +#include + +namespace cute { +namespace tl_wgmma_sp { +template +class GemmTensorOp { +public: + static_assert(num_warp_m % 4 == 0, "num_warp_m must be a multiple of 4"); + + using A_type = conditional_t::value, + tfloat32_t, A_type_raw>; + using B_type = conditional_t::value, + tfloat32_t, B_type_raw>; + using C_type = C_type_raw; + + static constexpr bool need_tfloat32_cast = + std::is_same::value && + std::is_same::value; + + static constexpr GMMA::Major GmmaMajorA = + trans_A ? GMMA::Major::MN : GMMA::Major::K; + static constexpr GMMA::Major GmmaMajorB = + trans_B ? GMMA::Major::K : GMMA::Major::MN; + + using TiledMma = decltype(make_tiled_mma( + GMMA::ss_op_selector_sparse< + A_type, B_type, C_type, + Shape, Int, Int>, + GmmaMajorA, GmmaMajorB>(), + Layout, Int, _1>>{})); + + using ElementAMma = typename TiledMma::ValTypeA; + using ElementAMmaSparsity = Int; + using ElementBMma = typename TiledMma::ValTypeB; + using ElementEMma = typename TiledMma::ValTypeE; + using ElementEMmaSparsity = Int; + using E_type_raw = typename ElementEMma::raw_type; + + using SparseConfig = + cutlass::Sm90GemmSparseConfig{}, _128{}))>; + + using LayoutA = decltype(SparseConfig::deduce_layoutA()); + using LayoutE = decltype(SparseConfig::deduce_layoutE()); + + using SmemLayoutAtomA = + decltype(cutlass::gemm::collective::detail::ss_smem_selector_sparse< + GmmaMajorA, A_type, Int, Int, ElementAMmaSparsity>()); + using SmemLayoutAtomB = + decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GmmaMajorB, B_type, Int, Int>()); + + using SmemLayoutAtomE_ = typename SparseConfig::TensorEAtom; + using SmemLayoutAtomE = + ComposedLayout, + smem_sparse_ptr_flag_bits>, + SmemLayoutAtomE_>; + + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, Shape, Int>{}, + conditional_t, Step<_1, _2>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, Shape, Int>{}, + conditional_t, Step<_2, _1>>{})); + using SmemLayoutE = decltype(tile_to_shape( + SmemLayoutAtomE{}, Shape, Int>{}, + conditional_t, Step<_1, _2>>{})); + + using SmemCopyAtomE = AutoVectorizingCopy; + + template + static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC, + E_type_raw *pE) { + const int tid = threadIdx.x; + Tensor sA = + make_tensor(make_smem_ptr(recast_ptr(pA)), SmemLayoutA{}); + Tensor sB = + make_tensor(make_smem_ptr(recast_ptr(pB)), SmemLayoutB{}); + Tensor sE = as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(recast_ptr(pE)), SmemLayoutE{})); + + TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tid); + + Tensor tCsA = thr_mma.partition_A(sA); + Tensor tCsB = thr_mma.partition_B(sB); + Tensor tCsE = partition_E(thr_mma, sE(_, _)); + + Tensor tCrA = thr_mma.make_fragment_A(tCsA); + Tensor tCrB = thr_mma.make_fragment_B(tCsB); + Tensor tCrE = make_fragment_like(tCsE); + + auto copy_atom_E = Copy_Atom{}; + auto smem_tiled_copy_E = make_tiled_copy_E(copy_atom_E, tiled_mma); + auto smem_thr_copy_E = smem_tiled_copy_E.get_thread_slice(tid); + Tensor tEsE = smem_thr_copy_E.partition_S(sE); + Tensor tErE = smem_thr_copy_E.retile_D(tCrE); + + Tensor acc = + make_tensor(make_rmem_ptr(pC), + partition_shape_C(tiled_mma, Shape, Int>{})); + + warpgroup_fence_operand(acc); + warpgroup_arrive(); + if constexpr (clear_accum) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + copy(smem_tiled_copy_E, tEsE, tErE); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + gemm(tiled_mma, make_zip_tensor(tCrA(_, _, k_block), tCrE(_, _, k_block)), + tCrB(_, _, k_block), acc); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + if constexpr (wg_wait >= 0) { + warpgroup_wait(); + } + warpgroup_fence_operand(acc); + } + + template + CUTE_HOST_DEVICE static constexpr auto + thrfrg_E(TiledMMA const &mma, + ETensor &&etensor) { + using TiledMma = TiledMMA; + + CUTE_STATIC_ASSERT_V(rank(etensor) >= Int<2>{}); + + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(get<0>(PermutationMNK{}), get<2>(PermutationMNK{})); + auto t_tensor = logical_divide(etensor, t_tile); // (PermM,PermK) + + // Tile the tensor for the Atom + auto e_tile = + make_tile(make_layout(size<0>(typename TiledMma::AtomShape_MNK{})), + make_layout(size<2>(typename TiledMma::AtomShape_MNK{}))); + auto e_tensor = + zipped_divide(t_tensor, e_tile); // ((AtomM,AtomK),(RestM,RestK)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + using AtomLayoutE_TV = typename TiledMma::Atom::Traits::ELayout; + auto tv_tensor = + e_tensor.compose(AtomLayoutE_TV{}, _); // ((ThrV,FrgV),(RestM,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = + make_tile(_, make_tile(make_layout(size<1>(mma.thr_layout_vmnk_)), + make_layout(size<3>(mma.thr_layout_vmnk_)))); + auto thr_tensor = zipped_divide( + tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK))) + + return thr_tensor; + } + + template + CUTE_HOST_DEVICE static constexpr auto + get_layoutE_TV(TiledMMA const &mma) { + // (M,K) -> (M,K) + auto ref_E = make_layout(make_shape(tile_size<0>(mma), tile_size<2>(mma))); + // (ethrid,val) -> (M,K) + auto layoutE_TV = thrfrg_E(mma, ref_E); + + // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto etile = make_tile( + _, make_tile(make_layout(make_shape(size<1>(mma.thr_layout_vmnk_), + size<2>(mma.thr_layout_vmnk_)), + make_stride(Int<1>{}, Int<0>{})), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = right_inverse(mma.thr_layout_vmnk_); + + // (thr_idx,val) -> (M,K) + return layoutE_TV.compose(etile, _).compose(thridx_2_thrid, _); + } + + template + CUTE_HOST_DEVICE static constexpr auto + partition_E(ThrMMA const &thr_mma, ETensor &&etensor) { + auto thr_tensor = make_tensor(static_cast(etensor).data(), + thrfrg_E(thr_mma, etensor.layout())); + + auto thr_vmk = make_coord( + get<0>(thr_mma.thr_vmnk_), + make_coord(get<1>(thr_mma.thr_vmnk_), get<3>(thr_mma.thr_vmnk_))); + return thr_tensor(thr_vmk, + make_coord(_, repeat(thr_tensor)>(_))); + } + + template + CUTE_HOST_DEVICE static constexpr auto + make_tiled_copy_E(Copy_Atom const ©_atom, + TiledMMA const &mma) { + return make_tiled_copy_impl( + copy_atom, get_layoutE_TV(mma), + make_shape(tile_size<0>(mma), tile_size<2>(mma))); + } +}; + +} // namespace tl_wgmma_sp +} // namespace cute + +namespace tl { +template , + typename E_type = typename GMMA::ElementEMma::raw_type> +TL_DEVICE void gemm_sp_ss(A_type *pA, B_type *pB, C_type *accum, E_type *pE) { + static_assert(use_wgmma, "only wgmma is supported for now"); + if constexpr (use_wgmma) { + GMMA::body(pA, pB, accum, pE); + } else { + CUTE_GCC_UNREACHABLE; + } +} +} // namespace tl \ No newline at end of file diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/instruction/wgmma.h b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/instruction/wgmma.h new file mode 100644 index 00000000000..0e971728090 --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/instruction/wgmma.h @@ -0,0 +1,647 @@ +#pragma once +#include "../common.h" +#include "cute/arch/mma_sm90_gmma.hpp" + +namespace tl { + +template inline constexpr bool always_false_v = false; + +// 主类模板 - 移除默认参数,因为特化不能有默认参数 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, + bool scale_out) { + printf("DEBUG: WgmmaSSImpl fallback - A_type=%d (kFloat16=%d), B_type=%d, " + "C_type=%d, M=%d, N=%d, K=%d, tnspA=%d, tnspB=%d, scaleA=%d, " + "scaleB=%d\n", + (int)A_type, (int)DataType::kFloat16, (int)B_type, (int)C_type, M, N, + K, (int)tnspA, (int)tnspB, scaleA, scaleB); + // 暂时注释掉 static_assert 来看调试输出 + // static_assert(always_false_v, + // "wgmma_ss: No specialization available for given template + // parameters!"); + }; +}; + +// ================================= F16 x F16 -> F16 +// ================================= + +// M64N8K16 F16 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, + bool scale_out) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " + "{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), + "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + } +}; + +// M64N16K16 F16 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, + bool scale_out) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 " + "{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), + "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + } +}; + +// M64N32K16 F16 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, + bool scale_out) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]), + "+r"(c[5]), "+r"(c[6]), "+r"(c[7]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), + "n"(int32_t(tnspB))); + } +}; + +// M64N64K16 F16 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, + bool scale_out) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + "%8, %9, %10, %11, %12, %13, %14, %15}," + " %16, %17, p, %19, %20, %21, %22;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]), + "+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]), + "+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]), + "+r"(c[14]), "+r"(c[15]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), + "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + } +}; + +// M64N96K16 F16 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, + bool scale_out) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + "%8, %9, %10, %11, %12, %13, %14, %15, " + "%16, %17, %18, %19, %20, %21, %22, %23}, " + "%24, %25, p, %27, %28, %29, %30;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]), + "+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]), + "+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]), + "+r"(c[14]), "+r"(c[15]), "+r"(c[16]), "+r"(c[17]), + "+r"(c[18]), "+r"(c[19]), "+r"(c[20]), "+r"(c[21]), + "+r"(c[22]), "+r"(c[23]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), + "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + } +}; + +// M64N128K16 F16 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, + bool scale_out) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + "%8, %9, %10, %11, %12, %13, %14, %15, " + "%16, %17, %18, %19, %20, %21, %22, %23, " + "%24, %25, %26, %27, %28, %29, %30, %31}, " + "%32, %33, p, %35, %36, %37, %38;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]), + "+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]), + "+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]), + "+r"(c[14]), "+r"(c[15]), "+r"(c[16]), "+r"(c[17]), + "+r"(c[18]), "+r"(c[19]), "+r"(c[20]), "+r"(c[21]), + "+r"(c[22]), "+r"(c[23]), "+r"(c[24]), "+r"(c[25]), + "+r"(c[26]), "+r"(c[27]), "+r"(c[28]), "+r"(c[29]), + "+r"(c[30]), "+r"(c[31]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), + "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + } +}; + +// M64N192K16 F16 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, + bool scale_out) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + "%8, %9, %10, %11, %12, %13, %14, %15, " + "%16, %17, %18, %19, %20, %21, %22, %23, " + "%24, %25, %26, %27, %28, %29, %30, %31, " + "%32, %33, %34, %35, %36, %37, %38, %39, " + "%40, %41, %42, %43, %44, %45, %46, %47}, " + "%48, %49, p, %51, %52, %53, %54;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]), + "+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]), + "+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]), "+r"(c[14]), + "+r"(c[15]), "+r"(c[16]), "+r"(c[17]), "+r"(c[18]), "+r"(c[19]), + "+r"(c[20]), "+r"(c[21]), "+r"(c[22]), "+r"(c[23]), "+r"(c[24]), + "+r"(c[25]), "+r"(c[26]), "+r"(c[27]), "+r"(c[28]), "+r"(c[29]), + "+r"(c[30]), "+r"(c[31]), "+r"(c[32]), "+r"(c[33]), "+r"(c[34]), + "+r"(c[35]), "+r"(c[36]), "+r"(c[37]), "+r"(c[38]), "+r"(c[39]), + "+r"(c[40]), "+r"(c[41]), "+r"(c[42]), "+r"(c[43]), "+r"(c[44]), + "+r"(c[45]), "+r"(c[46]), "+r"(c[47]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), + "n"(int32_t(tnspB))); + } +}; + +// M64N256K16 F16 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, + bool scale_out) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + "%8, %9, %10, %11, %12, %13, %14, %15, " + "%16, %17, %18, %19, %20, %21, %22, %23, " + "%24, %25, %26, %27, %28, %29, %30, %31, " + "%32, %33, %34, %35, %36, %37, %38, %39, " + "%40, %41, %42, %43, %44, %45, %46, %47, " + "%48, %49, %50, %51, %52, %53, %54, %55, " + "%56, %57, %58, %59, %60, %61, %62, %63}, " + "%64, %65, p, %67, %68, %69, %70;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]), + "+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]), + "+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]), "+r"(c[14]), + "+r"(c[15]), "+r"(c[16]), "+r"(c[17]), "+r"(c[18]), "+r"(c[19]), + "+r"(c[20]), "+r"(c[21]), "+r"(c[22]), "+r"(c[23]), "+r"(c[24]), + "+r"(c[25]), "+r"(c[26]), "+r"(c[27]), "+r"(c[28]), "+r"(c[29]), + "+r"(c[30]), "+r"(c[31]), "+r"(c[32]), "+r"(c[33]), "+r"(c[34]), + "+r"(c[35]), "+r"(c[36]), "+r"(c[37]), "+r"(c[38]), "+r"(c[39]), + "+r"(c[40]), "+r"(c[41]), "+r"(c[42]), "+r"(c[43]), "+r"(c[44]), + "+r"(c[45]), "+r"(c[46]), "+r"(c[47]), "+r"(c[48]), "+r"(c[49]), + "+r"(c[50]), "+r"(c[51]), "+r"(c[52]), "+r"(c[53]), "+r"(c[54]), + "+r"(c[55]), "+r"(c[56]), "+r"(c[57]), "+r"(c[58]), "+r"(c[59]), + "+r"(c[60]), "+r"(c[61]), "+r"(c[62]), "+r"(c[63]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), + "n"(int32_t(tnspB))); + } +}; + +// ================================= F16 x F16 -> F32 +// ================================= + +// M64N8K16 F16->F32 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, + bool scale_out) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 " + "{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), + "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + } +}; + +// M64N16K16 F16->F32 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, + bool scale_out) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]), + "+r"(c[5]), "+r"(c[6]), "+r"(c[7]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), + "n"(int32_t(tnspB))); + } +}; + +// M64N32K16 F16->F32 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, + bool scale_out) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + "%8, %9, %10, %11, %12, %13, %14, %15}, " + "%16, %17, p, %19, %20, %21, %22;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]), + "+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]), + "+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]), + "+r"(c[14]), "+r"(c[15]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), + "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + } +}; + +// M64N64K16 F16->F32 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, + bool scale_out) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + "%8, %9, %10, %11, %12, %13, %14, %15, " + "%16, %17, %18, %19, %20, %21, %22, %23, " + "%24, %25, %26, %27, %28, %29, %30, %31}, " + "%32, %33, p, %35, %36, %37, %38;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]), + "+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]), + "+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]), + "+r"(c[14]), "+r"(c[15]), "+r"(c[16]), "+r"(c[17]), + "+r"(c[18]), "+r"(c[19]), "+r"(c[20]), "+r"(c[21]), + "+r"(c[22]), "+r"(c[23]), "+r"(c[24]), "+r"(c[25]), + "+r"(c[26]), "+r"(c[27]), "+r"(c[28]), "+r"(c[29]), + "+r"(c[30]), "+r"(c[31]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), + "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + } +}; + +// ================================= BF16 x BF16 -> F32 +// ================================= + +// M64N8K16 BF16->F32 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, + bool scale_out) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16 " + "{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), + "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + } +}; + +// M64N16K16 BF16->F32 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, + bool scale_out) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]), + "+r"(c[5]), "+r"(c[6]), "+r"(c[7]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), + "n"(int32_t(tnspB))); + } +}; + +// ================================= TF32 x TF32 -> F32 +// ================================= + +// M64N8K8 TF32->F32 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, + bool scale_out) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32 " + "{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), + "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + } +}; + +// M64N16K8 TF32->F32 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, + bool scale_out) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]), + "+r"(c[5]), "+r"(c[6]), "+r"(c[7]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), + "n"(int32_t(tnspB))); + } +}; + +// ================================= INT8 x INT8 -> INT32 +// ================================= + +// M64N8K32 S8->S32 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, + bool scale_out) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8 " + "{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), + "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + } +}; + +// M64N16K32 S8->S32 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, + bool scale_out) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 " + "{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), + "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + } +}; + +// ================================= FP8 x FP8 -> F16/F32 +// ================================= + +// M64N8K32 E4M3->F16 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, + bool scale_out) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e4m3 " + "{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), + "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + } +}; + +// M64N8K32 E4M3->F32 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, + bool scale_out) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), + "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + } +}; + +// 函数模板委托给类模板 +template +TL_DEVICE void wgmma_ss(uint64_t desc_a, uint64_t desc_b, uint32_t *c, + bool scale_out) { + WgmmaSSImpl::execute(desc_a, desc_b, c, scale_out); +} + +// ================================= Mixed Precision Support +// ================================= + +// Mixed precision: S8 x U8 -> S32 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, + bool scale_out) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8 " + "{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), + "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + } +}; + +// Mixed precision: U8 x S8 -> S32 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, + bool scale_out) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8 " + "{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), + "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + } +}; + +// Mixed precision: U8 x U8 -> S32 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, + bool scale_out) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 " + "{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), + "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + } +}; + +// Mixed precision FP8: E4M3 x E5M2 -> F16 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, + bool scale_out) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e5m2 " + "{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), + "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + } +}; + +// Mixed precision FP8: E5M2 x E4M3 -> F16 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, + bool scale_out) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e4m3 " + "{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), + "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + } +}; + +// ================================= Convenience Templates +// ================================= + +// Type trait to determine the number of output registers needed +template struct WgmmaOutputRegs { + static constexpr int value = + (M * N * (C_type == DataType::kFloat32 ? 32 : 16)) / (32 * 8); +}; + +// Type trait to get element size in bits +template struct ElementBits { + static constexpr int value = + (dtype == DataType::kFloat32 || dtype == DataType::kTensorFloat32 || + dtype == DataType::kInt32) + ? 32 + : (dtype == DataType::kFloat16 || dtype == DataType::kBFloat16 || + dtype == DataType::kInt16 || dtype == DataType::kUInt16) + ? 16 + : (dtype == DataType::kInt8 || dtype == DataType::kUInt8 || + dtype == DataType::kFloat8_e4m3 || dtype == DataType::kFloat8_e5m2) + ? 8 + : (dtype == DataType::kInt4 || dtype == DataType::kUInt4) ? 4 + : 8; +}; + +} // namespace tl \ No newline at end of file diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/intrin.h b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/intrin.h new file mode 100644 index 00000000000..ef1afa7f935 --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/intrin.h @@ -0,0 +1,119 @@ +#pragma once + +#include "common.h" +#include "cutlass/cutlass.h" + +#if __CUDA_ARCH_LIST__ >= 900 +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/mma_sm90_gmma.hpp" +#endif + +namespace tl { + +namespace detail { + +// Provide architecture-specific defaults so callers may omit arguments. +TL_DEVICE constexpr int default_warp_size() { +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIP_DEVICE_COMPILE__) + return 64; +#else + return 32; +#endif +} + +TL_DEVICE constexpr int default_warps_per_group() { return 4; } + +TL_DEVICE int linear_thread_idx_in_block() { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + return threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z); +#else + return 0; +#endif +} + +} // namespace detail + +TL_DEVICE int get_lane_idx(int warp_size = detail::default_warp_size()) { + warp_size = warp_size > 0 ? warp_size : detail::default_warp_size(); + return detail::linear_thread_idx_in_block() % warp_size; +} + +TL_DEVICE int get_warp_idx_sync(int warp_size = detail::default_warp_size()) { + warp_size = warp_size > 0 ? warp_size : detail::default_warp_size(); + return detail::linear_thread_idx_in_block() / warp_size; +} + +TL_DEVICE int get_warp_idx(int warp_size = detail::default_warp_size()) { + warp_size = warp_size > 0 ? warp_size : detail::default_warp_size(); + return detail::linear_thread_idx_in_block() / warp_size; +} + +TL_DEVICE int +get_warp_group_idx(int warp_size = detail::default_warp_size(), + int warps_per_group = detail::default_warps_per_group()) { + warp_size = warp_size > 0 ? warp_size : detail::default_warp_size(); + warps_per_group = + warps_per_group > 0 ? warps_per_group : detail::default_warps_per_group(); + int threads_per_group = warp_size * warps_per_group; + threads_per_group = threads_per_group > 0 ? threads_per_group : warp_size; + return detail::linear_thread_idx_in_block() / threads_per_group; +} + +#if __CUDA_ARCH_LIST__ >= 900 +TL_DEVICE void warpgroup_arrive() { cute::warpgroup_arrive(); } +TL_DEVICE void warpgroup_commit_batch() { cute::warpgroup_commit_batch(); } + +template TL_DEVICE void warpgroup_wait() { + cute::warpgroup_wait(); +} + +// Template parameter: +// thread_extent: the logical size (in number of threads) of each "group" +// within which we want to elect exactly ONE representative +// thread. +template TL_DEVICE bool tl_shuffle_elect() { + + // Special case: thread_extent == 0 means "elect exactly one thread + // in the entire thread block", i.e., the leader of the first warp of the + // block. + if constexpr (thread_extent == 0) { + // cutlass::canonical_warp_idx_sync(): + // Returns the warp ID within the thread block in a "canonical" way + // (0 for the first warp, 1 for the second, ...). + // cute::elect_one_sync(): + // Elect exactly one lane in the warp to return true (typically lane 0), + // other lanes return false. + // The condition ensures that: + // (1) We are in warp 0 of the block. + // (2) We are the elected lane in this warp. + return cutlass::canonical_warp_idx_sync() == 0 && cute::elect_one_sync(); + } + + // General case: thread_extent != 0 + // (threadIdx.x / 32) is the warp index in the block. + // (thread_extent / 32) is the number of warps in one group of size + // thread_extent. We take warp_id % num_warps_in_group to get the warp's index + // within the group. + // __shfl_sync(mask, value, srcLane): broadcast 'value' from srcLane to all + // lanes in the warp. Here it broadcasts the group-local warp index from lane + // 0. Comparing to 0 selects only the group's warp 0. + return __shfl_sync(0xffffffff, // full warp mask + (threadIdx.x / 32) % + (thread_extent / 32), // warp index within group + 0 // take the value from lane 0 + ) == 0 && + // Within that group leader warp, elect exactly one lane (typically + // lane 0) to be the single representative for the group. + cute::elect_one_sync(); +} + +template TL_DEVICE void warpgroup_reg_alloc() { + asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(RegCount)); +} + +template TL_DEVICE void warpgroup_reg_dealloc() { + asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount)); +} +#endif + +} // namespace tl diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/ldsm.h b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/ldsm.h new file mode 100644 index 00000000000..4d6af8a0998 --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/ldsm.h @@ -0,0 +1,121 @@ +#pragma once + +#include "common.h" + +namespace tl { + +TL_DEVICE void ptx_ldmatrix_x1(void const *const smem_ptr, + void *const local_ptr) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + int32_t *value = reinterpret_cast(local_ptr); + asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" + : "=r"(value[0]) + : "r"(smem_int_ptr)); +} + +TL_DEVICE void ptx_ldmatrix_x2(void const *const smem_ptr, + void *const local_ptr) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + int32_t *value = reinterpret_cast(local_ptr); + asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(value[0]), "=r"(value[1]) + : "r"(smem_int_ptr)); +} + +TL_DEVICE void ptx_ldmatrix_x4(void const *const smem_ptr, + void *const local_ptr) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + int32_t *value = reinterpret_cast(local_ptr); + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(value[0]), "=r"(value[1]), "=r"(value[2]), "=r"(value[3]) + : "r"(smem_int_ptr)); +} + +TL_DEVICE void ptx_ldmatrix_x1_trans(void const *const smem_ptr, + void *const local_ptr) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + int32_t *value = reinterpret_cast(local_ptr); + asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n" + : "=r"(value[0]) + : "r"(smem_int_ptr)); +} + +TL_DEVICE void ptx_ldmatrix_x2_trans(void const *const smem_ptr, + void *const local_ptr) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + int32_t *value = reinterpret_cast(local_ptr); + asm volatile( + "ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(value[0]), "=r"(value[1]) + : "r"(smem_int_ptr)); +} + +TL_DEVICE void ptx_ldmatrix_x4_trans(void const *const smem_ptr, + void *const local_ptr) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + int32_t *value = reinterpret_cast(local_ptr); + asm volatile( + "ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(value[0]), "=r"(value[1]), "=r"(value[2]), "=r"(value[3]) + : "r"(smem_int_ptr)); +} + +TL_DEVICE void ptx_stmatrix_x1(void const *const smem_ptr, + const int32_t &value0) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n" ::"r"( + smem_int_ptr), + "r"(value0)); +} + +TL_DEVICE void ptx_stmatrix_x2(void const *const smem_ptr, + const int32_t &value0, const int32_t &value1) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" ::"r"( + smem_int_ptr), + "r"(value0), "r"(value1)); +} + +TL_DEVICE void ptx_stmatrix_x4(void const *const smem_ptr, + const int32_t &value0, const int32_t &value1, + const int32_t &value2, const int32_t &value3) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" :: + "r"(smem_int_ptr), + "r"(value0), "r"(value1), "r"(value2), "r"(value3)); +} + +TL_DEVICE void ptx_stmatrix_x1_trans(void const *const smem_ptr, + const int32_t &value0) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [%0], {%1};\n" ::"r"( + smem_int_ptr), + "r"(value0)); +} + +TL_DEVICE void ptx_stmatrix_x2_trans(void const *const smem_ptr, + const int32_t &value0, + const int32_t &value1) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [%0], {%1, %2};\n" ::"r"( + smem_int_ptr), + "r"(value0), "r"(value1)); +} + +TL_DEVICE void ptx_stmatrix_x4_trans(void const *const smem_ptr, + const int32_t &value0, + const int32_t &value1, + const int32_t &value2, + const int32_t &value3) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, " + "%3, %4};\n" ::"r"(smem_int_ptr), + "r"(value0), "r"(value1), "r"(value2), "r"(value3)); +} + +} // namespace tl \ No newline at end of file diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/nvrtc_std.h b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/nvrtc_std.h new file mode 100644 index 00000000000..9930c220036 --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/nvrtc_std.h @@ -0,0 +1,123 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#ifdef __CUDACC_RTC__ + +using int8_t = signed char; +using uint8_t = unsigned char; +using int16_t = signed short; +using uint16_t = unsigned short; +using int32_t = signed int; +using uint32_t = unsigned int; +using int64_t = signed long long; +using uint64_t = unsigned long long; +using cuuint64_t = unsigned long long; + +#ifndef CU_TENSOR_MAP_NUM_QWORDS +#define CU_TENSOR_MAP_NUM_QWORDS 16 + +struct CUtensorMap_st { +#if defined(__cplusplus) && (__cplusplus >= 201103L) + alignas(64) +#elif __STDC_VERSION__ >= 201112L + _Alignas(64) +#endif + cuuint64_t opaque[CU_TENSOR_MAP_NUM_QWORDS]; +}; + +using CUtensorMap = CUtensorMap_st; +#endif + +namespace std { + +template struct integral_constant { + static constexpr T value = v; + + using value_type = T; + using type = integral_constant; + + __device__ constexpr operator value_type() const noexcept { return value; } + + __device__ constexpr value_type operator()() const noexcept { return value; } +}; + +using false_type = integral_constant; +using true_type = integral_constant; + +template struct is_same : false_type {}; + +template struct is_same : true_type {}; + +template +inline constexpr bool is_same_v = is_same::value; + +namespace index_sequence_impl { + +// Based on https://stackoverflow.com/a/32223343/11717224 +template struct index_sequence { + using type = index_sequence; + using value_type = size_t; + static constexpr size_t size() noexcept { return sizeof...(Ints); } +}; + +template struct _merge_and_renumber; + +template +struct _merge_and_renumber, index_sequence> + : index_sequence {}; + +template +struct make_index_sequence + : _merge_and_renumber::type, + typename make_index_sequence::type> {}; + +template <> struct make_index_sequence<0> : index_sequence<> {}; +template <> struct make_index_sequence<1> : index_sequence<0> {}; + +} // namespace index_sequence_impl + +template +using index_sequence = index_sequence_impl::index_sequence; + +template +using make_index_sequence = index_sequence_impl::make_index_sequence; + +template constexpr T min(T a, T b) { return a < b ? a : b; } + +template constexpr T max(T a, T b) { return a > b ? a : b; } + +template struct conditional { + using type = T; +}; + +template struct conditional { + using type = F; +}; + +template +using conditional_t = typename conditional::type; + +template struct enable_if {}; + +template struct enable_if { + using type = T; +}; +} // namespace std + +#endif \ No newline at end of file diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/reduce.h b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/reduce.h new file mode 100644 index 00000000000..331da6dc871 --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/reduce.h @@ -0,0 +1,283 @@ +#pragma once + +#include "common.h" + +namespace tl { + +struct SumOp { + template TL_DEVICE T operator()(T const &x, T const &y) { + return x + y; + } +}; + +struct MaxOp { + template TL_DEVICE T operator()(T const &x, T const &y) { + return cutlass::fast_max(x, y); + } +}; + +struct MinOp { + template TL_DEVICE T operator()(T const &x, T const &y) { + return cutlass::fast_min(x, y); + } +}; + +struct BitAndOp { + template TL_DEVICE T operator()(T const &x, T const &y) { + return x & y; + } +}; + +struct BitOrOp { + template TL_DEVICE T operator()(T const &x, T const &y) { + return x | y; + } +}; + +struct BitXorOp { + template TL_DEVICE T operator()(T const &x, T const &y) { + return x ^ y; + } +}; + +template +struct SharedReduceWarp { + template + static TL_DEVICE void run(const T *__restrict__ src, T *__restrict__ dst, + int total_dest, int reduce_extent, int tail, + T init_value) { + if (total_dest <= 0 || reduce_extent <= 0) + return; + constexpr int kWarpSize = 32; + static_assert(Threads % kWarpSize == 0, + "SharedReduceWarp expects blockDim.x to be a multiple of " + "warp size on CUDA."); + const int tid = threadIdx.x; + const int warp_id = tid / kWarpSize; + const int lane = tid % kWarpSize; + const int num_warps = Threads / kWarpSize; + for (int dest_idx = warp_id; dest_idx < total_dest; dest_idx += num_warps) { + const int prefix = tail == 1 ? dest_idx : dest_idx / tail; + const int suffix = tail == 1 ? 0 : dest_idx % tail; + const int src_base = (prefix * reduce_extent) * tail + suffix; + const int dst_index = prefix * tail + suffix; + + T partial = init_value; + for (int rv = lane; rv < reduce_extent; rv += kWarpSize) { + T val = src[src_base + rv * tail]; + if constexpr (UseAbs) { + val = val < T(0) ? -val : val; + } + partial = Reducer()(partial, val); + } + + unsigned mask = __activemask(); + for (int offset = kWarpSize / 2; offset > 0; offset >>= 1) { + T other = __shfl_down_sync(mask, partial, offset); + partial = Reducer()(partial, other); + } + + if (lane == 0) { + if constexpr (NeedAccumulate) { + partial = Reducer()(dst[dst_index], partial); + } + dst[dst_index] = partial; + } + } + } +}; + +template +struct AllReduce { + static_assert(threads == 1024 or threads == 512 or threads == 256 or + threads == 128 or threads == 64 or threads == 32 or + threads == 16 or threads == 8 or threads == 4 or threads == 2); + static_assert(threads % scale == 0); + template static TL_DEVICE T run(T x, T *red_buf = nullptr) { + constexpr int offset = threads / 2; + if constexpr (offset >= 32) { + __syncthreads(); + red_buf[threadIdx.x - thread_offset] = x; + __syncthreads(); + x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]); + } else { + x = Reducer()(x, T(__shfl_xor_sync(uint32_t(-1), x, offset))); + } + if constexpr (offset == scale) { + return x; + } else { + return AllReduce::run( + x, red_buf); + } + } + + template + static TL_DEVICE T run_hopper(T x, T *red_buf = nullptr) { + constexpr int offset = threads / 2; + if constexpr (offset >= 32) { + asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(all_threads)); + red_buf[threadIdx.x - thread_offset] = x; + // TODO(lei): maybe we can merge the two bar.sync into one? + asm volatile("bar.sync %0, %1;" : : "r"(2), "r"(all_threads)); + x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]); + } else { + x = Reducer()(x, T(__shfl_xor_sync(uint32_t(-1), x, offset))); + } + if constexpr (offset == scale) { + return x; + } else { + return AllReduce::run_hopper(x, red_buf); + } + } +}; + +template struct CumSum1D { + static_assert(threads == 1024 or threads == 512 or threads == 256 or + threads == 128 or threads == 64 or threads == 32); + template + static TL_DEVICE void run(const T *__restrict__ src, T *__restrict__ dst, + int N) { + if (N <= 0) + return; + + constexpr unsigned MASK = 0xffffffff; + const int tid = threadIdx.x; + const int lane = tid % SEG; + + if (tid >= SEG) + return; + + T carry = (T)0; + + if (reverse) { + const int num_segments = (N + SEG - 1) / SEG; + for (int seg = num_segments - 1; seg >= 0; --seg) { + const int idx = seg * SEG + lane; + T val = (idx < N) ? src[idx] : (T)0; + +#pragma unroll + for (int off = 1; off < SEG; off <<= 1) { + T n = (T)__shfl_down_sync(MASK, val, off); + if (lane < SEG - off) + val += n; + } + + val += carry; + + if (idx < N) + dst[idx] = val; + + T segSum = (T)__shfl_sync(MASK, val, 0); + if (lane == 0) + carry = segSum; + carry = (T)__shfl_sync(MASK, carry, 0); + } + } else { + const int num_segments = (N + SEG - 1) / SEG; + for (int seg = 0; seg < num_segments; ++seg) { + const int idx = seg * SEG + lane; + T val = (idx < N) ? src[idx] : (T)0; + +#pragma unroll + for (int off = 1; off < SEG; off <<= 1) { + T n = (T)__shfl_up_sync(MASK, val, off); + if (lane >= off) + val += n; + } + + val += carry; + + if (idx < N) + dst[idx] = val; + + T segSum = (T)__shfl_sync(MASK, val, SEG - 1); + if (lane == SEG - 1) + carry = segSum; + carry = (T)__shfl_sync(MASK, carry, SEG - 1); + } + } + } +}; + +template struct CumSum2D { + static_assert(threads == 1024 or threads == 512 or threads == 256 or + threads == 128 or threads == 64 or threads == 32); + template + static TL_DEVICE T run(const T *__restrict__ src, T *__restrict__ dst, int H, + int W) { + + constexpr int TILE_H = threads / SEG; + constexpr unsigned MASK = 0xffffffff; + const int num_blocks = (H + TILE_H - 1) / TILE_H; + const int tid = threadIdx.x; + const int lane = tid % 32; + const int row = tid / 32; + + for (int b = 0; b < num_blocks; ++b) { + const int gRow = b * TILE_H + row; + if (gRow >= H) + return; + + T carry = (T)0; + + if (reverse) { + // Start from the last segment for reverse mode + for (int seg = (W + SEG - 1) / SEG - 1; seg >= 0; --seg) { + const int col = seg * SEG + lane; + + const int real_row = Axis == 1 ? gRow : col; + const int real_col = Axis == 1 ? col : gRow; + + T val = (col < W) ? src[real_row * W + real_col] : (T)0; + +#pragma unroll + for (int off = 1; off < SEG; off <<= 1) { + T n = (T)__shfl_down_sync(MASK, val, off); + if (lane < SEG - off) + val += n; + } + + val += carry; + + if (real_col < W) + dst[real_row * W + real_col] = val; + + T segSum = (T)__shfl_sync(MASK, val, (T)0); + if (lane == 0) + carry = segSum; + carry = (T)__shfl_sync(MASK, carry, (T)0); + } + } else { + for (int seg = 0; seg * SEG < W; ++seg) { + const int col = seg * SEG + lane; + + const int real_row = Axis == 1 ? gRow : col; + const int real_col = Axis == 1 ? col : gRow; + + T val = (col < W) ? src[real_row * W + real_col] : (T)0; + +#pragma unroll + for (int off = 1; off < SEG; off <<= 1) { + T n = (T)__shfl_up_sync(MASK, val, off); + if (lane >= off) + val += n; + } + + val += carry; + + if (real_col < W) + dst[real_row * W + real_col] = val; + + T segSum = (T)__shfl_sync(MASK, val, SEG - 1); + if (lane == SEG - 1) + carry = segSum; + carry = (T)__shfl_sync(MASK, carry, SEG - 1); + } + } + } + } +}; + +} // namespace tl diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/tcgen_05.h b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/tcgen_05.h new file mode 100644 index 00000000000..1211bc246c8 --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/tcgen_05.h @@ -0,0 +1,70 @@ +#pragma once + +#include +#ifndef __CUDACC_RTC__ +#include +#endif + +#include "common.h" + +namespace tl { + +TL_DEVICE void tmem_allocate(void *dst_ptr, int num_columns) { + uint32_t dst_intptr = smem_ptr_to_uint(dst_ptr); + asm volatile( + "tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;" + : + : "r"(dst_intptr), "r"(num_columns)); +} + +TL_DEVICE void tmem_deallocate(uint32_t *tmem_ptr, int num_columns) { + asm volatile("{\n\t" + "tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1; \n\t" + "}" + : + : "r"(*tmem_ptr), "r"(num_columns)); +} + +inline void __device__ fence_view_async_tmem_load() { + asm volatile("tcgen05.wait::ld.sync.aligned; " ::); +} + +inline void __device__ fence_view_async_tmem_store() { + asm volatile("tcgen05.wait::st.sync.aligned; " ::); +} + +template +inline void __device__ amma_fp16bf16_ss(uint64_t const desc_a, + uint64_t const desc_b, + uint32_t const tmem_c, + uint32_t const idesc, + uint32_t const addC = 1) { + static_assert(M == 64 || M == 128, "SM100_MMA_F16BF16 M-mode size should be " + "64 or 128 for 1 CTA cluster MMA."); + static_assert( + (M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) || + (M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)), + "SM100_MMA_F16BF16 N-mode size should be a multiple of 8 between 8 and 256 for M=64,\ + or a multiple of 16 between 16 and 256 for M=128."); + + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile("{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%5, %6, " + "%7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(idesc), "r"(addC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); +} + +inline __device__ void amma_commit(uint64_t const *smem_ptr) { + uint32_t bar_intptr = smem_ptr_to_uint(smem_ptr); + asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::" + "cluster.b64 [%0];" + : + : "r"(bar_intptr)); +} + +} // namespace tl \ No newline at end of file diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/tcgen_05_ld.h b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/tcgen_05_ld.h new file mode 100644 index 00000000000..b2eb2f81603 --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/tcgen_05_ld.h @@ -0,0 +1,713 @@ +#pragma once + +#include +#ifndef __CUDACC_RTC__ +#include +#endif + +#include "common.h" + +namespace tl { + +// 32 data path lanes, 32-bit pattern, repeated N times +class tmem_ld_32dp32bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, + "N must be a power of 2 and lies between 1 ~ 128"); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x1.b32" + "{%0}," + "[%1];\n" + : "=r"(dst_ptr[0]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x2.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x4.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.x32.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 64) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.x64.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 128) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.x128.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " + "%70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " + "%84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " + "%98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile("trap"); + } + } +}; + +// 16 data path lanes, 64-bit pattern, repeated N times +class tmem_ld_16dp64bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, + "N must be a power of 2 and lies between 1 ~ 128"); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.16x64b.x1.b32" + "{%0}," + "[%1];\n" + : "=r"(dst_ptr[0]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.16x64b.x2.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile("tcgen05.ld.sync.aligned.16x64b.x4.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile("tcgen05.ld.sync.aligned.16x64b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile( + "tcgen05.ld.sync.aligned.16x64b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile( + "tcgen05.ld.sync.aligned.16x64b.x32.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 64) { + asm volatile( + "tcgen05.ld.sync.aligned.16x64b.x64.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 128) { + asm volatile( + "tcgen05.ld.sync.aligned.16x64b.x128.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " + "%70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " + "%84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " + "%98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile("trap"); + } + } +}; + +// 16 data path lanes, 128-bit pattern, repeated N times +class tmem_ld_16dp128bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 64, + "N must be a power of 2 and lies between 1 ~ 64"); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.16x128b.x1.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.16x128b.x2.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile("tcgen05.ld.sync.aligned.16x128b.x4.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile( + "tcgen05.ld.sync.aligned.16x128b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile( + "tcgen05.ld.sync.aligned.16x128b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile( + "tcgen05.ld.sync.aligned.16x128b.x32.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 64) { + asm volatile( + "tcgen05.ld.sync.aligned.16x128b.x64.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " + "%70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " + "%84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " + "%98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile("trap"); + } + } +}; + +// 16 data path lanes, 256-bit pattern, repeated N times +class tmem_ld_16dp256bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 32, + "N must be a power of 2 and lies between 1 ~ 32"); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.16x256b.x2.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x4.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x32.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " + "%70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " + "%84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " + "%98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile("trap"); + } + } +}; + +// 32 data path lanes, 64-bit pattern, repeated N times +// (conducted with 2x16dp64bNx) +class tmem_ld_32dp64bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + tmem_ld_16dp64bNx::copy(src_addr, dst_ptr); + tmem_ld_16dp64bNx::copy(src_addr + (16 << 16), dst_ptr + N); + } +}; + +// 32 data path lanes, 128-bit pattern, repeated N times +class tmem_ld_32dp128bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + tmem_ld_16dp128bNx::copy(src_addr, dst_ptr); + tmem_ld_16dp128bNx::copy(src_addr + (16 << 16), dst_ptr + N * 2); + } +}; + +// 32 data path lanes, 256-bit pattern, repeated N times +class tmem_ld_32dp256bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + tmem_ld_16dp256bNx::copy(src_addr, dst_ptr); + tmem_ld_16dp256bNx::copy(src_addr + (16 << 16), dst_ptr + N * 4); + } +}; + +} // namespace tl diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/threadblock_swizzle.h b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/threadblock_swizzle.h new file mode 100644 index 00000000000..60fa0ad1f05 --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/tl_templates/cuda/threadblock_swizzle.h @@ -0,0 +1,43 @@ +#pragma once + +#include "common.h" + +namespace tl { + +template TL_DEVICE dim3 rasterization2DRow() { + const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x; + const unsigned int grid_size = gridDim.x * gridDim.y; + const unsigned int panel_size = panel_width * gridDim.x; + const unsigned int panel_offset = block_idx % panel_size; + const unsigned int panel_idx = block_idx / panel_size; + const unsigned int total_panel = cutlass::ceil_div(grid_size, panel_size); + const unsigned int stride = + panel_idx + 1 < total_panel + ? panel_width + : (grid_size - panel_idx * panel_size) / gridDim.x; + const unsigned int col_idx = (panel_idx & 1) + ? gridDim.x - 1 - panel_offset / stride + : panel_offset / stride; + const unsigned int row_idx = panel_offset % stride + panel_idx * panel_width; + return {col_idx, row_idx, blockIdx.z}; +} + +template TL_DEVICE dim3 rasterization2DColumn() { + const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x; + const unsigned int grid_size = gridDim.x * gridDim.y; + const unsigned int panel_size = panel_width * gridDim.y; + const unsigned int panel_offset = block_idx % panel_size; + const unsigned int panel_idx = block_idx / panel_size; + const unsigned int total_panel = cutlass::ceil_div(grid_size, panel_size); + const unsigned int stride = + panel_idx + 1 < total_panel + ? panel_width + : (grid_size - panel_idx * panel_size) / gridDim.y; + const unsigned int row_idx = (panel_idx & 1) + ? gridDim.y - 1 - panel_offset / stride + : panel_offset / stride; + const unsigned int col_idx = panel_offset % stride + panel_idx * panel_width; + return {col_idx, row_idx, blockIdx.z}; +} + +} // namespace tl diff --git a/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/vendor_config.h b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/vendor_config.h new file mode 100644 index 00000000000..f3849d073e0 --- /dev/null +++ b/ggml/src/ggml-cuda/vendors/tilelang/fp8_lightning_indexer/vendor_config.h @@ -0,0 +1,7 @@ +#pragma once +#ifndef __grid_constant__ +#define __grid_constant__ +#endif +#ifndef CUTLASS_CUDA_DRIVER_WRAPPER_CALL +#define CUTLASS_CUDA_DRIVER_WRAPPER_CALL(x) x +#endif diff --git a/ggml/src/ggml-fp8.cpp b/ggml/src/ggml-fp8.cpp new file mode 100644 index 00000000000..406c1b56286 --- /dev/null +++ b/ggml/src/ggml-fp8.cpp @@ -0,0 +1,249 @@ +#include +#include +#include +#include + +#define GGML_COMMON_DECL_CPP +#include "ggml-common.h" +#include "ggml.h" + +#include "ggml-fp8.h" + +union fp32_int32 { + float f; + uint32_t bits; +}; + + +static inline uint8_t float_to_e4m3_bits(float flt) { + constexpr int FP32_NUM_BITS = 32; + constexpr int FP32_NUM_MANTISSA_BITS = 23; + constexpr int FP32_EXPONENT_BIAS = 127; + + constexpr int FP8_NUM_EXPONENT_BITS = 4; + constexpr int FP8_NUM_MANTISSA_BITS = 3; + constexpr uint8_t FP8_NAN = 0x7fu; + constexpr int FP8_MAX_EXPONENT = 7; + constexpr int FP8_MIN_EXPONENT = -6; + constexpr int FP8_EXPONENT_BIAS = 7; + constexpr uint8_t FP8_EXPONENT_MASK = (1u << FP8_NUM_EXPONENT_BITS) - 1u; + constexpr uint8_t FP8_MANTISSA_MASK = (1u << FP8_NUM_MANTISSA_BITS) - 1u; + constexpr uint8_t FP8_MAX_FLT = 0x7eu; + + auto is_nan_f = [](float v) { + uint32_t s; std::memcpy(&s, &v, sizeof(s)); + return (s & 0x7fffffffu) > 0x7f800000u; + }; + auto is_inf_f = [](float v) { + uint32_t s; std::memcpy(&s, &v, sizeof(s)); + return (s == 0x7f800000u) || (s == 0xff800000u); + }; + + uint32_t s; std::memcpy(&s, &flt, sizeof(s)); + uint8_t sign = uint8_t((s >> 24) & 0x80u); + int32_t exp = int32_t((s >> FP32_NUM_MANTISSA_BITS) & 0xffu) - FP32_EXPONENT_BIAS; + int mantissa = int(s & 0x7fffffu); + uint8_t u = 0; + uint8_t const kF8_NaN = FP8_NAN; + + if (is_nan_f(flt)) { + return kF8_NaN; + } + if (is_inf_f(flt)) { + return uint8_t(sign | FP8_MAX_FLT); + } + if (exp == -128) { + return uint8_t(sign | FP8_MAX_FLT); + } + + int sticky_bit = 0; + bool skip_sign = false; + bool may_be_nan = false; + + if (exp >= FP8_MIN_EXPONENT && exp <= FP8_MAX_EXPONENT) { + exp = exp + FP8_EXPONENT_BIAS; + u = uint8_t((uint32_t(exp) & FP8_EXPONENT_MASK) << FP8_NUM_MANTISSA_BITS); + u = uint8_t(u | (mantissa >> (FP32_NUM_MANTISSA_BITS - FP8_NUM_MANTISSA_BITS))); + } else if (exp < FP8_MIN_EXPONENT) { + int rshift = (FP8_MIN_EXPONENT - exp); + if (rshift < FP32_NUM_BITS) { + mantissa |= (1 << FP32_NUM_MANTISSA_BITS); + sticky_bit = ((mantissa & ((1 << rshift) - 1)) != 0); + mantissa = (mantissa >> rshift); + u = uint8_t((mantissa >> (FP32_NUM_MANTISSA_BITS - FP8_NUM_MANTISSA_BITS)) & FP8_MANTISSA_MASK); + } else { + mantissa = 0; + u = 0; + } + } else { + if (exp == (FP8_MAX_EXPONENT + 1)) { + uint8_t mantissa_tmp = uint8_t(mantissa >> (FP32_NUM_MANTISSA_BITS - FP8_NUM_MANTISSA_BITS)); + if (mantissa_tmp < FP8_MANTISSA_MASK) { + exp = exp + FP8_EXPONENT_BIAS; + u = uint8_t(uint32_t(exp) << FP8_NUM_MANTISSA_BITS) | mantissa_tmp; + may_be_nan = (mantissa_tmp == (FP8_MANTISSA_MASK - 1)); + } else { + return uint8_t(sign | FP8_MAX_FLT); + } + } else { + return uint8_t(sign | FP8_MAX_FLT); + } + } + + int NUM_BITS_SHIFT = FP32_NUM_MANTISSA_BITS - (FP8_NUM_MANTISSA_BITS + 1); + int round_bit = ((mantissa >> NUM_BITS_SHIFT) & 1); + sticky_bit |= ((mantissa & ((1 << NUM_BITS_SHIFT) - 1)) != 0); + + if ((round_bit && sticky_bit) || (round_bit && (u & 1))) { + u = uint8_t(u + 1); + if (may_be_nan) { + skip_sign = true; + } + } + + if (u > FP8_MAX_FLT) { + u = uint8_t(sign | FP8_MAX_FLT); + } + if (!skip_sign) { + u = uint8_t(u | sign); + } + return u; +} + +template +inline FP8 float_to_fp8(float value) { + if constexpr (E == 4) { + FP8 out; + out.bits = float_to_e4m3_bits(value); + return out; + } + + FP8 out; + fp32_int32 in = {value}; + // the sign + out.bits = (in.bits >> 24) & 0x80; + // value without sign + in.bits &= 0x7fffffff; + //GGML_ASSERT(in.bits < 0x7f800000); // +/- infinity or NAN + if (in.f >= FP8::MAX) { + out.bits |= 0x7E; + } else if (in.f < FP8::MIN) { // => 0. + // OK: S.0000000 + } else { + in.f *= exp_f2::E_BIAS-127>(); + // - trunc + //uint32_t eps = 0; + // - rounding half away from zero + //uint32_t eps = 0x400000>>FP8::M; + // - rounding half toward zero + //uint32_t eps = 0x3fffff>>FP8::M; + // - rounding to nearest even + uint32_t eps = (0x3fffff>>FP8::M) + ((in.bits >> (23-FP8::M)) & 0x1); + // shift mantissa. + in.bits += eps; + out.bits |= (in.bits >> (23-FP8::M)) & 0x7F; + } + return out; +} + +template +inline float fp8_to_float(const FP8& in) { + if constexpr (E == 4) { + const uint8_t v = in.bits; + const uint32_t exp = (v >> 3) & 0x0F; + const uint32_t mant = v & 0x07; + if (exp == 0x0F && mant == 0x07) { + fp32_int32 out_nan; + out_nan.bits = 0x7fffffff; // FP32_NAN pattern used in tests + return out_nan.f; + } + } + + fp32_int32 out = {0}; + out.bits = in.bits & 0x80; + out.bits <<= 24; + uint32_t _bits = in.bits & 0x7F; + _bits <<= (23-FP8::M); + out.bits |= _bits; + out.f *= exp_f2<127-FP8::E_BIAS>(); + return out.f; +} + +template +static inline void conv(const FP8* x, float* y, int64_t size) { + for (int64_t i=0; i +static inline void conv(const float* x, FP8* y, int64_t size) { + for (int64_t i=0; i(x[i]); + } +} + +template +struct bloc_fp8 { + float d; + FP8 qs[QK]; +}; + +template +static inline void conv(const bloc_fp8* x, float* y, int64_t size) { + const auto qk_size = size / QK; + for (int64_t q=0; q +static inline void conv(const float* x, bloc_fp8* y, int64_t size) { + const auto qk_size = size / QK; + for (int64_t q=0; q::MAX/m; + y[q].d = m/FP8::MAX; + for (int64_t i=0; i(x[q*QK+i]*D); + } + } +} + +// the C API. +void ggml_e5m2_to_fp32_row(const ggml_e5m2_t * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + conv(reinterpret_cast*>(x), y, k); +} +void ggml_fp32_to_e5m2_row_ref(const float * GGML_RESTRICT x, ggml_e5m2_t * GGML_RESTRICT y, int64_t k) { + conv(x, reinterpret_cast*>(y), k); +} + +void ggml_e4m3_to_fp32_row(const ggml_e4m3_t * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + conv(reinterpret_cast*>(x), y, k); +} +void ggml_fp32_to_e4m3_row_ref(const float * GGML_RESTRICT x, ggml_e4m3_t * GGML_RESTRICT y, int64_t k) { + conv(x, reinterpret_cast*>(y), k); +} + +void dequantize_row_e4m3_q(const block_e4m3_q * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + conv(reinterpret_cast*>(x), y, k); +} +void quantize_row_e4m3_q_ref(const float * GGML_RESTRICT x, block_e4m3_q * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + conv(x, reinterpret_cast*>(y), k); +} + +void dequantize_row_e3m4_q(const block_e3m4_q * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + conv(reinterpret_cast*>(x), y, k); +} +void quantize_row_e3m4_q_ref(const float * GGML_RESTRICT x, block_e3m4_q * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + conv(x, reinterpret_cast*>(y), k); +} diff --git a/ggml/src/ggml-fp8.h b/ggml/src/ggml-fp8.h new file mode 100644 index 00000000000..da7784d4601 --- /dev/null +++ b/ggml/src/ggml-fp8.h @@ -0,0 +1,45 @@ +// this is more a .inc. +#ifdef __cplusplus +template +constexpr int exp_i2() { + return 1 << N; +} + +template +constexpr float exp_f2() { + if constexpr (N>0) return exp_f2()*2; + if constexpr (N<0) return exp_f2()/2; + if constexpr (N==0) return 1.; +} + + +template //, int M=7-E> 1.7 bits! +struct FP8 { + uint8_t bits; + using type = FP8<_E>; + static constexpr int E = _E; + static constexpr int M = (7-_E); + static constexpr int E_BIAS = exp_i2()-1; + static constexpr float MAX = (2-exp_f2<-M+1>())*exp_f2()>(); + static constexpr float MIN = exp_f2<-M>()*exp_f2<2-exp_i2()>(); +}; + +extern "C" { +#endif + + // Note: types are define in ggml-common.h + GGML_API void ggml_e5m2_to_fp32_row(const ggml_e5m2_t * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); + GGML_API void ggml_fp32_to_e5m2_row_ref(const float * GGML_RESTRICT x, ggml_e5m2_t * GGML_RESTRICT y, int64_t k); + + GGML_API void ggml_e4m3_to_fp32_row(const ggml_e4m3_t * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); + GGML_API void ggml_fp32_to_e4m3_row_ref(const float * GGML_RESTRICT x, ggml_e4m3_t * GGML_RESTRICT y, int64_t k); + + GGML_API void dequantize_row_e4m3_q(const block_e4m3_q * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); + GGML_API void quantize_row_e4m3_q_ref(const float * GGML_RESTRICT x, block_e4m3_q * GGML_RESTRICT y, int64_t k); + + GGML_API void dequantize_row_e3m4_q(const block_e3m4_q * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); + GGML_API void quantize_row_e3m4_q_ref(const float * GGML_RESTRICT x, block_e3m4_q * GGML_RESTRICT y, int64_t k); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index de5cbd75e86..a4499869d5b 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -5307,7 +5307,26 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb); } break; - + case GGML_TYPE_E4M3_Q: + case GGML_TYPE_E3M4_Q: + { + // Note realy clean, but it is the same test for E4M3. + const block_e3m4_q * q = (const block_e3m4_q *) data; + int nans = 0; + for (size_t i = 0; i < nb; ++i) { + if (!validate_float(q[i].d, i)) { + return false; + } + // NAN + for (size_t k = 0; k < QK_K; ++k) { + nans += (q[i].qs[k] & 0x7f) == 0x7f; + } + } + if (nans) { + fprintf(stderr, "%s: found %d NaNs in row of %zu FP8 values\n", __func__, nans, nb*QK_K); + return false; + } + } break; case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index aecbdad5a3d..a57c0a1cec1 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -9,6 +9,7 @@ // FIXME: required here for quantization functions #include "ggml-quants.h" +#include "ggml-fp8.h" #ifdef GGML_USE_CPU_HBM #include @@ -23,6 +24,8 @@ #include #include #include +#include + #include #include #include @@ -873,6 +876,38 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { .type_size = 0, .is_quantized = false, }, + [GGML_TYPE_E5M2] = { + .type_name = "fp8_e5m2", + .blck_size = 1, + .type_size = sizeof(ggml_e5m2_t), + .is_quantized = true, + .to_float = (ggml_to_float_t) ggml_e5m2_to_fp32_row, + .from_float_ref = (ggml_from_float_t) ggml_fp32_to_e5m2_row_ref, + }, + [GGML_TYPE_E4M3] = { + .type_name = "fp8_e4m3", + .blck_size = 1, + .type_size = sizeof(ggml_e4m3_t), + .is_quantized = true, + .to_float = (ggml_to_float_t) ggml_e4m3_to_fp32_row, + .from_float_ref = (ggml_from_float_t) ggml_fp32_to_e4m3_row_ref, + }, + [GGML_TYPE_E4M3_Q] = { + .type_name = "fp8_e4m3_q", + .blck_size = QK_K, + .type_size = sizeof(block_e4m3_q), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_e4m3_q, + .from_float_ref = (ggml_from_float_t) quantize_row_e4m3_q_ref, + }, + [GGML_TYPE_E3M4_Q] = { + .type_name = "fp8_e3m4_q", + .blck_size = QK_K, + .type_size = sizeof(block_e3m4_q), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_e3m4_q, + .from_float_ref = (ggml_from_float_t) quantize_row_e3m4_q_ref, + }, }; const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type) { @@ -1016,10 +1051,12 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "OPT_STEP_ADAMW", "OPT_STEP_SGD", + "SPARSE_TOPK_RADIX", + "INDEXER_FUSED", "GLU", }; -static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); +static_assert(GGML_OP_COUNT == 93, "GGML_OP_COUNT != 92"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1120,10 +1157,12 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "adamw(x)", "sgd(x)", + "sparse_topk_radix(x)", + "indexer_fused(x)", "glu(x)", }; -static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); +static_assert(GGML_OP_COUNT == 93, "GGML_OP_COUNT != 92"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -1340,6 +1379,10 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break; case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break; case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break; + case GGML_FTYPE_MOSTLY_E5M2: wtype = GGML_TYPE_E5M2; break; + case GGML_FTYPE_MOSTLY_E4M3: wtype = GGML_TYPE_E4M3; break; + case GGML_FTYPE_MOSTLY_E4M3_Q: wtype = GGML_TYPE_E4M3_Q; break; + case GGML_FTYPE_MOSTLY_E3M4_Q: wtype = GGML_TYPE_E3M4_Q; break; case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break; case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break; } @@ -1920,6 +1963,11 @@ static struct ggml_tensor * ggml_add_impl( bool inplace) { GGML_ASSERT(ggml_can_repeat(b, a)); + // Ensure RHS has CUDA-friendly stride alignment for broadcast add + if (ggml_type_size(b->type) > 0 && (b->nb[1] % ggml_type_size(b->type)) != 0) { + b = ggml_cont(ctx, b); + } + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); result->op = GGML_OP_ADD; @@ -3394,6 +3442,12 @@ struct ggml_tensor * ggml_reshape_2d( int64_t ne0, int64_t ne1) { GGML_ASSERT(ggml_is_contiguous(a)); + /* + printf("ggml_reshape_2d: a=[%5" PRId64 ", %5" PRId64 "], ne0=%5" PRId64 ", ne1=%5" PRId64 "\n", + a->ne[0], a->ne[1], + ne0, ne1); + fflush(stdout); + */ GGML_ASSERT(ggml_nelements(a) == ne0*ne1); const int64_t ne[2] = { ne0, ne1 }; @@ -3413,6 +3467,12 @@ struct ggml_tensor * ggml_reshape_3d( int64_t ne1, int64_t ne2) { GGML_ASSERT(ggml_is_contiguous(a)); + /* + printf("ggml_reshape_3d: a=[%5" PRId64 ", %5" PRId64 ", %5" PRId64 "], ne0=%5" PRId64 ", ne1=%5" PRId64 ", ne2=%5" PRId64 "\n", + a->ne[0], a->ne[1], a->ne[2], + ne0, ne1, ne2); + fflush(stdout); + */ GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2); const int64_t ne[3] = { ne0, ne1, ne2 }; @@ -7152,6 +7212,26 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_E5M2 : + { // move to ggml-cpu.c : type_traits[type].from_float(src + start, (char *) dst + start_row * row_size, (int64_t)nrows*n_per_row); + ggml_fp32_to_e5m2_row_ref(src + start, (ggml_e5m2_t*)((char *) dst + start_row * row_size), (int64_t)nrows*n_per_row); + result = nrows * row_size; + } break; + case GGML_TYPE_E4M3 : + { // move to ggml-cpu.c : type_traits[type].from_float(src + start, (char *) dst + start_row * row_size, (int64_t)nrows*n_per_row); + ggml_fp32_to_e4m3_row_ref(src + start, (ggml_e4m3_t*)((char *) dst + start_row * row_size), (int64_t)nrows*n_per_row); + result = nrows * row_size; + } break; + case GGML_TYPE_E4M3_Q: + { // move to ggml-cpu.c : type_traits[type].from_float(src + start, (char *) dst + start_row * row_size, (int64_t)nrows*n_per_row); + quantize_row_e4m3_q_ref(src + start, (block_e4m3_q*)((char *) dst + start_row * row_size), (int64_t)nrows*n_per_row); + result = nrows * row_size; + } break; + case GGML_TYPE_E3M4_Q: + { // move to ggml-cpu.c : type_traits[type].from_float(src + start, (char *) dst + start_row * row_size, (int64_t)nrows*n_per_row); + quantize_row_e3m4_q_ref(src + start, (block_e3m4_q*)((char *) dst + start_row * row_size), (int64_t)nrows*n_per_row); + result = nrows * row_size; + } break; case GGML_TYPE_F16: { size_t elemsize = sizeof(ggml_fp16_t); @@ -7208,3 +7288,142 @@ bool ggml_threadpool_params_match(const struct ggml_threadpool_params * p0, cons if (p0->strict_cpu != p1->strict_cpu ) return false; return memcmp(p0->cpumask, p1->cpumask, GGML_MAX_N_THREADS) == 0; } + + +// ggml_sparse_topk_radix + +GGML_API struct ggml_tensor * ggml_sparse_topk_radix( + struct ggml_context * ctx, + struct ggml_tensor * scores, + int k) { + GGML_ASSERT(scores->type == GGML_TYPE_F32 || scores->type == GGML_TYPE_F16); + GGML_ASSERT(scores->ne[2] == 1 && scores->ne[3] == 1); + GGML_ASSERT(k > 0 && k <= scores->ne[0]); + + struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, k, scores->ne[1]); + ggml_set_op_params_i32(result, 0, k); + result->op = GGML_OP_SPARSE_TOPK_RADIX; + result->src[0] = scores; + return result; +} + +// ggml_indexer_logits_fused +GGML_API struct ggml_tensor * ggml_indexer_logits_fused( + struct ggml_context * ctx, + struct ggml_tensor * q2d, + struct ggml_tensor * k2d, + struct ggml_tensor * w2d, + struct ggml_tensor * k_scale) { + GGML_ASSERT(q2d->type == GGML_TYPE_F32 || q2d->type == GGML_TYPE_F16 || q2d->type == GGML_TYPE_BF16); + GGML_ASSERT(k2d->type == GGML_TYPE_F32 || k2d->type == GGML_TYPE_F16 || k2d->type == GGML_TYPE_BF16); + GGML_ASSERT(w2d->type == GGML_TYPE_F32); + GGML_ASSERT(k_scale->type == GGML_TYPE_F32); + // Shapes: q2d:[D, Tc*H], k2d:[D, kv], w2d:[H, Tc], k_scale:[kv,1] or [kv] + GGML_ASSERT(q2d->ne[0] == k2d->ne[0]); + const int64_t D = q2d->ne[0]; + const int64_t TcH = q2d->ne[1]; + const int64_t kv = k2d->ne[1]; + const int64_t Tc = w2d->ne[1]; + const int64_t H = w2d->ne[0]; + GGML_ASSERT(TcH == Tc*H); + + struct ggml_tensor * out = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, kv, Tc); + out->op = GGML_OP_INDEXER_FUSED; + out->src[0] = q2d; + out->src[1] = k2d; + out->src[2] = w2d; + out->src[3] = k_scale; + (void)D; // silence unused warning in this TU + return out; +} + +// ggml_indexer_logits_fused_ex +GGML_API struct ggml_tensor * ggml_indexer_logits_fused_ex( + struct ggml_context * ctx, + struct ggml_tensor * q2d, + struct ggml_tensor * k2d, + struct ggml_tensor * w2d, + struct ggml_tensor * k_scale, + struct ggml_tensor * starts, + struct ggml_tensor * ends) { + GGML_ASSERT(q2d->type == GGML_TYPE_F32 || q2d->type == GGML_TYPE_F16 || q2d->type == GGML_TYPE_BF16); + GGML_ASSERT(k2d->type == GGML_TYPE_F32 || k2d->type == GGML_TYPE_F16 || k2d->type == GGML_TYPE_BF16); + GGML_ASSERT(w2d->type == GGML_TYPE_F32); + GGML_ASSERT(k_scale->type == GGML_TYPE_F32); + // Shapes: q2d:[D, Tc*H], k2d:[D, kv], w2d:[H, Tc], k_scale:[kv,1] or [kv] + GGML_ASSERT(q2d->ne[0] == k2d->ne[0]); + const int64_t D = q2d->ne[0]; + const int64_t TcH = q2d->ne[1]; + const int64_t kv = k2d->ne[1]; + const int64_t Tc = w2d->ne[1]; + const int64_t H = w2d->ne[0]; + GGML_ASSERT(TcH == Tc*H); + + struct ggml_tensor * out = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, kv, Tc); + out->op = GGML_OP_INDEXER_FUSED; + out->src[0] = q2d; + out->src[1] = k2d; + out->src[2] = w2d; + out->src[3] = k_scale; + out->src[4] = starts; + out->src[5] = ends; + (void)D; + return out; +} + + + +// ggml_sparse_mla_decode_fused +GGML_API struct ggml_tensor * ggml_sparse_mla_decode_fused( + struct ggml_context * ctx, + struct ggml_tensor * q2d, // [Dq, Hq] + struct ggml_tensor * k_cache, // [Dk, Hkv, N_kv] + struct ggml_tensor * v_cache, // [Dv, Hkv, N_kv] + struct ggml_tensor * idx_topk, // [K] + float kq_scale, + float attn_softcap) { + GGML_ASSERT(q2d->ne[2] == 1 && q2d->ne[3] == 1); + GGML_ASSERT(k_cache->ne[2] == v_cache->ne[2]); + GGML_ASSERT(k_cache->ne[1] == v_cache->ne[1]); + struct ggml_tensor * out = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, v_cache->ne[0], q2d->ne[1]); // [Dv, Hq] + out->op = GGML_OP_SPARSE_MLA_DECODE; + out->src[0] = q2d; + out->src[1] = k_cache; + out->src[2] = v_cache; + out->src[3] = idx_topk; + ggml_set_op_params_f32(out, 0, kq_scale); + ggml_set_op_params_f32(out, 1, attn_softcap); + return out; +} + + + +GGML_API struct ggml_tensor * ggml_sparse_topk_radix_ex( + struct ggml_context * ctx, + struct ggml_tensor * scores, + int k, + struct ggml_tensor * starts, + struct ggml_tensor * ends) { + GGML_ASSERT(scores->type == GGML_TYPE_F32 || scores->type == GGML_TYPE_F16); + GGML_ASSERT(scores->ne[2] == 1 && scores->ne[3] == 1); + GGML_ASSERT(k > 0 && k <= scores->ne[0]); + struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, k, scores->ne[1]); + ggml_set_op_params_i32(result, 0, k); + result->op = GGML_OP_SPARSE_TOPK_RADIX; + result->src[0] = scores; + // Thread optional starts/ends to backend via src[1]/src[2]; CPU path ignores them + if (starts) { + // starts must be 1D of length T (scores->ne[1]) + GGML_ASSERT(starts->ne[0] == scores->ne[1]); + result->src[1] = starts; + } else { + result->src[1] = NULL; + } + if (ends) { + GGML_ASSERT(ends->ne[0] == scores->ne[1]); + result->src[2] = ends; + } else { + result->src[2] = NULL; + } + return result; +} diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 88ea9f32f8c..7943799395e 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -372,6 +372,7 @@ class MODEL_ARCH(IntEnum): ARCTIC = auto() DEEPSEEK = auto() DEEPSEEK2 = auto() + DEEPSEEK3_2 = auto() CHATGLM = auto() GLM4 = auto() GLM4_MOE = auto() @@ -460,6 +461,10 @@ class MODEL_TENSOR(IntEnum): FFN_EXP_PROBS_B = auto() ATTN_Q_NORM = auto() ATTN_K_NORM = auto() + ATTN_INDEXER_K_NORM = auto() # deepseek3_2 + ATTN_INDEXER_WEIGHTS_PROJ = auto() # deepseek3_2 + ATTN_INDEXER_WK = auto() # deepseek3_2 + ATTN_INDEXER_WQ_B = auto() # deepseek3_2 LAYER_OUT_NORM = auto() PER_LAYER_TOKEN_EMBD = auto() # gemma3n PER_LAYER_MODEL_PROJ = auto() # gemma3n @@ -712,6 +717,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.ARCTIC: "arctic", MODEL_ARCH.DEEPSEEK: "deepseek", MODEL_ARCH.DEEPSEEK2: "deepseek2", + MODEL_ARCH.DEEPSEEK3_2: "deepseek3_2", MODEL_ARCH.CHATGLM: "chatglm", MODEL_ARCH.GLM4: "glm4", MODEL_ARCH.GLM4_MOE: "glm4moe", @@ -779,6 +785,10 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_SINKS: "blk.{bid}.attn_sinks", MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm", MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm", + MODEL_TENSOR.ATTN_INDEXER_K_NORM: "blk.{bid}.attn_indexer_k_norm", + MODEL_TENSOR.ATTN_INDEXER_WEIGHTS_PROJ: "blk.{bid}.attn_indexer_weights_proj", + MODEL_TENSOR.ATTN_INDEXER_WK: "blk.{bid}.attn_indexer_wk", + MODEL_TENSOR.ATTN_INDEXER_WQ_B: "blk.{bid}.attn_indexer_wq_b", MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm", MODEL_TENSOR.ATTN_POST_NORM: "blk.{bid}.post_attention_norm", MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp", @@ -2746,6 +2756,14 @@ class MODEL_TENSOR(IntEnum): # TODO } +MODEL_TENSORS[MODEL_ARCH.DEEPSEEK3_2] = [ + *MODEL_TENSORS[MODEL_ARCH.DEEPSEEK2], + MODEL_TENSOR.ATTN_INDEXER_K_NORM, + MODEL_TENSOR.ATTN_INDEXER_WEIGHTS_PROJ, + MODEL_TENSOR.ATTN_INDEXER_WK, + MODEL_TENSOR.ATTN_INDEXER_WQ_B, +] + # tensors that will not be serialized MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_ARCH.LLAMA: [ @@ -2788,6 +2806,10 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_ROT_EMBD, ], + MODEL_ARCH.DEEPSEEK3_2: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], MODEL_ARCH.CHATGLM: [ MODEL_TENSOR.ROPE_FREQS, ], diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index c533b55c012..12a0fb2f7ad 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -562,6 +562,22 @@ class TensorNameMap: "layers.{bid}.self_attn.k_norm", # qwen3-embedding ), + MODEL_TENSOR.ATTN_INDEXER_K_NORM: ( + "model.layers.{bid}.self_attn.indexer.k_norm", # deepseek3_2 + ), + + MODEL_TENSOR.ATTN_INDEXER_WEIGHTS_PROJ: ( + "model.layers.{bid}.self_attn.indexer.weights_proj", # deepseek3_2 + ), + + MODEL_TENSOR.ATTN_INDEXER_WK: ( + "model.layers.{bid}.self_attn.indexer.wk", # deepseek3_2 + ), + + MODEL_TENSOR.ATTN_INDEXER_WQ_B: ( + "model.layers.{bid}.self_attn.indexer.wq_b", # deepseek3_2 + ), + MODEL_TENSOR.ROPE_FREQS: ( "language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon ), @@ -1475,6 +1491,23 @@ class TensorNameMap: MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: ( "model.layers.{bid}.shared_head.norm", ), + + # DeepSeek V3.2 sparse attention indexer tensors + MODEL_TENSOR.ATTN_INDEXER_K_NORM: ( + "model.layers.{bid}.self_attn.indexer.k_norm", # deepseek3_2 + ), + + MODEL_TENSOR.ATTN_INDEXER_WEIGHTS_PROJ: ( + "model.layers.{bid}.self_attn.indexer.weights_proj", # deepseek3_2 + ), + + MODEL_TENSOR.ATTN_INDEXER_WK: ( + "model.layers.{bid}.self_attn.indexer.wk", # deepseek3_2 + ), + + MODEL_TENSOR.ATTN_INDEXER_WQ_B: ( + "model.layers.{bid}.self_attn.indexer.wq_b", # deepseek3_2 + ), } # architecture-specific block mappings diff --git a/requirements/requirements-convert_legacy_llama.txt b/requirements/requirements-convert_legacy_llama.txt index f6076142cee..94b766b1a60 100644 --- a/requirements/requirements-convert_legacy_llama.txt +++ b/requirements/requirements-convert_legacy_llama.txt @@ -5,7 +5,7 @@ sentencepiece~=0.2.0 # https://github.com/huggingface/transformers/releases/tag/v4.56.0-Embedding-Gemma-preview # The version is needed to be able to convert Embedding Gemma models to GGUF format: -git+https://github.com/huggingface/transformers@v4.56.0-Embedding-Gemma-preview +git+https://github.com/hnyls2002/transformers@2a8a2a9 # Once Embedding Gemma is officially released, we can switch to: #transformers>=4.57.1,<5.0.0 diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 18cfc76564d..18019488550 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -31,6 +31,9 @@ add_library(llama llama-model.cpp llama-quant.cpp llama-sampling.cpp + llama-sparse-indexer.cpp + llama-sparse-mla-fwd.cpp + llama-sparse-topk.cpp llama-vocab.cpp unicode-data.cpp unicode.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 4e8d54c4193..dedb01b9bde 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -62,6 +62,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_ARCTIC, "arctic" }, { LLM_ARCH_DEEPSEEK, "deepseek" }, { LLM_ARCH_DEEPSEEK2, "deepseek2" }, + { LLM_ARCH_DEEPSEEK3_2, "deepseek3_2" }, { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_GLM4, "glm4" }, { LLM_ARCH_GLM4_MOE, "glm4moe" }, @@ -1401,6 +1402,42 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, }, }, + { + LLM_ARCH_DEEPSEEK3_2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, + { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, + { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, + { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, + { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" }, + { LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + { LLM_TENSOR_ATTN_INDEXER_K_NORM,"blk.%d.attn_indexer_k_norm" }, + { LLM_TENSOR_ATTN_INDEXER_WEIGHTS_PROJ,"blk.%d.attn_indexer_weights_proj" }, + { LLM_TENSOR_ATTN_INDEXER_WK, "blk.%d.attn_indexer_wk" }, + { LLM_TENSOR_ATTN_INDEXER_WQ_B, "blk.%d.attn_indexer_wq_b" }, + }, + }, { LLM_ARCH_PLM, { @@ -2395,6 +2432,12 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + + // DeepSeek V3.2 sparse attention indexer tensors + {LLM_TENSOR_ATTN_INDEXER_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NORM}}, + {LLM_TENSOR_ATTN_INDEXER_WEIGHTS_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_INDEXER_WK, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_INDEXER_WQ_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} diff --git a/src/llama-arch.h b/src/llama-arch.h index b5c6f3d76a6..a43903d758e 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -66,6 +66,7 @@ enum llm_arch { LLM_ARCH_ARCTIC, LLM_ARCH_DEEPSEEK, LLM_ARCH_DEEPSEEK2, + LLM_ARCH_DEEPSEEK3_2, LLM_ARCH_CHATGLM, LLM_ARCH_GLM4, LLM_ARCH_GLM4_MOE, @@ -442,6 +443,10 @@ enum llm_tensor { LLM_TENSOR_NEXTN_HNORM, LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, + LLM_TENSOR_ATTN_INDEXER_K_NORM, // deepseek3_2 + LLM_TENSOR_ATTN_INDEXER_WEIGHTS_PROJ, // deepseek3_2 + LLM_TENSOR_ATTN_INDEXER_WK, // deepseek3_2 + LLM_TENSOR_ATTN_INDEXER_WQ_B, // deepseek3_2 }; enum llm_tensor_layer { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index d8a8b5e647a..8b2ecdb73d7 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1362,7 +1362,14 @@ void llama_context::output_reorder() { // uint32_t llama_context::graph_max_nodes() const { - return std::max(1024u, 8u*model.n_tensors()); + uint32_t base = std::max(1024u, 8u*model.n_tensors()); + if (model.arch == LLM_ARCH_DEEPSEEK3_2) { + // The DeepSeek V3.2 sparse-attention graph has significantly higher node fanout. + // Provide a cushion to avoid near-boundary meta pool overflows. + uint32_t ds_base = std::max(base, 7168u*model.n_tensors()); + base = ds_base + 16384u; + } + return base; } llm_graph_result * llama_context::get_gf_res_reserve() const { diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 90cd885a60a..26122a7f32e 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -347,6 +347,25 @@ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) { mctx->set_input_v_idxs(self_v_idxs, ubatch); mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + + /* + printf("[SET_INPUT_KV][base] micro=[%lld,%lld,%lld,%lld] full2d=%p\n", + (long long) self_kq_mask->ne[0], (long long) self_kq_mask->ne[1], + (long long) self_kq_mask->ne[2], (long long) self_kq_mask->ne[3], + (void*) self_kq_mask_full_2d); + fflush(stdout); + */ + + if (self_kq_mask_full_2d) { + /* + printf("[SET_INPUT_KV][base] full2d=[%lld,%lld,%lld,%lld]\n", + (long long) self_kq_mask_full_2d->ne[0], (long long) self_kq_mask_full_2d->ne[1], + (long long) self_kq_mask_full_2d->ne[2], (long long) self_kq_mask_full_2d->ne[3]); + fflush(stdout); + */ + mctx->set_input_kq_mask_full_2d(self_kq_mask_full_2d, ubatch, cparams.causal_attn); + } + } bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) { @@ -356,6 +375,14 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) { bool res = true; + /* + printf("[SET_INPUT_KV][iswa] micro(base)=[%lld,%lld,%lld,%lld]\n", + (long long) self_kq_mask->ne[0], (long long) self_kq_mask->ne[1], + (long long) self_kq_mask->ne[2], (long long) self_kq_mask->ne[3]); + fflush(stdout); + */ + + res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there @@ -1497,6 +1524,13 @@ static std::unique_ptr build_attn_inp_kv_impl( inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream); ggml_set_input(inp->self_kq_mask); + // Full-width KV-aware 2D mask for indexer/sparse MLA (DeepSeek V3.2 only) + if (mctx_cur->is_arch_deepseek_v3_2()) { + const auto kv_size_full = mctx_cur->get_kv_size(); + inp->self_kq_mask_full_2d = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, kv_size_full, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + ggml_set_input(inp->self_kq_mask_full_2d); + } + inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; } diff --git a/src/llama-graph.h b/src/llama-graph.h index 34b984afeb0..f32363ce854 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -286,6 +286,7 @@ class llm_graph_input_attn_kv : public llm_graph_input_i { ggml_tensor * get_v_idxs() const { return self_v_idxs; } ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } + ggml_tensor * get_kq_mask_full_2d() const { return self_kq_mask_full_2d; } ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa] @@ -293,6 +294,9 @@ class llm_graph_input_attn_kv : public llm_graph_input_i { ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + // Full-width KV-aware mask for indexer and sparse MLA: F32 [kv_size, PAD(T)] 2D + ggml_tensor * self_kq_mask_full_2d = nullptr; // F32 [kv_size, PAD(T)] + // note: these have to be copies because in order to be able to reuse a graph, its inputs // need to carry these parameters with them. otherwise, they can point to freed // llm_graph_params from a previous batch, causing stack-use-after-return diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 816f2d5de59..61ddfa10fd2 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -43,7 +43,7 @@ llama_kv_cache::llama_kv_cache( auto it = ctx_map.find(buft); if (it == ctx_map.end()) { ggml_init_params params = { - /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()), + /*.mem_size =*/ size_t(3u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; @@ -142,7 +142,30 @@ llama_kv_cache::llama_kv_cache( map_layer_ids[il] = layers.size(); - layers.push_back({ il, k, v, k_stream, v_stream, }); + // initialize kv_layer entry + layers.push_back({}); + auto & lyr = layers.back(); + lyr.il = il; + lyr.k = k; + lyr.v = v; + lyr.k_stream = std::move(k_stream); + lyr.v_stream = std::move(v_stream); + + // Allocate Indexer K cache if the model layer has indexer tensors + if (model.layers[il].attn_indexer_wk != nullptr) { + const int64_t index_head_dim = model.layers[il].attn_indexer_wk->ne[1]; + // Use F32 for indexer K cache to preserve top-k ranking stability + ggml_tensor * kidx = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, index_head_dim, kv_size, n_stream); + ggml_format_name(kidx, "cache_k_indexer_l%d", il); + + std::vector kidx_stream; + for (uint32_t s = 0; s < n_stream; ++s) { + kidx_stream.push_back(ggml_view_2d(ctx, kidx, index_head_dim, kv_size, kidx->nb[1], s*kidx->nb[2])); + } + + lyr.k_indexer = kidx; + lyr.k_indexer_stream = std::move(kidx_stream); + } } if (reuse) { @@ -199,6 +222,10 @@ llama_kv_cache::llama_kv_cache( debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0; } +bool llama_kv_cache::is_arch_deepseek_v3_2() const { + return model.arch == LLM_ARCH_DEEPSEEK3_2; +} + void llama_kv_cache::clear(bool data) { for (uint32_t s = 0; s < n_stream; ++s) { v_cells[s].reset(); @@ -1056,6 +1083,45 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm return ggml_set_rows(ctx, k, k_cur, k_idxs); } + + +ggml_tensor * llama_kv_cache::get_k_indexer(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const { + const int32_t ikv = map_layer_ids.at(il); + ggml_tensor * kidx = layers[ikv].k_indexer; + GGML_ASSERT(kidx && "Indexer K cache not allocated for this layer"); + + const uint64_t kv_size = get_size(); + const uint32_t ns = sinfo.s1 - sinfo.s0 + 1; + + // view as [D_index, n_kv, ns] + return ggml_view_3d(ctx, kidx, + kidx->ne[0], n_kv, ns, + ggml_row_size(kidx->type, kidx->ne[0]), + ggml_row_size(kidx->type, kidx->ne[0])*kv_size, + ggml_row_size(kidx->type, kidx->ne[0])*kv_size*sinfo.s0); +} + +ggml_tensor * llama_kv_cache::get_k_indexer_full(ggml_context * ctx, int32_t il, const slot_info & sinfo) const { + const int32_t ikv = map_layer_ids.at(il); + ggml_tensor * kidx = layers[ikv].k_indexer; + GGML_ASSERT(kidx && "Indexer K cache not allocated for this layer"); + const uint64_t kv_size = get_size(); + const uint32_t ns = sinfo.s1 - sinfo.s0 + 1; + if (kidx->ne[2] > 1) { + ggml_tensor * kidx_2d = ggml_reshape_2d(ctx, kidx, kidx->ne[0], kv_size * kidx->ne[2]); + return ggml_view_3d(ctx, kidx_2d, + kidx->ne[0], kv_size, ns, + ggml_row_size(kidx->type, kidx->ne[0]), + ggml_row_size(kidx->type, kidx->ne[0]) * kv_size, + ggml_row_size(kidx->type, kidx->ne[0]) * kv_size * sinfo.s0); + } + return ggml_view_3d(ctx, kidx, + kidx->ne[0], kv_size, ns, + ggml_row_size(kidx->type, kidx->ne[0]), + ggml_row_size(kidx->type, kidx->ne[0]) * kv_size, + ggml_row_size(kidx->type, kidx->ne[0]) * kv_size * sinfo.s0); +} + ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const { GGML_UNUSED(sinfo); @@ -1172,7 +1238,6 @@ void llama_kv_cache::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ub } else { // note: the V cache is transposed when not using flash attention const int64_t kv_size = get_size(); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max(); for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { @@ -1308,6 +1373,40 @@ size_t llama_kv_cache::total_size() const { return size; } +void llama_kv_cache::set_input_kq_mask_full_2d(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + float * data = (float *) dst->data; + + const int64_t kv_size = get_size(); + const int64_t T_pad = dst->ne[1]; + const int64_t T = ubatch->n_tokens; + + std::fill(data, data + kv_size * T_pad, -INFINITY); + + // For each token i in the batch, mark valid KV positions across the full cache width + for (uint32_t i = 0; i < T; ++i) { + // Use first seq_id for mask (consistent with existing set_input_kq_mask) + const llama_seq_id seq_id = ubatch->seq_id[i][0]; + const auto & cells = v_cells[seq_to_stream[seq_id]]; + const llama_pos p1 = ubatch->pos[i]; + + for (uint32_t j = 0; j < kv_size; ++j) { + if (cells.is_empty(j)) continue; + if (!cells.seq_has(j, seq_id)) continue; + const llama_pos p0 = cells.pos_get(j); + if (causal_attn && p0 > p1) continue; + if (is_masked_swa(p0, p1)) continue; + + // Row major: ne[0] = kv_size, ne[1] = T_pad + // offset = row + col*row_stride + // Row major: ne[0] = kv_size, ne[1] = T_pad + // offset = row + col*row_stride + // For valid entries: apply ALiBi if enabled, else 0.0f + data[j + i * kv_size] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f; + } + } +} + size_t llama_kv_cache::size_k_bytes() const { size_t size_k_bytes = 0; @@ -1949,7 +2048,6 @@ bool llama_kv_cache_context::apply() { return true; } - kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]); n_kv = kv->get_n_kv(sinfos[i_cur]); @@ -1970,10 +2068,71 @@ uint32_t llama_kv_cache_context::get_n_kv() const { return n_kv; } +bool llama_kv_cache_context::is_arch_deepseek_v3_2() const { + return kv->is_arch_deepseek_v3_2(); +} + +void llama_kv_cache_context::set_input_kq_mask_full_2d(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { + kv->set_input_kq_mask_full_2d(dst, ubatch, causal_attn); +} + +uint32_t llama_kv_cache_context::get_kv_size() const { + return kv->get_size(); +} + ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const { return kv->get_k(ctx, il, n_kv, sinfos[i_cur]); } +ggml_tensor * llama_kv_cache_context::get_k_indexer(ggml_context * ctx, int32_t il) const { + return kv->get_k_indexer(ctx, il, n_kv, sinfos[i_cur]); +} + +ggml_tensor * llama_kv_cache_context::get_k_full(ggml_context * ctx, int32_t il) const { + const auto & sinfo = sinfos[i_cur]; + return kv->get_k(ctx, il, kv->get_size(), sinfo); +} + +ggml_tensor * llama_kv_cache_context::get_v_full(ggml_context * ctx, int32_t il) const { + const auto & sinfo = sinfos[i_cur]; + return kv->get_v(ctx, il, kv->get_size(), sinfo); +} + +ggml_tensor * llama_kv_cache::get_k_full(ggml_context * ctx, int32_t il) const { + const int32_t ikv = map_layer_ids.at(il); + auto * k = layers[ikv].k; + const uint64_t kv_size = get_size(); + const uint64_t n_embd_k_gqa = k->ne[0]; + return ggml_view_3d(ctx, k, + n_embd_k_gqa, kv_size, k->ne[2], + ggml_row_size(k->type, n_embd_k_gqa), + ggml_row_size(k->type, n_embd_k_gqa), + ggml_row_size(k->type, n_embd_k_gqa)*kv_size); +} + +ggml_tensor * llama_kv_cache::get_v_full(ggml_context * ctx, int32_t il) const { + const int32_t ikv = map_layer_ids.at(il); + auto * v = layers[ikv].v; + const uint64_t kv_size = get_size(); + const uint64_t n_embd_v_gqa = v->ne[0]; + return ggml_view_3d(ctx, v, + n_embd_v_gqa, kv_size, v->ne[2], + ggml_row_size(v->type, n_embd_v_gqa), + ggml_row_size(v->type, n_embd_v_gqa), + ggml_row_size(v->type, n_embd_v_gqa)*kv_size); +} + +ggml_tensor * llama_kv_cache_context::get_k_indexer_full(ggml_context * ctx, int32_t il) const { + // Full-width indexer K view for DeepSeek V3.2 + const auto & sinfo = sinfos[i_cur]; + return kv->get_k_indexer_full(ctx, il, sinfo); +} + +ggml_tensor * llama_kv_cache_context::cpy_k_indexer(ggml_context * ctx, ggml_tensor * kidx_cur, ggml_tensor * k_idxs, int32_t il) const { + const auto & sinfo = sinfos[i_cur]; + return kv->cpy_k_indexer(ctx, kidx_cur, k_idxs, il, sinfo); +} + ggml_tensor * llama_kv_cache_context::get_v(ggml_context * ctx, int32_t il) const { return kv->get_v(ctx, il, n_kv, sinfos[i_cur]); } @@ -2018,3 +2177,23 @@ uint32_t llama_kv_cache::get_padding(const llama_cparams & cparams) { // the FA kernels require padding to avoid extra runtime boundary checks return cparams.flash_attn ? 256u : 32u; } + + +// Lightning Indexer: write K_indexer rows into cache +ggml_tensor * llama_kv_cache::cpy_k_indexer(ggml_context * ctx, ggml_tensor * kidx_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const { + GGML_UNUSED(sinfo); + const int32_t ikv = map_layer_ids.at(il); + ggml_tensor * kidx = layers[ikv].k_indexer; + GGML_ASSERT(kidx && "Indexer K cache not allocated for this layer"); + ggml_tensor * cur2d = kidx_cur; + if (kidx_cur->ne[2] > 1) { + GGML_ASSERT(ggml_row_size(kidx_cur->type, kidx_cur->ne[0]) == kidx_cur->nb[1]); + cur2d = ggml_view_2d(ctx, kidx_cur, kidx_cur->ne[0], kidx_cur->ne[2], kidx_cur->nb[2], 0); + } + const int64_t n_stream = kidx->ne[2]; + if (n_stream > 1) { + const int64_t kv_size = get_size(); + kidx = ggml_reshape_2d(ctx, kidx, kidx->ne[0], kv_size*n_stream); + } + return ggml_set_rows(ctx, kidx, cur2d, k_idxs); +} diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 85f0663d8c1..dc55ea8a012 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -21,7 +21,11 @@ class llama_kv_cache : public llama_memory_i { public: static uint32_t get_padding(const llama_cparams & cparams); + // Return true if the underlying model architecture is DeepSeek V3.2 (sparse attention) + bool is_arch_deepseek_v3_2() const; + struct stream_copy_info { + bool empty() const { assert(ssrc.size() == sdst.size()); return ssrc.empty(); @@ -146,10 +150,19 @@ class llama_kv_cache : public llama_memory_i { // get views of the current state of the cache ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; + ggml_tensor * get_k_indexer(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; + + // DeepSeek V3.2: full-width indexer K view + ggml_tensor * get_k_indexer_full(ggml_context * ctx, int32_t il, const slot_info & sinfo) const; + + // Full-width KV accessors for sparse paths (ignore micro-window) + ggml_tensor * get_k_full(ggml_context * ctx, int32_t il) const; + ggml_tensor * get_v_full(ggml_context * ctx, int32_t il) const; // store k_cur and v_cur in the cache based on the provided head location ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const; ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const; + ggml_tensor * cpy_k_indexer(ggml_context * ctx, ggml_tensor * kidx_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const; // // preparation API @@ -181,11 +194,14 @@ class llama_kv_cache : public llama_memory_i { void set_input_k_shift(ggml_tensor * dst) const; + void set_input_kq_mask_full_2d(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; + void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; private: const llama_model & model; + const llama_hparams & hparams; struct kv_layer { @@ -198,6 +214,10 @@ class llama_kv_cache : public llama_memory_i { std::vector k_stream; std::vector v_stream; + + // DeepSeek V3.2: Lightning Indexer K cache per layer + ggml_tensor * k_indexer = nullptr; // [D_index, kv_size, n_stream] + std::vector k_indexer_stream; // views per stream }; bool v_trans = true; // the value tensor is transposed @@ -313,10 +333,21 @@ class llama_kv_cache_context : public llama_memory_context_i { // uint32_t get_n_kv() const; + uint32_t get_kv_size() const; + + + bool is_arch_deepseek_v3_2() const; // get views of the current state of the cache ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; + ggml_tensor * get_k_indexer(ggml_context * ctx, int32_t il) const; + ggml_tensor * get_k_indexer_full(ggml_context * ctx, int32_t il) const; // DeepSeek V3.2 only + // Full-width KV accessors for sparse paths (ignore micro-window) + void set_input_kq_mask_full_2d(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; + + ggml_tensor * get_k_full(ggml_context * ctx, int32_t il) const; + ggml_tensor * get_v_full(ggml_context * ctx, int32_t il) const; // store k_cur and v_cur in the cache based on the provided head location // note: the heads in k_cur and v_cur should be layed out contiguously in memory @@ -326,6 +357,7 @@ class llama_kv_cache_context : public llama_memory_context_i { // - v_idxs [n_tokens] or [n_tokens*n_embd_v_gqa] depending if V cache is transposed ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const; ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const; + ggml_tensor * cpy_k_indexer(ggml_context * ctx, ggml_tensor * kidx_cur, ggml_tensor * k_idxs, int32_t il) const; // create destination indices for each head of the current batch for where it would be written in the KV cache // the indices address the global KV cache (not per stream) - this is not relevant for the user of this API, but diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 63655bf6517..312e347a4c7 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -10,6 +10,9 @@ #include "llama-kv-cache-iswa.h" #include "llama-memory-hybrid.h" #include "llama-memory-recurrent.h" +#include "llama-sparse-indexer.h" +#include "llama-sparse-mla-fwd.h" +#include "llama-sparse-topk.h" #include "ggml-cpp.h" @@ -24,6 +27,31 @@ #include #include #include +#include + + +// Debug helpers for tracking add() operand layouts during sparse attention +static void llama_dbg_tensor(const char * tag, struct ggml_tensor * t, int il) { + if (!t) { + printf("DBG %s L%d: null\n", tag, il); + fflush(stdout); + return; + } + printf("DBG %s L%d: ne=[%lld,%lld,%lld,%lld] nb=[%zu,%zu,%zu,%zu] type=%d cont=%d rowcont=%d\n", + tag, il, + (long long) t->ne[0], (long long) t->ne[1], (long long) t->ne[2], (long long) t->ne[3], + t->nb[0], t->nb[1], t->nb[2], t->nb[3], + (int) t->type, + (int) ggml_is_contiguous(t), (int) ggml_is_contiguous_rows(t)); + fflush(stdout); +} + +static struct ggml_tensor * llama_add_dbg(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, + const char * where, int il) { + llama_dbg_tensor(where, a, il); + llama_dbg_tensor(where, b, il); + return ggml_add(ctx, a, b); +} const char * llm_type_name(llm_type type) { switch (type) { @@ -278,6 +306,11 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w { op_tensor = ggml_scale(ctx, w, 1.0f); } break; + case GGML_OP_NORM: + { + ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); + op_tensor = ggml_norm(ctx, a, 1e-5f); + } break; default: GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name); } @@ -1491,6 +1524,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { } } break; case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_DEEPSEEK3_2: { bool is_lite = (hparams.n_layer == 27); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -4464,6 +4498,99 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } } break; + case LLM_ARCH_DEEPSEEK3_2: + { + const bool is_lite = (hparams.n_layer == 27); + + const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k_mla = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k; + const int64_t n_embd_head_v_mla = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v; + + const int64_t n_embd_head_qk_rope = hparams.n_rot; + const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; + + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + if (!is_lite) { + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); + } + + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + + if (!is_lite) { + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); + } else { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); + } + + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, 0); + + // note: only old legacy GGUF files will have the unsplit wkv_b tensor in + if (is_mla) { + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, 0); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0); + } else { + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v_mla)}, 0); + } + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + + // Sparse attention tensors (ATTN_INDEXER_*) + // These are part of the DeepSeek V3.2 sparse attention mechanism (indexer) + const int64_t index_head_dim = 128; // From VLLM: config.index_head_dim + const int64_t index_n_heads = 64; // From VLLM: config.index_n_heads + layer.attn_indexer_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_INDEXER_K_NORM, "weight", i), {index_head_dim}, 0); + layer.attn_indexer_k_norm_bias = create_tensor(tn(LLM_TENSOR_ATTN_INDEXER_K_NORM, "bias", i), {index_head_dim}, TENSOR_NOT_REQUIRED); + layer.attn_indexer_weights_proj = create_tensor(tn(LLM_TENSOR_ATTN_INDEXER_WEIGHTS_PROJ, "weight", i), {n_embd, index_n_heads}, 0); + layer.attn_indexer_wk = create_tensor(tn(LLM_TENSOR_ATTN_INDEXER_WK, "weight", i), {n_embd, index_head_dim}, 0); + layer.attn_indexer_wq_b = create_tensor(tn(LLM_TENSOR_ATTN_INDEXER_WQ_B, "weight", i), {q_lora_rank, index_n_heads * index_head_dim}, 0); + } + } break; case LLM_ARCH_PLM: { const int64_t n_embd_head_qk_rope = hparams.n_rot; @@ -6202,7 +6329,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); } - if (arch == LLM_ARCH_DEEPSEEK2) { + if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_DEEPSEEK3_2) { LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); @@ -13601,6 +13728,431 @@ struct llm_build_deepseek2 : public llm_graph_context { } }; +struct llm_build_deepseek3_2 : public llm_graph_context { + llm_build_deepseek3_2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const char * ENV_SPARSE_DEBUG = getenv("LLAMA_SPARSE_DEBUG"); + const bool dbg = (ENV_SPARSE_DEBUG && atoi(ENV_SPARSE_DEBUG) != 0); + + bool is_lite = (hparams.n_layer == 27); + + const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k; + const int64_t n_embd_head_v = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v; + + const int64_t n_embd_head_qk_rope = hparams.n_rot; + const int64_t n_embd_head_qk_nope = n_embd_head_k - n_embd_head_qk_rope; + + const uint32_t kv_lora_rank = hparams.n_lora_kv; + + // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly. + // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. + const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale)); + const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(n_embd_head_k)); + const float attn_factor = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)); + + if (dbg) printf("[deepseek3_2] layer init: attn_factor=%g mscale=%g dense_kq_scale=%g (n_embd_head_k=%lld)\n", attn_factor, mscale, kq_scale, (long long) n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // Sparse attention indexer for DeepSeek V3.2 + // This should be computed BEFORE the regular attention using the normalized hidden state + bool use_sparse_attention = false; + int64_t top_k = 0; + auto cb_wrapper = [this](ggml_tensor * cur, const char * name, int il) { + this->cb(cur, name, il); + }; + + if (model.layers[il].attn_indexer_k_norm != nullptr) { + // Use the new sparse attention implementation for indexer computation + + // Defer KV-aware top-k computation to the attention block using KV cache + use_sparse_attention = true; + top_k = 0; + } + + // self_attention + { + ggml_tensor * q = NULL; + if (!is_lite) { + q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); + cb(q, "q", il); + + q = build_norm(q, + model.layers[il].attn_q_a_norm, nullptr, + LLM_NORM_RMS, il); + cb(q, "q", il); + + q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q); + cb(q, "q", il); + } else { + q = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(q, "q", il); + } + + // split into {n_embd_head_qk_nope, n_head, n_tokens} + ggml_tensor * q_nope = ggml_view_3d(ctx0, q, + n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(q->type, n_embd_head_k), + ggml_row_size(q->type, n_embd_head_k) * n_head, + 0); + cb(q_nope, "q_nope", il); + + // and {n_embd_head_qk_rope, n_head, n_tokens} + ggml_tensor * q_pe = ggml_view_3d(ctx0, q, + n_embd_head_qk_rope, n_head, n_tokens, + ggml_row_size(q->type, n_embd_head_k), + ggml_row_size(q->type, n_embd_head_k) * n_head, + ggml_row_size(q->type, n_embd_head_qk_nope)); + cb(q_pe, "q_pe", il); + + ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); + cb(kv_cmpr_pe, "kv_cmpr_pe", il); + + // split into {kv_lora_rank, n_tokens} + ggml_tensor * kv_cmpr = ggml_view_2d(ctx0, kv_cmpr_pe, + kv_lora_rank, n_tokens, + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), + 0); + cb(kv_cmpr, "kv_cmpr", il); + + // and {n_embd_head_qk_rope, 1, n_tokens} + ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe, + n_embd_head_qk_rope, 1, n_tokens, + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank)); + cb(k_pe, "k_pe", il); + + q_pe = ggml_rope_ext(ctx0, q_pe, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(q_pe, "q_pe", il); + + k_pe = ggml_rope_ext(ctx0, k_pe, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(k_pe, "k_pe", il); + + kv_cmpr = build_norm(kv_cmpr, + model.layers[il].attn_kv_a_norm, nullptr, + LLM_NORM_RMS, il); + cb(kv_cmpr, "kv_cmpr", il); + + // Declare Qcur, Kcur, Vcur at higher scope for sparse attention + ggml_tensor * Qcur = nullptr; + ggml_tensor * Kcur = nullptr; + ggml_tensor * Vcur = nullptr; + + if (is_mla) { + // {n_embd_head_qk_nope, n_tokens, n_head} + q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); + cb(q_nope, "q_nope_perm", il); + + // {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head} + ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, model.layers[il].wk_b, q_nope); + cb(q_nope_absorbed, "q_nope_absorbed", il); + + // {kv_lora_rank, n_head, n_tokens} + q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3); + cb(q_nope_absorbed, "q_nope_absorbed_perm", il); + + // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens} + // note: rope must go first for in-place context shifting in build_rope_shift() + Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0); + cb(Qcur, "Qcur", il); + + kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens); + cb(kv_cmpr, "kv_cmpr_reshape", il); + + // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens} + Kcur = ggml_concat(ctx0, k_pe, kv_cmpr, 0); + cb(Kcur, "Kcur", il); + + // {kv_lora_rank, 1, n_tokens} + Vcur = kv_cmpr; + cb(Vcur, "Vcur", il); + + // Apply sparse attention if available, otherwise use regular attention + if (use_sparse_attention) {{ + const auto * mctx_cur = inp_attn->mctx; + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, Kcur, inp_attn->get_k_idxs(), il)); + ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, Vcur, inp_attn->get_v_idxs(), il)); + ggml_build_forward_expand(gf, inp_attn->get_kq_mask()); + } + + // Use sparse attention with top-k tokens (KV-aware) + { + const auto * mctx_cur2 = inp_attn->mctx; + // Use full-width KV cache for sparse MLA to match indexer indices + ggml_tensor * Kcache = mctx_cur2->get_k_full(ctx0, il); + ggml_tensor * Vcache = mctx_cur2->get_v_full(ctx0, il); + ggml_tensor * KQmask2 = inp_attn->get_kq_mask_full_2d(); + + const char *env_topk = getenv("LLAMA_SPARSE_TOPK"); + top_k = env_topk ? std::max(1, atoll(env_topk)) : 2048; + ggml_build_forward_expand(gf, KQmask2); + { + int64_t used_kv = mctx_cur2->get_n_kv(); + int64_t n_kv_cache = (int64_t) Kcache->ne[2]; + ggml_tensor * Kindexer_full = mctx_cur2->get_k_indexer_full(ctx0, il); + int64_t n_kv_indexer = Kindexer_full ? (int64_t) Kindexer_full->ne[1] : n_kv_cache; + int64_t available_kv = std::min(used_kv, std::min(n_kv_cache, n_kv_indexer)); + top_k = std::min(top_k, available_kv); + } + + ggml_tensor * kvaware_indices = llama::sparse_attn_indexer::build_kvaware_topk_indices( + ctx0, model, il, cur, n_tokens, mctx_cur2, inp_attn->get_k_idxs(), KQmask2, top_k, + inp_pos, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, + cb_wrapper, gf, sched, backend_cpu); + cur = llama::sparse_mla_fwd::apply_sparse_attention_kvaware( + ctx0, Qcur, Kcache, Vcache, kvaware_indices, n_tokens, top_k, kq_scale, KQmask2, hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f, cb_wrapper); + // Sanity checks for MLA sparse attention output vs expected V-dim (kv_lora_rank) + if (cur->ne[0] != (int64_t) kv_lora_rank) { + printf("[SPARSE-DBG-MLA] L%d: sparse attn out Dv=%" PRId64 " but kv_lora_rank=%u (mismatch)\n", il, cur->ne[0], kv_lora_rank); + } + if (dbg && model.layers[il].wv_b) { + printf("[SPARSE-DBG-MLA] L%d: wv_b dims=[%" PRId64 ", %" PRId64 "] expected=[%u, %" PRId64 "]\n", + il, (int64_t) model.layers[il].wv_b->ne[0], (int64_t) model.layers[il].wv_b->ne[1], kv_lora_rank, (int64_t) n_embd_head_v); + } + GGML_ASSERT(cur->ne[0] == (int64_t) kv_lora_rank); + if (model.layers[il].wv_b) { + GGML_ASSERT(model.layers[il].wv_b->ne[0] == (int64_t) kv_lora_rank); + GGML_ASSERT(model.layers[il].wv_b->ne[1] == (int64_t) n_embd_head_v); + } + } + + /* keep sparse attention output on device to avoid backend hops */ + + // Project kv_lora_rank -> n_embd_head_v per head using wv_b and flatten heads before WO + ggml_tensor * cur_perm = ggml_permute(ctx0, cur, 0, 2, 1, 3); // [kv_lora_rank, n_tokens, n_head] + cb(cur_perm, "sparse_attn_perm_kvT_H", il); + + ggml_tensor * cur_proj = ggml_mul_mat(ctx0, model.layers[il].wv_b, cur_perm); // [n_embd_head_v, n_tokens, n_head] + cb(cur_proj, "sparse_attn_vproj", il); + + ggml_tensor * cur_proj_perm = ggml_permute(ctx0, cur_proj, 0, 2, 1, 3); // [n_embd_head_v, n_head, n_tokens] + cb(cur_proj_perm, "sparse_attn_vproj_perm", il); + + cur_proj_perm = ggml_cont(ctx0, cur_proj_perm); + ggml_tensor * cur2d = ggml_reshape_2d(ctx0, cur_proj_perm, n_head * n_embd_head_v, n_tokens); + cb(cur2d, "sparse_attn_flat", il); + + // Apply output projection for sparse attention + cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur2d); + cb(cur, "sparse_attn_out", il); + + // Log that we're using sparse attention + LLAMA_LOG_DEBUG("DeepSeek V3.2: Using sparse attention with top-%d tokens for layer %d\n", + (int)top_k, il); + } else { + // note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group) + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].wv_b, kq_scale, il); + } + } else { + ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cmpr); + cb(kv, "kv", il); + + // split into {n_embd_head_qk_nope, n_head, n_tokens} + ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, + n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v), + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v) * n_head, + 0); + cb(k_nope, "k_nope_view", il); + + // and {n_embd_head_v, n_head, n_tokens} + Vcur = ggml_view_3d(ctx0, kv, + n_embd_head_v, n_head, n_tokens, + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v), + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v) * n_head, + ggml_row_size(kv->type, n_embd_head_qk_nope)); + cb(Vcur, "Vcur_view", il); + + Vcur = ggml_cont(ctx0, Vcur); + cb(Vcur, "Vcur_cont", il); + + // note: rope must go first for in-place context shifting in build_rope_shift() + Qcur = ggml_concat(ctx0, q_pe, q_nope, 0); + cb(Qcur, "Qcur", il); + + Kcur = ggml_concat(ctx0, ggml_repeat(ctx0, k_pe, q_pe), k_nope, 0); + cb(Kcur, "Kcur", il); + + // Apply sparse attention if available, otherwise use regular attention + if (use_sparse_attention) {{ + const auto * mctx_cur = inp_attn->mctx; + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, Kcur, inp_attn->get_k_idxs(), il)); + ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, Vcur, inp_attn->get_v_idxs(), il)); + ggml_build_forward_expand(gf, inp_attn->get_kq_mask()); + } + + // Use sparse attention with top-k tokens (KV-aware) + { + const auto * mctx_cur2 = inp_attn->mctx; + // Use full-width KV cache for sparse MLA to match indexer indices + ggml_tensor * Kcache = mctx_cur2->get_k_full(ctx0, il); + ggml_tensor * Vcache = mctx_cur2->get_v_full(ctx0, il); + ggml_tensor * KQmask2 = inp_attn->get_kq_mask_full_2d(); + + const char *env_topk = getenv("LLAMA_SPARSE_TOPK"); + top_k = env_topk ? std::max(1, atoll(env_topk)) : 2048; + ggml_build_forward_expand(gf, KQmask2); + { + int64_t used_kv = mctx_cur2->get_n_kv(); + int64_t n_kv_cache = (int64_t) Kcache->ne[2]; + ggml_tensor * Kindexer_full = mctx_cur2->get_k_indexer_full(ctx0, il); + int64_t n_kv_indexer = Kindexer_full ? (int64_t) Kindexer_full->ne[1] : n_kv_cache; + int64_t available_kv = std::min(used_kv, std::min(n_kv_cache, n_kv_indexer)); + top_k = std::min(top_k, available_kv); + } + ggml_tensor * kvaware_indices = llama::sparse_attn_indexer::build_kvaware_topk_indices( + ctx0, model, il, cur, n_tokens, mctx_cur2, inp_attn->get_k_idxs(), KQmask2, top_k, + inp_pos, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, + cb_wrapper, gf, sched, backend_cpu); + cur = llama::sparse_mla_fwd::apply_sparse_attention_kvaware( + ctx0, Qcur, Kcache, Vcache, kvaware_indices, n_tokens, top_k, kq_scale, KQmask2, hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f, cb_wrapper); + // Sanity checks for MHA sparse attention output vs expected V-dim (n_embd_head_v) + if (cur->ne[0] != (int64_t) n_embd_head_v) { + printf("[SPARSE-DBG-MHA] L%d: sparse attn out Dv=%" PRId64 " but n_embd_head_v=%" PRId64 " (mismatch)\n", il, (int64_t) cur->ne[0], (int64_t) n_embd_head_v); + } + GGML_ASSERT(cur->ne[0] == (int64_t) n_embd_head_v); + } + + /* keep sparse attention output on device to avoid backend hops */ + + // Flatten heads before WO + ggml_tensor * cur_perm2 = ggml_permute(ctx0, cur, 0, 2, 1, 3); // [n_embd_head_v, n_tokens, n_head] + cb(cur_perm2, "sparse_attn_perm_vT_H", il); + + cur_perm2 = ggml_cont(ctx0, cur_perm2); + ggml_tensor * cur2d2 = ggml_reshape_2d(ctx0, cur_perm2, n_head * n_embd_head_v, n_tokens); + cb(cur2d2, "sparse_attn_flat", il); + + // Apply output projection for sparse attention + cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur2d2); + cb(cur, "sparse_attn_out", il); + // ensure contiguous layout for subsequent broadcast adds (CUDA) + cur = ggml_cont(ctx0, cur); + + // Log that we're using sparse attention + LLAMA_LOG_DEBUG("DeepSeek V3.2: Using sparse attention with top-%d tokens for layer %d\n", + (int)top_k, il); + } else { + // note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups) + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + } + } + + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + if ((uint32_t) il < hparams.n_layer_dense_lead) { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(moe_out, "ffn_moe_out", il); + + // FFN shared expert + { + ggml_tensor * ffn_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + struct llm_build_bitnet : public llm_graph_context { llm_build_bitnet(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -19465,6 +20017,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_DEEPSEEK3_2: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_CHATGLM: { llm = std::make_unique(*this, params); @@ -19765,6 +20321,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ARCTIC: case LLM_ARCH_DEEPSEEK: case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_DEEPSEEK3_2: case LLM_ARCH_PLM: case LLM_ARCH_CHATGLM: case LLM_ARCH_GLM4: diff --git a/src/llama-model.h b/src/llama-model.h index d73ce969323..d609f5f5443 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -380,6 +380,13 @@ struct llama_layer { // openai-moe struct ggml_tensor * attn_sinks = nullptr; + // deepseek3_2 sparse attention + struct ggml_tensor * attn_indexer_k_norm = nullptr; + struct ggml_tensor * attn_indexer_k_norm_bias = nullptr; + struct ggml_tensor * attn_indexer_weights_proj = nullptr; + struct ggml_tensor * attn_indexer_wk = nullptr; + struct ggml_tensor * attn_indexer_wq_b = nullptr; + struct llama_layer_posnet posnet; struct llama_layer_convnext convnext; diff --git a/src/llama-sparse-indexer.cpp b/src/llama-sparse-indexer.cpp new file mode 100644 index 00000000000..eef79a4ee57 --- /dev/null +++ b/src/llama-sparse-indexer.cpp @@ -0,0 +1,497 @@ +#include "llama-sparse-indexer.h" +#include "llama-model.h" +#include "llama-impl.h" +#include "llama-sparse-topk.h" + +#include +#include +#include +#include +#include + +// Helper function to get memory usage in human-readable format +static std::string format_memory_size(size_t bytes) { + const char* units[] = {"B", "KB", "MB", "GB", "TB"}; + size_t unit_idx = 0; + double size = bytes; + + while (size >= 1024.0 && unit_idx < 4) { + size /= 1024.0; + unit_idx++; + } + + char buffer[32]; + snprintf(buffer, sizeof(buffer), "%.2f %s", size, units[unit_idx]); + return std::string(buffer); +} + +namespace llama { +extern "C" { + struct ggml_e4m3_t; + void ggml_e4m3_to_fp32_row(const ggml_e4m3_t * x, float * y, int64_t k); + void ggml_fp32_to_e4m3_row_ref(const float * x, ggml_e4m3_t * y, int64_t k); +} + +static inline float f32_to_e4m3_to_f32(float x) { + unsigned char q_byte = 0; + ggml_fp32_to_e4m3_row_ref(&x, (ggml_e4m3_t *) &q_byte, 1); + float y = 0.0f; + ggml_e4m3_to_fp32_row((const ggml_e4m3_t *) &q_byte, &y, 1); + return y; +} + + +using std::function; + +ggml_tensor * sparse_attn_indexer::idx_compute_scores_tile( + ggml_context * ctx, + ggml_tensor * q3d, + ggml_tensor * a_k, + ggml_tensor * weights, + ggml_tensor * k_scale_2d, + int64_t D, int64_t H, + int64_t Tc, int64_t kv_end, + int64_t t0) { + const char * __prof_env = getenv("LLAMA_SPARSE_PROF"); + bool prof = (__prof_env && atoi(__prof_env) != 0); + double t0_us = 0.0; + if (prof) { + t0_us = ggml_time_us(); + } + + // CPU FP8 Lightning Indexer reference, using GGML FP8 helpers. + // Layout conventions: + // q3d : [D, T_total, H] + // a_k : [D, N_kv] + // weights : [H, T_total] + // k_scale_2d : [N_kv, 1] + const int64_t kv = kv_end; + + // Pack Q tile [D, Tc, H] into flat Q[D * Tc * H] with layout Q[d + D*(tc*H + h)] + std::vector Q((size_t)D * (size_t)Tc * (size_t)H); + for (int64_t tc = 0; tc < Tc; ++tc) { + for (int64_t h = 0; h < H; ++h) { + for (int64_t d = 0; d < D; ++d) { + size_t dst = (size_t)d + (size_t)D * ((size_t)tc * (size_t)H + (size_t)h); + // q3d is [D, T_total, H] + size_t off = + (size_t)d * q3d->nb[0] + + (size_t)(t0 + tc) * q3d->nb[1] + + (size_t)h * q3d->nb[2]; + Q[dst] = *(float *)((char *) q3d->data + off); + } + } + } + + // Pack K slice [D, kv_end] into K[D*kv] row-major per kv row + std::vector K((size_t)D * (size_t)kv); + for (int64_t i = 0; i < kv; ++i) { + for (int64_t d = 0; d < D; ++d) { + size_t dst = (size_t)d + (size_t)D * (size_t)i; + size_t off = + (size_t)d * a_k->nb[0] + + (size_t)i * a_k->nb[1]; + K[dst] = *(float *)((char *) a_k->data + off); + } + } + + // Precompute FP8-dequantized Q: Qq = dequant(quant(Q)) + std::vector Qq(Q.size()); + for (int64_t tc = 0; tc < Tc; ++tc) { + for (int64_t h = 0; h < H; ++h) { + for (int64_t d = 0; d < D; ++d) { + size_t idx_q = (size_t)d + (size_t)D * ((size_t)tc * (size_t)H + (size_t)h); + Qq[idx_q] = f32_to_e4m3_to_f32(Q[idx_q]); + } + } + } + + + // Pack weights [H, Tc] for this tile: W[h + H*tc] + std::vector W((size_t)H * (size_t)Tc); + for (int64_t tc = 0; tc < Tc; ++tc) { + for (int64_t h = 0; h < H; ++h) { + size_t dst = (size_t)h + (size_t)H * (size_t)tc; + size_t off = + (size_t)h * weights->nb[0] + + (size_t)(t0 + tc) * weights->nb[1]; + W[dst] = *(float *)((char *) weights->data + off); + } + } + + // Pack k_scale (IndexKScale proxy) for first kv rows + std::vector KS((size_t)kv); + for (int64_t i = 0; i < kv; ++i) { + // k_scale_2d has shape [N_kv, 1] with nb[0] = sizeof(float), nb[1] = sizeof(float)*N_kv + // The per-row scale is at (i, 0), i.e. offset = i * nb[0] + size_t off = (size_t)i * k_scale_2d->nb[0]; + KS[i] = *(float *)((char *) k_scale_2d->data + off); + } + + // Per-row amax and K_sf exactly as in cpu_indexer_logits_fp8like + std::vector K_sf((size_t)kv); + for (int64_t i = 0; i < kv; ++i) { + float maxv = 0.0f; + const float *kvp = K.data() + (size_t)D * (size_t)i; + for (int64_t d = 0; d < D; ++d) { + float v = std::fabs(kvp[d]); + if (v > maxv) maxv = v; + } + if (maxv < 1e-4f) maxv = 1e-4f; + K_sf[i] = maxv / 448.0f; + } + + // Precompute FP8-dequantized K with per-row scaling: Kh = dequant(quant(K / K_sf[row])) + std::vector Kh(K.size()); + for (int64_t i = 0; i < kv; ++i) { + float sf = K_sf[i]; + const float *kvp = K.data() + (size_t)D * (size_t)i; + float *khp = Kh.data() + (size_t)D * (size_t)i; + for (int64_t d = 0; d < D; ++d) { + float v = kvp[d] / sf; + khp[d] = f32_to_e4m3_to_f32(v); + } + } + + + // Compute FP8-like logits into host buffer using precomputed Qq and Kh + std::vector out((size_t)kv * (size_t)Tc, 0.0f); + for (int64_t tc = 0; tc < Tc; ++tc) { + for (int64_t i = 0; i < kv; ++i) { + float acc = 0.0f; + const float *kvp = Kh.data() + (size_t)D * (size_t)i; + float sf_k = K_sf[i]; + for (int64_t h = 0; h < H; ++h) { + const float *qv = Qq.data() + (size_t)D * ((size_t)tc * (size_t)H + (size_t)h); + float dot = 0.0f; + for (int64_t d = 0; d < D; ++d) { + dot += qv[d] * kvp[d]; + } + if (dot < 0.0f) dot = 0.0f; // ReLU + acc += dot * W[(size_t)h + (size_t)H * (size_t)tc]; + } + out[(size_t)i + (size_t)kv * (size_t)tc] = acc * KS[i] * sf_k; + } + } + + if (getenv("LLAMA_INDEXER_TL_FP8_DEBUG")) { + int maxk = (int) (kv < 8 ? kv : 8); + int maxt = (int) (Tc < 2 ? Tc : 2); + fprintf(stderr, "[IDX_FP8_CPU] Logits (kv x T) sample:"); + for (int i = 0; i < maxk; ++i) { + for (int tc = 0; tc < maxt; ++tc) { + float v = out[(size_t)i + (size_t)kv * (size_t)tc]; + fprintf(stderr, " C[%d,%d]= % .6f", i, tc, v); + } + } + } + + // Materialize scores_tc as a new F32 tensor [kv_end, Tc] + ggml_tensor * scores_tc = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, kv, Tc); + std::memcpy(scores_tc->data, out.data(), out.size() * sizeof(float)); + scores_tc->op = GGML_OP_NONE; + + if (prof) { + double t1_us = ggml_time_us(); + double dt_ms = (t1_us - t0_us) / 1000.0; + fprintf(stderr, "[PROFILE_IDX_CPU] D=%lld H=%lld Tc=%lld kv=%lld ms=%.3f\n", + (long long) D, (long long) H, (long long) Tc, (long long) kv_end, (float) dt_ms); + } + + return scores_tc; +} + +IndexerKVTriplet sparse_attn_indexer::compute_indexer_triplet( + ggml_context * ctx, + const llama_model & model, + int layer_idx, + ggml_tensor * cur, + int64_t n_tokens, + const llama_kv_cache_context * mctx, + ggml_tensor * k_idxs, + ggml_tensor * inp_pos, + int64_t n_rot, + int rope_type, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + const function & cb, + ggml_cgraph * gf) { + + const char * ENV_SPARSE_DEBUG = getenv("LLAMA_SPARSE_DEBUG"); + const bool dbg = (ENV_SPARSE_DEBUG && atoi(ENV_SPARSE_DEBUG) != 0); + + // Compute Indexer K for current tokens and (optionally) write to cache + ggml_tensor * Kindexer_cur = ggml_mul_mat(ctx, model.layers[layer_idx].attn_indexer_wk, cur); + + // Apply LayerNorm over D_index (per token), then gamma/beta + const int64_t D_index = model.layers[layer_idx].attn_indexer_wk->ne[1]; + ggml_tensor * K3d = ggml_reshape_3d(ctx, Kindexer_cur, D_index, 1, n_tokens); // [D,1,T] + ggml_tensor * K_mean = ggml_sum_rows(ctx, K3d); // [1,1,T] + K_mean = ggml_scale(ctx, K_mean, 1.0f / (float) D_index); // [1,1,T] + ggml_tensor * K_mean_rep = ggml_repeat(ctx, K_mean, K3d); // [D,1,T] + ggml_tensor * K_centered = ggml_sub(ctx, K3d, K_mean_rep); // [D,1,T] + ggml_tensor * K_var = ggml_sum_rows(ctx, ggml_sqr(ctx, K_centered)); // [1,1,T] + K_var = ggml_scale(ctx, K_var, 1.0f / (float) D_index); // [1,1,T] + ggml_tensor * K_var_eps = ggml_clamp(ctx, K_var, 1e-6f, 1e9f); // [1,1,T] + ggml_tensor * K_std = ggml_sqrt(ctx, K_var_eps); // [1,1,T] + ggml_tensor * K_std_rep = ggml_repeat(ctx, K_std, K_centered); // [D,1,T] + ggml_tensor * K_normed = ggml_div(ctx, K_centered, K_std_rep); // [D,1,T] + if (model.layers[layer_idx].attn_indexer_k_norm != nullptr) { + ggml_tensor * gamma = model.layers[layer_idx].attn_indexer_k_norm; // [D] + ggml_tensor * gamma_r = ggml_repeat(ctx, gamma, K_normed); // [D,1,T] + K_normed = ggml_mul(ctx, K_normed, gamma_r); + } + if (model.layers[layer_idx].attn_indexer_k_norm_bias != nullptr) { + ggml_tensor * beta = model.layers[layer_idx].attn_indexer_k_norm_bias; // [D] + ggml_tensor * beta_r = ggml_repeat(ctx, beta, K_normed); // [D,1,T] + K_normed = ggml_add(ctx, K_normed, beta_r); + } + // reshape back to [D, T] + Kindexer_cur = ggml_reshape_2d(ctx, K_normed, D_index, n_tokens); + cb(Kindexer_cur, "indexer_k_norm", layer_idx); + + // Apply RoPE to the first n_rot dims of K-indexer: view as [n_rot, 1, T] + if (n_rot > 0) { + ggml_tensor * Kidx_pe = ggml_view_3d(ctx, Kindexer_cur, + n_rot, 1, n_tokens, + ggml_row_size(Kindexer_cur->type, Kindexer_cur->ne[0]), + ggml_row_size(Kindexer_cur->type, Kindexer_cur->ne[0]), + 0); + Kidx_pe = ggml_rope_ext(ctx, Kidx_pe, inp_pos, nullptr, + n_rot, (enum llama_rope_type) rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + // Reuse Kindexer_cur as concatenation of [pe, nope] + ggml_tensor * Kidx_nope = ggml_view_3d(ctx, Kindexer_cur, + Kindexer_cur->ne[0] - n_rot, 1, n_tokens, + ggml_row_size(Kindexer_cur->type, Kindexer_cur->ne[0]), + ggml_row_size(Kindexer_cur->type, Kindexer_cur->ne[0]), + ggml_row_size(Kindexer_cur->type, n_rot)); + Kindexer_cur = ggml_concat(ctx, Kidx_pe, Kidx_nope, 0); + cb(Kindexer_cur, "indexer_k_rope", layer_idx); + } + + // Removed rotate_activation on Kindexer_cur to match reference implementations + + if (mctx && gf) { + ggml_tensor * Kindexer_cur_3d = ggml_reshape_3d(ctx, Kindexer_cur, Kindexer_cur->ne[0], 1, n_tokens); + ggml_build_forward_expand(gf, mctx->cpy_k_indexer(ctx, Kindexer_cur_3d, k_idxs, layer_idx)); + } + // Build q_indexer and weights + ggml_tensor * qsrc = nullptr; + const bool has_wq_a = (model.layers[layer_idx].wq_a != nullptr); + if (has_wq_a) { + qsrc = ggml_mul_mat(ctx, model.layers[layer_idx].wq_a, cur); + // Apply learned RMSNorm (attn_q_a_norm) like main attention path + // This aligns with TileLang qr = q_norm(wq_a(x)) + qsrc = ggml_norm(ctx, qsrc, 1e-5f); + if (model.layers[layer_idx].attn_q_a_norm) { + ggml_tensor * gamma_q = model.layers[layer_idx].attn_q_a_norm; + ggml_tensor * gamma_q_r = ggml_repeat(ctx, gamma_q, qsrc); + qsrc = ggml_mul(ctx, qsrc, gamma_q_r); + if (dbg) { + printf("[SPARSE-IDX-Q] L%d: applied attn_q_a_norm to indexer qsrc\n", layer_idx); + fflush(stdout); + } + } else { + printf("[SPARSE-IDX-Q][WARN] L%d: attn_q_a_norm not found; using plain RMSNorm for indexer qsrc\n", layer_idx); + fflush(stdout); + } + } else { + qsrc = ggml_norm(ctx, cur, 1e-5f); + } + + if (dbg) { + // Logging and sanity checks for potential lite-config mismatch + const int64_t qsrc_in_dim = qsrc ? qsrc->ne[0] : -1; + const int64_t wq_b_in_dim = model.layers[layer_idx].attn_indexer_wq_b ? model.layers[layer_idx].attn_indexer_wq_b->ne[0] : -1; + const int64_t wq_b_out_dim = model.layers[layer_idx].attn_indexer_wq_b ? model.layers[layer_idx].attn_indexer_wq_b->ne[1] : -1; + printf("[SPARSE-IDX-Q] L%d: has_wq_a=%d qsrc_in=%lld wq_b_in=%lld wq_b_out=%lld\n", + layer_idx, (int)has_wq_a, (long long)qsrc_in_dim, (long long)wq_b_in_dim, (long long)wq_b_out_dim); + fflush(stdout); + + if (model.layers[layer_idx].attn_indexer_wq_b && qsrc) { + if (wq_b_in_dim != qsrc_in_dim) { + printf("[SPARSE-IDX-Q][WARN] L%d: attn_indexer_wq_b input dim (%lld) != qsrc dim (%lld). Lite config?\n", + layer_idx, (long long) wq_b_in_dim, (long long) qsrc_in_dim); + fflush(stdout); + } + } + } + + ggml_tensor * q_indexer = ggml_mul_mat(ctx, model.layers[layer_idx].attn_indexer_wq_b, qsrc); + + // index head dim (head_dim in Tilelang) - already defined earlier as D_index + // indexer head count (n_heads in Tilelang) + const int64_t H_index = model.layers[layer_idx].attn_indexer_wq_b->ne[1] / D_index; + if ((model.layers[layer_idx].attn_indexer_wq_b->ne[1] % D_index) != 0) { + printf("[SPARSE-IDX-Q][WARN] L%d: wq_b_out_dim (%lld) is not divisible by D_index (%lld)\n", + layer_idx, (long long) model.layers[layer_idx].attn_indexer_wq_b->ne[1], (long long) D_index); + fflush(stdout); + } + q_indexer = ggml_reshape_3d(ctx, q_indexer, D_index, H_index, n_tokens); + + // Apply RoPE to the first n_rot dims of q_indexer: view as [n_rot, H, T] + if (n_rot > 0) { + ggml_tensor * qidx_pe = ggml_view_3d(ctx, q_indexer, + n_rot, H_index, n_tokens, + ggml_row_size(q_indexer->type, D_index), + ggml_row_size(q_indexer->type, D_index) * H_index, + 0); + qidx_pe = ggml_rope_ext(ctx, qidx_pe, inp_pos, nullptr, + n_rot, (enum llama_rope_type) rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + ggml_tensor * qidx_nope = ggml_view_3d(ctx, q_indexer, + D_index - n_rot, H_index, n_tokens, + ggml_row_size(q_indexer->type, D_index), + ggml_row_size(q_indexer->type, D_index) * H_index, + ggml_row_size(q_indexer->type, n_rot)); + q_indexer = ggml_concat(ctx, qidx_pe, qidx_nope, 0); + cb(q_indexer, "indexer_q_rope", layer_idx); + } + + // Removed rotate_activation on q_indexer to match reference implementations + + cb(q_indexer, "indexer_q", layer_idx); + + // Diagnostic: sample small window of q_indexer [D_index, H_index, T] + { + const int64_t sd0 = std::min(q_indexer->ne[0], (int64_t)8); + const int64_t sd1 = std::min(q_indexer->ne[1], (int64_t)8); + const int64_t sd2 = std::min(q_indexer->ne[2], (int64_t)8); + ggml_tensor * q_sample = ggml_view_3d(ctx, q_indexer, + sd0, sd1, sd2, + q_indexer->nb[1], q_indexer->nb[2], 0); + cb(q_sample, "indexer_q_sample", layer_idx); + } + + + // Approximate q_scale via per-(head, token) RMS of q_indexer across D_index + // q_indexer: [D_index, H_index, T] + ggml_tensor * q_sqr = ggml_sqr(ctx, q_indexer); // [D_index, H, T] + ggml_tensor * q_sum = ggml_sum_rows(ctx, q_sqr); // [1, H, T] + ggml_tensor * q_mean= ggml_scale(ctx, q_sum, 1.0f / (float) D_index); // [1, H, T] + ggml_tensor * q_rms = ggml_sqrt(ctx, q_mean); // [1, H, T] + if (dbg) printf("[SPARSE-IDX-QRMS] L%d: computed q_rms over D_index; D_index=%" PRId64 " H=%" PRId64 " T=%" PRId64 "\n", + layer_idx, D_index, H_index, n_tokens); + + // Build base weights from projection on cur + ggml_tensor * idx_weights = ggml_mul_mat(ctx, model.layers[layer_idx].attn_indexer_weights_proj, cur); // [H, T] + + + // Diagnostic: sample small windows of K indexer cache [D_index, N_kv] + if (mctx) { + ggml_tensor * kidx_cache = mctx->get_k_indexer_full(const_cast(ctx), layer_idx); + if (kidx_cache) { + const int64_t d0 = std::min(kidx_cache->ne[0], (int64_t)8); + const int64_t c0 = std::min(kidx_cache->ne[1], (int64_t)8); + ggml_tensor * kcache_head = ggml_view_2d(ctx, kidx_cache, d0, c0, kidx_cache->nb[1], 0); + cb(kcache_head, "indexer_k_cache_head", layer_idx); + if (kidx_cache->ne[1] > c0) { + size_t off_tail = (kidx_cache->ne[1] - c0) * kidx_cache->nb[1]; + ggml_tensor * kcache_tail = ggml_view_2d(ctx, kidx_cache, d0, c0, kidx_cache->nb[1], off_tail); + cb(kcache_tail, "indexer_k_cache_tail", layer_idx); + } + } + } + + // Diagnostic: sample small window of idx_weights [H_index, T] + { + ggml_tensor * idxw = idx_weights; + const int64_t sw0 = std::min(idxw->ne[0], (int64_t)8); + const int64_t sw1 = std::min(idxw->ne[1], (int64_t)8); + ggml_tensor * idxw_sample = ggml_view_2d(ctx, idxw, sw0, sw1, idxw->nb[1], 0); + cb(idxw_sample, "indexer_weights_sample", layer_idx); + } + + fflush(stdout); + // Scale weights by 1/sqrt(H_index) and 1/sqrt(D_index), then multiply by q_rms + idx_weights = ggml_scale(ctx, idx_weights, 1.0f / sqrtf((float) H_index)); + idx_weights = ggml_scale(ctx, idx_weights, 1.0f / sqrtf((float) D_index)); + + // Broadcast q_scale proxy [1,H,T] to [H,T] and multiply + ggml_tensor * q_scale_proxy = ggml_reshape_2d(ctx, q_rms, H_index, n_tokens); // [H, T] + idx_weights = ggml_mul(ctx, idx_weights, q_scale_proxy); // [H, T] + + cb(idx_weights, "indexer_weights", layer_idx); + ggml_tensor * Kindexer_cache = mctx ? mctx->get_k_indexer_full(ctx, layer_idx) + : ggml_reshape_2d(ctx, Kindexer_cur, D_index, n_tokens); + IndexerKVTriplet out{ q_indexer, Kindexer_cache, idx_weights }; + return out; +} + +ggml_tensor * sparse_attn_indexer::build_kvaware_topk_indices( + ggml_context * ctx, + const llama_model & model, + int layer_idx, + ggml_tensor * cur, + int64_t n_tokens, + const llama_kv_cache_context * mctx, + ggml_tensor * k_idxs, + ggml_tensor * kq_mask, + int64_t top_k, + ggml_tensor * inp_pos, + int64_t n_rot, + int rope_type, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + const function & cb, + ggml_cgraph * gf, + ggml_backend_sched_t sched, + ggml_backend_t backend_cpu) +{ + const char * ENV_SPARSE_DEBUG = getenv("LLAMA_SPARSE_DEBUG"); + const bool dbg = (ENV_SPARSE_DEBUG && atoi(ENV_SPARSE_DEBUG) != 0); + + size_t initial_mem = 0; + if (dbg) { + printf("=== SPARSE INDEXER: build_kvaware_topk_indices L%d ===\n", layer_idx); + initial_mem = ggml_used_mem(ctx); + printf("Initial memory usage: %s\n", format_memory_size(initial_mem).c_str()); + fflush(stdout); + + // Dump indexer dims and sanity-check shapes + const int64_t D_index_dbg = model.layers[layer_idx].attn_indexer_wk->ne[1]; + const int64_t H_index_dbg = model.layers[layer_idx].attn_indexer_wq_b->ne[1] / D_index_dbg; + printf("[SPARSE-DBG-IDX] L%d: D_index=%" PRId64 " H_index=%" PRId64 " n_tokens=%" PRId64 "\n", layer_idx, (int64_t) D_index_dbg, (int64_t) H_index_dbg, (int64_t) n_tokens); + fflush(stdout); + } + IndexerKVTriplet trip = compute_indexer_triplet(ctx, model, layer_idx, cur, n_tokens, mctx, k_idxs, + inp_pos, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, + cb, gf); + // Use full-width indexer K view only for DeepSeek V3.2 + if (model.arch == LLM_ARCH_DEEPSEEK3_2 && mctx) { + ggml_tensor * kidx_full = mctx->get_k_indexer_full(const_cast(ctx), layer_idx); + if (kidx_full && kidx_full->ne[1] >= trip.k_indexer_cache->ne[1]) { + trip.k_indexer_cache = kidx_full; + } + } + ggml_tensor * Kindexer_cache = trip.k_indexer_cache; + + if (top_k <= 0) { + top_k = std::max(64, std::min(1024, Kindexer_cache->ne[1])); + } + ggml_tensor * kvaware_indices = llama::sparse_attn_topk::select_topk_tokens_indexer_kvaware( + ctx, trip.q_indexer, Kindexer_cache, trip.idx_weights, kq_mask, top_k, cb, gf, sched, backend_cpu); + if (dbg) { + printf("SPARSE INDEXER: Final topk_indices [k,T]=[%" PRId64 ", %" PRId64 "]\n", + kvaware_indices->ne[0], kvaware_indices->ne[1]); + printf("Final memory usage: %s (total delta: %s)\n", format_memory_size(ggml_used_mem(ctx)).c_str(), + format_memory_size(ggml_used_mem(ctx) - initial_mem).c_str()); + fflush(stdout); + } + return kvaware_indices; +} + + +} // namespace llama + diff --git a/src/llama-sparse-indexer.h b/src/llama-sparse-indexer.h new file mode 100644 index 00000000000..9b660d0a381 --- /dev/null +++ b/src/llama-sparse-indexer.h @@ -0,0 +1,88 @@ +#ifndef LLAMA_SPARSE_INDEXER_H +#define LLAMA_SPARSE_INDEXER_H + +#include +#include "ggml.h" +#include "ggml-cpp.h" +#include "ggml-backend.h" +#include "llama-kv-cache.h" + +// Forward declarations +struct llama_model; + +namespace llama { + +using std::function; + +// Triplet outputs for KV-aware Lightning Indexer +struct IndexerKVTriplet { + ggml_tensor * q_indexer; + ggml_tensor * k_indexer_cache; + ggml_tensor * idx_weights; +}; + + +// Lightning indexer helpers for DeepSeek V3.2 +struct sparse_attn_indexer { + static ggml_tensor * idx_compute_scores_tile( + ggml_context * ctx, + ggml_tensor * q3d, + ggml_tensor * a_k, + ggml_tensor * weights, + ggml_tensor * k_scale_2d, + int64_t D, int64_t H, + int64_t Tc, int64_t kv_end, + int64_t t0); + + // Build KV-aware top-k token indices using the Lightning Indexer tensors. + // If mctx is nullptr, uses freshly computed K_indexer directly without cache writes. + static IndexerKVTriplet compute_indexer_triplet( + ggml_context * ctx, + const llama_model & model, + int layer_idx, + ggml_tensor * cur, + int64_t n_tokens, + const llama_kv_cache_context * mctx, + ggml_tensor * k_idxs, + ggml_tensor * inp_pos, + int64_t n_rot, + int rope_type, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + const function & cb, + ggml_cgraph * gf); + + static ggml_tensor * build_kvaware_topk_indices( + ggml_context * ctx, + const llama_model & model, + int layer_idx, + ggml_tensor * cur, // [n_embd, T] + int64_t n_tokens, + const llama_kv_cache_context * mctx, + ggml_tensor * k_idxs, + ggml_tensor * kq_mask, + int64_t top_k, + ggml_tensor * inp_pos, + int64_t n_rot, + int rope_type, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + const function & cb, + ggml_cgraph * gf, + ggml_backend_sched_t sched, + ggml_backend_t backend_cpu); +}; + +} // namespace llama + +#endif // LLAMA_SPARSE_INDEXER_H diff --git a/src/llama-sparse-mla-fwd.cpp b/src/llama-sparse-mla-fwd.cpp new file mode 100644 index 00000000000..2a3d79ac27b --- /dev/null +++ b/src/llama-sparse-mla-fwd.cpp @@ -0,0 +1,164 @@ +#include "llama-sparse-mla-fwd.h" +#include "llama-impl.h" +#include +#include +#include +namespace llama { +using std::function; + ggml_tensor * sparse_mla_fwd::apply_sparse_attention_kvaware( + ggml_context * ctx, + ggml_tensor * q_cur, + ggml_tensor * k_cache, + ggml_tensor * v_cache, + ggml_tensor * topk_indices, + int64_t n_tokens, + int64_t top_k, + float kq_scale, + ggml_tensor * kq_mask, + float attn_softcap, + const function & cb) { + (void)n_tokens; + int64_t Dk = k_cache->ne[0]; + int64_t Hkv = k_cache->ne[1]; + int64_t N_kv = k_cache->ne[2]; + int64_t Dv = v_cache->ne[0]; + int64_t Hkv_v= v_cache->ne[1]; + int64_t N_kv_v= v_cache->ne[2]; + const char * ENV_SPARSE_DEBUG = getenv("LLAMA_SPARSE_DEBUG"); + const bool dbg = (ENV_SPARSE_DEBUG && atoi(ENV_SPARSE_DEBUG) != 0); + // Normalize V layout: expected effective layout is [Dv, Hkv_v, N_kv] + // Some builds return V cache with transposed layout [N_kv, Hkv_v, Dv, ns]. + ggml_tensor * V_gather_src = nullptr; + if (N_kv_v == N_kv) { + // Normal layout: [Dv, Hkv_v, N_kv, ns] + V_gather_src = v_cache; + } else if (Dv == N_kv) { + // Transposed layout: [N_kv, Hkv_v, Dv, ns] -> permute to [Dv, Hkv_v, N_kv, ns] + ggml_tensor * v_perm = ggml_permute(ctx, v_cache, 2, 1, 0, 3); + v_perm = ggml_cont(ctx, v_perm); + Dv = v_perm->ne[0]; + Hkv_v = v_perm->ne[1]; + N_kv_v = v_perm->ne[2]; + V_gather_src = v_perm; + } else { + // Unexpected; proceed without permute but warn + printf("[SPARSE-MLA][WARN] V cache unexpected layout: v_cache=[%lld,%lld,%lld,%lld], K N_kv=%lld\n", + (long long) v_cache->ne[0], (long long) v_cache->ne[1], (long long) v_cache->ne[2], (long long) v_cache->ne[3], (long long) N_kv); + fflush(stdout); + V_gather_src = v_cache; + // best effort: if v_cache->ne[0] == N_kv, treat as transposed + if (v_cache->ne[0] == N_kv) { + ggml_tensor * v_perm = ggml_permute(ctx, v_cache, 2, 1, 0, 3); + v_perm = ggml_cont(ctx, v_perm); + Dv = v_perm->ne[0]; + Hkv_v = v_perm->ne[1]; + N_kv_v = v_perm->ne[2]; + V_gather_src = v_perm; + } + } + const int64_t Dq = q_cur->ne[0]; + const int64_t Hq = q_cur->ne[1]; + const int64_t T = q_cur->ne[2]; + // Fused decode path: use custom CUDA op when T == 1 + const char * env_fused_dec = getenv("LLAMA_SPARSE_MLA_FUSED_DECODE"); + if (T == 1 && (env_fused_dec == nullptr || atoi(env_fused_dec) != 0)) { + // Build q_t [Dq, Hq] + ggml_tensor * q_cur_cont2 = ggml_cont(ctx, q_cur); + ggml_tensor * q_all_2d2 = ggml_reshape_2d(ctx, q_cur_cont2, Dq, Hq*T); + ggml_tensor * q_t_2d2 = ggml_view_2d(ctx, q_all_2d2, Dq, Hq, q_all_2d2->nb[1], 0); + // Top-k indices for t=0 -> [K] + ggml_tensor * idx0_2d = ggml_view_2d(ctx, topk_indices, top_k, 1, topk_indices->nb[1], 0); + idx0_2d = ggml_cont(ctx, idx0_2d); + ggml_tensor * idx0_1d = ggml_reshape_1d(ctx, idx0_2d, top_k); + // Call fused decode: returns [Dv, Hq] + ggml_tensor * out2d = ggml_sparse_mla_decode_fused(ctx, q_t_2d2, k_cache, V_gather_src, idx0_1d, kq_scale, attn_softcap); + ggml_tensor * out3d = ggml_reshape_3d(ctx, out2d, Dv, Hq, 1); + cb(out3d, "kvaware_sparse_attn_out", -1); + return out3d; + } + + if (dbg) { + cb(k_cache, "kvaware_k_cache", -1); + cb(v_cache, "kvaware_v_cache", -1); + cb(q_cur, "kvaware_q_cur", -1); + cb(topk_indices, "kvaware_topk_indices", -1); + printf("[SPARSE-MLA] Dq=%lld Hq=%lld T=%lld Dk=%lld Hkv=%lld N_kv=%lld Dv=%lld Hkv_v=%lld\n", + (long long) Dq, (long long) Hq, (long long) T, + (long long) Dk, (long long) Hkv, (long long) N_kv, + (long long) Dv, (long long) Hkv_v); + fflush(stdout); + printf("SPARSE MLA KV-AWARE DBG: Q=[%" PRId64 ",%" PRId64 ",%" PRId64 "] K=[%" PRId64 ",%" PRId64 ",%" PRId64 "] V=[%" PRId64 ",%" PRId64 ",%" PRId64 "] topk=[%" PRId64 ",%" PRId64 ",%" PRId64 ",%" PRId64 "]\n", + Dq, Hq, T, Dk, Hkv, N_kv, Dv, Hkv_v, N_kv_v, + topk_indices->ne[0], topk_indices->ne[1], topk_indices->ne[2], topk_indices->ne[3]); + fflush(stdout); + } + ggml_tensor * K4d = ggml_reshape_4d(ctx, k_cache, Dk*Hkv, N_kv, 1, 1); + ggml_tensor * V4d = ggml_reshape_4d(ctx, V_gather_src, Dv*Hkv_v, N_kv_v, 1, 1); + // For safety, contiguize topk_indices and q_cur for consistent views + ggml_tensor * idx_cont = ggml_cont(ctx, topk_indices); + ggml_tensor * q_cur_cont = ggml_cont(ctx, q_cur); + ggml_tensor * q_all_2d = ggml_reshape_2d(ctx, q_cur_cont, Dq, Hq*T); + // Optional FP16 path for sparse MLA GEMMs + const char *env_mla_fp16 = getenv("LLAMA_SPARSE_MLA_FP16"); + const bool use_mla_fp16 = (env_mla_fp16 == nullptr || atoi(env_mla_fp16) != 0); + ggml_tensor * output_acc = nullptr; + for (int64_t t = 0; t < T; ++t) { + GGML_ASSERT(topk_indices->ne[0] == top_k); + ggml_tensor * idx_t_2d = ggml_view_2d(ctx, idx_cont, top_k, 1, idx_cont->nb[1], t * idx_cont->nb[1]); + idx_t_2d = ggml_cont(ctx, idx_t_2d); + ggml_tensor * idx_t_4d = ggml_reshape_4d(ctx, idx_t_2d, top_k, 1, 1, 1); + + ggml_tensor * k_sel_4d = ggml_get_rows(ctx, K4d, idx_t_4d); // [Dk*Hkv, top_k] + ggml_tensor * v_sel_4d = ggml_get_rows(ctx, V4d, idx_t_4d); // [Dv*Hkv_v, top_k] + + ggml_tensor * k_sel_2d = ggml_reshape_2d(ctx, k_sel_4d, Dk, Hkv*top_k); + ggml_tensor * v_sel_2d0 = ggml_reshape_2d(ctx, v_sel_4d, Dv, Hkv_v*top_k); + // CUDA matmul requires row-contiguous inputs + k_sel_2d = ggml_cont(ctx, k_sel_2d); + // ensure v rows [Hkv_v*top_k, Dv] + ggml_tensor * v_sel_2d = ggml_cont(ctx, ggml_transpose(ctx, v_sel_2d0)); + + size_t q_off = (size_t) t * Hq * q_all_2d->nb[1]; + ggml_tensor * q_t_2d = ggml_view_2d(ctx, q_all_2d, Dq, Hq, q_all_2d->nb[1], q_off); + q_t_2d = ggml_cont(ctx, q_t_2d); + + if (use_mla_fp16) { + if (k_sel_2d->type != GGML_TYPE_F16) { k_sel_2d = ggml_cast(ctx, k_sel_2d, GGML_TYPE_F16); k_sel_2d = ggml_cont(ctx, k_sel_2d); } + if (q_t_2d->type != GGML_TYPE_F16) { q_t_2d = ggml_cast(ctx, q_t_2d, GGML_TYPE_F16); q_t_2d = ggml_cont(ctx, q_t_2d); } + } + + ggml_tensor * scores_t = ggml_mul_mat(ctx, k_sel_2d, q_t_2d); // [Hkv*top_k, Hq] + scores_t = ggml_scale(ctx, scores_t, kq_scale); + + if (kq_mask && kq_mask->ne[0] == N_kv && kq_mask->ne[1] >= T) { + ggml_tensor * mask_col = ggml_view_2d(ctx, kq_mask, kq_mask->ne[0], 1, kq_mask->nb[1], t * kq_mask->nb[1]); + ggml_tensor * mask_vec = ggml_transpose(ctx, mask_col); // [1, N_kv] + if (mask_vec->type != scores_t->type) mask_vec = ggml_cast(ctx, mask_vec, scores_t->type); + mask_vec = ggml_cont(ctx, mask_vec); + ggml_tensor * mask_rows_4d = ggml_get_rows(ctx, mask_vec, idx_t_4d); // [1, top_k, 1, 1] + ggml_tensor * mask_rows_2d = ggml_reshape_2d(ctx, mask_rows_4d, top_k, 1); + ggml_tensor * scores_col_view = ggml_view_2d(ctx, scores_t, Hkv*top_k, 1, scores_t->nb[1], 0); + ggml_tensor * mask_bias = ggml_repeat(ctx, mask_rows_2d, scores_col_view); + scores_t = ggml_add(ctx, scores_t, mask_bias); + } + + if (attn_softcap > 0.0f) { + scores_t = ggml_scale(ctx, scores_t, 1.0f / attn_softcap); + scores_t = ggml_tanh(ctx, scores_t); + scores_t = ggml_scale(ctx, scores_t, attn_softcap); + } + scores_t = ggml_clamp(ctx, scores_t, -1e30f, 1e30f); + + ggml_tensor * weights_t = ggml_soft_max(ctx, scores_t); + weights_t = ggml_cont(ctx, weights_t); + v_sel_2d = ggml_cont(ctx, v_sel_2d); + + ggml_tensor * out2d_t = ggml_mul_mat(ctx, weights_t, v_sel_2d); // [Hq, Dv] + ggml_tensor * out2d_t_T = ggml_cont(ctx, ggml_transpose(ctx, out2d_t)); // [Dv, Hq] + ggml_tensor * out3d_t = ggml_reshape_3d(ctx, out2d_t_T, Dv, Hq, 1); + output_acc = output_acc ? ggml_concat(ctx, output_acc, out3d_t, 2) : out3d_t; + } + cb(output_acc, "kvaware_sparse_attn_out", -1); + return output_acc; + } +} // namespace llama diff --git a/src/llama-sparse-mla-fwd.h b/src/llama-sparse-mla-fwd.h new file mode 100644 index 00000000000..e72cfc662e1 --- /dev/null +++ b/src/llama-sparse-mla-fwd.h @@ -0,0 +1,35 @@ +#ifndef LLAMA_SPARSE_MLA_FWD_H +#define LLAMA_SPARSE_MLA_FWD_H + +#include +#include "ggml-cpp.h" + +// Forward declarations +struct llm_graph_params; + +namespace llama { + +using std::function; + +// Sparse Multi-Query Attention Forward implementation for DeepSeek V3.2 +// Corresponds to tilelang's sparse_mla_fwd.py +struct sparse_mla_fwd { + // KV-aware variant: gather from full KV cache tensors instead of current block + static ggml_tensor * apply_sparse_attention_kvaware( + ggml_context * ctx, + ggml_tensor * q_cur, // [Dq, Hq, T] + ggml_tensor * k_cache, // [Dk, Hkv, N_kv] + ggml_tensor * v_cache, // [Dv, Hkv, N_kv] + ggml_tensor * topk_indices, // [top_k, T] + int64_t n_tokens, + int64_t top_k, + float kq_scale, + ggml_tensor * kq_mask, + float attn_softcap, + const function & cb); +}; + + +} // namespace llama + +#endif // LLAMA_SPARSE_MLA_FWD_H diff --git a/src/llama-sparse-topk.cpp b/src/llama-sparse-topk.cpp new file mode 100644 index 00000000000..1f5842b94a3 --- /dev/null +++ b/src/llama-sparse-topk.cpp @@ -0,0 +1,780 @@ +#include "llama-sparse-topk.h" +#include "llama-sparse-indexer.h" +#include +#include +#include + +#include "llama-impl.h" +#ifdef GGML_USE_CUDA +#include "ggml-cuda-indexer.h" +#endif + + +#include + +#include +#include +#include +#include + +#include +#include +namespace { + +static inline uint32_t float_to_key_desc(float x) { + uint32_t u; + memcpy(&u, &x, sizeof(u)); + // Map float bits to monotonically increasing unsigned keys (ascending order): + // TileLang-compatible mapping: negative -> bitwise NOT, non-negative -> set sign bit + if ((int32_t)u < 0) { + u = ~u; + } else { + u |= 0x80000000u; + } + return u; +} + +struct radix_topk_userdata { + // currently unused; k is taken from dst->ne[0] +}; + +static void radix_topk_custom(ggml_tensor * dst, int ith, int nth, void * userdata) { + (void)userdata; + const char * ENV_SPARSE_DEBUG = getenv("LLAMA_SPARSE_DEBUG"); + const bool dbg = (ENV_SPARSE_DEBUG && atoi(ENV_SPARSE_DEBUG) != 0); + ggml_tensor * src0 = dst->src[0]; + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_I32); + const int64_t N = src0->ne[0]; + const int64_t k = dst->ne[0]; + const int64_t nr = ggml_nrows(src0); + const size_t src_nb1 = src0->nb[1]; + const size_t src_nb0 = src0->nb[0]; + const size_t dst_nb1 = dst->nb[1]; + + for (int64_t r = ith; r < nr; r += nth) { + const char * row0 = (const char *)src0->data + r * src_nb1; + int32_t * out_idx = (int32_t *)((char *)dst->data + r * dst_nb1); + const int64_t KK = k < N ? k : N; + + // Precompute keys for this row + std::vector keys(N); + for (int64_t i = 0; i < N; ++i) { + float v = *(const float *)(row0 + (size_t)i*src_nb0); + keys[i] = float_to_key_desc(v); + } + + // Stage 1: histogram of high 8 bits (bits 31..24) + uint32_t counts[256] = {0}; + for (int64_t i = 0; i < N; ++i) { + uint32_t bin = (keys[i] >> 24) & 0xFFu; + counts[bin]++; + } + // Find threshold bin: number of items with bin > thr0 is sum of counts above thr0 + auto sum_greater = [&](int b){ uint32_t s=0; for (int bb=b+1; bb<256; ++bb) s += counts[bb]; return s; }; + int thr0 = 0; + uint32_t gt = 0; + for (int b = 255; b >= 0; --b) { + uint32_t sgt = sum_greater(b); + uint32_t eq = counts[b]; + if (sgt < (uint32_t)KK && sgt + eq >= (uint32_t)KK) { thr0 = b; gt = sgt; break; } + } + uint32_t eq0 = counts[thr0]; + int64_t remaining = (int64_t)KK - (int64_t)gt; + if (remaining < 0) remaining = 0; + + // Collect selected (> thr0) and eq candidates + std::vector selected; selected.reserve(KK); + std::vector eq_list; eq_list.reserve(eq0); + for (int64_t i = 0; i < N; ++i) { + uint32_t bin = (keys[i] >> 24) & 0xFFu; + if ((int)bin > thr0) { + if ((int64_t)selected.size() < KK) selected.push_back((int32_t)i); + } else if ((int)bin == thr0) { + eq_list.push_back((int32_t)i); + } + } + remaining = (int64_t)KK - (int64_t)selected.size(); + + // Safety check: ensure we have enough candidates to fill K + if ((int64_t)selected.size() + (int64_t)eq_list.size() < KK) { + // Fallback: use partial_sort to guarantee correctness + std::vector idx(N); + for (int64_t i = 0; i < N; ++i) idx[i] = (int32_t)i; + auto cmp = [&](int32_t a, int32_t b){ + float va = *(const float *)(row0 + (size_t)a*src_nb0); + float vb = *(const float *)(row0 + (size_t)b*src_nb0); + if (va != vb) return va > vb; + return a < b; + }; + std::partial_sort(idx.begin(), idx.begin() + KK, idx.end(), cmp); + for (int64_t i = 0; i < KK; ++i) out_idx[i] = idx[i]; + continue; + } + + // Tail passes for equal bin + int shifts[3] = {16, 8, 0}; + for (int pass = 0; pass < 3 && remaining > 0 && !eq_list.empty(); ++pass) { + uint32_t c2[256] = {0}; + for (int idx : eq_list) { + uint32_t bin = (keys[idx] >> shifts[pass]) & 0xFFu; + c2[bin]++; + } + auto sum_greater2 = [&](int b){ uint32_t s=0; for (int bb=b+1; bb<256; ++bb) s += c2[bb]; return s; }; + int thr = 255; + for (int b = 255; b >= 0; --b) { + uint32_t sgt = sum_greater2(b); + uint32_t eq = c2[b]; + if (sgt < (uint32_t)remaining && sgt + eq >= (uint32_t)remaining) { thr = b; break; } + } + std::vector next_eq; next_eq.reserve(c2[thr]); + // Add strictly greater than thr + for (int idx : eq_list) { + uint32_t bin = (keys[idx] >> shifts[pass]) & 0xFFu; + if ((int)bin > thr) { + if ((int64_t)selected.size() < KK) { selected.push_back(idx); } + } else if ((int)bin == thr) { + next_eq.push_back(idx); + } + } + eq_list.swap(next_eq); + remaining = (int64_t)KK - (int64_t)selected.size(); + if ((int64_t)selected.size() + (int64_t)eq_list.size() < remaining) { + // Fallback safety + break; + } + } + // Final fill from eq_list if still remaining + for (int64_t i = 0; i < (int64_t)eq_list.size() && (int64_t)selected.size() < KK; ++i) { + selected.push_back(eq_list[i]); + } + + // As a final fallback, if still not enough, use partial_sort + if ((int64_t)selected.size() < KK) { + std::vector idx(N); + for (int64_t i = 0; i < N; ++i) idx[i] = (int32_t)i; + auto cmp = [&](int32_t a, int32_t b){ + float va = *(const float *)(row0 + (size_t)a*src_nb0); + float vb = *(const float *)(row0 + (size_t)b*src_nb0); + if (va != vb) return va > vb; + return a < b; + }; + std::partial_sort(idx.begin(), idx.begin() + KK, idx.end(), cmp); + for (int64_t i = 0; i < KK; ++i) out_idx[i] = idx[i]; + continue; + } + + // Output first KK indices (order arbitrary) + for (int64_t i = 0; i < KK; ++i) out_idx[i] = selected[i]; + + // Debug: compare with partial_sort for a few rows + if (r < 8) { + std::vector ref(N); + for (int64_t i = 0; i < N; ++i) ref[i] = (int32_t)i; + auto cmp = [&](int32_t a, int32_t b){ + float va = *(const float *)(row0 + (size_t)a*src_nb0); + float vb = *(const float *)(row0 + (size_t)b*src_nb0); + if (va != vb) return va > vb; + return a < b; + }; + std::partial_sort(ref.begin(), ref.begin() + KK, ref.end(), cmp); + if (dbg) { + printf("[radix debug] row=%lld top: ", (long long)r); + for (int ii = 0; ii < (int)std::min(8, KK); ++ii) printf("%d ", out_idx[ii]); + printf("| ref: "); + for (int ii = 0; ii < (int)std::min(8, KK); ++ii) printf("%d ", ref[ii]); + printf("\n"); + fflush(stdout); + } + } + } +} + +} // anonymous namespace + +namespace llama { + +static inline int find_last_unmasked(const float * col, int N, size_t nb0) { + // col is [N] as a column with row stride nb0; return last index+1 where value > -1e29 + for (int i = N-1; i >= 0; --i) { + float v = *(const float *)((const char*)col + (size_t)i*nb0); + if (v > -1.0e29f) return i+1; + } + return 0; +} + +using std::function; + +struct fused_indexer_userdata { }; + +static void fused_indexer_custom(ggml_tensor * /*dst*/, int /*ith*/, int /*nth*/, void * /*userdata*/) { + // no-op custom CPU fallback; all real work happens in CUDA backend path +} + + + +static ggml_tensor * build_indexer_fused_logits_ex( + ggml_context * ctx, + ggml_tensor * q_tile2d, + ggml_tensor * k_slice, + ggml_tensor * w_slice, + ggml_tensor * ks_head, + bool have_windows, + ggml_tensor * win_starts, + ggml_tensor * win_ends, + int64_t t0, + int64_t Tc) { + ggml_tensor * starts_tile = nullptr; + ggml_tensor * ends_tile = nullptr; + if (have_windows && win_ends) { + if (win_starts) { + size_t off_s = (size_t)t0 * win_starts->nb[0]; + starts_tile = ggml_view_1d(ctx, win_starts, Tc, off_s); + starts_tile = ggml_cont(ctx, starts_tile); + } + if (win_ends) { + size_t off_e = (size_t)t0 * win_ends->nb[0]; + ends_tile = ggml_view_1d(ctx, win_ends, Tc, off_e); + ends_tile = ggml_cont(ctx, ends_tile); + } + } + return ggml_indexer_logits_fused_ex(ctx, q_tile2d, k_slice, w_slice, ks_head, starts_tile, ends_tile); +} +static ggml_tensor * build_indexer_fused_logits( + ggml_context * ctx, + ggml_tensor * q2d, // [D, Tc*H] + ggml_tensor * k2d, // [D, kv] + ggml_tensor * w2d, // [H, Tc] + ggml_tensor * k_scale // [kv] +) { + int64_t kv = k2d->ne[1]; + int64_t Tc = w2d->ne[1]; + ggml_tensor * args[4] = { q2d, k2d, w2d, k_scale }; + return ggml_custom_4d(ctx, GGML_TYPE_F32, kv, Tc, 1, 1, args, 4, fused_indexer_custom, 1, nullptr); +} + + +ggml_tensor * sparse_attn_topk::select_topk_tokens_indexer_kvaware( + ggml_context * ctx, + ggml_tensor * q_indexer, // [D, H, T] + ggml_tensor * k_indexer, // [D, N_kv] + ggml_tensor * weights, // [H, T] + ggml_tensor * kq_mask, // [N_kv, T] or [N_kv, PAD(T)] + int64_t top_k, + const std::function & cb, + ggml_cgraph * gf, + ggml_backend_sched_t sched, + ggml_backend_t /*backend_cpu*/) { + const int64_t D = q_indexer->ne[0]; + const int64_t H = q_indexer->ne[1]; + const int64_t T = q_indexer->ne[2]; + const int64_t N_kv = k_indexer->ne[1]; + + const char * ENV_SPARSE_DEBUG = getenv("LLAMA_SPARSE_DEBUG"); + const bool dbg = (ENV_SPARSE_DEBUG && atoi(ENV_SPARSE_DEBUG) != 0); + + if (dbg) { + printf("SPARSE TOPK KV-AWARE (INDEXER): q_indexer [D,H,T]=[%" PRId64 ",%" PRId64 ",%" PRId64 "]\n", D, H, T); + printf("SPARSE TOPK KV-AWARE (INDEXER): k_indexer dims=[%" PRId64 ",%" PRId64 ",%" PRId64 ",%" PRId64 "]\n", + k_indexer->ne[0], k_indexer->ne[1], k_indexer->ne[2], k_indexer->ne[3]); + printf("SPARSE TOPK KV-AWARE (INDEXER): weights [H,T]=[%" PRId64 ",%" PRId64 "]\n", + weights ? weights->ne[0] : -1, weights ? weights->ne[1] : -1); + if (kq_mask) { + + printf("SPARSE TOPK KV-AWARE (INDEXER): kq_mask dims=[%" PRId64 ",%" PRId64 ",%" PRId64 ",%" PRId64 "] type=%d\n", + kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3], (int)kq_mask->type); + } + fflush(stdout); + } + + // Shape/contiguity assertions for weights [H, T] + GGML_ASSERT(D > 0 && H > 0 && T > 0 && N_kv > 0); + GGML_ASSERT(weights != nullptr); + GGML_ASSERT(weights->ne[0] == H); + GGML_ASSERT(weights->ne[1] >= T); + GGML_ASSERT(weights->nb[0] == (size_t) ggml_type_size(weights->type)); + GGML_ASSERT(weights->nb[1] == (size_t) ggml_row_size(weights->type, weights->ne[0])); + // Ensure indexer K depth matches indexer Q depth + GGML_ASSERT(k_indexer->ne[0] == D); + // KV indexer currently expected as 2D [D, N_kv] or 3D with singleton stream + GGML_ASSERT(k_indexer->ne[2] <= 1); + + // Q as [D, H*T] + ggml_tensor * q_perm = ggml_permute(ctx, q_indexer, 0, 2, 1, 3); // [D, T, H] + ggml_tensor * q_cont = ggml_cont(ctx, q_perm); + ggml_tensor * Q2d_full = ggml_reshape_2d(ctx, q_cont, D, T*H); + cb(Q2d_full, "idxkv_Q2d_full", -1); + + // Optional FP16 path for indexer GEMMs + ggml_tensor * k_indexer_f16 = k_indexer; + const char *env_fp16 = getenv("LLAMA_SPARSE_TOPK_FP16"); + const bool use_fp16 = (env_fp16 == nullptr || atoi(env_fp16) != 0); + if (use_fp16 && k_indexer->type != GGML_TYPE_F16) { + k_indexer_f16 = ggml_cast(ctx, k_indexer, GGML_TYPE_F16); + k_indexer_f16 = ggml_cont(ctx, k_indexer_f16); + } + + // Diagnostics: sample K indexer head/tail once per call + if (dbg) { + const int64_t d0 = std::min(k_indexer->ne[0], (int64_t)8); + const int64_t c0 = std::min(k_indexer->ne[1], (int64_t)8); + // head columns + ggml_tensor * kidx_head = ggml_view_2d(ctx, k_indexer, d0, c0, k_indexer->nb[1], 0); + cb(kidx_head, "idxkv_k_indexer_head", -1); + if (gf) { ggml_set_output(kidx_head); ggml_build_forward_expand(gf, kidx_head); } + // tail columns + if (k_indexer->ne[1] > c0) { + size_t off_tail = (k_indexer->ne[1] - c0) * k_indexer->nb[1]; + ggml_tensor * kidx_tail = ggml_view_2d(ctx, k_indexer, d0, c0, k_indexer->nb[1], off_tail); + cb(kidx_tail, "idxkv_k_indexer_tail", -1); + if (gf) { ggml_set_output(kidx_tail); ggml_build_forward_expand(gf, kidx_tail); } + } + } + + bool no_alloc = ggml_get_no_alloc(ctx); + ggml_tensor * mask_full = nullptr; + bool prefer_device_windows = true; + if (const char *env = getenv("LLAMA_SPARSE_TOPK_WINDOWS_DEVICE")) { + prefer_device_windows = atoi(env) != 0; + } + + ggml_tensor * win_starts = nullptr; + (void)win_starts; + ggml_tensor * win_ends = nullptr; + bool have_windows = false; +#ifdef GGML_USE_CUDA + if (!no_alloc && kq_mask && kq_mask->buffer && !ggml_backend_buffer_is_host(kq_mask->buffer)) { + if (prefer_device_windows) { + // Create device-resident starts/ends tensors and fill them via device kernels + win_starts = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, T); + win_ends = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, T); + ggml_set_input(win_starts); + ggml_set_input(win_ends); + have_windows = true; + } else { + std::vector starts_h((size_t)T, 0); + ggml_cuda_mask_window_starts_device_to_host_simple((const float *)kq_mask->data, (int)N_kv, (int)T, starts_h.data()); + ggml_tensor * starts = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, T); + memcpy(starts->data, starts_h.data(), sizeof(int32_t) * (size_t)T); + ggml_set_input(starts); + win_starts = starts; have_windows = true; + + std::vector ends_h((size_t)T, 0); + ggml_cuda_mask_window_ends_device_to_host_simple((const float *)kq_mask->data, (int)N_kv, (int)T, ends_h.data()); + ggml_tensor * ends = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, T); + memcpy(ends->data, ends_h.data(), sizeof(int32_t) * (size_t)T); + ggml_set_input(ends); + win_ends = ends; have_windows = true; + } + } +#endif + // If no device-resident windows were created, but we prefer device windows, create empty device + // starts/ends now and let the CUDA top-k kernel derive them from scores + #ifdef GGML_USE_CUDA + if (!no_alloc && !have_windows && prefer_device_windows && kq_mask && kq_mask->buffer && !ggml_backend_buffer_is_host(kq_mask->buffer)) { + win_starts = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, T); + win_ends = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, T); + ggml_set_input(win_starts); + ggml_set_input(win_ends); + have_windows = true; + } + #endif + + +#ifdef GGML_USE_CUDA + +#endif + if (kq_mask) { + cb(kq_mask, "idxkv_kq_mask", -1); + ggml_tensor * tmp_mask = kq_mask; + cb(tmp_mask, "idxkv_mask2d", -1); + if (tmp_mask->ne[0] == N_kv && tmp_mask->ne[1] >= T) { + mask_full = tmp_mask; // use original mask; slice per tile and only contiguize the tile + } else { + printf("[TOPK-INDEXER] kq_mask dims [%lld,%lld] mismatch N_kv=%lld,T=%lld; ignoring mask for indexer selection\n", + (long long) tmp_mask->ne[0], (long long) tmp_mask->ne[1], (long long) N_kv, (long long) T); + fflush(stdout); + mask_full = nullptr; + } + } + + const int64_t k = std::min(top_k, N_kv); + int64_t TILE_T = use_fp16 ? 512 : 256; // heuristic default; overridable via env + if (const char *env = getenv("LLAMA_SPARSE_TOPK_TILE_T")) { + long v = strtol(env, nullptr, 10); + if (v > 0 && v <= 8192) TILE_T = v; + } + if (dbg) { + printf("[TOPK-INDEXER] N_kv=%lld T=%lld k=%lld TILE_T=%lld H=%lld D=%lld\n", + (long long) N_kv, (long long) T, (long long) k, (long long) TILE_T, (long long) H, (long long) D); + fflush(stdout); + } + + // K-scale proxy: RMS over D for each KV column + ggml_tensor * k_sqr = ggml_sqr(ctx, k_indexer); // [D, N_kv] + ggml_tensor * k_sum = ggml_sum_rows(ctx, k_sqr); // [1, N_kv] + ggml_tensor * k_mean = ggml_scale(ctx, k_sum, 1.0f / (float) D); // [1, N_kv] + ggml_tensor * k_scale_vec = ggml_sqrt(ctx, k_mean); // [1, N_kv] + ggml_tensor * k_scale_2d = ggml_transpose(ctx, k_scale_vec); // [N_kv, 1] + k_scale_2d = ggml_cont(ctx, k_scale_2d); + cb(k_scale_2d, "idxkv_k_scale_proxy", -1); + + ggml_tensor * result = nullptr; // [k, T] + for (int64_t t0 = 0; t0 < T; t0 += TILE_T) { + const int64_t Tc = std::min(TILE_T, T - t0); + // If we have per-token windows, compute aggregate kv_end for this tile + int64_t kv_end = N_kv; + if (!no_alloc && have_windows && win_ends && win_ends->buffer && ggml_backend_buffer_is_host(win_ends->buffer) && win_ends->data != nullptr) { + int32_t * e = (int32_t*)win_ends->data; + int64_t max_e = 0; + for (int64_t t = 0; t < Tc; ++t) max_e = std::max(max_e, (int64_t)e[t0 + t]); + kv_end = std::min(N_kv, std::max(k, max_e)); + } else { + const char *env_full_kv = getenv("LLAMA_SPARSE_TOPK_FULL_KV"); + const bool use_full_kv = (env_full_kv ? atoi(env_full_kv) != 0 : true); + kv_end = use_full_kv ? N_kv : std::min(N_kv, std::max(k, t0 + Tc)); + } + + // Use contiguized [D, T, H] directly for head-wise tiles + ggml_tensor * q3d = q_cont; + + ggml_tensor * scores_tc = nullptr; + { + // Host wall-clock timing for the tile compute path (portable) + auto __t0_wall = std::chrono::high_resolution_clock::now(); + + + const char *env_fused = getenv("LLAMA_SPARSE_INDEXER_FUSED"); + bool use_fused = (env_fused ? atoi(env_fused) != 0 : true); + if (use_fused) { + // prepare q2d tile [D, Tc*H] + size_t q_off = (size_t)t0 * q3d->nb[1]; + ggml_tensor * q_tile3d = ggml_view_3d(ctx, q3d, D, Tc, H, q3d->nb[1], q3d->nb[2], q_off); + q_tile3d = ggml_cont(ctx, q_tile3d); + ggml_tensor * q_tile2d = ggml_reshape_2d(ctx, q_tile3d, D, Tc*H); + q_tile2d = ggml_cont(ctx, q_tile2d); + // Determine kv window for this tile + int64_t kv_s = 0; + int64_t kv_e = kv_end; + if (have_windows && win_ends) { + int32_t * e = (int32_t*)win_ends->data; + int64_t max_e = 0; + for (int64_t t = 0; t < Tc; ++t) max_e = std::max(max_e, (int64_t)e[t0 + t]); + kv_e = std::min(N_kv, std::max(k, max_e)); + } + int64_t kv_len = std::max(0, kv_e - kv_s); + // k slice [D, kv_len] + ggml_tensor * k_slice = ggml_view_2d(ctx, k_indexer_f16, D, kv_len, k_indexer_f16->nb[1], kv_s * k_indexer_f16->nb[1]); + k_slice = ggml_cont(ctx, k_slice); + // w slice [H, Tc] + ggml_tensor * w_slice = ggml_view_2d(ctx, weights, H, Tc, weights->nb[1], t0*weights->nb[1]); + w_slice = ggml_cont(ctx, w_slice); + // k_scale head [kv_len] + ggml_tensor * ks_head = ggml_view_2d(ctx, k_scale_2d, kv_len, 1, k_scale_2d->nb[1], kv_s * k_scale_2d->nb[1]); + ks_head = ggml_reshape_1d(ctx, ks_head, kv_len); + ks_head = ggml_cont(ctx, ks_head); + + if (dbg && sched && t0 == 0) { + ggml_backend_t bq = ggml_backend_sched_get_tensor_backend(sched, q_tile2d); + ggml_backend_t bk = ggml_backend_sched_get_tensor_backend(sched, k_slice); + ggml_backend_t bw = ggml_backend_sched_get_tensor_backend(sched, w_slice); + ggml_backend_t bs = ggml_backend_sched_get_tensor_backend(sched, ks_head); + printf("[idxkv fused inputs strides] q nb=[%zu,%zu] k nb=[%zu,%zu] w nb=[%zu,%zu] ks nb0=%zu\n", + (size_t)q_tile2d->nb[0], (size_t)q_tile2d->nb[1], + (size_t)k_slice->nb[0], (size_t)k_slice->nb[1], + (size_t)w_slice->nb[0], (size_t)w_slice->nb[1], + (size_t)ks_head->nb[0]); + printf("[idxkv fused inputs] backends: q=%s k=%s w=%s ks=%s\n", + bq ? ggml_backend_name(bq) : "null", + bk ? ggml_backend_name(bk) : "null", + bw ? ggml_backend_name(bw) : "null", + bs ? ggml_backend_name(bs) : "null"); + fflush(stdout); + } + + const char *env_fused_dev = getenv("LLAMA_SPARSE_INDEXER_FUSED_DEVICE"); + bool use_fused_device = (env_fused_dev ? atoi(env_fused_dev) != 0 : true); + if (dbg && t0 == 0) { + printf("[idxkv] fused_device=%d\n", (int)use_fused_device); + fflush(stdout); + } + if (use_fused_device) { + scores_tc = build_indexer_fused_logits_ex(ctx, q_tile2d, k_slice, w_slice, ks_head, have_windows, win_starts, win_ends, t0, Tc); + } else { + scores_tc = build_indexer_fused_logits(ctx, q_tile2d, k_slice, w_slice, ks_head); + } + + if (dbg && t0 == 0) { + ggml_tensor * ref_scores = llama::sparse_attn_indexer::idx_compute_scores_tile(ctx, q3d, k_indexer_f16, weights, k_scale_2d, D, H, Tc, kv_end, t0); + ggml_tensor * ref_head = ggml_view_2d(ctx, ref_scores, std::min(kv_end, (int64_t)8), std::min(Tc, (int64_t)4), ref_scores->nb[1], 0); + cb(ref_head, "idxkv_scores_ref_head", -1); + if (gf) { ggml_set_output(ref_head); ggml_build_forward_expand(gf, ref_head); } + ggml_tensor * fused_head = ggml_view_2d(ctx, scores_tc, std::min(kv_end, (int64_t)8), std::min(Tc, (int64_t)4), scores_tc->nb[1], 0); + cb(fused_head, "idxkv_scores_fused_head", -1); + if (gf) { ggml_set_output(fused_head); ggml_build_forward_expand(gf, fused_head); } + printf("[idxkv debug] shapes: fused=[%lld,%lld,%lld,%lld] ref=[%lld,%lld,%lld,%lld]\n", + (long long)scores_tc->ne[0], (long long)scores_tc->ne[1], (long long)scores_tc->ne[2], (long long)scores_tc->ne[3], + (long long)ref_scores->ne[0], (long long)ref_scores->ne[1], (long long)ref_scores->ne[2], (long long)ref_scores->ne[3]); + fflush(stdout); + ggml_tensor * diff = ggml_sub(ctx, scores_tc, ref_scores); + ggml_tensor * L1 = ggml_sum(ctx, ggml_abs(ctx, diff)); + cb(L1, "idxkv_scores_diff_L1", -1); + if (gf) { ggml_set_output(L1); ggml_build_forward_expand(gf, L1); } + ggml_tensor * q_samp = ggml_view_3d(ctx, q3d, std::min(D, (int64_t)8), std::min(Tc, (int64_t)2), std::min(H, (int64_t)2), q3d->nb[1], q3d->nb[2], 0); + cb(q_samp, "idxkv_q3d_sample", -1); + if (gf) { ggml_set_output(q_samp); ggml_build_forward_expand(gf, q_samp); } + ggml_tensor * w_samp = ggml_view_2d(ctx, w_slice, std::min(H, (int64_t)4), std::min(Tc, (int64_t)4), w_slice->nb[1], 0); + cb(w_samp, "idxkv_w_slice_sample", -1); + if (gf) { ggml_set_output(w_samp); ggml_build_forward_expand(gf, w_samp); } + + size_t ks_off = 0; ggml_tensor * ks_samp = ggml_view_1d(ctx, ks_head, std::min(kv_end, (int64_t)8), ks_off); + cb(ks_samp, "idxkv_ks_head_sample", -1); + + if (gf) { ggml_set_output(ks_samp); ggml_build_forward_expand(gf, ks_samp); } + } + + // End wall timer and print (if LLAMA_SPARSE_PROF set) + auto __t1_wall = std::chrono::high_resolution_clock::now(); + const char * __prof = getenv("LLAMA_SPARSE_PROF"); + if (__prof && *__prof) { + double __ms = 1e3 * std::chrono::duration(__t1_wall - __t0_wall).count(); + static int __cnt_idx_comp = 0; static double __sum_idx_comp = 0.0; __sum_idx_comp += __ms; __cnt_idx_comp++; if (__cnt_idx_comp % 50 == 0) { fprintf(stderr, "[PROFILE] IDX_TILE compute avg_ms=%.3f over 50 tiles (last t0=%lld Tc=%lld kv_end=%lld fused=%d)\n", (float)(__sum_idx_comp/50.0), (long long)t0, (long long)Tc, (long long)kv_end, (int)use_fused); __sum_idx_comp = 0.0; } + } + + } else { + scores_tc = llama::sparse_attn_indexer::idx_compute_scores_tile(ctx, q3d, k_indexer_f16, weights, k_scale_2d, D, H, Tc, kv_end, t0); + } + } + + // Safe K-scale proxy application after head reduction: skip if fused kernel already applied it + { + const char *env_fused2 = getenv("LLAMA_SPARSE_INDEXER_FUSED"); + bool use_fused2 = (env_fused2 && atoi(env_fused2) != 0); + if (dbg && t0 == 0) { + ggml_tensor * pre_head = ggml_view_2d(ctx, scores_tc, std::min(kv_end, (int64_t)8), std::min(Tc, (int64_t)4), scores_tc->nb[1], 0); + cb(pre_head, "idxkv_scores_pre_kScale_head", -1); + printf("[idxkv] t0=%lld kv_end=%lld fused=%d\n", + (long long)t0, (long long)kv_end, (int)use_fused2); + fflush(stdout); + } + if (!use_fused2) { + ggml_tensor * k_scale_head = ggml_view_2d(ctx, k_scale_2d, kv_end, 1, k_scale_2d->nb[1], 0); + ggml_tensor * k_scale_bcast = ggml_repeat(ctx, k_scale_head, scores_tc); // [kv_end, Tc] + scores_tc = ggml_mul(ctx, scores_tc, k_scale_bcast); + if (dbg && t0 == 0) { + ggml_tensor * post_head = ggml_view_2d(ctx, scores_tc, std::min(kv_end, (int64_t)8), std::min(Tc, (int64_t)4), scores_tc->nb[1], 0); + cb(post_head, "idxkv_scores_post_kScale_head", -1); + } + } + } + + // Debug-only summaries + if (dbg) { + ggml_tensor * idxkv_scores_sum = ggml_sum(ctx, scores_tc); + ggml_tensor * idxkv_scores_ssq = ggml_sum(ctx, ggml_sqr(ctx, scores_tc)); + if (t0 == 0) { + ggml_tensor * idxkv_scores_post_abs_sum = ggml_sum(ctx, ggml_abs(ctx, scores_tc)); + cb(idxkv_scores_post_abs_sum, "idxkv_scores_post_abs_sum", -1); + if (gf) { + ggml_set_output(idxkv_scores_post_abs_sum); + ggml_build_forward_expand(gf, idxkv_scores_post_abs_sum); + } + } + cb(idxkv_scores_sum, "idxkv_scores_sum", -1); + cb(idxkv_scores_ssq, "idxkv_scores_ssq", -1); + } + + scores_tc = ggml_cont(ctx, scores_tc); + + // mask tile if available + if (mask_full) { + ggml_tensor * mask_tc = ggml_view_2d(ctx, mask_full, kv_end, Tc, mask_full->nb[1], t0 * mask_full->nb[1]); + mask_tc = ggml_cont(ctx, mask_tc); + if (mask_tc->type != scores_tc->type) { + mask_tc = ggml_cast(ctx, mask_tc, scores_tc->type); + mask_tc = ggml_cont(ctx, mask_tc); + } + if (dbg) cb(mask_tc, "idxkv_mask_tc", -1); + + // Ensure both operands have row-contiguous layout for safe broadcast add + GGML_ASSERT(scores_tc->nb[0] == (size_t) ggml_type_size(scores_tc->type)); + GGML_ASSERT(mask_tc->nb[0] == (size_t) ggml_type_size(mask_tc->type)); + if (dbg && t0 == 0) { + printf("[TOPK-INDEXER-DBG] (INDEXER) add mask: scores_tc ne=[%" PRId64 ",%" PRId64 "] nb=[%zu,%zu] type=%d | mask_tc ne=[%" PRId64 ",%" PRId64 "] nb=[%zu,%zu] type=%d\n", + scores_tc->ne[0], scores_tc->ne[1], scores_tc->nb[0], scores_tc->nb[1], (int) scores_tc->type, + mask_tc->ne[0], mask_tc->ne[1], mask_tc->nb[0], mask_tc->nb[1], (int) mask_tc->type); + fflush(stdout); + } + + if (dbg && t0 == 0) { + ggml_tensor * mask_head = ggml_view_2d(ctx, mask_tc, std::min(kv_end, (int64_t)8), std::min(Tc, (int64_t)4), mask_tc->nb[1], 0); + cb(mask_head, "idxkv_mask_tc_head", -1); + } + scores_tc = ggml_add(ctx, scores_tc, mask_tc); + // Clamp after mask to avoid inf in diagnostics and stabilize top-k + scores_tc = ggml_clamp(ctx, scores_tc, -1e30f, 1e30f); + + if (dbg && t0 == 0) { + ggml_tensor * post_mask_head = ggml_view_2d(ctx, scores_tc, std::min(kv_end, (int64_t)8), std::min(Tc, (int64_t)4), scores_tc->nb[1], 0); + cb(post_mask_head, "idxkv_scores_post_mask_head", -1); + } + + if (t0 == 0) { + ggml_tensor * idxkv_scores_post_mask_abs_sum = ggml_sum(ctx, ggml_abs(ctx, scores_tc)); + cb(idxkv_scores_post_mask_abs_sum, "idxkv_scores_post_mask_abs_sum", -1); + if (gf) { + ggml_set_output(idxkv_scores_post_mask_abs_sum); + ggml_build_forward_expand(gf, idxkv_scores_post_mask_abs_sum); + } + } + } + + // top-k for this tile -> [k, Tc] + // Ensure top-k runs on CPU to avoid CUDA backend returning invalid indices for generic shapes + ggml_tensor * scores_for_topk = scores_tc; + // Keep on device by default; only copy to host in debug mode + if (dbg && scores_tc->buffer && !ggml_backend_buffer_is_host(scores_tc->buffer)) { + ggml_tensor * host_scores = ggml_dup_tensor(ctx, scores_tc); + ggml_set_name(host_scores, "idxkv_scores_tc_host"); + host_scores = ggml_cpy(ctx, scores_tc, host_scores); + scores_for_topk = host_scores; + } + if (dbg && sched) { + ggml_backend_t chosen_in = ggml_backend_sched_get_tensor_backend(sched, scores_for_topk); + const char * in_name = chosen_in ? ggml_backend_name(chosen_in) : NULL; + printf("[TOPK] chosen backend for scores_for_topk: %s (non-null=%d)\n", in_name ? in_name : "null", chosen_in ? 1 : 0); + fflush(stdout); + } + // Log scores_for_topk (tile 0 only) and reductions to materialize in eval-callback + if (dbg && t0 == 0) { + cb(scores_for_topk, "idxkv_scores_for_topk", -1); + ggml_tensor * sft_sum = ggml_sum(ctx, scores_for_topk); + ggml_tensor * sft_sumsq = ggml_sum(ctx, ggml_sqr(ctx, scores_for_topk)); + cb(sft_sum, "idxkv_scores_for_topk_sum", -1); + cb(sft_sumsq, "idxkv_scores_for_topk_sumsq", -1); + if (gf) { + ggml_set_output(scores_for_topk); + // profile top-k selection time per tile + + ggml_set_output(sft_sum); + ggml_set_output(sft_sumsq); + ggml_build_forward_expand(gf, scores_for_topk); + ggml_build_forward_expand(gf, sft_sum); + ggml_build_forward_expand(gf, sft_sumsq); + } + } + // Ensure contiguous for CUDA op only if needed, then clamp infinities + ggml_tensor * scores_pre = ggml_is_contiguous(scores_for_topk) ? scores_for_topk : ggml_cont(ctx, scores_for_topk); + ggml_tensor * scores_clamped = ggml_clamp(ctx, scores_pre, -1e30f, 1e30f); + // Compute top-k indices via CUDA radix selection + const int64_t k_tile = std::min(k, scores_clamped->ne[0]); + ggml_tensor * topk_tc = nullptr; + if (have_windows && win_ends) { + // slice starts/ends for this tile [t0, t0+Tc) + ggml_tensor * starts_tile = nullptr; + ggml_tensor * ends_tile = nullptr; + if (win_starts) { + size_t off_s = (size_t)t0 * win_starts->nb[0]; + starts_tile = ggml_view_1d(ctx, win_starts, Tc, off_s); + starts_tile = ggml_cont(ctx, starts_tile); + } + if (win_ends) { + size_t off_e = (size_t)t0 * win_ends->nb[0]; + ends_tile = ggml_view_1d(ctx, win_ends, Tc, off_e); + ends_tile = ggml_cont(ctx, ends_tile); + } + if (dbg) { + printf("[TOPK] using start and end\n"); + fflush(stdout); + } + topk_tc = ggml_sparse_topk_radix_ex(ctx, scores_clamped, (int)k_tile, starts_tile, ends_tile); + } else { + if (dbg) { + printf("[TOPK] not using start and end, have_windows=%s win_ends=%s\n", + have_windows ? "true" : "false", win_ends ? "true" : "false"); + fflush(stdout); + } + topk_tc = ggml_sparse_topk_radix(ctx, scores_clamped, (int)k_tile); + } + if (dbg && t0 == 0) { + cb(topk_tc, "idxkv_topk_radix", -1); + int64_t kk = std::min(k_tile, (int64_t)16); + int64_t tt = std::min(Tc, (int64_t)4); + ggml_tensor * topk_head = ggml_view_2d(ctx, topk_tc, kk, tt, topk_tc->nb[1], 0); + cb(topk_head, "idxkv_topk_indices_head", -1); + if (gf) { ggml_set_output(topk_head); ggml_build_forward_expand(gf, topk_head); } + } + // If we applied a KV window, adjust indices by the start offset + if (have_windows && win_ends) { + // current window start kv_s was 0 in this version; if non-zero in future, add kv_s here + // For now, no offset is needed since kv_s==0 above. Placeholder for future starts support. + } + result = result ? ggml_concat(ctx, result, topk_tc, 1) : topk_tc; + } + + cb(result, "idxkv_topk_indices_k_T", -1); + if (gf && result) { + ggml_set_output(result); + ggml_build_forward_expand(gf, result); + } + // Also provide a float32 view for eval-callback visibility on platforms that skip integer dumps + if (dbg && result) { + ggml_tensor * result_f32 = ggml_cast(ctx, result, GGML_TYPE_F32); + cb(result_f32, "idxkv_topk_indices_k_T_f32", -1); + if (gf) { + ggml_set_output(result_f32); + ggml_build_forward_expand(gf, result_f32); + } + } + if (dbg && result) { + printf("SPARSE TOPK KV-AWARE (INDEXER): result topk_indices dims=[%" PRId64 ",%" PRId64 ",%" PRId64 ",%" PRId64 "] type=%d\n", + result->ne[0], result->ne[1], result->ne[2], result->ne[3], (int)result->type); + fflush(stdout); + } + // Keep indices on device by default to avoid host syncs during get_rows + return result; +} + +ggml_tensor * llama::sparse_attn_topk::topk_radix_indices( + ggml_context * ctx, + ggml_tensor * scores, // [N, T] + int64_t k) { + GGML_ASSERT(scores->type == GGML_TYPE_F32); + ggml_tensor * args[1] = { scores }; + return ggml_custom_4d( + ctx, + GGML_TYPE_I32, + /*ne0*/ k, + /*ne1*/ scores->ne[1], + /*ne2*/ 1, + /*ne3*/ 1, + args, + /*n_args*/ 1, + /*fun*/ radix_topk_custom, + /*n_tasks*/ GGML_N_TASKS_MAX, + /*userdata*/ nullptr); +} + +ggml_tensor * llama::sparse_attn_topk::derive_kv_windows(ggml_context * ctx, ggml_tensor * kq_mask, int64_t T, int64_t N_kv, ggml_tensor ** out_starts, ggml_tensor ** out_ends) { + *out_starts = nullptr; *out_ends = nullptr; + if (!kq_mask) return nullptr; + // Expect kq_mask: [N_kv, PAD(T)] row-contiguous on rows + if (kq_mask->ne[0] != N_kv || kq_mask->ne[1] < T) return nullptr; + // Compute starts=0 and ends per token as last unmasked+1 + ggml_tensor * starts = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, T); + ggml_tensor * ends = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, T); + // Copy mask to host buffer and compute ends + std::vector mask_host((size_t)N_kv * T); + ggml_backend_tensor_get(kq_mask, mask_host.data(), 0, ggml_nbytes(kq_mask)); + // Fill starts with zeros + for (int64_t t = 0; t < T; ++t) { ((int32_t*)starts->data)[t] = 0; } + // ends per column + for (int64_t t = 0; t < T; ++t) { + const float * col = (const float *)((const char*)kq_mask->data + (size_t)t*kq_mask->nb[1]); + int e = find_last_unmasked(col, (int)N_kv, kq_mask->nb[0]); + ((int32_t*)ends->data)[t] = e; + } + *out_starts = starts; *out_ends = ends; + return starts; +} +} // namespace llama diff --git a/src/llama-sparse-topk.h b/src/llama-sparse-topk.h new file mode 100644 index 00000000000..987822ce7e2 --- /dev/null +++ b/src/llama-sparse-topk.h @@ -0,0 +1,39 @@ +#ifndef LLAMA_SPARSE_TOPK_H +#define LLAMA_SPARSE_TOPK_H + +#include +#include "ggml.h" +#include "ggml-backend.h" + +namespace llama { + +using std::function; + +// Top-k selector implementation for DeepSeek V3.2 +struct sparse_attn_topk { + // Lightning Indexer KV-aware selection + static ggml_tensor * select_topk_tokens_indexer_kvaware( + ggml_context * ctx, + ggml_tensor * q_indexer, // [D_index, H_index, T] + ggml_tensor * k_indexer, // [D_index, N_kv] + ggml_tensor * weights, // [H_index, T] + ggml_tensor * kq_mask, // [N_kv, PAD(T)] or [N_kv, T] + int64_t top_k, + const function & cb, + ggml_cgraph * gf, + ggml_backend_sched_t sched, + ggml_backend_t backend_cpu); + + // new: compute top-k indices per column for a scores matrix [N, T] + static ggml_tensor * topk_radix_indices( + ggml_context * ctx, + ggml_tensor * scores, // [N, T] + int64_t k); + // Windows helper: derive per-token [start,end) from mask or used_kv + static ggml_tensor * derive_kv_windows(ggml_context * ctx, ggml_tensor * kq_mask, int64_t T, int64_t N_kv, ggml_tensor ** out_starts, ggml_tensor ** out_ends); +}; + +} // namespace llama + +#endif // LLAMA_SPARSE_TOPK_H + diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index da938af03bf..16c82244871 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -307,6 +307,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { }; break; case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM: + case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_2_LLM: case LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE: regex_exprs = { "\\p{N}{1,3}", @@ -1846,6 +1847,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "deepseek-v3") { pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM; clean_spaces = false; + } else if ( + tokenizer_pre == "deepseek-v3.2") { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_2_LLM; + clean_spaces = false; } else if ( tokenizer_pre == "falcon") { pre_type = LLAMA_VOCAB_PRE_TYPE_FALCON; diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 0d2f28c36c8..4fca21f6dde 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -48,6 +48,7 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37, LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38, LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_2_LLM = 40, }; struct LLM_KV; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d9cc5e933f4..3714c23dc33 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -226,3 +226,44 @@ target_link_libraries(${TEST_TARGET} PRIVATE llama) llama_build_and_test(test-alloc.cpp) target_include_directories(test-alloc PRIVATE ${PROJECT_SOURCE_DIR}/ggml/src) + + +# Unit test for sparse_attn_indexer::idx_compute_scores_tile +llama_build_and_test(test-indexer-scores-tile.cpp) +target_include_directories(test-indexer-scores-tile PRIVATE ${PROJECT_SOURCE_DIR}/src) +# Sparse attention tests +llama_build_and_test(test-sparse-attn.cpp) +target_include_directories(test-sparse-attn PRIVATE ${PROJECT_SOURCE_DIR}/ggml/src) + +# Reproducer for no_alloc sparse indexer crash +llama_build_and_test(test-sparse-attn-noalloc.cpp) + +# Radix top-k unit test for sparse indexer +llama_build_and_test(test-sparse-topk-radix.cpp) + +# CUDA radix top-k host wrapper quick test: reuse same source (it includes CUDA-only path) +if (GGML_CUDA) + llama_build_and_test(test-sparse-topk-radix-cuda.cpp) + + llama_build_and_test(test-sparse-topk-histogram-cuda.cpp) + llama_build_and_test(test-sparse-topk-select-cuda.cpp) + llama_build_and_test(test-sparse-topk-select-stress-cuda.cpp) + llama_build_and_test(test-sparse-topk-radix-stress-cuda.cpp) + llama_build_and_test(test-sparse-topk-op-cuda.cpp) + llama_build_and_test(test-indexer-fused-op-cuda.cpp) + find_package(CUDAToolkit) + llama_build_and_test(test-fp8-e4m3-cutlass-vs-cpu.cpp) + target_include_directories(test-fp8-e4m3-cutlass-vs-cpu PRIVATE + ${CUDAToolkit_INCLUDE_DIRS} + ${PROJECT_SOURCE_DIR}/ggml/src/ggml-cuda/vendors/cutlass/include + ) + llama_build_and_test(test-sparse-attn-mqa-cuda.cpp) + llama_build_and_test(test-sparse-mla-decode-mqa-cuda.cpp) +endif() + +llama_build_and_test(test-sparse-mla-decode-cuda.cpp) + + +if (GGML_CUDA) + llama_build_and_test(test-indexer-triplet-vs-fused.cpp) +endif() diff --git a/tests/CTestTestfile.cmake b/tests/CTestTestfile.cmake new file mode 100644 index 00000000000..e8b948814b5 --- /dev/null +++ b/tests/CTestTestfile.cmake @@ -0,0 +1,92 @@ +# CMake generated Testfile for +# Source directory: /home/jesse/sandbox/llama.cpp/tests +# Build directory: /home/jesse/sandbox/llama.cpp/tests +# +# This file includes the relevant testing commands required for +# testing this directory and lists subdirectories to be tested as well. +add_test(test-tokenizer-0-bert-bge "/home/jesse/sandbox/llama.cpp/bin/test-tokenizer-0" "/home/jesse/sandbox/llama.cpp/models/ggml-vocab-bert-bge.gguf") +set_tests_properties(test-tokenizer-0-bert-bge PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;36;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;114;llama_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-tokenizer-0-command-r "/home/jesse/sandbox/llama.cpp/bin/test-tokenizer-0" "/home/jesse/sandbox/llama.cpp/models/ggml-vocab-command-r.gguf") +set_tests_properties(test-tokenizer-0-command-r PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;36;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;115;llama_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-tokenizer-0-deepseek-coder "/home/jesse/sandbox/llama.cpp/bin/test-tokenizer-0" "/home/jesse/sandbox/llama.cpp/models/ggml-vocab-deepseek-coder.gguf") +set_tests_properties(test-tokenizer-0-deepseek-coder PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;36;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;116;llama_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-tokenizer-0-deepseek-llm "/home/jesse/sandbox/llama.cpp/bin/test-tokenizer-0" "/home/jesse/sandbox/llama.cpp/models/ggml-vocab-deepseek-llm.gguf") +set_tests_properties(test-tokenizer-0-deepseek-llm PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;36;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;117;llama_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-tokenizer-0-falcon "/home/jesse/sandbox/llama.cpp/bin/test-tokenizer-0" "/home/jesse/sandbox/llama.cpp/models/ggml-vocab-falcon.gguf") +set_tests_properties(test-tokenizer-0-falcon PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;36;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;118;llama_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-tokenizer-0-gpt-2 "/home/jesse/sandbox/llama.cpp/bin/test-tokenizer-0" "/home/jesse/sandbox/llama.cpp/models/ggml-vocab-gpt-2.gguf") +set_tests_properties(test-tokenizer-0-gpt-2 PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;36;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;119;llama_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-tokenizer-0-llama-bpe "/home/jesse/sandbox/llama.cpp/bin/test-tokenizer-0" "/home/jesse/sandbox/llama.cpp/models/ggml-vocab-llama-bpe.gguf") +set_tests_properties(test-tokenizer-0-llama-bpe PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;36;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;120;llama_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-tokenizer-0-llama-spm "/home/jesse/sandbox/llama.cpp/bin/test-tokenizer-0" "/home/jesse/sandbox/llama.cpp/models/ggml-vocab-llama-spm.gguf") +set_tests_properties(test-tokenizer-0-llama-spm PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;36;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;121;llama_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-tokenizer-0-mpt "/home/jesse/sandbox/llama.cpp/bin/test-tokenizer-0" "/home/jesse/sandbox/llama.cpp/models/ggml-vocab-mpt.gguf") +set_tests_properties(test-tokenizer-0-mpt PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;36;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;122;llama_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-tokenizer-0-phi-3 "/home/jesse/sandbox/llama.cpp/bin/test-tokenizer-0" "/home/jesse/sandbox/llama.cpp/models/ggml-vocab-phi-3.gguf") +set_tests_properties(test-tokenizer-0-phi-3 PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;36;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;123;llama_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-tokenizer-0-qwen2 "/home/jesse/sandbox/llama.cpp/bin/test-tokenizer-0" "/home/jesse/sandbox/llama.cpp/models/ggml-vocab-qwen2.gguf") +set_tests_properties(test-tokenizer-0-qwen2 PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;36;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;124;llama_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-tokenizer-0-refact "/home/jesse/sandbox/llama.cpp/bin/test-tokenizer-0" "/home/jesse/sandbox/llama.cpp/models/ggml-vocab-refact.gguf") +set_tests_properties(test-tokenizer-0-refact PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;36;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;125;llama_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-tokenizer-0-starcoder "/home/jesse/sandbox/llama.cpp/bin/test-tokenizer-0" "/home/jesse/sandbox/llama.cpp/models/ggml-vocab-starcoder.gguf") +set_tests_properties(test-tokenizer-0-starcoder PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;36;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;126;llama_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-tokenizers-ggml-vocabs "/home/jesse/sandbox/llama.cpp/tests/test-tokenizers-repo.sh" "https://huggingface.co/ggml-org/vocabs" "/home/jesse/sandbox/llama.cpp/models/ggml-vocabs") +set_tests_properties(test-tokenizers-ggml-vocabs PROPERTIES LABELS "main" WORKING_DIRECTORY "/home/jesse/sandbox/llama.cpp/bin" _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;64;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;129;llama_test_cmd;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-sampling "/home/jesse/sandbox/llama.cpp/bin/test-sampling") +set_tests_properties(test-sampling PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;102;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;143;llama_build_and_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-grammar-parser "/home/jesse/sandbox/llama.cpp/bin/test-grammar-parser") +set_tests_properties(test-grammar-parser PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;102;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;144;llama_build_and_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-grammar-integration "/home/jesse/sandbox/llama.cpp/bin/test-grammar-integration") +set_tests_properties(test-grammar-integration PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;102;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;145;llama_build_and_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-llama-grammar "/home/jesse/sandbox/llama.cpp/bin/test-llama-grammar") +set_tests_properties(test-llama-grammar PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;102;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;146;llama_build_and_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-chat "/home/jesse/sandbox/llama.cpp/bin/test-chat") +set_tests_properties(test-chat PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;102;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;147;llama_build_and_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-json-schema-to-grammar "/home/jesse/sandbox/llama.cpp/bin/test-json-schema-to-grammar") +set_tests_properties(test-json-schema-to-grammar PROPERTIES LABELS "main" WORKING_DIRECTORY "/home/jesse/sandbox/llama.cpp" _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;102;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;150;llama_build_and_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-tokenizer-1-llama-spm "/home/jesse/sandbox/llama.cpp/bin/test-tokenizer-1-spm" "/home/jesse/sandbox/llama.cpp/models/ggml-vocab-llama-spm.gguf") +set_tests_properties(test-tokenizer-1-llama-spm PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;36;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;176;llama_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-chat-parser "/home/jesse/sandbox/llama.cpp/bin/test-chat-parser") +set_tests_properties(test-chat-parser PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;102;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;182;llama_build_and_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-chat-template "/home/jesse/sandbox/llama.cpp/bin/test-chat-template") +set_tests_properties(test-chat-template PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;102;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;183;llama_build_and_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-json-partial "/home/jesse/sandbox/llama.cpp/bin/test-json-partial") +set_tests_properties(test-json-partial PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;102;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;184;llama_build_and_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-log "/home/jesse/sandbox/llama.cpp/bin/test-log") +set_tests_properties(test-log PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;102;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;185;llama_build_and_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-regex-partial "/home/jesse/sandbox/llama.cpp/bin/test-regex-partial") +set_tests_properties(test-regex-partial PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;102;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;186;llama_build_and_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-thread-safety "/home/jesse/sandbox/llama.cpp/bin/test-thread-safety" "-hf" "ggml-org/models" "-hff" "tinyllamas/stories15M-q4_0.gguf" "-ngl" "99" "-p" "The meaning of life is" "-n" "128" "-c" "256" "-ub" "32" "-np" "4" "-t" "2") +set_tests_properties(test-thread-safety PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;102;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;189;llama_build_and_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-arg-parser "/home/jesse/sandbox/llama.cpp/bin/test-arg-parser") +set_tests_properties(test-arg-parser PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;102;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;196;llama_build_and_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-opt "/home/jesse/sandbox/llama.cpp/bin/test-opt") +set_tests_properties(test-opt PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;102;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;201;llama_build_and_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-gguf "/home/jesse/sandbox/llama.cpp/bin/test-gguf") +set_tests_properties(test-gguf PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;102;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;203;llama_build_and_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-backend-ops "/home/jesse/sandbox/llama.cpp/bin/test-backend-ops") +set_tests_properties(test-backend-ops PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;102;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;204;llama_build_and_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-model-load-cancel "/home/jesse/sandbox/llama.cpp/bin/test-model-load-cancel") +set_tests_properties(test-model-load-cancel PROPERTIES LABELS "model" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;102;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;206;llama_build_and_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-autorelease "/home/jesse/sandbox/llama.cpp/bin/test-autorelease") +set_tests_properties(test-autorelease PROPERTIES LABELS "model" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;102;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;207;llama_build_and_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-barrier "/home/jesse/sandbox/llama.cpp/bin/test-barrier") +set_tests_properties(test-barrier PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;102;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;211;llama_build_and_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-quantize-fns "/home/jesse/sandbox/llama.cpp/bin/test-quantize-fns") +set_tests_properties(test-quantize-fns PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;102;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;212;llama_build_and_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-quantize-perf "/home/jesse/sandbox/llama.cpp/bin/test-quantize-perf") +set_tests_properties(test-quantize-perf PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;102;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;213;llama_build_and_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-rope "/home/jesse/sandbox/llama.cpp/bin/test-rope") +set_tests_properties(test-rope PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;102;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;214;llama_build_and_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-mtmd-c-api "/home/jesse/sandbox/llama.cpp/bin/test-mtmd-c-api") +set_tests_properties(test-mtmd-c-api PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;102;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;219;llama_build_and_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-alloc "/home/jesse/sandbox/llama.cpp/bin/test-alloc") +set_tests_properties(test-alloc PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;102;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;227;llama_build_and_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-sparse-attn "/home/jesse/sandbox/llama.cpp/bin/test-sparse-attn") +set_tests_properties(test-sparse-attn PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;102;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;231;llama_build_and_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-sparse-attn-noalloc "/home/jesse/sandbox/llama.cpp/bin/test-sparse-attn-noalloc") +set_tests_properties(test-sparse-attn-noalloc PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;102;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;235;llama_build_and_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-sparse-topk-radix "/home/jesse/sandbox/llama.cpp/bin/test-sparse-topk-radix") +set_tests_properties(test-sparse-topk-radix PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;102;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;238;llama_build_and_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") +add_test(test-sparse-mla-decode-cuda "/home/jesse/sandbox/llama.cpp/bin/test-sparse-mla-decode-cuda") +set_tests_properties(test-sparse-mla-decode-cuda PROPERTIES LABELS "main" WORKING_DIRECTORY "." _BACKTRACE_TRIPLES "/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;102;add_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;255;llama_build_and_test;/home/jesse/sandbox/llama.cpp/tests/CMakeLists.txt;0;") diff --git a/tests/cmake_install.cmake b/tests/cmake_install.cmake new file mode 100644 index 00000000000..5d555c6c9a5 --- /dev/null +++ b/tests/cmake_install.cmake @@ -0,0 +1,704 @@ +# Install script for directory: /home/jesse/sandbox/llama.cpp/tests + +# Set the install prefix +if(NOT DEFINED CMAKE_INSTALL_PREFIX) + set(CMAKE_INSTALL_PREFIX "/usr/local") +endif() +string(REGEX REPLACE "/$" "" CMAKE_INSTALL_PREFIX "${CMAKE_INSTALL_PREFIX}") + +# Set the install configuration name. +if(NOT DEFINED CMAKE_INSTALL_CONFIG_NAME) + if(BUILD_TYPE) + string(REGEX REPLACE "^[^A-Za-z0-9_]+" "" + CMAKE_INSTALL_CONFIG_NAME "${BUILD_TYPE}") + else() + set(CMAKE_INSTALL_CONFIG_NAME "Release") + endif() + message(STATUS "Install configuration: \"${CMAKE_INSTALL_CONFIG_NAME}\"") +endif() + +# Set the component getting installed. +if(NOT CMAKE_INSTALL_COMPONENT) + if(COMPONENT) + message(STATUS "Install component: \"${COMPONENT}\"") + set(CMAKE_INSTALL_COMPONENT "${COMPONENT}") + else() + set(CMAKE_INSTALL_COMPONENT) + endif() +endif() + +# Install shared libraries without execute permission? +if(NOT DEFINED CMAKE_INSTALL_SO_NO_EXE) + set(CMAKE_INSTALL_SO_NO_EXE "1") +endif() + +# Is this installation the result of a crosscompile? +if(NOT DEFINED CMAKE_CROSSCOMPILING) + set(CMAKE_CROSSCOMPILING "FALSE") +endif() + +# Set default install directory permissions. +if(NOT DEFINED CMAKE_OBJDUMP) + set(CMAKE_OBJDUMP "/usr/bin/objdump") +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-tokenizer-0" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-tokenizer-0") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-tokenizer-0" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-tokenizer-0") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-tokenizer-0" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-tokenizer-0") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-tokenizer-0" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-tokenizer-0") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sampling" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sampling") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sampling" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-sampling") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sampling" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sampling") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sampling" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sampling") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-grammar-parser" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-grammar-parser") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-grammar-parser" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-grammar-parser") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-grammar-parser" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-grammar-parser") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-grammar-parser" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-grammar-parser") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-grammar-integration" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-grammar-integration") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-grammar-integration" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-grammar-integration") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-grammar-integration" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-grammar-integration") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-grammar-integration" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-grammar-integration") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-llama-grammar" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-llama-grammar") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-llama-grammar" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-llama-grammar") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-llama-grammar" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-llama-grammar") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-llama-grammar" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-llama-grammar") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-chat" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-chat") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-chat" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-chat") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-chat" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-chat") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-chat" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-chat") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-json-schema-to-grammar" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-json-schema-to-grammar") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-json-schema-to-grammar" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-json-schema-to-grammar") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-json-schema-to-grammar" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-json-schema-to-grammar") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-json-schema-to-grammar" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-json-schema-to-grammar") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-quantize-stats" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-quantize-stats") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-quantize-stats" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-quantize-stats") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-quantize-stats" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-quantize-stats") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-quantize-stats" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-quantize-stats") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-gbnf-validator" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-gbnf-validator") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-gbnf-validator" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-gbnf-validator") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-gbnf-validator" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-gbnf-validator") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-gbnf-validator" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-gbnf-validator") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-tokenizer-1-bpe" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-tokenizer-1-bpe") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-tokenizer-1-bpe" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-tokenizer-1-bpe") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-tokenizer-1-bpe" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-tokenizer-1-bpe") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-tokenizer-1-bpe" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-tokenizer-1-bpe") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-tokenizer-1-spm" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-tokenizer-1-spm") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-tokenizer-1-spm" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-tokenizer-1-spm") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-tokenizer-1-spm" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-tokenizer-1-spm") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-tokenizer-1-spm" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-tokenizer-1-spm") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-chat-parser" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-chat-parser") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-chat-parser" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-chat-parser") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-chat-parser" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-chat-parser") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-chat-parser" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-chat-parser") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-chat-template" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-chat-template") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-chat-template" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-chat-template") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-chat-template" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-chat-template") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-chat-template" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-chat-template") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-json-partial" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-json-partial") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-json-partial" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-json-partial") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-json-partial" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-json-partial") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-json-partial" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-json-partial") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-log" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-log") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-log" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-log") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-log" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-log") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-log" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-log") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-regex-partial" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-regex-partial") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-regex-partial" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-regex-partial") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-regex-partial" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-regex-partial") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-regex-partial" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-regex-partial") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-thread-safety" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-thread-safety") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-thread-safety" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-thread-safety") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-thread-safety" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-thread-safety") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-thread-safety" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-thread-safety") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-arg-parser" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-arg-parser") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-arg-parser" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-arg-parser") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-arg-parser" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-arg-parser") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-arg-parser" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-arg-parser") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-opt" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-opt") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-opt" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-opt") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-opt" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-opt") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-opt" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-opt") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-gguf" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-gguf") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-gguf" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-gguf") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-gguf" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-gguf") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-gguf" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-gguf") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-backend-ops" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-backend-ops") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-backend-ops" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-backend-ops") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-backend-ops" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-backend-ops") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-backend-ops" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-backend-ops") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-model-load-cancel" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-model-load-cancel") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-model-load-cancel" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-model-load-cancel") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-model-load-cancel" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-model-load-cancel") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-model-load-cancel" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-model-load-cancel") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-autorelease" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-autorelease") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-autorelease" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-autorelease") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-autorelease" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-autorelease") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-autorelease" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-autorelease") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-barrier" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-barrier") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-barrier" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-barrier") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-barrier" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-barrier") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-barrier" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-barrier") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-quantize-fns" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-quantize-fns") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-quantize-fns" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-quantize-fns") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-quantize-fns" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-quantize-fns") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-quantize-fns" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-quantize-fns") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-quantize-perf" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-quantize-perf") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-quantize-perf" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-quantize-perf") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-quantize-perf" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-quantize-perf") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-quantize-perf" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-quantize-perf") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-rope" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-rope") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-rope" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-rope") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-rope" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-rope") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-rope" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-rope") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-mtmd-c-api" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-mtmd-c-api") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-mtmd-c-api" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-mtmd-c-api") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-mtmd-c-api" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-mtmd-c-api") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-mtmd-c-api" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-mtmd-c-api") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-alloc" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-alloc") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-alloc" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-alloc") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-alloc" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-alloc") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-alloc" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-alloc") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sparse-attn" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sparse-attn") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sparse-attn" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-sparse-attn") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sparse-attn" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sparse-attn") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sparse-attn" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sparse-attn") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sparse-attn-noalloc" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sparse-attn-noalloc") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sparse-attn-noalloc" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-sparse-attn-noalloc") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sparse-attn-noalloc" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sparse-attn-noalloc") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sparse-attn-noalloc" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sparse-attn-noalloc") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sparse-topk-radix" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sparse-topk-radix") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sparse-topk-radix" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-sparse-topk-radix") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sparse-topk-radix" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sparse-topk-radix") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sparse-topk-radix" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sparse-topk-radix") + endif() + endif() +endif() + +if(CMAKE_INSTALL_COMPONENT STREQUAL "Unspecified" OR NOT CMAKE_INSTALL_COMPONENT) + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sparse-mla-decode-cuda" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sparse-mla-decode-cuda") + file(RPATH_CHECK + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sparse-mla-decode-cuda" + RPATH "") + endif() + file(INSTALL DESTINATION "${CMAKE_INSTALL_PREFIX}/bin" TYPE EXECUTABLE FILES "/home/jesse/sandbox/llama.cpp/bin/test-sparse-mla-decode-cuda") + if(EXISTS "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sparse-mla-decode-cuda" AND + NOT IS_SYMLINK "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sparse-mla-decode-cuda") + file(RPATH_CHANGE + FILE "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sparse-mla-decode-cuda" + OLD_RPATH "/home/jesse/sandbox/llama.cpp/bin:" + NEW_RPATH "") + if(CMAKE_INSTALL_DO_STRIP) + execute_process(COMMAND "/usr/bin/strip" "$ENV{DESTDIR}${CMAKE_INSTALL_PREFIX}/bin/test-sparse-mla-decode-cuda") + endif() + endif() +endif() + diff --git a/tests/test-fp8-e4m3-cutlass-vs-cpu.cpp b/tests/test-fp8-e4m3-cutlass-vs-cpu.cpp new file mode 100644 index 00000000000..90b2c685856 --- /dev/null +++ b/tests/test-fp8-e4m3-cutlass-vs-cpu.cpp @@ -0,0 +1,89 @@ +#include +#include +#include +#include +#include +#include +#include + +// CUTLASS FP8 E4M3 reference implementation +#include "cutlass/float8.h" + +using cutlass::float_e4m3_t; + +extern "C" { +void ggml_e4m3_to_fp32_row(const uint8_t * x, float * y, int64_t k); +void ggml_fp32_to_e4m3_row_ref(const float * x, uint8_t * y, int64_t k); +} + + +static inline uint32_t f32_bits(float x) { + uint32_t u; std::memcpy(&u, &x, sizeof(u)); return u; +} + +int main() { +#ifndef GGML_USE_CUDA + std::printf("CUDA not enabled; skipping fp8 e4m3 test\n"); + return 0; +#else + // --- Decode test: 256 FP8 codes --- + int mism_decode = 0; + float max_abs_decode = 0.0f; + for (int b = 0; b < 256; ++b) { + uint8_t code = (uint8_t) b; + float cpu; ggml_e4m3_to_fp32_row(&code, &cpu, 1); + float ref = float_e4m3_t::to_float(float_e4m3_t::bitcast(code)); + if (std::isnan(cpu) && std::isnan(ref)) continue; + if (cpu == 0.0f && ref == 0.0f) continue; // treat +0 and -0 as equal + uint32_t bc = f32_bits(cpu); + uint32_t br = f32_bits(ref); + if (bc != br) { + ++mism_decode; + float da = std::fabs(cpu - ref); + if (da > max_abs_decode) max_abs_decode = da; + std::printf("DECODE MISM b=%3d code=0x%02x cpu=%08x ref=%08x cpu=%g ref=%g isNaN(cpu)=%d isNaN(ref)=%d\n", + b, code, bc, br, cpu, ref, std::isnan(cpu), std::isnan(ref)); + } + } + std::printf("FP8 E4M3 decode mismatches=%d max_abs_diff=%.6f\n", mism_decode, max_abs_decode); + if (mism_decode != 0) { + std::printf("TEST FAIL (decode)\n"); + return 1; + } + + // --- Encode test: random + edge cases --- + const int N = 131072; + std::vector vals(N); + std::mt19937 rng(123); + std::uniform_real_distribution dist(-1000.0f, 1000.0f); + for (int i = 0; i < N; ++i) vals[i] = dist(rng); + if (N >= 16) { + vals[0] = 0.0f; + vals[1] = -0.0f; + vals[2] = std::numeric_limits::infinity(); + vals[3] = -std::numeric_limits::infinity(); + vals[4] = std::numeric_limits::quiet_NaN(); + vals[5] = std::numeric_limits::denorm_min(); + vals[6] = 1e-10f; + vals[7] = -1e-10f; + } + + int mism_encode = 0; + for (int i = 0; i < N; ++i) { + float x = vals[i]; + uint8_t c; ggml_fp32_to_e4m3_row_ref(&x, &c, 1); + uint8_t r = float_e4m3_t::from_float(x).storage; + if (c != r) { + ++mism_encode; + } + } + std::printf("FP8 E4M3 encode mismatches=%d over %d samples\n", mism_encode, N); + if (mism_encode != 0) { + std::printf("TEST FAIL (encode)\n"); + return 1; + } + + std::printf("FP8 E4M3 CPU helper matches CUTLASS float_e4m3_t: TEST PASS\n"); + return 0; +#endif +} diff --git a/tests/test-indexer-fused-op-cuda.cpp b/tests/test-indexer-fused-op-cuda.cpp new file mode 100644 index 00000000000..2dde480cae8 --- /dev/null +++ b/tests/test-indexer-fused-op-cuda.cpp @@ -0,0 +1,528 @@ +#include +#include +#include + +#include "../src/llama-sparse-indexer.h" + +#include +#include +#include +#include +#include +#include +#include + +#include +using namespace llama; +// Helpers to simulate CUDA half rounding (round-to-nearest-even) on CPU +static inline uint16_t float_to_half_bits_rtne(float f) { + uint32_t x; std::memcpy(&x, &f, sizeof(x)); + uint32_t sign = (x >> 16) & 0x8000u; + int32_t exp = (int32_t)((x >> 23) & 0xFFu) - 127 + 15; + uint32_t mant = x & 0x007FFFFFu; + if (exp <= 0) { + if (exp < -10) return (uint16_t)sign; + mant |= 0x00800000u; + uint32_t sub = mant >> (1 - exp); + if (sub & 0x00001000u) sub += 0x00002000u; // round to nearest even + return (uint16_t)(sign | (sub >> 13)); + } else if (exp >= 31) { + if (mant == 0) return (uint16_t)(sign | 0x7C00u); + mant >>= 13; + return (uint16_t)(sign | 0x7C00u | mant | (mant == 0)); + } else { + if (mant & 0x00001000u) { + mant += 0x00002000u; + if (mant & 0x00800000u) { + mant = 0; + exp += 1; + if (exp >= 31) return (uint16_t)(sign | 0x7C00u); + } + } + return (uint16_t)(sign | ((uint32_t)exp << 10) | (mant >> 13)); + } +} +static inline float half_bits_to_float(uint16_t h) { + uint32_t sign = (h & 0x8000u) << 16; + uint32_t exp = (h & 0x7C00u) >> 10; + uint32_t mant = (h & 0x03FFu); + uint32_t out; + if (exp == 0) { + if (mant == 0) out = sign; + else { + // subnormal + exp = 127 - 15 + 1; + while ((mant & 0x0400u) == 0) { + mant <<= 1; + exp--; + } + mant &= 0x03FFu; + out = sign | (exp << 23) | (mant << 13); + } + } else if (exp == 0x1F) { + out = sign | 0x7F800000u | (mant << 13); + } else { + uint32_t exp_f = exp - 15 + 127; + out = sign | (exp_f << 23) | (mant << 13); + } + float f; std::memcpy(&f, &out, sizeof(f)); + return f; +} +static inline float f32_to_f16_to_f32(float x) { + return half_bits_to_float(float_to_half_bits_rtne(x)); +} + + +static inline uint16_t float_to_bf16_bits_rtne(float f) { + uint32_t x; std::memcpy(&x, &f, sizeof(x)); + uint32_t l = x & 0x0000FFFFu; + uint32_t u = x >> 16; + uint32_t round = (l > 0x8000u) || (l == 0x8000u && (u & 1)); + return (uint16_t)(u + round); +} +static inline float bf16_bits_to_float(uint16_t h) { + uint32_t u = ((uint32_t)h) << 16; + float f; std::memcpy(&f, &u, sizeof(f)); + return f; +} +static inline float f32_to_bf16_to_f32(float x) { + return bf16_bits_to_float(float_to_bf16_bits_rtne(x)); +} + + + +static void cpu_indexer_logits_bf16like(const float *Q, const float *K, const float *W, const float *k_scale, + int D, int H, int Tc, int kv, std::vector &out) { + out.assign((size_t)kv*Tc, 0.0f); + for (int tc=0; tc &out) { + out.assign((size_t)kv*Tc, 0.0f); + for (int tc=0; tc &out) { + out.assign((size_t)kv*Tc, 0.0f); + for (int tc=0; tc dist(-1.0f,1.0f); + + std::vector Q((size_t)D*Tc*H), K((size_t)D*kv), W((size_t)H*Tc), KS(kv); + for (auto &v:Q) v=dist(rng); + for (auto &v:K) v=dist(rng); + for (auto &v:W) v=dist(rng); + for (auto &v:KS) v=std::max(0.1f, std::abs(dist(rng))); + + ggml_init_params ip{}; + ip.mem_size = 64ull*1024*1024; + ip.no_alloc = true; + ggml_context * ctx = ggml_init(ip); + if (!ctx) { + printf("ctx init failed\n"); + return 1; + } + + ggml_tensor * q2d = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, D, Tc*H); + // Build CPU reference that matches the math path (FP16 in TL kernel) + + ggml_tensor * k2d = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, D, kv); + const char *tl = std::getenv("LLAMA_INDEXER_TL_PORT"); + bool use_bf16_ref = (wmma_env && std::atoi(wmma_env) != 0 && H < 16); + bool use_fp16_ref = (!use_bf16_ref) && ((tl && std::atoi(tl) != 0) || (wmma_env && std::atoi(wmma_env) != 0)); + + ggml_tensor * w2d = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, H, Tc); + ggml_tensor * ks = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, kv); + + // Build simple per-token KV windows for performance test + std::vector starts_h(Tc, 0); + std::vector ends_h(Tc, end); + // Create GGML tensors for starts/ends + ggml_tensor * starts = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, Tc); + ggml_tensor * ends = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, Tc); + // Use ex variant to pass windows + ggml_tensor * out = ggml_indexer_logits_fused_ex(ctx, q2d, k2d, w2d, ks, starts, ends); + + ggml_cgraph * gf = ggml_new_graph(ctx); + ggml_build_forward_expand(gf, out); + + ggml_backend_dev_t cuda_dev = ggml_backend_dev_by_name("CUDA0"); + if (!cuda_dev) { + // fallback: pick any CUDA-like device if name lookup fails + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t d = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(d) == GGML_BACKEND_DEVICE_TYPE_GPU) { + cuda_dev = d; + break; + } + } + } + + // Continue test setup + + ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (!cpu_dev) { + printf("no CPU device found\n"); + ggml_free(ctx); + return 1; + } + + std::vector O_cpu; + + if (use_fp8_ref) { + // When TL FP8 path is enabled, build scores via the proven CPU Lightning Indexer + // implementation (idx_compute_scores_tile), then apply FP8 emulation if desired + ggml_init_params ip_ref{}; + ip_ref.mem_size = 256ull * 1024 * 1024; + ip_ref.mem_buffer = nullptr; + ip_ref.no_alloc = false; + ggml_context * ctx_ref = ggml_init(ip_ref); + if (!ctx_ref) { + printf("ctx_ref init failed\n"); + return 1; + } + + const int64_t D_ref = D; + const int64_t H_ref = H; + const int64_t Tc_ref = Tc; + const int64_t kv_ref = kv; + + ggml_tensor * q3d = ggml_new_tensor_3d(ctx_ref, GGML_TYPE_F32, D_ref, Tc_ref, H_ref); + ggml_tensor * a_k = ggml_new_tensor_2d(ctx_ref, GGML_TYPE_F32, D_ref, kv_ref); + ggml_tensor * w2d = ggml_new_tensor_2d(ctx_ref, GGML_TYPE_F32, H_ref, Tc_ref); + ggml_tensor * ks1d = ggml_new_tensor_1d(ctx_ref, GGML_TYPE_F32, kv_ref); + ggml_tensor * ks2d = ggml_reshape_2d(ctx_ref, ks1d, kv_ref, 1); + + // Remap Q from [D, Tc*H] layout (column-major, with column index = tc*H + h) + // into q3d [D, Tc, H] so that idx_compute_scores_tile sees the same + // logical q_{t,h} vectors as the fused CUDA kernel. + { + std::vector Q3d((size_t)D_ref * (size_t)Tc_ref * (size_t)H_ref); + for (int64_t tc = 0; tc < Tc_ref; ++tc) { + for (int64_t h = 0; h < H_ref; ++h) { + for (int64_t d = 0; d < D_ref; ++d) { + size_t src = (size_t)d + (size_t)D_ref * ((size_t)tc * (size_t)H_ref + (size_t)h); + size_t dst = (size_t)d + (size_t)D_ref * ((size_t)tc + (size_t)Tc_ref * (size_t)h); + Q3d[dst] = Q[src]; + } + } + } + std::memcpy(q3d->data, Q3d.data(), Q3d.size() * sizeof(float)); + } + std::memcpy(a_k->data, K.data(), ggml_nbytes(a_k)); + std::memcpy(w2d->data, W.data(), ggml_nbytes(w2d)); + std::memcpy(ks1d->data, KS.data(), ggml_nbytes(ks1d)); + + ggml_tensor * scores_ref = llama::sparse_attn_indexer::idx_compute_scores_tile( + ctx_ref, + q3d, + a_k, + w2d, + ks2d, + D_ref, + H_ref, + Tc_ref, + kv_ref, + /*t0=*/0); + + ggml_cgraph * gf_ref = ggml_new_graph(ctx_ref); + ggml_build_forward_expand(gf_ref, scores_ref); + ggml_graph_compute_with_ctx(ctx_ref, gf_ref, /*n_threads=*/1); + + O_cpu.resize((size_t)kv_ref * (size_t)Tc_ref); + std::memcpy(O_cpu.data(), scores_ref->data, O_cpu.size() * sizeof(float)); + + // NOTE: idx_compute_scores_tile already mirrors the TL FP8 math when + // LLAMA_TL_FP8 is enabled, so we no longer apply an extra f32->fp8->f32 + // round-trip on the final scores here. This keeps the CPU reference + // numerically aligned with the TileLang CUDA kernel output. + + ggml_free(ctx_ref); + } else if (use_bf16_ref) { + cpu_indexer_logits_bf16like(Q.data(), K.data(), W.data(), KS.data(), D,H,Tc,kv, O_cpu); + } else if (use_fp16_ref) { + cpu_indexer_logits_f16like(Q.data(), K.data(), W.data(), KS.data(), D,H,Tc,kv, O_cpu); + } else { + cpu_indexer_logits(Q.data(), K.data(), W.data(), KS.data(), D,H,Tc,kv, O_cpu); + } + // Zero CPU reference outside window to align with GPU windowing for this test + for (int tc=0; tcne[0], (long long)q2d->ne[1], (size_t)q2d->nb[0], (size_t)q2d->nb[1], (int)q2d->type); + printf("k2d ne=[%lld,%lld] nb=[%zu,%zu] type=%d\n", (long long)k2d->ne[0], (long long)k2d->ne[1], (size_t)k2d->nb[0], (size_t)k2d->nb[1], (int)k2d->type); + printf("w2d ne=[%lld,%lld] nb=[%zu,%zu] type=%d\n", (long long)w2d->ne[0], (long long)w2d->ne[1], (size_t)w2d->nb[0], (size_t)w2d->nb[1], (int)w2d->type); + printf("ks ne=[%lld] nb0=%zu type=%d\n", (long long)ks->ne[0], (size_t)ks->nb[0], (int)ks->type); + ggml_backend_t bout = ggml_backend_sched_get_tensor_backend(sched, out); + printf("out ne=[%lld,%lld] nb=[%zu,%zu] type=%d backend=%s\n", (long long)out->ne[0], (long long)out->ne[1], (size_t)out->nb[0], (size_t)out->nb[1], (int)out->type, bout?ggml_backend_name(bout):"null"); + + ggml_backend_tensor_set(q2d, Q.data(), 0, ggml_nbytes(q2d)); + ggml_backend_tensor_set(k2d, K.data(), 0, ggml_nbytes(k2d)); + ggml_backend_tensor_set(w2d, W.data(), 0, ggml_nbytes(w2d)); + ggml_backend_tensor_set(ks, KS.data(),0, ggml_nbytes(ks)); + // ensure starts/ends are uploaded to device + ggml_backend_tensor_set(starts, starts_h.data(), 0, sizeof(int32_t)*(size_t)Tc); + ggml_backend_tensor_set(ends, ends_h.data(), 0, sizeof(int32_t)*(size_t)Tc); + + // Warmup run: compute once without sparse profiling so first-call overheads + // (e.g., TMA descriptor setup, JIT init) don't pollute timing logs. + { + const char *sav_prof = std::getenv("LLAMA_SPARSE_PROF"); + const char *sav_prof_ea = std::getenv("LLAMA_SPARSE_PROF_EACH"); + if (sav_prof) unsetenv("LLAMA_SPARSE_PROF"); + if (sav_prof_ea) unsetenv("LLAMA_SPARSE_PROF_EACH"); + ggml_backend_sched_graph_compute(sched, gf); + if (sav_prof) setenv("LLAMA_SPARSE_PROF", sav_prof, 1); else unsetenv("LLAMA_SPARSE_PROF"); + if (sav_prof_ea) setenv("LLAMA_SPARSE_PROF_EACH", sav_prof_ea, 1); else unsetenv("LLAMA_SPARSE_PROF_EACH"); + } + + printf("starting compute\n"); + ggml_status st = ggml_backend_sched_graph_compute(sched, gf); + printf("here5.4\n"); + fflush(stdout); + if (st != GGML_STATUS_SUCCESS) { + printf("backend compute failed: %d\n", (int)st); + fflush(stdout); + ggml_backend_sched_free(sched); + if (cuda) ggml_backend_free(cuda); + ggml_backend_free(cpu); + ggml_free(ctx); + return 1; + } + + std::vector O_gpu((size_t)kv*Tc, 0.0f); + ggml_backend_tensor_get(out, O_gpu.data(), 0, ggml_nbytes(out)); + // already computed O_cpu above matching math path + printf("D=%d H=%d Tc=%d kv=%d\n", D,H,Tc,kv); + int tt = Tc < 2 ? Tc : 2; + int kk = kv < 8 ? kv : 8; + for (int tc=0; tcmax_abs) max_abs=da; + if (da>1e-3f) mism++; + } + printf("fused op ggml test: mism=%d max_abs=%.6f\n", mism, max_abs); + printf("TEST %s\n", mism==0?"PASS":"FAIL"); + + ggml_backend_sched_free(sched); + ggml_backend_free(cuda); + ggml_backend_free(cpu); + ggml_free(ctx); + + + // Optional prefill-scale bench (env-controlled) + const char *BENCH = std::getenv("LLAMA_INDEXER_BENCH"); + if (BENCH && std::atoi(BENCH) != 0) { + int BD = std::getenv("LLAMA_INDEXER_BENCH_D") ? std::atoi(std::getenv("LLAMA_INDEXER_BENCH_D")) : 128; + int BH = std::getenv("LLAMA_INDEXER_BENCH_H") ? std::atoi(std::getenv("LLAMA_INDEXER_BENCH_H")) : 64; + int BT = std::getenv("LLAMA_INDEXER_BENCH_T") ? std::atoi(std::getenv("LLAMA_INDEXER_BENCH_T")) : 512; + int BK = std::getenv("LLAMA_INDEXER_BENCH_KV") ? std::atoi(std::getenv("LLAMA_INDEXER_BENCH_KV")) : 32768; + int iters = std::getenv("LLAMA_INDEXER_BENCH_ITERS") ? std::atoi(std::getenv("LLAMA_INDEXER_BENCH_ITERS")) : 10; + printf("[BENCH] running prefill-scale bench D=%d H=%d Tc=%d kv=%d iters=%d\n", BD, BH, BT, BK, iters); + ggml_init_params ip2{}; + ip2.mem_size = 256ull*1024*1024; + ip2.no_alloc = true; + ggml_context * ctx2 = ggml_init(ip2); + if (!ctx2) printf("ctx2 init failed\n"); + else { + std::mt19937 rng2(2025); + std::uniform_real_distribution dist2(-1.0f,1.0f); + std::vector QB((size_t)BD*BT*BH), KB((size_t)BD*BK), WB((size_t)BH*BT), KSB(BK); + for (auto &v:QB) v=dist2(rng2); + for (auto &v:KB) v=dist2(rng2); + for (auto &v:WB) v=dist2(rng2); + for (auto &v:KSB) v=std::max(0.1f, std::abs(dist2(rng2))); + ggml_tensor * q2 = ggml_new_tensor_2d(ctx2, GGML_TYPE_F32, BD, BT*BH); + ggml_tensor * k2 = ggml_new_tensor_2d(ctx2, GGML_TYPE_F32, BD, BK); + ggml_tensor * w2 = ggml_new_tensor_2d(ctx2, GGML_TYPE_F32, BH, BT); + ggml_tensor * ks2= ggml_new_tensor_1d(ctx2, GGML_TYPE_F32, BK); + ggml_backend_dev_t cuda_dev2 = ggml_backend_dev_by_name("CUDA0"); + if (!cuda_dev2) { + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t d = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(d) == GGML_BACKEND_DEVICE_TYPE_GPU) { + cuda_dev2 = d; + break; + } + } + } + ggml_backend_dev_t cpu_dev2 = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + ggml_backend_t cuda2 = cuda_dev2 ? ggml_backend_dev_init(cuda_dev2, nullptr) : nullptr; + ggml_backend_t cpu2 = ggml_backend_dev_init(cpu_dev2, nullptr); + ggml_backend_t backs2[2] = { cuda2, cpu2 }; + int n_backs2 = cuda2 ? 2 : 1; + ggml_backend_sched_t sched2 = ggml_backend_sched_new(backs2, nullptr, n_backs2, GGML_DEFAULT_GRAPH_SIZE, false, true); + ggml_tensor * out2 = ggml_indexer_logits_fused(ctx2, q2, k2, w2, ks2); + ggml_cgraph * gf2 = ggml_new_graph(ctx2); + ggml_build_forward_expand(gf2, out2); + ggml_backend_sched_reset(sched2); + ggml_backend_sched_alloc_graph(sched2, gf2); + // set data after allocation + ggml_backend_tensor_set(q2, QB.data(), 0, ggml_nbytes(q2)); + ggml_backend_tensor_set(k2, KB.data(), 0, ggml_nbytes(k2)); + ggml_backend_tensor_set(w2, WB.data(), 0, ggml_nbytes(w2)); + ggml_backend_tensor_set(ks2,KSB.data(), 0, ggml_nbytes(ks2)); + // warmup + ggml_backend_sched_graph_compute(sched2, gf2); + double sum_ms=0.0; + for (int it=0; it(t1 - t0).count(); + sum_ms += ms; + } + double avg = sum_ms / iters; + printf("[BENCH] FUSED_INDEXER avg_ms=%.3f (D=%d H=%d Tc=%d kv=%d iters=%d)\n", avg, BD, BH, BT, BK, iters); + ggml_backend_sched_free(sched2); + if (cuda2) ggml_backend_free(cuda2); + ggml_backend_free(cpu2); + ggml_free(ctx2); + } + } + return mism==0?0:1; +#endif +} diff --git a/tests/test-indexer-scores-tile.cpp b/tests/test-indexer-scores-tile.cpp new file mode 100644 index 00000000000..8eb2356f13b --- /dev/null +++ b/tests/test-indexer-scores-tile.cpp @@ -0,0 +1,140 @@ +#include + +#include "llama-sparse-indexer.h" + +#include +#include +#include +#include +#include +#include + +using namespace llama; + +static void cpu_reference_scores( + const std::vector & Q, // [D, Tc, H] + const std::vector & K, // [D, kv] + const std::vector & W, // [H, Tc] + const std::vector & KS, // [kv] + int64_t D, int64_t H, int64_t Tc, int64_t kv, + std::vector & out) { + out.assign((size_t)kv * (size_t)Tc, 0.0f); + for (int64_t tc = 0; tc < Tc; ++tc) { + for (int64_t kv_i = 0; kv_i < kv; ++kv_i) { + float acc = 0.0f; + for (int64_t h = 0; h < H; ++h) { + const float *qv = &Q[(size_t)D*((size_t)tc + (size_t)Tc*h)]; + const float *kvp = &K[(size_t)kv_i*D]; + float dot = 0.0f; + for (int64_t d = 0; d < D; ++d) { + dot += qv[d]*kvp[d]; + } + if (dot < 0.0f) dot = 0.0f; + // W stored column-major [H, Tc] in ggml (H is fastest) + acc += dot * W[(size_t)h + (size_t)H * (size_t)tc]; + } + // out is [kv, Tc] laid out as row-major + out[(size_t)kv_i + (size_t)kv * (size_t)tc] = acc * KS[(size_t)kv_i]; + } + } +} + +int main() { + const int64_t D = 64; + const int64_t H = 8; + const int64_t Tc = 64; + const int64_t kv = 4096; + + // NOTE: keep seed in sync with idx_compute_scores_tile reference in llama-sparse-topk.cpp if changed. + std::mt19937 rng(123); + std::uniform_real_distribution dist(-1.0f, 1.0f); + + std::vector Q((size_t)D*Tc*H); + std::vector K((size_t)D*kv); + std::vector W((size_t)H*Tc); + std::vector KS((size_t)kv); + + for (auto & v : Q) v = dist(rng); + for (auto & v : K) v = dist(rng); + for (auto & v : W) v = dist(rng); + for (auto & v : KS) v = std::max(0.1f, std::fabs(dist(rng))); + + // Build CPU reference + std::vector ref; + cpu_reference_scores(Q, K, W, KS, D, H, Tc, kv, ref); + + ggml_init_params ip{}; + ip.mem_size = 256ull*1024*1024; + ip.no_alloc = false; + ggml_context * ctx = ggml_init(ip); + if (!ctx) { + printf("ctx init failed\n"); + return 1; + } + + // q3d: [D, Tc, H] to match sparse_attn_indexer::idx_compute_scores_tile + ggml_tensor * q3d = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, D, Tc, H); + ggml_tensor * a_k = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, D, kv); + ggml_tensor * w2d = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, H, Tc); + ggml_tensor * ks1d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, kv); + ggml_tensor * ks2d = ggml_reshape_2d(ctx, ks1d, kv, 1); + + std::memcpy(q3d->data, Q.data(), ggml_nbytes(q3d)); + std::memcpy(a_k->data, K.data(), ggml_nbytes(a_k)); + std::memcpy(w2d->data, W.data(), ggml_nbytes(w2d)); + std::memcpy(ks1d->data, KS.data(), ggml_nbytes(ks1d)); + + ggml_tensor * scores = sparse_attn_indexer::idx_compute_scores_tile( + ctx, q3d, a_k, w2d, ks2d, D, H, Tc, kv, 0); + + const int iters = 10; + + auto run_once = [&]() { + ggml_cgraph * gf = ggml_new_graph(ctx); + ggml_build_forward_expand(gf, scores); + ggml_graph_compute_with_ctx(ctx, gf, /*n_threads=*/1); + }; + + // warmup + run_once(); + + auto t0 = std::chrono::high_resolution_clock::now(); + for (int i = 0; i < iters; ++i) { + run_once(); + } + auto t1 = std::chrono::high_resolution_clock::now(); + double elapsed_s = std::chrono::duration(t1 - t0).count(); + double avg_ms = 1000.0 * elapsed_s / iters; + + printf("[IDX_TILE_REF] idx_compute_scores_tile (CPU): D=%lld H=%lld Tc=%lld kv=%lld iters=%d avg_ms=%.3f\n", + (long long)D, (long long)H, (long long)Tc, (long long)kv, iters, avg_ms); + + // scores is [kv, Tc] with row-major layout + std::vector out((size_t)kv * (size_t)Tc); + std::memcpy(out.data(), scores->data, ggml_nbytes(scores)); + + int mism = 0; + float max_abs = 0.0f; + for (size_t i = 0; i < out.size(); ++i) { + float da = std::fabs(out[i] - ref[i]); + if (da > 1e-3f) mism++; + if (da > max_abs) max_abs = da; + } + + printf("idx_compute_scores_tile test: mism=%d max_abs=%.6f\n", mism, max_abs); + // debug: print a small window when mismatches occur + if (mism != 0) { + int tc0 = 0, kv0 = 0; + int tc1 = (int) (Tc > 1 ? Tc - 1 : 0); + int kv1 = (int) (kv > 1 ? kv - 1 : 0); + printf("sample ref/out at [kv,tc]:\n"); + printf(" [0,0]: ref=%f out=%f diff=%f\n", ref[(size_t)kv0 + (size_t)kv * (size_t)tc0], out[(size_t)kv0 + (size_t)kv * (size_t)tc0], out[(size_t)kv0 + (size_t)kv * (size_t)tc0] - ref[(size_t)kv0 + (size_t)kv * (size_t)tc0]); + printf(" [kv-1,0]: ref=%f out=%f diff=%f\n", ref[(size_t)kv1 + (size_t)kv * (size_t)tc0], out[(size_t)kv1 + (size_t)kv * (size_t)tc0], out[(size_t)kv1 + (size_t)kv * (size_t)tc0] - ref[(size_t)kv1 + (size_t)kv * (size_t)tc0]); + printf(" [0,Tc-1]: ref=%f out=%f diff=%f\n", ref[(size_t)kv0 + (size_t)kv * (size_t)tc1], out[(size_t)kv0 + (size_t)kv * (size_t)tc1], out[(size_t)kv0 + (size_t)kv * (size_t)tc1] - ref[(size_t)kv0 + (size_t)kv * (size_t)tc1]); + printf(" [kv-1,Tc-1]: ref=%f out=%f diff=%f\n", ref[(size_t)kv1 + (size_t)kv * (size_t)tc1], out[(size_t)kv1 + (size_t)kv * (size_t)tc1], out[(size_t)kv1 + (size_t)kv * (size_t)tc1] - ref[(size_t)kv1 + (size_t)kv * (size_t)tc1]); + } + printf("TEST %s\n", mism == 0 ? "PASS" : "FAIL"); + + ggml_free(ctx); + return mism == 0 ? 0 : 1; +} diff --git a/tests/test-indexer-triplet-vs-fused.cpp b/tests/test-indexer-triplet-vs-fused.cpp new file mode 100644 index 00000000000..3708f9ff573 --- /dev/null +++ b/tests/test-indexer-triplet-vs-fused.cpp @@ -0,0 +1,355 @@ +#include "../src/llama-sparse-indexer.h" +#include "../src/llama-sparse-topk.h" +#include "../src/llama-model.h" +#include "../src/llama-impl.h" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +using namespace llama; + +struct CpuContext { + ggml_context * ctx; + CpuContext() { + ggml_init_params p{}; + p.mem_size = 256ull * 1024 * 1024; + p.mem_buffer = nullptr; + p.no_alloc = false; + ctx = ggml_init(p); + } + ~CpuContext() { + if (ctx) ggml_free(ctx); + } +}; + +static llama_model * create_test_model(ggml_context * ctx, int num_layers = 1) { + llama_model_params params = llama_model_default_params(); + llama_model * model = new llama_model(params); + model->arch = LLM_ARCH_DEEPSEEK3_2; + model->layers.resize(num_layers); + + for (int i = 0; i < num_layers; ++i) { + llama_layer & layer = model->layers[i]; + + const int64_t hidden_dim = 512; + const int64_t index_head_dim = 128; + const int64_t index_n_heads = 64; + + layer.attn_indexer_wk = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hidden_dim, index_head_dim); + layer.attn_indexer_wq_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hidden_dim, index_head_dim * index_n_heads); + layer.attn_indexer_weights_proj= ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hidden_dim, index_n_heads); + layer.attn_indexer_k_norm_bias = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, index_head_dim); + layer.wq_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hidden_dim, hidden_dim); + + auto fill_tensor = [](ggml_tensor * t) { + size_t n = ggml_nelements(t); + std::vector tmp(n); + for (size_t k = 0; k < n; ++k) tmp[k] = (float) rand() / RAND_MAX; + memcpy(t->data, tmp.data(), n * sizeof(float)); + }; + + fill_tensor(layer.attn_indexer_wk); + fill_tensor(layer.attn_indexer_wq_b); + fill_tensor(layer.attn_indexer_weights_proj); + fill_tensor(layer.attn_indexer_k_norm_bias); + fill_tensor(layer.wq_a); + } + + return model; +} + +static void cleanup_test_model(llama_model * model) { + delete model; +} + +int main() { +#ifndef GGML_USE_CUDA + printf("CUDA not enabled; skipping indexer triplet vs fused test\n"); + return 0; +#else + srand(42); + + CpuContext cpu_ctx; + if (!cpu_ctx.ctx) { + printf("cpu_ctx init failed\n"); + return 1; + } + + llama_model * model = create_test_model(cpu_ctx.ctx, 1); + + const int64_t n_tokens = 2; + const int64_t hidden_dim = 512; + + ggml_tensor * cur = ggml_new_tensor_2d(cpu_ctx.ctx, GGML_TYPE_F32, hidden_dim, n_tokens); + std::vector cur_data((size_t) hidden_dim * (size_t) n_tokens); + for (size_t i = 0; i < cur_data.size(); ++i) cur_data[i] = (float) rand() / RAND_MAX; + memcpy(cur->data, cur_data.data(), cur_data.size() * sizeof(float)); + + auto cb_noop = [](ggml_tensor *, const char *, int) {}; + + IndexerKVTriplet trip = sparse_attn_indexer::compute_indexer_triplet( + cpu_ctx.ctx, *model, 0, cur, n_tokens, + /*mctx*/ nullptr, /*k_idxs*/ nullptr, + /*inp_pos*/ nullptr, /*n_rot*/ 0, /*rope_type*/ 0, /*n_ctx_orig*/ n_tokens, + /*freq_base*/ 10000.0f, /*freq_scale*/ 1.0f, + /*ext_factor*/ 1.0f, /*attn_factor*/ 1.0f, + /*beta_fast*/ 1.0f, /*beta_slow*/ 1.0f, + cb_noop, /*gf*/ nullptr); + + ggml_tensor * q_indexer = trip.q_indexer; // [D_index, H_index, T] + ggml_tensor * k_indexer = trip.k_indexer_cache;// [D_index, N_kv] + ggml_tensor * idx_weights = trip.idx_weights; // [H_index, T] + + if (!q_indexer || !k_indexer || !idx_weights) { + printf("triplet tensors are null; aborting\n"); + cleanup_test_model(model); + return 1; + } + + const int64_t D_index = q_indexer->ne[0]; + const int64_t H_index = q_indexer->ne[1]; + const int64_t T = q_indexer->ne[2]; + const int64_t N_kv = k_indexer->ne[1]; + + printf("[IDX_TRIPLET] D=%" PRId64 " H=%" PRId64 " T=%" PRId64 " N_kv=%" PRId64 "\n", + D_index, H_index, T, N_kv); + + // Build k_scale_2d as ones so it does not alter comparison + ggml_tensor * ks_vec = ggml_new_tensor_1d(cpu_ctx.ctx, GGML_TYPE_F32, N_kv); + std::vector ks_host((size_t) N_kv, 1.0f); + memcpy(ks_vec->data, ks_host.data(), ks_host.size() * sizeof(float)); + ggml_tensor * k_scale_2d = ggml_reshape_2d(cpu_ctx.ctx, ks_vec, N_kv, 1); + + // Build inputs for idx_compute_scores_tile: q3d [D,T,H], a_k [D,N_kv], w2d [H,T] + // q_indexer is [D,H,T]; match sparse_topk pipeline: permute to [D,T,H] then contiguize. + ggml_tensor * q_perm = ggml_permute(cpu_ctx.ctx, q_indexer, 0, 2, 1, 3); + ggml_tensor * q3d = ggml_cont(cpu_ctx.ctx, q_perm); + ggml_tensor * a_k = ggml_reshape_2d(cpu_ctx.ctx, k_indexer, D_index, N_kv); + ggml_tensor * w2d = idx_weights; // already [H_index, T] + + ggml_tensor * scores_tc = sparse_attn_indexer::idx_compute_scores_tile( + cpu_ctx.ctx, q3d, a_k, w2d, k_scale_2d, + D_index, H_index, T, N_kv, + /*t0=*/0); + + ggml_cgraph * gf_cpu = ggml_new_graph(cpu_ctx.ctx); + ggml_build_forward_expand(gf_cpu, scores_tc); + ggml_graph_compute_with_ctx(cpu_ctx.ctx, gf_cpu, /*n_threads=*/1); + + std::vector scores_trip((size_t) N_kv * (size_t) T); + memcpy(scores_trip.data(), scores_tc->data, scores_trip.size() * sizeof(float)); + + // Host buffers laid out as fused kernel expects: Q2d [D, T*H], K2d [D,N_kv], W2d [H,T], KS [N_kv] + std::vector Q2d((size_t) D_index * (size_t) T * (size_t) H_index); + std::vector K2d((size_t) D_index * (size_t) N_kv); + std::vector W2d((size_t) H_index * (size_t) T); + + auto load3 = [](ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2) { + char * base = (char *) t->data; + size_t off = (size_t) i0 * t->nb[0] + + (size_t) i1 * t->nb[1] + + (size_t) i2 * t->nb[2]; + return *(float *) (base + off); + }; + + auto load2 = [](ggml_tensor * t, int64_t i0, int64_t i1) { + char * base = (char *) t->data; + size_t off = (size_t) i0 * t->nb[0] + + (size_t) i1 * t->nb[1]; + return *(float *) (base + off); + }; + + // Q2d: [D_index, T*H_index], column index = t*H_index + h + for (int64_t t = 0; t < T; ++t) { + for (int64_t h = 0; h < H_index; ++h) { + size_t col = (size_t) t * (size_t) H_index + (size_t) h; + for (int64_t d = 0; d < D_index; ++d) { + float v = load3(q3d, d, t, h); + Q2d[(size_t) d + (size_t) D_index * col] = v; + } + } + } + + // K2d: [D_index, N_kv] + for (int64_t kv_i = 0; kv_i < N_kv; ++kv_i) { + for (int64_t d = 0; d < D_index; ++d) { + float v = load2(a_k, d, kv_i); + K2d[(size_t) d + (size_t) D_index * (size_t) kv_i] = v; + } + } + + // W2d: [H_index, T] + for (int64_t t = 0; t < T; ++t) { + for (int64_t h = 0; h < H_index; ++h) { + float v = load2(idx_weights, h, t); + W2d[(size_t) h + (size_t) H_index * (size_t) t] = v; + } + } + + // --- Fused CUDA path --- + ggml_init_params ip_fused{}; + ip_fused.mem_size = 64ull * 1024 * 1024; + ip_fused.no_alloc = true; + ggml_context * ctx_fused = ggml_init(ip_fused); + if (!ctx_fused) { + printf("ctx_fused init failed\n"); + cleanup_test_model(model); + return 1; + } + + ggml_tensor * q2d = ggml_new_tensor_2d(ctx_fused, GGML_TYPE_F32, D_index, T * H_index); + ggml_tensor * k2d = ggml_new_tensor_2d(ctx_fused, GGML_TYPE_F32, D_index, N_kv); + ggml_tensor * w2d_fused = ggml_new_tensor_2d(ctx_fused, GGML_TYPE_F32, H_index, T); + ggml_tensor * ks = ggml_new_tensor_1d(ctx_fused, GGML_TYPE_F32, N_kv); + + ggml_tensor * out = ggml_indexer_logits_fused_ex(ctx_fused, q2d, k2d, w2d_fused, ks, nullptr, nullptr); + ggml_cgraph * gf_fused = ggml_new_graph(ctx_fused); + ggml_build_forward_expand(gf_fused, out); + + ggml_backend_dev_t cuda_dev = ggml_backend_dev_by_name("CUDA0"); + if (!cuda_dev) { + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t d = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(d) == GGML_BACKEND_DEVICE_TYPE_GPU) { + cuda_dev = d; + break; + } + } + } + + if (!cuda_dev) { + printf("no CUDA device found; skipping fused comparison\n"); + ggml_free(ctx_fused); + cleanup_test_model(model); + return 0; + } + + ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + ggml_backend_t cuda = ggml_backend_dev_init(cuda_dev, nullptr); + ggml_backend_t cpu = ggml_backend_dev_init(cpu_dev, nullptr); + + ggml_backend_t backs_arr[2] = { cuda, cpu }; + int n_backs = cuda ? 2 : 1; + ggml_backend_sched_t sched = ggml_backend_sched_new(backs_arr, nullptr, n_backs, GGML_DEFAULT_GRAPH_SIZE, false, true); + if (!sched) { + printf("sched init failed\n"); + ggml_backend_free(cuda); + ggml_backend_free(cpu); + ggml_free(ctx_fused); + cleanup_test_model(model); + return 1; + } + + ggml_backend_sched_reset(sched); + if (cuda) { + ggml_backend_sched_set_tensor_backend(sched, q2d, cuda); + ggml_backend_sched_set_tensor_backend(sched, k2d, cuda); + ggml_backend_sched_set_tensor_backend(sched, w2d_fused, cuda); + ggml_backend_sched_set_tensor_backend(sched, ks, cuda); + ggml_backend_sched_set_tensor_backend(sched, out, cuda); + } + ggml_backend_sched_reserve(sched, gf_fused); + ggml_backend_sched_alloc_graph(sched, gf_fused); + + // Upload host data + ggml_backend_tensor_set(q2d, Q2d.data(), 0, ggml_nbytes(q2d)); + ggml_backend_tensor_set(k2d, K2d.data(), 0, ggml_nbytes(k2d)); + ggml_backend_tensor_set(w2d_fused, W2d.data(), 0, ggml_nbytes(w2d_fused)); + ggml_backend_tensor_set(ks, ks_host.data(), 0, ggml_nbytes(ks)); + + // Prefer tiled CUDA indexer; exactness already validated by test-indexer-fused-op-cuda + //setenv("LLAMA_INDEXER_TL_PORT", "0", 1); + + ggml_backend_sched_graph_compute(sched, gf_fused); + + std::vector scores_fused((size_t) N_kv * (size_t) T); + ggml_backend_tensor_get(out, scores_fused.data(), 0, ggml_nbytes(out)); + + // Compare CPU triplet+tile vs fused CUDA + int mism = 0; + float max_abs = 0.0f; + for (size_t i = 0; i < scores_trip.size(); ++i) { + float da = fabs(scores_trip[i] - scores_fused[i]); + if (da > 1e-2f) ++mism; + if (da > max_abs) max_abs = da; + } + + + // Debug: print a small grid of cpu vs fused scores for inspection + { + int max_kv_print = (int) (N_kv < 8 ? N_kv : 8); + int max_t_print = (int) (T < 8 ? T : 8); + printf("[IDX_TRIPLET_VS_FUSED] sample grid (kv x T):\n"); + for (int kv_i = 0; kv_i < max_kv_print; ++kv_i) { + printf("kv=%d:", kv_i); + for (int t = 0; t < max_t_print; ++t) { + size_t idx = (size_t) kv_i + (size_t) N_kv * (size_t) t; + float cpu_v = scores_trip[idx]; + float fused_v = scores_fused[idx]; + float diff = fused_v - cpu_v; + printf(" [%d,%d]=(%8.3f,%8.3f, diff=%8.3f)\n", kv_i, t, cpu_v, fused_v, diff); + } + printf("\n"); + } + } + + // Direct CPU Eq. (1) from q_indexer / k_indexer / idx_weights for small kv,t + { + auto load3 = [](ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2) { + char * base = (char *) t->data; + size_t off = (size_t) i0 * t->nb[0] + + (size_t) i1 * t->nb[1] + + (size_t) i2 * t->nb[2]; + return *(float *) (base + off); + }; + auto load2 = [](ggml_tensor * t, int64_t i0, int64_t i1) { + char * base = (char *) t->data; + size_t off = (size_t) i0 * t->nb[0] + + (size_t) i1 * t->nb[1]; + return *(float *) (base + off); + }; + int max_kv_print = (int) (N_kv < 4 ? N_kv : 4); + int max_t_print = (int) (T < 4 ? T : 4); + printf("[IDX_TRIPLET_EQ1] sample grid (kv x T) from direct formula:\n"); + for (int kv_i = 0; kv_i < max_kv_print; ++kv_i) { + printf("kv=%d:", kv_i); + for (int t = 0; t < max_t_print; ++t) { + float acc_direct = 0.0f; + for (int h = 0; h < H_index; ++h) { + // dot(q_{t,h}, k_s) + float dot = 0.0f; + for (int d = 0; d < D_index; ++d) { + float qv = load3(q_indexer, d, h, t); + float kv = load2(k_indexer, d, kv_i); + dot += qv * kv; + } + if (dot < 0.0f) dot = 0.0f; + float w = load2(idx_weights, h, t); + acc_direct += dot * w; + } + printf(" [%d,%d]=%8.3f", kv_i, t, acc_direct); + } + printf("\n"); + } + } + printf("[IDX_TRIPLET_VS_FUSED] mism=%d max_abs=%.6f\n", mism, max_abs); + printf("TEST %s\n", mism == 0 ? "PASS" : "FAIL"); + + ggml_backend_sched_free(sched); + ggml_backend_free(cuda); + ggml_backend_free(cpu); + ggml_free(ctx_fused); + cleanup_test_model(model); + + return mism == 0 ? 0 : 1; +#endif +} diff --git a/tests/test-sparse-attn-mqa-cuda.cpp b/tests/test-sparse-attn-mqa-cuda.cpp new file mode 100644 index 00000000000..6bf427c857c --- /dev/null +++ b/tests/test-sparse-attn-mqa-cuda.cpp @@ -0,0 +1,99 @@ +#include "../src/llama-sparse-indexer.h" +#include "../src/llama-model.h" +#include "../src/llama-impl.h" +#include "../src/llama-sparse-topk.h" +#include "../src/llama-sparse-mla-fwd.h" + +#include +#include +#include +#include +#include +#include + +using namespace llama; + +int main() { + printf("=== Sparse MLA kv-aware MQA fused/non-fused repro ===\n"); + // Build graph (no_alloc=true), then schedule and compute with CUDA+CPU + ggml_init_params p{}; + p.mem_size = 64ull * 1024 * 1024; + p.mem_buffer = nullptr; + p.no_alloc = true; + + ggml_context * ctx = ggml_init(p); + if (!ctx) { fprintf(stderr, "ctx init failed\n"); return 1; } + + // Shapes to trigger fused decode (T==1) and Hq != Hkv + const int64_t Dq = 576; // per-head embed for Q/K + const int64_t Hq = 128; // query heads + const int64_t Hkv = 1; // KV heads (MQA) + const int64_t Dv = 576; // per-head embed for V + const int64_t N_kv = 1024; // KV cache length + const int64_t T = 1; // decode step to enable fused path + const int64_t top_k = 64; + + // q_cur [Dq, Hq, T], k_cache [Dq, Hkv, N_kv], v_cache [Dv, Hkv, N_kv] + ggml_tensor * q_cur = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, Dq, Hq, T); + ggml_tensor * k_cache = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, Dq, Hkv, N_kv); + ggml_tensor * v_cache = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, Dv, Hkv, N_kv); + + // topk indices [top_k, T] + ggml_tensor * topk_idx = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, top_k, T); + + // Build graph for kv-aware sparse attention + auto cb = [](ggml_tensor * t, const char * name, int il) { + (void)il; if (!t) { printf("CB: %s=null\n", name); return; } + printf("CB: %s: shape=[%" PRId64 ",%" PRId64 ",%" PRId64 ",%" PRId64 "]\n", name, t->ne[0], t->ne[1], t->ne[2], t->ne[3]); + }; + + ggml_tensor * out = llama::sparse_mla_fwd::apply_sparse_attention_kvaware( + ctx, q_cur, k_cache, v_cache, topk_idx, + /*n_tokens=*/T, /*top_k=*/top_k, /*kq_scale=*/1.0f, + /*kq_mask=*/nullptr, /*attn_softcap=*/0.0f, cb); + + ggml_cgraph * gf = ggml_new_graph(ctx); + ggml_build_forward_expand(gf, out); + + ggml_backend_dev_t cuda_dev = ggml_backend_dev_by_name("CUDA0"); + ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + ggml_backend_t cuda = cuda_dev ? ggml_backend_dev_init(cuda_dev, nullptr) : nullptr; + ggml_backend_t cpu = ggml_backend_dev_init(cpu_dev, nullptr); + if (!cuda || !cpu) { printf("Missing backend(s)\n"); return 0; } + + ggml_backend_t backs[2] = { cuda, cpu }; + int nb = 2; + ggml_backend_sched_t sched = ggml_backend_sched_new(backs, nullptr, nb, GGML_DEFAULT_GRAPH_SIZE, false, true); + + // Assign inputs to CUDA backend to avoid mixed-backend allocation issues + ggml_backend_sched_set_tensor_backend(sched, q_cur, cuda); + ggml_backend_sched_set_tensor_backend(sched, k_cache, cuda); + ggml_backend_sched_set_tensor_backend(sched, v_cache, cuda); + ggml_backend_sched_set_tensor_backend(sched, topk_idx, cuda); + // Allocate graph (no explicit reserve) + ggml_backend_sched_alloc_graph(sched, gf); + + // Initialize host buffers + std::vector hQ((size_t)Dq*Hq*T, 0.0f); + std::vector hK((size_t)Dq*Hkv*N_kv, 0.0f); + std::vector hV((size_t)Dv*Hkv*N_kv, 0.0f); + std::vector hTopK((size_t)top_k*T); + for (int i = 0; i < top_k; ++i) hTopK[i] = i % (int)N_kv; // simple ascending indices + + // Upload to backend + ggml_backend_tensor_set(q_cur, hQ.data(), 0, ggml_nbytes(q_cur)); + ggml_backend_tensor_set(k_cache,hK.data(), 0, ggml_nbytes(k_cache)); + ggml_backend_tensor_set(v_cache,hV.data(), 0, ggml_nbytes(v_cache)); + ggml_backend_tensor_set(topk_idx, hTopK.data(), 0, ggml_nbytes(topk_idx)); + + ggml_status st = ggml_backend_sched_graph_compute(sched, gf); + if (st != GGML_STATUS_SUCCESS) { printf("compute failed (%d)\n", (int)st); return 1; } + + ggml_backend_sched_free(sched); + ggml_backend_free(cuda); + ggml_backend_free(cpu); + ggml_free(ctx); + + printf("=== Repro finished ===\n"); + return 0; +} diff --git a/tests/test-sparse-attn-noalloc.cpp b/tests/test-sparse-attn-noalloc.cpp new file mode 100644 index 00000000000..66f88315fa2 --- /dev/null +++ b/tests/test-sparse-attn-noalloc.cpp @@ -0,0 +1,115 @@ +#include "../src/llama-sparse-indexer.h" +#include "../src/llama-model.h" +#include "../src/llama-impl.h" +#include "../src/llama-sparse-topk.h" +#include "../src/llama-sparse-mla-fwd.h" + +#include +#include +#include +#include + +using namespace llama; + +static llama_model* create_test_model_noalloc(ggml_context * ctx, int num_layers, + int64_t n_embd, + int64_t index_head_dim, + int64_t index_n_heads) { + llama_model_params params = llama_model_default_params(); + llama_model* model = new llama_model(params); + model->arch = LLM_ARCH_DEEPSEEK3_2; + model->layers.resize(num_layers); + + for (int i = 0; i < num_layers; ++i) { + llama_layer & layer = model->layers[i]; + // Shapes chosen to satisfy ggml_mul_mat invariants used in compute_token_importance + // cur: [n_embd, n_tokens] + // wk: [n_embd, index_head_dim] -> wk * cur => [index_head_dim, n_tokens] + layer.attn_indexer_wk = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, index_head_dim); + // wq_a: [n_embd, n_embd] -> wq_a * x_t => [n_embd, 1] + layer.wq_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd); + // wq_b: [n_embd, index_head_dim * index_n_heads] -> wq_b * qr => [index_head_dim*index_n_heads, 1] + layer.attn_indexer_wq_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, index_head_dim * index_n_heads); + // weights proj: [n_embd, index_n_heads] -> weights * x_t => [index_n_heads, 1] + layer.attn_indexer_weights_proj = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, index_n_heads); + // k_norm bias + layer.attn_indexer_k_norm_bias = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, index_head_dim); + } + + return model; +} + +int main() { + printf("=== DeepSeek V3.2-Exp Sparse Attention no_alloc Unit Test ===\n"); + // Create ggml context with no_alloc=true to simulate graph build context + ggml_init_params p{}; + p.mem_size = 64ull * 1024 * 1024; // 64 MB for meta objects + p.mem_buffer = nullptr; // let ggml allocate + p.no_alloc = true; // IMPORTANT: no_alloc + + ggml_context * ctx = ggml_init(p); + if (!ctx) { + fprintf(stderr, "Failed to init ggml context\n"); + return 1; + } + + const int64_t n_embd = 7168; + const int64_t n_tokens = 4096; + const int64_t index_head_dim = 128; + const int64_t index_n_heads = 64; + + llama_model * model = create_test_model_noalloc(ctx, 1, n_embd, index_head_dim, index_n_heads); + + // Create a placeholder current hidden state tensor: [n_embd, n_tokens] + ggml_tensor * cur = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens); + + auto cb = [](ggml_tensor * t, const char * name, int il) { + (void)il; + if (!t) { + printf("CB: %s is null\n", name); + return; + } + printf("CB: %s: shape=[%" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "]\n", + name, t->ne[0], t->ne[1], t->ne[2], t->ne[3]); + }; + + printf("About to call build_kvaware_topk_indices...\n"); + ggml_tensor * topk_indices = llama::sparse_attn_indexer::build_kvaware_topk_indices( + ctx, *model, 0, cur, n_tokens, /*mctx*/ nullptr, /*k_idxs*/ nullptr, /*kq_mask*/ nullptr, /*top_k*/ 64, + /*inp_pos*/ nullptr, /*n_rot*/ 0, /*rope_type*/ 0, /*n_ctx_orig*/ 0, + /*freq_base*/ 0.0f, /*freq_scale*/ 1.0f, /*ext_factor*/ 0.0f, /*attn_factor*/ 1.0f, + /*beta_fast*/ 1.0f, /*beta_slow*/ 1.0f, + cb, /*gf*/ nullptr, /*sched*/ nullptr, /*backend_cpu*/ nullptr); + + if (topk_indices) { + printf("OK: topk_indices shape = [%" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "]\n", + topk_indices->ne[0], topk_indices->ne[1], topk_indices->ne[2], topk_indices->ne[3]); + } else { + printf("topk_indices null\n"); + } + + // Reproduce the exact shape configuration observed during runtime startup + // q_cur: [576, 128, 4096], k_cur: [576, 1, 4096], v_cur: [512, 1, 4096] + // This mismatch between K/V head dims triggers the reshape assert in apply_sparse_attention + const int64_t q_embd_head = 576; + const int64_t q_n_head = 128; + const int64_t kv_n_head = 1; + const int64_t k_embd_head = 576; + const int64_t v_embd_head = 512; // intentionally different than K to reproduce the issue + + ggml_tensor * q_cur = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, q_embd_head, q_n_head, n_tokens); + ggml_tensor * k_cur = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, k_embd_head, kv_n_head, n_tokens); + ggml_tensor * v_cur = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, v_embd_head, kv_n_head, n_tokens); + + printf("About to call apply_sparse_attention_kvaware (expected to reproduce runtime assertion)...\n"); + fflush(stdout); + ggml_tensor * sparse_out = llama::sparse_mla_fwd::apply_sparse_attention_kvaware( + ctx, q_cur, k_cur, v_cur, topk_indices, n_tokens, /*top_k=*/64, /*kq_scale=*/1.0f, /*kq_mask=*/nullptr, /*attn_softcap=*/0.0f, cb); + (void)sparse_out; + + delete model; + ggml_free(ctx); + + printf("=== Test finished ===\n"); + return 0; +} diff --git a/tests/test-sparse-attn.cpp b/tests/test-sparse-attn.cpp new file mode 100644 index 00000000000..b4eed19b03f --- /dev/null +++ b/tests/test-sparse-attn.cpp @@ -0,0 +1,314 @@ +#include "../src/llama-sparse-indexer.h" +#include "../src/llama-sparse-mla-fwd.h" +#include "../src/llama-sparse-topk.h" +#include "../src/llama-model.h" +#include "../src/llama-impl.h" + +#include +#include +#include +#include +#include + +#include + +// Include llama.h for model parameter functions +#include "../include/llama.h" + +#include +#include +#include +#include +#include +#include + +// Simple test to verify sparse attention tensor operations +// This test focuses on the core operations without loading a full model + +// Function declarations +void test_compute_indexer_triplet(); +void test_select_topk_tokens_indexer_kvaware(); + +struct TestContext { + ggml_context * ctx; + ggml_backend_t backend; + + TestContext() { + // Create a simple CPU backend for testing + backend = ggml_backend_cpu_init(); + + // Create a context with reasonable size and let GGML handle allocation + ggml_init_params p{}; + p.mem_size = 100 * 1024 * 1024; // 100MB + p.mem_buffer = nullptr; + p.no_alloc = false; // Let GGML handle allocation + ctx = ggml_init(p); + } + + ~TestContext() { + if (ctx) ggml_free(ctx); + if (backend) ggml_backend_free(backend); + } +}; + +// Helper function to create a minimal llama_model instance for testing +// This creates a real llama_model instance and populates the sparse attention tensors +static llama_model* create_test_model(TestContext & test_ctx, int num_layers = 1) { + // Create model parameters with default values + llama_model_params params = llama_model_default_params(); + + // Create a real llama_model instance + llama_model* model = new llama_model(params); + + // Set the architecture to DeepSeek3_2 + model->arch = LLM_ARCH_DEEPSEEK3_2; + + // Initialize the layers vector + model->layers.resize(num_layers); + + // Create and populate sparse attention tensors for each layer + for (int i = 0; i < num_layers; i++) { + llama_layer& layer = model->layers[i]; + + // Based on DeepSeek V3.2-Exp architecture + const int64_t hidden_dim = 512; + const int64_t index_n_heads = 64; + const int64_t index_head_dim = 128; + + // Indexer key projection + layer.attn_indexer_wk = ggml_new_tensor_2d(test_ctx.ctx, GGML_TYPE_F32, hidden_dim, index_head_dim); + + // Indexer query projection (wq_b) + layer.attn_indexer_wq_b = ggml_new_tensor_2d(test_ctx.ctx, GGML_TYPE_F32, hidden_dim, index_head_dim * index_n_heads); + + // Indexer weights projection + layer.attn_indexer_weights_proj = ggml_new_tensor_2d(test_ctx.ctx, GGML_TYPE_F32, hidden_dim, index_n_heads); + + // Indexer normalization bias + layer.attn_indexer_k_norm_bias = ggml_new_tensor_1d(test_ctx.ctx, GGML_TYPE_F32, index_head_dim); + + // Query projection (wq_a) for non-lite version + layer.wq_a = ggml_new_tensor_2d(test_ctx.ctx, GGML_TYPE_F32, hidden_dim, hidden_dim); + + // Initialize tensors with random data + std::vector wk_data(ggml_nelements(layer.attn_indexer_wk)); + for (size_t j = 0; j < wk_data.size(); j++) { + wk_data[j] = (float)rand() / RAND_MAX; + } + memcpy(layer.attn_indexer_wk->data, wk_data.data(), wk_data.size() * sizeof(float)); + + std::vector wq_b_data(ggml_nelements(layer.attn_indexer_wq_b)); + for (size_t j = 0; j < wq_b_data.size(); j++) { + wq_b_data[j] = (float)rand() / RAND_MAX; + } + memcpy(layer.attn_indexer_wq_b->data, wq_b_data.data(), wq_b_data.size() * sizeof(float)); + + std::vector weights_proj_data(ggml_nelements(layer.attn_indexer_weights_proj)); + for (size_t j = 0; j < weights_proj_data.size(); j++) { + weights_proj_data[j] = (float)rand() / RAND_MAX; + } + memcpy(layer.attn_indexer_weights_proj->data, weights_proj_data.data(), weights_proj_data.size() * sizeof(float)); + + if (layer.attn_indexer_k_norm_bias != nullptr) { + std::vector bias_data(ggml_nelements(layer.attn_indexer_k_norm_bias)); + for (size_t j = 0; j < bias_data.size(); j++) { + bias_data[j] = (float)rand() / RAND_MAX; + } + memcpy(layer.attn_indexer_k_norm_bias->data, bias_data.data(), bias_data.size() * sizeof(float)); + } + + if (layer.wq_a != nullptr) { + std::vector wq_a_data(ggml_nelements(layer.wq_a)); + for (size_t j = 0; j < wq_a_data.size(); j++) { + wq_a_data[j] = (float)rand() / RAND_MAX; + } + memcpy(layer.wq_a->data, wq_a_data.data(), wq_a_data.size() * sizeof(float)); + } + } + + return model; +} + +// Helper function to cleanup the test model +static void cleanup_test_model(llama_model* model) { + if (model) { + delete model; + } +} + +// Test the Lightning Indexer kv-aware topk builder +void test_compute_indexer_triplet() { + printf("Testing compute_indexer_triplet...\n"); + fflush(stdout); + + + TestContext test_ctx; + + // Create a real llama_model instance instead of mocking + llama_model* model = create_test_model(test_ctx, 1); + + // Create a mock current hidden state tensor + const int64_t n_tokens = 16; + const int64_t hidden_dim = 512; + + ggml_tensor * cur = ggml_new_tensor_2d(test_ctx.ctx, GGML_TYPE_F32, hidden_dim, n_tokens); + + // Initialize with random values + std::vector cur_data(hidden_dim * n_tokens); + for (size_t i = 0; i < cur_data.size(); i++) { + cur_data[i] = (float)rand() / RAND_MAX; + } + // Copy data directly to tensor memory + memcpy(cur->data, cur_data.data(), cur_data.size() * sizeof(float)); + + // Create a simple callback function + auto cb = [](ggml_tensor * tensor, const char * name, int layer_idx) { + (void)layer_idx; // Unused parameter + printf("Tensor '%s' (layer %d): shape [", name, layer_idx); + for (int i = 0; i < GGML_MAX_DIMS; i++) { + if (tensor->ne[i] > 1) { + printf("%" PRId64, tensor->ne[i]); + if (i < GGML_MAX_DIMS - 1 && tensor->ne[i+1] > 1) { + printf(", "); + } + } + } + printf("]\n"); + fflush(stdout); + }; + + // Test both lite and non-lite versions + for (bool is_lite : {false, true}) { + printf("Testing %s version...\n", is_lite ? "lite" : "non-lite"); + fflush(stdout); + + try { + auto trip = llama::sparse_attn_indexer::compute_indexer_triplet( + test_ctx.ctx, *model, 0, cur, n_tokens, /*mctx*/ nullptr, /*k_idxs*/ nullptr, + /*inp_pos*/ nullptr, /*n_rot*/ 0, /*rope_type*/ 0, /*n_ctx_orig*/ 0, + /*freq_base*/ 0.0f, /*freq_scale*/ 1.0f, /*ext_factor*/ 0.0f, /*attn_factor*/ 1.0f, + /*beta_fast*/ 1.0f, /*beta_slow*/ 1.0f, + cb, /*gf*/ nullptr); + + if (trip.q_indexer && trip.k_indexer_cache && trip.idx_weights) { + printf("Success: triplet shapes q_indexer=[%" PRId64 ", %" PRId64 ", %" PRId64 "] k_indexer_cache=[%" PRId64 ", %" PRId64 "] idx_weights=[%" PRId64 ", %" PRId64 "]\n", + trip.q_indexer->ne[0], trip.q_indexer->ne[1], trip.q_indexer->ne[2], + trip.k_indexer_cache->ne[0], trip.k_indexer_cache->ne[1], + trip.idx_weights->ne[0], trip.idx_weights->ne[1]); + fflush(stdout); + } else { + printf("Error: triplet contains null tensor(s)\n"); + fflush(stdout); + } + } catch (const std::exception& e) { + printf("Exception: %s\n", e.what()); + fflush(stdout); + } + } + + // Cleanup the model + cleanup_test_model(model); + + printf("compute_indexer_triplet test completed\n\n"); + fflush(stdout); +} + + +// Test the select_topk_tokens_indexer_kvaware function +void test_select_topk_tokens_indexer_kvaware() { + printf("Testing select_topk_tokens_indexer_kvaware...\n"); + fflush(stdout); + + TestContext test_ctx; + + // Simple synthetic shapes + const int64_t D_index = 128; + const int64_t H_index = 8; + const int64_t T = 32; + const int64_t N_kv = 64; + const int64_t top_k = 8; + + // Allocate tensors + ggml_tensor * q_indexer = ggml_new_tensor_3d(test_ctx.ctx, GGML_TYPE_F32, D_index, H_index, T); + ggml_tensor * k_indexer = ggml_new_tensor_2d(test_ctx.ctx, GGML_TYPE_F32, D_index, N_kv); + ggml_tensor * weights = ggml_new_tensor_2d(test_ctx.ctx, GGML_TYPE_F32, H_index, T); + ggml_tensor * kq_mask = ggml_new_tensor_2d(test_ctx.ctx, GGML_TYPE_F32, N_kv, T); + + // Initialize data + std::vector qidx_data(ggml_nelements(q_indexer)); + std::vector kidx_data(ggml_nelements(k_indexer)); + std::vector w_data(ggml_nelements(weights)); + std::vector mask_data(ggml_nelements(kq_mask), 0.0f); + + for (size_t i = 0; i < qidx_data.size(); ++i) qidx_data[i] = (float) rand() / RAND_MAX; + for (size_t i = 0; i < kidx_data.size(); ++i) kidx_data[i] = (float) rand() / RAND_MAX; + for (size_t i = 0; i < w_data.size(); ++i) w_data[i] = (float) rand() / RAND_MAX; + + // Mask out some KV rows for the first few tokens + for (int64_t t = 0; t < T; ++t) { + for (int64_t j = 0; j < N_kv/16; ++j) { + mask_data[t * N_kv + j] = -INFINITY; + } + } + + memcpy(q_indexer->data, qidx_data.data(), qidx_data.size() * sizeof(float)); + memcpy(k_indexer->data, kidx_data.data(), kidx_data.size() * sizeof(float)); + memcpy(weights->data, w_data.data(), w_data.size() * sizeof(float)); + memcpy(kq_mask->data, mask_data.data(),mask_data.size()* sizeof(float)); + + auto cb = [](ggml_tensor * tensor, const char * name, int layer_idx) { + (void)layer_idx; + printf("Tensor '%s': shape [", name); + for (int i = 0; i < GGML_MAX_DIMS; i++) { + if (tensor->ne[i] > 1) { + printf("%" PRId64, tensor->ne[i]); + if (i < GGML_MAX_DIMS - 1 && tensor->ne[i+1] > 1) { + printf(", "); + } + } + } + printf("]\n"); + fflush(stdout); + }; + + try { + ggml_tensor * topk_indices = llama::sparse_attn_topk::select_topk_tokens_indexer_kvaware( + test_ctx.ctx, q_indexer, k_indexer, weights, kq_mask, top_k, cb, /*gf*/ nullptr, /*sched*/ nullptr, /*backend_cpu*/ nullptr); + + if (topk_indices) { + assert(topk_indices->ne[0] == top_k); + assert(topk_indices->ne[1] == T); + printf("Success: idxkv topk_indices tensor created with shape [%" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "]\n", + topk_indices->ne[0], topk_indices->ne[1], topk_indices->ne[2], topk_indices->ne[3]); + fflush(stdout); + } else { + printf("Error: idxkv topk_indices is null\n"); + fflush(stdout); + } + } catch (const std::exception& e) { + printf("Exception: %s\n", e.what()); + fflush(stdout); + } + + printf("select_topk_tokens_indexer_kvaware test completed\n\n"); + fflush(stdout); +} + + +// Main test function +int main() { + printf("=== DeepSeek V3.2-Exp Sparse Attention Unit Tests ===\n\n"); + fflush(stdout); + + // Initialize random seed for reproducible tests + srand(42); + + // Run individual tests + test_compute_indexer_triplet(); + test_select_topk_tokens_indexer_kvaware(); + + printf("=== All tests completed ===\n"); + fflush(stdout); + + return 0; +} diff --git a/tests/test-sparse-mla-decode-cuda.cpp b/tests/test-sparse-mla-decode-cuda.cpp new file mode 100644 index 00000000000..4645624d815 --- /dev/null +++ b/tests/test-sparse-mla-decode-cuda.cpp @@ -0,0 +1,93 @@ +#include +#include +#include +#include +#include +#include +#include +static void ref_sparse_decode(const float *Q, const float *K, const float *V, const int32_t *topk, + int D, int H, int Dv, int N, int Ksel, float kq_scale, float softcap, + std::vector &Out) { + Out.assign((size_t)Dv*H, 0.0f); + for (int h=0; h scores(Ksel); + float m = -1e30f; + for (int i=0; i= 0 && idx < N) { + dot = 0.0f; + const float * qh = Q + (size_t)D*h; + const float * kh = K + (size_t)D*(h + (size_t)H*idx); + for (int d=0; d 0.f) dot = std::tanh(dot/softcap)*softcap; + } + scores[i] = dot; m = std::max(m, dot); + } + float ssum = 0.0f; + for (int i=0; i= N) continue; + float p = scores[i] * inv; + const float * vh = V + (size_t)Dv*(h + (size_t)H*idx); + acc += p * vh[dv]; + } + Out[dv + (size_t)Dv*h] = acc; + } + } +} +int main(){ +#ifndef GGML_USE_CUDA + printf("CUDA not enabled; skipping sparse mla decode test\n"); + return 0; +#else + const int D=64, H=4, Dv=64, N=256, Ksel=32; + float kq_scale = 1.0f; float softcap = 0.0f; + std::mt19937 rng(42); std::uniform_real_distribution dist(-1.0f,1.0f); + std::vector Q((size_t)D*H), K((size_t)D*H*N), V((size_t)Dv*H*N); + std::vector TOPK(Ksel); + for (auto &v:Q) v=dist(rng); + for (auto &v:K) v=dist(rng); + for (auto &v:V) v=dist(rng); + for (int i=0;i O_gpu((size_t)Dv*H); + ggml_backend_tensor_get(out, O_gpu.data(), 0, ggml_nbytes(out)); + std::vector O_ref; + ref_sparse_decode(Q.data(), K.data(), V.data(), TOPK.data(), D,H,Dv,N,Ksel,kq_scale,softcap,O_ref); + float max_abs=0.0f; int mism=0; for (size_t i=0;imax_abs) max_abs=d; if (d>1e-3f) mism++; } + printf("sparse mla decode fused: mism=%d max_abs=%.6f\n", mism, max_abs); + ggml_backend_sched_free(sched); + if (cuda) ggml_backend_free(cuda); + ggml_backend_free(cpu); + ggml_free(ctx); + printf("TEST %s\n", mism==0?"PASS":"FAIL"); + return mism==0?0:1; +#endif +} diff --git a/tests/test-sparse-mla-decode-mqa-cuda.cpp b/tests/test-sparse-mla-decode-mqa-cuda.cpp new file mode 100644 index 00000000000..f7dfdc695e2 --- /dev/null +++ b/tests/test-sparse-mla-decode-mqa-cuda.cpp @@ -0,0 +1,114 @@ +#include +#include +#include +#include +#include +#include +#include + +static void ref_sparse_decode_mqa( + const float *Q, // [D * Hq] + const float *K, // [D * Hkv * N] + const float *V, // [Dv * Hkv * N] + const int32_t *topk, // [Ksel] + int D, int Hq, int Hkv, int Dv, int N, int Ksel, + float kq_scale, float softcap, + std::vector &Out) { // [Dv * Hq] + + Out.assign((size_t)Dv*Hq, 0.0f); + for (int h = 0; h < Hq; ++h) { + const int hk = (Hkv == 1 ? 0 : (h % Hkv)); + std::vector scores(Ksel); + float m = -1e30f; + for (int i = 0; i < Ksel; ++i) { + int idx = topk[i]; + float dot = -1e30f; + if (idx >= 0 && idx < N) { + dot = 0.0f; + const float * qh = Q + (size_t)D*h; + const float * kh = K + (size_t)D*(hk + (size_t)Hkv*idx); + for (int d = 0; d < D; ++d) dot += qh[d]*kh[d]; + dot *= kq_scale; + if (softcap > 0.f) dot = std::tanh(dot/softcap)*softcap; + } + scores[i] = dot; m = std::max(m, dot); + } + float ssum = 0.0f; + for (int i = 0; i < Ksel; ++i) { scores[i] = std::exp(scores[i]-m); ssum += scores[i]; } + float inv = 1.0f/ssum; + for (int dv = 0; dv < Dv; ++dv) { + float acc = 0.0f; + for (int i = 0; i < Ksel; ++i) { + int idx = topk[i]; if (idx < 0 || idx >= N) continue; + float p = scores[i] * inv; + const float * vh = V + (size_t)Dv*(hk + (size_t)Hkv*idx); + acc += p * vh[dv]; + } + Out[dv + (size_t)Dv*h] = acc; + } + } +} + +int main(){ +#ifndef GGML_USE_CUDA + printf("CUDA not enabled; skipping sparse mla decode MQA test\n"); + return 0; +#else + // MQA/GQA repro: Hq != Hkv + const int D=128, Hq=8, Hkv=1, Dv=128, N=1024, Ksel=64; + float kq_scale = 1.0f; float softcap = 0.0f; + std::mt19937 rng(123); std::uniform_real_distribution dist(-1.0f,1.0f); + + std::vector Q((size_t)D*Hq), K((size_t)D*Hkv*N), V((size_t)Dv*Hkv*N); + std::vector TOPK(Ksel); + for (auto &v:Q) v=dist(rng); + for (auto &v:K) v=dist(rng); + for (auto &v:V) v=dist(rng); + for (int i=0;i O_gpu((size_t)Dv*Hq); + ggml_backend_tensor_get(out, O_gpu.data(), 0, ggml_nbytes(out)); + + std::vector O_ref; + ref_sparse_decode_mqa(Q.data(), K.data(), V.data(), TOPK.data(), D,Hq,Hkv,Dv,N,Ksel,kq_scale,softcap,O_ref); + float max_abs=0.0f; int mism=0; for (size_t i=0;imax_abs) max_abs=d; if (d>5e-3f) mism++; } + printf("sparse mla decode fused (MQA): mism=%d max_abs=%.6f\n", mism, max_abs); + + ggml_backend_sched_free(sched); + if (cuda) ggml_backend_free(cuda); + ggml_backend_free(cpu); + ggml_free(ctx); + + printf("TEST %s\n", mism==0?"PASS":"FAIL"); + return mism==0?0:1; +#endif +} diff --git a/tests/test-sparse-topk-histogram-cuda.cpp b/tests/test-sparse-topk-histogram-cuda.cpp new file mode 100644 index 00000000000..4092faf6ac1 --- /dev/null +++ b/tests/test-sparse-topk-histogram-cuda.cpp @@ -0,0 +1,100 @@ +#include +#include +#include +#include +#include +#include + +#ifdef GGML_USE_CUDA +#include +#endif + +static inline uint16_t float_to_half_bits_rtne(float f) { + uint32_t x; std::memcpy(&x, &f, sizeof(x)); + uint32_t sign = (x >> 16) & 0x8000u; + int32_t exp = (int32_t)((x >> 23) & 0xFFu) - 127 + 15; + uint32_t mant = x & 0x007FFFFFu; + if (exp <= 0) { + if (exp < -10) return (uint16_t)sign; + mant |= 0x00800000u; + uint32_t sub = mant >> (1 - exp); + // round to nearest even + if (sub & 0x00001000u) sub += 0x00002000u; + return (uint16_t)(sign | (sub >> 13)); + } else if (exp >= 31) { + // Inf/NaN + if (mant == 0) return (uint16_t)(sign | 0x7C00u); + mant >>= 13; + return (uint16_t)(sign | 0x7C00u | mant | (mant == 0)); + } else { + // round to nearest even + if (mant & 0x00001000u) { + mant += 0x00002000u; + if (mant & 0x00800000u) { mant = 0; exp += 1; if (exp >= 31) return (uint16_t)(sign | 0x7C00u); } + } + return (uint16_t)(sign | ((uint32_t)exp << 10) | (mant >> 13)); + } +} + +static inline uint32_t float_to_key_desc(float x) { + uint32_t u; std::memcpy(&u, &x, sizeof(u)); + if ((int32_t)u < 0) { return ~u; } else { return u ^ 0x80000000u; } +} +static inline uint8_t key32_msb_bin_desc_host(float x) { + return (uint8_t)(float_to_key_desc(x) >> 24); +} + +int main() { +#ifndef GGML_USE_CUDA + printf("CUDA not enabled; skipping histogram test\n"); + return 0; +#else + printf("Testing CUDA histogram kernel (top byte) ...\n"); + const int N = 512; // ensure 16-byte alignment for float4 loads (multiple of 4) + const int T = 7; + + std::mt19937 rng(1234); + std::uniform_real_distribution dist(-2.0f, 3.0f); + std::vector scores((size_t)N*T); + for (auto & v : scores) v = dist(rng); + + std::vector gt_counts_gpu(256*(size_t)T, 0); + std::vector thr_bins_gpu((size_t)T, 0); + + // call CUDA host wrapper + ggml_cuda_topk_histogram_host(scores.data(), N, T, + gt_counts_gpu.data(), thr_bins_gpu.data()); + + // CPU reference using fp16 coarse bin mapping (to match kernel) + std::vector gt_counts_ref(256*(size_t)T, 0); + for (int t = 0; t < T; ++t) { + unsigned int hist[256] = {0}; + for (int i = 0; i < N; ++i) { + uint8_t b0 = key32_msb_bin_desc_host(scores[i + (size_t)N*t]); + hist[b0]++; + } + unsigned int sum = 0; + for (int b = 255; b >= 0; --b) { + gt_counts_ref[b + 256*(size_t)t] = sum; + sum += hist[b]; + } + } + + // Compare + bool ok = true; + for (size_t i = 0; i < gt_counts_ref.size(); ++i) { + if (gt_counts_ref[i] != gt_counts_gpu[i]) { + printf("Mismatch at %zu: ref=%u gpu=%u\n", i, gt_counts_ref[i], gt_counts_gpu[i]); + ok = false; break; + } + } + + if (ok) { + printf("Histogram CUDA test: PASS\n"); + return 0; + } else { + printf("Histogram CUDA test: FAIL\n"); + return 1; + } +#endif +} diff --git a/tests/test-sparse-topk-op-cuda.cpp b/tests/test-sparse-topk-op-cuda.cpp new file mode 100644 index 00000000000..f7078d3bd8e --- /dev/null +++ b/tests/test-sparse-topk-op-cuda.cpp @@ -0,0 +1,135 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +static void build_and_run(int N, int T, int K, int end=0, bool warmup = false) { + // host scores + std::mt19937 rng(42); + std::uniform_real_distribution dist(-5.0f, 5.0f); + std::vector scores_h((size_t)N*T); + for (auto & v : scores_h) v = dist(rng); + + if (end) { + // Force l_end_idx = 1 for every column: only row 0 is > -1e29 + for (int t = 0; t < T; ++t) { + scores_h[0 + (size_t)N * t] = 1.0f; // any normal value > -1e29 + for (int i = end; i < N; ++i) { + scores_h[i + (size_t)N * t] = -1.0e30f; // masked (<= -1e29) + } + } + } + + ggml_init_params ip{}; ip.mem_size = 64ull*1024*1024; ip.no_alloc = true; + ggml_context* ctx = ggml_init(ip); + + ggml_tensor* scores = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, N, T); + ggml_tensor* idx = ggml_sparse_topk_radix(ctx, scores, K); + + ggml_cgraph* gf = ggml_new_graph(ctx); + ggml_build_forward_expand(gf, idx); + + ggml_backend_dev_t cuda_dev = ggml_backend_dev_by_name("CUDA0"); + if (!cuda_dev) { + // fallback: pick any CUDA-like device if name lookup fails + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t d = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(d) == GGML_BACKEND_DEVICE_TYPE_GPU) { cuda_dev = d; break; } + } + } + ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (!cpu_dev) { printf("no CPU device found\n"); ggml_free(ctx); return; } + + ggml_backend_t cuda = cuda_dev ? ggml_backend_dev_init(cuda_dev, nullptr) : nullptr; + ggml_backend_t cpu = ggml_backend_dev_init(cpu_dev, nullptr); + if (!cpu || (!cuda && cuda_dev)) { printf("backend init failed\n"); if (cuda) ggml_backend_free(cuda); if (cpu) ggml_backend_free(cpu); ggml_free(ctx); return; } + + ggml_backend_t backs_arr[2] = { cuda, cpu }; + int n_backs = cuda ? 2 : 1; + ggml_backend_sched_t sched = ggml_backend_sched_new(backs_arr, nullptr, n_backs, GGML_DEFAULT_GRAPH_SIZE, false, true); + if (!sched) { printf("sched init failed\n"); if (cuda) ggml_backend_free(cuda); ggml_backend_free(cpu); ggml_free(ctx); return; } + + ggml_backend_sched_reset(sched); + // Reserve exact buffer sizes to avoid reallocation warnings during alloc_graph + ggml_backend_sched_reserve(sched, gf); + ggml_backend_sched_alloc_graph(sched, gf); + + // copy host scores into device tensor + ggml_backend_tensor_set(scores, scores_h.data(), 0, ggml_nbytes(scores)); + + if (warmup) { + // during warmup, don’t print failures and don’t pollute profiling averages + const char *sav_prof = getenv("LLAMA_SPARSE_PROF"); + const char *sav_prof_ea = getenv("LLAMA_SPARSE_PROF_EACH"); + if (sav_prof) unsetenv("LLAMA_SPARSE_PROF"); + if (sav_prof_ea) unsetenv("LLAMA_SPARSE_PROF_EACH"); + ggml_backend_sched_graph_compute(sched, gf); + if (sav_prof) setenv("LLAMA_SPARSE_PROF", sav_prof, 1); else unsetenv("LLAMA_SPARSE_PROF"); + if (sav_prof_ea) setenv("LLAMA_SPARSE_PROF_EACH", sav_prof_ea, 1); else unsetenv("LLAMA_SPARSE_PROF_EACH"); + // skip the rest of validation and cleanup; return early + ggml_graph_clear(gf); + ggml_backend_sched_free(sched); + ggml_backend_free(cuda); + ggml_backend_free(cpu); + ggml_free(ctx); + return; + } + printf("starting compute\n"); + fflush(stdout); + ggml_status st = ggml_backend_sched_graph_compute(sched, gf); + if (st != GGML_STATUS_SUCCESS) { + printf("backend compute failed: %d\n", (int)st); + fflush(stdout); + return; + } + + // read back indices from device + std::vector idx_h((size_t)K*T, -1); + ggml_backend_tensor_get(idx, idx_h.data(), 0, ggml_nbytes(idx)); + + // validate threshold criterion per column + int ok = 1; + for (int t = 0; t < T; ++t) { + std::vector col(N); + for (int i = 0; i < N; ++i) col[i] = scores_h[i + (size_t)N*t]; + std::vector sorted = col; + std::nth_element(sorted.begin(), sorted.begin() + (K-1), sorted.end(), std::greater()); + float thresh = sorted[K-1]; + std::vector seen(N, 0); + for (int i = 0; i < K; ++i) { + int ix = idx_h[i + (size_t)K*t]; + if (ix < 0 || ix >= N) { printf("op-cuda: invalid idx %d\n", ix); ok = 0; break; } + if (seen[ix]) { printf("op-cuda: duplicate idx %d\n", ix); ok = 0; break; } + seen[ix] = 1; + if (!(col[ix] + 0.0f >= thresh)) { printf("op-cuda: below threshold\n"); ok = 0; break; } + } + if (!ok) break; + } + if (!ok) { + // force non-zero exit to make ctest fail + std::fprintf(stderr, "sparse_topk_radix op test failed (N=%d T=%d K=%d)\n", N, T, K); + std::fflush(stderr); + std::exit(1); + } + + + printf("sparse_topk_radix op test (%d,%d,%d): %s\n", N, T, K, ok?"PASS":"FAIL"); + fflush(stdout); + ggml_graph_clear(gf); + ggml_backend_sched_free(sched); + ggml_backend_free(cuda); + ggml_backend_free(cpu); + ggml_free(ctx); +} + +int main() { + build_and_run(4096, 1, 256, /*end=*/0, /*warmup=*/true); + build_and_run(4096, 1, 256, /*end=*/0); + build_and_run(4096, 1, 256, /*end=*/1); + build_and_run(163840, 1, 256, /*end=*/1); + return 0; +} diff --git a/tests/test-sparse-topk-radix-cuda.cpp b/tests/test-sparse-topk-radix-cuda.cpp new file mode 100644 index 00000000000..a1b5bb83c34 --- /dev/null +++ b/tests/test-sparse-topk-radix-cuda.cpp @@ -0,0 +1,136 @@ +#include "../src/llama-sparse-topk.h" +#include "../src/llama-impl.h" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef GGML_USE_CUDA +static int test_cuda_topk_radix_host(); +#endif + +// Reference top-k (descending) using std::partial_sort on indices +static void ref_topk(const float * row, int64_t N, int64_t k, std::vector & out) { + out.resize(k); + std::vector idx(N); + for (int64_t i = 0; i < N; ++i) idx[i] = (int32_t)i; + auto cmp = [&](int32_t a, int32_t b){ return row[a] > row[b]; }; + if (k < N) { + std::partial_sort(idx.begin(), idx.begin()+k, idx.end(), cmp); + std::copy(idx.begin(), idx.begin()+k, out.begin()); + } else { + std::sort(idx.begin(), idx.end(), cmp); + std::copy(idx.begin(), idx.begin()+k, out.begin()); + } +} + +int main() { + // Build a small random problem: N_kv=4096, T=8 + const int64_t N_kv = 4096; + const int64_t T = 8; + const int64_t H = 4; + const int64_t D = 64; + // removed unused variable k + + // Initialize GGML context + ggml_init_params p{}; + p.mem_size = 256ull * 1024 * 1024; + p.mem_buffer = nullptr; + p.no_alloc = false; + ggml_context * ctx = ggml_init(p); + assert(ctx); + // Create tensors: q_indexer [D,H,T], k_indexer [D,N_kv], weights [H,T] + ggml_tensor * q_indexer = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, D, H, T); + ggml_tensor * k_indexer = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, D, N_kv); + ggml_tensor * weights = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, H, T); + + + // Fill with random values and write into tensors directly (no backend set/get) + std::mt19937 rng(42); + std::uniform_real_distribution dist(-1.0f, 1.0f); + auto fill_tensor = [&](ggml_tensor * t, std::vector & host) { + host.resize(ggml_nelements(t)); + for (auto & v : host) v = dist(rng); + // write respecting strides + for (int64_t i3 = 0; i3 < t->ne[3]; ++i3) + for (int64_t i2 = 0; i2 < t->ne[2]; ++i2) + for (int64_t i1 = 0; i1 < t->ne[1]; ++i1) + for (int64_t i0 = 0; i0 < t->ne[0]; ++i0) { + size_t off = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0*t->nb[0]; + size_t lin = i0 + t->ne[0]*(i1 + t->ne[1]*(i2 + t->ne[2]*i3)); + *(float*)((char*)t->data + off) = host[lin]; + } + }; + std::vector Q, Khost, Whost; + fill_tensor(q_indexer, Q); + fill_tensor(k_indexer, Khost); + fill_tensor(weights, Whost); + + int status = 0; +#ifdef GGML_USE_CUDA + // also run CUDA host-wrapper validation + status |= test_cuda_topk_radix_host(); +#endif + return status; +} + +#ifdef GGML_USE_CUDA +#include +#include +#endif + +static int test_cuda_topk_radix_host() { +#ifdef GGML_USE_CUDA + printf("Running CUDA radix top-k host wrapper test...\n"); + const int N = 1024; + const int T = 8; + const int k = 32; + std::mt19937 rng(123); + std::uniform_real_distribution dist(-1.0f, 1.0f); + std::vector scores((size_t)N*T); + for (auto & v : scores) v = dist(rng); + std::vector idx_cuda((size_t)k*T); + // call host wrapper + ggml_cuda_topk_radix_indices_host(scores.data(), N, T, k, idx_cuda.data()); + // compute reference per column and check threshold criterion + for (int t = 0; t < T; ++t) { + std::vector col(N); + for (int i = 0; i < N; ++i) col[i] = scores[i + N*t]; + std::vector sorted = col; + std::nth_element(sorted.begin(), sorted.begin() + (k-1), sorted.end(), std::greater()); + float thresh = sorted[k-1]; + std::vector seen(N, 0); + for (int i = 0; i < k; ++i) { + int idx = idx_cuda[i + k*t]; + if (idx < 0 || idx >= N) { + printf("CUDA: column %d out-of-bounds index %d N=%d\n", t, idx, N); + return 1; + } + if (seen[idx]) { + printf("CUDA: column %d duplicate index %d\n", t, idx); + return 1; + } + seen[idx] = 1; + if (!(col[idx] + 0.0f >= thresh)) { + printf("CUDA: column %d value below threshold v=%.6f thresh=%.6f\n", t, col[idx], thresh); + return 1; + } + } + } + printf("CUDA radix top-k host wrapper: PASS\n"); + return 0; +#else + (void)0; // unused + return 0; +#endif +} + diff --git a/tests/test-sparse-topk-radix-stress-cuda.cpp b/tests/test-sparse-topk-radix-stress-cuda.cpp new file mode 100644 index 00000000000..a060227fccb --- /dev/null +++ b/tests/test-sparse-topk-radix-stress-cuda.cpp @@ -0,0 +1,51 @@ +#include +#include +#include +#include +#include +#include +#ifdef GGML_USE_CUDA +#include +#endif +static inline uint32_t float_to_key_desc(float x) { + uint32_t u; std::memcpy(&u, &x, sizeof(u)); + if ((int32_t)u < 0) { return ~u; } else { return u | 0x80000000u; } +} +int main() { +#ifndef GGML_USE_CUDA + printf("CUDA not enabled; skipping radix stress test\n"); + return 0; +#else + printf("Radix end-to-end stress CUDA ...\n"); + const int N = 32768; + const int T = 2; + const int K = 64; + std::mt19937 rng(2027); + std::uniform_real_distribution dist(-20.0f, 20.0f); + std::vector scores((size_t)N*T); + for (auto & v : scores) v = dist(rng); + std::vector idx_gpu((size_t)K*T, -1); + // Exercise full path via histogram+select inside device wrapper + ggml_cuda_topk_radix_indices_host(scores.data(), N, T, K, idx_gpu.data()); + // Validate per column + bool ok = true; + for (int t = 0; t < T; ++t) { + std::vector col(N); + for (int i = 0; i < N; ++i) col[i] = scores[i + (size_t)N*t]; + std::vector sorted = col; + std::nth_element(sorted.begin(), sorted.begin() + (K-1), sorted.end(), std::greater()); + float thresh = sorted[K-1]; + std::vector seen(N, 0); + for (int i = 0; i < K; ++i) { + int idx = idx_gpu[i + K*t]; + if (idx < 0 || idx >= N) { printf("Stress-radix: invalid idx %d\n", idx); ok = false; break; } + if (seen[idx]) { printf("Stress-radix: duplicate idx %d\n", idx); ok = false; break; } + seen[idx] = 1; + if (!(col[idx] + 0.0f >= thresh)) { printf("Stress-radix: below threshold\n"); ok = false; break; } + } + if (!ok) break; + } + printf("Radix stress CUDA: %s\n", ok ? "PASS" : "FAIL"); + return ok ? 0 : 1; +#endif +} diff --git a/tests/test-sparse-topk-radix.cpp b/tests/test-sparse-topk-radix.cpp new file mode 100644 index 00000000000..c3de78eea71 --- /dev/null +++ b/tests/test-sparse-topk-radix.cpp @@ -0,0 +1,177 @@ +#include "../src/llama-sparse-topk.h" +#include "../src/llama-impl.h" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +// Reference top-k (descending) using std::partial_sort on indices +static void ref_topk(const float * row, int64_t N, int64_t k, std::vector & out) { + out.resize(k); + std::vector idx(N); + for (int64_t i = 0; i < N; ++i) idx[i] = (int32_t)i; + auto cmp = [&](int32_t a, int32_t b){ return row[a] > row[b]; }; + if (k < N) { + std::partial_sort(idx.begin(), idx.begin()+k, idx.end(), cmp); + std::copy(idx.begin(), idx.begin()+k, out.begin()); + } else { + std::sort(idx.begin(), idx.end(), cmp); + std::copy(idx.begin(), idx.begin()+k, out.begin()); + } +} + +int main() { + // Build a small random problem: N_kv=4096, T=8 + const int64_t N_kv = 4096; + const int64_t T = 8; + const int64_t H = 4; + const int64_t D = 64; + const int64_t k = 64; + + // Initialize GGML context + ggml_init_params p{}; + p.mem_size = 256ull * 1024 * 1024; + p.mem_buffer = nullptr; + p.no_alloc = false; + ggml_context * ctx = ggml_init(p); + assert(ctx); + // Create tensors: q_indexer [D,H,T], k_indexer [D,N_kv], weights [H,T] + ggml_tensor * q_indexer = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, D, H, T); + ggml_tensor * k_indexer = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, D, N_kv); + ggml_tensor * weights = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, H, T); + + + // Fill with random values and write into tensors directly (no backend set/get) + std::mt19937 rng(42); + std::uniform_real_distribution dist(-1.0f, 1.0f); + auto fill_tensor = [&](ggml_tensor * t, std::vector & host) { + host.resize(ggml_nelements(t)); + for (auto & v : host) v = dist(rng); + // write respecting strides + for (int64_t i3 = 0; i3 < t->ne[3]; ++i3) + for (int64_t i2 = 0; i2 < t->ne[2]; ++i2) + for (int64_t i1 = 0; i1 < t->ne[1]; ++i1) + for (int64_t i0 = 0; i0 < t->ne[0]; ++i0) { + size_t off = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0*t->nb[0]; + size_t lin = i0 + t->ne[0]*(i1 + t->ne[1]*(i2 + t->ne[2]*i3)); + *(float*)((char*)t->data + off) = host[lin]; + } + }; + std::vector Q, Khost, Whost; + fill_tensor(q_indexer, Q); + fill_tensor(k_indexer, Khost); + fill_tensor(weights, Whost); + + // Build top-k indices via our new radix top-k on precomputed scores + // First, form the scores matrix for all tokens T: [N_kv, T] + ggml_tensor * q_perm = ggml_permute(ctx, q_indexer, 0, 2, 1, 3); // [D, T, H] + ggml_tensor * q_cont = ggml_cont(ctx, q_perm); + ggml_tensor * Q2d = ggml_reshape_2d(ctx, q_cont, D, T*H); // [D, T*H] + ggml_tensor * logits_all = ggml_mul_mat(ctx, k_indexer, Q2d); // [N_kv, T*H] + ggml_tensor * logits_resh= ggml_reshape_3d(ctx, logits_all, N_kv, H, T); + ggml_tensor * logits_act = ggml_relu(ctx, logits_resh); + ggml_tensor * w = ggml_reshape_3d(ctx, weights, 1, H, T); + ggml_tensor * w_bc = ggml_repeat(ctx, w, logits_act); + ggml_tensor * contrib = ggml_mul(ctx, logits_act, w_bc); + ggml_tensor * contrib_perm = ggml_permute(ctx, contrib, 1, 0, 2, 3); + contrib_perm = ggml_cont(ctx, contrib_perm); + ggml_tensor * sum_h = ggml_sum_rows(ctx, contrib_perm); // [1, N_kv, T] + ggml_tensor * scores_2d = ggml_reshape_2d(ctx, sum_h, N_kv, T); // [N_kv, T] + // Apply k_scale proxy + ggml_tensor * k_sqr = ggml_sqr(ctx, k_indexer); + ggml_tensor * k_sum = ggml_sum_rows(ctx, k_sqr); + ggml_tensor * k_mean= ggml_scale(ctx, k_sum, 1.0f/float(D)); + ggml_tensor * k_scale_vec = ggml_sqrt(ctx, k_mean); // [1, N_kv] + ggml_tensor * k_scale_2d = ggml_transpose(ctx, k_scale_vec); // [N_kv, 1] + k_scale_2d = ggml_cont(ctx, k_scale_2d); + ggml_tensor * k_scale_b = ggml_repeat(ctx, k_scale_2d, scores_2d); + ggml_tensor * scores_scaled = ggml_mul(ctx, scores_2d, k_scale_b); + + ggml_tensor * scores_cont = ggml_cont(ctx, scores_scaled); + ggml_tensor * idx_radix = llama::sparse_attn_topk::topk_radix_indices(ctx, scores_cont, k); + ggml_tensor * idx_ref = ggml_top_k(ctx, scores_cont, k); + // Build and execute graph to materialize both + struct ggml_cgraph * gf = ggml_new_graph(ctx); + ggml_build_forward_expand(gf, idx_radix); + ggml_build_forward_expand(gf, idx_ref); + ggml_graph_compute_with_ctx(ctx, gf, 4); + + // Copy both tensors to host + std::vector RAD(ggml_nelements(idx_radix)); + std::vector REF(ggml_nelements(idx_ref)); + for (int64_t i3 = 0; i3 < idx_radix->ne[3]; ++i3) + for (int64_t i2 = 0; i2 < idx_radix->ne[2]; ++i2) + for (int64_t i1 = 0; i1 < idx_radix->ne[1]; ++i1) + for (int64_t i0 = 0; i0 < idx_radix->ne[0]; ++i0) { + size_t off = i3*idx_radix->nb[3] + i2*idx_radix->nb[2] + i1*idx_radix->nb[1] + i0*idx_radix->nb[0]; + size_t lin = i0 + idx_radix->ne[0]*(i1 + idx_radix->ne[1]*(i2 + idx_radix->ne[2]*i3)); + RAD[lin] = *(int32_t*)((char*)idx_radix->data + off); + } + for (int64_t i3 = 0; i3 < idx_ref->ne[3]; ++i3) + for (int64_t i2 = 0; i2 < idx_ref->ne[2]; ++i2) + for (int64_t i1 = 0; i1 < idx_ref->ne[1]; ++i1) + for (int64_t i0 = 0; i0 < idx_ref->ne[0]; ++i0) { + size_t off = i3*idx_ref->nb[3] + i2*idx_ref->nb[2] + i1*idx_ref->nb[1] + i0*idx_ref->nb[0]; + size_t lin = i0 + idx_ref->ne[0]*(i1 + idx_ref->ne[1]*(i2 + idx_ref->ne[2]*i3)); + REF[lin] = *(int32_t*)((char*)idx_ref->data + off); + } + if (RAD.size() != REF.size()) { + printf("Size mismatch: radix=%zu ref=%zu\n", RAD.size(), REF.size()); + return 1; + } + + // Validate correctness by threshold criterion per column (order and specific indices may differ under ties) + const int64_t KK = idx_radix->ne[0]; + const int64_t Tcol = idx_radix->ne[1]; + const int64_t N = scores_cont->ne[0]; + const size_t nb0 = scores_cont->nb[0]; + const size_t nb1 = scores_cont->nb[1]; + for (int64_t t = 0; t < Tcol; ++t) { + // collect column values + std::vector col(N); + for (int64_t i = 0; i < N; ++i) { + col[i] = *(float*)((char*)scores_cont->data + i*nb0 + t*nb1); + } + // compute threshold = K-th largest value + std::vector sorted = col; + std::nth_element(sorted.begin(), sorted.begin() + (KK-1), sorted.end(), std::greater()); + float thresh = sorted[KK-1]; + // gather radix indices for this column and verify + std::vector rad_k(KK); + for (int64_t i = 0; i < KK; ++i) { + size_t lin = i + KK * t; + rad_k[i] = RAD[lin]; + } + // check bounds and uniqueness + std::vector seen(N, 0); + for (int64_t i = 0; i < KK; ++i) { + int idx = rad_k[i]; + if (idx < 0 || idx >= N) { + printf("Column %lld: index out of bounds: %d (N=%lld)\n", (long long)t, idx, (long long)N); + return 1; + } + if (seen[idx]) { + printf("Column %lld: duplicate index in top-k: %d\n", (long long)t, idx); + return 1; + } + seen[idx] = 1; + float v = col[idx]; + if (!(v + 0.0f >= thresh)) { // allow equality under ties + printf("Column %lld: value below threshold: idx=%d v=%.6f thresh=%.6f\n", (long long)t, idx, v, thresh); + return 1; + } + } + } + + printf("radix top-k unit test: PASS\n"); + return 0; +} diff --git a/tests/test-sparse-topk-select-cuda.cpp b/tests/test-sparse-topk-select-cuda.cpp new file mode 100644 index 00000000000..f1e1b565cec --- /dev/null +++ b/tests/test-sparse-topk-select-cuda.cpp @@ -0,0 +1,116 @@ +#include +#include +#include +#include +#include +#include + +#ifdef GGML_USE_CUDA +#include +#endif + +static inline uint16_t float_to_half_bits_rtne(float f) { + uint32_t x; std::memcpy(&x, &f, sizeof(x)); + uint32_t sign = (x >> 16) & 0x8000u; + int32_t exp = (int32_t)((x >> 23) & 0xFFu) - 127 + 15; + uint32_t mant = x & 0x007FFFFFu; + if (exp <= 0) { + if (exp < -10) return (uint16_t)sign; + mant |= 0x00800000u; + uint32_t sub = mant >> (1 - exp); + if (sub & 0x00001000u) sub += 0x00002000u; // round to nearest even + return (uint16_t)(sign | (sub >> 13)); + } else if (exp >= 31) { + if (mant == 0) return (uint16_t)(sign | 0x7C00u); + mant >>= 13; return (uint16_t)(sign | 0x7C00u | mant | (mant == 0)); + } else { + if (mant & 0x00001000u) { mant += 0x00002000u; if (mant & 0x00800000u) { mant = 0; exp += 1; if (exp >= 31) return (uint16_t)(sign | 0x7C00u); } } + return (uint16_t)(sign | ((uint32_t)exp << 10) | (mant >> 13)); + } +} +static inline uint32_t float_to_key_desc(float x) { + uint32_t u; std::memcpy(&u, &x, sizeof(u)); + if ((int32_t)u < 0) { return ~u; } else { return u ^ 0x80000000u; } +} +static inline uint8_t key32_msb_bin_desc_host(float x) { + return (uint8_t)(float_to_key_desc(x) >> 24); +} + +int main() { +#ifndef GGML_USE_CUDA + printf("CUDA not enabled; skipping select test\n"); + return 0; +#else + printf("Testing CUDA select kernel (given histogram) ...\n"); + const int N = 512; + const int T = 5; + const int K = 32; + + std::mt19937 rng(2025); + std::uniform_real_distribution dist(-5.0f, 5.0f); + std::vector scores((size_t)N*T); + for (auto & v : scores) v = dist(rng); + + // Build histogram greater-counts on CPU to feed kernel + std::vector gt_counts(256*(size_t)T, 0); + for (int t = 0; t < T; ++t) { + unsigned int hist[256] = {0}; + for (int i = 0; i < N; ++i) { + uint8_t b0 = key32_msb_bin_desc_host(scores[i + (size_t)N*t]); + hist[b0]++; + } + unsigned int sum = 0; + for (int b = 255; b >= 0; --b) { + gt_counts[b + 256*(size_t)t] = sum; + sum += hist[b]; + } + } + + // Print gt_counts for debug (first column only) + printf("gt_counts (t=0): "); + for (int b = 0; b < 256; ++b) { + printf("%u ", gt_counts[b + 256*0]); + } + printf("\n"); + + // Run selection only + std::vector idx_gpu((size_t)K*T, -1); + ggml_cuda_topk_select_host(scores.data(), N, T, K, gt_counts.data(), idx_gpu.data()); + + // Validate: threshold criterion per column + bool ok = true; + for (int t = 0; t < T; ++t) { + std::vector col(N); + for (int i = 0; i < N; ++i) col[i] = scores[i + (size_t)N*t]; + std::vector sorted = col; + std::nth_element(sorted.begin(), sorted.begin() + (K-1), sorted.end(), std::greater()); + float thresh = sorted[K-1]; + std::vector seen(N, 0); + for (int i = 0; i < K; ++i) { + int idx = idx_gpu[i + K*t]; + if (idx < 0 || idx >= N) { + printf("Select: column %d invalid index %d\n", t, idx); + ok = false; break; + } + if (seen[idx]) { + printf("Select: column %d duplicate index %d\n", t, idx); + ok = false; break; + } + seen[idx] = 1; + if (!(col[idx] + 0.0f >= thresh)) { + printf("Select: column %d value %.6f below threshold %.6f\n", t, col[idx], thresh); + ok = false; break; + } + } + if (!ok) break; + } + + if (ok) { + printf("Select CUDA test: PASS (note: current kernel may still have edge cases)\n"); + return 0; + } else { + printf("Select CUDA test: FAIL\n"); + return 1; + } +#endif +} diff --git a/tests/test-sparse-topk-select-stress-cuda.cpp b/tests/test-sparse-topk-select-stress-cuda.cpp new file mode 100644 index 00000000000..351c156930a --- /dev/null +++ b/tests/test-sparse-topk-select-stress-cuda.cpp @@ -0,0 +1,86 @@ +#include +#include +#include +#include +#include +#include +#ifdef GGML_USE_CUDA +#include +#endif +static inline uint16_t float_to_half_bits_rtne(float f) { + uint32_t x; std::memcpy(&x, &f, sizeof(x)); + uint32_t sign = (x >> 16) & 0x8000u; + int32_t exp = (int32_t)((x >> 23) & 0xFFu) - 127 + 15; + uint32_t mant = x & 0x007FFFFFu; + if (exp <= 0) { + if (exp < -10) return (uint16_t)sign; + mant |= 0x00800000u; + uint32_t sub = mant >> (1 - exp); + if (sub & 0x00001000u) sub += 0x00002000u; // round to nearest even + return (uint16_t)(sign | (sub >> 13)); + } else if (exp >= 31) { + if (mant == 0) return (uint16_t)(sign | 0x7C00u); + mant >>= 13; return (uint16_t)(sign | 0x7C00u | mant | (mant == 0)); + } else { + if (mant & 0x00001000u) { mant += 0x00002000u; if (mant & 0x00800000u) { mant = 0; exp += 1; if (exp >= 31) return (uint16_t)(sign | 0x7C00u); } } + return (uint16_t)(sign | ((uint32_t)exp << 10) | (mant >> 13)); + } +} +static inline uint32_t float_to_key_desc(float x) { + uint32_t u; std::memcpy(&u, &x, sizeof(u)); + if ((int32_t)u < 0) { return ~u; } else { return u ^ 0x80000000u; } +} +static inline uint8_t key32_msb_bin_desc_host(float x) { + return (uint8_t)(float_to_key_desc(x) >> 24); +} +int main() { +#ifndef GGML_USE_CUDA + printf("CUDA not enabled; skipping select stress test\n"); + return 0; +#else + printf("Select stress CUDA ...\n"); + const int N = 16384; + const int T = 3; + const int K = 32; + std::mt19937 rng(2026); + std::uniform_real_distribution dist(-10.0f, 10.0f); + std::vector scores((size_t)N*T); + for (auto & v : scores) v = dist(rng); + // Build gt_counts on CPU + std::vector gt_counts(256*(size_t)T, 0); + for (int t = 0; t < T; ++t) { + unsigned int hist[256] = {0}; + for (int i = 0; i < N; ++i) { + uint8_t b0 = key32_msb_bin_desc_host(scores[i + (size_t)N*t]); + hist[b0]++; + } + unsigned int sum = 0; + for (int b = 255; b >= 0; --b) { + gt_counts[b + 256*(size_t)t] = sum; + sum += hist[b]; + } + } + std::vector idx_gpu((size_t)K*T, -1); + ggml_cuda_topk_select_host(scores.data(), N, T, K, gt_counts.data(), idx_gpu.data()); + // Validate + bool ok = true; + for (int t = 0; t < T; ++t) { + std::vector col(N); + for (int i = 0; i < N; ++i) col[i] = scores[i + (size_t)N*t]; + std::vector sorted = col; + std::nth_element(sorted.begin(), sorted.begin() + (K-1), sorted.end(), std::greater()); + float thresh = sorted[K-1]; + std::vector seen(N, 0); + for (int i = 0; i < K; ++i) { + int idx = idx_gpu[i + K*t]; + if (idx < 0 || idx >= N) { printf("Stress-select: invalid idx %d\n", idx); ok = false; break; } + if (seen[idx]) { printf("Stress-select: duplicate idx %d\n", idx); ok = false; break; } + seen[idx] = 1; + if (!(col[idx] + 0.0f >= thresh)) { printf("Stress-select: below threshold\n"); ok = false; break; } + } + if (!ok) break; + } + printf("Select stress CUDA: %s\n", ok ? "PASS" : "FAIL"); + return ok ? 0 : 1; +#endif +}