11import torch
22import torch .nn .functional as F
33import pytest
4- from torchlpc .core import lpc_np
4+ from torchlpc .core import lpc_np , lpc_cuda
55
66
77from .test_grad import create_test_inputs
1515 "cmplx" ,
1616 [True , False ],
1717)
18- def test_scan_cpu_equiv (samples : int , cmplx : bool ):
18+ @pytest .mark .parametrize (
19+ "device" ,
20+ [
21+ "cpu" ,
22+ pytest .param (
23+ "cuda" ,
24+ marks = pytest .mark .skipif (
25+ not torch .cuda .is_available (), reason = "CUDA not available"
26+ ),
27+ ),
28+ ],
29+ )
30+ def test_scan_equiv (samples : int , cmplx : bool , device : str ):
1931 batch_size = 4
2032 x = torch .randn (
21- batch_size , samples , dtype = torch .float32 if not cmplx else torch .complex64
33+ batch_size ,
34+ samples ,
35+ dtype = torch .float32 if not cmplx else torch .complex64 ,
36+ device = device ,
2237 )
23- A = torch .rand_like (x ) * 1.8 - 0.9
24- zi = torch .randn (batch_size , dtype = x .dtype )
25-
26- numba_y = torch .from_numpy (
27- lpc_np (
28- x .cpu ().numpy (),
29- - A .cpu ().unsqueeze (2 ).numpy (),
30- zi .cpu ().unsqueeze (1 ).numpy (),
38+ if cmplx :
39+ A = torch .rand (
40+ batch_size , samples , dtype = x .dtype , device = device
41+ ).sqrt () * torch .exp (
42+ 2j
43+ * torch .rand (batch_size , samples , dtype = x .dtype , device = device )
44+ * torch .pi
3145 )
32- )
33- ext_y = torch .ops .torchlpc .scan_cpu (x , A , zi )
46+ else :
47+ A = torch .rand_like (x ) * 1.8 - 0.9
48+ zi = torch .randn (batch_size , dtype = x .dtype , device = device )
3449
35- assert torch .allclose (numba_y , ext_y )
50+ if device == "cuda" :
51+ numba_y = lpc_cuda (x , - A .unsqueeze (2 ), zi .unsqueeze (1 ))
52+ else :
53+ numba_y = torch .from_numpy (
54+ lpc_np (
55+ x .cpu ().numpy (),
56+ - A .cpu ().unsqueeze (2 ).numpy (),
57+ zi .cpu ().unsqueeze (1 ).numpy (),
58+ )
59+ )
60+ ext_y = torch .ops .torchlpc .scan (x , A , zi )
61+
62+ assert torch .allclose (numba_y , ext_y , atol = 5e-7 ), torch .max (
63+ torch .abs (numba_y - ext_y )
64+ ).item ()
3665
3766
3867@pytest .mark .parametrize (
@@ -43,12 +72,12 @@ def test_scan_cpu_equiv(samples: int, cmplx: bool):
4372 "cmplx" ,
4473 [True , False ],
4574)
46- def test_lpc_cpu_equiv (samples : int , cmplx : bool ):
75+ def test_lpc_equiv (samples : int , cmplx : bool ):
4776 batch_size = 4
4877 x , A , zi = tuple (
4978 x .to ("cpu" ) for x in create_test_inputs (batch_size , samples , cmplx )
5079 )
5180 numba_y = torch .from_numpy (lpc_np (x .numpy (), A .numpy (), zi .numpy ()))
52- ext_y = torch .ops .torchlpc .lpc_cpu (x , A , zi )
81+ ext_y = torch .ops .torchlpc .lpc (x , A , zi )
5382
5483 assert torch .allclose (numba_y , ext_y )
0 commit comments