@@ -221,7 +221,10 @@ def set_device_map(self,
221221
222222 self .device_context = []
223223 for idx , scratch_bytes in enumerate (fixed_bytes ):
224- self .device_context .append (ExLlamaV2DeviceContext (self , idx , scratch_bytes ))
224+ if scratch_bytes > 0 :
225+ self .device_context .append (ExLlamaV2DeviceContext (self , idx , scratch_bytes ))
226+ else :
227+ self .device_context .append (None )
225228
226229 # Create map for cache
227230
@@ -300,7 +303,8 @@ def load_tp(
300303 callback : Callable [[int , int ], None ] | None = None ,
301304 callback_gen : Callable [[int , int ], None ] | None = None ,
302305 progress : bool = False ,
303- expect_cache_tokens : int = 0
306+ expect_cache_tokens : int = 0 ,
307+ expect_cache_base : type = None
304308 ):
305309
306310 if progress :
@@ -313,7 +317,7 @@ def callback_pb(a, b):
313317 assert callback is None , \
314318 "Cannot use callback function and console progress bar at the same time."
315319 callback = callback_pb
316- f = self .load_tp_gen (gpu_split , callback , callback_gen , expect_cache_tokens )
320+ f = self .load_tp_gen (gpu_split , callback , callback_gen , expect_cache_tokens , expect_cache_base )
317321 for item in f :
318322 pass
319323 if progress :
@@ -325,10 +329,11 @@ def load_tp_gen(
325329 gpu_split : list [float ] | None = None ,
326330 callback : Callable [[int , int ], None ] | None = None ,
327331 callback_gen : Callable [[int , int ], None ] | None = None ,
328- expect_cache_tokens : int = 0
332+ expect_cache_tokens : int = 0 ,
333+ expect_cache_base : type = None
329334 ):
330335 self .config .no_graphs = True
331- self .tp_context = TPContext (self , gpu_split , expect_cache_tokens )
336+ self .tp_context = TPContext (self , gpu_split , expect_cache_tokens , expect_cache_base )
332337
333338 # Create device tensors
334339
0 commit comments