Skip to content

Commit c20025d

Browse files
authored
Add fp8 kv cache for ROCm (#2856)
* add fp8 kv cache for rocm * improvements * update log statement * remove bookkeeping field
1 parent de19e7e commit c20025d

File tree

2 files changed

+62
-34
lines changed

2 files changed

+62
-34
lines changed

server/text_generation_server/layers/attention/kv_cache.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,18 @@ def __init__(
5252
device: torch.device,
5353
):
5454
"""Construct the key-value cache for a layer."""
55-
56-
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn} and (
57-
ATTENTION != "flashinfer" or SYSTEM != "cuda"
58-
):
59-
raise ValueError(
60-
"FP8 KV cache is currently only supported for flashinfer on CUDA"
61-
)
55+
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:
56+
if not (
57+
(ATTENTION == "flashinfer" and SYSTEM == "cuda")
58+
or (ATTENTION == "paged" and SYSTEM == "rocm")
59+
):
60+
raise ValueError(
61+
"FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on ROCm. "
62+
)
63+
if SYSTEM == "rocm" and dtype == torch.float8_e5m2:
64+
raise ValueError(
65+
"float8_e5m2 FP8 KV cache is not supported on AMD ROCm"
66+
)
6267

6368
element_size = torch.tensor([], dtype=dtype).element_size()
6469
if SYSTEM == "ipex" and device.type == "xpu":
@@ -113,21 +118,17 @@ def can_scale(self, kv_scales: KVScales) -> bool:
113118
"""Check if the cache can be scaled by the given scales."""
114119
if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0:
115120
return False
116-
elif (
117-
self.dtype == torch.float8_e4m3fn
118-
and ATTENTION == "flashinfer"
119-
and SYSTEM == "cuda"
121+
elif self.dtype == torch.float8_e4m3fn and (
122+
(ATTENTION == "flashinfer" and SYSTEM == "cuda")
123+
or (ATTENTION == "paged" and SYSTEM == "rocm")
120124
):
121-
log_once(
122-
logger.info,
123-
"Using FP8 KV cache scales",
124-
)
125+
log_once(logger.info, "Using FP8 KV cache scales")
125126
return True
126127
else:
127128
# We have scales, but not the correct FP8 cache type, so warn once.
128129
log_once(
129130
logger.info,
130-
"Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on flashinfer is supported",
131+
"Ignoring FP8 KV cache scales, supported only for float8_e4m3fn KV cache with flashinfer on CUDA and paged attention on ROCm",
131132
)
132133
return False
133134

@@ -161,7 +162,7 @@ def store(
161162
key_cache = self.kv_cache[0]
162163
value_cache = self.kv_cache[1]
163164

164-
if self.can_scale(kv_scales):
165+
if self.can_scale(kv_scales) and SYSTEM == "cuda":
165166
if kv_scales.key_scale_cpu != 1.0:
166167
key = fp8_quantize(
167168
key.float(),
@@ -197,7 +198,15 @@ def store(
197198
key, value, key_cache, value_cache, slots
198199
)
199200
else:
200-
paged_reshape_and_cache(key, value, key_cache, value_cache, slots)
201+
paged_reshape_and_cache(
202+
key,
203+
value,
204+
key_cache,
205+
value_cache,
206+
slots,
207+
kv_scales.key_scale_cpu,
208+
kv_scales.value_scale_cpu,
209+
)
201210

202211

203212
def paged_reshape_and_cache(
@@ -206,7 +215,10 @@ def paged_reshape_and_cache(
206215
key_cache: torch.Tensor,
207216
value_cache: torch.Tensor,
208217
slots: torch.Tensor,
218+
k_scale: float = 1.0,
219+
v_scale: float = 1.0,
209220
):
221+
210222
if SYSTEM == "cuda":
211223
try:
212224
import attention_kernels
@@ -224,8 +236,15 @@ def paged_reshape_and_cache(
224236
raise ImportError(
225237
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
226238
)
239+
240+
kv_cache_dtype = "auto"
241+
if key_cache.dtype == torch.float8_e4m3fn:
242+
key_cache = key_cache.view(torch.uint8)
243+
value_cache = value_cache.view(torch.uint8)
244+
kv_cache_dtype = "fp8"
245+
227246
ops.reshape_and_cache(
228-
key, value, key_cache, value_cache, slots, "auto", 1.0, 1.0
247+
key, value, key_cache, value_cache, slots, kv_cache_dtype, k_scale, v_scale
229248
)
230249
elif SYSTEM == "ipex":
231250
import intel_extension_for_pytorch as ipex

server/text_generation_server/layers/attention/rocm.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,15 @@ def paged_attention(
133133

134134
out = torch.empty_like(query)
135135

136+
if kv_cache.dtype == torch.float8_e4m3fn:
137+
key = kv_cache.key.view(torch.uint8)
138+
value = kv_cache.value.view(torch.uint8)
139+
kv_cache_dtype = "fp8"
140+
else:
141+
key = kv_cache.key
142+
value = kv_cache.value
143+
kv_cache_dtype = "auto"
144+
136145
# NOTE(woosuk): We use a simple heuristic to decide whether to use
137146
# PagedAttention V1 or V2. If the number of partitions is 1, we use
138147
# V1 to avoid the overhead of reduction. Also, if the number of
@@ -147,18 +156,18 @@ def paged_attention(
147156
ops.paged_attention_v1(
148157
out,
149158
query,
150-
kv_cache.key,
151-
kv_cache.value,
159+
key,
160+
value,
152161
num_kv_heads,
153162
softmax_scale,
154163
block_tables,
155164
input_lengths,
156165
block_size,
157166
max_s,
158167
None,
159-
"auto",
160-
1.0,
161-
1.0,
168+
kv_cache_dtype,
169+
kv_scales.key_scale_cpu,
170+
kv_scales.value_scale_cpu,
162171
)
163172
else:
164173
# Run PagedAttention V2.
@@ -182,18 +191,18 @@ def paged_attention(
182191
max_logits,
183192
tmp_output,
184193
query,
185-
kv_cache.key,
186-
kv_cache.value,
194+
key,
195+
value,
187196
num_kv_heads,
188197
softmax_scale,
189198
block_tables,
190199
input_lengths,
191200
block_size,
192201
max_s,
193202
None,
194-
"auto",
195-
1.0,
196-
1.0,
203+
kv_cache_dtype,
204+
kv_scales.key_scale_cpu,
205+
kv_scales.value_scale_cpu,
197206
)
198207
else:
199208
ops.paged_attention_rocm(
@@ -202,18 +211,18 @@ def paged_attention(
202211
max_logits,
203212
tmp_output,
204213
query,
205-
kv_cache.key,
206-
kv_cache.value,
214+
key,
215+
value,
207216
num_kv_heads,
208217
softmax_scale,
209218
block_tables,
210219
input_lengths,
211220
block_size,
212221
max_s,
213222
None,
214-
"auto",
215-
1.0,
216-
1.0,
223+
kv_cache_dtype,
224+
kv_scales.key_scale_cpu,
225+
kv_scales.value_scale_cpu,
217226
None,
218227
_PARTITION_SIZE,
219228
)

0 commit comments

Comments
 (0)