4747from exllamav2 .pos_embedding import ExLlamaV2PosEmbedding
4848from exllamav2 .compat import safe_move_tensor
4949from exllamav2 .fasttensors import cleanup_stfiles
50- from exllamav2 .device import ExLlamaV2DeviceContext
50+ from exllamav2 .device import ExLlamaV2DeviceContext , set_device_streams
5151from exllamav2 .tensor_p import TPContext , BROADCAST_VC
5252import gc
5353import threading
@@ -923,6 +923,10 @@ def forward_chunk(self,
923923 seq_len <= self .config .max_output_len , \
924924 "seq_len exceeds max_output_len"
925925
926+ # Ensure streams are always set in the current thread
927+
928+ set_device_streams ()
929+
926930 # Output
927931
928932 r = {}
@@ -944,10 +948,6 @@ def forward_chunk(self,
944948 cache .current_seq_len = past_len
945949
946950 device = self .modules [0 ].device_idx
947- if device is not None and device >= 0 :
948- context = self .get_device_context (device )
949- if context :
950- torch .cuda .set_stream (context .stream )
951951
952952 for idx , module in enumerate (self .modules ):
953953
@@ -969,9 +969,6 @@ def forward_chunk(self,
969969 n_device = module .device_idx
970970 if n_device is not None and n_device != device and n_device >= 0 :
971971 x = safe_move_tensor (x , n_device , non_blocking = True )
972- device = n_device
973- context = self .get_device_context (device )
974- torch .cuda .set_stream (context .stream )
975972
976973 x = module .forward (x , cache = cache , attn_params = attn_params , past_len = past_len , loras = loras , ** kwargs )
977974
0 commit comments