diff --git a/dpctl/tests/test_usm_ndarray_dlpack.py b/dpctl/tests/test_usm_ndarray_dlpack.py index e63c37bd0b..85230267c3 100644 --- a/dpctl/tests/test_usm_ndarray_dlpack.py +++ b/dpctl/tests/test_usm_ndarray_dlpack.py @@ -84,7 +84,7 @@ def test_dlpack_device(usm_type, all_root_devices): assert type(dev) is tuple assert len(dev) == 2 assert dev[0] == device_oneAPI - assert sycl_dev == all_root_devices[dev[1]] + assert dev[1] == sycl_dev.get_device_id() def test_dlpack_exporter(typestr, usm_type, all_root_devices): @@ -834,7 +834,7 @@ def test_sycl_device_to_dldevice(all_root_devices): assert type(dev) is tuple assert len(dev) == 2 assert dev[0] == device_oneAPI - assert dev[1] == all_root_devices.index(sycl_dev) + assert dev[1] == sycl_dev.get_device_id() def test_dldevice_to_sycl_device(all_root_devices): @@ -842,7 +842,7 @@ def test_dldevice_to_sycl_device(all_root_devices): dldev = dpt.empty(0, device=sycl_dev).__dlpack_device__() dev = dpt.dldevice_to_sycl_device(dldev) assert type(dev) is dpctl.SyclDevice - assert dev == all_root_devices[dldev[1]] + assert dev.get_device_id() == sycl_dev.get_device_id() def test_dldevice_conversion_arg_validation():