From 1f392ba063c5370f7f6c0b7b447a8a04c3fc47a3 Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Wed, 22 Jan 2025 17:33:05 +0000 Subject: [PATCH 01/15] draft: scan extension on cpu --- torchlpc/csrc/scan_cpu.cpp | 75 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 torchlpc/csrc/scan_cpu.cpp diff --git a/torchlpc/csrc/scan_cpu.cpp b/torchlpc/csrc/scan_cpu.cpp new file mode 100644 index 0000000..b555e90 --- /dev/null +++ b/torchlpc/csrc/scan_cpu.cpp @@ -0,0 +1,75 @@ +#include +#include +#include +#include +#include + +template +at::Tensor scan_cpu(const at::Tensor &input, const at::Tensor &initials, const at::Tensor &weights) +{ + TORCH_CHECK(input.dim() == 2, "Input must be 2D"); + TORCH_CHECK(initials.dim() == 1, "Initials must be 1D"); + TORCH_CHECK(weights.sizes() == input.sizes(), "Weights must have the same size as input"); + TORCH_CHECK(initials.size(0) == input.size(0), "The first dimension of initials must be the same as the first dimension of input"); + // TORCH_INTERNAL_ASSERT(input.scalar_type() == at::kFloat, "Input must be float"); + // TORCH_INTERNAL_ASSERT(initials.scalar_type() == at::kFloat, "Initials must be float"); + // TORCH_INTERNAL_ASSERT(weights.scalar_type() == at::kFloat, "Weights must be float"); + TORCH_INTERNAL_ASSERT(input.device().is_cpu(), "Input must be on CPU"); + TORCH_INTERNAL_ASSERT(initials.device().is_cpu(), "Initials must be on CPU"); + TORCH_INTERNAL_ASSERT(weights.device().is_cpu(), "Weights must be on CPU"); + TORCH_INTERNAL_ASSERT(input.is_contiguous(), "Input must be contiguous"); + TORCH_INTERNAL_ASSERT(initials.is_contiguous(), "Initials must be contiguous"); + TORCH_INTERNAL_ASSERT(weights.is_contiguous(), "Weights must be contiguous"); + + auto n_batch = input.size(0); + auto T = input.size(1); + auto total_size = input.numel(); + + std::array, total_size> buffer; + at::Tensor output = at::empty_like(input); + + const scalar_t *input_ptr = input.data_ptr(); + const scalar_t *initials_ptr = initials.data_ptr(); + const scalar_t *weights_ptr = weights.data_ptr(); + scalar_t *output_ptr = output.data_ptr(); + + std::transform(weights_ptr, weights_ptr + total_size, input_ptr, buffer.begin(), std::make_pair); + + at::parallel_for(0, n_batch, 1, [buffer, T, initials_ptr](int64_t start, int64_t end) + { + for (auto b = start; b < end; b++) + { + std::inclusive_scan( + buffer.begin() + b * T, + buffer.begin() + (b + 1) * T, + buffer.begin() + b * T, + [](std::pair &a, const std::pair &b) { + return std::make_pair(a.first * b.first, a.second * b.first + b.second); + }, std::make_pair(1.0, initials_ptr[b])); + } }); + + std::transform(buffer.begin(), buffer.end(), output_ptr, [](const std::pair &a) + { return a.second; }); + + return output; +} + +at::Tensor scan_cpu_wrapper(const at::Tensor &input, const at::Tensor &initials, const at::Tensor &weights) +{ + TORCH_CHECK(input.is_floating_point() || input.is_complex(), "Input must be floating point or complex"); + TORCH_CHECK(initials.is_floating_point() || initials.is_complex(), "Initials must be floating point or complex"); + TORCH_CHECK(weights.is_floating_point() || weights.is_complex(), "Weights must be floating point or complex"); + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "scan_cpu", [&] + { return scan_cpu(input, initials, weights); }); +} + +TORCH_LIBRARY(torchlpc, m) +{ + m.def("torchlpc::scan_cpu(Tensor input, Tensor initials, Tensor weights) -> Tensor", &scan_cpu); +} + +TORCH_LIBRARY_IMPL(torchlpc, CPU, m) +{ + m.impl("scan_cpu", &scan_cpu_wrapper); +} \ No newline at end of file From d77846aae5f4455a22da76befff9de50a779804e Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Wed, 22 Jan 2025 17:33:21 +0000 Subject: [PATCH 02/15] include cpp extension in setup.py --- setup.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/setup.py b/setup.py index 5718687..850f83d 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,5 @@ import setuptools +from torch.utils import cpp_extension NAME = "torchlpc" VERSION = "0.6" @@ -25,4 +26,13 @@ "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ], + ext_modules=[ + cpp_extension.CppExtension("torchlpc", ["torchlpc/csrc/scan_cpu.cpp"]) + ], + cmdclass={"build_ext": cpp_extension.BuildExtension}, + # include_dirs=[ + # "/Library/Developer/CommandLineTools/usr/lib/clang/16/include", + # "/Library/Developer/CommandLineTools/usr/include", + # "/Library/Developer/CommandLineTools/SDKs/MacOSX15.2.sdk/usr/include", + # ], ) From 37b3340686c1eca6b26fa73655a4baf8d1d99033 Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Wed, 22 Jan 2025 17:35:16 +0000 Subject: [PATCH 03/15] fix: remove extra arg --- torchlpc/csrc/scan_cpu.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchlpc/csrc/scan_cpu.cpp b/torchlpc/csrc/scan_cpu.cpp index b555e90..ed71645 100644 --- a/torchlpc/csrc/scan_cpu.cpp +++ b/torchlpc/csrc/scan_cpu.cpp @@ -66,7 +66,7 @@ at::Tensor scan_cpu_wrapper(const at::Tensor &input, const at::Tensor &initials, TORCH_LIBRARY(torchlpc, m) { - m.def("torchlpc::scan_cpu(Tensor input, Tensor initials, Tensor weights) -> Tensor", &scan_cpu); + m.def("torchlpc::scan_cpu(Tensor input, Tensor initials, Tensor weights) -> Tensor"); } TORCH_LIBRARY_IMPL(torchlpc, CPU, m) From e110d14f07ba63d1327a86883a58adc2c77d2102 Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Thu, 23 Jan 2025 18:18:12 +0800 Subject: [PATCH 04/15] fix: compile errors --- torchlpc/csrc/scan_cpu.cpp | 50 +++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/torchlpc/csrc/scan_cpu.cpp b/torchlpc/csrc/scan_cpu.cpp index ed71645..72ddfdb 100644 --- a/torchlpc/csrc/scan_cpu.cpp +++ b/torchlpc/csrc/scan_cpu.cpp @@ -1,75 +1,75 @@ #include #include #include -#include +#include #include template -at::Tensor scan_cpu(const at::Tensor &input, const at::Tensor &initials, const at::Tensor &weights) +void scan_cpu(const at::Tensor &input, const at::Tensor &weights, const at::Tensor &initials, at::Tensor &output) { TORCH_CHECK(input.dim() == 2, "Input must be 2D"); TORCH_CHECK(initials.dim() == 1, "Initials must be 1D"); TORCH_CHECK(weights.sizes() == input.sizes(), "Weights must have the same size as input"); + TORCH_CHECK(output.sizes() == input.sizes(), "Output must have the same size as input"); TORCH_CHECK(initials.size(0) == input.size(0), "The first dimension of initials must be the same as the first dimension of input"); - // TORCH_INTERNAL_ASSERT(input.scalar_type() == at::kFloat, "Input must be float"); - // TORCH_INTERNAL_ASSERT(initials.scalar_type() == at::kFloat, "Initials must be float"); - // TORCH_INTERNAL_ASSERT(weights.scalar_type() == at::kFloat, "Weights must be float"); TORCH_INTERNAL_ASSERT(input.device().is_cpu(), "Input must be on CPU"); TORCH_INTERNAL_ASSERT(initials.device().is_cpu(), "Initials must be on CPU"); TORCH_INTERNAL_ASSERT(weights.device().is_cpu(), "Weights must be on CPU"); + TORCH_INTERNAL_ASSERT(output.device().is_cpu(), "Output must be on CPU"); TORCH_INTERNAL_ASSERT(input.is_contiguous(), "Input must be contiguous"); - TORCH_INTERNAL_ASSERT(initials.is_contiguous(), "Initials must be contiguous"); TORCH_INTERNAL_ASSERT(weights.is_contiguous(), "Weights must be contiguous"); + TORCH_INTERNAL_ASSERT(output.is_contiguous(), "Output must be contiguous"); auto n_batch = input.size(0); auto T = input.size(1); auto total_size = input.numel(); - std::array, total_size> buffer; - at::Tensor output = at::empty_like(input); + std::pair buffer[total_size]; const scalar_t *input_ptr = input.data_ptr(); const scalar_t *initials_ptr = initials.data_ptr(); const scalar_t *weights_ptr = weights.data_ptr(); scalar_t *output_ptr = output.data_ptr(); - std::transform(weights_ptr, weights_ptr + total_size, input_ptr, buffer.begin(), std::make_pair); + std::transform(weights_ptr, weights_ptr + total_size, input_ptr, buffer, + [](const scalar_t &a, const scalar_t &b) + { return std::make_pair(a, b); }); - at::parallel_for(0, n_batch, 1, [buffer, T, initials_ptr](int64_t start, int64_t end) + at::parallel_for(0, n_batch, 1, [&](int64_t start, int64_t end) { for (auto b = start; b < end; b++) { std::inclusive_scan( - buffer.begin() + b * T, - buffer.begin() + (b + 1) * T, - buffer.begin() + b * T, - [](std::pair &a, const std::pair &b) { - return std::make_pair(a.first * b.first, a.second * b.first + b.second); - }, std::make_pair(1.0, initials_ptr[b])); + buffer + b * T, + buffer + (b + 1) * T, + buffer + b * T, + [](const std::pair &a, const std::pair &b) { + return std::make_pair(a.first * b.first, a.second * b.first + b.second); + }, + std::make_pair((scalar_t)1.0, initials_ptr[b])); } }); - std::transform(buffer.begin(), buffer.end(), output_ptr, [](const std::pair &a) + std::transform(buffer, buffer + total_size, output_ptr, [](const std::pair &a) { return a.second; }); - - return output; } -at::Tensor scan_cpu_wrapper(const at::Tensor &input, const at::Tensor &initials, const at::Tensor &weights) +void scan_cpu_wrapper(const at::Tensor &input, const at::Tensor &weights, const at::Tensor &initials, at::Tensor &output) { TORCH_CHECK(input.is_floating_point() || input.is_complex(), "Input must be floating point or complex"); - TORCH_CHECK(initials.is_floating_point() || initials.is_complex(), "Initials must be floating point or complex"); - TORCH_CHECK(weights.is_floating_point() || weights.is_complex(), "Weights must be floating point or complex"); + TORCH_CHECK(initials.scalar_type() == input.scalar_type(), "Initials must have the same scalar type as input"); + TORCH_CHECK(weights.scalar_type() == input.scalar_type(), "Weights must have the same scalar type as input"); + TORCH_CHECK(output.scalar_type() == input.scalar_type(), "Output must have the same scalar type as input"); AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "scan_cpu", [&] - { return scan_cpu(input, initials, weights); }); + { scan_cpu(input, weights, initials, output); }); } TORCH_LIBRARY(torchlpc, m) { - m.def("torchlpc::scan_cpu(Tensor input, Tensor initials, Tensor weights) -> Tensor"); + m.def("torchlpc::scan_cpu(Tensor a, Tensor b, Tensor c, Tensor(a!) out) -> ()"); } TORCH_LIBRARY_IMPL(torchlpc, CPU, m) { m.impl("scan_cpu", &scan_cpu_wrapper); -} \ No newline at end of file +} From 4611ce2e66a2a3020ef0dfc1fbc177bb4345e804 Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Thu, 23 Jan 2025 18:18:54 +0800 Subject: [PATCH 05/15] use dev versioning --- setup.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/setup.py b/setup.py index 850f83d..f16fc6d 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from torch.utils import cpp_extension NAME = "torchlpc" -VERSION = "0.6" +VERSION = "0.7.dev" MAINTAINER = "Chin-Yun Yu" EMAIL = "chin-yun.yu@qmul.ac.uk" @@ -27,12 +27,7 @@ "Operating System :: OS Independent", ], ext_modules=[ - cpp_extension.CppExtension("torchlpc", ["torchlpc/csrc/scan_cpu.cpp"]) + cpp_extension.CppExtension("torchlpc._C", ["torchlpc/csrc/scan_cpu.cpp"]) ], cmdclass={"build_ext": cpp_extension.BuildExtension}, - # include_dirs=[ - # "/Library/Developer/CommandLineTools/usr/lib/clang/16/include", - # "/Library/Developer/CommandLineTools/usr/include", - # "/Library/Developer/CommandLineTools/SDKs/MacOSX15.2.sdk/usr/include", - # ], ) From 5e28b302e7aaef8815a7f0474a9d1c1a9c524b50 Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Thu, 23 Jan 2025 18:19:19 +0800 Subject: [PATCH 06/15] load library file when being imported --- torchlpc/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torchlpc/__init__.py b/torchlpc/__init__.py index acecbed..68e2fee 100644 --- a/torchlpc/__init__.py +++ b/torchlpc/__init__.py @@ -1,5 +1,6 @@ import torch from typing import Optional +from pathlib import Path from .core import LPC from .parallel_scan import WARPSIZE @@ -8,6 +9,11 @@ __all__ = ["sample_wise_lpc"] +so_files = list(Path(__file__).parent.glob("_C*.so")) +assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}" +torch.ops.load_library(so_files[0]) + + def sample_wise_lpc( x: torch.Tensor, a: torch.Tensor, zi: Optional[torch.Tensor] = None ) -> torch.Tensor: From 5b69d2ae3ba6b0e13ff804ed951b5e8841985f2f Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Thu, 23 Jan 2025 18:19:31 +0800 Subject: [PATCH 07/15] test equivalence to numba version --- tests/test_extension.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 tests/test_extension.py diff --git a/tests/test_extension.py b/tests/test_extension.py new file mode 100644 index 0000000..2a7f2b4 --- /dev/null +++ b/tests/test_extension.py @@ -0,0 +1,36 @@ +import torch +import torch.nn.functional as F +import pytest +from torchlpc.core import lpc_np + + +from .test_grad import create_test_inputs + + +@pytest.mark.parametrize( + "samples", + [64, 4097], +) +@pytest.mark.parametrize( + "cmplx", + [True, False], +) +def test_scan_cpu_equiv(samples: int, cmplx: bool): + batch_size = 4 + x = torch.randn( + batch_size, samples, dtype=torch.float32 if not cmplx else torch.complex64 + ) + A = torch.rand_like(x) * 1.8 - 0.9 + zi = torch.randn(batch_size, dtype=x.dtype) + + numba_y = torch.from_numpy( + lpc_np( + x.cpu().numpy(), + -A.cpu().unsqueeze(2).numpy(), + zi.cpu().unsqueeze(1).numpy(), + ) + ) + ext_y = torch.empty_like(x) + torch.ops.torchlpc.scan_cpu(x, A, zi, ext_y) + + assert torch.allclose(numba_y, ext_y) From 7b54fb15d01f5393004db6232ab34db0a13f7e65 Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Thu, 23 Jan 2025 19:16:27 +0800 Subject: [PATCH 08/15] refactor: return tensor instead of void function --- tests/test_extension.py | 3 +-- torchlpc/csrc/scan_cpu.cpp | 9 ++++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/test_extension.py b/tests/test_extension.py index 2a7f2b4..ef94e49 100644 --- a/tests/test_extension.py +++ b/tests/test_extension.py @@ -30,7 +30,6 @@ def test_scan_cpu_equiv(samples: int, cmplx: bool): zi.cpu().unsqueeze(1).numpy(), ) ) - ext_y = torch.empty_like(x) - torch.ops.torchlpc.scan_cpu(x, A, zi, ext_y) + ext_y = torch.ops.torchlpc.scan_cpu(x, A, zi) assert torch.allclose(numba_y, ext_y) diff --git a/torchlpc/csrc/scan_cpu.cpp b/torchlpc/csrc/scan_cpu.cpp index 72ddfdb..9df56db 100644 --- a/torchlpc/csrc/scan_cpu.cpp +++ b/torchlpc/csrc/scan_cpu.cpp @@ -53,20 +53,23 @@ void scan_cpu(const at::Tensor &input, const at::Tensor &weights, const at::Tens { return a.second; }); } -void scan_cpu_wrapper(const at::Tensor &input, const at::Tensor &weights, const at::Tensor &initials, at::Tensor &output) +at::Tensor scan_cpu_wrapper(const at::Tensor &input, const at::Tensor &weights, const at::Tensor &initials) { TORCH_CHECK(input.is_floating_point() || input.is_complex(), "Input must be floating point or complex"); TORCH_CHECK(initials.scalar_type() == input.scalar_type(), "Initials must have the same scalar type as input"); TORCH_CHECK(weights.scalar_type() == input.scalar_type(), "Weights must have the same scalar type as input"); - TORCH_CHECK(output.scalar_type() == input.scalar_type(), "Output must have the same scalar type as input"); + // TORCH_CHECK(output.scalar_type() == input.scalar_type(), "Output must have the same scalar type as input"); + + auto output = at::empty_like(input); AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "scan_cpu", [&] { scan_cpu(input, weights, initials, output); }); + return output; } TORCH_LIBRARY(torchlpc, m) { - m.def("torchlpc::scan_cpu(Tensor a, Tensor b, Tensor c, Tensor(a!) out) -> ()"); + m.def("torchlpc::scan_cpu(Tensor a, Tensor b, Tensor c) -> Tensor"); } TORCH_LIBRARY_IMPL(torchlpc, CPU, m) From 1e7ea161f457070663d967fafc1fdb2e56814d4e Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Thu, 23 Jan 2025 19:47:59 +0800 Subject: [PATCH 09/15] refactor: rename RecurrenceCUDA to Recurrence to cover CPU device --- torchlpc/recurrence.py | 49 ++++++++++++++++++++++++++++++------------ 1 file changed, 35 insertions(+), 14 deletions(-) diff --git a/torchlpc/recurrence.py b/torchlpc/recurrence.py index 0ca5835..2db8f91 100644 --- a/torchlpc/recurrence.py +++ b/torchlpc/recurrence.py @@ -4,10 +4,11 @@ from numba import cuda from typing import Tuple, Optional, Any, List -from .parallel_scan import compute_linear_recurrence +from .parallel_scan import compute_linear_recurrence, WARPSIZE +from .core import lpc_cuda, lpc_np -class RecurrenceCUDA(Function): +class Recurrence(Function): @staticmethod def forward( decay: torch.Tensor, @@ -15,15 +16,32 @@ def forward( initial_state: torch.Tensor, ) -> torch.Tensor: n_dims, n_steps = decay.shape - out = torch.empty_like(impulse) - compute_linear_recurrence( - cuda.as_cuda_array(decay.detach()), - cuda.as_cuda_array(impulse.detach()), - cuda.as_cuda_array(initial_state.detach()), - cuda.as_cuda_array(out), - n_dims, - n_steps, - ) + if decay.is_cuda: + if n_dims * WARPSIZE < n_steps: + out = torch.empty_like(impulse) + compute_linear_recurrence( + cuda.as_cuda_array(decay.detach()), + cuda.as_cuda_array(impulse.detach()), + cuda.as_cuda_array(initial_state.detach()), + cuda.as_cuda_array(out), + n_dims, + n_steps, + ) + else: + out = lpc_cuda(impulse, -decay.unsqueeze(2), initial_state.unsqueeze(1)) + else: + num_threads = torch.get_num_threads() + # This is just a rough estimation of the computational cost + if min(n_dims, num_threads) < num_threads / 3: + out = torch.ops.torchlpc.scan_cpu(impulse, decay, initial_state) + else: + out = torch.from_numpy( + lpc_np( + impulse.detach().numpy(), + -decay.unsqueeze(2).detach().numpy(), + initial_state.unsqueeze(1).detach().numpy(), + ) + ) return out @staticmethod @@ -48,7 +66,7 @@ def backward( padded_decay = padded_decay[:, 1:] init = padded_grad_out.new_zeros(n_dims) - flipped_grad_impulse = RecurrenceCUDA.apply( + flipped_grad_impulse = Recurrence.apply( padded_decay.flip(1).conj_physical(), padded_grad_out.flip(1), init, @@ -91,7 +109,7 @@ def jvp( fwd_decay = concat_out * grad_decay fwd_impulse = fwd_impulse + fwd_decay - return RecurrenceCUDA.apply(decay, fwd_impulse, fwd_initial_state) + return Recurrence.apply(decay, fwd_impulse, fwd_initial_state) @staticmethod def vmap(info, in_dims, *args): @@ -107,5 +125,8 @@ def maybe_expand_bdim_at_front(x, x_bdim): ) ) - out = RecurrenceCUDA.apply(decay, impulse, initial_state) + out = Recurrence.apply(decay, impulse, initial_state) return out.reshape(info.batch_size, -1, *out.shape[1:]), 0 + + +RecurrenceCUDA = Recurrence From 314c5a27574d2d55ec21eaf922c9859164d367f7 Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Thu, 23 Jan 2025 19:48:52 +0800 Subject: [PATCH 10/15] refactor: update functions to use Recurrence for CPU and CUDA devices --- tests/test_grad.py | 28 ++++++++++++++++++++-------- tests/test_vmap.py | 27 +++++++++++++++++++-------- torchlpc/__init__.py | 26 +++++++++++++++++--------- 3 files changed, 56 insertions(+), 25 deletions(-) diff --git a/tests/test_grad.py b/tests/test_grad.py index b771170..028c6f5 100644 --- a/tests/test_grad.py +++ b/tests/test_grad.py @@ -2,7 +2,7 @@ import torch from torch.autograd.gradcheck import gradcheck, gradgradcheck from torchlpc.core import LPC -from torchlpc.recurrence import RecurrenceCUDA +from torchlpc.recurrence import Recurrence def get_random_biquads(cmplx=False): @@ -123,21 +123,33 @@ def test_float64_vs_32_cuda(): "zi_requires_grad", [True, False], ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_cuda_parallel_scan( +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda", + marks=pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" + ), + ), + ], +) +def test_parallel_scan( x_requires_grad: bool, a_requires_grad: bool, zi_requires_grad: bool, + device: str, ): batch_size = 2 samples = 123 - x = torch.randn(batch_size, samples, dtype=torch.double, device="cuda") - A = torch.rand(batch_size, samples, dtype=torch.double, device="cuda") * 2 - 1 - zi = torch.randn(batch_size, dtype=torch.double, device="cuda") + x = torch.randn(batch_size, samples, dtype=torch.double, device=device) + A = torch.rand(batch_size, samples, dtype=torch.double, device=device) * 2 - 1 + zi = torch.randn(batch_size, dtype=torch.double, device=device) A.requires_grad = a_requires_grad x.requires_grad = x_requires_grad zi.requires_grad = zi_requires_grad - assert gradcheck(RecurrenceCUDA.apply, (A, x, zi), check_forward_ad=True) - assert gradgradcheck(RecurrenceCUDA.apply, (A, x, zi)) + assert gradcheck(Recurrence.apply, (A, x, zi), check_forward_ad=True) + assert gradgradcheck(Recurrence.apply, (A, x, zi)) diff --git a/tests/test_vmap.py b/tests/test_vmap.py index f50c57c..d99e8a1 100644 --- a/tests/test_vmap.py +++ b/tests/test_vmap.py @@ -3,7 +3,7 @@ from torch.func import jacfwd import pytest from torchlpc.core import LPC -from torchlpc.recurrence import RecurrenceCUDA +from torchlpc.recurrence import Recurrence from .test_grad import create_test_inputs @@ -48,14 +48,25 @@ def func(x, A, zi): assert torch.allclose(jac, arg.grad) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_cuda_parallel_scan_vmap(): +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda", + marks=pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" + ), + ), + ], +) +def test_parallel_scan_vmap(device: str): batch_size = 3 samples = 255 - x = torch.randn(batch_size, samples, dtype=torch.double, device="cuda") - A = torch.rand(batch_size, samples, dtype=torch.double, device="cuda") * 2 - 1 - zi = torch.randn(batch_size, dtype=torch.double, device="cuda") - y = torch.randn(batch_size, samples, dtype=torch.double, device="cuda") + x = torch.randn(batch_size, samples, dtype=torch.double, device=device) + A = torch.rand(batch_size, samples, dtype=torch.double, device=device) * 2 - 1 + zi = torch.randn(batch_size, dtype=torch.double, device=device) + y = torch.randn(batch_size, samples, dtype=torch.double, device=device) A.requires_grad = True x.requires_grad = True @@ -64,7 +75,7 @@ def test_cuda_parallel_scan_vmap(): args = (x, A, zi) def func(x, A, zi): - return F.mse_loss(RecurrenceCUDA.apply(A, x, zi), y) + return F.mse_loss(Recurrence.apply(A, x, zi), y) jacs = jacfwd(func, argnums=tuple(range(len(args))))(*args) diff --git a/torchlpc/__init__.py b/torchlpc/__init__.py index 68e2fee..1e0aa95 100644 --- a/torchlpc/__init__.py +++ b/torchlpc/__init__.py @@ -2,16 +2,22 @@ from typing import Optional from pathlib import Path -from .core import LPC -from .parallel_scan import WARPSIZE -from .recurrence import RecurrenceCUDA +so_files = list(Path(__file__).parent.glob("_C*.so")) +# assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}" +if len(so_files) == 1: + torch.ops.load_library(so_files[0]) + EXTENSION_LOADED = True +elif len(so_files) > 1: + raise ValueError(f"Expected one _C*.so file, found {len(so_files)}") +else: + EXTENSION_LOADED = False -__all__ = ["sample_wise_lpc"] +from .core import LPC +# from .parallel_scan import WARPSIZE +from .recurrence import Recurrence -so_files = list(Path(__file__).parent.glob("_C*.so")) -assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}" -torch.ops.load_library(so_files[0]) +__all__ = ["sample_wise_lpc"] def sample_wise_lpc( @@ -43,7 +49,9 @@ def sample_wise_lpc( else: assert zi.shape == (B, order) - if order == 1 and x.is_cuda and B * WARPSIZE < T: - return RecurrenceCUDA.apply(-a.squeeze(2), x, zi.squeeze(1)) + # if order == 1 and x.is_cuda and B * WARPSIZE < T: + # return RecurrenceCUDA.apply(-a.squeeze(2), x, zi.squeeze(1)) + if order == 1: + return Recurrence.apply(-a.squeeze(2), x, zi.squeeze(1)) return LPC.apply(x, a, zi) From 1472acfffe2bae73ca335d35dab0e2efdb1ca878 Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Thu, 23 Jan 2025 19:49:23 +0800 Subject: [PATCH 11/15] refactor: remove contiguous check besides output tensor --- torchlpc/csrc/scan_cpu.cpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/torchlpc/csrc/scan_cpu.cpp b/torchlpc/csrc/scan_cpu.cpp index 9df56db..04d3292 100644 --- a/torchlpc/csrc/scan_cpu.cpp +++ b/torchlpc/csrc/scan_cpu.cpp @@ -16,19 +16,23 @@ void scan_cpu(const at::Tensor &input, const at::Tensor &weights, const at::Tens TORCH_INTERNAL_ASSERT(initials.device().is_cpu(), "Initials must be on CPU"); TORCH_INTERNAL_ASSERT(weights.device().is_cpu(), "Weights must be on CPU"); TORCH_INTERNAL_ASSERT(output.device().is_cpu(), "Output must be on CPU"); - TORCH_INTERNAL_ASSERT(input.is_contiguous(), "Input must be contiguous"); - TORCH_INTERNAL_ASSERT(weights.is_contiguous(), "Weights must be contiguous"); + // TORCH_INTERNAL_ASSERT(input.is_contiguous(), "Input must be contiguous"); + // TORCH_INTERNAL_ASSERT(weights.is_contiguous(), "Weights must be contiguous"); TORCH_INTERNAL_ASSERT(output.is_contiguous(), "Output must be contiguous"); + auto input_contiguous = input.contiguous(); + auto weights_contiguous = weights.contiguous(); + auto initials_contiguous = initials.contiguous(); + auto n_batch = input.size(0); auto T = input.size(1); auto total_size = input.numel(); std::pair buffer[total_size]; - const scalar_t *input_ptr = input.data_ptr(); - const scalar_t *initials_ptr = initials.data_ptr(); - const scalar_t *weights_ptr = weights.data_ptr(); + const scalar_t *input_ptr = input_contiguous.data_ptr(); + const scalar_t *initials_ptr = initials_contiguous.data_ptr(); + const scalar_t *weights_ptr = weights_contiguous.data_ptr(); scalar_t *output_ptr = output.data_ptr(); std::transform(weights_ptr, weights_ptr + total_size, input_ptr, buffer, From 5e7a9143d14259dc671cf6bdd819d6cdb4600193 Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Thu, 23 Jan 2025 19:57:26 +0800 Subject: [PATCH 12/15] refactor: add warning for missing _C*.so file and check extension loading in Recurrence --- torchlpc/__init__.py | 2 ++ torchlpc/recurrence.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/torchlpc/__init__.py b/torchlpc/__init__.py index 1e0aa95..79f86b6 100644 --- a/torchlpc/__init__.py +++ b/torchlpc/__init__.py @@ -1,6 +1,7 @@ import torch from typing import Optional from pathlib import Path +import warnings so_files = list(Path(__file__).parent.glob("_C*.so")) # assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}" @@ -10,6 +11,7 @@ elif len(so_files) > 1: raise ValueError(f"Expected one _C*.so file, found {len(so_files)}") else: + warnings.warn("No _C*.so file found. Custom extension not loaded.") EXTENSION_LOADED = False from .core import LPC diff --git a/torchlpc/recurrence.py b/torchlpc/recurrence.py index 2db8f91..05b9fd5 100644 --- a/torchlpc/recurrence.py +++ b/torchlpc/recurrence.py @@ -6,6 +6,7 @@ from .parallel_scan import compute_linear_recurrence, WARPSIZE from .core import lpc_cuda, lpc_np +from . import EXTENSION_LOADED class Recurrence(Function): @@ -32,7 +33,7 @@ def forward( else: num_threads = torch.get_num_threads() # This is just a rough estimation of the computational cost - if min(n_dims, num_threads) < num_threads / 3: + if EXTENSION_LOADED and min(n_dims, num_threads) < num_threads / 3: out = torch.ops.torchlpc.scan_cpu(impulse, decay, initial_state) else: out = torch.from_numpy( From 6cde45bc4f23253ea4540b61f812857dec370ad9 Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Thu, 23 Jan 2025 21:30:17 +0800 Subject: [PATCH 13/15] ci: add workflow step to build CPP extension and copy shared objects --- .github/workflows/python-package.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 03fbf19..a117ac0 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -36,6 +36,10 @@ jobs: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Build CPP extension + run: | + python setup.py build + find build/ -name "_C*.so" -exec cp {} ./torchlpc/ \; - name: Test with pytest run: | pytest From 32590c9b064067d5585473619fe995a3e685e7a9 Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Fri, 24 Jan 2025 18:19:42 +0800 Subject: [PATCH 14/15] apply suggestions and remove comments --- torchlpc/csrc/scan_cpu.cpp | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/torchlpc/csrc/scan_cpu.cpp b/torchlpc/csrc/scan_cpu.cpp index 04d3292..5db071c 100644 --- a/torchlpc/csrc/scan_cpu.cpp +++ b/torchlpc/csrc/scan_cpu.cpp @@ -5,7 +5,7 @@ #include template -void scan_cpu(const at::Tensor &input, const at::Tensor &weights, const at::Tensor &initials, at::Tensor &output) +void scan_cpu(const at::Tensor &input, const at::Tensor &weights, const at::Tensor &initials, const at::Tensor &output) { TORCH_CHECK(input.dim() == 2, "Input must be 2D"); TORCH_CHECK(initials.dim() == 1, "Initials must be 1D"); @@ -16,8 +16,6 @@ void scan_cpu(const at::Tensor &input, const at::Tensor &weights, const at::Tens TORCH_INTERNAL_ASSERT(initials.device().is_cpu(), "Initials must be on CPU"); TORCH_INTERNAL_ASSERT(weights.device().is_cpu(), "Weights must be on CPU"); TORCH_INTERNAL_ASSERT(output.device().is_cpu(), "Output must be on CPU"); - // TORCH_INTERNAL_ASSERT(input.is_contiguous(), "Input must be contiguous"); - // TORCH_INTERNAL_ASSERT(weights.is_contiguous(), "Weights must be contiguous"); TORCH_INTERNAL_ASSERT(output.is_contiguous(), "Output must be contiguous"); auto input_contiguous = input.contiguous(); @@ -39,19 +37,19 @@ void scan_cpu(const at::Tensor &input, const at::Tensor &weights, const at::Tens [](const scalar_t &a, const scalar_t &b) { return std::make_pair(a, b); }); - at::parallel_for(0, n_batch, 1, [&](int64_t start, int64_t end) - { - for (auto b = start; b < end; b++) - { + at::parallel_for(0, n_batch, 1, [ & ](int64_t start, int64_t end) { + for (auto b = start; b < end; b++) { std::inclusive_scan( buffer + b * T, buffer + (b + 1) * T, buffer + b * T, - [](const std::pair &a, const std::pair &b) { + [](const std::pair < scalar_t, scalar_t > & a, + const std::pair < scalar_t, scalar_t > & b) { return std::make_pair(a.first * b.first, a.second * b.first + b.second); }, - std::make_pair((scalar_t)1.0, initials_ptr[b])); - } }); + std::make_pair((scalar_t) 1.0, initials_ptr[b])); + } + }); std::transform(buffer, buffer + total_size, output_ptr, [](const std::pair &a) { return a.second; }); @@ -62,7 +60,6 @@ at::Tensor scan_cpu_wrapper(const at::Tensor &input, const at::Tensor &weights, TORCH_CHECK(input.is_floating_point() || input.is_complex(), "Input must be floating point or complex"); TORCH_CHECK(initials.scalar_type() == input.scalar_type(), "Initials must have the same scalar type as input"); TORCH_CHECK(weights.scalar_type() == input.scalar_type(), "Weights must have the same scalar type as input"); - // TORCH_CHECK(output.scalar_type() == input.scalar_type(), "Output must have the same scalar type as input"); auto output = at::empty_like(input); From c4517b6a5bc65960e56e5d2459ee7fcd5ff09fc9 Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Fri, 24 Jan 2025 18:27:06 +0800 Subject: [PATCH 15/15] refactor: apply google style format --- torchlpc/csrc/scan_cpu.cpp | 71 +++++++++++++++++++++----------------- 1 file changed, 39 insertions(+), 32 deletions(-) diff --git a/torchlpc/csrc/scan_cpu.cpp b/torchlpc/csrc/scan_cpu.cpp index 5db071c..dd85657 100644 --- a/torchlpc/csrc/scan_cpu.cpp +++ b/torchlpc/csrc/scan_cpu.cpp @@ -1,19 +1,25 @@ #include #include + #include -#include #include +#include template -void scan_cpu(const at::Tensor &input, const at::Tensor &weights, const at::Tensor &initials, const at::Tensor &output) -{ +void scan_cpu(const at::Tensor &input, const at::Tensor &weights, + const at::Tensor &initials, const at::Tensor &output) { TORCH_CHECK(input.dim() == 2, "Input must be 2D"); TORCH_CHECK(initials.dim() == 1, "Initials must be 1D"); - TORCH_CHECK(weights.sizes() == input.sizes(), "Weights must have the same size as input"); - TORCH_CHECK(output.sizes() == input.sizes(), "Output must have the same size as input"); - TORCH_CHECK(initials.size(0) == input.size(0), "The first dimension of initials must be the same as the first dimension of input"); + TORCH_CHECK(weights.sizes() == input.sizes(), + "Weights must have the same size as input"); + TORCH_CHECK(output.sizes() == input.sizes(), + "Output must have the same size as input"); + TORCH_CHECK(initials.size(0) == input.size(0), + "The first dimension of initials must be the same as the first " + "dimension of input"); TORCH_INTERNAL_ASSERT(input.device().is_cpu(), "Input must be on CPU"); - TORCH_INTERNAL_ASSERT(initials.device().is_cpu(), "Initials must be on CPU"); + TORCH_INTERNAL_ASSERT(initials.device().is_cpu(), + "Initials must be on CPU"); TORCH_INTERNAL_ASSERT(weights.device().is_cpu(), "Weights must be on CPU"); TORCH_INTERNAL_ASSERT(output.device().is_cpu(), "Output must be on CPU"); TORCH_INTERNAL_ASSERT(output.is_contiguous(), "Output must be contiguous"); @@ -34,46 +40,47 @@ void scan_cpu(const at::Tensor &input, const at::Tensor &weights, const at::Tens scalar_t *output_ptr = output.data_ptr(); std::transform(weights_ptr, weights_ptr + total_size, input_ptr, buffer, - [](const scalar_t &a, const scalar_t &b) - { return std::make_pair(a, b); }); + [](const scalar_t &a, const scalar_t &b) { + return std::make_pair(a, b); + }); - at::parallel_for(0, n_batch, 1, [ & ](int64_t start, int64_t end) { + at::parallel_for(0, n_batch, 1, [&](int64_t start, int64_t end) { for (auto b = start; b < end; b++) { std::inclusive_scan( - buffer + b * T, - buffer + (b + 1) * T, - buffer + b * T, - [](const std::pair < scalar_t, scalar_t > & a, - const std::pair < scalar_t, scalar_t > & b) { - return std::make_pair(a.first * b.first, a.second * b.first + b.second); + buffer + b * T, buffer + (b + 1) * T, buffer + b * T, + [](const std::pair &a, + const std::pair &b) { + return std::make_pair(a.first * b.first, + a.second * b.first + b.second); }, - std::make_pair((scalar_t) 1.0, initials_ptr[b])); + std::make_pair((scalar_t)1.0, initials_ptr[b])); } }); - std::transform(buffer, buffer + total_size, output_ptr, [](const std::pair &a) - { return a.second; }); + std::transform( + buffer, buffer + total_size, output_ptr, + [](const std::pair &a) { return a.second; }); } -at::Tensor scan_cpu_wrapper(const at::Tensor &input, const at::Tensor &weights, const at::Tensor &initials) -{ - TORCH_CHECK(input.is_floating_point() || input.is_complex(), "Input must be floating point or complex"); - TORCH_CHECK(initials.scalar_type() == input.scalar_type(), "Initials must have the same scalar type as input"); - TORCH_CHECK(weights.scalar_type() == input.scalar_type(), "Weights must have the same scalar type as input"); +at::Tensor scan_cpu_wrapper(const at::Tensor &input, const at::Tensor &weights, + const at::Tensor &initials) { + TORCH_CHECK(input.is_floating_point() || input.is_complex(), + "Input must be floating point or complex"); + TORCH_CHECK(initials.scalar_type() == input.scalar_type(), + "Initials must have the same scalar type as input"); + TORCH_CHECK(weights.scalar_type() == input.scalar_type(), + "Weights must have the same scalar type as input"); auto output = at::empty_like(input); - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "scan_cpu", [&] - { scan_cpu(input, weights, initials, output); }); + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( + input.scalar_type(), "scan_cpu", + [&] { scan_cpu(input, weights, initials, output); }); return output; } -TORCH_LIBRARY(torchlpc, m) -{ +TORCH_LIBRARY(torchlpc, m) { m.def("torchlpc::scan_cpu(Tensor a, Tensor b, Tensor c) -> Tensor"); } -TORCH_LIBRARY_IMPL(torchlpc, CPU, m) -{ - m.impl("scan_cpu", &scan_cpu_wrapper); -} +TORCH_LIBRARY_IMPL(torchlpc, CPU, m) { m.impl("scan_cpu", &scan_cpu_wrapper); }