Skip to content

Commit 44bdb75

Browse files
committed
support chat parsing for gpt-oss
1 parent 6b30372 commit 44bdb75

File tree

5 files changed

+36
-4
lines changed

5 files changed

+36
-4
lines changed

common/arg.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2922,11 +2922,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
29222922
"controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n"
29232923
"- none: leaves thoughts unparsed in `message.content`\n"
29242924
"- deepseek: puts thoughts in `message.reasoning_content` (except in streaming mode, which behaves as `none`)\n"
2925-
"(default: deepseek)",
2925+
"(default: auto)",
29262926
[](common_params & params, const std::string & value) {
29272927
/**/ if (value == "deepseek") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; }
29282928
else if (value == "deepseek-legacy") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY; }
29292929
else if (value == "none") { params.reasoning_format = COMMON_REASONING_FORMAT_NONE; }
2930+
else if (value == "auto") { params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; }
29302931
else { throw std::invalid_argument("invalid value"); }
29312932
}
29322933
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK"));

common/chat.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,7 @@ const char * common_chat_format_name(common_chat_format format) {
592592
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
593593
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
594594
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
595+
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
595596
default:
596597
throw std::runtime_error("Unknown chat format");
597598
}
@@ -1289,6 +1290,26 @@ static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
12891290
tool_calls_end);
12901291
}
12911292

1293+
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) {
1294+
common_chat_params data;
1295+
auto prompt = apply(tmpl, inputs);
1296+
1297+
data.prompt = prompt;
1298+
data.format = COMMON_CHAT_FORMAT_GPT_OSS;
1299+
1300+
// TODO: support tool calls in GPT-OSS?
1301+
1302+
return data;
1303+
}
1304+
static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
1305+
// TODO @ngxson : this won't work with --special enabled, we should fix that
1306+
builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|start|>assistant<|channel|>final<|message|>");
1307+
if (!builder.syntax().parse_tool_calls) {
1308+
builder.add_content(builder.consume_rest());
1309+
return;
1310+
}
1311+
}
1312+
12921313
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
12931314
LOG_DBG("%s\n", __func__);
12941315
common_chat_params data;
@@ -1774,6 +1795,11 @@ static common_chat_params common_chat_templates_apply_jinja(
17741795
return common_chat_params_init_hermes_2_pro(tmpl, params);
17751796
}
17761797

1798+
// GPT-OSS
1799+
if (src.find("<|channel|>") != std::string::npos && params.json_schema.is_null()) {
1800+
return common_chat_params_init_gpt_oss(tmpl, params);
1801+
}
1802+
17771803
// Use generic handler when mixing tools + JSON schema.
17781804
// TODO: support that mix in handlers below.
17791805
if ((params.tools.is_array() && params.json_schema.is_object())) {
@@ -1925,6 +1951,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
19251951
case COMMON_CHAT_FORMAT_COMMAND_R7B:
19261952
common_chat_parse_command_r7b(builder);
19271953
break;
1954+
case COMMON_CHAT_FORMAT_GPT_OSS:
1955+
common_chat_parse_gpt_oss(builder);
1956+
break;
19281957
default:
19291958
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
19301959
}

common/chat.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ enum common_chat_format {
109109
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
110110
COMMON_CHAT_FORMAT_HERMES_2_PRO,
111111
COMMON_CHAT_FORMAT_COMMAND_R7B,
112+
COMMON_CHAT_FORMAT_GPT_OSS,
112113

113114
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
114115
};

common/common.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ struct common_params_diffusion {
236236

237237
enum common_reasoning_format {
238238
COMMON_REASONING_FORMAT_NONE,
239+
COMMON_REASONING_FORMAT_AUTO,
239240
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
240241
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
241242
};
@@ -394,7 +395,7 @@ struct common_params {
394395
std::string chat_template = ""; // NOLINT
395396
bool use_jinja = false; // NOLINT
396397
bool enable_chat_template = true;
397-
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
398+
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO;
398399
int reasoning_budget = -1;
399400
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
400401

src/llama-vocab.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2330,9 +2330,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
23302330
}
23312331
}
23322332

2333-
// @ngxson : quick hack for gpt-oss
2333+
// @ngxson : quick hack for gpt-oss, always render these tokens
23342334
for (const auto & t : token_to_id) {
2335-
if (t.first == "<|channel|>" || t.first == "<|message|>") {
2335+
if (t.first == "<|channel|>" || t.first == "<|message|>" || t.first == "<|start|>") {
23362336
id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_USER_DEFINED;
23372337
}
23382338
}

0 commit comments

Comments
 (0)