diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 03fbf19..a117ac0 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -36,6 +36,10 @@ jobs: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Build CPP extension + run: | + python setup.py build + find build/ -name "_C*.so" -exec cp {} ./torchlpc/ \; - name: Test with pytest run: | pytest diff --git a/setup.py b/setup.py index 5718687..f16fc6d 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,8 @@ import setuptools +from torch.utils import cpp_extension NAME = "torchlpc" -VERSION = "0.6" +VERSION = "0.7.dev" MAINTAINER = "Chin-Yun Yu" EMAIL = "chin-yun.yu@qmul.ac.uk" @@ -25,4 +26,8 @@ "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ], + ext_modules=[ + cpp_extension.CppExtension("torchlpc._C", ["torchlpc/csrc/scan_cpu.cpp"]) + ], + cmdclass={"build_ext": cpp_extension.BuildExtension}, ) diff --git a/tests/test_extension.py b/tests/test_extension.py new file mode 100644 index 0000000..ef94e49 --- /dev/null +++ b/tests/test_extension.py @@ -0,0 +1,35 @@ +import torch +import torch.nn.functional as F +import pytest +from torchlpc.core import lpc_np + + +from .test_grad import create_test_inputs + + +@pytest.mark.parametrize( + "samples", + [64, 4097], +) +@pytest.mark.parametrize( + "cmplx", + [True, False], +) +def test_scan_cpu_equiv(samples: int, cmplx: bool): + batch_size = 4 + x = torch.randn( + batch_size, samples, dtype=torch.float32 if not cmplx else torch.complex64 + ) + A = torch.rand_like(x) * 1.8 - 0.9 + zi = torch.randn(batch_size, dtype=x.dtype) + + numba_y = torch.from_numpy( + lpc_np( + x.cpu().numpy(), + -A.cpu().unsqueeze(2).numpy(), + zi.cpu().unsqueeze(1).numpy(), + ) + ) + ext_y = torch.ops.torchlpc.scan_cpu(x, A, zi) + + assert torch.allclose(numba_y, ext_y) diff --git a/tests/test_grad.py b/tests/test_grad.py index b771170..028c6f5 100644 --- a/tests/test_grad.py +++ b/tests/test_grad.py @@ -2,7 +2,7 @@ import torch from torch.autograd.gradcheck import gradcheck, gradgradcheck from torchlpc.core import LPC -from torchlpc.recurrence import RecurrenceCUDA +from torchlpc.recurrence import Recurrence def get_random_biquads(cmplx=False): @@ -123,21 +123,33 @@ def test_float64_vs_32_cuda(): "zi_requires_grad", [True, False], ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_cuda_parallel_scan( +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda", + marks=pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" + ), + ), + ], +) +def test_parallel_scan( x_requires_grad: bool, a_requires_grad: bool, zi_requires_grad: bool, + device: str, ): batch_size = 2 samples = 123 - x = torch.randn(batch_size, samples, dtype=torch.double, device="cuda") - A = torch.rand(batch_size, samples, dtype=torch.double, device="cuda") * 2 - 1 - zi = torch.randn(batch_size, dtype=torch.double, device="cuda") + x = torch.randn(batch_size, samples, dtype=torch.double, device=device) + A = torch.rand(batch_size, samples, dtype=torch.double, device=device) * 2 - 1 + zi = torch.randn(batch_size, dtype=torch.double, device=device) A.requires_grad = a_requires_grad x.requires_grad = x_requires_grad zi.requires_grad = zi_requires_grad - assert gradcheck(RecurrenceCUDA.apply, (A, x, zi), check_forward_ad=True) - assert gradgradcheck(RecurrenceCUDA.apply, (A, x, zi)) + assert gradcheck(Recurrence.apply, (A, x, zi), check_forward_ad=True) + assert gradgradcheck(Recurrence.apply, (A, x, zi)) diff --git a/tests/test_vmap.py b/tests/test_vmap.py index f50c57c..d99e8a1 100644 --- a/tests/test_vmap.py +++ b/tests/test_vmap.py @@ -3,7 +3,7 @@ from torch.func import jacfwd import pytest from torchlpc.core import LPC -from torchlpc.recurrence import RecurrenceCUDA +from torchlpc.recurrence import Recurrence from .test_grad import create_test_inputs @@ -48,14 +48,25 @@ def func(x, A, zi): assert torch.allclose(jac, arg.grad) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_cuda_parallel_scan_vmap(): +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda", + marks=pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" + ), + ), + ], +) +def test_parallel_scan_vmap(device: str): batch_size = 3 samples = 255 - x = torch.randn(batch_size, samples, dtype=torch.double, device="cuda") - A = torch.rand(batch_size, samples, dtype=torch.double, device="cuda") * 2 - 1 - zi = torch.randn(batch_size, dtype=torch.double, device="cuda") - y = torch.randn(batch_size, samples, dtype=torch.double, device="cuda") + x = torch.randn(batch_size, samples, dtype=torch.double, device=device) + A = torch.rand(batch_size, samples, dtype=torch.double, device=device) * 2 - 1 + zi = torch.randn(batch_size, dtype=torch.double, device=device) + y = torch.randn(batch_size, samples, dtype=torch.double, device=device) A.requires_grad = True x.requires_grad = True @@ -64,7 +75,7 @@ def test_cuda_parallel_scan_vmap(): args = (x, A, zi) def func(x, A, zi): - return F.mse_loss(RecurrenceCUDA.apply(A, x, zi), y) + return F.mse_loss(Recurrence.apply(A, x, zi), y) jacs = jacfwd(func, argnums=tuple(range(len(args))))(*args) diff --git a/torchlpc/__init__.py b/torchlpc/__init__.py index acecbed..79f86b6 100644 --- a/torchlpc/__init__.py +++ b/torchlpc/__init__.py @@ -1,9 +1,23 @@ import torch from typing import Optional +from pathlib import Path +import warnings + +so_files = list(Path(__file__).parent.glob("_C*.so")) +# assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}" +if len(so_files) == 1: + torch.ops.load_library(so_files[0]) + EXTENSION_LOADED = True +elif len(so_files) > 1: + raise ValueError(f"Expected one _C*.so file, found {len(so_files)}") +else: + warnings.warn("No _C*.so file found. Custom extension not loaded.") + EXTENSION_LOADED = False from .core import LPC -from .parallel_scan import WARPSIZE -from .recurrence import RecurrenceCUDA + +# from .parallel_scan import WARPSIZE +from .recurrence import Recurrence __all__ = ["sample_wise_lpc"] @@ -37,7 +51,9 @@ def sample_wise_lpc( else: assert zi.shape == (B, order) - if order == 1 and x.is_cuda and B * WARPSIZE < T: - return RecurrenceCUDA.apply(-a.squeeze(2), x, zi.squeeze(1)) + # if order == 1 and x.is_cuda and B * WARPSIZE < T: + # return RecurrenceCUDA.apply(-a.squeeze(2), x, zi.squeeze(1)) + if order == 1: + return Recurrence.apply(-a.squeeze(2), x, zi.squeeze(1)) return LPC.apply(x, a, zi) diff --git a/torchlpc/csrc/scan_cpu.cpp b/torchlpc/csrc/scan_cpu.cpp new file mode 100644 index 0000000..dd85657 --- /dev/null +++ b/torchlpc/csrc/scan_cpu.cpp @@ -0,0 +1,86 @@ +#include +#include + +#include +#include +#include + +template +void scan_cpu(const at::Tensor &input, const at::Tensor &weights, + const at::Tensor &initials, const at::Tensor &output) { + TORCH_CHECK(input.dim() == 2, "Input must be 2D"); + TORCH_CHECK(initials.dim() == 1, "Initials must be 1D"); + TORCH_CHECK(weights.sizes() == input.sizes(), + "Weights must have the same size as input"); + TORCH_CHECK(output.sizes() == input.sizes(), + "Output must have the same size as input"); + TORCH_CHECK(initials.size(0) == input.size(0), + "The first dimension of initials must be the same as the first " + "dimension of input"); + TORCH_INTERNAL_ASSERT(input.device().is_cpu(), "Input must be on CPU"); + TORCH_INTERNAL_ASSERT(initials.device().is_cpu(), + "Initials must be on CPU"); + TORCH_INTERNAL_ASSERT(weights.device().is_cpu(), "Weights must be on CPU"); + TORCH_INTERNAL_ASSERT(output.device().is_cpu(), "Output must be on CPU"); + TORCH_INTERNAL_ASSERT(output.is_contiguous(), "Output must be contiguous"); + + auto input_contiguous = input.contiguous(); + auto weights_contiguous = weights.contiguous(); + auto initials_contiguous = initials.contiguous(); + + auto n_batch = input.size(0); + auto T = input.size(1); + auto total_size = input.numel(); + + std::pair buffer[total_size]; + + const scalar_t *input_ptr = input_contiguous.data_ptr(); + const scalar_t *initials_ptr = initials_contiguous.data_ptr(); + const scalar_t *weights_ptr = weights_contiguous.data_ptr(); + scalar_t *output_ptr = output.data_ptr(); + + std::transform(weights_ptr, weights_ptr + total_size, input_ptr, buffer, + [](const scalar_t &a, const scalar_t &b) { + return std::make_pair(a, b); + }); + + at::parallel_for(0, n_batch, 1, [&](int64_t start, int64_t end) { + for (auto b = start; b < end; b++) { + std::inclusive_scan( + buffer + b * T, buffer + (b + 1) * T, buffer + b * T, + [](const std::pair &a, + const std::pair &b) { + return std::make_pair(a.first * b.first, + a.second * b.first + b.second); + }, + std::make_pair((scalar_t)1.0, initials_ptr[b])); + } + }); + + std::transform( + buffer, buffer + total_size, output_ptr, + [](const std::pair &a) { return a.second; }); +} + +at::Tensor scan_cpu_wrapper(const at::Tensor &input, const at::Tensor &weights, + const at::Tensor &initials) { + TORCH_CHECK(input.is_floating_point() || input.is_complex(), + "Input must be floating point or complex"); + TORCH_CHECK(initials.scalar_type() == input.scalar_type(), + "Initials must have the same scalar type as input"); + TORCH_CHECK(weights.scalar_type() == input.scalar_type(), + "Weights must have the same scalar type as input"); + + auto output = at::empty_like(input); + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( + input.scalar_type(), "scan_cpu", + [&] { scan_cpu(input, weights, initials, output); }); + return output; +} + +TORCH_LIBRARY(torchlpc, m) { + m.def("torchlpc::scan_cpu(Tensor a, Tensor b, Tensor c) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(torchlpc, CPU, m) { m.impl("scan_cpu", &scan_cpu_wrapper); } diff --git a/torchlpc/recurrence.py b/torchlpc/recurrence.py index 0ca5835..05b9fd5 100644 --- a/torchlpc/recurrence.py +++ b/torchlpc/recurrence.py @@ -4,10 +4,12 @@ from numba import cuda from typing import Tuple, Optional, Any, List -from .parallel_scan import compute_linear_recurrence +from .parallel_scan import compute_linear_recurrence, WARPSIZE +from .core import lpc_cuda, lpc_np +from . import EXTENSION_LOADED -class RecurrenceCUDA(Function): +class Recurrence(Function): @staticmethod def forward( decay: torch.Tensor, @@ -15,15 +17,32 @@ def forward( initial_state: torch.Tensor, ) -> torch.Tensor: n_dims, n_steps = decay.shape - out = torch.empty_like(impulse) - compute_linear_recurrence( - cuda.as_cuda_array(decay.detach()), - cuda.as_cuda_array(impulse.detach()), - cuda.as_cuda_array(initial_state.detach()), - cuda.as_cuda_array(out), - n_dims, - n_steps, - ) + if decay.is_cuda: + if n_dims * WARPSIZE < n_steps: + out = torch.empty_like(impulse) + compute_linear_recurrence( + cuda.as_cuda_array(decay.detach()), + cuda.as_cuda_array(impulse.detach()), + cuda.as_cuda_array(initial_state.detach()), + cuda.as_cuda_array(out), + n_dims, + n_steps, + ) + else: + out = lpc_cuda(impulse, -decay.unsqueeze(2), initial_state.unsqueeze(1)) + else: + num_threads = torch.get_num_threads() + # This is just a rough estimation of the computational cost + if EXTENSION_LOADED and min(n_dims, num_threads) < num_threads / 3: + out = torch.ops.torchlpc.scan_cpu(impulse, decay, initial_state) + else: + out = torch.from_numpy( + lpc_np( + impulse.detach().numpy(), + -decay.unsqueeze(2).detach().numpy(), + initial_state.unsqueeze(1).detach().numpy(), + ) + ) return out @staticmethod @@ -48,7 +67,7 @@ def backward( padded_decay = padded_decay[:, 1:] init = padded_grad_out.new_zeros(n_dims) - flipped_grad_impulse = RecurrenceCUDA.apply( + flipped_grad_impulse = Recurrence.apply( padded_decay.flip(1).conj_physical(), padded_grad_out.flip(1), init, @@ -91,7 +110,7 @@ def jvp( fwd_decay = concat_out * grad_decay fwd_impulse = fwd_impulse + fwd_decay - return RecurrenceCUDA.apply(decay, fwd_impulse, fwd_initial_state) + return Recurrence.apply(decay, fwd_impulse, fwd_initial_state) @staticmethod def vmap(info, in_dims, *args): @@ -107,5 +126,8 @@ def maybe_expand_bdim_at_front(x, x_bdim): ) ) - out = RecurrenceCUDA.apply(decay, impulse, initial_state) + out = Recurrence.apply(decay, impulse, initial_state) return out.reshape(info.batch_size, -1, *out.shape[1:]), 0 + + +RecurrenceCUDA = Recurrence