File tree Expand file tree Collapse file tree 3 files changed +15
-13
lines changed
Expand file tree Collapse file tree 3 files changed +15
-13
lines changed Original file line number Diff line number Diff line change @@ -757,16 +757,16 @@ def forward_paged_tp_old(
757757 else :
758758 cache_seqlens_a = attn_params .cache_seqlens_tp
759759
760- if cache .q_block == 1 :
761- cache .get_kv_state (
762- self .layer_idx ,
763- batch_size ,
764- 0 ,
765- attn_params .max_cache_seqlen ,
766- page_size ,
767- attn_params .cache_seqlens_tp ,
768- attn_params .block_index_tp
769- )
760+ # if cache.q_block == 1:
761+ # cache.get_kv_state(
762+ # self.layer_idx,
763+ # batch_size,
764+ # 0,
765+ # attn_params.max_cache_seqlen,
766+ # page_size,
767+ # attn_params.cache_seqlens_tp,
768+ # attn_params.block_index_tp
769+ # )
770770
771771 flash_kwargs = {}
772772 if self .sliding_window :
Original file line number Diff line number Diff line change @@ -845,7 +845,6 @@ def copy_states(
845845 to_row : int ,
846846 to_rows : int
847847 ):
848- # TODO: Parallel implementation
849848 for cache , tcache in zip (self .caches , target .caches ):
850849 cache .copy_states (
851850 tcache ,
@@ -865,8 +864,10 @@ def touch_device(self, device):
865864
866865
867866 def all_tensors (self ):
868- # TODO: Support defrag with TP cache
869- return []
867+ tensors = []
868+ for cache in self .caches :
869+ tensors += cache .all_tensors ()
870+ return tensors
870871
871872
872873 def reset (self ):
Original file line number Diff line number Diff line change @@ -2222,6 +2222,7 @@ def prefill(self, results: list):
22222222 best_match_page = page
22232223 if best_match_page and best_match > 1 :
22242224 page = seq .allocated_pages [p0 ]
2225+ # print([sap.page_index for sap in seq.allocated_pages])
22252226 for c in [self .generator .cache ] if not self .generator .draft_model else \
22262227 [self .generator .cache , self .generator .draft_cache ]:
22272228 c .copy_states (
You can’t perform that action at this time.
0 commit comments