@@ -52,13 +52,18 @@ def __init__(
5252 device : torch .device ,
5353 ):
5454 """Construct the key-value cache for a layer."""
55-
56- if dtype in {torch .float8_e5m2 , torch .float8_e4m3fn } and (
57- ATTENTION != "flashinfer" or SYSTEM != "cuda"
58- ):
59- raise ValueError (
60- "FP8 KV cache is currently only supported for flashinfer on CUDA"
61- )
55+ if dtype in {torch .float8_e5m2 , torch .float8_e4m3fn }:
56+ if not (
57+ (ATTENTION == "flashinfer" and SYSTEM == "cuda" )
58+ or (ATTENTION == "paged" and SYSTEM == "rocm" )
59+ ):
60+ raise ValueError (
61+ "FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on ROCm. "
62+ )
63+ if SYSTEM == "rocm" and dtype == torch .float8_e5m2 :
64+ raise ValueError (
65+ "float8_e5m2 FP8 KV cache is not supported on AMD ROCm"
66+ )
6267
6368 element_size = torch .tensor ([], dtype = dtype ).element_size ()
6469 if SYSTEM == "ipex" and device .type == "xpu" :
@@ -113,21 +118,17 @@ def can_scale(self, kv_scales: KVScales) -> bool:
113118 """Check if the cache can be scaled by the given scales."""
114119 if kv_scales .key_scale_cpu == 1.0 and kv_scales .value_scale_cpu == 1.0 :
115120 return False
116- elif (
117- self .dtype == torch .float8_e4m3fn
118- and ATTENTION == "flashinfer"
119- and SYSTEM == "cuda"
121+ elif self .dtype == torch .float8_e4m3fn and (
122+ (ATTENTION == "flashinfer" and SYSTEM == "cuda" )
123+ or (ATTENTION == "paged" and SYSTEM == "rocm" )
120124 ):
121- log_once (
122- logger .info ,
123- "Using FP8 KV cache scales" ,
124- )
125+ log_once (logger .info , "Using FP8 KV cache scales" )
125126 return True
126127 else :
127128 # We have scales, but not the correct FP8 cache type, so warn once.
128129 log_once (
129130 logger .info ,
130- "Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on flashinfer is supported " ,
131+ "Ignoring FP8 KV cache scales, supported only for float8_e4m3fn KV cache with flashinfer on CUDA and paged attention on ROCm " ,
131132 )
132133 return False
133134
@@ -161,7 +162,7 @@ def store(
161162 key_cache = self .kv_cache [0 ]
162163 value_cache = self .kv_cache [1 ]
163164
164- if self .can_scale (kv_scales ):
165+ if self .can_scale (kv_scales ) and SYSTEM == "cuda" :
165166 if kv_scales .key_scale_cpu != 1.0 :
166167 key = fp8_quantize (
167168 key .float (),
@@ -197,7 +198,15 @@ def store(
197198 key , value , key_cache , value_cache , slots
198199 )
199200 else :
200- paged_reshape_and_cache (key , value , key_cache , value_cache , slots )
201+ paged_reshape_and_cache (
202+ key ,
203+ value ,
204+ key_cache ,
205+ value_cache ,
206+ slots ,
207+ kv_scales .key_scale_cpu ,
208+ kv_scales .value_scale_cpu ,
209+ )
201210
202211
203212def paged_reshape_and_cache (
@@ -206,7 +215,10 @@ def paged_reshape_and_cache(
206215 key_cache : torch .Tensor ,
207216 value_cache : torch .Tensor ,
208217 slots : torch .Tensor ,
218+ k_scale : float = 1.0 ,
219+ v_scale : float = 1.0 ,
209220):
221+
210222 if SYSTEM == "cuda" :
211223 try :
212224 import attention_kernels
@@ -224,8 +236,15 @@ def paged_reshape_and_cache(
224236 raise ImportError (
225237 f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: { e } "
226238 )
239+
240+ kv_cache_dtype = "auto"
241+ if key_cache .dtype == torch .float8_e4m3fn :
242+ key_cache = key_cache .view (torch .uint8 )
243+ value_cache = value_cache .view (torch .uint8 )
244+ kv_cache_dtype = "fp8"
245+
227246 ops .reshape_and_cache (
228- key , value , key_cache , value_cache , slots , "auto" , 1.0 , 1.0
247+ key , value , key_cache , value_cache , slots , kv_cache_dtype , k_scale , v_scale
229248 )
230249 elif SYSTEM == "ipex" :
231250 import intel_extension_for_pytorch as ipex
0 commit comments