diff --git a/common/chat.cpp b/common/chat.cpp index 111b4a21b36..955c42852a9 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -622,6 +622,7 @@ const char * common_chat_format_name(common_chat_format format) { case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B"; case COMMON_CHAT_FORMAT_GRANITE: return "Granite"; case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS"; + case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS"; default: throw std::runtime_error("Unknown chat format"); } @@ -2059,6 +2060,94 @@ static void common_chat_parse_granite(common_chat_msg_parser & builder) { } } +static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) { + // Parse thinking tags first - this handles the main reasoning content + builder.try_parse_reasoning("", ""); + + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + // Parse tool calls - Seed-OSS uses format + static const common_regex tool_call_begin_regex(""); + static const common_regex tool_call_end_regex(""); + static const common_regex function_regex("]+)>"); + static const common_regex param_regex("]+)>"); + + while (auto tool_res = builder.try_find_regex(tool_call_begin_regex)) { + builder.consume_spaces(); // Consume whitespace after + + // Look for function call inside tool call, ignore any content before it + if (auto func_res = builder.try_find_regex(function_regex, std::string::npos, false)) { + auto function_name = builder.str(func_res->groups[1]); + + // Parse Seed-OSS parameters value + json args = json::object(); + // Parse all parameters + while (auto param_res = builder.try_find_regex(param_regex, std::string::npos, false)) { + // again, ignore noise around parameters + auto param_name = builder.str(param_res->groups[1]); + builder.move_to(param_res->groups[0].end); + builder.consume_spaces(); // Consume whitespace after parameter + auto savedPos = builder.pos(); + if (auto param_parse = builder.try_find_literal("")) { + auto param = param_parse->prelude; + builder.move_to(savedPos); + try { + if (auto param_res = builder.try_consume_json()) { + args[param_name] = param_res->json; + } else { + args[param_name] = param; + } + } catch (json::exception &) { + args[param_name] = param; + } + } else { + throw common_chat_msg_partial_exception("Incomplete tool parameter"); + } + } + // Look for closing function tag + auto end_func = builder.try_find_literal(""); + if (end_func) { + builder.move_to(end_func->groups[0].end); + builder.consume_spaces(); // Consume whitespace after + + // Add the tool call with parsed arguments, but only if we REALLY got the literal + auto eaten_fragment = builder.input().substr(end_func->groups[0].begin, end_func->groups[0].end); + auto funlen = std::string("").length(); + if (eaten_fragment.length() >= funlen && eaten_fragment.substr(0, funlen) == std::string("")) { + if (!builder.add_tool_call(function_name, "", args.dump())) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + } else { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + } else { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + // Look for closing tool call tag + if (auto end_tool = builder.try_find_regex(tool_call_end_regex, std::string::npos, false)) { + builder.move_to(end_tool->groups[0].end); + builder.consume_spaces(); // Consume trailing whitespace after tool call + } else { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + } else { + // No function found - don't consume content here, let it be handled at the end + break; + } + } + + // Consume any remaining whitespace after all tool call processing + builder.consume_spaces(); + auto remaining = builder.consume_rest(); + // If there's any non-whitespace content remaining, add it as content + if (!string_strip(remaining).empty()) { + builder.add_content(remaining); + } +} + static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; data.prompt = apply(tmpl, inputs); @@ -2075,8 +2164,62 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha return data; } +static common_chat_params common_chat_params_init_seed_oss( + const common_chat_template & tmpl, + templates_params & params, + const common_chat_templates_inputs & inputs) +{ + common_chat_params data; + data.prompt = apply(tmpl, params); + data.format = COMMON_CHAT_FORMAT_SEED_OSS; + if (string_ends_with(data.prompt, "")) { + if (!inputs.enable_thinking) { + data.prompt += ""; + } else { + data.thinking_forced_open = true; + } + } + + if (params.tools.is_array() && !params.tools.empty()) { + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector tool_rules; + foreach_function(params.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + auto parameters = function.at("parameters"); + builder.resolve_refs(parameters); + + // Create rule for Seed-OSS function call format + std::string param_rules; + if (parameters.contains("properties")) { + for (const auto & [key, value] : parameters.at("properties").items()) { + param_rules += "\"\"" + builder.add_schema(name + "-arg-" + key, value) + + "\"\""; + } + } + + tool_rules.push_back(builder.add_rule(name + "-call", + "\"\" space \"\" space " + + param_rules + + " \"\" space \"\"")); + }); + + data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "" }); + + data.preserved_tokens = { + "", "", "", "", + "", "", + }; + + builder.add_rule("root", string_join(tool_rules, " | ")); + }); + } + return data; +} + static common_chat_params common_chat_templates_apply_jinja( - const struct common_chat_templates * tmpls, + const struct common_chat_templates * tmpls, const struct common_chat_templates_inputs & inputs) { templates_params params; @@ -2145,6 +2288,11 @@ static common_chat_params common_chat_templates_apply_jinja( return common_chat_params_init_gpt_oss(tmpl, params); } + // Seed-OSS + if (src.find("") != std::string::npos) { + return common_chat_params_init_seed_oss(tmpl, params, inputs); + } + // Use generic handler when mixing tools + JSON schema. // TODO: support that mix in handlers below. if ((params.tools.is_array() && params.json_schema.is_object())) { @@ -2303,6 +2451,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) { case COMMON_CHAT_FORMAT_GPT_OSS: common_chat_parse_gpt_oss(builder); break; + case COMMON_CHAT_FORMAT_SEED_OSS: + common_chat_parse_seed_oss(builder); + break; default: throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); } diff --git a/common/chat.h b/common/chat.h index d1e480c9187..b09ff3b126a 100644 --- a/common/chat.h +++ b/common/chat.h @@ -111,6 +111,7 @@ enum common_chat_format { COMMON_CHAT_FORMAT_COMMAND_R7B, COMMON_CHAT_FORMAT_GRANITE, COMMON_CHAT_FORMAT_GPT_OSS, + COMMON_CHAT_FORMAT_SEED_OSS, COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats }; diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 6c8a0340605..df37c4a6e43 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -7546,9 +7546,13 @@ def __init__(self, *args, **kwargs): ] # n_group and d_inner are used during reshape_tensors for mamba2 - self.d_model = self.find_hparam(["hidden_size", "d_model"]) - self.n_group = self.find_hparam(["n_groups"]) - self.d_inner = self.find_hparam(["expand"]) * self.d_model + # NOTE: Explicitly include hparam prefix prefix for d_model to + # disambiguate with top-level head_dim + # NOTE 2: If needed for future models, this can be isolated in a method + # to separate the prefix setting and teh keys used + self.d_model = self.find_hparam([f"{self.hparam_prefixes[0]}_head_dim", "hidden_size", "d_model"]) + self.n_group = self.find_hparam(["n_groups", "num_groups"]) + self.d_inner = self.find_hparam(["expand", "num_heads"]) * self.d_model def get_attn_layers(self): # Explicit list of layer type names @@ -7609,12 +7613,12 @@ def set_gguf_parameters(self): ## Mamba mixer params ## self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"])) - self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"])) + self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state", "state_dim", "ssm_state_size"])) self.gguf_writer.add_ssm_group_count(self.n_group) self.gguf_writer.add_ssm_inner_size(self.d_inner) # NOTE: The mamba_dt_rank is _not_ the right field for how this is used # in llama.cpp - self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"])) + self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads", "num_heads"])) ## Attention params ## head_count_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"]) @@ -7641,6 +7645,55 @@ def set_vocab(self): Mamba2Model.set_vocab(self) +@ModelBase.register("NemotronHForCausalLM") +class NemotronHModel(GraniteHybridModel): + """Hybrid mamba2/attention model from NVIDIA""" + model_arch = gguf.MODEL_ARCH.NEMOTRON_H + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Save the top-level head_dim for later + self.head_dim = self.hparams.get("head_dim", self.hparams.get("attention_head_dim")) + assert self.head_dim is not None, "Could not find the attention head dim in config" + + # Don't use expand to calculate d_inner + self.d_inner = self.find_hparam(["num_heads"]) * self.d_model + + # Update the ssm / attn / mlp layers + # M: Mamba2, *: Attention, -: MLP + hybrid_override_pattern = self.hparams["hybrid_override_pattern"] + self._ssm_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "M"] + self._mlp_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "-"] + + def get_attn_layers(self): + hybrid_override_pattern = self.hparams["hybrid_override_pattern"] + assert len(hybrid_override_pattern) == self.block_count, "Mismatch between hybrid override and num_hidden_layers!" + return [i for i, val in enumerate(hybrid_override_pattern) if val == "*"] + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + self.gguf_writer.add_key_length(self.head_dim) + self.gguf_writer.add_value_length(self.head_dim) + + # Set feed_forward_length + # NOTE: This will trigger an override warning. This is preferrable to + # duplicating all the parent logic + n_ff = self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"]) + self.gguf_writer.add_feed_forward_length([ + n_ff if i in self._mlp_layers else 0 for i in range(self.block_count) + ]) + + def set_vocab(self): + super().set_vocab() + + # The tokenizer _does_ add a BOS token (via post_processor type + # TemplateProcessing) but does not set add_bos_token to true in the + # config, so we need to explicitly override it here. + self.gguf_writer.add_add_bos_token(True) + + @ModelBase.register("BailingMoeForCausalLM") class BailingMoeModel(TextModel): model_arch = gguf.MODEL_ARCH.BAILINGMOE diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index e1fbf0e1366..1c76566344a 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -1,5 +1,6 @@ #include "binbcast.cuh" #include +#include static __device__ __forceinline__ float op_repeat(const float a, const float b) { return b; @@ -22,13 +23,16 @@ static __device__ __forceinline__ float op_div(const float a, const float b) { return a / b; } -template + + +template static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, - int ne0, int ne1, int ne2, int ne3, - int ne10, int ne11, int ne12, int ne13, - /*int s0, */ int s1, int s2, int s3, - /*int s00,*/ int s01, int s02, int s03, - /*int s10,*/ int s11, int s12, int s13) { + const int ne0, const int ne1, const int ne2, const int ne3, + const int ne10, const int ne11, const int ne12, const int ne13, + /*int s0, */ const int s1, const int s2, const int s3, + /*int s00,*/ const int s01, const int s02, const int s03, + /*int s10,*/ const int s11, const int s12, const int s13, + src1_ptrs... src1s) { const int i0s = blockDim.x*blockIdx.x + threadIdx.x; const int i1 = (blockDim.y*blockIdx.y + threadIdx.y); const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3; @@ -46,24 +50,31 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; const size_t i_dst = i3*s3 + i2*s2 + i1*s1; - const src0_t * src0_row = src0 + i_src0; - const src1_t * src1_row = src1 + i_src1; + const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr; dst_t * dst_row = dst + i_dst; for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) { const int i10 = i0 % ne10; - dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); + + float result = src0_row ? (float) src0_row[i0] : 0.0f; + if constexpr (sizeof...(src1_ptrs) > 0) { + result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10]))); + } else { + result = bin_op(result, (float)src1[i_src1 + i10]); + } + + dst_row[i0] = (dst_t) result; } } -template -static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst, - int ne0, int ne1, int ne2, int ne3, - int ne10, int ne11, int ne12, int ne13, - /*int s0, */ int s1, int s2, int s3, - /*int s00,*/ int s01, int s02, int s03, - /*int s10,*/ int s11, int s12, int s13) { - +template +static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst, + const int ne0, const int ne1, const int ne2,const int ne3, + const int ne10, const int ne11, const int ne12, const int ne13, + /*int s0, */ const int s1, const int s2, const int s3, + /*int s00,*/ const int s01, const int s02, const int s03, + /*int s10,*/ const int s11, const int s12, const int s13, + src1_ptrs ... src1s) { const int i = blockDim.x*blockIdx.x + threadIdx.x; const int i3 = i/(ne2*ne1*ne0); @@ -83,12 +94,190 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; const size_t i_dst = i3*s3 + i2*s2 + i1*s1; - const src0_t * src0_row = src0 + i_src0; - const src1_t * src1_row = src1 + i_src1; + const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr; dst_t * dst_row = dst + i_dst; const int i10 = i0 % ne10; - dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); + + float result = src0_row ? (float) src0_row[i0] : 0.0f; + if constexpr (sizeof...(src1_ptrs) > 0) { + result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10]))); + } else { + result = bin_op(result, (float)src1[i_src1 + i10]); + } + + dst_row[i0] = (dst_t) result; +} + +template +static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, + cudaStream_t stream, std::index_sequence) { + GGML_TENSOR_BINARY_OP_LOCALS + + int nr0 = ne10 / ne0; + int nr1 = ne11 / ne1; + int nr2 = ne12 / ne2; + int nr3 = ne13 / ne3; + + int nr[4] = { nr0, nr1, nr2, nr3 }; + + int64_t cne[] = { ne0, ne1, ne2, ne3 }; + int64_t cne0[] = { ne00, ne01, ne02, ne03 }; + int64_t cne1[] = { ne10, ne11, ne12, ne13 }; + + size_t cnb[] = { nb0, nb1, nb2, nb3 }; + size_t cnb0[] = { nb00, nb01, nb02, nb03 }; + size_t cnb1[] = { nb10, nb11, nb12, nb13 }; + + auto collapse = [](int64_t cne[]) { + cne[0] *= cne[1]; + cne[1] = cne[2]; + cne[2] = cne[3]; + cne[3] = 1; + }; + + auto collapse_nb = [](size_t cnb[], const int64_t cne[]) { + cnb[1] *= cne[1]; + cnb[2] *= cne[2]; + cnb[3] *= cne[3]; + }; + + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { + for (int i = 0; i < 4; i++) { + if (nr[i] != 1) { + break; + } + if (i > 0) { + collapse_nb(cnb, cne); + collapse_nb(cnb0, cne0); + collapse_nb(cnb1, cne1); + collapse(cne); + collapse(cne0); + collapse(cne1); + } + } + } + + { + int64_t ne0 = cne[0]; + int64_t ne1 = cne[1]; + int64_t ne2 = cne[2]; + int64_t ne3 = cne[3]; + + //int64_t ne00 = cne0[0]; GGML_UNUSED(ne00); + //int64_t ne01 = cne0[1]; GGML_UNUSED(ne01); + //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02); + //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03); + + int64_t ne10 = cne1[0]; + int64_t ne11 = cne1[1]; + int64_t ne12 = cne1[2]; + int64_t ne13 = cne1[3]; + + size_t nb0 = cnb[0]; + size_t nb1 = cnb[1]; + size_t nb2 = cnb[2]; + size_t nb3 = cnb[3]; + + size_t nb00 = cnb0[0]; + size_t nb01 = cnb0[1]; + size_t nb02 = cnb0[2]; + size_t nb03 = cnb0[3]; + + size_t nb10 = cnb1[0]; + size_t nb11 = cnb1[1]; + size_t nb12 = cnb1[2]; + size_t nb13 = cnb1[3]; + + size_t s0 = nb0 / sizeof(dst_t); + size_t s1 = nb1 / sizeof(dst_t); + size_t s2 = nb2 / sizeof(dst_t); + size_t s3 = nb3 / sizeof(dst_t); + + size_t s10 = nb10 / sizeof(src1_t); + size_t s11 = nb11 / sizeof(src1_t); + size_t s12 = nb12 / sizeof(src1_t); + size_t s13 = nb13 / sizeof(src1_t); + + size_t s00 = nb00 / sizeof(src0_t); + size_t s01 = nb01 / sizeof(src0_t); + size_t s02 = nb02 / sizeof(src0_t); + size_t s03 = nb03 / sizeof(src0_t); + + GGML_ASSERT(nb0 % sizeof(dst_t) == 0); + GGML_ASSERT(nb1 % sizeof(dst_t) == 0); + GGML_ASSERT(nb2 % sizeof(dst_t) == 0); + GGML_ASSERT(nb3 % sizeof(dst_t) == 0); + + GGML_ASSERT(nb00 % sizeof(src0_t) == 0); + GGML_ASSERT(nb01 % sizeof(src0_t) == 0); + GGML_ASSERT(nb02 % sizeof(src0_t) == 0); + GGML_ASSERT(nb03 % sizeof(src0_t) == 0); + + GGML_ASSERT(nb10 % sizeof(src1_t) == 0); + GGML_ASSERT(nb11 % sizeof(src1_t) == 0); + GGML_ASSERT(nb12 % sizeof(src1_t) == 0); + GGML_ASSERT(nb13 % sizeof(src1_t) == 0); + + GGML_ASSERT(s0 == 1); + GGML_ASSERT(s00 == 1); + GGML_ASSERT(s10 == 1); + + const int block_size = 128; + + int64_t hne0 = std::max(ne0 / 2LL, 1LL); + + dim3 block_dims; + block_dims.x = std::min(hne0, block_size); + block_dims.y = std::min(ne1, block_size / block_dims.x); + block_dims.z = std::min(std::min(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U); + + dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x, + (ne1 + block_dims.y - 1) / block_dims.y, + (ne2 * ne3 + block_dims.z - 1) / block_dims.z); + + if (block_nums.z > 65535) { + int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size; + if constexpr (sizeof...(I) > 0) { + k_bin_bcast_unravel + <<>>(src0_dd, src1_dd, dst_dd, + ne0, ne1, ne2, ne3, + ne10, ne11, ne12, ne13, + /* s0, */ s1, s2, s3, + /* s00,*/ s01, s02, s03, + /* s10,*/ s11, s12,s13, + (const src1_t *) dst->src[I + 1]->data...); + } else { + k_bin_bcast_unravel + <<>>(src0_dd, src1_dd, dst_dd, + ne0, ne1, ne2, ne3, + ne10, ne11, ne12, ne13, + /* s0, */ s1, s2, s3, + /* s00,*/ s01, s02, s03, + /* s10,*/ s11, s12,s13); + } + } else { + if constexpr (sizeof...(I) > 0) { + k_bin_bcast + <<>>(src0_dd, src1_dd, dst_dd, + ne0, ne1, ne2, ne3, + ne10, ne11, ne12, ne13, + /* s0, */ s1, s2, s3, + /* s00,*/ s01, s02, s03, + /* s10,*/ s11, s12,s13, + (const src1_t *) dst->src[I + 1]->data...); + } else { + k_bin_bcast + <<>>(src0_dd, src1_dd, dst_dd, + ne0, ne1, ne2, ne3, + ne10, ne11, ne12, ne13, + /* s0, */ s1, s2, s3, + /* s00,*/ s01, s02, s03, + /* s10,*/ s11, s12,s13); + } + } + } } template @@ -120,160 +309,14 @@ static __global__ void k_repeat_back( dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum; } -template +template struct bin_bcast_cuda { template void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, cudaStream_t stream) { - - GGML_TENSOR_BINARY_OP_LOCALS - - int nr0 = ne10/ne0; - int nr1 = ne11/ne1; - int nr2 = ne12/ne2; - int nr3 = ne13/ne3; - - int nr[4] = { nr0, nr1, nr2, nr3 }; - - // collapse dimensions until first broadcast dimension - int64_t cne[] = {ne0, ne1, ne2, ne3}; - int64_t cne0[] = {ne00, ne01, ne02, ne03}; - int64_t cne1[] = {ne10, ne11, ne12, ne13}; - - size_t cnb[] = {nb0, nb1, nb2, nb3}; - size_t cnb0[] = {nb00, nb01, nb02, nb03}; - size_t cnb1[] = {nb10, nb11, nb12, nb13}; - - auto collapse = [](int64_t cne[]) { - cne[0] *= cne[1]; - cne[1] = cne[2]; - cne[2] = cne[3]; - cne[3] = 1; - }; - - auto collapse_nb = [](size_t cnb[], const int64_t cne[]) { - cnb[1] *= cne[1]; - cnb[2] *= cne[2]; - cnb[3] *= cne[3]; - }; - - if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { - for (int i = 0; i < 4; i++) { - if (nr[i] != 1) { - break; - } - if (i > 0) { - collapse_nb(cnb, cne); - collapse_nb(cnb0, cne0); - collapse_nb(cnb1, cne1); - collapse(cne); - collapse(cne0); - collapse(cne1); - } - } - } - - { - int64_t ne0 = cne[0]; - int64_t ne1 = cne[1]; - int64_t ne2 = cne[2]; - int64_t ne3 = cne[3]; - - //int64_t ne00 = cne0[0]; GGML_UNUSED(ne00); - //int64_t ne01 = cne0[1]; GGML_UNUSED(ne01); - //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02); - //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03); - - int64_t ne10 = cne1[0]; - int64_t ne11 = cne1[1]; - int64_t ne12 = cne1[2]; - int64_t ne13 = cne1[3]; - - size_t nb0 = cnb[0]; - size_t nb1 = cnb[1]; - size_t nb2 = cnb[2]; - size_t nb3 = cnb[3]; - - size_t nb00 = cnb0[0]; - size_t nb01 = cnb0[1]; - size_t nb02 = cnb0[2]; - size_t nb03 = cnb0[3]; - - size_t nb10 = cnb1[0]; - size_t nb11 = cnb1[1]; - size_t nb12 = cnb1[2]; - size_t nb13 = cnb1[3]; - - size_t s0 = nb0 / sizeof(dst_t); - size_t s1 = nb1 / sizeof(dst_t); - size_t s2 = nb2 / sizeof(dst_t); - size_t s3 = nb3 / sizeof(dst_t); - - size_t s10 = nb10 / sizeof(src1_t); - size_t s11 = nb11 / sizeof(src1_t); - size_t s12 = nb12 / sizeof(src1_t); - size_t s13 = nb13 / sizeof(src1_t); - - size_t s00 = nb00 / sizeof(src0_t); - size_t s01 = nb01 / sizeof(src0_t); - size_t s02 = nb02 / sizeof(src0_t); - size_t s03 = nb03 / sizeof(src0_t); - - GGML_ASSERT(nb0 % sizeof(dst_t) == 0); - GGML_ASSERT(nb1 % sizeof(dst_t) == 0); - GGML_ASSERT(nb2 % sizeof(dst_t) == 0); - GGML_ASSERT(nb3 % sizeof(dst_t) == 0); - - GGML_ASSERT(nb00 % sizeof(src0_t) == 0); - GGML_ASSERT(nb01 % sizeof(src0_t) == 0); - GGML_ASSERT(nb02 % sizeof(src0_t) == 0); - GGML_ASSERT(nb03 % sizeof(src0_t) == 0); - - GGML_ASSERT(nb10 % sizeof(src1_t) == 0); - GGML_ASSERT(nb11 % sizeof(src1_t) == 0); - GGML_ASSERT(nb12 % sizeof(src1_t) == 0); - GGML_ASSERT(nb13 % sizeof(src1_t) == 0); - - GGML_ASSERT(s0 == 1); - GGML_ASSERT(s00 == 1); - GGML_ASSERT(s10 == 1); - - const int block_size = 128; - - int64_t hne0 = std::max(ne0/2LL, 1LL); - - dim3 block_dims; - block_dims.x = std::min(hne0, block_size); - block_dims.y = std::min(ne1, block_size / block_dims.x); - block_dims.z = std::min(std::min(ne2*ne3, block_size / block_dims.x / block_dims.y), 64U); - - dim3 block_nums( - (hne0 + block_dims.x - 1) / block_dims.x, - (ne1 + block_dims.y - 1) / block_dims.y, - (ne2*ne3 + block_dims.z - 1) / block_dims.z - ); - - if (block_nums.z > 65535) { - // this is the maximum number of blocks in z dimension, fallback to 1D grid kernel - int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size; - k_bin_bcast_unravel<<>>( - src0_dd, src1_dd, dst_dd, - ne0, ne1, ne2, ne3, - ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00, */ s01, s02, s03, - /* s10, */ s11, s12, s13); - } else { - k_bin_bcast<<>>( - src0_dd, src1_dd, dst_dd, - ne0, ne1, ne2, ne3, - ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00, */ s01, s02, s03, - /* s10, */ s11, s12, s13); - } - } + launch_bin_bcast_pack( + src0, src1, dst, src0_dd, src1_dd, dst_dd, stream, std::make_index_sequence{}); } }; @@ -312,7 +355,7 @@ static void ggml_cuda_op_bin_bcast( } void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_bin_bcast>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream()); + ggml_cuda_op_bin_bcast>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream()); } void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -331,6 +374,68 @@ void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_bin_bcast>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); } +template +static void ggml_cuda_op_fused_binbcast_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + cudaStream_t stream = ctx.stream(); + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + launch_bin_bcast_pack(src0, src1, dst, + (const float *) src0->data, (const float *) src1->data, (float *) dst->data, + stream, std::make_index_sequence{}); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + launch_bin_bcast_pack(src0, src1, dst, + (const half *) src0->data, (const half *) src1->data, (half *) dst->data, + stream, std::make_index_sequence{}); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { + launch_bin_bcast_pack(src0, src1, dst, + (const half *) src0->data, (const float *) src1->data, (half *) dst->data, + stream, std::make_index_sequence{}); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + launch_bin_bcast_pack(src0, src1, dst, + (const half *) src0->data, (const float *) src1->data, (float *) dst->data, + stream, std::make_index_sequence{}); + } else { + fprintf(stderr, + "%s: unsupported types for fusion: dst: %s, src0: %s, src1: %s\n", + __func__, ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); + GGML_ABORT("fatal error"); + } +} + + +void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) { + GGML_ASSERT(2 <= n_fuse && n_fuse <= 8); + + switch (n_fuse) { + case 2: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 3: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 4: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 5: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 6: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 7: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 8: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + default: + GGML_ASSERT(false && "Unsupported n_fuse value"); + } +} + void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; diff --git a/ggml/src/ggml-cuda/binbcast.cuh b/ggml/src/ggml-cuda/binbcast.cuh index 3ac1c9b03fc..62bc950111b 100644 --- a/ggml/src/ggml-cuda/binbcast.cuh +++ b/ggml/src/ggml-cuda/binbcast.cuh @@ -7,3 +7,5 @@ void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 4c02b57227a..e06f95f0819 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2821,9 +2821,14 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, return false; } - if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) { + if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) { const ggml_tensor *rms_norm = cgraph->nodes[node_idx]; const ggml_tensor *mul = cgraph->nodes[node_idx+1]; + const ggml_tensor *add = nullptr; + + if (ops.size() == 3 && ops.begin()[2] == GGML_OP_ADD) { + add = cgraph->nodes[node_idx+2]; + } GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); @@ -2835,6 +2840,12 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, return false; } + if (add && (add->src[0]->type != GGML_TYPE_F32 || + add->src[1]->type != GGML_TYPE_F32 || + add->type != GGML_TYPE_F32) ) { + return false; + } + //if rms norm is the B operand, then we don't handle broadcast if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) { return false; @@ -2845,6 +2856,10 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, return false; } + if (add && (!ggml_is_contiguous(add->src[0]) || !ggml_is_contiguous_rows(add->src[1]))) { + return false; + } + return true; } @@ -2891,7 +2906,46 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); if (!disable_fusion) { - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) { + + if (node->op == GGML_OP_ADD) { + int n_fuse = 0; + ggml_op ops[8]; + std::fill(ops, ops + 8, GGML_OP_ADD); + + for (; n_fuse <= 6; ++n_fuse){ + if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) { + break; + } + if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) { + break; + } + if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) { + break; + } + } + + n_fuse++; + + if (n_fuse > 1) { + for (int j = 0; j < n_fuse - 1; ++j) { + node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1]; + } + cgraph->nodes[i + n_fuse - 1]->data = node->data; + ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse); + i += n_fuse - 1; + + continue; + } + } + + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) { + ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]); + i += 2; + continue; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) { ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]); i++; continue; diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index bddcca51b7b..d5157d958b7 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -104,12 +104,30 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr } } -template -static __global__ void rms_norm_f32( - const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, - const int64_t stride_sample, const float eps, const float * mul = nullptr, const int64_t mul_stride_row = 0, - const int64_t mul_stride_channel = 0, const int64_t mul_stride_sample = 0, const int mul_ncols = 0, - const int mul_nrows = 0, const int mul_nchannels = 0, const int mul_nsamples = 0) { +template +static __global__ void rms_norm_f32(const float * x, float * dst, + const int ncols, + const int64_t stride_row, + const int64_t stride_channel, + const int64_t stride_sample, + const float eps, + const float * mul = nullptr, + const int64_t mul_stride_row = 0, + const int64_t mul_stride_channel = 0, + const int64_t mul_stride_sample = 0, + const int mul_ncols = 0, + const int mul_nrows = 0, + const int mul_nchannels = 0, + const int mul_nsamples = 0, + const float * add = nullptr, + const int64_t add_stride_row = 0, + const int64_t add_stride_channel = 0, + const int64_t add_stride_sample = 0, + const int add_ncols = 0, + const int add_nrows = 0, + const int add_nchannels = 0, + const int add_nsamples = 0) { + const int nrows = gridDim.x; const int nchannels = gridDim.y; @@ -118,6 +136,8 @@ static __global__ void rms_norm_f32( const int sample = blockIdx.z; const int tid = threadIdx.x; + static_assert(!do_add || do_multiply, "fusing add is not supported without multiplying"); + x += sample*stride_sample + channel*stride_channel + row*stride_row; dst += ((sample*nchannels + channel)*nrows + row)*ncols; @@ -128,6 +148,13 @@ static __global__ void rms_norm_f32( mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row; } + if constexpr (do_add) { + const int add_row = row % add_nrows; + const int add_channel = channel % add_nchannels; + const int add_sample = sample % add_nsamples; + add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row; + } + float tmp = 0.0f; // partial sum for thread in warp for (int col = tid; col < ncols; col += block_size) { @@ -154,7 +181,11 @@ static __global__ void rms_norm_f32( const float scale = rsqrtf(mean + eps); for (int col = tid; col < ncols; col += block_size) { - if constexpr (do_multiply) { + if constexpr (do_multiply && do_add) { + const int mul_col = col % mul_ncols; + const int add_col = col % add_ncols; + dst[col] = scale * x[col] * mul[mul_col] + add[add_col]; + } else if constexpr (do_multiply) { const int mul_col = col % mul_ncols; dst[col] = scale * x[col] * mul[mul_col]; } else { @@ -331,23 +362,70 @@ static void rms_norm_f32_cuda( } } -static void rms_norm_mul_f32_cuda( - const float * x, const float * mul, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, - const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, - const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample, - const int mul_ncols, const int mul_nrows, const int mul_nchannels, const int mul_nsamples, - const float eps, cudaStream_t stream) { +static void rms_norm_mul_f32_cuda(const float * x, + const float * mul, + const float * add, + float * dst, + const int ncols, + const int nrows, + const int nchannels, + const int nsamples, + const int64_t stride_row, + const int64_t stride_channel, + const int64_t stride_sample, + const int64_t mul_stride_row, + const int64_t mul_stride_channel, + const int64_t mul_stride_sample, + const int mul_ncols, + const int mul_nrows, + const int mul_nchannels, + const int mul_nsamples, + const int64_t add_stride_row, + const int64_t add_stride_channel, + const int64_t add_stride_sample, + const int add_ncols, + const int add_nrows, + const int add_nchannels, + const int add_nsamples, + const float eps, + cudaStream_t stream) { const dim3 blocks_num(nrows, nchannels, nsamples); if (mul == nullptr) { rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream); return; } - if (ncols < 1024) { - const dim3 block_dims(WARP_SIZE, 1, 1); - rms_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples); + if (add == nullptr) { + if (ncols < 1024) { + const dim3 block_dims(WARP_SIZE, 1, 1); + rms_norm_f32<<>>(x, dst, + ncols, stride_row, stride_channel, stride_sample, eps, + mul, mul_stride_row, mul_stride_channel, mul_stride_sample, + mul_ncols, mul_nrows, mul_nchannels, mul_nsamples); + } else { + const dim3 block_dims(1024, 1, 1); + rms_norm_f32<1024, true><<>>(x, dst, + ncols, stride_row, stride_channel, stride_sample, eps, + mul, mul_stride_row, mul_stride_channel, mul_stride_sample, + mul_ncols, mul_nrows, mul_nchannels, mul_nsamples); + } } else { - const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, true><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples); + if (ncols < 1024) { + const dim3 block_dims(WARP_SIZE, 1, 1); + rms_norm_f32<<>>(x, dst, + ncols, stride_row, stride_channel, stride_sample, eps, + mul, mul_stride_row, mul_stride_channel, mul_stride_sample, + mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, + add, add_stride_row, add_stride_channel, add_stride_sample, + add_ncols, add_nrows, add_nchannels, add_nsamples); + } else { + const dim3 block_dims(1024, 1, 1); + rms_norm_f32<1024, true, true><<>>(x, dst, + ncols, stride_row, stride_channel, stride_sample, eps, + mul, mul_stride_row, mul_stride_channel, mul_stride_sample, + mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, + add, add_stride_row, add_stride_channel, add_stride_sample, + add_ncols, add_nrows, add_nchannels, add_nsamples); + } } } @@ -491,7 +569,102 @@ void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * const int mul_nchannels = mul_src->ne[2]; const int mul_nsamples = mul_src->ne[3]; - rms_norm_mul_f32_cuda(src0_d, mul_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, mul_s01, mul_s02, mul_s03, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, eps, stream); + rms_norm_mul_f32_cuda(src0_d, mul_d, nullptr, dst_d, + ne00, ne01, ne02, ne03, + /*s00*/ s01, s02, s03, + /*mul_s00*/ mul_s01, mul_s02, mul_s03, + mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, + /*add_s00*/ 0, 0, 0, + 0, 0, 0, 0, + eps, stream); +} + +void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx, + ggml_tensor * dst, + ggml_tensor * mul_tensor, + ggml_tensor * add_tensor) { + const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0]; + float eps = 0.0f; + + memcpy(&eps, dst->op_params, sizeof(float)); + + const float * src0_d = (const float *) rms_norm_src->data; + const float * mul_d = nullptr; + const ggml_tensor * mul_src = nullptr; + + if (mul_tensor->src[0] == dst) { + mul_d = (float *) mul_tensor->src[1]->data; + mul_src = mul_tensor->src[1]; + } else if (mul_tensor->src[1] == dst) { + mul_d = (float *) mul_tensor->src[0]->data; + mul_src = mul_tensor->src[0]; + } else { + GGML_ASSERT(false); + } + + const float * add_d = nullptr; + const ggml_tensor * add_src = nullptr; + + if (add_tensor->src[0] == mul_tensor) { + add_d = (float *) add_tensor->src[1]->data; + add_src = add_tensor->src[1]; + } else if (add_tensor->src[1] == mul_tensor) { + add_d = (float *) add_tensor->src[0]->data; + add_src = add_tensor->src[0]; + } else { + GGML_ASSERT(false); + } + + float * dst_d = (float *) add_tensor->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32); + GGML_ASSERT(add_tensor->type == GGML_TYPE_F32); + GGML_ASSERT(eps >= 0.0f); + + const int64_t ne00 = rms_norm_src->ne[0]; + const int64_t ne01 = rms_norm_src->ne[1]; + const int64_t ne02 = rms_norm_src->ne[2]; + const int64_t ne03 = rms_norm_src->ne[3]; + + const size_t ts0 = ggml_type_size(rms_norm_src->type); + GGML_ASSERT(rms_norm_src->nb[0] == ts0); + const int64_t s01 = rms_norm_src->nb[1] / ts0; + const int64_t s02 = rms_norm_src->nb[2] / ts0; + const int64_t s03 = rms_norm_src->nb[3] / ts0; + + const size_t ts_mul = ggml_type_size(mul_src->type); + GGML_ASSERT(mul_src->nb[0] == ts_mul); + const int64_t mul_s01 = mul_src->nb[1] / ts_mul; + const int64_t mul_s02 = mul_src->nb[2] / ts_mul; + const int64_t mul_s03 = mul_src->nb[3] / ts_mul; + + const int mul_ncols = mul_src->ne[0]; + const int mul_nrows = mul_src->ne[1]; + const int mul_nchannels = mul_src->ne[2]; + const int mul_nsamples = mul_src->ne[3]; + + const size_t ts_add = ggml_type_size(add_src->type); + GGML_ASSERT(add_src->nb[0] == ts_add); + const int64_t add_s01 = add_src->nb[1] / ts_add; + const int64_t add_s02 = add_src->nb[2] / ts_add; + const int64_t add_s03 = add_src->nb[3] / ts_add; + + const int add_ncols = add_src->ne[0]; + const int add_nrows = add_src->ne[1]; + const int add_nchannels = add_src->ne[2]; + const int add_nsamples = add_src->ne[3]; + + rms_norm_mul_f32_cuda(src0_d, mul_d,add_d,dst_d, + ne00,ne01, ne02, ne03, + /*s00*/ s01, s02, s03, + /*mul_s00*/ mul_s01, mul_s02, mul_s03, + mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, + /*add_s00*/ add_s01, add_s02, add_s03, + add_ncols, add_nrows, add_nchannels, add_nsamples, + eps, stream); } void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cuda/norm.cuh b/ggml/src/ggml-cuda/norm.cuh index 7ea7bd4df3c..a74f6376720 100644 --- a/ggml/src/ggml-cuda/norm.cuh +++ b/ggml/src/ggml-cuda/norm.cuh @@ -8,6 +8,11 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor); +void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx, + ggml_tensor * dst, + ggml_tensor * mul_tensor, + ggml_tensor * add_tensor); + void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index a581f9601f0..6156d35c2ad 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -367,6 +367,7 @@ class MODEL_ARCH(IntEnum): T5ENCODER = auto() JAIS = auto() NEMOTRON = auto() + NEMOTRON_H = auto() EXAONE = auto() EXAONE4 = auto() GRANITE = auto() @@ -700,6 +701,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.T5ENCODER: "t5encoder", MODEL_ARCH.JAIS: "jais", MODEL_ARCH.NEMOTRON: "nemotron", + MODEL_ARCH.NEMOTRON_H: "nemotron_h", MODEL_ARCH.EXAONE: "exaone", MODEL_ARCH.EXAONE4: "exaone4", MODEL_ARCH.GRANITE: "granite", @@ -2297,6 +2299,25 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.NEMOTRON_H: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_D, + MODEL_TENSOR.SSM_NORM, + MODEL_TENSOR.SSM_OUT, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], MODEL_ARCH.EXAONE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index abb21fa8219..497f48809f8 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -191,6 +191,7 @@ class TensorNameMap: "model.layers.{bid}.self_attn.q_proj", # llama4 "model.transformer.blocks.{bid}.q_proj", # llada "layers.{bid}.self_attn.q_proj", # qwen3-embedding + "backbone.layers.{bid}.mixer.q_proj", # nemotron-h ), # Attention key @@ -209,6 +210,7 @@ class TensorNameMap: "model.layers.{bid}.self_attn.k_proj", # llama4 "model.transformer.blocks.{bid}.k_proj", # llada "layers.{bid}.self_attn.k_proj", # qwen3-embedding + "backbone.layers.{bid}.mixer.k_proj", # nemotron-h ), # Attention value @@ -226,6 +228,7 @@ class TensorNameMap: "model.layers.{bid}.self_attn.v_proj", # llama4 "model.transformer.blocks.{bid}.v_proj", # llada "layers.{bid}.self_attn.v_proj", # qwen3-embedding + "backbone.layers.{bid}.mixer.v_proj", # nemotron-h ), # Attention output @@ -260,6 +263,7 @@ class TensorNameMap: "transformer_encoder.{bid}.wo", # neobert "model.transformer.blocks.{bid}.attn_out", # llada "layers.{bid}.self_attn.o_proj", # qwen3-embedding + "backbone.layers.{bid}.mixer.o_proj", # nemotron-h ), # Attention output norm @@ -387,6 +391,7 @@ class TensorNameMap: "model.layers.{bid}.block_sparse_moe.up", # smallthinker "model.transformer.blocks.{bid}.up_proj", # llada "layers.{bid}.mlp.up_proj", # qwen3-embedding + "backbone.layers.{bid}.mixer.up_proj", # nemotron-h ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -480,6 +485,7 @@ class TensorNameMap: "model.layers.{bid}.block_sparse_moe.down", # smallthinker "model.transformer.blocks.{bid}.ff_out", # llada "layers.{bid}.mlp.down_proj", # qwen3-embedding + "backbone.layers.{bid}.mixer.down_proj", # nemotron-h ), MODEL_TENSOR.FFN_DOWN_EXP: ( diff --git a/models/templates/ByteDance-Seed-OSS.jinja b/models/templates/ByteDance-Seed-OSS.jinja new file mode 100644 index 00000000000..903ebaaba77 --- /dev/null +++ b/models/templates/ByteDance-Seed-OSS.jinja @@ -0,0 +1,171 @@ +{# ----------‑‑‑ special token variables ‑‑‑---------- #} +{%- set bos_token = '' -%} +{%- set eos_token = '' -%} +{%- set pad_token = '' -%} +{%- set toolcall_begin_token = '' -%} +{%- set toolcall_end_token = '' -%} +{%- set think_begin_token = '' -%} +{%- set think_end_token = '' -%} +{%- set budget_begin_token = ''-%} +{%- set budget_end_token = ''-%} +{# -------------- reflection-interval lookup -------------- #} +{%- if not thinking_budget is defined %} +{%- set thinking_budget = -1 -%} +{%- endif -%} +{%- set budget_reflections_v05 = { + 0: 0, + 512: 128, + 1024: 256, + 2048: 512, + 4096: 512, + 8192: 1024, + 16384: 1024 +} -%} +{# Find the first gear that is greater than or equal to the thinking_budget. #} +{%- set ns = namespace(interval = None) -%} +{%- for k, v in budget_reflections_v05 | dictsort -%} + {%- if ns.interval is none and thinking_budget <= k -%} + {%- set ns.interval = v -%} + {%- endif -%} +{%- endfor -%} +{# If it exceeds the maximum gear, use the value of the last gear #} +{%- if ns.interval is none -%} + {%- set ns.interval = budget_reflections_v05[16384] -%} +{%- endif -%} +{# ---------- Preprocess the system message ---------- #} +{%- if messages[0]["role"] == "system" %} +{%- set system_message = messages[0]["content"] %} +{%- set loop_messages = messages[1:] %} +{%- else %} +{%- set loop_messages = messages %} +{%- endif %} +{# ---------- Ensure tools exist ---------- #} +{%- if not tools is defined or tools is none %} +{%- set tools = [] %} +{%- endif %} +{# tools2doc.jinja #} +{%- macro py_type(t) -%} + {%- if t == "string" -%}str + {%- elif t in ("number", "integer") -%}int + {%- elif t == "boolean" -%}bool + {%- elif t == "array" -%}list + {%- else -%}Any{%- endif -%} +{%- endmacro -%} +{# ---------- Output the system block ---------- #} +{%- if system_message is defined %} +{{ bos_token + "system\n" + system_message }} +{%- else %} +{%- if tools is iterable and tools | length > 0 %} +{{ bos_token + "system\nYou are Doubao, a helpful AI assistant. You may call one or more functions to assist with the user query." }} +{%- endif %} +{%- endif %} +{%- if use_json_tooldef is defined and use_json_tooldef %} + +{{"Tool List:\nYou are authorized to use the following tools (described in JSON Schema format). Before performing any task, you must decide how to call them based on the descriptions and parameters of these tools."}} +{{ tools | tojson(ensure_ascii=False) }} +{%- else %} +{%- for item in tools if item.type == "function" %} + + +Function: +def {{ item.function.name }}( +{%- for name, spec in item.function.parameters.properties.items() %} + {{- name }}: {{ py_type(spec.type) }}{% if not loop.last %},{% endif %} +{%- endfor %}): + """ + {{ item.function.description | trim }} + + {# ---------- Args ---------- #} + {%- if item.function.parameters.properties %} + Args: + {%- for name, spec in item.function.parameters.properties.items() %} + + - {{ name }} ({{ py_type(spec.type) }}) + {%- if name in item.function.parameters.required %} [必填]{% else %} [选填]{% endif %}: + {{- " " ~ (spec.description or "") }} + {%- endfor %} + {%- endif %} + + {# ---------- Returns ---------- #} + {%- if item.function.returns is defined + and item.function.returns.properties is defined + and item.function.returns.properties %} + Returns: + {%- for name, spec in item.function.returns.properties.items() %} + + - {{ name }} ({{ py_type(spec.type) }}): + {{- " " ~ (spec.description or "") }} + {%- endfor %} + {%- endif %} + + """ +{%- endfor %} +{%- endif %} +{%- if tools is iterable and tools | length > 0 %} + +{{"工具调用请遵循如下格式:\n\n\nvalue_1\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n"}} +{%- endif %} +{# End the system block line #} +{%- if system_message is defined or tools is iterable and tools | length > 0 %} +{{ eos_token }} +{%- endif %} +{# ---------- Thinking Budget ---------- #} +{%- if thinking_budget is defined %} +{%- if thinking_budget == 0 %} +{{ bos_token+"system" }} +{{ "You are an intelligent assistant that can answer questions in one step without the need for reasoning and thinking, that is, your thinking budget is 0. Next, please skip the thinking process and directly start answering the user's questions." }} +{{ eos_token }} +{%- elif not thinking_budget == -1 %} +{{ bos_token+"system" }} +{{ "You are an intelligent assistant with reflective ability. In the process of thinking and reasoning, you need to strictly follow the thinking budget, which is "}}{{thinking_budget}}{{". That is, you need to complete your thinking within "}}{{thinking_budget}}{{" tokens and start answering the user's questions. You will reflect on your thinking process every "}}{{ns.interval}}{{" tokens, stating how many tokens have been used and how many are left."}} +{{ eos_token }} +{%- endif %} +{%- endif %} +{# ---------- List the historical messages one by one ---------- #} +{%- for message in loop_messages %} +{%- if message.role == "assistant" + and message.tool_calls is defined + and message.tool_calls is iterable + and message.tool_calls | length > 0 %} +{{ bos_token + message.role }} +{%- if message.reasoning_content is defined and message.reasoning_content is string and message.reasoning_content | trim | length > 0 %} +{{ "\n" + think_begin_token + message.reasoning_content | trim + think_end_token }} +{%- endif %} +{%- if message.content is defined and message.content is string and message.content | trim | length > 0 %} +{{ "\n" + message.content | trim + "\n" }} +{%- endif %} +{%- for tool_call in message.tool_calls %} +{%- if tool_call.function is defined %}{% set tool_call = tool_call.function %}{% endif %} +{{ "\n" + toolcall_begin_token + "\n\n" }} +{%- if tool_call.arguments is defined %} +{%- for arg_name, arg_value in tool_call.arguments | items %} +{{ "" }} +{%- set arg_value = arg_value if arg_value is string else arg_value | string %} +{{ arg_value+"\n" }} +{%- endfor %} +{%- endif %} +{{ "\n" + toolcall_end_token }} +{%- endfor %} +{{ eos_token }} +{%- elif message.role in ["user", "system"] %} +{{ bos_token + message.role + "\n" + message.content + eos_token }} +{%- elif message.role == "assistant" %} +{{ bos_token + message.role }} +{%- if message.reasoning_content is defined and message.reasoning_content is string and message.reasoning_content | trim | length > 0 %} +{{ "\n" + think_begin_token + message.reasoning_content | trim + think_end_token }} +{%- endif %} +{%- if message.content is defined and message.content is string and message.content | trim | length > 0 %} +{{ "\n" + message.content | trim + eos_token }} +{%- endif %} +{# Include the tool role #} +{%- else %} +{{ bos_token + message.role + "\n" + message.content + eos_token }} +{%- endif %} +{%- endfor %} +{# ---------- Control the model to start continuation ---------- #} +{%- if add_generation_prompt %} +{{ bos_token+"assistant\n" }} +{%- if thinking_budget == 0 %} +{{ think_begin_token + "\n" + budget_begin_token + "The current thinking budget is 0, so I will directly start answering the question." + budget_end_token + "\n" + think_end_token }} +{%- endif %} +{%- endif %} \ No newline at end of file diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index a61dc177ac9..d5c8477f4aa 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -69,6 +69,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_T5ENCODER, "t5encoder" }, { LLM_ARCH_JAIS, "jais" }, { LLM_ARCH_NEMOTRON, "nemotron" }, + { LLM_ARCH_NEMOTRON_H, "nemotron_h" }, { LLM_ARCH_EXAONE, "exaone" }, { LLM_ARCH_EXAONE4, "exaone4" }, { LLM_ARCH_RWKV6, "rwkv6" }, @@ -1550,6 +1551,31 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_NEMOTRON_H, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + // mamba(2) ssm layers + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + // attention layers + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + // dense FFN + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_EXAONE, { @@ -2355,6 +2381,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { case LLM_ARCH_PLAMO2: case LLM_ARCH_GRANITE_HYBRID: case LLM_ARCH_LFM2: + case LLM_ARCH_NEMOTRON_H: return true; default: return false; diff --git a/src/llama-arch.h b/src/llama-arch.h index 94b0bef719b..86c119692d8 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -73,6 +73,7 @@ enum llm_arch { LLM_ARCH_T5ENCODER, LLM_ARCH_JAIS, LLM_ARCH_NEMOTRON, + LLM_ARCH_NEMOTRON_H, LLM_ARCH_EXAONE, LLM_ARCH_EXAONE4, LLM_ARCH_RWKV6, diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index f71c40f8e3f..8182a9adf53 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -788,6 +788,7 @@ const struct ggml_tensor * llama_model_loader::check_tensor_dims(const std::stri } struct ggml_tensor * llama_model_loader::create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list & ne, int flags) { + LLAMA_LOG_DEBUG("%s: loading tensor %s\n", __func__, name.c_str()); const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED)); if (cur == NULL) { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 30974a723fd..f3e0e9ac64b 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1570,6 +1570,27 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_NEMOTRON_H: + { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // A layer is recurrent IFF the n_head_kv value is set to 0 and + // the n_ff value is set to 0 + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = (hparams.n_head_kv(i) == 0 && hparams.n_ff(i) == 0); + } + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 56: type = LLM_TYPE_9B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_EXAONE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -4688,6 +4709,75 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); } } break; + case LLM_ARCH_NEMOTRON_H: + { + // mamba2 Mixer SSM params + // NOTE: int64_t for tensor dimensions + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_ssm_head = hparams.ssm_dt_rank; + const int64_t n_group = hparams.ssm_n_group; + const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head; + + // embeddings + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // all blocks use the attn norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.is_recurrent(i)) { + // ssm layers + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); + + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, TENSOR_NOT_REQUIRED); + + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_ssm_head}, 0); + + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_ssm_head}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_ssm_head}, 0); + + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); + + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); + } else if (hparams.n_ff(i) == 0) { + // attention layers (with optional bias) + const int64_t n_head_i = hparams.n_head(i); + const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i); + const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_i}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa_i}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa_i}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + } else { + // mlp layers + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { hparams.n_ff(i), n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, hparams.n_ff(i)}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {hparams.n_ff(i)}, TENSOR_NOT_REQUIRED); + } + } + } break; case LLM_ARCH_EXAONE: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -5862,7 +5952,8 @@ void llama_model::print_info() const { arch == LLM_ARCH_JAMBA || arch == LLM_ARCH_FALCON_H1 || arch == LLM_ARCH_PLAMO2 || - arch == LLM_ARCH_GRANITE_HYBRID) { + arch == LLM_ARCH_GRANITE_HYBRID || + arch == LLM_ARCH_NEMOTRON_H) { LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); @@ -14129,6 +14220,138 @@ struct llm_build_nemotron : public llm_graph_context { } }; +struct llm_build_nemotron_h : public llm_graph_context_mamba { + llm_build_nemotron_h( + const llama_model & model, + const llm_graph_params & params) : + llm_graph_context_mamba(params) { + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + auto * inp = build_inp_mem_hybrid(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + if (hparams.is_recurrent(il)) { + // ssm layer // + cur = build_mamba2_layer(inp->get_recr(), cur, model, ubatch, il); + } else if (hparams.n_ff(il) == 0) { + // attention layer // + cur = build_attention_layer(cur, inp->get_attn(), model, n_embd_head, il); + } else { + cur = build_ffn_layer(cur, model, il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // add residual + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "block_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } + + ggml_tensor * build_attention_layer( + ggml_tensor * cur, + llm_graph_input_attn_kv * inp_attn, + const llama_model & model, + const int64_t n_embd_head, + const int il) { + + // compute Q and K and (optionally) RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + return cur; + } + + ggml_tensor * build_ffn_layer( + ggml_tensor * cur, + const llama_model & model, + const int il) { + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_RELU_SQR, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + return cur; + } +}; + struct llm_build_exaone : public llm_graph_context { llm_build_exaone(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -18277,6 +18500,23 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.n_seq_max, nullptr); } else if (llm_arch_is_hybrid(arch)) { + + // The main difference between hybrid architectures is the + // layer filters, so pick the right one here + llama_memory_hybrid::layer_filter_cb filter_attn = nullptr; + llama_memory_hybrid::layer_filter_cb filter_recr = nullptr; + if (arch == LLM_ARCH_FALCON_H1) { + filter_attn = [&](int32_t) { return true; }; + filter_recr = [&](int32_t) { return true; }; + } else if (arch == LLM_ARCH_NEMOTRON_H) { + filter_attn = [&](int32_t il) { + return !hparams.is_recurrent(il) && hparams.n_ff(il) == 0; + }; + filter_recr = [&](int32_t il) { + return hparams.is_recurrent(il) && hparams.n_ff(il) == 0; + }; + } + const auto padding = llama_kv_cache::get_padding(cparams); cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); @@ -18296,8 +18536,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* n_seq_max */ cparams.n_seq_max, /* offload */ cparams.offload_kqv, /* unified */ cparams.kv_unified, - /* filter_attn */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr, - /* filter_recr */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr); + /* filter_attn */ std::move(filter_attn), + /* filter_recr */ std::move(filter_recr)); } else { const auto padding = llama_kv_cache::get_padding(cparams); @@ -18625,6 +18865,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_NEMOTRON_H: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_EXAONE: { llm = std::make_unique(*this, params); @@ -18860,6 +19104,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_RWKV7: case LLM_ARCH_ARWKV7: case LLM_ARCH_WAVTOKENIZER_DEC: + case LLM_ARCH_NEMOTRON_H: return LLAMA_ROPE_TYPE_NONE; // use what we call a normal RoPE, operating on pairs of consecutive head values diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index a6daa93a82b..8120b45c4bf 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -1621,6 +1621,140 @@ static void test_template_output_parsers() { /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, })); } + { + // Seed-OSS format tests + auto tmpls = read_templates("models/templates/ByteDance-Seed-OSS.jinja"); + std::vector end_tokens{ "" }; + + assert_equals(COMMON_CHAT_FORMAT_SEED_OSS, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); + assert_equals(COMMON_CHAT_FORMAT_SEED_OSS, common_chat_templates_apply(tmpls.get(), inputs_tools).format); + + test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + + // Test simple reasoning content + assert_msg_equals( + simple_assist_msg("Hello, world!", "I'm thinking about the answer"), + common_chat_parse( + "I'm thinking about the answerHello, world!", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_SEED_OSS, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + })); + + // Test budget reflection tags + common_chat_msg msg_budget_reflect; + msg_budget_reflect.role = "assistant"; + msg_budget_reflect.content = "Token usage: 45/1000\nI should continue thinking to find the best solution.I need to calculate this step by step."; + msg_budget_reflect.reasoning_content = "Token usage: 45/1000\nI should continue thinking to find the best solution."; + assert_msg_equals( + msg_budget_reflect, + common_chat_parse( + "Token usage: 45/1000\nI should continue thinking to find the best solution." + "Token usage: 45/1000\nI should continue thinking to find the best solution." + "I need to calculate this step by step.", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_SEED_OSS, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + })); + + // Test tool calls with Seed-OSS format + common_chat_msg msg_tool_call; + msg_tool_call.role = "assistant"; + msg_tool_call.tool_calls.push_back({"calculate_sum", "{\"numbers\": [1, 2, 3]}", ""}); + assert_msg_equals( + msg_tool_call, + common_chat_parse( + "\n" + "\n" + "[1, 2, 3]\n" + "\n" + "", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_SEED_OSS})); + + // Test reasoning + tool call combination + common_chat_msg msg_reasoning_tool; + msg_reasoning_tool.role = "assistant"; + msg_reasoning_tool.content = ""; + msg_reasoning_tool.reasoning_content = "I need to calculate the sum of these numbers"; + msg_reasoning_tool.tool_calls.push_back({"calculate_sum", "{\"numbers\": [1, 2, 3]}", ""}); + assert_msg_equals( + msg_reasoning_tool, + common_chat_parse( + "I need to calculate the sum of these numbers" + "\n" + "\n" + "[1, 2, 3]\n" + "\n" + "", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_SEED_OSS, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + })); + + // Test deltas: the number of tool calls in partial parses should never decrease + std::string tool_msg = "\n" + "\n" + "[1, 2, 3]\n" + ""; + std::size_t previousToolCalls = 0; + for (std::size_t i = std::string("").length(); i < tool_msg.length() - 1; i++) { + auto partial = tool_msg.substr(0, i); + auto partial_res = common_chat_parse(partial, true, { COMMON_CHAT_FORMAT_SEED_OSS, COMMON_REASONING_FORMAT_DEEPSEEK }); + if (partial_res.tool_calls.size() < previousToolCalls) { + throw std::runtime_error("Tool call size decreased on partial: " + partial + " from " + std::to_string(previousToolCalls) + " to " + std::to_string(partial_res.tool_calls.size())); + } + previousToolCalls = partial_res.tool_calls.size(); + } + + // Test multiple parameters in tool call + common_chat_msg msg_multi_param; + msg_multi_param.role = "assistant"; + msg_multi_param.tool_calls.push_back({"process_data", "{\"input\": \"test\", \"format\": \"json\"}", ""}); + assert_msg_equals( + msg_multi_param, + common_chat_parse( + "\n" + "\n" + "test\n" + "json\n" + "\n" + "", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_SEED_OSS})); + + // Test partial parsing for incomplete tool call - don't actually add the call until parsing parameters is done + assert_msg_equals( + simple_assist_msg("", ""), + common_chat_parse( + "\n" + "\n" + "[1,\n", + /* is_partial= */ true, + {COMMON_CHAT_FORMAT_SEED_OSS})); + + // Test incomplete reasoning tag + assert_msg_equals( + simple_assist_msg("", "I was thinking"), + common_chat_parse( + "I was thinking", + /* is_partial= */ true, + { + /* .format = */ COMMON_CHAT_FORMAT_SEED_OSS, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + })); + + // Test content without reasoning + assert_msg_equals( + simple_assist_msg("This is a simple response without reasoning."), + common_chat_parse( + "This is a simple response without reasoning.", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_SEED_OSS})); + } } static void test_msg_diffs_compute() {