Skip to content

Commit f940c2c

Browse files
authored
support glm-4-9b-chat (#10)
1 parent 9d50215 commit f940c2c

File tree

11 files changed

+303
-15
lines changed

11 files changed

+303
-15
lines changed

README.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,14 @@ During inference, the quantized weight is recovered as bfloat16 for matrix multi
8080

8181
# Supported Models
8282

83-
| Architecture | Models | DashInfer model_type | HuggingFace Models | ModelScope Models |
84-
|:------------:|:------:|:--------------------:|:------------------:|:-----------------:|
85-
| QWenLMHeadModel | Qwen | Qwen_v10 | [Qwen/Qwen-1_8B-Chat](https://huggingface.co/Qwen/Qwen-1_8B-Chat),<br>[Qwen/Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat),<br>[Qwen/Qwen-14B-Chat](https://huggingface.co/Qwen/Qwen-14B-Chat), etc. | [qwen/Qwen-1_8B-Chat](https://modelscope.cn/models/qwen/Qwen-1_8B-Chat/summary),<br>[qwen/Qwen-7B-Chat](https://modelscope.cn/models/qwen/Qwen-7B-Chat/summary),<br>[qwen/Qwen-14B-Chat](https://modelscope.cn/models/qwen/Qwen-14B-Chat/summary), etc. |
86-
| Qwen2ForCausalLM | Qwen1.5 | Qwen_v15 | [Qwen/Qwen1.5-0.5B-Chat](https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat),<br>[Qwen/Qwen1.5-1.8B-Chat](https://huggingface.co/Qwen/Qwen1.5-1.8B-Chat),<br>[Qwen/Qwen1.5-4B-Chat](https://huggingface.co/Qwen/Qwen1.5-4B-Chat),<br>[Qwen/Qwen1.5-7B-Chat](https://huggingface.co/Qwen/Qwen1.5-7B-Chat),<br>[Qwen/Qwen1.5-14B-Chat](https://huggingface.co/Qwen/Qwen1.5-14B-Chat), etc. | [qwen/Qwen1.5-0.5B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-0.5B-Chat/summary),<br>[qwen/Qwen1.5-1.8B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-1.8B-Chat/summary),<br>[qwen/Qwen1.5-4B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-4B-Chat/summary),<br>[qwen/Qwen1.5-7B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-7B-Chat/summary),<br>[qwen/Qwen1.5-14B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-14B-Chat/summary), etc. |
87-
| ChatGLMModel | ChatGLM | ChatGLM_v2 | [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b),<br>[THUDM/chatglm2-6b-32k](https://huggingface.co/THUDM/chatglm2-6b-32k) | [ZhipuAI/chatglm2-6b](https://modelscope.cn/models/ZhipuAI/chatglm2-6b/summary),<br>[ZhipuAI/chatglm2-6b-32k](https://modelscope.cn/models/ZhipuAI/chatglm2-6b-32k/summary) |
88-
| ChatGLMModel | ChatGLM | ChatGLM_v3 | [THUDM/chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b),<br>[THUDM/chatglm3-6b-32k](https://huggingface.co/THUDM/chatglm3-6b-32k) | [ZhipuAI/chatglm3-6b](https://modelscope.cn/models/ZhipuAI/chatglm3-6b/summary),<br>[ZhipuAI/chatglm3-6b-32k](https://modelscope.cn/models/ZhipuAI/chatglm3-6b-32k/summary) |
89-
| LlamaForCausalLM | LLaMA-2 | LLaMA_v2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf),<br>[meta-llama/Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) | [modelscope/Llama-2-7b-chat-ms](https://modelscope.cn/models/modelscope/Llama-2-7b-chat-ms/summary),<br>[modelscope/Llama-2-13b-chat-ms](https://modelscope.cn/models/modelscope/Llama-2-13b-chat-ms/summary) |
83+
| Architecture | Models | DashInfer model_type | HuggingFace Models | ModelScope Models | DashInfer Models |
84+
|:------------:|:------:|:--------------------:|:------------------:|:-----------------:|:----------------:|
85+
| QWenLMHeadModel | Qwen | Qwen_v10 | [Qwen/Qwen-1_8B-Chat](https://huggingface.co/Qwen/Qwen-1_8B-Chat),<br>[Qwen/Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat),<br>[Qwen/Qwen-14B-Chat](https://huggingface.co/Qwen/Qwen-14B-Chat), etc. | [qwen/Qwen-1_8B-Chat](https://modelscope.cn/models/qwen/Qwen-1_8B-Chat/summary),<br>[qwen/Qwen-7B-Chat](https://modelscope.cn/models/qwen/Qwen-7B-Chat/summary),<br>[qwen/Qwen-14B-Chat](https://modelscope.cn/models/qwen/Qwen-14B-Chat/summary), etc. | / |
86+
| Qwen2ForCausalLM | Qwen1.5 | Qwen_v15 | [Qwen/Qwen1.5-0.5B-Chat](https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat),<br>[Qwen/Qwen1.5-1.8B-Chat](https://huggingface.co/Qwen/Qwen1.5-1.8B-Chat),<br>[Qwen/Qwen1.5-4B-Chat](https://huggingface.co/Qwen/Qwen1.5-4B-Chat),<br>[Qwen/Qwen1.5-7B-Chat](https://huggingface.co/Qwen/Qwen1.5-7B-Chat),<br>[Qwen/Qwen1.5-14B-Chat](https://huggingface.co/Qwen/Qwen1.5-14B-Chat), etc. | [qwen/Qwen1.5-0.5B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-0.5B-Chat/summary),<br>[qwen/Qwen1.5-1.8B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-1.8B-Chat/summary),<br>[qwen/Qwen1.5-4B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-4B-Chat/summary),<br>[qwen/Qwen1.5-7B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-7B-Chat/summary),<br>[qwen/Qwen1.5-14B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-14B-Chat/summary), etc. | / |
87+
| ChatGLMModel | ChatGLM | ChatGLM_v2 | [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b),<br>[THUDM/chatglm2-6b-32k](https://huggingface.co/THUDM/chatglm2-6b-32k) | [ZhipuAI/chatglm2-6b](https://modelscope.cn/models/ZhipuAI/chatglm2-6b/summary),<br>[ZhipuAI/chatglm2-6b-32k](https://modelscope.cn/models/ZhipuAI/chatglm2-6b-32k/summary) | / |
88+
| ChatGLMModel | ChatGLM | ChatGLM_v3 | [THUDM/chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b),<br>[THUDM/chatglm3-6b-32k](https://huggingface.co/THUDM/chatglm3-6b-32k) | [ZhipuAI/chatglm3-6b](https://modelscope.cn/models/ZhipuAI/chatglm3-6b/summary),<br>[ZhipuAI/chatglm3-6b-32k](https://modelscope.cn/models/ZhipuAI/chatglm3-6b-32k/summary) | / |
89+
| ChatGLMModel | ChatGLM | ChatGLM_v4 | [THUDM/glm-4-9b-chat](https://huggingface.co/THUDM/glm-4-9b-chat) | [ZhipuAI/glm-4-9b-chat](https://modelscope.cn/models/ZhipuAI/glm-4-9b-chat/summary) | [dash-infer/glm-4-9b-chat-DI](https://modelscope.cn/models/dash-infer/glm-4-9b-chat-DI/summary) |
90+
| LlamaForCausalLM | LLaMA-2 | LLaMA_v2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf),<br>[meta-llama/Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) | [modelscope/Llama-2-7b-chat-ms](https://modelscope.cn/models/modelscope/Llama-2-7b-chat-ms/summary),<br>[modelscope/Llama-2-13b-chat-ms](https://modelscope.cn/models/modelscope/Llama-2-13b-chat-ms/summary) | / |
9091

9192
# Software Architecture
9293

README_CN.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,14 @@ $$ x_{u8} = x_{fp32} / scale + zeropoint $$
8181

8282
# 模型支持
8383

84-
| Architecture | Models | DashInfer model_type | HuggingFace Models | ModelScope Models |
85-
|:------------:|:------:|:--------------------:|:------------------:|:-----------------:|
86-
| QWenLMHeadModel | Qwen | Qwen_v10 | [Qwen/Qwen-1_8B-Chat](https://huggingface.co/Qwen/Qwen-1_8B-Chat),<br>[Qwen/Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat),<br>[Qwen/Qwen-14B-Chat](https://huggingface.co/Qwen/Qwen-14B-Chat), etc. | [qwen/Qwen-1_8B-Chat](https://modelscope.cn/models/qwen/Qwen-1_8B-Chat/summary),<br>[qwen/Qwen-7B-Chat](https://modelscope.cn/models/qwen/Qwen-7B-Chat/summary),<br>[qwen/Qwen-14B-Chat](https://modelscope.cn/models/qwen/Qwen-14B-Chat/summary), etc. |
87-
| Qwen2ForCausalLM | Qwen1.5 | Qwen_v15 | [Qwen/Qwen1.5-0.5B-Chat](https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat),<br>[Qwen/Qwen1.5-1.8B-Chat](https://huggingface.co/Qwen/Qwen1.5-1.8B-Chat),<br>[Qwen/Qwen1.5-4B-Chat](https://huggingface.co/Qwen/Qwen1.5-4B-Chat),<br>[Qwen/Qwen1.5-7B-Chat](https://huggingface.co/Qwen/Qwen1.5-7B-Chat),<br>[Qwen/Qwen1.5-14B-Chat](https://huggingface.co/Qwen/Qwen1.5-14B-Chat), etc. | [qwen/Qwen1.5-0.5B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-0.5B-Chat/summary),<br>[qwen/Qwen1.5-1.8B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-1.8B-Chat/summary),<br>[qwen/Qwen1.5-4B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-4B-Chat/summary),<br>[qwen/Qwen1.5-7B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-7B-Chat/summary),<br>[qwen/Qwen1.5-14B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-14B-Chat/summary), etc. |
88-
| ChatGLMModel | ChatGLM | ChatGLM_v2 | [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b),<br>[THUDM/chatglm2-6b-32k](https://huggingface.co/THUDM/chatglm2-6b-32k) | [ZhipuAI/chatglm2-6b](https://modelscope.cn/models/ZhipuAI/chatglm2-6b/summary),<br>[ZhipuAI/chatglm2-6b-32k](https://modelscope.cn/models/ZhipuAI/chatglm2-6b-32k/summary) |
89-
| ChatGLMModel | ChatGLM | ChatGLM_v3 | [THUDM/chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b),<br>[THUDM/chatglm3-6b-32k](https://huggingface.co/THUDM/chatglm3-6b-32k) | [ZhipuAI/chatglm3-6b](https://modelscope.cn/models/ZhipuAI/chatglm3-6b/summary),<br>[ZhipuAI/chatglm3-6b-32k](https://modelscope.cn/models/ZhipuAI/chatglm3-6b-32k/summary) |
90-
| LlamaForCausalLM | LLaMA-2 | LLaMA_v2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf),<br>[meta-llama/Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) | [modelscope/Llama-2-7b-chat-ms](https://modelscope.cn/models/modelscope/Llama-2-7b-chat-ms/summary),<br>[modelscope/Llama-2-13b-chat-ms](https://modelscope.cn/models/modelscope/Llama-2-13b-chat-ms/summary) |
84+
| Architecture | Models | DashInfer model_type | HuggingFace Models | ModelScope Models | DashInfer Models |
85+
|:------------:|:------:|:--------------------:|:------------------:|:-----------------:|:----------------:|
86+
| QWenLMHeadModel | Qwen | Qwen_v10 | [Qwen/Qwen-1_8B-Chat](https://huggingface.co/Qwen/Qwen-1_8B-Chat),<br>[Qwen/Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat),<br>[Qwen/Qwen-14B-Chat](https://huggingface.co/Qwen/Qwen-14B-Chat), etc. | [qwen/Qwen-1_8B-Chat](https://modelscope.cn/models/qwen/Qwen-1_8B-Chat/summary),<br>[qwen/Qwen-7B-Chat](https://modelscope.cn/models/qwen/Qwen-7B-Chat/summary),<br>[qwen/Qwen-14B-Chat](https://modelscope.cn/models/qwen/Qwen-14B-Chat/summary), etc. | / |
87+
| Qwen2ForCausalLM | Qwen1.5 | Qwen_v15 | [Qwen/Qwen1.5-0.5B-Chat](https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat),<br>[Qwen/Qwen1.5-1.8B-Chat](https://huggingface.co/Qwen/Qwen1.5-1.8B-Chat),<br>[Qwen/Qwen1.5-4B-Chat](https://huggingface.co/Qwen/Qwen1.5-4B-Chat),<br>[Qwen/Qwen1.5-7B-Chat](https://huggingface.co/Qwen/Qwen1.5-7B-Chat),<br>[Qwen/Qwen1.5-14B-Chat](https://huggingface.co/Qwen/Qwen1.5-14B-Chat), etc. | [qwen/Qwen1.5-0.5B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-0.5B-Chat/summary),<br>[qwen/Qwen1.5-1.8B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-1.8B-Chat/summary),<br>[qwen/Qwen1.5-4B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-4B-Chat/summary),<br>[qwen/Qwen1.5-7B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-7B-Chat/summary),<br>[qwen/Qwen1.5-14B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-14B-Chat/summary), etc. | / |
88+
| ChatGLMModel | ChatGLM | ChatGLM_v2 | [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b),<br>[THUDM/chatglm2-6b-32k](https://huggingface.co/THUDM/chatglm2-6b-32k) | [ZhipuAI/chatglm2-6b](https://modelscope.cn/models/ZhipuAI/chatglm2-6b/summary),<br>[ZhipuAI/chatglm2-6b-32k](https://modelscope.cn/models/ZhipuAI/chatglm2-6b-32k/summary) | / |
89+
| ChatGLMModel | ChatGLM | ChatGLM_v3 | [THUDM/chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b),<br>[THUDM/chatglm3-6b-32k](https://huggingface.co/THUDM/chatglm3-6b-32k) | [ZhipuAI/chatglm3-6b](https://modelscope.cn/models/ZhipuAI/chatglm3-6b/summary),<br>[ZhipuAI/chatglm3-6b-32k](https://modelscope.cn/models/ZhipuAI/chatglm3-6b-32k/summary) | / |
90+
| ChatGLMModel | ChatGLM | ChatGLM_v4 | [THUDM/glm-4-9b-chat](https://huggingface.co/THUDM/glm-4-9b-chat) | [ZhipuAI/glm-4-9b-chat](https://modelscope.cn/models/ZhipuAI/glm-4-9b-chat/summary) | [dash-infer/glm-4-9b-chat-DI](https://modelscope.cn/models/dash-infer/glm-4-9b-chat-DI/summary) |
91+
| LlamaForCausalLM | LLaMA-2 | LLaMA_v2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf),<br>[meta-llama/Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) | [modelscope/Llama-2-7b-chat-ms](https://modelscope.cn/models/modelscope/Llama-2-7b-chat-ms/summary),<br>[modelscope/Llama-2-13b-chat-ms](https://modelscope.cn/models/modelscope/Llama-2-13b-chat-ms/summary) | / |
9192

9293
# 软件框架
9394

csrc/core/model/chatglm/chatglm.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,5 @@ AsStatus ChatGLMModel::Forward(const TensorMap& inputs, TensorMap* outputs) {
102102

103103
REGISTER_MODEL("ChatGLM_v2", ChatGLM_v2Model)
104104
REGISTER_MODEL("ChatGLM_v3", ChatGLM_v3Model)
105+
REGISTER_MODEL("ChatGLM_v4", ChatGLM_v4Model)
105106
} // namespace allspark

csrc/core/model/chatglm/chatglm.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,10 @@ class ChatGLM_v3Model : public ChatGLMModel {
3131
explicit ChatGLM_v3Model(const std::string& model_type = "")
3232
: ChatGLMModel(model_type){};
3333
};
34+
35+
class ChatGLM_v4Model : public ChatGLMModel {
36+
public:
37+
explicit ChatGLM_v4Model(const std::string& model_type = "")
38+
: ChatGLMModel(model_type){};
39+
};
3440
} // namespace allspark
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
#
2+
# Copyright (c) Alibaba, Inc. and its affiliates.
3+
# @file basic_example_chatglm4.py
4+
#
5+
import os
6+
import copy
7+
import time
8+
import random
9+
import argparse
10+
import subprocess
11+
from jinja2 import Template
12+
from concurrent.futures import ThreadPoolExecutor
13+
14+
from dashinfer.helper import EngineHelper, ConfigManager
15+
16+
17+
def download_model(model_id, revision, source="modelscope"):
18+
print(f"Downloading model {model_id} (revision: {revision}) from {source}")
19+
if source == "modelscope":
20+
from modelscope import snapshot_download
21+
model_dir = snapshot_download(model_id, revision=revision)
22+
elif source == "huggingface":
23+
from huggingface_hub import snapshot_download
24+
model_dir = snapshot_download(repo_id=model_id)
25+
else:
26+
raise ValueError("Unknown source")
27+
28+
print(f"Save model to path {model_dir}")
29+
30+
return model_dir
31+
32+
33+
def create_test_prompt(default_gen_cfg=None):
34+
input_list = [
35+
"浙江的省会在哪",
36+
"Where is the capital of Zhejiang?",
37+
"将“温故而知新”翻译成英文,并解释其含义",
38+
]
39+
40+
user_msg = {"role": "user", "content": ""}
41+
assistant_msg = {"role": "assistant", "content": ""}
42+
43+
prompt_template = Template(
44+
"[gMASK] <sop> " + "<|{{user_role}}|>\n" + "{{user_content}}" +
45+
"<|{{assistant_role}}|>\n")
46+
47+
gen_cfg_list = []
48+
prompt_list = []
49+
for i in range(len(input_list)):
50+
user_msg["content"] = input_list[i]
51+
prompt = prompt_template.render(user_role=user_msg["role"], user_content=user_msg["content"],
52+
assistant_role=assistant_msg["role"])
53+
prompt_list.append(prompt)
54+
if default_gen_cfg != None:
55+
gen_cfg = copy.deepcopy(default_gen_cfg)
56+
gen_cfg["seed"] = random.randint(0, 10000)
57+
gen_cfg_list.append(gen_cfg)
58+
59+
return prompt_list, gen_cfg_list
60+
61+
62+
def process_request(request_list, engine_helper: EngineHelper):
63+
64+
def print_inference_result(request):
65+
msg = "***********************************\n"
66+
msg += f"* Answer (dashinfer) for Request {request.id}\n"
67+
msg += "***********************************\n"
68+
msg += f"** context_time: {request.context_time} s, generate_time: {request.generate_time} s\n\n"
69+
msg += f"** encoded input, len: {request.in_tokens_len} **\n{request.in_tokens}\n\n"
70+
msg += f"** encoded output, len: {request.out_tokens_len} **\n{request.out_tokens}\n\n"
71+
msg += f"** text input **\n{request.in_text}\n\n"
72+
msg += f"** text output **\n{request.out_text}\n\n"
73+
print(msg)
74+
75+
def done_callback(future):
76+
request = future.argument
77+
future.result()
78+
print_inference_result(request)
79+
80+
# create a threadpool
81+
executor = ThreadPoolExecutor(
82+
max_workers=engine_helper.engine_config["engine_max_batch"])
83+
84+
try:
85+
# submit all tasks to the threadpool
86+
futures = []
87+
for request in request_list:
88+
future = executor.submit(engine_helper.process_one_request, request)
89+
future.argument = request
90+
future.add_done_callback(done_callback)
91+
futures.append(future)
92+
finally:
93+
executor.shutdown(wait=True)
94+
95+
return
96+
97+
98+
if __name__ == '__main__':
99+
parser = argparse.ArgumentParser()
100+
parser.add_argument('--quantize', action='store_true')
101+
args = parser.parse_args()
102+
103+
config_file = "../model_config/config_chatglm4_9b.json"
104+
config = ConfigManager.get_config_from_json(config_file)
105+
config["convert_config"]["do_dynamic_quantize_convert"] = args.quantize
106+
107+
cmd = f"pip show dashinfer | grep 'Location' | cut -d ' ' -f 2"
108+
package_location = subprocess.run(cmd,
109+
stdout=subprocess.PIPE,
110+
stderr=subprocess.PIPE,
111+
shell=True,
112+
text=True)
113+
package_location = package_location.stdout.strip()
114+
os.environ["AS_DAEMON_PATH"] = package_location + "/dashinfer/allspark/bin"
115+
os.environ["AS_NUMA_NUM"] = str(len(config["device_ids"]))
116+
os.environ["AS_NUMA_OFFSET"] = str(config["device_ids"][0])
117+
118+
## download original model
119+
## download model from huggingface
120+
# original_model = {
121+
# "source": "huggingface",
122+
# "model_id": "THUDM/glm-4-9b-chat",
123+
# "revision": "",
124+
# "model_path": ""
125+
# }
126+
127+
## download model from modelscope
128+
original_model = {
129+
"source": "modelscope",
130+
"model_id": "ZhipuAI/glm-4-9b-chat",
131+
"revision": "master",
132+
"model_path": ""
133+
}
134+
original_model["model_path"] = download_model(original_model["model_id"],
135+
original_model["revision"],
136+
original_model["source"])
137+
138+
## init EngineHelper class
139+
engine_helper = EngineHelper(config)
140+
engine_helper.verbose = True
141+
engine_helper.init_tokenizer(original_model["model_path"])
142+
143+
## convert huggingface model to dashinfer model
144+
## only one conversion is required
145+
engine_helper.convert_model(original_model["model_path"])
146+
147+
## inference
148+
engine_helper.init_engine()
149+
150+
prompt_list, gen_cfg_list = create_test_prompt(
151+
engine_helper.default_gen_cfg)
152+
request_list = engine_helper.create_request(prompt_list, gen_cfg_list)
153+
154+
global_start = time.time()
155+
process_request(request_list, engine_helper)
156+
global_end = time.time()
157+
158+
total_timecost = global_end - global_start
159+
# engine_helper.print_inference_result_all(request_list)
160+
engine_helper.print_profiling_data(request_list, total_timecost)
161+
print(f"total timecost: {total_timecost} s")
162+
163+
engine_helper.uninit_engine()
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#
2+
# Copyright (c) Alibaba, Inc. and its affiliates.
3+
# @file basic_example_chatglm4_dimodel_simple.py
4+
#
5+
import copy
6+
import random
7+
8+
from modelscope import snapshot_download
9+
from dashinfer.helper import EngineHelper, ConfigManager
10+
11+
model_path = snapshot_download("dash-infer/glm-4-9b-chat-DI")
12+
13+
config_file = model_path + "/" + "di_config.json"
14+
config = ConfigManager.get_config_from_json(config_file)
15+
config["model_path"] = model_path
16+
17+
## init EngineHelper class
18+
engine_helper = EngineHelper(config)
19+
engine_helper.verbose = True
20+
engine_helper.init_tokenizer(model_path)
21+
22+
## init engine
23+
engine_helper.init_engine()
24+
25+
## prepare inputs and generation configs
26+
user_input = "浙江的省会在哪"
27+
prompt = "[gMASK] <sop> " + "<|user|>\n" + user_input + "<|assistant|>\n"
28+
gen_cfg = copy.deepcopy(engine_helper.default_gen_cfg)
29+
gen_cfg["seed"] = random.randint(0, 10000)
30+
request_list = engine_helper.create_request([prompt], [gen_cfg])
31+
32+
## inference
33+
engine_helper.process_one_request(request_list[0])
34+
engine_helper.print_inference_result_all(request_list)
35+
36+
engine_helper.uninit_engine()

0 commit comments

Comments
 (0)