Skip to content

Commit 8851441

Browse files
authored
Flash decoding kernel adding and prefill-chunking and prefix caching enabling in intel cpu/xpu (#2815)
* flash decoding Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable xpu flashdecoding Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * set flashdecoding blocksize as 64 Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable flashdecoding, prefill chunking and prefix caching Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * add flashdecoding-ipex Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
1 parent 82f6ea1 commit 8851441

File tree

6 files changed

+97
-40
lines changed

6 files changed

+97
-40
lines changed

Dockerfile_intel

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,9 +224,9 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
224224
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
225225

226226
FROM ${PLATFORM} AS final
227-
ENV ATTENTION=paged
228-
ENV PREFIX_CACHING=0
229-
ENV PREFILL_CHUNKING=0
227+
ENV ATTENTION=flashdecoding-ipex
228+
ENV PREFIX_CACHING=1
229+
ENV PREFILL_CHUNKING=1
230230
ENV CUDA_GRAPHS=0
231231
ENTRYPOINT ["text-generation-launcher"]
232232
CMD ["--json-output"]

launcher/src/main.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,9 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
143143
}
144144
}
145145

146-
let fallback_attention = if matches!(compute_capability, Some((major, _)) if major < 8) {
146+
let fallback_attention = if compute_capability.is_none()
147+
|| matches!(compute_capability, Some((major, _)) if major < 8)
148+
{
147149
"paged"
148150
} else {
149151
"flashdecoding"

server/text_generation_server/layers/attention/ipex.py

Lines changed: 67 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import intel_extension_for_pytorch as ipex
22
import torch
33
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
4-
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
54
from text_generation_server.layers.attention import Seqlen
65
from typing import Optional
6+
from text_generation_server.models.globals import (
7+
ATTENTION,
8+
BLOCK_SIZE,
9+
)
710

811
SUPPORTS_WINDOWING = False
912

@@ -28,22 +31,38 @@ def attention(
2831
out = torch.empty_like(query)
2932

3033
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
31-
ipex.llm.functional.varlen_attention(
32-
query.contiguous() if query.device.type == "xpu" else query,
33-
key.contiguous() if key.device.type == "xpu" else key,
34-
value.contiguous() if value.device.type == "xpu" else value,
35-
out,
36-
seqlen.cu_seqlen_q,
37-
seqlen.cu_seqlen_q,
38-
seqlen.max_q,
39-
seqlen.max_q,
40-
0.0,
41-
softmax_scale,
42-
False,
43-
causal,
44-
False,
45-
None,
46-
)
34+
if ATTENTION == "flashdecoding-ipex":
35+
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
36+
out,
37+
query.contiguous() if query.device.type == "xpu" else query,
38+
kv_cache.key,
39+
kv_cache.value,
40+
seqlen.cu_seqlen_q,
41+
seqlen.cu_seqlen_k,
42+
seqlen.max_q,
43+
seqlen.max_k,
44+
softmax_scale,
45+
causal,
46+
block_tables,
47+
None,
48+
)
49+
else:
50+
ipex.llm.functional.varlen_attention(
51+
query.contiguous() if query.device.type == "xpu" else query,
52+
key.contiguous() if key.device.type == "xpu" else key,
53+
value.contiguous() if value.device.type == "xpu" else value,
54+
out,
55+
seqlen.cu_seqlen_q,
56+
seqlen.cu_seqlen_q,
57+
seqlen.max_q,
58+
seqlen.max_q,
59+
0.0,
60+
softmax_scale,
61+
False,
62+
causal,
63+
False,
64+
None,
65+
)
4766

4867
return out
4968

@@ -64,20 +83,37 @@ def paged_attention(
6483
raise NotImplementedError("softcap is not available in IPEX")
6584

6685
out = torch.empty_like(query)
67-
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
68-
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
69-
out,
70-
query,
71-
kv_cache.key,
72-
kv_cache.value,
73-
kv_head_mapping,
74-
softmax_scale,
75-
block_tables,
76-
input_lengths,
77-
BLOCK_SIZE,
78-
max_s,
79-
None,
80-
)
86+
87+
if ATTENTION == "flashdecoding-ipex":
88+
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
89+
out,
90+
query.contiguous() if query.device.type == "xpu" else query,
91+
kv_cache.key,
92+
kv_cache.value,
93+
seqlen.cu_seqlen_q,
94+
seqlen.cu_seqlen_k,
95+
seqlen.max_q,
96+
seqlen.max_k,
97+
softmax_scale,
98+
True,
99+
block_tables,
100+
None,
101+
)
102+
else:
103+
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
104+
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
105+
out,
106+
query,
107+
kv_cache.key,
108+
kv_cache.value,
109+
kv_head_mapping,
110+
softmax_scale,
111+
block_tables,
112+
input_lengths,
113+
BLOCK_SIZE,
114+
max_s,
115+
None,
116+
)
81117
return out
82118

83119

server/text_generation_server/layers/attention/kv_cache.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ def __init__(
6666
else:
6767
x = BLOCK_SIZE // element_size
6868

69-
if ATTENTION in {"flashdecoding", "flashinfer"}:
69+
if ATTENTION in {"flashdecoding", "flashinfer"} or (
70+
ATTENTION == "flashdecoding-ipex" and device.type == "xpu"
71+
):
7072
self.kv_cache = (
7173
torch.empty(
7274
(num_blocks, BLOCK_SIZE, num_heads, head_size),
@@ -80,6 +82,7 @@ def __init__(
8082
),
8183
)
8284
elif SYSTEM == "ipex" and device == torch.device("cpu"):
85+
# ipex cpu flashdecoding kernel and paged attention kernel share same layout
8386
self.kv_cache = (
8487
torch.empty(
8588
(num_blocks, num_heads, BLOCK_SIZE, head_size),
@@ -187,6 +190,12 @@ def store(
187190
shape = key_cache.shape
188191
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
189192
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
193+
elif ATTENTION == "flashdecoding-ipex" and key.device.type == "xpu":
194+
import intel_extension_for_pytorch as ipex
195+
196+
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
197+
key, value, key_cache, value_cache, slots
198+
)
190199
else:
191200
paged_reshape_and_cache(key, value, key_cache, value_cache, slots)
192201

server/text_generation_server/models/globals.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,33 @@
1414
}
1515
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "1").lower() in {"1", "true"}
1616
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
17-
_expected = {"paged", "flashdecoding", "flashinfer"}
17+
_expected = {"paged", "flashdecoding", "flashdecoding-ipex", "flashinfer"}
1818
assert (
1919
ATTENTION in _expected
2020
), f"Attention is not valid {ATTENTION}, expected {_expected}"
2121
log_master(logger.info, f"Using Attention = {ATTENTION}")
2222

23-
if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
23+
if PREFIX_CACHING and ATTENTION not in {
24+
"flashinfer",
25+
"flashdecoding",
26+
"flashdecoding-ipex",
27+
}:
2428
raise RuntimeError("Prefix caching is only supported with flashinfer")
2529

2630
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
2731
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.95"))
2832
assert TGI_WIGGLE_ROOM > 0
2933
assert TGI_WIGGLE_ROOM < 1
3034

35+
3136
# This is overridden by the cli
3237
BLOCK_SIZE: int
3338
if ATTENTION == "flashdecoding":
3439
BLOCK_SIZE = 256
3540
elif ATTENTION == "flashinfer":
3641
BLOCK_SIZE = 1
42+
elif ATTENTION == "flashdecoding-ipex":
43+
BLOCK_SIZE = 64
3744
else:
3845
BLOCK_SIZE = 16
3946

server/text_generation_server/models/model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,13 @@ def __init__(
7979
"Prefill chunking will be turned off",
8080
)
8181
support_chunking = False
82-
if ATTENTION not in ["flashinfer", "flashdecoding"] and support_chunking:
82+
if (
83+
ATTENTION not in ["flashinfer", "flashdecoding", "flashdecoding-ipex"]
84+
and support_chunking
85+
):
8386
log_master(
8487
logger.warning,
85-
"Prefill chunking is only supported with `flashinfer` or `flashdecoding` attention types.",
88+
"Prefill chunking is only supported with `flashinfer` or `flashdecoding` or `flashdecoding-ipex` attention types.",
8689
)
8790
support_chunking = False
8891

0 commit comments

Comments
 (0)