|
| 1 | +# |
| 2 | +# Copyright (c) Alibaba, Inc. and its affiliates. |
| 3 | +# @file basic_example_chatglm4.py |
| 4 | +# |
| 5 | +import os |
| 6 | +import copy |
| 7 | +import time |
| 8 | +import random |
| 9 | +import argparse |
| 10 | +import subprocess |
| 11 | +from jinja2 import Template |
| 12 | +from concurrent.futures import ThreadPoolExecutor |
| 13 | + |
| 14 | +from dashinfer.helper import EngineHelper, ConfigManager |
| 15 | + |
| 16 | + |
| 17 | +def download_model(model_id, revision, source="modelscope"): |
| 18 | + print(f"Downloading model {model_id} (revision: {revision}) from {source}") |
| 19 | + if source == "modelscope": |
| 20 | + from modelscope import snapshot_download |
| 21 | + model_dir = snapshot_download(model_id, revision=revision) |
| 22 | + elif source == "huggingface": |
| 23 | + from huggingface_hub import snapshot_download |
| 24 | + model_dir = snapshot_download(repo_id=model_id) |
| 25 | + else: |
| 26 | + raise ValueError("Unknown source") |
| 27 | + |
| 28 | + print(f"Save model to path {model_dir}") |
| 29 | + |
| 30 | + return model_dir |
| 31 | + |
| 32 | + |
| 33 | +def create_test_prompt(default_gen_cfg=None): |
| 34 | + input_list = [ |
| 35 | + "浙江的省会在哪", |
| 36 | + "Where is the capital of Zhejiang?", |
| 37 | + "将“温故而知新”翻译成英文,并解释其含义", |
| 38 | + ] |
| 39 | + |
| 40 | + user_msg = {"role": "user", "content": ""} |
| 41 | + assistant_msg = {"role": "assistant", "content": ""} |
| 42 | + |
| 43 | + prompt_template = Template( |
| 44 | + "[gMASK] <sop> " + "<|{{user_role}}|>\n" + "{{user_content}}" + |
| 45 | + "<|{{assistant_role}}|>\n") |
| 46 | + |
| 47 | + gen_cfg_list = [] |
| 48 | + prompt_list = [] |
| 49 | + for i in range(len(input_list)): |
| 50 | + user_msg["content"] = input_list[i] |
| 51 | + prompt = prompt_template.render(user_role=user_msg["role"], user_content=user_msg["content"], |
| 52 | + assistant_role=assistant_msg["role"]) |
| 53 | + prompt_list.append(prompt) |
| 54 | + if default_gen_cfg != None: |
| 55 | + gen_cfg = copy.deepcopy(default_gen_cfg) |
| 56 | + gen_cfg["seed"] = random.randint(0, 10000) |
| 57 | + gen_cfg_list.append(gen_cfg) |
| 58 | + |
| 59 | + return prompt_list, gen_cfg_list |
| 60 | + |
| 61 | + |
| 62 | +def process_request(request_list, engine_helper: EngineHelper): |
| 63 | + |
| 64 | + def print_inference_result(request): |
| 65 | + msg = "***********************************\n" |
| 66 | + msg += f"* Answer (dashinfer) for Request {request.id}\n" |
| 67 | + msg += "***********************************\n" |
| 68 | + msg += f"** context_time: {request.context_time} s, generate_time: {request.generate_time} s\n\n" |
| 69 | + msg += f"** encoded input, len: {request.in_tokens_len} **\n{request.in_tokens}\n\n" |
| 70 | + msg += f"** encoded output, len: {request.out_tokens_len} **\n{request.out_tokens}\n\n" |
| 71 | + msg += f"** text input **\n{request.in_text}\n\n" |
| 72 | + msg += f"** text output **\n{request.out_text}\n\n" |
| 73 | + print(msg) |
| 74 | + |
| 75 | + def done_callback(future): |
| 76 | + request = future.argument |
| 77 | + future.result() |
| 78 | + print_inference_result(request) |
| 79 | + |
| 80 | + # create a threadpool |
| 81 | + executor = ThreadPoolExecutor( |
| 82 | + max_workers=engine_helper.engine_config["engine_max_batch"]) |
| 83 | + |
| 84 | + try: |
| 85 | + # submit all tasks to the threadpool |
| 86 | + futures = [] |
| 87 | + for request in request_list: |
| 88 | + future = executor.submit(engine_helper.process_one_request, request) |
| 89 | + future.argument = request |
| 90 | + future.add_done_callback(done_callback) |
| 91 | + futures.append(future) |
| 92 | + finally: |
| 93 | + executor.shutdown(wait=True) |
| 94 | + |
| 95 | + return |
| 96 | + |
| 97 | + |
| 98 | +if __name__ == '__main__': |
| 99 | + parser = argparse.ArgumentParser() |
| 100 | + parser.add_argument('--quantize', action='store_true') |
| 101 | + args = parser.parse_args() |
| 102 | + |
| 103 | + config_file = "../model_config/config_chatglm4_9b.json" |
| 104 | + config = ConfigManager.get_config_from_json(config_file) |
| 105 | + config["convert_config"]["do_dynamic_quantize_convert"] = args.quantize |
| 106 | + |
| 107 | + cmd = f"pip show dashinfer | grep 'Location' | cut -d ' ' -f 2" |
| 108 | + package_location = subprocess.run(cmd, |
| 109 | + stdout=subprocess.PIPE, |
| 110 | + stderr=subprocess.PIPE, |
| 111 | + shell=True, |
| 112 | + text=True) |
| 113 | + package_location = package_location.stdout.strip() |
| 114 | + os.environ["AS_DAEMON_PATH"] = package_location + "/dashinfer/allspark/bin" |
| 115 | + os.environ["AS_NUMA_NUM"] = str(len(config["device_ids"])) |
| 116 | + os.environ["AS_NUMA_OFFSET"] = str(config["device_ids"][0]) |
| 117 | + |
| 118 | + ## download original model |
| 119 | + ## download model from huggingface |
| 120 | + # original_model = { |
| 121 | + # "source": "huggingface", |
| 122 | + # "model_id": "THUDM/glm-4-9b-chat", |
| 123 | + # "revision": "", |
| 124 | + # "model_path": "" |
| 125 | + # } |
| 126 | + |
| 127 | + ## download model from modelscope |
| 128 | + original_model = { |
| 129 | + "source": "modelscope", |
| 130 | + "model_id": "ZhipuAI/glm-4-9b-chat", |
| 131 | + "revision": "master", |
| 132 | + "model_path": "" |
| 133 | + } |
| 134 | + original_model["model_path"] = download_model(original_model["model_id"], |
| 135 | + original_model["revision"], |
| 136 | + original_model["source"]) |
| 137 | + |
| 138 | + ## init EngineHelper class |
| 139 | + engine_helper = EngineHelper(config) |
| 140 | + engine_helper.verbose = True |
| 141 | + engine_helper.init_tokenizer(original_model["model_path"]) |
| 142 | + |
| 143 | + ## convert huggingface model to dashinfer model |
| 144 | + ## only one conversion is required |
| 145 | + engine_helper.convert_model(original_model["model_path"]) |
| 146 | + |
| 147 | + ## inference |
| 148 | + engine_helper.init_engine() |
| 149 | + |
| 150 | + prompt_list, gen_cfg_list = create_test_prompt( |
| 151 | + engine_helper.default_gen_cfg) |
| 152 | + request_list = engine_helper.create_request(prompt_list, gen_cfg_list) |
| 153 | + |
| 154 | + global_start = time.time() |
| 155 | + process_request(request_list, engine_helper) |
| 156 | + global_end = time.time() |
| 157 | + |
| 158 | + total_timecost = global_end - global_start |
| 159 | + # engine_helper.print_inference_result_all(request_list) |
| 160 | + engine_helper.print_profiling_data(request_list, total_timecost) |
| 161 | + print(f"total timecost: {total_timecost} s") |
| 162 | + |
| 163 | + engine_helper.uninit_engine() |
0 commit comments