Skip to content

Commit eecca27

Browse files
authored
feat: improve qwen2-vl startup (#2802)
* feat: tokenize each request individually and increase warmup image size * feat: adjust rotary embed and avoid cuda graphs of size 2 and smaller * fix: address image resize and rebase changes * feat: update to run qwen2-vl tests * fix: tweak param types
1 parent 6e982f4 commit eecca27

File tree

11 files changed

+173
-95
lines changed

11 files changed

+173
-95
lines changed

backends/client/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,6 @@ impl ChunksToString for Vec<InputChunk> {
8686
}
8787
}
8888

89-
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
89+
static WARMUP_IMAGE_BASE64: &str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
9090

9191
pub type Result<T> = std::result::Result<T, ClientError>;

backends/v2/src/client/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,6 @@ impl From<transport::Error> for ClientError {
6363
}
6464
}
6565

66-
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
66+
static WARMUP_IMAGE_BASE64: &str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
6767

6868
pub type Result<T> = std::result::Result<T, ClientError>;

backends/v3/src/client/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,6 @@ impl From<Chunk> for InputChunk {
6262
}
6363
}
6464

65-
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
65+
static WARMUP_IMAGE_BASE64: &str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
6666

6767
pub type Result<T> = std::result::Result<T, ClientError>;
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
{
2+
"choices": [
3+
{
4+
"finish_reason": "stop",
5+
"index": 0,
6+
"logprobs": null,
7+
"message": {
8+
"content": "The correct answer is: blue",
9+
"name": null,
10+
"role": "assistant",
11+
"tool_calls": null
12+
},
13+
"usage": null
14+
}
15+
],
16+
"created": 1733445131,
17+
"id": "",
18+
"model": "Qwen/Qwen2-VL-2B-Instruct",
19+
"object": "chat.completion",
20+
"system_fingerprint": "2.4.2-dev0-native",
21+
"usage": {
22+
"completion_tokens": 7,
23+
"prompt_tokens": 27,
24+
"total_tokens": 34
25+
}
26+
}
Lines changed: 80 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,80 @@
1-
# Disabled because it's broken.
2-
# import pytest
3-
#
4-
#
5-
# @pytest.fixture(scope="module")
6-
# def flash_qwen2_vl_handle(launcher):
7-
# with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle:
8-
# yield handle
9-
#
10-
#
11-
# @pytest.fixture(scope="module")
12-
# async def flash_qwen2(flash_qwen2_vl_handle):
13-
# await flash_qwen2_vl_handle.health(300)
14-
# return flash_qwen2_vl_handle.client
15-
#
16-
#
17-
# @pytest.mark.private
18-
# async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot):
19-
# response = await flash_qwen2.chat(
20-
# max_tokens=100,
21-
# seed=42,
22-
# messages=[
23-
# {
24-
# "role": "user",
25-
# "content": [
26-
# {
27-
# "type": "image_url",
28-
# "image_url": {
29-
# "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
30-
# },
31-
# },
32-
# {"type": "text", "text": "Describe this image."},
33-
# ],
34-
# },
35-
# ],
36-
# )
37-
#
38-
# assert (
39-
# response.choices[0].message.content
40-
# == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape."
41-
# )
42-
#
43-
# assert response == response_snapshot
44-
#
45-
#
46-
# @pytest.mark.private
47-
# async def test_flash_qwen2_vl_simple_streaming(flash_qwen2, response_snapshot):
48-
# responses = await flash_qwen2.chat(
49-
# max_tokens=100,
50-
# seed=42,
51-
# messages=[
52-
# {
53-
# "role": "user",
54-
# "content": [
55-
# {
56-
# "type": "image_url",
57-
# "image_url": {
58-
# "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
59-
# },
60-
# },
61-
# {"type": "text", "text": "Describe this image."},
62-
# ],
63-
# },
64-
# ],
65-
# stream=True,
66-
# )
67-
#
68-
# count = 0
69-
# generated = ""
70-
# last_response = None
71-
# async for response in responses:
72-
# count += 1
73-
# generated += response.choices[0].delta.content
74-
# last_response = response
75-
#
76-
# assert (
77-
# generated
78-
# == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape."
79-
# )
80-
# assert count == 58
81-
# assert last_response == response_snapshot
1+
import pytest
2+
3+
4+
@pytest.fixture(scope="module")
5+
def flash_qwen2_vl_handle(launcher):
6+
with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle:
7+
yield handle
8+
9+
10+
@pytest.fixture(scope="module")
11+
async def flash_qwen2(flash_qwen2_vl_handle):
12+
await flash_qwen2_vl_handle.health(300)
13+
return flash_qwen2_vl_handle.client
14+
15+
16+
@pytest.mark.private
17+
async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot):
18+
response = await flash_qwen2.chat(
19+
max_tokens=100,
20+
seed=42,
21+
messages=[
22+
{
23+
"role": "user",
24+
"content": [
25+
{
26+
"type": "image_url",
27+
"image_url": {
28+
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
29+
},
30+
},
31+
{"type": "text", "text": "Describe this image."},
32+
],
33+
},
34+
],
35+
)
36+
37+
assert (
38+
response.choices[0].message.content
39+
== "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape."
40+
)
41+
42+
assert response == response_snapshot
43+
44+
45+
@pytest.mark.private
46+
async def test_flash_qwen2_vl_simple_streaming(flash_qwen2, response_snapshot):
47+
responses = await flash_qwen2.chat(
48+
max_tokens=100,
49+
seed=42,
50+
messages=[
51+
{
52+
"role": "user",
53+
"content": [
54+
{
55+
"type": "image_url",
56+
"image_url": {
57+
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
58+
},
59+
},
60+
{"type": "text", "text": "Describe this image."},
61+
],
62+
},
63+
],
64+
stream=True,
65+
)
66+
67+
count = 0
68+
generated = ""
69+
last_response = None
70+
async for response in responses:
71+
count += 1
72+
generated += response.choices[0].delta.content
73+
last_response = response
74+
75+
assert (
76+
generated
77+
== "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape."
78+
)
79+
assert count == 58
80+
assert last_response == response_snapshot
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import pytest
2+
3+
4+
@pytest.fixture(scope="module")
5+
def flash_qwen2_vl_handle(launcher):
6+
with launcher(
7+
"Qwen/Qwen2-VL-2B-Instruct",
8+
max_input_length=40,
9+
max_batch_prefill_tokens=50,
10+
max_total_tokens=51,
11+
) as handle:
12+
yield handle
13+
14+
15+
@pytest.fixture(scope="module")
16+
async def flash_qwen2(flash_qwen2_vl_handle):
17+
await flash_qwen2_vl_handle.health(300)
18+
return flash_qwen2_vl_handle.client
19+
20+
21+
@pytest.mark.private
22+
async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot):
23+
response = await flash_qwen2.chat(
24+
max_tokens=20,
25+
seed=42,
26+
messages=[
27+
{
28+
"role": "user",
29+
"content": [
30+
{"type": "text", "text": "What is the color of the sky?"},
31+
],
32+
},
33+
],
34+
)
35+
36+
assert response.choices[0].message.content == "The correct answer is: blue"
37+
38+
assert response == response_snapshot

server/text_generation_server/models/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
BloomForCausalLM,
3030
)
3131
from text_generation_server.models.globals import ATTENTION
32+
import text_generation_server.models.globals as globals
3233
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
3334
from text_generation_server.models.galactica import GalacticaCausalLMBatch
3435
from text_generation_server.models.custom_modeling.neox_modeling import (
@@ -1217,6 +1218,11 @@ def get_model(
12171218
else:
12181219
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
12191220
if model_type == QWEN2_VL:
1221+
# TODO: remove edge case when cuda graph issue is resolved for BS=2 with Qwen2-VL
1222+
logger.warning(
1223+
"Qwen2-VL requires cuda graphs to be greater than 2. Removing all cuda graphs with a batch size equal or less than 2."
1224+
)
1225+
globals.CUDA_GRAPHS = list(filter(lambda x: x > 2, globals.CUDA_GRAPHS))
12201226
return VlmCausalLM(
12211227
model_id=model_id,
12221228
model_class=Qwen2VLForConditionalGeneration,

server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,12 @@ def forward(
138138
dim=-1,
139139
)
140140

141-
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
141+
self.rotary_emb(
142+
query,
143+
torch.select(kv, dim=1, index=0),
144+
cos[: query.shape[0], ...],
145+
sin[: query.shape[0], ...],
146+
)
142147

143148
if prefill_cache_indices is not None:
144149
kv_to_cache = kv[prefill_cache_indices]

server/text_generation_server/models/custom_modeling/qwen2_vl.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -517,11 +517,11 @@ def forward(
517517
pixel_values: torch.FloatTensor = None,
518518
image_grid_thw: Optional[torch.LongTensor] = None,
519519
video_grid_thw: Optional[torch.LongTensor] = None,
520-
pixel_attention_mask=None,
520+
pixel_attention_mask: Optional[torch.Tensor] = None,
521521
image_sizes: Optional[torch.LongTensor] = None,
522522
adapter_data: Optional[torch.Tensor] = None,
523523
cross_attention_states: Optional[torch.Tensor] = None,
524-
image_indices=None,
524+
image_indices: Optional[torch.Tensor] = None,
525525
):
526526
inputs_embeds = self.embed_tokens(input_ids)
527527

@@ -533,6 +533,7 @@ def forward(
533533
).squeeze(0)
534534
inputs_embeds[input_ids == self.image_token_id] = image_embeds
535535

536+
max_s = max(max_s, inputs_embeds.size(0))
536537
hidden_states = self.text_model(
537538
inputs_embeds=inputs_embeds,
538539
position_ids=position_ids,

server/text_generation_server/models/flash_causal_lm.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,13 @@
5656
MEM_POOL,
5757
ATTENTION,
5858
BLOCK_SIZE,
59-
CUDA_GRAPHS,
6059
REQUEST_LOGPROBS,
6160
TGI_WIGGLE_ROOM,
6261
get_adapter_to_index,
6362
)
63+
64+
# avoid coping CUDA_GRAPHS value by importing globals as a module
65+
import text_generation_server.models.globals as globals
6466
from text_generation_server.layers.attention import KVCache, Seqlen
6567
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
6668
from text_generation_server.utils.dist import MEMORY_FRACTION
@@ -1635,8 +1637,8 @@ def warmup(
16351637
int(val)
16361638
for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
16371639
]
1638-
elif CUDA_GRAPHS is not None:
1639-
tuning_sequences = CUDA_GRAPHS
1640+
elif globals.CUDA_GRAPHS is not None:
1641+
tuning_sequences = globals.CUDA_GRAPHS
16401642
else:
16411643
tuning_sequences = [1, 2, 3, 4, 5, 6, 7]
16421644

@@ -1675,13 +1677,14 @@ def warmup(
16751677
"PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.",
16761678
)
16771679

1678-
if CUDA_GRAPHS:
1680+
if globals.CUDA_GRAPHS:
16791681
try:
16801682
log_master(
1681-
logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}"
1683+
logger.info,
1684+
f"Cuda Graphs are enabled for sizes {globals.CUDA_GRAPHS}",
16821685
)
16831686
# Warmup cuda graphs
1684-
for bs in CUDA_GRAPHS:
1687+
for bs in globals.CUDA_GRAPHS:
16851688
synchronize(self.device)
16861689
free_memory = get_free_memory(
16871690
self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM
@@ -1705,7 +1708,8 @@ def warmup(
17051708
logger.exception("Decode cuda graph warmup failed")
17061709
else:
17071710
log_master(
1708-
logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})."
1711+
logger.info,
1712+
f"Cuda Graphs are disabled (CUDA_GRAPHS={globals.CUDA_GRAPHS}).",
17091713
)
17101714

17111715
assert max_input_tokens is not None

0 commit comments

Comments
 (0)