Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ LLM inference in C/C++

## Hot topics

- Support for the `gpt-oss` model with native MXFP4 format has been added | [PR](https://github.com/ggml-org/llama.cpp/pull/15091) | [Collaboration with NVIDIA](https://blogs.nvidia.com/blog/rtx-ai-garage-openai-oss) | [Comment](https://github.com/ggml-org/llama.cpp/discussions/15095)
- Hot PRs: [All](https://github.com/ggml-org/llama.cpp/pulls?q=is%3Apr+label%3Ahot+) | [Open](https://github.com/ggml-org/llama.cpp/pulls?q=is%3Apr+label%3Ahot+is%3Aopen)
- Multimodal support arrived in `llama-server`: [#12898](https://github.com/ggml-org/llama.cpp/pull/12898) | [documentation](./docs/multimodal.md)
- VS Code extension for FIM completions: https://github.com/ggml-org/llama.vscode
Expand Down
3 changes: 2 additions & 1 deletion common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2947,11 +2947,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n"
"- none: leaves thoughts unparsed in `message.content`\n"
"- deepseek: puts thoughts in `message.reasoning_content` (except in streaming mode, which behaves as `none`)\n"
"(default: deepseek)",
"(default: auto)",
[](common_params & params, const std::string & value) {
/**/ if (value == "deepseek") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; }
else if (value == "deepseek-legacy") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY; }
else if (value == "none") { params.reasoning_format = COMMON_REASONING_FORMAT_NONE; }
else if (value == "auto") { params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; }
else { throw std::invalid_argument("invalid value"); }
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK"));
Expand Down
50 changes: 48 additions & 2 deletions common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const comm
typedef minja::chat_template common_chat_template;

struct common_chat_templates {
bool add_bos;
bool add_eos;
bool has_explicit_template; // Model had builtin template or template overridde was specified.
std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
std::unique_ptr<common_chat_template> template_tool_use;
Expand All @@ -143,6 +145,8 @@ struct templates_params {
bool enable_thinking = true;
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
json extra_context;
bool add_bos;
bool add_eos;
};

common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
Expand Down Expand Up @@ -445,6 +449,8 @@ std::string common_chat_format_single(

common_chat_templates_inputs inputs;
inputs.use_jinja = use_jinja;
inputs.add_bos = tmpls->add_bos;
inputs.add_eos = tmpls->add_eos;

std::string fmt_past_msg;
if (!past_msg.empty()) {
Expand All @@ -469,6 +475,8 @@ std::string common_chat_format_single(
std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) {
common_chat_templates_inputs inputs;
inputs.use_jinja = use_jinja;
inputs.add_bos = tmpls->add_bos;
inputs.add_eos = tmpls->add_eos;
auto add_simple_msg = [&](auto role, auto content) {
common_chat_msg msg;
msg.role = role;
Expand Down Expand Up @@ -546,6 +554,8 @@ common_chat_templates_ptr common_chat_templates_init(
}
std::string token_bos = bos_token_override;
std::string token_eos = eos_token_override;
bool add_bos = false;
bool add_eos = false;
if (model) {
const auto * vocab = llama_model_get_vocab(model);
const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
Expand All @@ -560,9 +570,13 @@ common_chat_templates_ptr common_chat_templates_init(
};
token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
add_bos = llama_vocab_get_add_bos(vocab);
add_eos = llama_vocab_get_add_eos(vocab);
}
common_chat_templates_ptr tmpls(new common_chat_templates());
tmpls->has_explicit_template = has_explicit_template;
tmpls->add_bos = add_bos;
tmpls->add_eos = add_eos;
try {
tmpls->template_default = std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos);
} catch (const std::exception & e) {
Expand Down Expand Up @@ -592,6 +606,7 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
default:
throw std::runtime_error("Unknown chat format");
}
Expand All @@ -600,6 +615,7 @@ const char * common_chat_format_name(common_chat_format format) {
const char * common_reasoning_format_name(common_reasoning_format format) {
switch (format) {
case COMMON_REASONING_FORMAT_NONE: return "none";
case COMMON_REASONING_FORMAT_AUTO: return "auto";
case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek";
case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy";
default:
Expand Down Expand Up @@ -748,10 +764,10 @@ static std::string apply(
// instead of using `chat_template_options.use_bos_token = false`, since these tokens
// may be needed inside the template / between messages too.
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
if (string_starts_with(result, tmpl.bos_token())) {
if (inputs.add_bos && string_starts_with(result, tmpl.bos_token())) {
result = result.substr(tmpl.bos_token().size());
}
if (string_ends_with(result, tmpl.eos_token())) {
if (inputs.add_eos && string_ends_with(result, tmpl.eos_token())) {
result = result.substr(0, result.size() - tmpl.eos_token().size());
}
return result;
Expand Down Expand Up @@ -1289,6 +1305,26 @@ static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
tool_calls_end);
}

static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
auto prompt = apply(tmpl, inputs);

data.prompt = prompt;
data.format = COMMON_CHAT_FORMAT_GPT_OSS;

// TODO: support tool calls in GPT-OSS?

return data;
}
static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
// TODO @ngxson : this won't work with --special enabled, we should fix that
builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|start|>assistant<|channel|>final<|message|>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
}

static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
LOG_DBG("%s\n", __func__);
common_chat_params data;
Expand Down Expand Up @@ -1731,6 +1767,8 @@ static common_chat_params common_chat_templates_apply_jinja(
params.enable_thinking = inputs.enable_thinking;
params.grammar = inputs.grammar;
params.now = inputs.now;
params.add_bos = inputs.add_bos;
params.add_eos = inputs.add_eos;

params.extra_context = json::object();
for (auto el : inputs.chat_template_kwargs) {
Expand Down Expand Up @@ -1772,6 +1810,11 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_hermes_2_pro(tmpl, params);
}

// GPT-OSS
if (src.find("<|channel|>") != std::string::npos && params.json_schema.is_null()) {
return common_chat_params_init_gpt_oss(tmpl, params);
}

// Use generic handler when mixing tools + JSON schema.
// TODO: support that mix in handlers below.
if ((params.tools.is_array() && params.json_schema.is_object())) {
Expand Down Expand Up @@ -1923,6 +1966,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_COMMAND_R7B:
common_chat_parse_command_r7b(builder);
break;
case COMMON_CHAT_FORMAT_GPT_OSS:
common_chat_parse_gpt_oss(builder);
break;
default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
}
Expand Down
3 changes: 3 additions & 0 deletions common/chat.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
COMMON_CHAT_FORMAT_HERMES_2_PRO,
COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_GPT_OSS,

COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
Expand All @@ -127,6 +128,8 @@ struct common_chat_templates_inputs {
bool enable_thinking = true;
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
std::map<std::string, std::string> chat_template_kwargs;
bool add_bos = false;
bool add_eos = false;
};

struct common_chat_params {
Expand Down
3 changes: 2 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ struct common_params_diffusion {

enum common_reasoning_format {
COMMON_REASONING_FORMAT_NONE,
COMMON_REASONING_FORMAT_AUTO,
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
};
Expand Down Expand Up @@ -394,7 +395,7 @@ struct common_params {
std::string chat_template = ""; // NOLINT
bool use_jinja = false; // NOLINT
bool enable_chat_template = true;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO;
int reasoning_budget = -1;
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response

Expand Down
114 changes: 114 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7950,6 +7950,119 @@ def set_vocab(self):
self.gguf_writer.add_chat_template(chat_template)


@ModelBase.register("GptOssForCausalLM")
class GptOssModel(TextModel):
model_arch = gguf.MODEL_ARCH.GPT_OSS

def transform_nibble_layout(self, tensor):
assert tensor.dtype == torch.uint8
assert tensor.shape[-1] == 16
# swap nibbles
t_lo = tensor & 0x0F
t_hi = tensor & 0xF0
t_swapped = (t_lo << 4) | (t_hi >> 4)
tensor = t_swapped
# transform aaaa...bbbb... to abababab...
blk_a, blk_b = tensor.chunk(2, dim=-1)
# get a_
blk_a0 = (blk_a & 0xF0).view(-1, 1)
blk_a1 = (blk_a << 4).view(-1, 1)
blk_a = torch.stack((blk_a0, blk_a1), dim=2).view(tensor.shape)
# get _b
blk_b0 = (blk_b >> 4).view(-1, 1)
blk_b1 = (blk_b & 0x0F).view(-1, 1)
blk_b = torch.stack((blk_b0, blk_b1), dim=2).view(tensor.shape)
# swap once more
out = blk_a | blk_b
out_h = out & 0xF0
out_l = out & 0x0F
out = (out_h >> 4) | (out_l << 4)
return out

def repack_mxfp4(self, new_name: str, blocks: Tensor, scales: Tensor):
assert blocks.dtype == torch.uint8
assert scales.dtype == torch.uint8
scales = scales.unsqueeze(-1)
assert len(blocks.shape) == 4
assert len(scales.shape) == 4
blocks = self.transform_nibble_layout(blocks)
new_data = torch.concat((scales, blocks), dim=-1)
new_shape = [new_data.shape[0], new_data.shape[1], new_data.shape[2] * 32]
logger.info(f"Repacked {new_name} with shape {new_shape} and quantization MXFP4")
# flatten last dim
new_data = new_data.view(new_data.shape[0], new_data.shape[1], new_data.shape[2] * new_data.shape[3])
new_data = new_data.numpy()
self.gguf_writer.add_tensor(new_name, new_data, raw_dtype=gguf.GGMLQuantizationType.MXFP4)

def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
blocks0: Tensor = torch.zeros(1)
blocks1: Tensor = torch.zeros(1)
found_mxfp4_tensors = False
# we assume that tensors are loaded in the correct order
for name, data_torch in self.get_tensors():
if "mlp.experts.down_proj_blocks" in name:
blocks0 = data_torch
elif "mlp.experts.down_proj_scales" in name:
new_name = self.map_tensor_name(name.replace("_scales", ".weight"))
self.repack_mxfp4(new_name, blocks0, data_torch)
found_mxfp4_tensors = True
elif "mlp.experts.gate_up_proj_blocks" in name:
blocks0, blocks1 = data_torch[:, ::2, :, :], data_torch[:, 1::2, :, :]
elif "mlp.experts.gate_up_proj_scales" in name:
scales0, scales1 = data_torch[:, ::2, :], data_torch[:, 1::2, :]
new_name_gate = self.map_tensor_name(name.replace("gate_up_proj_scales", "gate_proj.weight"))
new_name_up = self.map_tensor_name(name.replace("gate_up_proj_scales", "up_proj.weight"))
self.repack_mxfp4(new_name_gate, blocks0, scales0)
self.repack_mxfp4(new_name_up, blocks1, scales1)
found_mxfp4_tensors = True
if not found_mxfp4_tensors:
raise ValueError("No MXFP4 tensors found in the model. Please make sure you are using MXFP4 model.")
return []

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

if "sinks" in name:
name += ".weight"

# correct naming for down_proj
if "down_proj" in name:
if name.endswith("_bias"):
name = name.replace("down_proj_bias", "down_proj.bias")
else:
return []

# split the gate_up into gate and up
if "gate_up_proj" in name:
if name.endswith("_bias"):
name_up = name.replace("gate_up_proj_bias", "up_proj.bias")
name_gate = name.replace("gate_up_proj_bias", "gate_proj.bias")
gate_proj_bias, up_proj_bias = data_torch[..., ::2], data_torch[..., 1::2]
return [
(self.map_tensor_name(name_gate), gate_proj_bias),
(self.map_tensor_name(name_up), up_proj_bias)
]
else:
return []

return [(self.map_tensor_name(name), data_torch)]

def set_vocab(self):
self._set_vocab_gpt2()

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size"])

rope_scaling = self.hparams.get("rope_scaling") or {}
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type"))
assert rope_type == "yarn", f"GPT-OSS only supports yarn rope scaling, got {rope_type}"
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling.get("original_max_position_embeddings", 4096))


@ModelBase.register("Lfm2ForCausalLM")
@ModelBase.register("LFM2ForCausalLM")
class LFM2Model(TextModel):
Expand Down Expand Up @@ -8089,6 +8202,7 @@ class LazyTorchTensor(gguf.LazyBase):
_dtype_map: dict[torch.dtype, type] = {
torch.float16: np.float16,
torch.float32: np.float32,
torch.uint8: np.uint8,
}

# used for safetensors slices
Expand Down
Loading