diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index c9245745756..f3f2edc0cc4 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -101,8 +101,6 @@ struct server_slot { std::string generated_text; llama_tokens generated_tokens; - common_chat_msg chat_msg; - std::vector generated_token_probs; bool has_next_token = true; @@ -153,9 +151,6 @@ struct server_slot { llama_token sampled; - common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - std::vector generated_tool_call_ids; - // stats size_t n_sent_text = 0; // number of sent text character @@ -183,13 +178,10 @@ struct server_slot { stop = STOP_TYPE_NONE; stopping_word = ""; n_sent_text = 0; - chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; generated_tokens.clear(); generated_token_probs.clear(); - chat_msg = {}; json_schema = json(); - generated_tool_call_ids.clear(); // clear speculative decoding stats n_draft_total = 0; @@ -302,23 +294,6 @@ struct server_slot { return timings; } - const common_chat_msg & update_chat_msg(std::vector & diffs) { - GGML_ASSERT(task); - - auto previous_msg = chat_msg; - SRV_DBG("Parsing chat message: %s\n", generated_text.c_str()); - auto new_msg = common_chat_parse( - generated_text, - /* is_partial= */ stop != STOP_TYPE_EOS, - task->params.oaicompat_chat_syntax); - if (!new_msg.empty()) { - new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id); - chat_msg = new_msg; - diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg); - } - return chat_msg; - } - size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) { GGML_ASSERT(task); @@ -1284,8 +1259,6 @@ struct server_context_impl { } else { res->content = tkn.text_to_send; res->tokens = { tkn.tok }; - - slot.update_chat_msg(res->oaicompat_msg_diffs); } res->n_decoded = slot.n_decoded; @@ -1317,8 +1290,14 @@ struct server_context_impl { res->id_slot = slot.id; res->index = slot.task->index; - res->content = slot.generated_text; - res->tokens = std::move(slot.generated_tokens); + // in stream mode, content and tokens are already in last partial chunk + if (slot.task->params.stream) { + res->content = ""; + res->tokens = llama_tokens{}; + } else { + res->content = std::move(slot.generated_text); + res->tokens = std::move(slot.generated_tokens); + } res->timings = slot.get_timings(); res->prompt = slot.task->tokens.detokenize(ctx, true); res->response_fields = std::move(slot.task->params.response_fields); @@ -1338,7 +1317,6 @@ struct server_context_impl { res->res_type = slot.task->params.res_type; res->oaicompat_model = slot.task->params.oaicompat_model; res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id; - res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs); // populate res.probs_output if (slot.task->params.sampling.n_probs > 0) { @@ -2596,6 +2574,9 @@ static std::unique_ptr handle_completions_impl( try { std::vector tasks; + // tracking generation state and partial tool calls + std::vector states; + const auto & prompt = data.at("prompt"); // TODO: this log can become very long, put it behind a flag or think about a more compact format //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); @@ -2611,6 +2592,7 @@ static std::unique_ptr handle_completions_impl( inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); } tasks.reserve(inputs.size()); + states.reserve(inputs.size()); for (size_t i = 0; i < inputs.size(); i++) { server_task task = server_task(type); @@ -2628,10 +2610,12 @@ static std::unique_ptr handle_completions_impl( task.params.res_type = res_type; task.params.oaicompat_cmpl_id = completion_id; task.params.oaicompat_model = ctx_server.model_name; + states.push_back(task.params.oaicompat_chat_syntax); tasks.push_back(std::move(task)); } + rd.set_states(std::move(states)); rd.post_tasks(std::move(tasks)); } catch (const std::exception & e) { res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); @@ -2657,7 +2641,6 @@ static std::unique_ptr handle_completions_impl( // if single request, return single object instead of array res->ok(arr.size() == 1 ? arr[0] : arr); } - } else { // in streaming mode, the first error must be treated as non-stream response // this is to match the OAI API behavior @@ -2676,76 +2659,92 @@ static std::unique_ptr handle_completions_impl( } // next responses are streamed + // to be sent immediately + json first_result_json = first_result->to_json(); if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { - res->data = format_anthropic_sse(first_result->to_json()); + res->data = format_anthropic_sse(first_result_json); } else { - res->data = format_oai_sse(first_result->to_json()); // to be sent immediately + res->data = format_oai_sse(first_result_json); } res->status = 200; res->content_type = "text/event-stream"; res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool { - if (should_stop()) { - SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); - return false; // should_stop condition met - } - - if (!res_this->data.empty()) { - // flush the first chunk - output = std::move(res_this->data); - res_this->data.clear(); - return true; - } - - server_response_reader & rd = res_this->rd; - - // check if there is more data - if (!rd.has_next()) { - if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { - // Anthropic doesn't send [DONE], message_stop was already sent - output = ""; - } else if (res_type != TASK_RESPONSE_TYPE_NONE) { - output = "data: [DONE]\n\n"; - } else { - output = ""; - } - SRV_DBG("%s", "all results received, terminating stream\n"); - return false; // no more data, terminate - } - - // receive subsequent results - auto result = rd.next(should_stop); - if (result == nullptr) { - SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); - return false; // should_stop condition met - } - - // send the results - json res_json = result->to_json(); - if (result->is_error()) { + static auto format_error = [](task_response_type res_type, const json & res_json) { if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { - output = format_anthropic_sse({ + return format_anthropic_sse({ {"event", "error"}, {"data", res_json}, }); } else { - output = format_oai_sse(json {{ "error", res_json }}); + return format_oai_sse(json {{ "error", res_json }}); } - SRV_DBG("%s", "error received during streaming, terminating stream\n"); - return false; // terminate on error - } else { - GGML_ASSERT( - dynamic_cast(result.get()) != nullptr - || dynamic_cast(result.get()) != nullptr - ); - if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { - output = format_anthropic_sse(res_json); + }; + + try { + if (should_stop()) { + SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); + return false; // should_stop condition met + } + + if (!res_this->data.empty()) { + // flush the first chunk + output = std::move(res_this->data); + res_this->data.clear(); + return true; + } + + server_response_reader & rd = res_this->rd; + + // check if there is more data + if (!rd.has_next()) { + if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { + // Anthropic doesn't send [DONE], message_stop was already sent + output = ""; + } else if (res_type != TASK_RESPONSE_TYPE_NONE) { + output = "data: [DONE]\n\n"; + } else { + output = ""; + } + SRV_DBG("%s", "all results received, terminating stream\n"); + return false; // no more data, terminate + } + + // receive subsequent results + auto result = rd.next(should_stop); + if (result == nullptr) { + SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); + return false; // should_stop condition met + } + + // send the results + if (result->is_error()) { + json res_json = result->to_json(); + output = format_error(res_type, res_json); + SRV_DBG("%s", "error received during streaming, terminating stream\n"); + return false; // terminate on error } else { - output = format_oai_sse(res_json); + GGML_ASSERT( + dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + ); + json res_json = result->to_json(); + if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { + output = format_anthropic_sse(res_json); + } else { + output = format_oai_sse(res_json); + } } - } - // has next data, continue - return true; + // has next data, continue + return true; + + } catch (const std::exception & e) { + json error_json = format_error_response(e.what(), ERROR_TYPE_SERVER); + output = format_error(res_type, error_json); + + // terminate on exception + return false; + } }; } diff --git a/tools/server/server-queue.cpp b/tools/server/server-queue.cpp index 38a4858522e..10196128db1 100644 --- a/tools/server/server-queue.cpp +++ b/tools/server/server-queue.cpp @@ -271,6 +271,10 @@ void server_response::terminate() { // server_response_reader // +void server_response_reader::set_states(std::vector && states) { + this->states = std::move(states); +} + void server_response_reader::post_tasks(std::vector && tasks) { id_tasks = server_task::get_list_id(tasks); queue_results.add_waiting_tasks(tasks); @@ -298,6 +302,12 @@ server_task_result_ptr server_response_reader::next(const std::function SRV_DBG("%s", "received error result, stopping further processing\n"); return result; } + if (!states.empty()) { + // update the generation state if needed + size_t idx = result->get_index(); + GGML_ASSERT(idx < states.size()); + result->update(states[idx]); + } if (result->is_stop()) { received_count++; } diff --git a/tools/server/server-queue.h b/tools/server/server-queue.h index 209d2017c7e..a5c3179d8ca 100644 --- a/tools/server/server-queue.h +++ b/tools/server/server-queue.h @@ -7,6 +7,8 @@ #include #include +// struct for managing server tasks +// in most cases, use server_response_reader to post new tasks and retrieve results struct server_queue { private: int id = 0; @@ -67,6 +69,8 @@ struct server_queue { void cleanup_pending_task(int id_target); }; +// struct for managing server responses +// in most cases, use server_response_reader to retrieve results struct server_response { private: bool running = true; @@ -120,6 +124,10 @@ struct server_response_reader { bool cancelled = false; int polling_interval_seconds; + // tracking generation state and partial tool calls + // only used by streaming completions + std::vector states; + // should_stop function will be called each polling_interval_seconds server_response_reader(std::pair server_queues, int polling_interval_seconds) : queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {} @@ -127,6 +135,7 @@ struct server_response_reader { stop(); } + void set_states(std::vector && states); void post_tasks(std::vector && tasks); bool has_next() const; diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 8a9477d7321..df066264778 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -565,6 +565,7 @@ std::vector completion_token_output::str_to_bytes(const std::stri // server_task_result_cmpl_final // json server_task_result_cmpl_final::to_json() { + GGML_ASSERT(is_updated && "update() must be called before to_json()"); switch (res_type) { case TASK_RESPONSE_TYPE_NONE: return to_json_non_oaicompat(); @@ -582,8 +583,8 @@ json server_task_result_cmpl_final::to_json() { json server_task_result_cmpl_final::to_json_non_oaicompat() { json res = json { {"index", index}, - {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk - {"tokens", stream ? llama_tokens {} : tokens}, + {"content", content}, + {"tokens", tokens}, {"id_slot", id_slot}, {"stop", true}, {"model", oaicompat_model}, @@ -619,7 +620,7 @@ json server_task_result_cmpl_final::to_json_oaicompat() { json res = json { {"choices", json::array({ json{ - {"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk + {"text", content}, {"index", index}, {"logprobs", logprobs}, {"finish_reason", finish_reason}, @@ -700,6 +701,25 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat() { return res; } +common_chat_msg task_result_state::update_chat_msg( + const std::string & text_added, + bool is_partial, + std::vector & diffs) { + generated_text += text_added; + auto msg_prv_copy = chat_msg; + SRV_DBG("Parsing chat message: %s\n", generated_text.c_str()); + auto new_msg = common_chat_parse( + generated_text, + is_partial, + oaicompat_chat_syntax); + if (!new_msg.empty()) { + new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id); + chat_msg = new_msg; + diffs = common_chat_msg_diff::compute_diffs(msg_prv_copy, new_msg.empty() ? msg_prv_copy : new_msg); + } + return chat_msg; +} + json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() { std::time_t t = std::time(0); std::string finish_reason = "length"; @@ -956,6 +976,7 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() { // server_task_result_cmpl_partial // json server_task_result_cmpl_partial::to_json() { + GGML_ASSERT(is_updated && "update() must be called before to_json()"); switch (res_type) { case TASK_RESPONSE_TYPE_NONE: return to_json_non_oaicompat(); diff --git a/tools/server/server-task.h b/tools/server/server-task.h index a22d7cab116..8e7b9e3e310 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -161,6 +161,25 @@ struct result_prompt_progress { json to_json() const; }; +// struct for tracking the state of a task (e.g., for streaming) +struct task_result_state { + // tracking diffs for partial tool calls + std::vector diffs; + common_chat_syntax oaicompat_chat_syntax; + common_chat_msg chat_msg; + std::string generated_text; // append new chunks of generated text here + std::vector generated_tool_call_ids; + + task_result_state(const common_chat_syntax & oaicompat_chat_syntax) + : oaicompat_chat_syntax(oaicompat_chat_syntax) {} + + // parse partial tool calls and update the internal state + common_chat_msg update_chat_msg( + const std::string & text_added, + bool is_partial, + std::vector & diffs); +}; + struct server_task_result { int id = -1; int id_slot = -1; @@ -175,6 +194,9 @@ struct server_task_result { virtual int get_index() { return -1; } + virtual void update(task_result_state &) { + // only used by server_task_result_cmpl_* + } virtual json to_json() = 0; virtual ~server_task_result() = default; }; @@ -233,9 +255,10 @@ struct server_task_result_cmpl_final : server_task_result { task_response_type res_type = TASK_RESPONSE_TYPE_NONE; std::string oaicompat_model; std::string oaicompat_cmpl_id; - common_chat_msg oaicompat_msg; + common_chat_msg oaicompat_msg; // to be populated by update() - std::vector oaicompat_msg_diffs; + std::vector oaicompat_msg_diffs; // to be populated by update() + bool is_updated = false; virtual int get_index() override { return index; @@ -247,6 +270,11 @@ struct server_task_result_cmpl_final : server_task_result { virtual json to_json() override; + virtual void update(task_result_state & state) override { + is_updated = true; + oaicompat_msg = state.update_chat_msg(content, false, oaicompat_msg_diffs); + } + json to_json_non_oaicompat(); json to_json_oaicompat(); @@ -280,7 +308,8 @@ struct server_task_result_cmpl_partial : server_task_result { task_response_type res_type = TASK_RESPONSE_TYPE_NONE; std::string oaicompat_model; std::string oaicompat_cmpl_id; - std::vector oaicompat_msg_diffs; + std::vector oaicompat_msg_diffs; // to be populated by update() + bool is_updated = false; virtual int get_index() override { return index; @@ -292,6 +321,11 @@ struct server_task_result_cmpl_partial : server_task_result { virtual json to_json() override; + virtual void update(task_result_state & state) override { + is_updated = true; + state.update_chat_msg(content, true, oaicompat_msg_diffs); + } + json to_json_non_oaicompat(); json to_json_oaicompat();