Skip to content

Commit c662e34

Browse files
authored
feat: use original cuda scan from linear RNN (#22)
* feat: add linear recurrence with MIT license * refactor linear recurrence code * fix: correct ndims and nsteps * refactor: rename scan_cpu and lpc_cpu functions to scan and lpc * refactor: update function calls to use unified 'scan' operation * refactor: reorganize setup.py for building CUDA extensions * refactor: update tests to include device and complex parameterization for scan and lpc functions * refactor: implement separate recurrence functions for improved clarity and maintainability * refactor: create dummy _C module for python loading * fix: typo * apply copilot's suggestion * refactor: use channel-first format, swap the role of lane and warp to run faster
1 parent d372cee commit c662e34

File tree

9 files changed

+522
-79
lines changed

9 files changed

+522
-79
lines changed

setup.py

Lines changed: 54 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
11
import setuptools
2+
import os
3+
import glob
24
import torch
3-
from torch.utils import cpp_extension
5+
from torch.utils.cpp_extension import (
6+
CppExtension,
7+
CUDAExtension,
8+
BuildExtension,
9+
CUDA_HOME,
10+
)
411

5-
NAME = "torchlpc"
12+
library_name = "torchlpc"
613
VERSION = "0.7.dev"
714
MAINTAINER = "Chin-Yun Yu"
815
EMAIL = "chin-yun.yu@qmul.ac.uk"
@@ -12,15 +19,51 @@
1219
long_description = fh.read()
1320

1421

22+
# if torch.__version__ >= "2.6.0":
23+
# py_limited_api = True
24+
# else:
25+
py_limited_api = False
26+
27+
28+
def get_extensions():
29+
use_cuda = torch.cuda.is_available() and CUDA_HOME is not None
30+
use_openmp = torch.backends.openmp.is_available()
31+
extension = CUDAExtension if use_cuda else CppExtension
32+
33+
extra_link_args = []
34+
extra_compile_args = {}
35+
if use_openmp:
36+
extra_compile_args["cxx"] = ["-fopenmp"]
37+
extra_link_args.append("-lgomp")
38+
39+
this_dir = os.path.abspath(os.path.dirname(__file__))
40+
extensions_dir = os.path.join(this_dir, library_name, "csrc")
41+
sources = list(glob.glob(os.path.join(extensions_dir, "*.cpp")))
42+
43+
extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
44+
cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "*.cu")))
45+
46+
if use_cuda:
47+
sources += cuda_sources
48+
49+
ext_modules = [
50+
extension(
51+
f"{library_name}._C",
52+
sources,
53+
extra_compile_args=extra_compile_args,
54+
extra_link_args=extra_link_args,
55+
py_limited_api=py_limited_api,
56+
)
57+
]
58+
59+
return ext_modules
60+
61+
1562
extra_link_args = []
1663
extra_compile_args = {}
17-
# check if openmp is available
18-
if torch.backends.openmp.is_available():
19-
extra_compile_args["cxx"] = ["-fopenmp"]
20-
extra_link_args.append("-lgomp")
2164

2265
setuptools.setup(
23-
name=NAME,
66+
name=library_name,
2467
version=VERSION,
2568
author=MAINTAINER,
2669
author_email=EMAIL,
@@ -32,16 +75,10 @@
3275
install_requires=["torch>=2.0", "numpy", "numba"],
3376
classifiers=[
3477
"Programming Language :: Python :: 3",
35-
"License :: OSI Approved :: MIT License",
3678
"Operating System :: OS Independent",
3779
],
38-
ext_modules=[
39-
cpp_extension.CppExtension(
40-
"torchlpc._C",
41-
["torchlpc/csrc/scan_cpu.cpp"],
42-
extra_compile_args=extra_compile_args,
43-
extra_link_args=extra_link_args,
44-
)
45-
],
46-
cmdclass={"build_ext": cpp_extension.BuildExtension},
80+
license="MIT",
81+
ext_modules=get_extensions(),
82+
cmdclass={"build_ext": BuildExtension},
83+
options={"bdist_wheel": {"py_limited_api": "cp39"}} if py_limited_api else {},
4784
)

tests/test_extension.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import torch.nn.functional as F
33
import pytest
4-
from torchlpc.core import lpc_np
4+
from torchlpc.core import lpc_np, lpc_cuda
55

66

77
from .test_grad import create_test_inputs
@@ -15,24 +15,53 @@
1515
"cmplx",
1616
[True, False],
1717
)
18-
def test_scan_cpu_equiv(samples: int, cmplx: bool):
18+
@pytest.mark.parametrize(
19+
"device",
20+
[
21+
"cpu",
22+
pytest.param(
23+
"cuda",
24+
marks=pytest.mark.skipif(
25+
not torch.cuda.is_available(), reason="CUDA not available"
26+
),
27+
),
28+
],
29+
)
30+
def test_scan_equiv(samples: int, cmplx: bool, device: str):
1931
batch_size = 4
2032
x = torch.randn(
21-
batch_size, samples, dtype=torch.float32 if not cmplx else torch.complex64
33+
batch_size,
34+
samples,
35+
dtype=torch.float32 if not cmplx else torch.complex64,
36+
device=device,
2237
)
23-
A = torch.rand_like(x) * 1.8 - 0.9
24-
zi = torch.randn(batch_size, dtype=x.dtype)
25-
26-
numba_y = torch.from_numpy(
27-
lpc_np(
28-
x.cpu().numpy(),
29-
-A.cpu().unsqueeze(2).numpy(),
30-
zi.cpu().unsqueeze(1).numpy(),
38+
if cmplx:
39+
A = torch.rand(
40+
batch_size, samples, dtype=x.dtype, device=device
41+
).sqrt() * torch.exp(
42+
2j
43+
* torch.rand(batch_size, samples, dtype=x.dtype, device=device)
44+
* torch.pi
3145
)
32-
)
33-
ext_y = torch.ops.torchlpc.scan_cpu(x, A, zi)
46+
else:
47+
A = torch.rand_like(x) * 1.8 - 0.9
48+
zi = torch.randn(batch_size, dtype=x.dtype, device=device)
3449

35-
assert torch.allclose(numba_y, ext_y)
50+
if device == "cuda":
51+
numba_y = lpc_cuda(x, -A.unsqueeze(2), zi.unsqueeze(1))
52+
else:
53+
numba_y = torch.from_numpy(
54+
lpc_np(
55+
x.cpu().numpy(),
56+
-A.cpu().unsqueeze(2).numpy(),
57+
zi.cpu().unsqueeze(1).numpy(),
58+
)
59+
)
60+
ext_y = torch.ops.torchlpc.scan(x, A, zi)
61+
62+
assert torch.allclose(numba_y, ext_y, atol=5e-7), torch.max(
63+
torch.abs(numba_y - ext_y)
64+
).item()
3665

3766

3867
@pytest.mark.parametrize(
@@ -43,12 +72,12 @@ def test_scan_cpu_equiv(samples: int, cmplx: bool):
4372
"cmplx",
4473
[True, False],
4574
)
46-
def test_lpc_cpu_equiv(samples: int, cmplx: bool):
75+
def test_lpc_equiv(samples: int, cmplx: bool):
4776
batch_size = 4
4877
x, A, zi = tuple(
4978
x.to("cpu") for x in create_test_inputs(batch_size, samples, cmplx)
5079
)
5180
numba_y = torch.from_numpy(lpc_np(x.numpy(), A.numpy(), zi.numpy()))
52-
ext_y = torch.ops.torchlpc.lpc_cpu(x, A, zi)
81+
ext_y = torch.ops.torchlpc.lpc(x, A, zi)
5382

5483
assert torch.allclose(numba_y, ext_y)

tests/test_grad.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,10 @@ def test_float64_vs_32_cuda():
123123
"zi_requires_grad",
124124
[True, False],
125125
)
126+
@pytest.mark.parametrize(
127+
"cmplx",
128+
[True, False],
129+
)
126130
@pytest.mark.parametrize(
127131
"device",
128132
[
@@ -139,13 +143,25 @@ def test_parallel_scan(
139143
x_requires_grad: bool,
140144
a_requires_grad: bool,
141145
zi_requires_grad: bool,
146+
cmplx: bool,
142147
device: str,
143148
):
144149
batch_size = 2
145150
samples = 123
146-
x = torch.randn(batch_size, samples, dtype=torch.double, device=device)
147-
A = torch.rand(batch_size, samples, dtype=torch.double, device=device) * 2 - 1
148-
zi = torch.randn(batch_size, dtype=torch.double, device=device)
151+
dtype = torch.complex128 if cmplx else torch.double
152+
x = torch.randn(batch_size, samples, dtype=dtype, device=device)
153+
if cmplx:
154+
A = torch.rand(
155+
batch_size, samples, dtype=torch.double, device=device
156+
).sqrt() * torch.exp(
157+
1j
158+
* torch.rand(batch_size, samples, dtype=torch.double, device=device)
159+
* 2
160+
* torch.pi
161+
)
162+
else:
163+
A = torch.rand(batch_size, samples, dtype=dtype, device=device) * 2 - 1
164+
zi = torch.randn(batch_size, dtype=dtype, device=device)
149165

150166
A.requires_grad = a_requires_grad
151167
x.requires_grad = x_requires_grad

torchlpc/__init__.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,24 @@
33
from pathlib import Path
44
import warnings
55

6-
so_files = list(Path(__file__).parent.glob("_C*.so"))
7-
# assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
8-
if len(so_files) == 1:
9-
torch.ops.load_library(so_files[0])
6+
# so_files = list(Path(__file__).parent.glob("_C*.so"))
7+
# # assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
8+
# if len(so_files) == 1:
9+
# torch.ops.load_library(so_files[0])
10+
# EXTENSION_LOADED = True
11+
# elif len(so_files) > 1:
12+
# raise ValueError(f"Expected one _C*.so file, found {len(so_files)}")
13+
# else:
14+
# warnings.warn("No _C*.so file found. Custom extension not loaded.")
15+
# EXTENSION_LOADED = False
16+
17+
try:
18+
from . import _C
19+
1020
EXTENSION_LOADED = True
11-
elif len(so_files) > 1:
12-
raise ValueError(f"Expected one _C*.so file, found {len(so_files)}")
13-
else:
14-
warnings.warn("No _C*.so file found. Custom extension not loaded.")
21+
except ImportError:
1522
EXTENSION_LOADED = False
23+
warnings.warn("Custom extension not loaded. Falling back to Numba implementation.")
1624

1725
from .core import LPC
1826

torchlpc/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def forward(x: torch.Tensor, A: torch.Tensor, zi: torch.Tensor) -> torch.Tensor:
162162
if x.is_cuda:
163163
y = lpc_cuda(x.detach(), A.detach(), zi.detach())
164164
elif EXTENSION_LOADED:
165-
y = torch.ops.torchlpc.lpc_cpu(x, A, zi)
165+
y = torch.ops.torchlpc.lpc(x, A, zi)
166166
else:
167167
warnings.warn(
168168
"Cannot find custom extension. Falling back to Numba implementation which will be deprecated in v1.0."

torchlpc/csrc/cuda/LICENSE.txt

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
Copyright (c) <2017> <eric@ericmart.in>
2+
3+
Permission is hereby granted, free of charge, to any person obtaining a copy
4+
of this software and associated documentation files (the "Software"), to deal
5+
in the Software without restriction, including without limitation the rights
6+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7+
copies of the Software, and to permit persons to whom the Software is
8+
furnished to do so, subject to the following conditions:
9+
10+
The above copyright notice and this permission notice shall be included in all
11+
copies or substantial portions of the Software.
12+
13+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19+
SOFTWARE.

0 commit comments

Comments
 (0)