Skip to content

Commit 0d65c04

Browse files
authored
docs: update example code and apiserver doc. (#51)
* docs: update example code and apiserver doc. * readme: add colab notebook for gpu example.
1 parent 773c178 commit 0d65c04

File tree

5 files changed

+3079
-130
lines changed

5 files changed

+3079
-130
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,9 @@ For the detailed user manual, please refer to the documentation: [Documentation
121121

122122
In `<path_to_dashinfer>/examples` there are examples for C++ and Python interfaces, and please refer to the documentation in `<path_to_dashinfer>/documents/EN` to run the examples.
123123

124-
- [Basic Python Example](examples/python/0_basic/basic_example_qwen_v10_io.ipynb) [![Open In PAI-DSW](https://modelscope.oss-cn-beijing.aliyuncs.com/resource/Open-in-DSW20px.svg)](https://gallery.pai-ml.com/#/import/https://github.com/modelscope/dash-infer/blob/main/examples/python/0_basic/basic_example_qwen_v10_io.ipynb)
124+
125+
126+
- [Base GPU Python Example](examples/python/0_basic/cuda/demo_dashinfer_2_0_gpu_example.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/modelscope/dash-infer/blob/main/examples/python/0_basic/cuda/demo_dashinfer_2_0_gpu_example.ipynb)
125127
- [Documentation for All Python Examples](docs/EN/examples_python.md)
126128
- [Documentation for C++ Examples](docs/EN/examples_cpp.md)
127129

docs/sphinx/get_started/quick_start_api_py_en.md

Lines changed: 103 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -17,93 +17,85 @@ Below is an example of how to quickly serialize a Hugging Face model and perform
1717
### Inference Python Example
1818
This is an example of using asynchronous interface to obtain output, with bfloat16, in memory model serialize, and async output processing. The model is downloaded from Modelscope. Initiating requests and receiving outputs are both asynchronous, and can be handled according to your application needs.
1919

20-
```py
21-
import os
22-
import modelscope
23-
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
24-
25-
from dashinfer import allspark
26-
from dashinfer.allspark import *
27-
from dashinfer.allspark.engine import *
28-
from dashinfer.allspark.prompt_utils import PromptTemplate
29-
30-
# if use in memory serialize, change this flag to True
31-
in_memory = True
32-
device_list=[0]
33-
34-
modelscope_name ="qwen/Qwen2.5-1.5B-Instruct"
35-
ms_version = DEFAULT_MODEL_REVISION
36-
output_base_folder="output_qwen"
37-
model_local_path=""
38-
tmp_dir = "model_output"
39-
40-
model_local_path = modelscope.snapshot_download(modelscope_name, ms_version)
41-
safe_model_name = str(modelscope_name).replace("/", "_")
42-
model_convert_folder = os.path.join(output_base_folder, safe_model_name)
43-
44-
model_loader = allspark.HuggingFaceModel(model_local_path, safe_model_name, in_memory_serialize=in_memory, trust_remote_code=True)
45-
engine = allspark.Engine()
46-
47-
model_loader.load_model().serialize(engine, model_output_dir=tmp_dir).free_model()
48-
49-
runtime_cfg_builder = model_loader.create_reference_runtime_config_builder(safe_model_name, TargetDevice.CUDA, device_list, max_batch=8)
50-
# this builder can change runtime parameter
51-
# like change to engine max length to a smaller value
52-
runtime_cfg_builder.max_length(2048)
53-
54-
runtime_cfg = runtime_cfg_builder.build()
55-
56-
# install model to engine
57-
engine.install_model(runtime_cfg)
58-
59-
model_loader.free_memory_serialize_file()
60-
61-
# start the model inference
62-
engine.start_model(safe_model_name)
63-
64-
input_str = "How to protect our planet and build a green future?"
65-
input_str = PromptTemplate.apply_chatml_template(input_str)
66-
messages = [
67-
{"role": "system", "content": "You are a helpful assistant."},
68-
{"role": "user", "content": input_str}]
69-
70-
templated_input_str = model_loader.init_tokenizer().get_tokenizer().apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
71-
# generate a reference generate config.
72-
gen_cfg = model_loader.create_reference_generation_config_builder(runtime_cfg)
73-
# change generate config base on this generation config, like change top_k = 1
74-
gen_cfg.update({"top_k": 1})
20+
```python
21+
import os
22+
23+
from dashinfer import allspark
24+
from dashinfer.allspark import *
25+
from dashinfer.allspark.engine import *
26+
from dashinfer.allspark.prompt_utils import PromptTemplate
27+
28+
# Configuration
29+
in_memory = False
30+
device_list = [0] # single card by default, 4 cards replace with [0,1,2,3]
31+
model_name = "qwen/Qwen2.5-1.5B-Instruct"
32+
output_base_folder = "model_output"
33+
user_data_type = "float16" # most device supports float16
34+
use_modelscope = True
35+
36+
# Download and prepare the model
37+
if use_modelscope:
38+
import modelscope
39+
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
40+
model_local_path = modelscope.snapshot_download(model_name, DEFAULT_MODEL_REVISION)
41+
else:
42+
model_local_path = model_name
7543

76-
status, handle, queue = engine.start_request_text(safe_model_name, model_loader, templated_input_str, gen_cfg)
44+
safe_model_name = model_name.replace("/", "_")
45+
model_convert_folder = os.path.join(output_base_folder, safe_model_name)
7746

78-
generated_ids = []
79-
status = queue.GenerateStatus()
47+
# Initialize model and engine
48+
model_loader = allspark.HuggingFaceModel(model_local_path, safe_model_name,
49+
in_memory_serialize=in_memory,
50+
user_set_data_type=user_data_type)
51+
engine = allspark.Engine()
8052

81-
# in following 3 status, it means tokens are generating
82-
while (status == GenerateRequestStatus.Init or status == GenerateRequestStatus.Generating or status == GenerateRequestStatus.ContextFinished):
83-
elements = queue.Get()
84-
if elements is not None:
85-
generated_ids += elements.ids_from_generate
86-
status = queue.GenerateStatus()
53+
# Load and serialize the model
54+
model_loader.load_model().serialize(engine, model_output_dir=output_base_folder).free_model()
8755

88-
if status == GenerateRequestStatus.GenerateFinished:
89-
break
90-
# This means generated is finished.
56+
# Configure runtime settings
57+
runtime_cfg = model_loader.create_reference_runtime_config_builder(
58+
safe_model_name, TargetDevice.CUDA, device_list, max_batch=8).max_length(2048).build()
59+
engine.install_model(runtime_cfg)
60+
engine.start_model(safe_model_name)
9161

92-
if status == GenerateRequestStatus.GenerateInterrupted:
93-
break
94-
# This means the GPU has no available resources; the request has been halted by the engine.
95-
# The client should collect the tokens generated so far and initiate a new request later.
62+
if in_memory: model_loader.free_memory_serialize_file()
9663

97-
# de-tokenize id to text
98-
output_text = model_loader.init_tokenizer().get_tokenizer().decode(generated_ids)
64+
# Prepare input
65+
input_str = "How to protect our planet and build a green future?"
66+
messages = [
67+
{"role": "system", "content": "You are a helpful assistant."},
68+
{"role": "user", "content": PromptTemplate.apply_chatml_template(input_str)}
69+
]
70+
templated_input_str = model_loader.init_tokenizer().get_tokenizer().apply_chat_template(
71+
messages, tokenize=False, add_generation_prompt=True)
9972

100-
print(f"model: {modelscope_name} input:\n{input_str} \n output:\n{output_text}\n")
101-
print(f"input token:\n {model_loader.init_tokenizer().get_tokenizer().encode(input_str)}")
102-
print(f"output token:\n {generated_ids}")
73+
# Configure generation settings
74+
gen_cfg = model_loader.create_reference_generation_config_builder(runtime_cfg)
75+
gen_cfg.update({"top_k": 1})
10376

104-
engine.release_request(safe_model_name, handle)
77+
# Generate response
78+
status, handle, queue = engine.start_request_text(
79+
safe_model_name, model_loader, templated_input_str, gen_cfg)
80+
generated_ids = []
81+
82+
while True:
83+
elements = queue.Get()
84+
if elements:
85+
generated_ids += elements.ids_from_generate
86+
status = queue.GenerateStatus()
87+
if status in [GenerateRequestStatus.GenerateFinished,
88+
GenerateRequestStatus.GenerateInterrupted]:
89+
break
90+
91+
# Decode and print output
92+
output_text = model_loader.init_tokenizer().get_tokenizer().decode(generated_ids)
93+
print(f"Model: {model_name}\nInput: {input_str}\nOutput: {output_text}")
10594

106-
engine.stop_model(safe_model_name)
95+
# Clean up
96+
engine.release_request(safe_model_name, handle)
97+
engine.stop_model(safe_model_name)
98+
print(f"Model: {model_name} have been released.")
10799
```
108100

109101
### Explanations of the Code:
@@ -112,7 +104,9 @@ In this example, the `HuggingFaceModel` (`dashinfer.allspark.model_loader.Huggin
112104

113105
If you want to convert only once, pass `skip_if_exists=True`. If existing files are found, the model conversion step will be skipped. The model files will reside in the `{output_base_folder}` directory, generating two files: `{safe_model_name}.asparam`, `{safe_model_name}.asmodel`. The `free_model()` function will release the Hugging Face model files to save memory.
114106
```python
115-
model_loader = allspark.HuggingFaceModel(model_local_path, safe_model_name, trust_remote_code=True)
107+
model_loader = allspark.HuggingFaceModel(model_local_path, safe_model_name,
108+
in_memory_serialize=in_memory,
109+
user_set_data_type=user_data_type)
116110
engine = allspark.Engine()
117111
```
118112

@@ -123,6 +117,8 @@ use `model_loader.serialize()` for uniform API like in sample code, or use `seri
123117
`skip_if_exists` means if there is local file exits, local file serialize will be bypassed.
124118

125119
```python
120+
model_loader.load_model().serialize(engine, model_output_dir=output_base_folder).free_model()
121+
# or
126122
if in_memory:
127123
(model_loader.load_model()
128124
.serialize_to_memory(engine, enable_quant=init_quant, weight_only_quant=weight_only_quant)
@@ -140,21 +136,19 @@ In this code section, inference is conducted using a single CUDA card, with the
140136
If using in-memory serialization, you can release the memory file after `install_model`, since it is no longer needed.
141137

142138
```python
143-
if in_memory:
144-
model_loader.free_memory_serialize_file()
139+
if in_memory: model_loader.free_memory_serialize_file()
140+
145141
```
146142

147143
Upon calling `start_model`, the engine will perform a warm-up step that simulates a run with the maximum length set in the runtime parameters and the maximum batch size to ensure that no new resources will be requested during subsequent runs, ensuring stability. If the warm-up fails, reduce the length settings in the runtime configurations to lower resource demand. After completion of the warm-up, the engine enters a state ready to accept requests.
148144

149145
```python
150-
runtime_cfg_builder = model_loader.create_reference_runtime_config_builder(safe_model_name, TargetDevice.CUDA, [0, 1], max_batch=8)
151-
# like change to engine max length to a smaller value
152-
runtime_cfg_builder.max_length(2048)
146+
runtime_cfg = model_loader.create_reference_runtime_config_builder(
147+
safe_model_name, TargetDevice.CUDA, device_list, max_batch=8).max_length(2048).build()
153148
# like enable int8 kv-cache or int4 kv cache rather than fp16 kv-cache
154149
# runtime_cfg_builder.kv_cache_mode(AsCacheMode.AsCacheQuantI8)
155150
# or u4
156151
# runtime_cfg_builder.kv_cache_mode(AsCacheMode.AsCacheQuantU4)
157-
runtime_cfg = runtime_cfg_builder.build()
158152
# install model to engine
159153
engine.install_model(runtime_cfg)
160154
# start the model inference
@@ -166,23 +160,42 @@ The following code is focused on generating configurations and applying text tem
166160

167161
```python
168162
gen_cfg = model_loader.create_reference_generation_config_builder(runtime_cfg)
169-
# change generate config based on this generation config, like change top_k = 1
170163
gen_cfg.update({"top_k": 1})
171-
gen_cfg.update({"repetition_penalty": 1.1})
172164
```
173165
This code takes recommended generation parameters from Hugging Face's `generation_config.json` and makes optional modifications. It then asynchronously initiates model inference, where `status` indicates the success of the API. If successful, `handle` and `queue` are used for subsequent requests. The `handle` represents the request handle, while `queue` indicates the output queue; each request has its own output queue, which continuously accumulates generated tokens. This queue will only be released after `release_request` is invoked.
174166

175167
```python
176-
status, handle, queue = engine.start_request_text(safe_model_name, model_loader, input_str, gen_cfg)
168+
status, handle, queue = engine.start_request_text(
169+
safe_model_name, model_loader, templated_input_str, gen_cfg)
177170
```
178171

179172
#### 5. Handling Output
180-
##### 5.1 Synchronous Processing
181173

182174
DashInfer prioritizes asynchronous APIs for optimal performance and to align with the inherent nature of LLMs. Sending and receiving requests is primarily designed for asynchronous operation. However, for compatibility with user preferences accustomed to synchronous calls, we provide `engine.sync_request()`. This API allows users to block until the generation request completes.
183175

176+
##### 5.1 Asynchronous Processing
177+
Asynchronous processing differs in that it requires repeated calls to the queue until the status changes to `GenerateRequestStatus.ContextFinished`. A normal state machine transition goes:
178+
`Init` (initial state) -> `ContextFinished` (prefill completed and first token generated) ->
179+
`Generating` (in progress) -> `GenerateFinished` (completed).
180+
During this normal state transition, an exceptional state can occur: `GenerateInterrupted`, which indicates resource shortages, causing the request to pause while its resources are temporarily released for others. This often happens under heavy loads.
181+
182+
```python
183+
generated_ids = []
184+
while True:
185+
elements = queue.Get()
186+
if elements:
187+
generated_ids += elements.ids_from_generate
188+
status = queue.GenerateStatus()
189+
if status in [GenerateRequestStatus.GenerateFinished, GenerateRequestStatus.GenerateInterrupted]:
190+
break
191+
```
192+
193+
##### 5.2 Synchronous Processing
194+
184195
The subsequent call to `sync_request` will block until generation is finished, simulating a synchronous call. Without this invocation, operations on the queue can proceed but will require polling. The following code synchronously fetches all currently generated IDs from the queue, blocking at this point if there are IDs yet to be generated until completion or an error occurs.
185196

197+
Sync processing is not showing in this example code, you can modify example following code.
198+
186199
Here's an example:
187200

188201
```python
@@ -201,45 +214,10 @@ generated_elem = queue.Get()
201214
generated_ids = generated_elem.ids_from_generate
202215
```
203216

217+
#### 6. Decode Token
204218
For usage of the queue class, you can use `help(dashinfer.allspark.ResultQueue)` for detailed information. The next step converts IDs back into text:
205219

206220
```python
207221
output_text = model_loader.init_tokenizer().get_tokenizer().decode(generated_ids)
208222
```
209223

210-
##### 5.2 Asynchronous Processing
211-
Asynchronous processing differs in that it requires repeated calls to the queue until the status changes to `GenerateRequestStatus.ContextFinished`. A normal state machine transition goes:
212-
`Init` (initial state) -> `ContextFinished` (prefill completed and first token generated) ->
213-
`Generating` (in progress) -> `GenerateFinished` (completed).
214-
During this normal state transition, an exceptional state can occur: `GenerateInterrupted`, which indicates resource shortages, causing the request to pause while its resources are temporarily released for others. This often happens under heavy loads.
215-
216-
```python
217-
generated_ids2 = []
218-
# async fetch output result.
219-
# looping until status is not okay
220-
print(f"2 request: status: {queue2.GenerateStatus()}")
221-
status = queue2.GenerateStatus()
222-
# in the following 3 statuses, it means tokens are generating
223-
while (status == GenerateRequestStatus.Init
224-
or status == GenerateRequestStatus.Generating
225-
or status == GenerateRequestStatus.ContextFinished):
226-
print(f"2 request: status: {queue2.GenerateStatus()}")
227-
elements = queue2.Get()
228-
if elements is not None:
229-
print(f"new token: {elements.ids_from_generate}")
230-
generated_ids2 += elements.ids_from_generate
231-
status = queue2.GenerateStatus()
232-
if status == GenerateRequestStatus.GenerateFinished:
233-
break
234-
# This means generation is finished.
235-
if status == GenerateRequestStatus.GenerateInterrupted:
236-
break
237-
# This means the GPU has no available resources; the request has been halted by the engine.
238-
# The client should collect the tokens generated so far and initiate a new request later.
239-
240-
if test:
241-
test.assertEqual(queue2.GenerateStatus(), GenerateRequestStatus.GenerateFinished)
242-
243-
print(f"generated id: {queue2.GenerateStatus()} {generated_ids2}")
244-
output_text2 = model_loader.init_tokenizer().get_tokenizer().decode(generated_ids2)
245-
```

docs/sphinx/get_started/quick_start_api_server_en.md

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,31 @@
1-
## Quick Start Guide for OpenAI API Server
1+
## Quick Start Guide for OpenAI API Chat Server
22

3-
### Start OpenAI Server with Docker
3+
### Test the OpenAI API Server (fastapi)
4+
5+
prepare:
6+
`pip install fastapi uvicorn openai`
7+
8+
start server
9+
`python ./dash-infer/examples/api_server/fastapi/fastapi-server.py`
10+
11+
user may change the parameter by check `fastapi-server.py -h`
12+
13+
After sever start, server will print some log like:
14+
```
15+
INFO: Started server process [4898]
16+
INFO: Waiting for application startup.
17+
INFO: Application startup complete.
18+
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
19+
```
20+
21+
test with openai client
22+
`python ./dash-infer/examples/api_server/fastapi/openai-client.py`
23+
24+
This client will call with openai client with streaming and block mode.
25+
26+
27+
28+
### Start OpenAI Server with Docker (fastchat)
429

530
We have provide a Docker image to start OpenAI server.
631
This example demonstrates how to use Docker to run DashInfer as an inference engine, providing OpenAI API endpoints.
@@ -40,7 +65,7 @@ docker run \
4065

4166
You can also build you owner fastchat Docker image by modifying the Docker file `scripts/docker/fschat_ubuntu_cuda.Dockerfile`.
4267

43-
### Testing the OpenAI API Server
68+
### Testing the OpenAI API Server (fastchat)
4469

4570
#### Testing with OpenAI SDK
4671
In `examples/api_server/fschat/openai-client.py`, the official OpenAI SDK is used to test the API server.

examples/api_server/fastapi/fastapi-server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from dashinfer.allspark.runtime_config import AsModelRuntimeConfigBuilder
1717
import uvicorn
1818
import json
19-
import modelscope
2019

2120
# 解析命令行参数
2221
parser = argparse.ArgumentParser(description="FastAPI server with custom options")
@@ -45,6 +44,7 @@
4544
tmp_dir = "model_output"
4645

4746
if args.use_modelscope:
47+
import modelscope
4848
modelscope_model_name = args.model_name
4949
model_local_path = modelscope.snapshot_download(modelscope_model_name)
5050
else:

0 commit comments

Comments
 (0)