Skip to content

Commit 773c178

Browse files
x574chenXiaotong Chen
andauthored
Support transformers run Vision model (#49)
* add transformers to run qwen2-vl vision model for accuracy baseline * update multimodel requirement: add dashinfer * Skip TensorRT package import if TensorRT is not installed * use flash_attention_2 and update doc --------- Co-authored-by: Xiaotong Chen <“cxt459847@alibaba-inc.com”>
1 parent 96de58b commit 773c178

File tree

8 files changed

+66
-34
lines changed

8 files changed

+66
-34
lines changed

docs/sphinx/vlm/vlm_offline_inference_en.rst

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,21 +111,26 @@ You can also use OpenAI's Python client library:
111111
112112
Launching with CLI
113113
-------------------------
114-
You can also opt to install dashinfer-vlm locally and use command line to launch server.
114+
You can install dashinfer-vlm locally and use the command line to launch the server by following these steps. We highly recommend using NVIDIA PyTorch Containers `nvcr.io/nvidia/pytorch:xx.xx-py3` for setup.
115115

116-
1. Pull dashinfer docker image (see :ref:`docker-label`)
117-
2. Install TensorRT Python package, and download TensorRT GA build from NVIDIA Developer Zone.
116+
1. (Optional when TensorRT is installed) Install TensorRT Python package, and download TensorRT GA build from NVIDIA Developer Zone.
118117

119118
.. code-block:: bash
120119
121120
wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.5.0/tars/TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-12.6.tar.gz
122121
tar -xvzf TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-12.6.tar.gz
122+
pip install `pwd`/TensorRT-10.5.0.18/python/tensorrt-10.5.0-cp310-none-linux_x86_64.whl
123123
export LD_LIBRARY_PATH=`pwd`/TensorRT-10.5.0.18/lib
124124
125-
3. Install dashinfer Python Package from `release <https://github.com/modelscope/dash-infer/releases>`_
126-
4. Install dashinfer-vlm: ``pip install dashinfer-vlm``.
125+
2. Install dashinfer-vlm: ``pip install dashinfer-vlm``, or install from source code
127126

128-
Now you can launch server with command line:
127+
.. code-block:: bash
128+
129+
git clone https://github.com/modelscope/dash-infer.git
130+
cd dash-infer/multimodal/
131+
pip install -e ./
132+
133+
3. Launch server with command line:
129134

130135
.. code-block:: bash
131136

multimodal/dashinfer_vlm/api_server/config.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def add_context_args(parser):
4848
"--vision_engine",
4949
type=str,
5050
default="tensorrt",
51-
choices=["tensorrt"],
51+
choices=["tensorrt", "transformers"],
5252
help="engine to run vision model",
5353
)
5454
group.add_argument(
@@ -76,6 +76,11 @@ def add_context_args(parser):
7676
action="store_true",
7777
help="enable FP8",
7878
)
79+
group.add_argument(
80+
"--dtype",
81+
default="bfloat16",
82+
choices=["bfloat16", "float16"],
83+
)
7984
group.add_argument(
8085
"--min-pixels",
8186
default=4*28*28,

multimodal/dashinfer_vlm/api_server/conversation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def get_content(self, content) -> Tuple:
116116
image_list = []
117117
else:
118118
if image_list[0] == "image":
119-
text = "<|vision_start|><|vision_end|>\n" * (len(image_list) - 1)
119+
text = "<|vision_start|><|vision_end|>" * (len(image_list) - 1)
120120
elif image_list[0] == "video":
121121
text = "<|vision_start|><|vision_end|>\n"
122122
else:

multimodal/dashinfer_vlm/api_server/protocol/openai_api_protocol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class ChatCompletionRequest(BaseModel):
8888
]
8989
temperature: Optional[float] = 0.7
9090
top_p: Optional[float] = 1.0
91-
top_k: Optional[int] = 1
91+
top_k: Optional[int] = 0
9292
n: Optional[int] = 1
9393
max_tokens: Optional[int] = None
9494
max_completion_tokens: Optional[int] = None

multimodal/dashinfer_vlm/api_server/server.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def init():
7979
home_dir = os.environ.get("HOME") or "/root"
8080
output_dir = os.path.join(home_dir, ".cache/as_model/", model.split("/")[-1])
8181
model_name = "model"
82-
data_type = "bfloat16"
82+
data_type = context.get("dtype")
8383

8484
model_loader = HuggingFaceVLModel(
8585
model,
@@ -474,7 +474,9 @@ def get_vl_request(
474474
"min_length": 5,
475475
"frequency_penalty": frequency_penalty,
476476
"presence_penalty": presence_penalty,
477+
# "repetition_penalty": 1.05,
477478
"length_penalty": 1,
479+
"stop_words_ids": [[151643], [151644], [151645]],
478480
"eos_token_id": context.get("eos_token_id"),
479481
"seed": 1234567,
480482
}

multimodal/dashinfer_vlm/vl_inference/runtime/hie_worker.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
import threading
1212
import queue
1313

14-
from ..utils.trt.vit_process import VisualTRT_V2
14+
try:
15+
from ..utils.trt.vit_process import VisualTRT_V2
16+
except Exception:
17+
pass
18+
1519
import torch
1620
import numpy as np
1721
import time
@@ -75,11 +79,11 @@ def run(self):
7579
if self.model_type == "QWEN2-VL":
7680
# warm up
7781
image = torch.randn(
78-
10080,
82+
2436,
7983
1176,
8084
dtype=torch.float16 if self.precision == "fp16" else torch.float32,
8185
)
82-
grid_thw = torch.tensor([[1, 120, 84]], dtype=torch.int64)
86+
grid_thw = torch.tensor([[1, 58, 42]], dtype=torch.int64)
8387
first_grid = grid_thw[0, 0].item()
8488
batch_tensor = torch.zeros(first_grid)
8589
dict(
@@ -93,6 +97,10 @@ def run(self):
9397
self.model = VisualTRT_V2(
9498
vit_engine_path=self.model_path, trt_vit_config=self.trt_vit_config
9599
)
100+
elif self.backend == "transformers":
101+
self.model = self.model_path.to(self.device)
102+
with torch.no_grad():
103+
self.model(image.to(self.device), grid_thw=grid_thw.to(self.device))
96104
elif self.backend == "hie":
97105
raise NotImplementedError
98106
else:
@@ -126,9 +134,6 @@ def get_vit_result(self, image, input_info):
126134
if self.model_type == "QWEN1-VL":
127135
output = self.model(image, use_flashattn=True)
128136
elif self.model_type == "QWEN2-VL":
129-
# grid_thw = torch.tensor(
130-
# [input_info["vit_grid_t"], input_info["vit_grid_h"], input_info["vit_grid_w"]], dtype=torch.int32
131-
# ).unsqueeze(0)
132137
grid_thw = np.array(
133138
[
134139
[
@@ -144,11 +149,14 @@ def get_vit_result(self, image, input_info):
144149
batch_tensor = torch.zeros(first_grid).to(
145150
dtype=torch.int32, device=self.device
146151
)
147-
# output = self.model(image, grid_thw, batch_tensor)
148-
output = self.model(image, grid_thw, batch_tensor)
149-
# print("vit output shape: ", output.shape)
152+
if self.backend == "tensorrt":
153+
output = self.model(image, grid_thw, batch_tensor)
154+
elif self.backend == "transformers":
155+
with torch.no_grad():
156+
output = self.model(image, grid_thw=grid_thw)
150157
else:
151158
output = self.model(image.contiguous().to(self.device), input_info)
159+
152160
return output
153161

154162
def process_request(self, task: VitRequest) -> None:

multimodal/dashinfer_vlm/vl_inference/utils/model_loader.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
import torch
77
import glob
8+
import warnings
89
from modelscope import snapshot_download
910
from transformers import Qwen2VLForConditionalGeneration, AutoConfig, AutoTokenizer
1011
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig
@@ -13,8 +14,20 @@
1314
from dashinfer import allspark
1415
from dashinfer.allspark.model_loader import HuggingFaceModel, ModelSerializerException
1516
from dashinfer.allspark.model_config import QWen2ConfigAdapter
16-
from .trt.onnx_to_plan import ONNX_TRT
17-
17+
try:
18+
from .trt.onnx_to_plan import ONNX_TRT
19+
except Exception:
20+
warnings.warn("TensorRT package is not available", ImportWarning)
21+
22+
def dtype_to_torch_dtype(dtype):
23+
if dtype == "float32":
24+
return torch.float32
25+
elif dtype == "float16":
26+
return torch.float16
27+
elif dtype == "bfloat16":
28+
return torch.bfloat16
29+
else:
30+
raise ValueError("unsupported data type: {}".format(dtype))
1831

1932
class HuggingFaceVLModel(HuggingFaceModel):
2033
def __init__(
@@ -49,6 +62,8 @@ def load_model(
4962
self.torch_model = Qwen2VLForConditionalGeneration.from_pretrained(
5063
self.hf_model_path,
5164
trust_remote_code=self.trust_remote_code,
65+
torch_dtype=dtype_to_torch_dtype(self.data_type),
66+
device_map="cpu",
5267
**kwargs,
5368
).eval()
5469
self.vit_config = Qwen2VLVisionConfig.from_pretrained(
@@ -62,17 +77,6 @@ def load_model(
6277
trust_remote_code=self.trust_remote_code,
6378
**kwargs,
6479
)
65-
self.torch_model = self.torch_model.cpu()
66-
67-
if self.data_type == "float32":
68-
self.torch_model.float()
69-
elif self.data_type == "float16":
70-
self.torch_model.half()
71-
elif self.data_type == "bfloat16":
72-
self.torch_model.bfloat16()
73-
else:
74-
self.torch_model = None
75-
raise ValueError("unsupported data type: {}".format(self.data_type))
7680
except Exception as e:
7781
print(
7882
f"exception when load model: {self.hf_model_path} , exception: {e}"
@@ -122,9 +126,17 @@ def serialize(
122126
onnx_trt_obj = ONNX_TRT(self.hf_model_path)
123127
onnx_trt_obj.export_onnx(onnxFile)
124128
onnx_trt_obj.generate_trt_engine(onnxFile, self.vision_model_path)
129+
elif self.vision_engine == "transformers":
130+
visual_model = Qwen2VLForConditionalGeneration.from_pretrained(
131+
self.hf_model_path,
132+
trust_remote_code=self.trust_remote_code,
133+
torch_dtype=dtype_to_torch_dtype(self.data_type),
134+
device_map="cpu",
135+
attn_implementation="flash_attention_2",
136+
).visual.eval()
137+
self.vision_model_path = visual_model
125138
else:
126139
raise ValueError(f"unsupported engine {self.vision_engine}")
127-
128140
# Convert Allspark LLM
129141
enable_quant = self.fp8
130142
weight_only_quant=False

multimodal/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
tensorrt==10.5.0
1+
dashinfer@https://github.com/modelscope/dash-infer/releases/download/v2.0.0-rc3/dashinfer-2.0.0rc3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
22
av
33
numpy==1.24.3
44
requests==2.32.3

0 commit comments

Comments
 (0)