Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 85 additions & 86 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,6 @@ struct server_slot {
std::string generated_text;
llama_tokens generated_tokens;

common_chat_msg chat_msg;

std::vector<completion_token_output> generated_token_probs;

bool has_next_token = true;
Expand Down Expand Up @@ -153,9 +151,6 @@ struct server_slot {

llama_token sampled;

common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
std::vector<std::string> generated_tool_call_ids;

// stats
size_t n_sent_text = 0; // number of sent text character

Expand Down Expand Up @@ -183,13 +178,10 @@ struct server_slot {
stop = STOP_TYPE_NONE;
stopping_word = "";
n_sent_text = 0;
chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;

generated_tokens.clear();
generated_token_probs.clear();
chat_msg = {};
json_schema = json();
generated_tool_call_ids.clear();

// clear speculative decoding stats
n_draft_total = 0;
Expand Down Expand Up @@ -302,23 +294,6 @@ struct server_slot {
return timings;
}

const common_chat_msg & update_chat_msg(std::vector<common_chat_msg_diff> & diffs) {
GGML_ASSERT(task);

auto previous_msg = chat_msg;
SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
auto new_msg = common_chat_parse(
generated_text,
/* is_partial= */ stop != STOP_TYPE_EOS,
task->params.oaicompat_chat_syntax);
if (!new_msg.empty()) {
new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id);
chat_msg = new_msg;
diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg);
}
return chat_msg;
}

size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
GGML_ASSERT(task);

Expand Down Expand Up @@ -1284,8 +1259,6 @@ struct server_context_impl {
} else {
res->content = tkn.text_to_send;
res->tokens = { tkn.tok };

slot.update_chat_msg(res->oaicompat_msg_diffs);
}

res->n_decoded = slot.n_decoded;
Expand Down Expand Up @@ -1317,8 +1290,14 @@ struct server_context_impl {
res->id_slot = slot.id;

res->index = slot.task->index;
res->content = slot.generated_text;
res->tokens = std::move(slot.generated_tokens);
// in stream mode, content and tokens are already in last partial chunk
if (slot.task->params.stream) {
res->content = "";
res->tokens = llama_tokens{};
} else {
res->content = std::move(slot.generated_text);
res->tokens = std::move(slot.generated_tokens);
}
res->timings = slot.get_timings();
res->prompt = slot.task->tokens.detokenize(ctx, true);
res->response_fields = std::move(slot.task->params.response_fields);
Expand All @@ -1338,7 +1317,6 @@ struct server_context_impl {
res->res_type = slot.task->params.res_type;
res->oaicompat_model = slot.task->params.oaicompat_model;
res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);

// populate res.probs_output
if (slot.task->params.sampling.n_probs > 0) {
Expand Down Expand Up @@ -2596,6 +2574,9 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
try {
std::vector<server_task> tasks;

// tracking generation state and partial tool calls
std::vector<task_result_state> states;

const auto & prompt = data.at("prompt");
// TODO: this log can become very long, put it behind a flag or think about a more compact format
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
Expand All @@ -2611,6 +2592,7 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
}
tasks.reserve(inputs.size());
states.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
server_task task = server_task(type);

Expand All @@ -2628,10 +2610,12 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
task.params.res_type = res_type;
task.params.oaicompat_cmpl_id = completion_id;
task.params.oaicompat_model = ctx_server.model_name;
states.push_back(task.params.oaicompat_chat_syntax);

tasks.push_back(std::move(task));
}

rd.set_states(std::move(states));
rd.post_tasks(std::move(tasks));
} catch (const std::exception & e) {
res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
Expand All @@ -2657,7 +2641,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
// if single request, return single object instead of array
res->ok(arr.size() == 1 ? arr[0] : arr);
}

} else {
// in streaming mode, the first error must be treated as non-stream response
// this is to match the OAI API behavior
Expand All @@ -2676,76 +2659,92 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
}

// next responses are streamed
// to be sent immediately
json first_result_json = first_result->to_json();
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
res->data = format_anthropic_sse(first_result->to_json());
res->data = format_anthropic_sse(first_result_json);
} else {
res->data = format_oai_sse(first_result->to_json()); // to be sent immediately
res->data = format_oai_sse(first_result_json);
}
res->status = 200;
res->content_type = "text/event-stream";
res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool {
if (should_stop()) {
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
return false; // should_stop condition met
}

if (!res_this->data.empty()) {
// flush the first chunk
output = std::move(res_this->data);
res_this->data.clear();
return true;
}

server_response_reader & rd = res_this->rd;

// check if there is more data
if (!rd.has_next()) {
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
// Anthropic doesn't send [DONE], message_stop was already sent
output = "";
} else if (res_type != TASK_RESPONSE_TYPE_NONE) {
output = "data: [DONE]\n\n";
} else {
output = "";
}
SRV_DBG("%s", "all results received, terminating stream\n");
return false; // no more data, terminate
}

// receive subsequent results
auto result = rd.next(should_stop);
if (result == nullptr) {
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
return false; // should_stop condition met
}

// send the results
json res_json = result->to_json();
if (result->is_error()) {
static auto format_error = [](task_response_type res_type, const json & res_json) {
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
output = format_anthropic_sse({
return format_anthropic_sse({
{"event", "error"},
{"data", res_json},
});
} else {
output = format_oai_sse(json {{ "error", res_json }});
return format_oai_sse(json {{ "error", res_json }});
}
SRV_DBG("%s", "error received during streaming, terminating stream\n");
return false; // terminate on error
} else {
GGML_ASSERT(
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
);
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
output = format_anthropic_sse(res_json);
};

try {
if (should_stop()) {
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
return false; // should_stop condition met
}

if (!res_this->data.empty()) {
// flush the first chunk
output = std::move(res_this->data);
res_this->data.clear();
return true;
}

server_response_reader & rd = res_this->rd;

// check if there is more data
if (!rd.has_next()) {
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
// Anthropic doesn't send [DONE], message_stop was already sent
output = "";
} else if (res_type != TASK_RESPONSE_TYPE_NONE) {
output = "data: [DONE]\n\n";
} else {
output = "";
}
SRV_DBG("%s", "all results received, terminating stream\n");
return false; // no more data, terminate
}

// receive subsequent results
auto result = rd.next(should_stop);
if (result == nullptr) {
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
return false; // should_stop condition met
}

// send the results
if (result->is_error()) {
json res_json = result->to_json();
output = format_error(res_type, res_json);
SRV_DBG("%s", "error received during streaming, terminating stream\n");
return false; // terminate on error
} else {
output = format_oai_sse(res_json);
GGML_ASSERT(
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
);
json res_json = result->to_json();
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
output = format_anthropic_sse(res_json);
} else {
output = format_oai_sse(res_json);
}
}
}

// has next data, continue
return true;
// has next data, continue
return true;

} catch (const std::exception & e) {
json error_json = format_error_response(e.what(), ERROR_TYPE_SERVER);
output = format_error(res_type, error_json);

// terminate on exception
return false;
}
};
}

Expand Down
10 changes: 10 additions & 0 deletions tools/server/server-queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,10 @@ void server_response::terminate() {
// server_response_reader
//

void server_response_reader::set_states(std::vector<task_result_state> && states) {
this->states = std::move(states);
}

void server_response_reader::post_tasks(std::vector<server_task> && tasks) {
id_tasks = server_task::get_list_id(tasks);
queue_results.add_waiting_tasks(tasks);
Expand Down Expand Up @@ -298,6 +302,12 @@ server_task_result_ptr server_response_reader::next(const std::function<bool()>
SRV_DBG("%s", "received error result, stopping further processing\n");
return result;
}
if (!states.empty()) {
// update the generation state if needed
size_t idx = result->get_index();
GGML_ASSERT(idx < states.size());
result->update(states[idx]);
}
if (result->is_stop()) {
received_count++;
}
Expand Down
9 changes: 9 additions & 0 deletions tools/server/server-queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include <mutex>
#include <unordered_set>

// struct for managing server tasks
// in most cases, use server_response_reader to post new tasks and retrieve results
struct server_queue {
private:
int id = 0;
Expand Down Expand Up @@ -67,6 +69,8 @@ struct server_queue {
void cleanup_pending_task(int id_target);
};

// struct for managing server responses
// in most cases, use server_response_reader to retrieve results
struct server_response {
private:
bool running = true;
Expand Down Expand Up @@ -120,13 +124,18 @@ struct server_response_reader {
bool cancelled = false;
int polling_interval_seconds;

// tracking generation state and partial tool calls
// only used by streaming completions
std::vector<task_result_state> states;

// should_stop function will be called each polling_interval_seconds
server_response_reader(std::pair<server_queue &, server_response &> server_queues, int polling_interval_seconds)
: queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {}
~server_response_reader() {
stop();
}

void set_states(std::vector<task_result_state> && states);
void post_tasks(std::vector<server_task> && tasks);
bool has_next() const;

Expand Down
Loading
Loading