@@ -244,6 +244,32 @@ def load(
244244 callback_gen : Callable [[int , int ], None ] | None = None ,
245245 progress : bool = False
246246 ):
247+ """
248+ Load model, regular manual split mode.
249+
250+ :param gpu_split:
251+ List of VRAM allocations for weights and fixed buffers per GPU. Does not account for the size of the cache
252+ which must be allocated with reference to the model subsequently and whose split across GPUs will depend
253+ on which devices end up receiving which attention layers.
254+
255+ If None, only the first GPU is used.
256+
257+ :param lazy:
258+ Only set the device map according to the split, but don't actually load any of the modules. Modules can
259+ subsequently be loaded and unloaded one by one for layer-streaming mode.
260+
261+ :param stats:
262+ Legacy, unused
263+
264+ :param callback:
265+ Callable function that triggers after each layer has loaded, for progress update etc.
266+
267+ :param callback_gen:
268+ Same as callback, but for use by async functions
269+
270+ :param progress:
271+ If True, create a rich progress bar in the console while loading. Cannot be used with callbacks
272+ """
247273
248274 if progress :
249275 progressbar = get_basic_progress ()
@@ -270,7 +296,6 @@ def load_gen(
270296 callback : Callable [[int , int ], None ] | None = None ,
271297 callback_gen : Callable [[int , int ], None ] | None = None
272298 ):
273-
274299 with torch .inference_mode ():
275300
276301 stats_ = self .set_device_map (gpu_split or [99999 ])
@@ -306,7 +331,34 @@ def load_tp(
306331 expect_cache_tokens : int = 0 ,
307332 expect_cache_base : type = None
308333 ):
334+ """
335+ Load model, tensor-parallel mode.
336+
337+ :param gpu_split:
338+ List of VRAM allocations per GPU. The loader attempts to balance tensor splits to stay within these
339+ allocations, accounting for an uneven distribution of attention heads and the expected size of the cache.
340+
341+ If None, the loader attempts to use all available GPUs and creates a split based on the currently available
342+ VRAM according to nvidia-smi etc.
343+
344+ :param callback:
345+ Callable function that triggers after each layer has loaded, for progress update etc.
346+
347+ :param callback_gen:
348+ Same as callback, but for use by async functions
309349
350+ :param progress:
351+ If True, create a rich progress bar in the console while loading. Cannot be used with callbacks
352+
353+ :param expect_cache_tokens:
354+ Expected size of the cache, in tokens (i.e. max_seq_len * max_batch_size, or just the cache size for use
355+ with the dynamic generator) to inform the automatic tensor split. If not provided, the configured
356+ max_seq_len for the model is assumed.
357+
358+ :param expect_cache_base:
359+ Cache type to expect, e.g. ExLlamaV2Cache_Q6. Also informs the tensor split. If not provided, FP16 cache
360+ is assumed.
361+ """
310362 if progress :
311363 progressbar = get_basic_progress ()
312364 progressbar .start ()
@@ -400,7 +452,31 @@ def load_autosplit(
400452 callback_gen : Callable [[int , int ], None ] | None = None ,
401453 progress : bool = False
402454 ):
455+ """
456+ Load model, auto-split mode. This mode loads the model and builds the cache in parallel, using available
457+ devices in turn and moving on to the next device whenever the previous one is full.
458+
459+ :param cache:
460+ Cache constructed with lazy = True. Actual tensor allocation for the cache will happen while loading the
461+ model.
462+
463+ :param reserve_vram:
464+ Number of bytes to reserve on each device, either for all devices (as an int) or per-device (as a list).
403465
466+ :param last_id_only:
467+ If True, model will be loaded in a mode that does can only output one set of logits (i.e. one token
468+ position) per forward pass. This conserves memory if the model is only to be used for generating text and
469+ not e.g. perplexity measurement.
470+
471+ :param callback:
472+ Callable function that triggers after each layer has loaded, for progress update etc.
473+
474+ :param callback_gen:
475+ Same as callback, but for use by async functions
476+
477+ :param progress:
478+ If True, create a rich progress bar in the console while loading. Cannot be used with callbacks
479+ """
404480 if progress :
405481 progressbar = get_basic_progress ()
406482 progressbar .start ()
@@ -569,6 +645,9 @@ def load_autosplit_gen(
569645
570646
571647 def unload (self ):
648+ """
649+ Unloads the model and frees all unmanaged resources.
650+ """
572651
573652 for module in self .modules :
574653 module .unload ()
0 commit comments