Skip to content

Commit 814d710

Browse files
authored
[tests] cache non lora pipeline outputs. (huggingface#12298)
* cache non lora pipeline outputs. * up * up * up * up * Revert "up" This reverts commit 772c32e. * up * Revert "up" This reverts commit cca03df. * up * up * add . * up * up * up * up * up * up
1 parent cc5b31f commit 814d710

File tree

4 files changed

+42
-83
lines changed

4 files changed

+42
-83
lines changed

tests/lora/test_lora_layers_cogview4.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,6 @@ def test_simple_inference_save_pretrained(self):
129129
pipe.set_progress_bar_config(disable=None)
130130
_, _, inputs = self.get_dummy_inputs(with_generator=False)
131131

132-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
133-
self.assertTrue(output_no_lora.shape == self.output_shape)
134-
135132
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
136133

137134
with tempfile.TemporaryDirectory() as tmpdirname:

tests/lora/test_lora_layers_flux.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,6 @@ def test_with_alpha_in_state_dict(self):
122122
pipe.set_progress_bar_config(disable=None)
123123
_, _, inputs = self.get_dummy_inputs(with_generator=False)
124124

125-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
126-
self.assertTrue(output_no_lora.shape == self.output_shape)
127-
128125
pipe.transformer.add_adapter(denoiser_lora_config)
129126
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
130127

@@ -170,8 +167,7 @@ def test_lora_expansion_works_for_absent_keys(self):
170167
pipe.set_progress_bar_config(disable=None)
171168
_, _, inputs = self.get_dummy_inputs(with_generator=False)
172169

173-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
174-
self.assertTrue(output_no_lora.shape == self.output_shape)
170+
output_no_lora = self.get_base_pipe_output()
175171

176172
# Modify the config to have a layer which won't be present in the second LoRA we will load.
177173
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
@@ -218,9 +214,7 @@ def test_lora_expansion_works_for_extra_keys(self):
218214
pipe = pipe.to(torch_device)
219215
pipe.set_progress_bar_config(disable=None)
220216
_, _, inputs = self.get_dummy_inputs(with_generator=False)
221-
222-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
223-
self.assertTrue(output_no_lora.shape == self.output_shape)
217+
output_no_lora = self.get_base_pipe_output()
224218

225219
# Modify the config to have a layer which won't be present in the first LoRA we will load.
226220
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
@@ -329,6 +323,7 @@ def get_dummy_inputs(self, with_generator=True):
329323
noise = floats_tensor((batch_size, num_channels) + sizes)
330324
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
331325

326+
np.random.seed(0)
332327
pipeline_inputs = {
333328
"prompt": "A painting of a squirrel eating a burger",
334329
"control_image": Image.fromarray(np.random.randint(0, 255, size=(32, 32, 3), dtype="uint8")),

tests/lora/test_lora_layers_wanvace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def test_lora_exclude_modules_wanvace(self):
169169
pipe = self.pipeline_class(**components).to(torch_device)
170170
_, _, inputs = self.get_dummy_inputs(with_generator=False)
171171

172-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
172+
output_no_lora = self.get_base_pipe_output()
173173
self.assertTrue(output_no_lora.shape == self.output_shape)
174174

175175
# only supported for `denoiser` now

tests/lora/utils.py

Lines changed: 38 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,20 @@ class PeftLoraLoaderMixinTests:
126126
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
127127
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
128128

129-
def get_dummy_components(self, use_dora=False, lora_alpha=None):
129+
cached_non_lora_output = None
130+
131+
def get_base_pipe_output(self):
132+
if self.cached_non_lora_output is None:
133+
self.cached_non_lora_output = self._compute_baseline_output()
134+
return self.cached_non_lora_output
135+
136+
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
130137
if self.unet_kwargs and self.transformer_kwargs:
131138
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
132139
if self.has_two_text_encoders and self.has_three_text_encoders:
133140
raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.")
134141

135-
scheduler_cls = self.scheduler_cls
142+
scheduler_cls = scheduler_cls if scheduler_cls is not None else self.scheduler_cls
136143
rank = 4
137144
lora_alpha = rank if lora_alpha is None else lora_alpha
138145

@@ -238,15 +245,16 @@ def get_dummy_inputs(self, with_generator=True):
238245

239246
return noise, input_ids, pipeline_inputs
240247

241-
# Copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
242-
def get_dummy_tokens(self):
243-
max_seq_length = 77
244-
245-
inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0))
248+
def _compute_baseline_output(self):
249+
components, _, _ = self.get_dummy_components(self.scheduler_cls)
250+
pipe = self.pipeline_class(**components)
251+
pipe = pipe.to(torch_device)
252+
pipe.set_progress_bar_config(disable=None)
246253

247-
prepared_inputs = {}
248-
prepared_inputs["input_ids"] = inputs
249-
return prepared_inputs
254+
# Always ensure the inputs are without the `generator`. Make sure to pass the `generator`
255+
# explicitly.
256+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
257+
return pipe(**inputs, generator=torch.manual_seed(0))[0]
250258

251259
def _get_lora_state_dicts(self, modules_to_save):
252260
state_dicts = {}
@@ -316,14 +324,8 @@ def test_simple_inference(self):
316324
"""
317325
Tests a simple inference and makes sure it works as expected
318326
"""
319-
components, text_lora_config, _ = self.get_dummy_components()
320-
pipe = self.pipeline_class(**components)
321-
pipe = pipe.to(torch_device)
322-
pipe.set_progress_bar_config(disable=None)
323-
324-
_, _, inputs = self.get_dummy_inputs()
325-
output_no_lora = pipe(**inputs)[0]
326-
self.assertTrue(output_no_lora.shape == self.output_shape)
327+
output_no_lora = self.get_base_pipe_output()
328+
assert output_no_lora.shape == self.output_shape
327329

328330
def test_simple_inference_with_text_lora(self):
329331
"""
@@ -336,9 +338,7 @@ def test_simple_inference_with_text_lora(self):
336338
pipe.set_progress_bar_config(disable=None)
337339
_, _, inputs = self.get_dummy_inputs(with_generator=False)
338340

339-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
340-
self.assertTrue(output_no_lora.shape == self.output_shape)
341-
341+
output_no_lora = self.get_base_pipe_output()
342342
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
343343

344344
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -414,9 +414,6 @@ def test_low_cpu_mem_usage_with_loading(self):
414414
pipe.set_progress_bar_config(disable=None)
415415
_, _, inputs = self.get_dummy_inputs(with_generator=False)
416416

417-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
418-
self.assertTrue(output_no_lora.shape == self.output_shape)
419-
420417
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
421418

422419
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -466,8 +463,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
466463
pipe.set_progress_bar_config(disable=None)
467464
_, _, inputs = self.get_dummy_inputs(with_generator=False)
468465

469-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
470-
self.assertTrue(output_no_lora.shape == self.output_shape)
466+
output_no_lora = self.get_base_pipe_output()
471467

472468
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
473469

@@ -503,8 +499,7 @@ def test_simple_inference_with_text_lora_fused(self):
503499
pipe.set_progress_bar_config(disable=None)
504500
_, _, inputs = self.get_dummy_inputs(with_generator=False)
505501

506-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
507-
self.assertTrue(output_no_lora.shape == self.output_shape)
502+
output_no_lora = self.get_base_pipe_output()
508503

509504
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
510505

@@ -534,8 +529,7 @@ def test_simple_inference_with_text_lora_unloaded(self):
534529
pipe.set_progress_bar_config(disable=None)
535530
_, _, inputs = self.get_dummy_inputs(with_generator=False)
536531

537-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
538-
self.assertTrue(output_no_lora.shape == self.output_shape)
532+
output_no_lora = self.get_base_pipe_output()
539533

540534
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
541535

@@ -566,9 +560,6 @@ def test_simple_inference_with_text_lora_save_load(self):
566560
pipe.set_progress_bar_config(disable=None)
567561
_, _, inputs = self.get_dummy_inputs(with_generator=False)
568562

569-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
570-
self.assertTrue(output_no_lora.shape == self.output_shape)
571-
572563
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
573564

574565
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -616,8 +607,7 @@ def test_simple_inference_with_partial_text_lora(self):
616607
pipe.set_progress_bar_config(disable=None)
617608
_, _, inputs = self.get_dummy_inputs(with_generator=False)
618609

619-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
620-
self.assertTrue(output_no_lora.shape == self.output_shape)
610+
output_no_lora = self.get_base_pipe_output()
621611

622612
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
623613

@@ -666,9 +656,6 @@ def test_simple_inference_save_pretrained_with_text_lora(self):
666656
pipe.set_progress_bar_config(disable=None)
667657
_, _, inputs = self.get_dummy_inputs(with_generator=False)
668658

669-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
670-
self.assertTrue(output_no_lora.shape == self.output_shape)
671-
672659
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
673660
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
674661

@@ -708,9 +695,6 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
708695
pipe.set_progress_bar_config(disable=None)
709696
_, _, inputs = self.get_dummy_inputs(with_generator=False)
710697

711-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
712-
self.assertTrue(output_no_lora.shape == self.output_shape)
713-
714698
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
715699

716700
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -747,9 +731,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
747731
pipe.set_progress_bar_config(disable=None)
748732
_, _, inputs = self.get_dummy_inputs(with_generator=False)
749733

750-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
751-
self.assertTrue(output_no_lora.shape == self.output_shape)
752-
734+
output_no_lora = self.get_base_pipe_output()
753735
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
754736

755737
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -790,8 +772,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self):
790772
pipe.set_progress_bar_config(disable=None)
791773
_, _, inputs = self.get_dummy_inputs(with_generator=False)
792774

793-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
794-
self.assertTrue(output_no_lora.shape == self.output_shape)
775+
output_no_lora = self.get_base_pipe_output()
795776

796777
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
797778

@@ -825,8 +806,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self):
825806
pipe.set_progress_bar_config(disable=None)
826807
_, _, inputs = self.get_dummy_inputs(with_generator=False)
827808

828-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
829-
self.assertTrue(output_no_lora.shape == self.output_shape)
809+
output_no_lora = self.get_base_pipe_output()
830810

831811
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
832812

@@ -900,7 +880,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self):
900880
pipe.set_progress_bar_config(disable=None)
901881
_, _, inputs = self.get_dummy_inputs(with_generator=False)
902882

903-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
883+
output_no_lora = self.get_base_pipe_output()
904884

905885
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
906886
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1024,7 +1004,7 @@ def test_simple_inference_with_text_denoiser_block_scale(self):
10241004
pipe.set_progress_bar_config(disable=None)
10251005
_, _, inputs = self.get_dummy_inputs(with_generator=False)
10261006

1027-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1007+
output_no_lora = self.get_base_pipe_output()
10281008

10291009
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
10301010
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
@@ -1080,7 +1060,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
10801060
pipe.set_progress_bar_config(disable=None)
10811061
_, _, inputs = self.get_dummy_inputs(with_generator=False)
10821062

1083-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1063+
output_no_lora = self.get_base_pipe_output()
10841064

10851065
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
10861066
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1240,7 +1220,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
12401220
pipe.set_progress_bar_config(disable=None)
12411221
_, _, inputs = self.get_dummy_inputs(with_generator=False)
12421222

1243-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1223+
output_no_lora = self.get_base_pipe_output()
12441224

12451225
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
12461226
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1331,7 +1311,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
13311311
pipe.set_progress_bar_config(disable=None)
13321312
_, _, inputs = self.get_dummy_inputs(with_generator=False)
13331313

1334-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1314+
output_no_lora = self.get_base_pipe_output()
13351315

13361316
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
13371317
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1551,7 +1531,6 @@ def test_get_list_adapters(self):
15511531

15521532
self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
15531533

1554-
@require_peft_version_greater(peft_version="0.6.2")
15551534
def test_simple_inference_with_text_lora_denoiser_fused_multi(
15561535
self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
15571536
):
@@ -1565,9 +1544,6 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
15651544
pipe.set_progress_bar_config(disable=None)
15661545
_, _, inputs = self.get_dummy_inputs(with_generator=False)
15671546

1568-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1569-
self.assertTrue(output_no_lora.shape == self.output_shape)
1570-
15711547
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
15721548
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
15731549
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
@@ -1641,8 +1617,7 @@ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expec
16411617
pipe.set_progress_bar_config(disable=None)
16421618
_, _, inputs = self.get_dummy_inputs(with_generator=False)
16431619

1644-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1645-
self.assertTrue(output_no_lora.shape == self.output_shape)
1620+
output_no_lora = self.get_base_pipe_output()
16461621

16471622
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
16481623
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1685,7 +1660,6 @@ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expec
16851660
"LoRA should change the output",
16861661
)
16871662

1688-
@require_peft_version_greater(peft_version="0.9.0")
16891663
def test_simple_inference_with_dora(self):
16901664
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(use_dora=True)
16911665
pipe = self.pipeline_class(**components)
@@ -1695,7 +1669,6 @@ def test_simple_inference_with_dora(self):
16951669

16961670
output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
16971671
self.assertTrue(output_no_dora_lora.shape == self.output_shape)
1698-
16991672
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
17001673

17011674
output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -1783,7 +1756,6 @@ def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self):
17831756
pipe = pipe.to(torch_device)
17841757
pipe.set_progress_bar_config(disable=None)
17851758
_, _, inputs = self.get_dummy_inputs(with_generator=False)
1786-
17871759
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
17881760

17891761
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
@@ -1820,7 +1792,7 @@ def test_logs_info_when_no_lora_keys_found(self):
18201792
pipe.set_progress_bar_config(disable=None)
18211793

18221794
_, _, inputs = self.get_dummy_inputs(with_generator=False)
1823-
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
1795+
output_no_lora = self.get_base_pipe_output()
18241796

18251797
no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)}
18261798
logger = logging.get_logger("diffusers.loaders.peft")
@@ -1832,7 +1804,7 @@ def test_logs_info_when_no_lora_keys_found(self):
18321804

18331805
denoiser = getattr(pipe, "unet") if self.unet_kwargs is not None else getattr(pipe, "transformer")
18341806
self.assertTrue(cap_logger.out.startswith(f"No LoRA keys associated to {denoiser.__class__.__name__}"))
1835-
self.assertTrue(np.allclose(original_out, out_after_lora_attempt, atol=1e-5, rtol=1e-5))
1807+
self.assertTrue(np.allclose(output_no_lora, out_after_lora_attempt, atol=1e-5, rtol=1e-5))
18361808

18371809
# test only for text encoder
18381810
for lora_module in self.pipeline_class._lora_loadable_modules:
@@ -1864,9 +1836,7 @@ def test_set_adapters_match_attention_kwargs(self):
18641836
pipe.set_progress_bar_config(disable=None)
18651837
_, _, inputs = self.get_dummy_inputs(with_generator=False)
18661838

1867-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1868-
self.assertTrue(output_no_lora.shape == self.output_shape)
1869-
1839+
output_no_lora = self.get_base_pipe_output()
18701840
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
18711841

18721842
lora_scale = 0.5
@@ -2212,9 +2182,6 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
22122182
pipe = self.pipeline_class(**components).to(torch_device)
22132183
_, _, inputs = self.get_dummy_inputs(with_generator=False)
22142184

2215-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
2216-
self.assertTrue(output_no_lora.shape == self.output_shape)
2217-
22182185
pipe, _ = self.add_adapters_to_pipeline(
22192186
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
22202187
)
@@ -2260,7 +2227,7 @@ def test_inference_load_delete_load_adapters(self):
22602227
pipe.set_progress_bar_config(disable=None)
22612228
_, _, inputs = self.get_dummy_inputs(with_generator=False)
22622229

2263-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
2230+
output_no_lora = self.get_base_pipe_output()
22642231

22652232
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
22662233
pipe.text_encoder.add_adapter(text_lora_config)

0 commit comments

Comments
 (0)