1414
1515import inspect
1616from dataclasses import dataclass , field
17- from typing import Any , Dict , List , Tuple , Union
17+ from typing import Any , Dict , List , Tuple , Union , Type
18+ from collections import OrderedDict
1819
1920import torch
2021from tqdm .auto import tqdm
3031from .pipeline_loading_utils import _fetch_class_library_tuple , _get_pipeline_class
3132from .pipeline_utils import DiffusionPipeline
3233
34+ import warnings
35+
3336
3437if is_accelerate_available ():
3538 import accelerate
@@ -99,6 +102,7 @@ class PipelineBlock:
99102 optional_components = []
100103 required_components = []
101104 required_auxiliaries = []
105+ optional_auxiliaries = []
102106
103107 @property
104108 def inputs (self ) -> Tuple [Tuple [str , Any ], ...]:
@@ -122,7 +126,7 @@ def __init__(self, **kwargs):
122126 for key , value in kwargs .items ():
123127 if key in self .required_components or key in self .optional_components :
124128 self .components [key ] = value
125- elif key in self .required_auxiliaries :
129+ elif key in self .required_auxiliaries or key in self . optional_auxiliaries :
126130 self .auxiliaries [key ] = value
127131 else :
128132 self .configs [key ] = value
@@ -152,10 +156,11 @@ def from_pipe(cls, pipe: DiffusionPipeline, **kwargs):
152156 components_to_add [component_name ] = component
153157
154158 # add auxiliaries
159+ expected_auxiliaries = set (cls .required_auxiliaries + cls .optional_auxiliaries )
155160 # - auxiliaries that are passed in kwargs
156- auxiliaries_to_add = {k : kwargs .pop (k ) for k in cls . required_auxiliaries if k in kwargs }
161+ auxiliaries_to_add = {k : kwargs .pop (k ) for k in expected_auxiliaries if k in kwargs }
157162 # - auxiliaries that are in the pipeline
158- for aux_name in cls . required_auxiliaries :
163+ for aux_name in expected_auxiliaries :
159164 if hasattr (pipe , aux_name ) and aux_name not in auxiliaries_to_add :
160165 auxiliaries_to_add [aux_name ] = getattr (pipe , aux_name )
161166 block_kwargs = {** components_to_add , ** auxiliaries_to_add }
@@ -167,7 +172,7 @@ def from_pipe(cls, pipe: DiffusionPipeline, **kwargs):
167172 expected_configs = {
168173 k
169174 for k in pipe .config .keys ()
170- if k in init_params and k not in expected_components and k not in cls . required_auxiliaries
175+ if k in init_params and k not in expected_components and k not in expected_auxiliaries
171176 }
172177
173178 for config_name in expected_configs :
@@ -210,6 +215,188 @@ def __repr__(self):
210215 )
211216
212217
218+ def combine_inputs (* input_lists : List [Tuple [str , Any ]]) -> List [Tuple [str , Any ]]:
219+ """
220+ Combines multiple lists of (name, default_value) tuples.
221+ For duplicate inputs, updates only if current value is None and new value is not None.
222+ Warns if multiple non-None default values exist for the same input.
223+ """
224+ combined_dict = {}
225+ for inputs in input_lists :
226+ for name , value in inputs :
227+ if name in combined_dict :
228+ current_value = combined_dict [name ]
229+ if current_value is not None and value is not None and current_value != value :
230+ warnings .warn (
231+ f"Multiple different default values found for input '{ name } ': "
232+ f"{ current_value } and { value } . Using { current_value } ."
233+ )
234+ if current_value is None and value is not None :
235+ combined_dict [name ] = value
236+ else :
237+ combined_dict [name ] = value
238+ return list (combined_dict .items ())
239+
240+
241+
242+ class AutoStep (PipelineBlock ):
243+ base_blocks = [] # list of block classes
244+ trigger_inputs = [] # list of trigger inputs (None for default block)
245+ required_components = []
246+ optional_components = []
247+ required_auxiliaries = []
248+ optional_auxiliaries = []
249+
250+ def __init__ (self , ** kwargs ):
251+ self .blocks = []
252+
253+ for block_cls , trigger in zip (self .base_blocks , self .trigger_inputs ):
254+ # Check components
255+ missing_components = [
256+ component for component in block_cls .required_components
257+ if component not in kwargs
258+ ]
259+
260+ # Check auxiliaries
261+ missing_auxiliaries = [
262+ auxiliary for auxiliary in block_cls .required_auxiliaries
263+ if auxiliary not in kwargs
264+ ]
265+
266+ if not missing_components and not missing_auxiliaries :
267+ # Only get kwargs that the block's __init__ accepts
268+ block_params = inspect .signature (block_cls .__init__ ).parameters
269+ block_kwargs = {
270+ k : v for k , v in kwargs .items ()
271+ if k in block_params
272+ }
273+ self .blocks .append (block_cls (** block_kwargs ))
274+
275+ # Print message about trigger condition
276+ if trigger is None :
277+ print (f"Added default block: { block_cls .__name__ } " )
278+ else :
279+ print (f"Added block { block_cls .__name__ } - will be dispatched if '{ trigger } ' input is not None" )
280+ else :
281+ if trigger is None :
282+ print (f"Cannot add default block { block_cls .__name__ } :" )
283+ else :
284+ print (f"Cannot add block { block_cls .__name__ } (triggered by '{ trigger } '):" )
285+ if missing_components :
286+ print (f" - Missing components: { missing_components } " )
287+ if missing_auxiliaries :
288+ print (f" - Missing auxiliaries: { missing_auxiliaries } " )
289+
290+ @property
291+ def components (self ):
292+ # Combine components from all blocks
293+ components = {}
294+ for block in self .blocks :
295+ components .update (block .components )
296+ return components
297+
298+ @property
299+ def auxiliaries (self ):
300+ # Combine auxiliaries from all blocks
301+ auxiliaries = {}
302+ for block in self .blocks :
303+ auxiliaries .update (block .auxiliaries )
304+ return auxiliaries
305+
306+ @property
307+ def configs (self ):
308+ # Combine configs from all blocks
309+ configs = {}
310+ for block in self .blocks :
311+ configs .update (block .configs )
312+ return configs
313+
314+ @property
315+ def inputs (self ) -> List [Tuple [str , Any ]]:
316+ return combine_inputs (* (block .inputs for block in self .blocks ))
317+
318+ @property
319+ def intermediates_inputs (self ) -> List [str ]:
320+ return list (set ().union (* (
321+ block .intermediates_inputs for block in self .blocks
322+ )))
323+
324+ @property
325+ def intermediates_outputs (self ) -> List [str ]:
326+ return list (set ().union (* (
327+ block .intermediates_outputs for block in self .blocks
328+ )))
329+
330+ def __call__ (self , pipeline , state ):
331+ # Check triggers in priority order
332+ for idx , trigger in enumerate (self .trigger_inputs [:- 1 ]): # Skip last (None) trigger
333+ if state .get_input (trigger ) is not None :
334+ return self .blocks [idx ](pipeline , state )
335+ # If no triggers match, use the default block (last one)
336+ return self .blocks [- 1 ](pipeline , state )
337+
338+
339+ def make_auto_step (pipeline_block_map : OrderedDict ) -> Type [AutoStep ]:
340+ """
341+ Creates a new AutoStep subclass with updated class attributes based on the pipeline block map.
342+
343+ Args:
344+ pipeline_block_map: OrderedDict mapping trigger inputs to pipeline block classes.
345+ Order determines priority (earlier entries take precedence).
346+ Must include None key for the default block.
347+ """
348+ blocks = list (pipeline_block_map .values ())
349+ triggers = list (pipeline_block_map .keys ())
350+
351+ # Get all expected components (either required or optional by any block)
352+ expected_components = []
353+ for block in blocks :
354+ for component in (block .required_components + block .optional_components ):
355+ if component not in expected_components :
356+ expected_components .append (component )
357+
358+ # A component is required if it's in required_components of all blocks
359+ required_components = [
360+ component for component in expected_components
361+ if all (component in block .required_components for block in blocks )
362+ ]
363+
364+ # All other expected components are optional
365+ optional_components = [
366+ component for component in expected_components
367+ if component not in required_components
368+ ]
369+
370+ # Get all expected auxiliaries (either required or optional by any block)
371+ expected_auxiliaries = []
372+ for block in blocks :
373+ for auxiliary in (block .required_auxiliaries + getattr (block , 'optional_auxiliaries' , [])):
374+ if auxiliary not in expected_auxiliaries :
375+ expected_auxiliaries .append (auxiliary )
376+
377+ # An auxiliary is required if it's in required_auxiliaries of all blocks
378+ required_auxiliaries = [
379+ auxiliary for auxiliary in expected_auxiliaries
380+ if all (auxiliary in block .required_auxiliaries for block in blocks )
381+ ]
382+
383+ # All other expected auxiliaries are optional
384+ optional_auxiliaries = [
385+ auxiliary for auxiliary in expected_auxiliaries
386+ if auxiliary not in required_auxiliaries
387+ ]
388+
389+ # Create new class with updated attributes
390+ return type ('AutoStep' , (AutoStep ,), {
391+ 'base_blocks' : blocks ,
392+ 'trigger_inputs' : triggers ,
393+ 'required_components' : required_components ,
394+ 'optional_components' : optional_components ,
395+ 'required_auxiliaries' : required_auxiliaries ,
396+ 'optional_auxiliaries' : optional_auxiliaries ,
397+ })
398+
399+
213400class ModularPipelineBuilder (ConfigMixin ):
214401 """
215402 Base class for all Modular pipelines.
@@ -585,7 +772,7 @@ def from_pipe(cls, pipeline, **kwargs):
585772 # Create each block, passing only unused items that the block expects
586773 for block_class in modular_pipeline_class .default_pipeline_blocks :
587774 expected_components = set (block_class .required_components + block_class .optional_components )
588- expected_auxiliaries = set (block_class .required_auxiliaries )
775+ expected_auxiliaries = set (block_class .required_auxiliaries + block_class . optional_auxiliaries )
589776
590777 # Get init parameters to check for expected configs
591778 init_params = inspect .signature (block_class .__init__ ).parameters
0 commit comments