Skip to content

Commit 951520d

Browse files
authored
server: delegate result_state creation to server_task (ggml-org#17835)
* server: delegate result_state creation to server_task * remove unued states * add more docs
1 parent 68522c6 commit 951520d

File tree

6 files changed

+76
-40
lines changed

6 files changed

+76
-40
lines changed

tools/server/README-dev.md

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,15 @@ graph TD
4242
server_response --> server_routes
4343
```
4444

45-
TODO: mention about how batching is handled by `server_slot`
45+
### Batching
46+
47+
The server context maintains a single batch shared across all slots. When `update_slots()` is invoked, the system iterates through all active slots to populate this batch. For each slot, either a generated token from the previous decoding step or available prompt tokens are added to the batch.
48+
49+
Batching constraints apply: slots can only be batched together if they share compatible configurations. For instance, slots using a specific LoRA adapter can be batched with each other, but not with slots using a different LoRA adapter or no adapter at all.
50+
51+
Once the batch reaches capacity or all slots have been processed, `llama_decode` is called to execute the inference. This operation represents the primary computational bottleneck in `update_slots()`.
52+
53+
Following decoding, the system either retrieves embeddings or samples the next token using `common_sampler_sample`. If a slot has remaining prompt tokens to process, it yields until the next `update_slots()` iteration.
4654

4755
### Thread Management
4856

@@ -62,6 +70,23 @@ Each incoming HTTP request is handled by its own thread managed by the HTTP libr
6270
- All JSON formatting and chat template logic must stay in the HTTP layer.
6371
- Avoid passing raw JSON between the HTTP layer and `server_slot`. Instead, parse everything into native C++ types as early as possible.
6472

73+
### Example trace of a request
74+
75+
Here is an example trace of an API request for text completion:
76+
77+
- A request arrives at the HTTP layer.
78+
- The request is routed to the corresponding handler inside `server_routes`. In this case, `handle_completions_impl` is invoked.
79+
- The handler parses the input request, constructs a new `server_task`, and passes it to `server_res_generator`.
80+
- `server_res_generator` creates a new `task_result_state` for each task:
81+
- `task_result_state` stays in the HTTP layer, responsible for keeping track of the current state of the response (e.g., parsing tool calls or thinking messages).
82+
- `server_task` is moved into `server_queue` inside `server_context`.
83+
- `server_context` launches the task by moving it into an available slot (see `launch_slot_with_task()`).
84+
- `update_slot()` processes the task as described in the "Batching" section above.
85+
- Results may be sent using `send_partial_response` or `send_final_response`, which creates a new `server_task_result` and pushes it to the response queue.
86+
- At the same time, `server_res_generator` listens to the response queue and retrieves this response.
87+
- As the response is stateless, `server_res_generator` calls `response->update()` to update the response with the current state.
88+
- `server_res_generator` then calls `response->to_json()` and passes the response to the HTTP layer.
89+
6590
### Testing
6691

6792
`llama-server` includes an automated test suite based on `pytest`.

tools/server/server-context.cpp

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2589,6 +2589,10 @@ struct server_context_impl {
25892589
int get_slot_n_ctx() {
25902590
return slots.back().n_ctx;
25912591
}
2592+
2593+
server_response_reader get_response_reader() {
2594+
return server_response_reader(queue_tasks, queue_results, HTTP_POLLING_SECONDS);
2595+
}
25922596
};
25932597

25942598
//
@@ -2618,8 +2622,8 @@ llama_context * server_context::get_llama_context() const {
26182622
return impl->ctx;
26192623
}
26202624

2621-
std::pair<server_queue &, server_response &> server_context::get_queues() {
2622-
return { impl->queue_tasks, impl->queue_results };
2625+
server_response_reader server_context::get_response_reader() {
2626+
return impl->get_response_reader();
26232627
}
26242628

26252629

@@ -2628,7 +2632,7 @@ std::pair<server_queue &, server_response &> server_context::get_queues() {
26282632
struct server_res_generator : server_http_res {
26292633
server_response_reader rd;
26302634
server_res_generator(server_context_impl & ctx_server)
2631-
: rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS) {}
2635+
: rd(ctx_server.queue_tasks, ctx_server.queue_results, HTTP_POLLING_SECONDS) {}
26322636
void ok(const json & response_data) {
26332637
status = 200;
26342638
data = safe_json_to_str(response_data);
@@ -2661,9 +2665,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
26612665
try {
26622666
std::vector<server_task> tasks;
26632667

2664-
// tracking generation state and partial tool calls
2665-
std::vector<task_result_state> states;
2666-
26672668
const auto & prompt = data.at("prompt");
26682669
// TODO: this log can become very long, put it behind a flag or think about a more compact format
26692670
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
@@ -2679,7 +2680,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
26792680
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
26802681
}
26812682
tasks.reserve(inputs.size());
2682-
states.reserve(inputs.size());
26832683
int idx = 0;
26842684
for (size_t i = 0; i < inputs.size(); i++) {
26852685
server_task task = server_task(type);
@@ -2698,7 +2698,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
26982698
task.params.res_type = res_type;
26992699
task.params.oaicompat_cmpl_id = completion_id;
27002700
task.params.oaicompat_model = ctx_server.model_name;
2701-
states.push_back(task.params.oaicompat_chat_syntax);
27022701

27032702
if (task.params.n_cmpl > 1) {
27042703
task.n_children = task.params.n_cmpl - 1;
@@ -2707,15 +2706,13 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
27072706
task.id,
27082707
ctx_server.queue_tasks.get_new_id(),
27092708
idx++);
2710-
states.push_back(child.params.oaicompat_chat_syntax);
27112709
tasks.push_back(std::move(child));
27122710
}
27132711
}
27142712

27152713
tasks.push_back(std::move(task));
27162714
}
27172715

2718-
rd.set_states(std::move(states));
27192716
rd.post_tasks(std::move(tasks));
27202717
} catch (const std::exception & e) {
27212718
res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
@@ -3445,7 +3442,7 @@ void server_routes::init_routes() {
34453442

34463443
// create and queue the task
34473444
json responses = json::array();
3448-
server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS);
3445+
server_response_reader rd = ctx_server.get_response_reader();
34493446
{
34503447
std::vector<server_task> tasks;
34513448
tasks.reserve(documents.size());
@@ -3705,7 +3702,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_embeddings_impl(cons
37053702

37063703
// create and queue the task
37073704
json responses = json::array();
3708-
server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS);
3705+
server_response_reader rd = ctx_server.get_response_reader();
37093706
{
37103707
std::vector<server_task> tasks;
37113708
for (size_t i = 0; i < tokenized_prompts.size(); i++) {

tools/server/server-context.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,8 @@ struct server_context {
3131
// get the underlaying llama_context
3232
llama_context * get_llama_context() const;
3333

34-
// get the underlaying queue_tasks and queue_results
35-
// used by CLI application
36-
std::pair<server_queue &, server_response &> get_queues();
34+
// get a new response reader, used by CLI application
35+
server_response_reader get_response_reader();
3736
};
3837

3938

tools/server/server-queue.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,12 +271,21 @@ void server_response::terminate() {
271271
// server_response_reader
272272
//
273273

274-
void server_response_reader::set_states(std::vector<task_result_state> && states) {
275-
this->states = std::move(states);
274+
void server_response_reader::post_task(server_task && task) {
275+
GGML_ASSERT(id_tasks.empty() && "post_task() can only be called once per reader");
276+
id_tasks.insert(task.id);
277+
states.push_back(task.create_state());
278+
queue_results.add_waiting_task_id(task.id);
279+
queue_tasks.post(std::move(task));
276280
}
277281

278282
void server_response_reader::post_tasks(std::vector<server_task> && tasks) {
283+
GGML_ASSERT(id_tasks.empty() && "post_tasks() can only be called once per reader");
279284
id_tasks = server_task::get_list_id(tasks);
285+
states.reserve(tasks.size());
286+
for (size_t i = 0; i < tasks.size(); i++) {
287+
states.push_back(tasks[i].create_state());
288+
}
280289
queue_results.add_waiting_tasks(tasks);
281290
queue_tasks.post(std::move(tasks));
282291
}

tools/server/server-queue.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,13 @@ struct server_response_reader {
129129
std::vector<task_result_state> states;
130130

131131
// should_stop function will be called each polling_interval_seconds
132-
server_response_reader(std::pair<server_queue &, server_response &> server_queues, int polling_interval_seconds)
133-
: queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {}
132+
server_response_reader(server_queue & queue_tasks, server_response & queue_results, int polling_interval_seconds)
133+
: queue_tasks(queue_tasks), queue_results(queue_results), polling_interval_seconds(polling_interval_seconds) {}
134134
~server_response_reader() {
135135
stop();
136136
}
137137

138-
void set_states(std::vector<task_result_state> && states);
138+
void post_task(server_task && tasks);
139139
void post_tasks(std::vector<server_task> && tasks);
140140
bool has_next() const;
141141

tools/server/server-task.h

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,25 @@ struct task_params {
8585
json to_json(bool only_metrics = false) const;
8686
};
8787

88+
// struct for tracking the state of a task (e.g., for streaming)
89+
struct task_result_state {
90+
// tracking diffs for partial tool calls
91+
std::vector<common_chat_msg_diff> diffs;
92+
common_chat_syntax oaicompat_chat_syntax;
93+
common_chat_msg chat_msg;
94+
std::string generated_text; // append new chunks of generated text here
95+
std::vector<std::string> generated_tool_call_ids;
96+
97+
task_result_state(const common_chat_syntax & oaicompat_chat_syntax)
98+
: oaicompat_chat_syntax(oaicompat_chat_syntax) {}
99+
100+
// parse partial tool calls and update the internal state
101+
common_chat_msg update_chat_msg(
102+
const std::string & text_added,
103+
bool is_partial,
104+
std::vector<common_chat_msg_diff> & diffs);
105+
};
106+
88107
struct server_task {
89108
int id = -1; // to be filled by server_queue
90109
int index = -1; // used when there are multiple prompts (batch request)
@@ -149,6 +168,12 @@ struct server_task {
149168
copy.tokens = tokens.clone();
150169
return copy;
151170
}
171+
172+
// the task will be moved into queue, then onto slots
173+
// however, the state must be kept by caller (e.g., HTTP thread)
174+
task_result_state create_state() const {
175+
return task_result_state(params.oaicompat_chat_syntax);
176+
}
152177
};
153178

154179
struct result_timings {
@@ -180,25 +205,6 @@ struct result_prompt_progress {
180205
json to_json() const;
181206
};
182207

183-
// struct for tracking the state of a task (e.g., for streaming)
184-
struct task_result_state {
185-
// tracking diffs for partial tool calls
186-
std::vector<common_chat_msg_diff> diffs;
187-
common_chat_syntax oaicompat_chat_syntax;
188-
common_chat_msg chat_msg;
189-
std::string generated_text; // append new chunks of generated text here
190-
std::vector<std::string> generated_tool_call_ids;
191-
192-
task_result_state(const common_chat_syntax & oaicompat_chat_syntax)
193-
: oaicompat_chat_syntax(oaicompat_chat_syntax) {}
194-
195-
// parse partial tool calls and update the internal state
196-
common_chat_msg update_chat_msg(
197-
const std::string & text_added,
198-
bool is_partial,
199-
std::vector<common_chat_msg_diff> & diffs);
200-
};
201-
202208
struct server_task_result {
203209
int id = -1;
204210
int id_slot = -1;

0 commit comments

Comments
 (0)