From 7ef0e442c1c5d631f626cbebaddd85ee0a946c2e Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Fri, 31 Jan 2025 14:25:39 -0800 Subject: [PATCH] Fix malformed tests in test_usm_ndarray_dlpack These tests would fail on machines with more than 2 devices for a given platform due to an incorrect asusmption that the DLPack device ID would match that of the cached root devices, of which only 2 are kept per platform --- dpctl/tests/test_usm_ndarray_dlpack.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dpctl/tests/test_usm_ndarray_dlpack.py b/dpctl/tests/test_usm_ndarray_dlpack.py index e63c37bd0b..85230267c3 100644 --- a/dpctl/tests/test_usm_ndarray_dlpack.py +++ b/dpctl/tests/test_usm_ndarray_dlpack.py @@ -84,7 +84,7 @@ def test_dlpack_device(usm_type, all_root_devices): assert type(dev) is tuple assert len(dev) == 2 assert dev[0] == device_oneAPI - assert sycl_dev == all_root_devices[dev[1]] + assert dev[1] == sycl_dev.get_device_id() def test_dlpack_exporter(typestr, usm_type, all_root_devices): @@ -834,7 +834,7 @@ def test_sycl_device_to_dldevice(all_root_devices): assert type(dev) is tuple assert len(dev) == 2 assert dev[0] == device_oneAPI - assert dev[1] == all_root_devices.index(sycl_dev) + assert dev[1] == sycl_dev.get_device_id() def test_dldevice_to_sycl_device(all_root_devices): @@ -842,7 +842,7 @@ def test_dldevice_to_sycl_device(all_root_devices): dldev = dpt.empty(0, device=sycl_dev).__dlpack_device__() dev = dpt.dldevice_to_sycl_device(dldev) assert type(dev) is dpctl.SyclDevice - assert dev == all_root_devices[dldev[1]] + assert dev.get_device_id() == sycl_dev.get_device_id() def test_dldevice_conversion_arg_validation():