Skip to content

Commit 206497e

Browse files
committed
refactor: update tests to include device and complex parameterization for scan and lpc functions
1 parent 98f314f commit 206497e

File tree

2 files changed

+64
-19
lines changed

2 files changed

+64
-19
lines changed

tests/test_extension.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import torch.nn.functional as F
33
import pytest
4-
from torchlpc.core import lpc_np
4+
from torchlpc.core import lpc_np, lpc_cuda
55

66

77
from .test_grad import create_test_inputs
@@ -15,24 +15,53 @@
1515
"cmplx",
1616
[True, False],
1717
)
18-
def test_scan_cpu_equiv(samples: int, cmplx: bool):
18+
@pytest.mark.parametrize(
19+
"device",
20+
[
21+
"cpu",
22+
pytest.param(
23+
"cuda",
24+
marks=pytest.mark.skipif(
25+
not torch.cuda.is_available(), reason="CUDA not available"
26+
),
27+
),
28+
],
29+
)
30+
def test_scan_equiv(samples: int, cmplx: bool, device: str):
1931
batch_size = 4
2032
x = torch.randn(
21-
batch_size, samples, dtype=torch.float32 if not cmplx else torch.complex64
33+
batch_size,
34+
samples,
35+
dtype=torch.float32 if not cmplx else torch.complex64,
36+
device=device,
2237
)
23-
A = torch.rand_like(x) * 1.8 - 0.9
24-
zi = torch.randn(batch_size, dtype=x.dtype)
25-
26-
numba_y = torch.from_numpy(
27-
lpc_np(
28-
x.cpu().numpy(),
29-
-A.cpu().unsqueeze(2).numpy(),
30-
zi.cpu().unsqueeze(1).numpy(),
38+
if cmplx:
39+
A = torch.rand(
40+
batch_size, samples, dtype=x.dtype, device=device
41+
).sqrt() * torch.exp(
42+
2j
43+
* torch.rand(batch_size, samples, dtype=x.dtype, device=device)
44+
* torch.pi
3145
)
32-
)
33-
ext_y = torch.ops.torchlpc.scan_cpu(x, A, zi)
46+
else:
47+
A = torch.rand_like(x) * 1.8 - 0.9
48+
zi = torch.randn(batch_size, dtype=x.dtype, device=device)
3449

35-
assert torch.allclose(numba_y, ext_y)
50+
if device == "cuda":
51+
numba_y = lpc_cuda(x, -A.unsqueeze(2), zi.unsqueeze(1))
52+
else:
53+
numba_y = torch.from_numpy(
54+
lpc_np(
55+
x.cpu().numpy(),
56+
-A.cpu().unsqueeze(2).numpy(),
57+
zi.cpu().unsqueeze(1).numpy(),
58+
)
59+
)
60+
ext_y = torch.ops.torchlpc.scan(x, A, zi)
61+
62+
assert torch.allclose(numba_y, ext_y, atol=5e-7), torch.max(
63+
torch.abs(numba_y - ext_y)
64+
).item()
3665

3766

3867
@pytest.mark.parametrize(
@@ -43,12 +72,12 @@ def test_scan_cpu_equiv(samples: int, cmplx: bool):
4372
"cmplx",
4473
[True, False],
4574
)
46-
def test_lpc_cpu_equiv(samples: int, cmplx: bool):
75+
def test_lpc_equiv(samples: int, cmplx: bool):
4776
batch_size = 4
4877
x, A, zi = tuple(
4978
x.to("cpu") for x in create_test_inputs(batch_size, samples, cmplx)
5079
)
5180
numba_y = torch.from_numpy(lpc_np(x.numpy(), A.numpy(), zi.numpy()))
52-
ext_y = torch.ops.torchlpc.lpc_cpu(x, A, zi)
81+
ext_y = torch.ops.torchlpc.lpc(x, A, zi)
5382

5483
assert torch.allclose(numba_y, ext_y)

tests/test_grad.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,10 @@ def test_float64_vs_32_cuda():
123123
"zi_requires_grad",
124124
[True, False],
125125
)
126+
@pytest.mark.parametrize(
127+
"cmplx",
128+
[True, False],
129+
)
126130
@pytest.mark.parametrize(
127131
"device",
128132
[
@@ -139,13 +143,25 @@ def test_parallel_scan(
139143
x_requires_grad: bool,
140144
a_requires_grad: bool,
141145
zi_requires_grad: bool,
146+
cmplx: bool,
142147
device: str,
143148
):
144149
batch_size = 2
145150
samples = 123
146-
x = torch.randn(batch_size, samples, dtype=torch.double, device=device)
147-
A = torch.rand(batch_size, samples, dtype=torch.double, device=device) * 2 - 1
148-
zi = torch.randn(batch_size, dtype=torch.double, device=device)
151+
dtype = torch.complex128 if cmplx else torch.double
152+
x = torch.randn(batch_size, samples, dtype=dtype, device=device)
153+
if cmplx:
154+
A = torch.rand(
155+
batch_size, samples, dtype=torch.double, device=device
156+
).sqrt() * torch.exp(
157+
1j
158+
* torch.rand(batch_size, samples, dtype=torch.double, device=device)
159+
* 2
160+
* torch.pi
161+
)
162+
else:
163+
A = torch.rand(batch_size, samples, dtype=dtype, device=device) * 2 - 1
164+
zi = torch.randn(batch_size, dtype=dtype, device=device)
149165

150166
A.requires_grad = a_requires_grad
151167
x.requires_grad = x_requires_grad

0 commit comments

Comments
 (0)