From b342f5d6fa77338397b579679cbcab9c93d32d0f Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Fri, 24 Jan 2025 21:56:45 +0800 Subject: [PATCH 1/4] feat: add LPC CPU implementation and wrapper function --- torchlpc/csrc/scan_cpu.cpp | 68 +++++++++++++++++++++++++++++++++++++- 1 file changed, 67 insertions(+), 1 deletion(-) diff --git a/torchlpc/csrc/scan_cpu.cpp b/torchlpc/csrc/scan_cpu.cpp index dd85657..8463341 100644 --- a/torchlpc/csrc/scan_cpu.cpp +++ b/torchlpc/csrc/scan_cpu.cpp @@ -62,6 +62,47 @@ void scan_cpu(const at::Tensor &input, const at::Tensor &weights, [](const std::pair &a) { return a.second; }); } +template +void lpc_cpu_core(const torch::Tensor &a, const torch::Tensor &padded_out) { + // Ensure input dimensions are correct + TORCH_CHECK(a.dim() == 3, "a must be 3-dimensional"); + TORCH_CHECK(padded_out.dim() == 2, "out must be 2-dimensional"); + TORCH_CHECK(padded_out.size(0) == a.size(0), + "Batch size of out and x must match"); + TORCH_CHECK(padded_out.size(1) == (a.size(1) + a.size(2)), + "Time dimension of out must match x and a"); + TORCH_INTERNAL_ASSERT(a.device().is_cpu(), "a must be on CPU"); + TORCH_INTERNAL_ASSERT(padded_out.device().is_cpu(), + "Output must be on CPU"); + TORCH_INTERNAL_ASSERT(padded_out.is_contiguous(), + "Output must be contiguous"); + + // Get the dimensions + const auto B = a.size(0); + const auto T = a.size(1); + const auto order = a.size(2); + + auto a_contiguous = a.contiguous(); + + const scalar_t *a_ptr = a_contiguous.data_ptr(); + scalar_t *out_ptr = padded_out.data_ptr(); + + at::parallel_for(0, B, 1, [&](int64_t start, int64_t end) { + for (auto b = start; b < end; b++) { + auto out_offset = b * (T + order) + order; + auto a_offset = b * T * order; + for (int64_t t = 0; t < T; t++) { + scalar_t y = out_ptr[out_offset + t]; + for (int64_t i = 0; i < order; i++) { + y -= a_ptr[a_offset + t * order + i] * + out_ptr[out_offset + t - i - 1]; + } + out_ptr[out_offset + t] = y; + } + } + }); +} + at::Tensor scan_cpu_wrapper(const at::Tensor &input, const at::Tensor &weights, const at::Tensor &initials) { TORCH_CHECK(input.is_floating_point() || input.is_complex(), @@ -79,8 +120,33 @@ at::Tensor scan_cpu_wrapper(const at::Tensor &input, const at::Tensor &weights, return output; } +at::Tensor lpc_cpu(const at::Tensor &x, const at::Tensor &a, + const at::Tensor &zi) { + TORCH_CHECK(x.is_floating_point() || x.is_complex(), + "Input must be floating point or complex"); + TORCH_CHECK(a.scalar_type() == x.scalar_type(), + "Coefficients must have the same scalar type as input"); + TORCH_CHECK(zi.scalar_type() == x.scalar_type(), + "Initial conditions must have the same scalar type as input"); + + TORCH_CHECK(x.dim() == 2, "Input must be 2D"); + TORCH_CHECK(zi.dim() == 2, "Initial conditions must be 2D"); + TORCH_CHECK(x.size(0) == zi.size(0), + "Batch size of input and initial conditions must match"); + + auto out = at::cat({zi.flip(1), x}, 1).contiguous(); + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( + x.scalar_type(), "lpc_cpu", [&] { lpc_cpu_core(a, out); }); + return out.slice(1, zi.size(1), out.size(1)).contiguous(); +} + TORCH_LIBRARY(torchlpc, m) { m.def("torchlpc::scan_cpu(Tensor a, Tensor b, Tensor c) -> Tensor"); + m.def("torchlpc::lpc_cpu(Tensor a, Tensor b, Tensor c) -> Tensor"); } -TORCH_LIBRARY_IMPL(torchlpc, CPU, m) { m.impl("scan_cpu", &scan_cpu_wrapper); } +TORCH_LIBRARY_IMPL(torchlpc, CPU, m) { + m.impl("scan_cpu", &scan_cpu_wrapper); + m.impl("lpc_cpu", &lpc_cpu); +} From b31b5e7c5461ae20f0f1fa4f74fb8bb3bd33fd0b Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Fri, 24 Jan 2025 21:57:03 +0800 Subject: [PATCH 2/4] test: add equivalence tests for lpc_cpu function --- tests/test_extension.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/test_extension.py b/tests/test_extension.py index ef94e49..bb281cc 100644 --- a/tests/test_extension.py +++ b/tests/test_extension.py @@ -33,3 +33,22 @@ def test_scan_cpu_equiv(samples: int, cmplx: bool): ext_y = torch.ops.torchlpc.scan_cpu(x, A, zi) assert torch.allclose(numba_y, ext_y) + + +@pytest.mark.parametrize( + "samples", + [1024], +) +@pytest.mark.parametrize( + "cmplx", + [True, False], +) +def test_lpc_cpu_equiv(samples: int, cmplx: bool): + batch_size = 4 + x, A, zi = tuple( + x.to("cpu") for x in create_test_inputs(batch_size, samples, cmplx) + ) + numba_y = torch.from_numpy(lpc_np(x.numpy(), A.numpy(), zi.numpy())) + ext_y = torch.ops.torchlpc.lpc_cpu(x, A, zi) + + assert torch.allclose(numba_y, ext_y) From a4898def046b7475e0621b5332935d320dcd86b9 Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Fri, 24 Jan 2025 21:57:39 +0800 Subject: [PATCH 3/4] feat: openmp compilation flag --- setup.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f16fc6d..f476167 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,5 @@ import setuptools +import torch from torch.utils import cpp_extension NAME = "torchlpc" @@ -10,6 +11,14 @@ with open("README.md", "r") as fh: long_description = fh.read() + +extra_link_args = [] +extra_compile_args = {} +# check if openmp is available +if torch.backends.openmp.is_available(): + extra_compile_args["cxx"] = ["-fopenmp"] + extra_link_args.append("-lgomp") + setuptools.setup( name=NAME, version=VERSION, @@ -27,7 +36,12 @@ "Operating System :: OS Independent", ], ext_modules=[ - cpp_extension.CppExtension("torchlpc._C", ["torchlpc/csrc/scan_cpu.cpp"]) + cpp_extension.CppExtension( + "torchlpc._C", + ["torchlpc/csrc/scan_cpu.cpp"], + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, + ) ], cmdclass={"build_ext": cpp_extension.BuildExtension}, ) From 06d830e575e4d58308e58689a18f7288defc800f Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Fri, 24 Jan 2025 23:24:01 +0800 Subject: [PATCH 4/4] feat: use cpp lpc and add deprecation warning --- torchlpc/core.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torchlpc/core.py b/torchlpc/core.py index 1ae9a64..543d54f 100644 --- a/torchlpc/core.py +++ b/torchlpc/core.py @@ -1,3 +1,4 @@ +import warnings import torch import numpy as np import torch.nn.functional as F @@ -5,6 +6,7 @@ from typing import Any, Tuple, Optional, Callable, List from numba import jit, njit, prange, cuda, float32, float64, complex64, complex128 +from . import EXTENSION_LOADED lpc_cuda_kernel_float32: Callable = None lpc_cuda_kernel_float64: Callable = None @@ -159,7 +161,12 @@ class LPC(Function): def forward(x: torch.Tensor, A: torch.Tensor, zi: torch.Tensor) -> torch.Tensor: if x.is_cuda: y = lpc_cuda(x.detach(), A.detach(), zi.detach()) + elif EXTENSION_LOADED: + y = torch.ops.torchlpc.lpc_cpu(x, A, zi) else: + warnings.warn( + "Cannot find custom extension. Falling back to Numba implementation which will be deprecated in v1.0." + ) y = lpc_np( x.detach().cpu().numpy(), A.detach().cpu().numpy(),