Skip to content

Commit abe598b

Browse files
committed
move states to server_response_reader
1 parent e89b80d commit abe598b

File tree

3 files changed

+24
-7
lines changed

3 files changed

+24
-7
lines changed

tools/server/server-context.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2571,12 +2571,12 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
25712571
auto completion_id = gen_chatcmplid();
25722572
auto & rd = res->rd;
25732573

2574-
// tracking generation state and partial tool calls
2575-
std::vector<task_result_state> states;
2576-
25772574
try {
25782575
std::vector<server_task> tasks;
25792576

2577+
// tracking generation state and partial tool calls
2578+
std::vector<task_result_state> states;
2579+
25802580
const auto & prompt = data.at("prompt");
25812581
// TODO: this log can become very long, put it behind a flag or think about a more compact format
25822582
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
@@ -2615,6 +2615,7 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
26152615
tasks.push_back(std::move(task));
26162616
}
26172617

2618+
rd.set_states(std::move(states));
26182619
rd.post_tasks(std::move(tasks));
26192620
} catch (const std::exception & e) {
26202621
res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
@@ -2635,7 +2636,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
26352636
json arr = json::array();
26362637
for (auto & res : all_results.results) {
26372638
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
2638-
res->update(states[res->get_index()]); // update generation state
26392639
arr.push_back(res->to_json());
26402640
}
26412641
// if single request, return single object instead of array
@@ -2656,7 +2656,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
26562656
dynamic_cast<server_task_result_cmpl_partial*>(first_result.get()) != nullptr
26572657
|| dynamic_cast<server_task_result_cmpl_final*>(first_result.get()) != nullptr
26582658
);
2659-
first_result->update(states[first_result->get_index()]); // update generation state
26602659
}
26612660

26622661
// next responses are streamed
@@ -2669,7 +2668,7 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
26692668
}
26702669
res->status = 200;
26712670
res->content_type = "text/event-stream";
2672-
res->next = [res_this = res.get(), res_type, &should_stop, states = std::move(states)](std::string & output) mutable -> bool {
2671+
res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) mutable -> bool {
26732672
static auto format_error = [](task_response_type res_type, const json & res_json) {
26742673
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
26752674
return format_anthropic_sse({
@@ -2728,7 +2727,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
27282727
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
27292728
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
27302729
);
2731-
result->update(states[result->get_index()]); // update generation state
27322730
json res_json = result->to_json();
27332731
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
27342732
output = format_anthropic_sse(res_json);

tools/server/server-queue.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,10 @@ void server_response::terminate() {
271271
// server_response_reader
272272
//
273273

274+
void server_response_reader::set_states(std::vector<task_result_state> && states) {
275+
this->states = std::move(states);
276+
}
277+
274278
void server_response_reader::post_tasks(std::vector<server_task> && tasks) {
275279
id_tasks = server_task::get_list_id(tasks);
276280
queue_results.add_waiting_tasks(tasks);
@@ -298,6 +302,12 @@ server_task_result_ptr server_response_reader::next(const std::function<bool()>
298302
SRV_DBG("%s", "received error result, stopping further processing\n");
299303
return result;
300304
}
305+
if (!states.empty()) {
306+
// update the generation state if needed
307+
auto idx = result->get_index();
308+
GGML_ASSERT(idx >= 0 && idx < states.size());
309+
result->update(states[idx]);
310+
}
301311
if (result->is_stop()) {
302312
received_count++;
303313
}

tools/server/server-queue.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#include <mutex>
88
#include <unordered_set>
99

10+
// struct for managing server tasks
11+
// in most cases, use server_response_reader to post new tasks and retrieve results
1012
struct server_queue {
1113
private:
1214
int id = 0;
@@ -67,6 +69,8 @@ struct server_queue {
6769
void cleanup_pending_task(int id_target);
6870
};
6971

72+
// struct for managing server responses
73+
// in most cases, use server_response_reader to retrieve results
7074
struct server_response {
7175
private:
7276
bool running = true;
@@ -120,13 +124,18 @@ struct server_response_reader {
120124
bool cancelled = false;
121125
int polling_interval_seconds;
122126

127+
// tracking generation state and partial tool calls
128+
// only used by streaming completions
129+
std::vector<task_result_state> states;
130+
123131
// should_stop function will be called each polling_interval_seconds
124132
server_response_reader(std::pair<server_queue &, server_response &> server_queues, int polling_interval_seconds)
125133
: queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {}
126134
~server_response_reader() {
127135
stop();
128136
}
129137

138+
void set_states(std::vector<task_result_state> && states);
130139
void post_tasks(std::vector<server_task> && tasks);
131140
bool has_next() const;
132141

0 commit comments

Comments
 (0)