|
| 1 | +import numpy |
| 2 | +import pytest |
| 3 | + |
| 4 | +import dpnp as cp |
| 5 | +from dpnp.tests.third_party.cupy import testing |
| 6 | + |
| 7 | +# TODO: remove once all dtype aliases added |
| 8 | +cp.int8 = numpy.int8 |
| 9 | +cp.uint8 = numpy.uint8 |
| 10 | +cp.int16 = numpy.int16 |
| 11 | + |
| 12 | +# "example string" or |
| 13 | +# ("example string", "xfail message") |
| 14 | +examples = [ |
| 15 | + "uint8(1) + 2", |
| 16 | + "array([1], uint8) + int64(1)", |
| 17 | + "array([1], uint8) + array(1, int64)", |
| 18 | + "array([1.], float32) + float64(1.)", |
| 19 | + "array([1.], float32) + array(1., float64)", |
| 20 | + "array([1], uint8) + 1", |
| 21 | + "array([1], uint8) + 200", |
| 22 | + "array([100], uint8) + 200", |
| 23 | + "array([1], uint8) + 300", |
| 24 | + "uint8(1) + 300", |
| 25 | + "uint8(100) + 200", |
| 26 | + "float32(1) + 3e100", |
| 27 | + "array([1.0], float32) + 1e-14 == 1.0", |
| 28 | + "array([0.1], float32) == float64(0.1)", |
| 29 | + "array(1.0, float32) + 1e-14 == 1.0", |
| 30 | + "array([1.], float32) + 3", |
| 31 | + "array([1.], float32) + int64(3)", |
| 32 | + "3j + array(3, complex64)", |
| 33 | + "float32(1) + 1j", |
| 34 | + "int32(1) + 5j", |
| 35 | + # additional examples from the NEP text |
| 36 | + "int16(2) + 2", |
| 37 | + "int16(4) + 4j", |
| 38 | + "float32(5) + 5j", |
| 39 | + "bool_(True) + 1", |
| 40 | + "True + uint8(2)", |
| 41 | + # not in the NEP |
| 42 | + "1.0 + array([1, 2, 3], int8)", |
| 43 | + "array([1], float32) + 1j", |
| 44 | +] |
| 45 | + |
| 46 | + |
| 47 | +@testing.with_requires("numpy>=2.0") |
| 48 | +@pytest.mark.parametrize("example", examples) |
| 49 | +@testing.numpy_cupy_allclose(atol=1e-15, accept_error=OverflowError) |
| 50 | +def test_nep50_examples(xp, example): |
| 51 | + dct = { |
| 52 | + "array": xp.array, |
| 53 | + "uint8": xp.uint8, |
| 54 | + "int64": xp.int64, |
| 55 | + "float32": xp.float32, |
| 56 | + "float64": xp.float64, |
| 57 | + "int16": xp.int16, |
| 58 | + "bool_": xp.bool_, |
| 59 | + "int32": xp.int32, |
| 60 | + "complex64": xp.complex64, |
| 61 | + "int8": xp.int8, |
| 62 | + } |
| 63 | + |
| 64 | + if isinstance(example, tuple): |
| 65 | + example, mesg = example |
| 66 | + pytest.xfail(mesg) |
| 67 | + |
| 68 | + result = eval(example, dct) |
| 69 | + return result |
0 commit comments