Skip to content

Commit d658b5d

Browse files
authored
Deepseek R1 for Gaudi backend (#3211)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
1 parent 58934c8 commit d658b5d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1133
-238
lines changed

Dockerfile_gaudi

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ FROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytor
6060
ENV ATTENTION=default
6161
ENV PREFIX_CACHING=0
6262
ENV PREFILL_CHUNKING=0
63+
ENV PT_HPU_LAZY_MODE=1
64+
ENV PT_HPU_WEIGHT_SHARING=0
6365

6466
# Text Generation Inference base env
6567
ENV HF_HOME=/data \
@@ -95,7 +97,8 @@ RUN cd server && \
9597
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
9698
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
9799
pip install . --no-cache-dir
98-
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git
100+
RUN pip install git+https://github.com/HabanaAI/vllm-hpu-extension.git@a060794
101+
99102
# Install benchmarker
100103
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
101104
# Install router

backends/gaudi/server/text_generation_server/cli.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ class Dtype(str, Enum):
2626
bloat16 = "bfloat16"
2727

2828

29+
class KVCacheDtype(str, Enum):
30+
fp8_e4m3fn = "fp8_e4m3fn"
31+
fp8_e5m2 = "fp8_e5m2"
32+
33+
2934
@app.command()
3035
def serve(
3136
model_id: str,
@@ -34,6 +39,7 @@ def serve(
3439
quantize: Optional[Quantization] = None,
3540
speculate: Optional[int] = None,
3641
dtype: Optional[Dtype] = None,
42+
kv_cache_dtype: Optional[KVCacheDtype] = None,
3743
trust_remote_code: bool = False,
3844
uds_path: Path = "/tmp/text-generation-server",
3945
logger_level: str = "INFO",
@@ -93,7 +99,8 @@ def serve(
9399
# Downgrade enum into str for easier management later on
94100
quantize = None if quantize is None else quantize.value
95101
dtype = "bfloat16" if dtype is None else dtype.value
96-
logger.info(f"quantize={quantize}")
102+
kv_cache_dtype = None if kv_cache_dtype is None else kv_cache_dtype.value
103+
logger.info(f"quantize={quantize} kv_cache_dtype={kv_cache_dtype}")
97104
if dtype is not None and quantize not in {
98105
None,
99106
"bitsandbytes",
@@ -175,6 +182,7 @@ def terminate_handler(sig, frame):
175182
quantize,
176183
speculate,
177184
dtype,
185+
kv_cache_dtype,
178186
trust_remote_code,
179187
uds_path,
180188
max_input_tokens,

backends/gaudi/server/text_generation_server/layers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# Just to add the `load` methods.
1313
from text_generation_server.layers.layernorm import load_layer_norm
1414
from text_generation_server.layers.conv import load_conv2d
15+
from text_generation_server.layers.fp8 import Fp8Linear
1516

1617
from text_generation_server.layers.lora import (
1718
LoraLinear,
@@ -27,6 +28,7 @@
2728
"TensorParallelEmbedding",
2829
"SpeculativeHead",
2930
"LoraLinear",
31+
"Fp8Linear",
3032
"TensorParallelMultiAdapterLinear",
3133
"TensorParallelAdapterRowLinear",
3234
"load_layer_norm",

backends/gaudi/server/text_generation_server/layers/attention/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,21 @@
1010
SUPPORTS_WINDOWING,
1111
attention,
1212
paged_attention,
13+
paged_attention_mla,
1314
)
1415

1516

1617
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
17-
from .kv_cache import KVCache, get_kv_scales
18+
from .kv_cache import KVCache, get_kv_scales, KVCompressCache
1819

1920
__all__ = [
2021
"attention",
2122
"get_kv_scales",
2223
"paged_attention",
24+
"paged_attention_mla",
2325
"SUPPORTS_WINDOWING",
2426
"KVCache",
27+
"KVCompressCache",
2528
"Seqlen",
2629
"HPUPagedAttentionMetadata",
2730
"trim_seqlen_metadata",

backends/gaudi/server/text_generation_server/layers/attention/hpu.py

Lines changed: 96 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,61 @@
1111
SUPPORTS_WINDOWING = False
1212

1313

14-
def fetch_from_cache(cache, blocks):
15-
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
16-
return cache[: blocks.size(0)]
17-
else:
18-
return cache.index_select(0, blocks)
14+
class FP8Matmul(torch.nn.Module):
15+
16+
def __init__(self, scale_other):
17+
super().__init__()
18+
self.scale_input = torch.tensor(1.0, dtype=torch.bfloat16, device="hpu")
19+
self.scale_other = scale_other
20+
21+
def quant_input(self, x, scale):
22+
return torch.ops.hpu.cast_to_fp8_v2(
23+
x, scale, False, False, torch.float8_e4m3fn
24+
)[0]
25+
26+
def matmul_fp8(
27+
self, x, other, out_dtype, scale_input_inv=None, scale_other_inv=None
28+
):
29+
return torch.ops.hpu.fp8_gemm_v2(
30+
A=x,
31+
trans_A=False,
32+
B=other,
33+
trans_B=False,
34+
D=None,
35+
out_dtype=out_dtype,
36+
A_scale_inv=scale_input_inv,
37+
B_scale_inv=scale_other_inv,
38+
bias=None,
39+
accumulate=False,
40+
)
41+
42+
def forward(self, input, other):
43+
qinput = self.quant_input(input, self.scale_input)
44+
qother = self.quant_input(other, self.scale_other)
45+
output = self.matmul_fp8(
46+
qinput,
47+
qother,
48+
out_dtype=torch.bfloat16,
49+
scale_input_inv=1.0 / self.scale_input,
50+
scale_other_inv=1.0 / self.scale_other,
51+
)
52+
return output
53+
54+
55+
class FetchFromCache(torch.nn.Module):
56+
57+
def __init__(self, scale_inv):
58+
super().__init__()
59+
self.scale_inv = scale_inv
60+
61+
def forward(self, cache, blocks):
62+
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
63+
out = cache[: blocks.size(0)]
64+
else:
65+
out = cache.index_select(0, blocks)
66+
if out.dtype == torch.float8_e4m3fn:
67+
out = torch.ops.hpu.cast_from_fp8(out, self.scale_inv, torch.bfloat16)
68+
return out
1969

2070

2171
def attention(
@@ -67,6 +117,7 @@ def paged_attention(
67117
hpu_attention_meta: HPUPagedAttentionMetadata,
68118
):
69119
batch_size, head_num, head_size = query.shape
120+
fp8_kv = kv_cache.dtype == torch.float8_e4m3fn
70121
output = ops.flat_pa(
71122
query=query.view(batch_size, 1, head_num * head_size),
72123
key_cache=kv_cache.key,
@@ -76,19 +127,50 @@ def paged_attention(
76127
block_bias=hpu_attention_meta.attn_bias,
77128
block_groups=hpu_attention_meta.block_groups,
78129
scale=softmax_scale,
79-
matmul_qk_op=Matmul(),
80-
matmul_av_op=Matmul(),
130+
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
131+
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
81132
batch2block_matmul_op=Matmul(),
82133
block2batch_matmul_op=Matmul(),
83-
keys_fetch_func=fetch_from_cache,
84-
values_fetch_func=fetch_from_cache,
134+
keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu),
135+
values_fetch_func=FetchFromCache(1.0 / kv_scales.value_scale_cpu),
85136
)
86137
# Reshape the output tensor.
87138
return output.view(batch_size, head_num, head_size)
88139

89140

90-
__all__ = [
91-
"SUPPORTS_WINDOWING",
92-
"attention",
93-
"paged_attention",
94-
]
141+
def paged_attention_mla(
142+
query: torch.Tensor,
143+
kv_cache: KVCache,
144+
kv_head_mapping: torch.Tensor,
145+
softmax_scale: float,
146+
seqlen: Seqlen,
147+
*,
148+
kv_scales: KVScales,
149+
softcap: Optional[float] = None,
150+
hpu_attention_meta: HPUPagedAttentionMetadata,
151+
kv_lora_rank: int = 0,
152+
):
153+
batch_size, head_num, head_size = query.shape
154+
fp8_kv = kv_cache.dtype == torch.float8_e4m3fn
155+
output = ops.flat_pa_mla(
156+
query=query,
157+
key_cache=kv_cache.key,
158+
value_cache=None,
159+
block_list=hpu_attention_meta.block_list,
160+
block_mapping=hpu_attention_meta.block_mapping,
161+
block_bias=hpu_attention_meta.attn_bias,
162+
block_groups=hpu_attention_meta.block_groups,
163+
scale=softmax_scale,
164+
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
165+
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
166+
batch2block_matmul_op=Matmul(),
167+
block2batch_matmul_op=Matmul(),
168+
keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu),
169+
values_fetch_func=None,
170+
kv_lora_rank=kv_lora_rank,
171+
)
172+
# Reshape the output tensor.
173+
return output.view(batch_size, head_num, -1)
174+
175+
176+
__all__ = ["SUPPORTS_WINDOWING", "attention", "paged_attention", "paged_attention_mla"]

backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py

Lines changed: 76 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def __init__(
5050
):
5151
"""Construct the key-value cache for a layer."""
5252
## TODO FP8 kv cache support
53+
if dtype is torch.float8_e5m2:
54+
raise ValueError("torch.float8_e5m2 is not supported in hpu. ")
5355

5456
self.kv_cache = (
5557
torch.zeros(
@@ -101,22 +103,92 @@ def store(
101103
key_cache,
102104
value_cache,
103105
slots,
104-
kv_scales.key_scale_cpu,
105-
kv_scales.value_scale_cpu,
106+
kv_scales.key_scale,
107+
kv_scales.value_scale,
106108
)
107109

108110

111+
class KVCompressCache(KVCache):
112+
"""
113+
Key-value cache for attention layers.
114+
"""
115+
116+
kv_cache: torch.Tensor
117+
118+
def __init__(
119+
self,
120+
*,
121+
num_blocks: int,
122+
head_size: int,
123+
dtype: torch.dtype,
124+
device: torch.device,
125+
):
126+
"""Construct the key-value cache for a layer."""
127+
## TODO FP8 kv cache support
128+
if dtype is torch.float8_e5m2:
129+
raise ValueError("torch.float8_e5m2 is not supported in hpu. ")
130+
131+
self.kv_cache = torch.zeros(
132+
(num_blocks, BLOCK_SIZE, 1, head_size),
133+
dtype=dtype,
134+
device=device,
135+
)
136+
137+
@property
138+
def dtype(self):
139+
"""Get the data type of the cache."""
140+
return self.kv_cache.dtype
141+
142+
@property
143+
def key(self):
144+
"""Get the key cache."""
145+
146+
return self.kv_cache
147+
148+
@property
149+
def value(self):
150+
"""Get the value cache."""
151+
152+
return self.kv_cache
153+
154+
def store(
155+
self,
156+
*,
157+
key: torch.Tensor,
158+
value: torch.Tensor,
159+
slots: torch.Tensor,
160+
kv_scales: KVScales,
161+
):
162+
"""Store the key and value at the given slots."""
163+
## TODO FP8 kv cache support
164+
165+
block_idx = slots // BLOCK_SIZE
166+
block_offset = slots % BLOCK_SIZE
167+
if self.kv_cache.dtype == torch.float8_e4m3fn:
168+
key = torch.ops.hpu.cast_to_fp8_v2(
169+
key, kv_scales.key_scale, False, False, torch.float8_e4m3fn
170+
)[0]
171+
cache_ops.insert_or_update_cache(key, self.kv_cache, block_idx, block_offset)
172+
173+
109174
def paged_reshape_and_cache(
110175
key: torch.Tensor,
111176
value: torch.Tensor,
112177
key_cache: torch.Tensor,
113178
value_cache: torch.Tensor,
114179
slots: torch.Tensor,
115-
k_scale: float = 1.0,
116-
v_scale: float = 1.0,
180+
k_scale: torch.Tensor,
181+
v_scale: torch.Tensor,
117182
):
118183
block_idx = slots // BLOCK_SIZE
119184
block_offset = slots % BLOCK_SIZE
185+
if key_cache.dtype == torch.float8_e4m3fn:
186+
key = torch.ops.hpu.cast_to_fp8_v2(
187+
key, k_scale, False, False, torch.float8_e4m3fn
188+
)[0]
189+
value = torch.ops.hpu.cast_to_fp8_v2(
190+
value, v_scale, False, False, torch.float8_e4m3fn
191+
)[0]
120192
cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset)
121193
cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset)
122194

0 commit comments

Comments
 (0)