Skip to content

Commit f14ec94

Browse files
committed
fix marigold ut case fail on xpu
Signed-off-by: Yao, Matrix <matrix.yao@intel.com>
1 parent 7e7e62c commit f14ec94

File tree

2 files changed

+24
-13
lines changed

2 files changed

+24
-13
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
if is_transformers_available():
4949
import transformers
5050
from transformers import PreTrainedModel, PreTrainedTokenizerBase
51-
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
5251
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
5352
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
5453

@@ -112,7 +111,7 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
112111
]
113112

114113
if is_transformers_available():
115-
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
114+
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
116115

117116
# model_pytorch, diffusion_model_pytorch, ...
118117
weight_prefixes = [w.split(".")[0] for w in weight_names]
@@ -191,7 +190,7 @@ def filter_model_files(filenames):
191190
]
192191

193192
if is_transformers_available():
194-
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
193+
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
195194

196195
allowed_extensions = [wn.split(".")[-1] for wn in weight_names]
197196

@@ -212,7 +211,7 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
212211
]
213212

214213
if is_transformers_available():
215-
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
214+
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
216215

217216
# model_pytorch, diffusion_model_pytorch, ...
218217
weight_prefixes = [w.split(".")[0] for w in weight_names]

tests/pipelines/marigold/test_marigold_depth.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
)
3434

3535
from ...testing_utils import (
36+
Expectations,
3637
backend_empty_cache,
3738
enable_full_determinism,
3839
floats_tensor,
@@ -356,7 +357,7 @@ def test_marigold_depth_einstein_f32_cpu_G0_S1_P32_E1_B1_M1(self):
356357
match_input_resolution=True,
357358
)
358359

359-
def test_marigold_depth_einstein_f32_cuda_G0_S1_P768_E1_B1_M1(self):
360+
def test_marigold_depth_einstein_f32_accelerator_G0_S1_P768_E1_B1_M1(self):
360361
self._test_marigold_depth(
361362
is_fp16=False,
362363
device=torch_device,
@@ -369,7 +370,7 @@ def test_marigold_depth_einstein_f32_cuda_G0_S1_P768_E1_B1_M1(self):
369370
match_input_resolution=True,
370371
)
371372

372-
def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E1_B1_M1(self):
373+
def test_marigold_depth_einstein_f16_accelerator_G0_S1_P768_E1_B1_M1(self):
373374
self._test_marigold_depth(
374375
is_fp16=True,
375376
device=torch_device,
@@ -382,7 +383,7 @@ def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E1_B1_M1(self):
382383
match_input_resolution=True,
383384
)
384385

385-
def test_marigold_depth_einstein_f16_cuda_G2024_S1_P768_E1_B1_M1(self):
386+
def test_marigold_depth_einstein_f16_accelerator_G2024_S1_P768_E1_B1_M1(self):
386387
self._test_marigold_depth(
387388
is_fp16=True,
388389
device=torch_device,
@@ -395,20 +396,31 @@ def test_marigold_depth_einstein_f16_cuda_G2024_S1_P768_E1_B1_M1(self):
395396
match_input_resolution=True,
396397
)
397398

398-
def test_marigold_depth_einstein_f16_cuda_G0_S2_P768_E1_B1_M1(self):
399+
def test_marigold_depth_einstein_f16_accelerator_G0_S2_P768_E1_B1_M1(self):
400+
# fmt: off
401+
expected_slices = Expectations(
402+
{
403+
("cuda", 7): np.array([0.1085, 0.1098, 0.1110, 0.1081, 0.1085, 0.1082, 0.1085, 0.1057, 0.0996]),
404+
("xpu", 3): np.array([0.1084, 0.1096, 0.1108, 0.1080, 0.1083, 0.1080,
405+
0.1085, 0.1057, 0.0996]),
406+
}
407+
)
408+
expected_slice = expected_slices.get_expectation()
409+
# fmt: on
410+
399411
self._test_marigold_depth(
400412
is_fp16=True,
401413
device=torch_device,
402414
generator_seed=0,
403-
expected_slice=np.array([0.1085, 0.1098, 0.1110, 0.1081, 0.1085, 0.1082, 0.1085, 0.1057, 0.0996]),
415+
expected_slice=expected_slice,
404416
num_inference_steps=2,
405417
processing_resolution=768,
406418
ensemble_size=1,
407419
batch_size=1,
408420
match_input_resolution=True,
409421
)
410422

411-
def test_marigold_depth_einstein_f16_cuda_G0_S1_P512_E1_B1_M1(self):
423+
def test_marigold_depth_einstein_f16_accelerator_G0_S1_P512_E1_B1_M1(self):
412424
self._test_marigold_depth(
413425
is_fp16=True,
414426
device=torch_device,
@@ -421,7 +433,7 @@ def test_marigold_depth_einstein_f16_cuda_G0_S1_P512_E1_B1_M1(self):
421433
match_input_resolution=True,
422434
)
423435

424-
def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E3_B1_M1(self):
436+
def test_marigold_depth_einstein_f16_accelerator_G0_S1_P768_E3_B1_M1(self):
425437
self._test_marigold_depth(
426438
is_fp16=True,
427439
device=torch_device,
@@ -435,7 +447,7 @@ def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E3_B1_M1(self):
435447
match_input_resolution=True,
436448
)
437449

438-
def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E4_B2_M1(self):
450+
def test_marigold_depth_einstein_f16_accelerator_G0_S1_P768_E4_B2_M1(self):
439451
self._test_marigold_depth(
440452
is_fp16=True,
441453
device=torch_device,
@@ -449,7 +461,7 @@ def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E4_B2_M1(self):
449461
match_input_resolution=True,
450462
)
451463

452-
def test_marigold_depth_einstein_f16_cuda_G0_S1_P512_E1_B1_M0(self):
464+
def test_marigold_depth_einstein_f16_accelerator_G0_S1_P512_E1_B1_M0(self):
453465
self._test_marigold_depth(
454466
is_fp16=True,
455467
device=torch_device,

0 commit comments

Comments
 (0)