Skip to content

Commit e705319

Browse files
committed
Enable defrag for paged TP cache
1 parent e89dc5b commit e705319

File tree

3 files changed

+15
-13
lines changed

3 files changed

+15
-13
lines changed

exllamav2/attn.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -757,16 +757,16 @@ def forward_paged_tp_old(
757757
else:
758758
cache_seqlens_a = attn_params.cache_seqlens_tp
759759

760-
if cache.q_block == 1:
761-
cache.get_kv_state(
762-
self.layer_idx,
763-
batch_size,
764-
0,
765-
attn_params.max_cache_seqlen,
766-
page_size,
767-
attn_params.cache_seqlens_tp,
768-
attn_params.block_index_tp
769-
)
760+
# if cache.q_block == 1:
761+
# cache.get_kv_state(
762+
# self.layer_idx,
763+
# batch_size,
764+
# 0,
765+
# attn_params.max_cache_seqlen,
766+
# page_size,
767+
# attn_params.cache_seqlens_tp,
768+
# attn_params.block_index_tp
769+
# )
770770

771771
flash_kwargs = {}
772772
if self.sliding_window:

exllamav2/cache.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,6 @@ def copy_states(
845845
to_row: int,
846846
to_rows: int
847847
):
848-
# TODO: Parallel implementation
849848
for cache, tcache in zip(self.caches, target.caches):
850849
cache.copy_states(
851850
tcache,
@@ -865,8 +864,10 @@ def touch_device(self, device):
865864

866865

867866
def all_tensors(self):
868-
# TODO: Support defrag with TP cache
869-
return []
867+
tensors = []
868+
for cache in self.caches:
869+
tensors += cache.all_tensors()
870+
return tensors
870871

871872

872873
def reset(self):

exllamav2/generator/dynamic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2222,6 +2222,7 @@ def prefill(self, results: list):
22222222
best_match_page = page
22232223
if best_match_page and best_match > 1:
22242224
page = seq.allocated_pages[p0]
2225+
# print([sap.page_index for sap in seq.allocated_pages])
22252226
for c in [self.generator.cache] if not self.generator.draft_model else \
22262227
[self.generator.cache, self.generator.draft_cache]:
22272228
c.copy_states(

0 commit comments

Comments
 (0)