Skip to content

Commit 630f198

Browse files
authored
flashinfer: switch to plan API (#2904)
This change doesn't switch `forward` to `run` yet, since it requires that we have access to the softmax scale and the logit softcap outside the model.
1 parent 8f6146f commit 630f198

File tree

2 files changed

+2
-5
lines changed

2 files changed

+2
-5
lines changed

server/text_generation_server/layers/attention/cuda.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,6 @@ def attention(
235235
paged_kv_cache=(kv_cache.key, kv_cache.value),
236236
logits_soft_cap=softcap,
237237
sm_scale=softmax_scale,
238-
window_left=window_size_left,
239238
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
240239
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
241240
)

server/text_generation_server/layers/attention/flashinfer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def use_prefill_with_paged_kv_state(
8484

8585
token = prefill_with_paged_kv_state.set(state)
8686
try:
87-
state.begin_forward(
87+
state.plan(
8888
qo_indptr=cu_seqlens,
8989
paged_kv_indptr=indptr,
9090
paged_kv_indices=block_tables,
@@ -99,7 +99,6 @@ def use_prefill_with_paged_kv_state(
9999
)
100100
yield
101101
finally:
102-
state.end_forward()
103102
if token is not None:
104103
prefill_with_paged_kv_state.reset(token)
105104

@@ -200,7 +199,7 @@ def use_decode_state(
200199
token = decode_state.set(state)
201200

202201
try:
203-
state.begin_forward(
202+
state.plan(
204203
indptr=indptr,
205204
indices=block_tables,
206205
last_page_len=last_page_len,
@@ -214,6 +213,5 @@ def use_decode_state(
214213
)
215214
yield
216215
finally:
217-
state.end_forward()
218216
if token is not None:
219217
decode_state.reset(token)

0 commit comments

Comments
 (0)