Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 54 additions & 17 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import setuptools
import os
import glob
import torch
from torch.utils import cpp_extension
from torch.utils.cpp_extension import (
CppExtension,
CUDAExtension,
BuildExtension,
CUDA_HOME,
)

NAME = "torchlpc"
library_name = "torchlpc"
VERSION = "0.7.dev"
MAINTAINER = "Chin-Yun Yu"
EMAIL = "chin-yun.yu@qmul.ac.uk"
Expand All @@ -12,15 +19,51 @@
long_description = fh.read()


# if torch.__version__ >= "2.6.0":
# py_limited_api = True
# else:
py_limited_api = False


def get_extensions():
use_cuda = torch.cuda.is_available() and CUDA_HOME is not None
use_openmp = torch.backends.openmp.is_available()
extension = CUDAExtension if use_cuda else CppExtension

extra_link_args = []
extra_compile_args = {}
if use_openmp:
extra_compile_args["cxx"] = ["-fopenmp"]
extra_link_args.append("-lgomp")

this_dir = os.path.abspath(os.path.dirname(__file__))
extensions_dir = os.path.join(this_dir, library_name, "csrc")
sources = list(glob.glob(os.path.join(extensions_dir, "*.cpp")))

extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "*.cu")))

if use_cuda:
sources += cuda_sources

ext_modules = [
extension(
f"{library_name}._C",
sources,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
py_limited_api=py_limited_api,
)
]

return ext_modules


extra_link_args = []
extra_compile_args = {}
# check if openmp is available
if torch.backends.openmp.is_available():
extra_compile_args["cxx"] = ["-fopenmp"]
extra_link_args.append("-lgomp")

setuptools.setup(
name=NAME,
name=library_name,
version=VERSION,
author=MAINTAINER,
author_email=EMAIL,
Expand All @@ -32,16 +75,10 @@
install_requires=["torch>=2.0", "numpy", "numba"],
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
ext_modules=[
cpp_extension.CppExtension(
"torchlpc._C",
["torchlpc/csrc/scan_cpu.cpp"],
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
],
cmdclass={"build_ext": cpp_extension.BuildExtension},
license="MIT",
ext_modules=get_extensions(),
cmdclass={"build_ext": BuildExtension},
options={"bdist_wheel": {"py_limited_api": "cp39"}} if py_limited_api else {},
)
61 changes: 45 additions & 16 deletions tests/test_extension.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn.functional as F
import pytest
from torchlpc.core import lpc_np
from torchlpc.core import lpc_np, lpc_cuda


from .test_grad import create_test_inputs
Expand All @@ -15,24 +15,53 @@
"cmplx",
[True, False],
)
def test_scan_cpu_equiv(samples: int, cmplx: bool):
@pytest.mark.parametrize(
"device",
[
"cpu",
pytest.param(
"cuda",
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA not available"
),
),
],
)
def test_scan_equiv(samples: int, cmplx: bool, device: str):
batch_size = 4
x = torch.randn(
batch_size, samples, dtype=torch.float32 if not cmplx else torch.complex64
batch_size,
samples,
dtype=torch.float32 if not cmplx else torch.complex64,
device=device,
)
A = torch.rand_like(x) * 1.8 - 0.9
zi = torch.randn(batch_size, dtype=x.dtype)

numba_y = torch.from_numpy(
lpc_np(
x.cpu().numpy(),
-A.cpu().unsqueeze(2).numpy(),
zi.cpu().unsqueeze(1).numpy(),
if cmplx:
A = torch.rand(
batch_size, samples, dtype=x.dtype, device=device
).sqrt() * torch.exp(
2j
* torch.rand(batch_size, samples, dtype=x.dtype, device=device)
* torch.pi
)
)
ext_y = torch.ops.torchlpc.scan_cpu(x, A, zi)
else:
A = torch.rand_like(x) * 1.8 - 0.9
zi = torch.randn(batch_size, dtype=x.dtype, device=device)

assert torch.allclose(numba_y, ext_y)
if device == "cuda":
numba_y = lpc_cuda(x, -A.unsqueeze(2), zi.unsqueeze(1))
else:
numba_y = torch.from_numpy(
lpc_np(
x.cpu().numpy(),
-A.cpu().unsqueeze(2).numpy(),
zi.cpu().unsqueeze(1).numpy(),
)
)
ext_y = torch.ops.torchlpc.scan(x, A, zi)

assert torch.allclose(numba_y, ext_y, atol=5e-7), torch.max(
torch.abs(numba_y - ext_y)
).item()


@pytest.mark.parametrize(
Expand All @@ -43,12 +72,12 @@ def test_scan_cpu_equiv(samples: int, cmplx: bool):
"cmplx",
[True, False],
)
def test_lpc_cpu_equiv(samples: int, cmplx: bool):
def test_lpc_equiv(samples: int, cmplx: bool):
batch_size = 4
x, A, zi = tuple(
x.to("cpu") for x in create_test_inputs(batch_size, samples, cmplx)
)
numba_y = torch.from_numpy(lpc_np(x.numpy(), A.numpy(), zi.numpy()))
ext_y = torch.ops.torchlpc.lpc_cpu(x, A, zi)
ext_y = torch.ops.torchlpc.lpc(x, A, zi)

assert torch.allclose(numba_y, ext_y)
22 changes: 19 additions & 3 deletions tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ def test_float64_vs_32_cuda():
"zi_requires_grad",
[True, False],
)
@pytest.mark.parametrize(
"cmplx",
[True, False],
)
@pytest.mark.parametrize(
"device",
[
Expand All @@ -139,13 +143,25 @@ def test_parallel_scan(
x_requires_grad: bool,
a_requires_grad: bool,
zi_requires_grad: bool,
cmplx: bool,
device: str,
):
batch_size = 2
samples = 123
x = torch.randn(batch_size, samples, dtype=torch.double, device=device)
A = torch.rand(batch_size, samples, dtype=torch.double, device=device) * 2 - 1
zi = torch.randn(batch_size, dtype=torch.double, device=device)
dtype = torch.complex128 if cmplx else torch.double
x = torch.randn(batch_size, samples, dtype=dtype, device=device)
if cmplx:
A = torch.rand(
batch_size, samples, dtype=torch.double, device=device
).sqrt() * torch.exp(
1j
* torch.rand(batch_size, samples, dtype=torch.double, device=device)
* 2
* torch.pi
)
else:
A = torch.rand(batch_size, samples, dtype=dtype, device=device) * 2 - 1
zi = torch.randn(batch_size, dtype=dtype, device=device)

A.requires_grad = a_requires_grad
x.requires_grad = x_requires_grad
Expand Down
24 changes: 16 additions & 8 deletions torchlpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,24 @@
from pathlib import Path
import warnings

so_files = list(Path(__file__).parent.glob("_C*.so"))
# assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
if len(so_files) == 1:
torch.ops.load_library(so_files[0])
# so_files = list(Path(__file__).parent.glob("_C*.so"))
# # assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
# if len(so_files) == 1:
# torch.ops.load_library(so_files[0])
# EXTENSION_LOADED = True
# elif len(so_files) > 1:
# raise ValueError(f"Expected one _C*.so file, found {len(so_files)}")
# else:
# warnings.warn("No _C*.so file found. Custom extension not loaded.")
# EXTENSION_LOADED = False

try:
from . import _C

EXTENSION_LOADED = True
elif len(so_files) > 1:
raise ValueError(f"Expected one _C*.so file, found {len(so_files)}")
else:
warnings.warn("No _C*.so file found. Custom extension not loaded.")
except ImportError:
EXTENSION_LOADED = False
warnings.warn("Custom extension not loaded. Falling back to Numba implementation.")

from .core import LPC

Expand Down
2 changes: 1 addition & 1 deletion torchlpc/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def forward(x: torch.Tensor, A: torch.Tensor, zi: torch.Tensor) -> torch.Tensor:
if x.is_cuda:
y = lpc_cuda(x.detach(), A.detach(), zi.detach())
elif EXTENSION_LOADED:
y = torch.ops.torchlpc.lpc_cpu(x, A, zi)
y = torch.ops.torchlpc.lpc(x, A, zi)
else:
warnings.warn(
"Cannot find custom extension. Falling back to Numba implementation which will be deprecated in v1.0."
Expand Down
19 changes: 19 additions & 0 deletions torchlpc/csrc/cuda/LICENSE.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
Copyright (c) <2017> <eric@ericmart.in>

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Loading
Loading