@@ -126,13 +126,20 @@ class PeftLoraLoaderMixinTests:
126126 text_encoder_target_modules = ["q_proj" , "k_proj" , "v_proj" , "out_proj" ]
127127 denoiser_target_modules = ["to_q" , "to_k" , "to_v" , "to_out.0" ]
128128
129- def get_dummy_components (self , use_dora = False , lora_alpha = None ):
129+ cached_non_lora_output = None
130+
131+ def get_base_pipe_output (self ):
132+ if self .cached_non_lora_output is None :
133+ self .cached_non_lora_output = self ._compute_baseline_output ()
134+ return self .cached_non_lora_output
135+
136+ def get_dummy_components (self , scheduler_cls = None , use_dora = False , lora_alpha = None ):
130137 if self .unet_kwargs and self .transformer_kwargs :
131138 raise ValueError ("Both `unet_kwargs` and `transformer_kwargs` cannot be specified." )
132139 if self .has_two_text_encoders and self .has_three_text_encoders :
133140 raise ValueError ("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True." )
134141
135- scheduler_cls = self .scheduler_cls
142+ scheduler_cls = scheduler_cls if scheduler_cls is not None else self .scheduler_cls
136143 rank = 4
137144 lora_alpha = rank if lora_alpha is None else lora_alpha
138145
@@ -238,15 +245,16 @@ def get_dummy_inputs(self, with_generator=True):
238245
239246 return noise , input_ids , pipeline_inputs
240247
241- # Copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
242- def get_dummy_tokens (self ):
243- max_seq_length = 77
244-
245- inputs = torch . randint ( 2 , 56 , size = ( 1 , max_seq_length ), generator = torch . manual_seed ( 0 ) )
248+ def _compute_baseline_output ( self ):
249+ components , _ , _ = self . get_dummy_components (self . scheduler_cls )
250+ pipe = self . pipeline_class ( ** components )
251+ pipe = pipe . to ( torch_device )
252+ pipe . set_progress_bar_config ( disable = None )
246253
247- prepared_inputs = {}
248- prepared_inputs ["input_ids" ] = inputs
249- return prepared_inputs
254+ # Always ensure the inputs are without the `generator`. Make sure to pass the `generator`
255+ # explicitly.
256+ _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
257+ return pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
250258
251259 def _get_lora_state_dicts (self , modules_to_save ):
252260 state_dicts = {}
@@ -316,14 +324,8 @@ def test_simple_inference(self):
316324 """
317325 Tests a simple inference and makes sure it works as expected
318326 """
319- components , text_lora_config , _ = self .get_dummy_components ()
320- pipe = self .pipeline_class (** components )
321- pipe = pipe .to (torch_device )
322- pipe .set_progress_bar_config (disable = None )
323-
324- _ , _ , inputs = self .get_dummy_inputs ()
325- output_no_lora = pipe (** inputs )[0 ]
326- self .assertTrue (output_no_lora .shape == self .output_shape )
327+ output_no_lora = self .get_base_pipe_output ()
328+ assert output_no_lora .shape == self .output_shape
327329
328330 def test_simple_inference_with_text_lora (self ):
329331 """
@@ -336,9 +338,7 @@ def test_simple_inference_with_text_lora(self):
336338 pipe .set_progress_bar_config (disable = None )
337339 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
338340
339- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
340- self .assertTrue (output_no_lora .shape == self .output_shape )
341-
341+ output_no_lora = self .get_base_pipe_output ()
342342 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
343343
344344 output_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
@@ -414,9 +414,6 @@ def test_low_cpu_mem_usage_with_loading(self):
414414 pipe .set_progress_bar_config (disable = None )
415415 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
416416
417- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
418- self .assertTrue (output_no_lora .shape == self .output_shape )
419-
420417 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
421418
422419 images_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
@@ -466,8 +463,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
466463 pipe .set_progress_bar_config (disable = None )
467464 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
468465
469- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
470- self .assertTrue (output_no_lora .shape == self .output_shape )
466+ output_no_lora = self .get_base_pipe_output ()
471467
472468 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
473469
@@ -503,8 +499,7 @@ def test_simple_inference_with_text_lora_fused(self):
503499 pipe .set_progress_bar_config (disable = None )
504500 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
505501
506- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
507- self .assertTrue (output_no_lora .shape == self .output_shape )
502+ output_no_lora = self .get_base_pipe_output ()
508503
509504 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
510505
@@ -534,8 +529,7 @@ def test_simple_inference_with_text_lora_unloaded(self):
534529 pipe .set_progress_bar_config (disable = None )
535530 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
536531
537- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
538- self .assertTrue (output_no_lora .shape == self .output_shape )
532+ output_no_lora = self .get_base_pipe_output ()
539533
540534 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
541535
@@ -566,9 +560,6 @@ def test_simple_inference_with_text_lora_save_load(self):
566560 pipe .set_progress_bar_config (disable = None )
567561 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
568562
569- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
570- self .assertTrue (output_no_lora .shape == self .output_shape )
571-
572563 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
573564
574565 images_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
@@ -616,8 +607,7 @@ def test_simple_inference_with_partial_text_lora(self):
616607 pipe .set_progress_bar_config (disable = None )
617608 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
618609
619- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
620- self .assertTrue (output_no_lora .shape == self .output_shape )
610+ output_no_lora = self .get_base_pipe_output ()
621611
622612 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
623613
@@ -666,9 +656,6 @@ def test_simple_inference_save_pretrained_with_text_lora(self):
666656 pipe .set_progress_bar_config (disable = None )
667657 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
668658
669- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
670- self .assertTrue (output_no_lora .shape == self .output_shape )
671-
672659 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
673660 images_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
674661
@@ -708,9 +695,6 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
708695 pipe .set_progress_bar_config (disable = None )
709696 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
710697
711- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
712- self .assertTrue (output_no_lora .shape == self .output_shape )
713-
714698 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
715699
716700 images_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
@@ -747,9 +731,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
747731 pipe .set_progress_bar_config (disable = None )
748732 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
749733
750- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
751- self .assertTrue (output_no_lora .shape == self .output_shape )
752-
734+ output_no_lora = self .get_base_pipe_output ()
753735 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
754736
755737 output_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
@@ -790,8 +772,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self):
790772 pipe .set_progress_bar_config (disable = None )
791773 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
792774
793- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
794- self .assertTrue (output_no_lora .shape == self .output_shape )
775+ output_no_lora = self .get_base_pipe_output ()
795776
796777 pipe , denoiser = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
797778
@@ -825,8 +806,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self):
825806 pipe .set_progress_bar_config (disable = None )
826807 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
827808
828- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
829- self .assertTrue (output_no_lora .shape == self .output_shape )
809+ output_no_lora = self .get_base_pipe_output ()
830810
831811 pipe , denoiser = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
832812
@@ -900,7 +880,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self):
900880 pipe .set_progress_bar_config (disable = None )
901881 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
902882
903- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
883+ output_no_lora = self . get_base_pipe_output ()
904884
905885 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
906886 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1024,7 +1004,7 @@ def test_simple_inference_with_text_denoiser_block_scale(self):
10241004 pipe .set_progress_bar_config (disable = None )
10251005 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
10261006
1027- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
1007+ output_no_lora = self . get_base_pipe_output ()
10281008
10291009 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
10301010 self .assertTrue (check_if_lora_correctly_set (pipe .text_encoder ), "Lora not correctly set in text encoder" )
@@ -1080,7 +1060,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
10801060 pipe .set_progress_bar_config (disable = None )
10811061 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
10821062
1083- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
1063+ output_no_lora = self . get_base_pipe_output ()
10841064
10851065 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
10861066 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1240,7 +1220,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
12401220 pipe .set_progress_bar_config (disable = None )
12411221 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
12421222
1243- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
1223+ output_no_lora = self . get_base_pipe_output ()
12441224
12451225 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
12461226 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1331,7 +1311,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
13311311 pipe .set_progress_bar_config (disable = None )
13321312 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
13331313
1334- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
1314+ output_no_lora = self . get_base_pipe_output ()
13351315
13361316 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
13371317 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1551,7 +1531,6 @@ def test_get_list_adapters(self):
15511531
15521532 self .assertDictEqual (pipe .get_list_adapters (), dicts_to_be_checked )
15531533
1554- @require_peft_version_greater (peft_version = "0.6.2" )
15551534 def test_simple_inference_with_text_lora_denoiser_fused_multi (
15561535 self , expected_atol : float = 1e-3 , expected_rtol : float = 1e-3
15571536 ):
@@ -1565,9 +1544,6 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
15651544 pipe .set_progress_bar_config (disable = None )
15661545 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
15671546
1568- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
1569- self .assertTrue (output_no_lora .shape == self .output_shape )
1570-
15711547 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
15721548 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
15731549 self .assertTrue (check_if_lora_correctly_set (pipe .text_encoder ), "Lora not correctly set in text encoder" )
@@ -1641,8 +1617,7 @@ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expec
16411617 pipe .set_progress_bar_config (disable = None )
16421618 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
16431619
1644- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
1645- self .assertTrue (output_no_lora .shape == self .output_shape )
1620+ output_no_lora = self .get_base_pipe_output ()
16461621
16471622 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
16481623 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1685,7 +1660,6 @@ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expec
16851660 "LoRA should change the output" ,
16861661 )
16871662
1688- @require_peft_version_greater (peft_version = "0.9.0" )
16891663 def test_simple_inference_with_dora (self ):
16901664 components , text_lora_config , denoiser_lora_config = self .get_dummy_components (use_dora = True )
16911665 pipe = self .pipeline_class (** components )
@@ -1695,7 +1669,6 @@ def test_simple_inference_with_dora(self):
16951669
16961670 output_no_dora_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
16971671 self .assertTrue (output_no_dora_lora .shape == self .output_shape )
1698-
16991672 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
17001673
17011674 output_dora_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
@@ -1783,7 +1756,6 @@ def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self):
17831756 pipe = pipe .to (torch_device )
17841757 pipe .set_progress_bar_config (disable = None )
17851758 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
1786-
17871759 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
17881760
17891761 pipe .unet = torch .compile (pipe .unet , mode = "reduce-overhead" , fullgraph = True )
@@ -1820,7 +1792,7 @@ def test_logs_info_when_no_lora_keys_found(self):
18201792 pipe .set_progress_bar_config (disable = None )
18211793
18221794 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
1823- original_out = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
1795+ output_no_lora = self . get_base_pipe_output ()
18241796
18251797 no_op_state_dict = {"lora_foo" : torch .tensor (2.0 ), "lora_bar" : torch .tensor (3.0 )}
18261798 logger = logging .get_logger ("diffusers.loaders.peft" )
@@ -1832,7 +1804,7 @@ def test_logs_info_when_no_lora_keys_found(self):
18321804
18331805 denoiser = getattr (pipe , "unet" ) if self .unet_kwargs is not None else getattr (pipe , "transformer" )
18341806 self .assertTrue (cap_logger .out .startswith (f"No LoRA keys associated to { denoiser .__class__ .__name__ } " ))
1835- self .assertTrue (np .allclose (original_out , out_after_lora_attempt , atol = 1e-5 , rtol = 1e-5 ))
1807+ self .assertTrue (np .allclose (output_no_lora , out_after_lora_attempt , atol = 1e-5 , rtol = 1e-5 ))
18361808
18371809 # test only for text encoder
18381810 for lora_module in self .pipeline_class ._lora_loadable_modules :
@@ -1864,9 +1836,7 @@ def test_set_adapters_match_attention_kwargs(self):
18641836 pipe .set_progress_bar_config (disable = None )
18651837 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
18661838
1867- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
1868- self .assertTrue (output_no_lora .shape == self .output_shape )
1869-
1839+ output_no_lora = self .get_base_pipe_output ()
18701840 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
18711841
18721842 lora_scale = 0.5
@@ -2212,9 +2182,6 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
22122182 pipe = self .pipeline_class (** components ).to (torch_device )
22132183 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
22142184
2215- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2216- self .assertTrue (output_no_lora .shape == self .output_shape )
2217-
22182185 pipe , _ = self .add_adapters_to_pipeline (
22192186 pipe , text_lora_config = text_lora_config , denoiser_lora_config = denoiser_lora_config
22202187 )
@@ -2260,7 +2227,7 @@ def test_inference_load_delete_load_adapters(self):
22602227 pipe .set_progress_bar_config (disable = None )
22612228 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
22622229
2263- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
2230+ output_no_lora = self . get_base_pipe_output ()
22642231
22652232 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
22662233 pipe .text_encoder .add_adapter (text_lora_config )
0 commit comments