@@ -142,7 +142,7 @@ def custom_offload_with_hook(
142142 user_hook .attach ()
143143 return user_hook
144144
145-
145+ # this is the class that user can customize to implement their own offload strategy
146146class AutoOffloadStrategy :
147147 """
148148 Offload strategy that should be used with `CustomOffloadHook` to automatically offload models to the CPU based on
@@ -213,7 +213,101 @@ def search_best_candidate(module_sizes, min_memory_offload):
213213 return hooks_to_offload
214214
215215
216+ # utils for display component info in a readable format
217+ # TODO: move to a different file
218+ def summarize_dict_by_value_and_parts (d : Dict [str , Any ]) -> Dict [str , Any ]:
219+ """Summarizes a dictionary by finding common prefixes that share the same value.
220+
221+ For a dictionary with dot-separated keys like: {
222+ 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6],
223+ 'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6],
224+ 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3],
225+ }
226+
227+ Returns a dictionary where keys are the shortest common prefixes and values are their shared values: {
228+ 'down_blocks': [0.6], 'up_blocks': [0.3]
229+ }
230+ """
231+ # First group by values - convert lists to tuples to make them hashable
232+ value_to_keys = {}
233+ for key , value in d .items ():
234+ value_tuple = tuple (value ) if isinstance (value , list ) else value
235+ if value_tuple not in value_to_keys :
236+ value_to_keys [value_tuple ] = []
237+ value_to_keys [value_tuple ].append (key )
238+
239+ def find_common_prefix (keys : List [str ]) -> str :
240+ """Find the shortest common prefix among a list of dot-separated keys."""
241+ if not keys :
242+ return ""
243+ if len (keys ) == 1 :
244+ return keys [0 ]
245+
246+ # Split all keys into parts
247+ key_parts = [k .split ("." ) for k in keys ]
248+
249+ # Find how many initial parts are common
250+ common_length = 0
251+ for parts in zip (* key_parts ):
252+ if len (set (parts )) == 1 : # All parts at this position are the same
253+ common_length += 1
254+ else :
255+ break
256+
257+ if common_length == 0 :
258+ return ""
259+
260+ # Return the common prefix
261+ return "." .join (key_parts [0 ][:common_length ])
262+
263+ # Create summary by finding common prefixes for each value group
264+ summary = {}
265+ for value_tuple , keys in value_to_keys .items ():
266+ prefix = find_common_prefix (keys )
267+ if prefix : # Only add if we found a common prefix
268+ # Convert tuple back to list if it was originally a list
269+ value = list (value_tuple ) if isinstance (d [keys [0 ]], list ) else value_tuple
270+ summary [prefix ] = value
271+ else :
272+ summary ["" ] = value # Use empty string if no common prefix
273+
274+ return summary
275+
276+
216277class ComponentsManager :
278+ """
279+ A central registry and management system for model components across multiple pipelines.
280+
281+ [`ComponentsManager`] provides a unified way to register, track, and reuse model components
282+ (like UNet, VAE, text encoders, etc.) across different modular pipelines. It includes
283+ features for duplicate detection, memory management, and component organization.
284+
285+ <Tip warning={true}>
286+
287+ This is an experimental feature and is likely to change in the future.
288+
289+ </Tip>
290+
291+ Example:
292+ ```python
293+ from diffusers import ComponentsManager
294+
295+ # Create a components manager
296+ cm = ComponentsManager()
297+
298+ # Add components
299+ cm.add("unet", unet_model, collection="sdxl")
300+ cm.add("vae", vae_model, collection="sdxl")
301+
302+ # Enable auto offloading
303+ cm.enable_auto_cpu_offload(device="cuda")
304+
305+ # Retrieve components
306+ unet = cm.get_one(name="unet", collection="sdxl")
307+ ```
308+ """
309+
310+
217311 _available_info_fields = [
218312 "model_id" ,
219313 "added_time" ,
@@ -278,7 +372,19 @@ def _lookup_ids(
278372 def _id_to_name (component_id : str ):
279373 return "_" .join (component_id .split ("_" )[:- 1 ])
280374
281- def add (self , name , component , collection : Optional [str ] = None ):
375+ def add (self , name : str , component : Any , collection : Optional [str ] = None ):
376+ """
377+ Add a component to the ComponentsManager.
378+
379+ Args:
380+ name (str): The name of the component
381+ component (Any): The component to add
382+ collection (Optional[str]): The collection to add the component to
383+
384+ Returns:
385+ str: The unique component ID, which is generated as "{name}_{id(component)}" where
386+ id(component) is Python's built-in unique identifier for the object
387+ """
282388 component_id = f"{ name } _{ id (component )} "
283389
284390 # check for duplicated components
@@ -334,6 +440,12 @@ def add(self, name, component, collection: Optional[str] = None):
334440 return component_id
335441
336442 def remove (self , component_id : str = None ):
443+ """
444+ Remove a component from the ComponentsManager.
445+
446+ Args:
447+ component_id (str): The ID of the component to remove
448+ """
337449 if component_id not in self .components :
338450 logger .warning (f"Component '{ component_id } ' not found in ComponentsManager" )
339451 return
@@ -545,6 +657,22 @@ def matches_pattern(component_id, pattern, exact_match=False):
545657 return get_return_dict (matches , return_dict_with_names )
546658
547659 def enable_auto_cpu_offload (self , device : Union [str , int , torch .device ] = "cuda" , memory_reserve_margin = "3GB" ):
660+ """
661+ Enable automatic CPU offloading for all components.
662+
663+ The algorithm works as follows:
664+ 1. All models start on CPU by default
665+ 2. When a model's forward pass is called, it's moved to the execution device
666+ 3. If there's insufficient memory, other models on the device are moved back to CPU
667+ 4. The system tries to offload the smallest combination of models that frees enough memory
668+ 5. Models stay on the execution device until another model needs memory and forces them off
669+
670+ Args:
671+ device (Union[str, int, torch.device]): The execution device where models are moved for forward passes
672+ memory_reserve_margin (str): The memory reserve margin to use, default is 3GB. This is the amount of
673+ memory to keep free on the device to avoid running out of memory during
674+ model execution (e.g., for intermediate activations, gradients, etc.)
675+ """
548676 if not is_accelerate_available ():
549677 raise ImportError ("Make sure to install accelerate to use auto_cpu_offload" )
550678
@@ -574,6 +702,9 @@ def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = "cuda"
574702 self ._auto_offload_device = device
575703
576704 def disable_auto_cpu_offload (self ):
705+ """
706+ Disable automatic CPU offloading for all components.
707+ """
577708 if self .model_hooks is None :
578709 self ._auto_offload_enabled = False
579710 return
@@ -595,13 +726,12 @@ def get_model_info(
595726 """Get comprehensive information about a component.
596727
597728 Args:
598- component_id: Name of the component to get info for
599- fields: Optional field (s) to return. Can be a string for single field or list of fields.
729+ component_id (str) : Name of the component to get info for
730+ fields ( Optional[Union[str, List[str]]]): Field (s) to return. Can be a string for single field or list of fields.
600731 If None, uses the available_info_fields setting.
601732
602733 Returns:
603- Dictionary containing requested component metadata. If fields is specified, returns only those fields. If a
604- single field is requested as string, returns just that field's value.
734+ Dictionary containing requested component metadata. If fields is specified, returns only those fields. Otherwise, returns all fields.
605735 """
606736 if component_id not in self .components :
607737 raise ValueError (f"Component '{ component_id } ' not found in ComponentsManager" )
@@ -808,15 +938,16 @@ def get_one(
808938 load_id : Optional [str ] = None ,
809939 ) -> Any :
810940 """
811- Get a single component by either: (1) searching name (pattern matching), collection, or load_id. (2) passing in
812- a component_id Raises an error if multiple components match or none are found. support pattern matching for
813- name
941+ Get a single component by either:
942+ - searching name (pattern matching), collection, or load_id.
943+ - passing in a component_id
944+ Raises an error if multiple components match or none are found.
814945
815946 Args:
816- component_id: Optional component ID to get
817- name: Component name or pattern
818- collection: Optional collection to filter by
819- load_id: Optional load_id to filter by
947+ component_id (Optional[str]) : Optional component ID to get
948+ name (Optional[str]) : Component name or pattern
949+ collection (Optional[str]) : Optional collection to filter by
950+ load_id (Optional[str]) : Optional load_id to filter by
820951
821952 Returns:
822953 A single component
@@ -847,6 +978,13 @@ def get_one(
847978 def get_ids (self , names : Union [str , List [str ]] = None , collection : Optional [str ] = None ):
848979 """
849980 Get component IDs by a list of names, optionally filtered by collection.
981+
982+ Args:
983+ names (Union[str, List[str]]): List of component names
984+ collection (Optional[str]): Optional collection to filter by
985+
986+ Returns:
987+ List[str]: List of component IDs
850988 """
851989 ids = set ()
852990 if not isinstance (names , list ):
@@ -858,6 +996,20 @@ def get_ids(self, names: Union[str, List[str]] = None, collection: Optional[str]
858996 def get_components_by_ids (self , ids : List [str ], return_dict_with_names : Optional [bool ] = True ):
859997 """
860998 Get components by a list of IDs.
999+
1000+ Args:
1001+ ids (List[str]):
1002+ List of component IDs
1003+ return_dict_with_names (Optional[bool]):
1004+ Whether to return a dictionary with component names as keys:
1005+
1006+ Returns:
1007+ Dict[str, Any]: Dictionary of components.
1008+ - If return_dict_with_names=True, keys are component names.
1009+ - If return_dict_with_names=False, keys are component IDs.
1010+
1011+ Raises:
1012+ ValueError: If duplicate component names are found in the search results when return_dict_with_names=True
8611013 """
8621014 components = {id : self .components [id ] for id in ids }
8631015
@@ -877,65 +1029,17 @@ def get_components_by_ids(self, ids: List[str], return_dict_with_names: Optional
8771029 def get_components_by_names (self , names : List [str ], collection : Optional [str ] = None ):
8781030 """
8791031 Get components by a list of names, optionally filtered by collection.
880- """
881- ids = self .get_ids (names , collection )
882- return self .get_components_by_ids (ids )
883-
884-
885- def summarize_dict_by_value_and_parts (d : Dict [str , Any ]) -> Dict [str , Any ]:
886- """Summarizes a dictionary by finding common prefixes that share the same value.
887-
888- For a dictionary with dot-separated keys like: {
889- 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6],
890- 'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6],
891- 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3],
892- }
893-
894- Returns a dictionary where keys are the shortest common prefixes and values are their shared values: {
895- 'down_blocks': [0.6], 'up_blocks': [0.3]
896- }
897- """
898- # First group by values - convert lists to tuples to make them hashable
899- value_to_keys = {}
900- for key , value in d .items ():
901- value_tuple = tuple (value ) if isinstance (value , list ) else value
902- if value_tuple not in value_to_keys :
903- value_to_keys [value_tuple ] = []
904- value_to_keys [value_tuple ].append (key )
9051032
906- def find_common_prefix (keys : List [str ]) -> str :
907- """Find the shortest common prefix among a list of dot-separated keys."""
908- if not keys :
909- return ""
910- if len (keys ) == 1 :
911- return keys [0 ]
912-
913- # Split all keys into parts
914- key_parts = [k .split ("." ) for k in keys ]
915-
916- # Find how many initial parts are common
917- common_length = 0
918- for parts in zip (* key_parts ):
919- if len (set (parts )) == 1 : # All parts at this position are the same
920- common_length += 1
921- else :
922- break
923-
924- if common_length == 0 :
925- return ""
1033+ Args:
1034+ names (List[str]): List of component names
1035+ collection (Optional[str]): Optional collection to filter by
9261036
927- # Return the common prefix
928- return "." . join ( key_parts [ 0 ][: common_length ])
1037+ Returns:
1038+ Dict[str, Any]: Dictionary of components with component names as keys
9291039
930- # Create summary by finding common prefixes for each value group
931- summary = {}
932- for value_tuple , keys in value_to_keys .items ():
933- prefix = find_common_prefix (keys )
934- if prefix : # Only add if we found a common prefix
935- # Convert tuple back to list if it was originally a list
936- value = list (value_tuple ) if isinstance (d [keys [0 ]], list ) else value_tuple
937- summary [prefix ] = value
938- else :
939- summary ["" ] = value # Use empty string if no common prefix
1040+ Raises:
1041+ ValueError: If duplicate component names are found in the search results
1042+ """
1043+ ids = self .get_ids (names , collection )
1044+ return self .get_components_by_ids (ids )
9401045
941- return summary
0 commit comments