From b0d17ad28e21a99d113e0ee932265748f5ae53ae Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Tue, 20 May 2025 13:12:28 -0400 Subject: [PATCH 1/8] init test Signed-off-by: Liu, Kaixuan --- .../pipelines/pipeline_loading_utils.py | 35 +++++++++++++++---- 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 3404ae5130fe..75ad4cfc1852 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -92,7 +92,7 @@ ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) -def is_safetensors_compatible(filenames, passed_components=None, folder_names=None) -> bool: +def is_safetensors_compatible(filenames, passed_components=None, folder_names=None, variant=None) -> bool: """ Checking for safetensors compatibility: - The model is safetensors compatible only if there is a safetensors file for each model component present in @@ -103,6 +103,28 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No - For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin" extension is replaced with ".safetensors" """ + weight_names = [ + WEIGHTS_NAME, + SAFETENSORS_WEIGHTS_NAME, + FLAX_WEIGHTS_NAME, + ONNX_WEIGHTS_NAME, + ONNX_EXTERNAL_WEIGHTS_NAME, + ] + + if is_transformers_available(): + weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME] + + # model_pytorch, diffusion_model_pytorch, ... + weight_prefixes = [w.split(".")[0] for w in weight_names] + # .bin, .safetensors, ... + weight_suffixs = [w.split(".")[-1] for w in weight_names] + # -00001-of-00002 + transformers_index_format = r"\d{5}-of-\d{5}" + # `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors` + non_variant_file_re = re.compile( + rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$" + ) + passed_components = passed_components or [] if folder_names: filenames = {f for f in filenames if os.path.split(f)[0] in folder_names} @@ -130,6 +152,8 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No for component, component_filenames in components.items(): matches = [] for component_filename in component_filenames: + if variant is None: + component_filename = filter_with_regex(component_filename, non_variant_file_re) filename, extension = os.path.splitext(component_filename) match_exists = extension == ".safetensors" @@ -158,6 +182,8 @@ def filter_model_files(filenames): return [f for f in filenames if any(f.endswith(extension) for extension in allowed_extensions)] +def filter_with_regex(filenames, pattern_re): + return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None} def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -> Union[List[os.PathLike], str]: weight_names = [ @@ -207,9 +233,6 @@ def filter_for_compatible_extensions(filenames, ignore_patterns=None): # interested in the extension name return {f for f in filenames if not any(f.endswith(pat.lstrip("*.")) for pat in ignore_patterns)} - def filter_with_regex(filenames, pattern_re): - return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None} - # Group files by component components = {} for filename in filenames: @@ -997,7 +1020,7 @@ def _get_ignore_patterns( use_safetensors and not allow_pickle and not is_safetensors_compatible( - model_filenames, passed_components=passed_components, folder_names=model_folder_names + model_filenames, passed_components=passed_components, folder_names=model_folder_names, variant=variant ) ): raise EnvironmentError( @@ -1008,7 +1031,7 @@ def _get_ignore_patterns( ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"] elif use_safetensors and is_safetensors_compatible( - model_filenames, passed_components=passed_components, folder_names=model_folder_names + model_filenames, passed_components=passed_components, folder_names=model_folder_names, variant=variant ): ignore_patterns = ["*.bin", "*.msgpack"] From 6bf814f12897e13a45444087316c17e7d4df76d3 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Tue, 20 May 2025 13:46:06 -0400 Subject: [PATCH 2/8] adjust Signed-off-by: Liu, Kaixuan --- src/diffusers/pipelines/pipeline_loading_utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 75ad4cfc1852..f30e72ecd740 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -121,6 +121,9 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No # -00001-of-00002 transformers_index_format = r"\d{5}-of-\d{5}" # `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors` + variant_file_re = re.compile( + rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$" + ) non_variant_file_re = re.compile( rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$" ) @@ -151,9 +154,11 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No # if variant is provided check if the variant of the safetensors exists for component, component_filenames in components.items(): matches = [] + if variant is not None: + component_filenames = filter_with_regex(component_filenames, variant_file_re) + else: + component_filenames = filter_with_regex(component_filenames, non_variant_file_re) for component_filename in component_filenames: - if variant is None: - component_filename = filter_with_regex(component_filename, non_variant_file_re) filename, extension = os.path.splitext(component_filename) match_exists = extension == ".safetensors" From 445f5cc539267bdc5ed94ac0e9bfb11b93e00e13 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 20 May 2025 08:59:14 +0000 Subject: [PATCH 3/8] Apply style fixes --- src/diffusers/pipelines/pipeline_loading_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index f30e72ecd740..a61c99f31150 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -187,9 +187,11 @@ def filter_model_files(filenames): return [f for f in filenames if any(f.endswith(extension) for extension in allowed_extensions)] + def filter_with_regex(filenames, pattern_re): return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None} + def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -> Union[List[os.PathLike], str]: weight_names = [ WEIGHTS_NAME, From 8f578cae34d352da491f159509ef47daadaa8042 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Tue, 20 May 2025 19:34:33 -0400 Subject: [PATCH 4/8] add the variant check when there are no component folders Signed-off-by: Liu, Kaixuan --- src/diffusers/pipelines/pipeline_loading_utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index a61c99f31150..032a6e161bbd 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -147,7 +147,11 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No # If there are no component folders check the main directory for safetensors files if not components: - return any(".safetensors" in filename for filename in filenames) + if variant is not None: + filtered_filenames = filter_with_regex(filenames, variant_file_re) + else: + filtered_filenames = filter_with_regex(filenames, non_variant_file_re) + return any(".safetensors" in filename for filename in filtered_filenames) # iterate over all files of a component # check if safetensor files exist for that component @@ -155,10 +159,10 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No for component, component_filenames in components.items(): matches = [] if variant is not None: - component_filenames = filter_with_regex(component_filenames, variant_file_re) + filtered_component_filenames = filter_with_regex(component_filenames, variant_file_re) else: - component_filenames = filter_with_regex(component_filenames, non_variant_file_re) - for component_filename in component_filenames: + filtered_component_filenames = filter_with_regex(component_filenames, non_variant_file_re) + for component_filename in filtered_component_filenames: filename, extension = os.path.splitext(component_filename) match_exists = extension == ".safetensors" From d408a653723e4351f2f803284452b7954f730eca Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Tue, 20 May 2025 20:18:44 -0400 Subject: [PATCH 5/8] update related test cases Signed-off-by: Liu, Kaixuan --- tests/pipelines/test_pipeline_utils.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py index 423c2b8ab146..f680cf2dcf18 100644 --- a/tests/pipelines/test_pipeline_utils.py +++ b/tests/pipelines/test_pipeline_utils.py @@ -87,21 +87,24 @@ def test_all_is_compatible_variant(self): "unet/diffusion_pytorch_model.fp16.bin", "unet/diffusion_pytorch_model.fp16.safetensors", ] - self.assertTrue(is_safetensors_compatible(filenames)) + self.assertFalse(is_safetensors_compatible(filenames)) + self.assertTrue(is_safetensors_compatible(filenames, variant="fp16")) def test_diffusers_model_is_compatible_variant(self): filenames = [ "unet/diffusion_pytorch_model.fp16.bin", "unet/diffusion_pytorch_model.fp16.safetensors", ] - self.assertTrue(is_safetensors_compatible(filenames)) + self.assertFalse(is_safetensors_compatible(filenames)) + self.assertTrue(is_safetensors_compatible(filenames, variant="fp16")) def test_diffusers_model_is_compatible_variant_mixed(self): filenames = [ "unet/diffusion_pytorch_model.bin", "unet/diffusion_pytorch_model.fp16.safetensors", ] - self.assertTrue(is_safetensors_compatible(filenames)) + self.assertFalse(is_safetensors_compatible(filenames)) + self.assertTrue(is_safetensors_compatible(filenames, variant="fp16")) def test_diffusers_model_is_not_compatible_variant(self): filenames = [ @@ -121,7 +124,8 @@ def test_transformer_model_is_compatible_variant(self): "text_encoder/pytorch_model.fp16.bin", "text_encoder/model.fp16.safetensors", ] - self.assertTrue(is_safetensors_compatible(filenames)) + self.assertFalse(is_safetensors_compatible(filenames)) + self.assertTrue(is_safetensors_compatible(filenames, variant="fp16")) def test_transformer_model_is_not_compatible_variant(self): filenames = [ @@ -145,7 +149,8 @@ def test_transformer_model_is_compatible_variant_extra_folder(self): "unet/diffusion_pytorch_model.fp16.bin", "unet/diffusion_pytorch_model.fp16.safetensors", ] - self.assertTrue(is_safetensors_compatible(filenames, folder_names={"vae", "unet"})) + self.assertFalse(is_safetensors_compatible(filenames, folder_names={"vae", "unet"})) + self.assertTrue(is_safetensors_compatible(filenames, folder_names={"vae", "unet"}, variant="fp16")) def test_transformer_model_is_not_compatible_variant_extra_folder(self): filenames = [ @@ -173,7 +178,8 @@ def test_transformers_is_compatible_variant_sharded(self): "text_encoder/model.fp16-00001-of-00002.safetensors", "text_encoder/model.fp16-00001-of-00002.safetensors", ] - self.assertTrue(is_safetensors_compatible(filenames)) + self.assertFalse(is_safetensors_compatible(filenames)) + self.assertTrue(is_safetensors_compatible(filenames, variant="fp16")) def test_diffusers_is_compatible_sharded(self): filenames = [ @@ -189,13 +195,15 @@ def test_diffusers_is_compatible_variant_sharded(self): "unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors", "unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors", ] - self.assertTrue(is_safetensors_compatible(filenames)) + self.assertFalse(is_safetensors_compatible(filenames)) + self.assertTrue(is_safetensors_compatible(filenames, variant="fp16")) def test_diffusers_is_compatible_only_variants(self): filenames = [ "unet/diffusion_pytorch_model.fp16.safetensors", ] - self.assertTrue(is_safetensors_compatible(filenames)) + self.assertFalse(is_safetensors_compatible(filenames)) + self.assertTrue(is_safetensors_compatible(filenames, variant="fp16")) def test_diffusers_is_compatible_no_components(self): filenames = [ From 7a9d3bd50a02b5c6f2caef1e56991d2fe09f9f72 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Wed, 21 May 2025 10:54:40 -0400 Subject: [PATCH 6/8] update related unit test cases Signed-off-by: Liu, Kaixuan --- tests/pipelines/test_pipelines.py | 99 +++++++++++++++++-------------- 1 file changed, 54 insertions(+), 45 deletions(-) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index caa7755904a5..15d8f6587028 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -167,9 +167,9 @@ def test_one_request_upon_cached(self): download_requests = [r.method for r in m.request_history] assert download_requests.count("HEAD") == 15, "15 calls to files" assert download_requests.count("GET") == 17, "15 calls to files + model_info + model_index.json" - assert len(download_requests) == 32, ( - "2 calls per file (15 files) + send_telemetry, model_info and model_index.json" - ) + assert ( + len(download_requests) == 32 + ), "2 calls per file (15 files) + send_telemetry, model_info and model_index.json" with requests_mock.mock(real_http=True) as m: DiffusionPipeline.download( @@ -179,9 +179,9 @@ def test_one_request_upon_cached(self): cache_requests = [r.method for r in m.request_history] assert cache_requests.count("HEAD") == 1, "model_index.json is only HEAD" assert cache_requests.count("GET") == 1, "model info is only GET" - assert len(cache_requests) == 2, ( - "We should call only `model_info` to check for _commit hash and `send_telemetry`" - ) + assert ( + len(cache_requests) == 2 + ), "We should call only `model_info` to check for _commit hash and `send_telemetry`" def test_less_downloads_passed_object(self): with tempfile.TemporaryDirectory() as tmpdirname: @@ -217,9 +217,9 @@ def test_less_downloads_passed_object_calls(self): assert download_requests.count("HEAD") == 13, "13 calls to files" # 17 - 2 because no call to config or model file for `safety_checker` assert download_requests.count("GET") == 15, "13 calls to files + model_info + model_index.json" - assert len(download_requests) == 28, ( - "2 calls per file (13 files) + send_telemetry, model_info and model_index.json" - ) + assert ( + len(download_requests) == 28 + ), "2 calls per file (13 files) + send_telemetry, model_info and model_index.json" with requests_mock.mock(real_http=True) as m: DiffusionPipeline.download( @@ -229,9 +229,9 @@ def test_less_downloads_passed_object_calls(self): cache_requests = [r.method for r in m.request_history] assert cache_requests.count("HEAD") == 1, "model_index.json is only HEAD" assert cache_requests.count("GET") == 1, "model info is only GET" - assert len(cache_requests) == 2, ( - "We should call only `model_info` to check for _commit hash and `send_telemetry`" - ) + assert ( + len(cache_requests) == 2 + ), "We should call only `model_info` to check for _commit hash and `send_telemetry`" def test_download_only_pytorch(self): with tempfile.TemporaryDirectory() as tmpdirname: @@ -538,26 +538,38 @@ def test_download_variant_partly(self): variant = "no_ema" with tempfile.TemporaryDirectory() as tmpdirname: - tmpdirname = StableDiffusionPipeline.download( - "hf-internal-testing/stable-diffusion-all-variants", - cache_dir=tmpdirname, - variant=variant, - use_safetensors=use_safetensors, - ) - all_root_files = [t[-1] for t in os.walk(tmpdirname)] - files = [item for sublist in all_root_files for item in sublist] - - unet_files = os.listdir(os.path.join(tmpdirname, "unet")) + if use_safetensors: + with self.assertRaises(OSError) as error_context: + tmpdirname = StableDiffusionPipeline.download( + "hf-internal-testing/stable-diffusion-all-variants", + cache_dir=tmpdirname, + variant=variant, + use_safetensors=use_safetensors, + ) + assert "Could not find the necessary `safetensors` weights" in str(error_context.exception) + else: + tmpdirname = StableDiffusionPipeline.download( + "hf-internal-testing/stable-diffusion-all-variants", + cache_dir=tmpdirname, + variant=variant, + use_safetensors=use_safetensors, + ) + all_root_files = [t[-1] for t in os.walk(tmpdirname)] + files = [item for sublist in all_root_files for item in sublist] - # Some of the downloaded files should be a non-variant file, check: - # https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet - assert len(files) == 15, f"We should only download 15 files, not {len(files)}" - # only unet has "no_ema" variant - assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files - assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1 - # vae, safety_checker and text_encoder should have no variant - assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3 - assert not any(f.endswith(other_format) for f in files) + unet_files = os.listdir(os.path.join(tmpdirname, "unet")) + + # Some of the downloaded files should be a non-variant file, check: + # https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet + assert len(files) == 15, f"We should only download 15 files, not {len(files)}" + # only unet has "no_ema" variant + assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files + assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1 + # vae, safety_checker and text_encoder should have no variant + assert ( + sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3 + ) + assert not any(f.endswith(other_format) for f in files) def test_download_variants_with_sharded_checkpoints(self): # Here we test for downloading of "variant" files belonging to the `unet` and @@ -588,19 +600,16 @@ def test_download_legacy_variants_with_sharded_ckpts_raises_warning(self): logger = logging.get_logger("diffusers.pipelines.pipeline_utils") deprecated_warning_msg = "Warning: The repository contains sharded checkpoints for variant" - for is_local in [True, False]: - with CaptureLogger(logger) as cap_logger: - with tempfile.TemporaryDirectory() as tmpdirname: - local_repo_id = repo_id - if is_local: - local_repo_id = snapshot_download(repo_id, cache_dir=tmpdirname) + with CaptureLogger(logger) as cap_logger: + with tempfile.TemporaryDirectory() as tmpdirname: + local_repo_id = snapshot_download(repo_id, cache_dir=tmpdirname) - _ = DiffusionPipeline.from_pretrained( - local_repo_id, - safety_checker=None, - variant="fp16", - use_safetensors=True, - ) + _ = DiffusionPipeline.from_pretrained( + local_repo_id, + safety_checker=None, + variant="fp16", + use_safetensors=True, + ) assert deprecated_warning_msg in str(cap_logger), "Deprecation warning not found in logs" def test_download_safetensors_only_variant_exists_for_model(self): @@ -616,7 +625,7 @@ def test_download_safetensors_only_variant_exists_for_model(self): variant=variant, use_safetensors=use_safetensors, ) - assert "Error no file name" in str(error_context.exception) + assert "Could not find the necessary `safetensors` weights" in str(error_context.exception) # text encoder has fp16 variants so we can load it with tempfile.TemporaryDirectory() as tmpdirname: @@ -675,7 +684,7 @@ def test_download_safetensors_variant_does_not_exist_for_model(self): use_safetensors=use_safetensors, ) - assert "Error no file name" in str(error_context.exception) + assert "Could not find the necessary `safetensors` weights" in str(error_context.exception) def test_download_bin_variant_does_not_exist_for_model(self): variant = "no_ema" From 8d21655e8fd8bf495fd5fbb8084cd3153e505981 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Wed, 21 May 2025 11:14:19 -0400 Subject: [PATCH 7/8] adjust Signed-off-by: Liu, Kaixuan --- tests/pipelines/test_pipelines.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 15d8f6587028..7f10b7d067fb 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -610,7 +610,7 @@ def test_download_legacy_variants_with_sharded_ckpts_raises_warning(self): variant="fp16", use_safetensors=True, ) - assert deprecated_warning_msg in str(cap_logger), "Deprecation warning not found in logs" + assert deprecated_warning_msg in str(cap_logger), "Deprecation warning not found in logs" def test_download_safetensors_only_variant_exists_for_model(self): variant = None From 0ef3590de1aa70ba3c8a3f31d47e8b402af46359 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 21 May 2025 07:14:58 +0000 Subject: [PATCH 8/8] Apply style fixes --- tests/pipelines/test_pipelines.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 7f10b7d067fb..a2241236da20 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -167,9 +167,9 @@ def test_one_request_upon_cached(self): download_requests = [r.method for r in m.request_history] assert download_requests.count("HEAD") == 15, "15 calls to files" assert download_requests.count("GET") == 17, "15 calls to files + model_info + model_index.json" - assert ( - len(download_requests) == 32 - ), "2 calls per file (15 files) + send_telemetry, model_info and model_index.json" + assert len(download_requests) == 32, ( + "2 calls per file (15 files) + send_telemetry, model_info and model_index.json" + ) with requests_mock.mock(real_http=True) as m: DiffusionPipeline.download( @@ -179,9 +179,9 @@ def test_one_request_upon_cached(self): cache_requests = [r.method for r in m.request_history] assert cache_requests.count("HEAD") == 1, "model_index.json is only HEAD" assert cache_requests.count("GET") == 1, "model info is only GET" - assert ( - len(cache_requests) == 2 - ), "We should call only `model_info` to check for _commit hash and `send_telemetry`" + assert len(cache_requests) == 2, ( + "We should call only `model_info` to check for _commit hash and `send_telemetry`" + ) def test_less_downloads_passed_object(self): with tempfile.TemporaryDirectory() as tmpdirname: @@ -217,9 +217,9 @@ def test_less_downloads_passed_object_calls(self): assert download_requests.count("HEAD") == 13, "13 calls to files" # 17 - 2 because no call to config or model file for `safety_checker` assert download_requests.count("GET") == 15, "13 calls to files + model_info + model_index.json" - assert ( - len(download_requests) == 28 - ), "2 calls per file (13 files) + send_telemetry, model_info and model_index.json" + assert len(download_requests) == 28, ( + "2 calls per file (13 files) + send_telemetry, model_info and model_index.json" + ) with requests_mock.mock(real_http=True) as m: DiffusionPipeline.download( @@ -229,9 +229,9 @@ def test_less_downloads_passed_object_calls(self): cache_requests = [r.method for r in m.request_history] assert cache_requests.count("HEAD") == 1, "model_index.json is only HEAD" assert cache_requests.count("GET") == 1, "model info is only GET" - assert ( - len(cache_requests) == 2 - ), "We should call only `model_info` to check for _commit hash and `send_telemetry`" + assert len(cache_requests) == 2, ( + "We should call only `model_info` to check for _commit hash and `send_telemetry`" + ) def test_download_only_pytorch(self): with tempfile.TemporaryDirectory() as tmpdirname: