1414
1515import inspect
1616from dataclasses import dataclass , field
17- from typing import Any , Dict , List , Optional , Tuple , Union
18- import importlib
19- from collections import OrderedDict
20- import PIL
17+ from typing import Any , Dict , List , Tuple , Union
18+
2119import torch
2220from tqdm .auto import tqdm
2321
2422from ..configuration_utils import ConfigMixin
25- from ..loaders import StableDiffusionXLLoraLoaderMixin , TextualInversionLoaderMixin
26- from ..models import ImageProjection
27- from ..models .attention_processor import AttnProcessor2_0 , XFormersAttnProcessor
28- from ..models .lora import adjust_lora_scale_text_encoder
2923from ..utils import (
30- USE_PEFT_BACKEND ,
3124 is_accelerate_available ,
3225 is_accelerate_version ,
3326 logging ,
34- scale_lora_layers ,
35- unscale_lora_layers ,
3627)
3728from ..utils .hub_utils import validate_hf_hub_args
38- from ..utils .torch_utils import randn_tensor
39- from .pipeline_loading_utils import _fetch_class_library_tuple , _get_pipeline_class
40- from .pipeline_utils import DiffusionPipeline , StableDiffusionMixin
4129from .auto_pipeline import _get_model
30+ from .pipeline_loading_utils import _fetch_class_library_tuple , _get_pipeline_class
31+ from .pipeline_utils import DiffusionPipeline
32+
4233
4334if is_accelerate_available ():
4435 import accelerate
5142}
5243
5344
54-
5545@dataclass
5646class PipelineState :
5747 """
@@ -225,6 +215,7 @@ class ModularPipelineBuilder(ConfigMixin):
225215 Base class for all Modular pipelines.
226216
227217 """
218+
228219 config_name = "model_index.json"
229220 model_cpu_offload_seq = None
230221 hf_device_map = None
@@ -316,7 +307,7 @@ def components(self) -> Dict[str, Any]:
316307 expected_components = set ()
317308 for block in self .pipeline_blocks :
318309 expected_components .update (block .components .keys ())
319-
310+
320311 components = {}
321312 for name in expected_components :
322313 if hasattr (self , name ):
@@ -349,8 +340,8 @@ def auxiliaries(self) -> Dict[str, Any]:
349340 @property
350341 def configs (self ) -> Dict [str , Any ]:
351342 r"""
352- The `self.configs` property returns all configs needed to initialize the pipeline, as defined by the
353- pipeline blocks.
343+ The `self.configs` property returns all configs needed to initialize the pipeline, as defined by the pipeline
344+ blocks.
354345
355346 Returns (`dict`):
356347 A dictionary containing all the configs defined in the pipeline blocks.
@@ -393,31 +384,32 @@ def __call__(self, *args, **kwargs):
393384
394385 def remove_blocks (self , indices : Union [int , List [int ]]):
395386 """
396- Remove one or more blocks from the pipeline by their indices and clean up associated components,
397- configs, and auxiliaries that are no longer needed by remaining blocks.
387+ Remove one or more blocks from the pipeline by their indices and clean up associated components, configs, and
388+ auxiliaries that are no longer needed by remaining blocks.
398389
399390 Args:
400391 indices (Union[int, List[int]]): The index or list of indices of blocks to remove
401392 """
402393 # Convert single index to list
403394 indices = [indices ] if isinstance (indices , int ) else indices
404-
395+
405396 # Validate indices
406397 for idx in indices :
407398 if not 0 <= idx < len (self .pipeline_blocks ):
408- raise ValueError (f"Invalid block index { idx } . Index must be between 0 and { len (self .pipeline_blocks ) - 1 } " )
409-
399+ raise ValueError (
400+ f"Invalid block index { idx } . Index must be between 0 and { len (self .pipeline_blocks ) - 1 } "
401+ )
402+
410403 # Sort indices in descending order to avoid shifting issues when removing
411404 indices = sorted (indices , reverse = True )
412-
405+
413406 # Store blocks to be removed
414407 blocks_to_remove = [self .pipeline_blocks [idx ] for idx in indices ]
415-
408+
416409 # Remove blocks from pipeline
417410 for idx in indices :
418411 self .pipeline_blocks .pop (idx )
419412
420-
421413 # Consolidate items to remove from all blocks
422414 components_to_remove = {k : v for block in blocks_to_remove for k , v in block .components .items ()}
423415 auxiliaries_to_remove = {k : v for block in blocks_to_remove for k , v in block .auxiliaries .items ()}
@@ -448,15 +440,15 @@ def remove_blocks(self, indices: Union[int, List[int]]):
448440
449441 def add_blocks (self , pipeline_blocks , at : int = - 1 ):
450442 """Add blocks to the pipeline.
451-
443+
452444 Args:
453445 pipeline_blocks: A single PipelineBlock instance or a list of PipelineBlock instances.
454446 at (int, optional): Index at which to insert the blocks. Defaults to -1 (append at end).
455447 """
456448 # Convert single block to list for uniform processing
457449 if not isinstance (pipeline_blocks , (list , tuple )):
458450 pipeline_blocks = [pipeline_blocks ]
459-
451+
460452 # Validate insert_at index
461453 if at != - 1 and not 0 <= at <= len (self .pipeline_blocks ):
462454 raise ValueError (f"Invalid at index { at } . Index must be between 0 and { len (self .pipeline_blocks )} " )
@@ -465,24 +457,24 @@ def add_blocks(self, pipeline_blocks, at: int = -1):
465457 components_to_add = {}
466458 configs_to_add = {}
467459 auxiliaries_to_add = {}
468-
460+
469461 # Add blocks in order
470462 for i , block in enumerate (pipeline_blocks ):
471463 # Add block to pipeline at specified position
472464 if at == - 1 :
473465 self .pipeline_blocks .append (block )
474466 else :
475467 self .pipeline_blocks .insert (at + i , block )
476-
468+
477469 # Collect components that don't already exist
478470 for k , v in block .components .items ():
479471 if not hasattr (self , k ) or (getattr (self , k , None ) is None and v is not None ):
480472 components_to_add [k ] = v
481-
473+
482474 # Collect configs and auxiliaries
483475 configs_to_add .update (block .configs )
484476 auxiliaries_to_add .update (block .auxiliaries )
485-
477+
486478 # Validate all required components and auxiliaries after consolidation
487479 for block in pipeline_blocks :
488480 for required_component in block .required_components :
@@ -513,44 +505,37 @@ def add_blocks(self, pipeline_blocks, at: int = -1):
513505 if configs_to_add :
514506 self .register_to_config (** configs_to_add )
515507 for key , value in auxiliaries_to_add .items ():
516-
517508 setattr (self , key , value )
518509
519510 def replace_blocks (self , pipeline_blocks , at : int ):
520511 """Replace one or more blocks in the pipeline at the specified index.
521-
512+
522513 Args:
523- pipeline_blocks: A single PipelineBlock instance or a list of PipelineBlock instances
514+ pipeline_blocks: A single PipelineBlock instance or a list of PipelineBlock instances
524515 that will replace existing blocks.
525516 at (int): Index at which to replace the blocks.
526517 """
527518 # Convert single block to list for uniform processing
528519 if not isinstance (pipeline_blocks , (list , tuple )):
529520 pipeline_blocks = [pipeline_blocks ]
530-
521+
531522 # Validate replace_at index
532523 if not 0 <= at < len (self .pipeline_blocks ):
533- raise ValueError (
534- f"Invalid at index { at } . Index must be between 0 and { len (self .pipeline_blocks ) - 1 } "
535- )
536-
524+ raise ValueError (f"Invalid at index { at } . Index must be between 0 and { len (self .pipeline_blocks ) - 1 } " )
525+
537526 # Add new blocks first
538527 self .add_blocks (pipeline_blocks , at = at )
539-
528+
540529 # Calculate indices to remove
541530 # We need to remove the original blocks that are now shifted by the length of pipeline_blocks
542- indices_to_remove = list (range (
543- at + len (pipeline_blocks ),
544- at + len (pipeline_blocks ) * 2
545- ))
546-
531+ indices_to_remove = list (range (at + len (pipeline_blocks ), at + len (pipeline_blocks ) * 2 ))
532+
547533 # Remove the old blocks
548534 self .remove_blocks (indices_to_remove )
549535
550536 @classmethod
551537 @validate_hf_hub_args
552538 def from_pretrained (cls , pretrained_model_or_path , ** kwargs ):
553-
554539 # (1) create the base pipeline
555540 cache_dir = kwargs .pop ("cache_dir" , None )
556541 force_download = kwargs .pop ("force_download" , False )
@@ -579,47 +564,41 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
579564 modular_pipeline_class_name = MODULAR_PIPELINE_MAPPING [_get_model (base_pipeline_class_name )]
580565 modular_pipeline_class = _get_pipeline_class (cls , config = None , class_name = modular_pipeline_class_name )
581566
582-
583567 # (3) create the pipeline blocks
584568 pipeline_blocks = [
585- block_class .from_pipe (base_pipeline )
586- for block_class in modular_pipeline_class .default_pipeline_blocks
569+ block_class .from_pipe (base_pipeline ) for block_class in modular_pipeline_class .default_pipeline_blocks
587570 ]
588571
589572 # (4) create the builder
590573 builder = modular_pipeline_class ()
591574 builder .add_blocks (pipeline_blocks )
592575
593576 return builder
594-
577+
595578 @classmethod
596579 def from_pipe (cls , pipeline , ** kwargs ):
597580 base_pipeline_class_name = pipeline .__class__ .__name__
598581 modular_pipeline_class_name = MODULAR_PIPELINE_MAPPING [_get_model (base_pipeline_class_name )]
599582 modular_pipeline_class = _get_pipeline_class (cls , config = None , class_name = modular_pipeline_class_name )
600-
583+
601584 pipeline_blocks = []
602585 # Create each block, passing only unused items that the block expects
603586 for block_class in modular_pipeline_class .default_pipeline_blocks :
604587 expected_components = set (block_class .required_components + block_class .optional_components )
605588 expected_auxiliaries = set (block_class .required_auxiliaries )
606-
589+
607590 # Get init parameters to check for expected configs
608591 init_params = inspect .signature (block_class .__init__ ).parameters
609592 expected_configs = {
610- k for k in init_params
611- if k not in expected_components
612- and k not in expected_auxiliaries
593+ k for k in init_params if k not in expected_components and k not in expected_auxiliaries
613594 }
614-
595+
615596 block_kwargs = {}
616-
597+
617598 for key , value in kwargs .items ():
618- if (key in expected_components or
619- key in expected_auxiliaries or
620- key in expected_configs ):
599+ if key in expected_components or key in expected_auxiliaries or key in expected_configs :
621600 block_kwargs [key ] = value
622-
601+
623602 # Create the block with filtered kwargs
624603 block = block_class .from_pipe (pipeline , ** block_kwargs )
625604 pipeline_blocks .append (block )
@@ -630,10 +609,10 @@ def from_pipe(cls, pipeline, **kwargs):
630609
631610 # Warn about unused kwargs
632611 unused_kwargs = {
633- k : v for k , v in kwargs .items ()
612+ k : v
613+ for k , v in kwargs .items ()
634614 if not any (
635- k in block .components or k in block .auxiliaries or k in block .configs
636- for block in pipeline_blocks
615+ k in block .components or k in block .auxiliaries or k in block .configs for block in pipeline_blocks
637616 )
638617 }
639618 if unused_kwargs :
@@ -774,7 +753,6 @@ def __repr__(self):
774753 output += f"{ name } : { config !r} \n "
775754 output += "\n "
776755
777-
778756 # List the default call parameters
779757 output += "Default Call Parameters:\n "
780758 output += "------------------------\n "
0 commit comments