Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions tools/server/server-common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1263,7 +1263,11 @@ json convert_anthropic_to_oai(const json & body) {
return oai_body;
}

json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64) {
json format_embeddings_response_oaicompat(
const json & request,
const std::string & model_name,
const json & embeddings,
bool use_base64) {
json data = json::array();
int32_t n_tokens = 0;
int i = 0;
Expand Down Expand Up @@ -1293,7 +1297,7 @@ json format_embeddings_response_oaicompat(const json & request, const json & emb
}

json res = json {
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"model", json_value(request, "model", model_name)},
{"object", "list"},
{"usage", json {
{"prompt_tokens", n_tokens},
Expand All @@ -1307,6 +1311,7 @@ json format_embeddings_response_oaicompat(const json & request, const json & emb

json format_response_rerank(
const json & request,
const std::string & model_name,
const json & ranks,
bool is_tei_format,
std::vector<std::string> & texts,
Expand Down Expand Up @@ -1338,7 +1343,7 @@ json format_response_rerank(
if (is_tei_format) return results;

json res = json{
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"model", json_value(request, "model", model_name)},
{"object", "list"},
{"usage", json{
{"prompt_tokens", n_tokens},
Expand Down
9 changes: 6 additions & 3 deletions tools/server/server-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
#include <vector>
#include <cinttypes>

#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo"

const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT);

using json = nlohmann::ordered_json;
Expand Down Expand Up @@ -298,11 +296,16 @@ json oaicompat_chat_params_parse(
json convert_anthropic_to_oai(const json & body);

// TODO: move it to server-task.cpp
json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false);
json format_embeddings_response_oaicompat(
const json & request,
const std::string & model_name,
const json & embeddings,
bool use_base64 = false);

// TODO: move it to server-task.cpp
json format_response_rerank(
const json & request,
const std::string & model_name,
const json & ranks,
bool is_tei_format,
std::vector<std::string> & texts,
Expand Down
27 changes: 22 additions & 5 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <cinttypes>
#include <memory>
#include <unordered_set>
#include <filesystem>

// fix problem with std::min and std::max
#if defined(_WIN32)
Expand Down Expand Up @@ -518,6 +519,8 @@ struct server_context_impl {
// Necessary similarity of prompt for slot selection
float slot_prompt_similarity = 0.0f;

std::string model_name; // name of the loaded model, to be used by API

common_chat_templates_ptr chat_templates;
oaicompat_parser_options oai_parser_opt;

Expand Down Expand Up @@ -758,6 +761,18 @@ struct server_context_impl {
}
SRV_WRN("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n");

if (!params_base.model_alias.empty()) {
// user explicitly specified model name
model_name = params_base.model_alias;
} else if (!params_base.model.name.empty()) {
// use model name in registry format (for models in cache)
model_name = params_base.model.name;
} else {
// fallback: derive model name from file name
auto model_path = std::filesystem::path(params_base.model.path);
model_name = model_path.filename().string();
}

// thinking is enabled if:
// 1. It's not explicitly disabled (reasoning_budget == 0)
// 2. The chat template supports it
Expand Down Expand Up @@ -2605,6 +2620,7 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
task.params = server_task::params_from_json_cmpl(
ctx_server.ctx,
ctx_server.params_base,
ctx_server.model_name,
data);
task.id_slot = json_value(data, "id_slot", -1);

Expand Down Expand Up @@ -2939,7 +2955,7 @@ void server_routes::init_routes() {
json data = {
{ "default_generation_settings", default_generation_settings_for_props },
{ "total_slots", ctx_server.params_base.n_parallel },
{ "model_alias", ctx_server.params_base.model_alias },
{ "model_alias", ctx_server.model_name },
{ "model_path", ctx_server.params_base.model.path },
{ "modalities", json {
{"vision", ctx_server.oai_parser_opt.allow_image},
Expand Down Expand Up @@ -3181,8 +3197,8 @@ void server_routes::init_routes() {
json models = {
{"models", {
{
{"name", params.model_alias.empty() ? params.model.path : params.model_alias},
{"model", params.model_alias.empty() ? params.model.path : params.model_alias},
{"name", ctx_server.model_name},
{"model", ctx_server.model_name},
{"modified_at", ""},
{"size", ""},
{"digest", ""}, // dummy value, llama.cpp does not support managing model file's hash
Expand All @@ -3204,7 +3220,7 @@ void server_routes::init_routes() {
{"object", "list"},
{"data", {
{
{"id", params.model_alias.empty() ? params.model.path : params.model_alias},
{"id", ctx_server.model_name},
{"object", "model"},
{"created", std::time(0)},
{"owned_by", "llamacpp"},
Expand Down Expand Up @@ -3351,6 +3367,7 @@ void server_routes::init_routes() {
// write JSON response
json root = format_response_rerank(
body,
ctx_server.model_name,
responses,
is_tei_format,
documents,
Expand Down Expand Up @@ -3613,7 +3630,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_embeddings_impl(cons

// write JSON response
json root = res_type == TASK_RESPONSE_TYPE_OAI_EMBD
? format_embeddings_response_oaicompat(body, responses, use_base64)
? format_embeddings_response_oaicompat(body, ctx_server.model_name, responses, use_base64)
: json(responses);
res->ok(root);
return res;
Expand Down
2 changes: 1 addition & 1 deletion tools/server/server-task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ json task_params::to_json(bool only_metrics) const {
task_params server_task::params_from_json_cmpl(
const llama_context * ctx,
const common_params & params_base,
const std::string & model_name,
const json & data) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
Expand Down Expand Up @@ -450,7 +451,6 @@ task_params server_task::params_from_json_cmpl(
}
}

std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias;
params.oaicompat_model = json_value(data, "model", model_name);

return params;
Expand Down
1 change: 1 addition & 0 deletions tools/server/server-task.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ struct server_task {
static task_params params_from_json_cmpl(
const llama_context * ctx,
const common_params & params_base,
const std::string & model_name,
const json & data);

// utility function
Expand Down
4 changes: 2 additions & 2 deletions tools/server/tests/unit/test_chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
)
def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
global server
server.model_alias = None # try using DEFAULT_OAICOMPAT_MODEL
server.model_alias = "llama-test-model"
server.start()
res = server.make_stream_request("POST", "/chat/completions", data={
"max_tokens": max_tokens,
Expand All @@ -81,7 +81,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
else:
assert "role" not in choice["delta"]
assert data["system_fingerprint"].startswith("b")
assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
assert data["model"] == "llama-test-model"
if last_cmpl_id is None:
last_cmpl_id = data["id"]
assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
Expand Down
Loading