diff --git a/setup.py b/setup.py index f476167..c52786f 100644 --- a/setup.py +++ b/setup.py @@ -1,8 +1,15 @@ import setuptools +import os +import glob import torch -from torch.utils import cpp_extension +from torch.utils.cpp_extension import ( + CppExtension, + CUDAExtension, + BuildExtension, + CUDA_HOME, +) -NAME = "torchlpc" +library_name = "torchlpc" VERSION = "0.7.dev" MAINTAINER = "Chin-Yun Yu" EMAIL = "chin-yun.yu@qmul.ac.uk" @@ -12,15 +19,51 @@ long_description = fh.read() +# if torch.__version__ >= "2.6.0": +# py_limited_api = True +# else: +py_limited_api = False + + +def get_extensions(): + use_cuda = torch.cuda.is_available() and CUDA_HOME is not None + use_openmp = torch.backends.openmp.is_available() + extension = CUDAExtension if use_cuda else CppExtension + + extra_link_args = [] + extra_compile_args = {} + if use_openmp: + extra_compile_args["cxx"] = ["-fopenmp"] + extra_link_args.append("-lgomp") + + this_dir = os.path.abspath(os.path.dirname(__file__)) + extensions_dir = os.path.join(this_dir, library_name, "csrc") + sources = list(glob.glob(os.path.join(extensions_dir, "*.cpp"))) + + extensions_cuda_dir = os.path.join(extensions_dir, "cuda") + cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "*.cu"))) + + if use_cuda: + sources += cuda_sources + + ext_modules = [ + extension( + f"{library_name}._C", + sources, + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, + py_limited_api=py_limited_api, + ) + ] + + return ext_modules + + extra_link_args = [] extra_compile_args = {} -# check if openmp is available -if torch.backends.openmp.is_available(): - extra_compile_args["cxx"] = ["-fopenmp"] - extra_link_args.append("-lgomp") setuptools.setup( - name=NAME, + name=library_name, version=VERSION, author=MAINTAINER, author_email=EMAIL, @@ -32,16 +75,10 @@ install_requires=["torch>=2.0", "numpy", "numba"], classifiers=[ "Programming Language :: Python :: 3", - "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ], - ext_modules=[ - cpp_extension.CppExtension( - "torchlpc._C", - ["torchlpc/csrc/scan_cpu.cpp"], - extra_compile_args=extra_compile_args, - extra_link_args=extra_link_args, - ) - ], - cmdclass={"build_ext": cpp_extension.BuildExtension}, + license="MIT", + ext_modules=get_extensions(), + cmdclass={"build_ext": BuildExtension}, + options={"bdist_wheel": {"py_limited_api": "cp39"}} if py_limited_api else {}, ) diff --git a/tests/test_extension.py b/tests/test_extension.py index bb281cc..c3fa107 100644 --- a/tests/test_extension.py +++ b/tests/test_extension.py @@ -1,7 +1,7 @@ import torch import torch.nn.functional as F import pytest -from torchlpc.core import lpc_np +from torchlpc.core import lpc_np, lpc_cuda from .test_grad import create_test_inputs @@ -15,24 +15,53 @@ "cmplx", [True, False], ) -def test_scan_cpu_equiv(samples: int, cmplx: bool): +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda", + marks=pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" + ), + ), + ], +) +def test_scan_equiv(samples: int, cmplx: bool, device: str): batch_size = 4 x = torch.randn( - batch_size, samples, dtype=torch.float32 if not cmplx else torch.complex64 + batch_size, + samples, + dtype=torch.float32 if not cmplx else torch.complex64, + device=device, ) - A = torch.rand_like(x) * 1.8 - 0.9 - zi = torch.randn(batch_size, dtype=x.dtype) - - numba_y = torch.from_numpy( - lpc_np( - x.cpu().numpy(), - -A.cpu().unsqueeze(2).numpy(), - zi.cpu().unsqueeze(1).numpy(), + if cmplx: + A = torch.rand( + batch_size, samples, dtype=x.dtype, device=device + ).sqrt() * torch.exp( + 2j + * torch.rand(batch_size, samples, dtype=x.dtype, device=device) + * torch.pi ) - ) - ext_y = torch.ops.torchlpc.scan_cpu(x, A, zi) + else: + A = torch.rand_like(x) * 1.8 - 0.9 + zi = torch.randn(batch_size, dtype=x.dtype, device=device) - assert torch.allclose(numba_y, ext_y) + if device == "cuda": + numba_y = lpc_cuda(x, -A.unsqueeze(2), zi.unsqueeze(1)) + else: + numba_y = torch.from_numpy( + lpc_np( + x.cpu().numpy(), + -A.cpu().unsqueeze(2).numpy(), + zi.cpu().unsqueeze(1).numpy(), + ) + ) + ext_y = torch.ops.torchlpc.scan(x, A, zi) + + assert torch.allclose(numba_y, ext_y, atol=5e-7), torch.max( + torch.abs(numba_y - ext_y) + ).item() @pytest.mark.parametrize( @@ -43,12 +72,12 @@ def test_scan_cpu_equiv(samples: int, cmplx: bool): "cmplx", [True, False], ) -def test_lpc_cpu_equiv(samples: int, cmplx: bool): +def test_lpc_equiv(samples: int, cmplx: bool): batch_size = 4 x, A, zi = tuple( x.to("cpu") for x in create_test_inputs(batch_size, samples, cmplx) ) numba_y = torch.from_numpy(lpc_np(x.numpy(), A.numpy(), zi.numpy())) - ext_y = torch.ops.torchlpc.lpc_cpu(x, A, zi) + ext_y = torch.ops.torchlpc.lpc(x, A, zi) assert torch.allclose(numba_y, ext_y) diff --git a/tests/test_grad.py b/tests/test_grad.py index 028c6f5..0279634 100644 --- a/tests/test_grad.py +++ b/tests/test_grad.py @@ -123,6 +123,10 @@ def test_float64_vs_32_cuda(): "zi_requires_grad", [True, False], ) +@pytest.mark.parametrize( + "cmplx", + [True, False], +) @pytest.mark.parametrize( "device", [ @@ -139,13 +143,25 @@ def test_parallel_scan( x_requires_grad: bool, a_requires_grad: bool, zi_requires_grad: bool, + cmplx: bool, device: str, ): batch_size = 2 samples = 123 - x = torch.randn(batch_size, samples, dtype=torch.double, device=device) - A = torch.rand(batch_size, samples, dtype=torch.double, device=device) * 2 - 1 - zi = torch.randn(batch_size, dtype=torch.double, device=device) + dtype = torch.complex128 if cmplx else torch.double + x = torch.randn(batch_size, samples, dtype=dtype, device=device) + if cmplx: + A = torch.rand( + batch_size, samples, dtype=torch.double, device=device + ).sqrt() * torch.exp( + 1j + * torch.rand(batch_size, samples, dtype=torch.double, device=device) + * 2 + * torch.pi + ) + else: + A = torch.rand(batch_size, samples, dtype=dtype, device=device) * 2 - 1 + zi = torch.randn(batch_size, dtype=dtype, device=device) A.requires_grad = a_requires_grad x.requires_grad = x_requires_grad diff --git a/torchlpc/__init__.py b/torchlpc/__init__.py index 79f86b6..620d360 100644 --- a/torchlpc/__init__.py +++ b/torchlpc/__init__.py @@ -3,16 +3,24 @@ from pathlib import Path import warnings -so_files = list(Path(__file__).parent.glob("_C*.so")) -# assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}" -if len(so_files) == 1: - torch.ops.load_library(so_files[0]) +# so_files = list(Path(__file__).parent.glob("_C*.so")) +# # assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}" +# if len(so_files) == 1: +# torch.ops.load_library(so_files[0]) +# EXTENSION_LOADED = True +# elif len(so_files) > 1: +# raise ValueError(f"Expected one _C*.so file, found {len(so_files)}") +# else: +# warnings.warn("No _C*.so file found. Custom extension not loaded.") +# EXTENSION_LOADED = False + +try: + from . import _C + EXTENSION_LOADED = True -elif len(so_files) > 1: - raise ValueError(f"Expected one _C*.so file, found {len(so_files)}") -else: - warnings.warn("No _C*.so file found. Custom extension not loaded.") +except ImportError: EXTENSION_LOADED = False + warnings.warn("Custom extension not loaded. Falling back to Numba implementation.") from .core import LPC diff --git a/torchlpc/core.py b/torchlpc/core.py index 543d54f..2101601 100644 --- a/torchlpc/core.py +++ b/torchlpc/core.py @@ -162,7 +162,7 @@ def forward(x: torch.Tensor, A: torch.Tensor, zi: torch.Tensor) -> torch.Tensor: if x.is_cuda: y = lpc_cuda(x.detach(), A.detach(), zi.detach()) elif EXTENSION_LOADED: - y = torch.ops.torchlpc.lpc_cpu(x, A, zi) + y = torch.ops.torchlpc.lpc(x, A, zi) else: warnings.warn( "Cannot find custom extension. Falling back to Numba implementation which will be deprecated in v1.0." diff --git a/torchlpc/csrc/cuda/LICENSE.txt b/torchlpc/csrc/cuda/LICENSE.txt new file mode 100644 index 0000000..5487462 --- /dev/null +++ b/torchlpc/csrc/cuda/LICENSE.txt @@ -0,0 +1,19 @@ +Copyright (c) <2017> + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/torchlpc/csrc/cuda/linear_recurrence.cu b/torchlpc/csrc/cuda/linear_recurrence.cu new file mode 100644 index 0000000..c723100 --- /dev/null +++ b/torchlpc/csrc/cuda/linear_recurrence.cu @@ -0,0 +1,291 @@ +#include +#include +#include +#include +#include +#include + +#define CEIL_DIV(x, y) ((x + y - 1) / y) + +#define gpuErrChk(ans) \ + { \ + gpuAssert((ans), __FILE__, __LINE__); \ + } +void gpuAssert(cudaError_t code, const char *file, int line) { + if (code != cudaSuccess) { + fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, + line); + } +} + +__device__ int2 divide_work(int n_jobs, int n_workers, int worker_idx) { + // Each worker will do a continuous slice of either n_jobs / n_workers + // or ceil_div(n_jobs, n_workers). The return value is an int2 representing + // a half open interval of jobs for the worker to perform (perform jobs + // i for a <= i < b) + + int cd = CEIL_DIV(n_jobs, n_workers); + int d = n_jobs / n_workers; + + int doing_cd = n_jobs % n_workers; + + int2 retval; + if (worker_idx < doing_cd) { + retval.x = worker_idx * cd; + retval.y = retval.x + cd; + } else { + retval.x = doing_cd * cd + (worker_idx - doing_cd) * d; + retval.y = retval.x + d; + } + + return retval; +} + +__device__ int2 compute_warp_start_stop(int block_idx, int warp_idx, + int n_blocks, int n_steps) { + int2 block_ss = divide_work(n_steps, n_blocks, block_idx); + int block_start = block_ss.x; + int block_stop = block_ss.y; + int block_jobs = block_stop - block_start; + + int2 warp_ss = divide_work(block_jobs, 32, warp_idx); + int warp_start = block_start + warp_ss.x; + int warp_stop = block_start + warp_ss.y; + + int2 retval; + retval.x = warp_start; + retval.y = warp_stop; + return retval; +} + +// decay storage, h_storage: +// each a n_dims x 33 x n_blocks matrix on GPU with 33rd column for block +// reduction +template +__global__ void reduction_kernel(const scalar_t *decays, + const scalar_t *impulses, + const scalar_t *initial_state, + scalar_t *_decay_storage, scalar_t *_h_storage, + int n_dims, int n_steps) { + int warp = threadIdx.x / 32; + int lane = threadIdx.x % 32; + + scalar_t *decay_storage = &_decay_storage[blockIdx.x * 33 * n_dims]; + scalar_t *h_storage = &_h_storage[blockIdx.x * 33 * n_dims]; + + int2 start_stop = + compute_warp_start_stop(blockIdx.x, lane, gridDim.x, n_steps); + int warp_start = start_stop.x; + int warp_stop = start_stop.y; + + /* + * Reduce within warps. + * After this loop exits, the storage arrays should contain the reduction + * from warp_start to warp_stop (including initial state) at index + * (feature_idx, warp, block). + */ + for (int i = warp; i < n_dims; i += CEIL_DIV(blockDim.x, 32)) { + scalar_t cum_decay = static_cast(1.0); + scalar_t h = static_cast(0.0); + if (blockIdx.x == 0 && lane == 0 && initial_state != NULL) { + h = initial_state[i]; + } + + for (int t = warp_start; t < warp_stop; t++) { + cum_decay *= decays[i * n_steps + t]; + h = decays[i * n_steps + t] * h + impulses[i * n_steps + t]; + } + + // TODO: store into shared memory, work in shared memory sized blocks + // store into global memory + decay_storage[i + lane * n_dims] = cum_decay; + h_storage[i + lane * n_dims] = h; + } + + __syncthreads(); + + /* + * Reduce over warps. + * After this loop exits, the storage arrays should contain the reduction + * from block_start to block_finish (including initial state) at index + * (feature_idx, 32, block). + */ + // TODO: parallel reduction (or scan). Need to worry about changing the warp + // reduction values (as I use them again later) + for (int i = threadIdx.x; i < n_dims; i += blockDim.x) { + scalar_t cum_decay = static_cast(1.0); + scalar_t h = static_cast(0.0); + for (int t = 0; t < 32; t++) { + cum_decay *= decay_storage[i + t * n_dims]; + h = decay_storage[i + t * n_dims] * h + h_storage[i + t * n_dims]; + } + decay_storage[i + 32 * n_dims] = cum_decay; + h_storage[i + 32 * n_dims] = h; + } +} + +template +__global__ void block_scan_kernel(scalar_t *decay_storage, scalar_t *h_storage, + int n_dims, int n_blocks) { + /* + * Scan over blocks. + * After this loop exits, the storage arrays should contain the cumulative + * sum from block_idx 0 to i (inclusive) at index (feature_idx, 32, i) This + * means (feature_idx, 32, 2) contains the reduction of blocks 0, 1, and 2. + */ + // TODO: parallel scan (tricky because number of blocks isn't necessarily + // smaller than number of warps that can fit in a single block) + for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < n_dims; + i += blockDim.x * gridDim.x) { + for (int t = 1; t < n_blocks; t++) { + int cur_idx = i + 32 * n_dims + t * 33 * n_dims; + int prev_idx = i + 32 * n_dims + (t - 1) * 33 * n_dims; + + // TODO: remove unneccessary reads from global memory (prev_idx + // accesses) + h_storage[cur_idx] = decay_storage[cur_idx] * h_storage[prev_idx] + + h_storage[cur_idx]; + decay_storage[cur_idx] *= decay_storage[prev_idx]; + } + } +} + +template +__global__ void warp_scan_kernel(const scalar_t *decays, + const scalar_t *impulses, + const scalar_t *initial_state, scalar_t *out, + scalar_t *decay_storage, scalar_t *h_storage, + int n_dims, int n_steps) { + int warp = threadIdx.x / 32; + int lane = threadIdx.x % 32; + + // Note: Due to the index ordering of the storage arrays, the following + // indices are equivalent: + // + // i + (t - 1) * n_dims + blockIdx.x * 33 * n_dims + // i + 32 * n_dims + (blockIdx.x - 1) * 33 * n_dims + // + // when t is 0. This means something that looks like negative indexing + // (t-1) can be used to safely access the stored value for the previous + // warp (even if the previous warp belonged to the previous block). + + /* + * Scan over warps. + * After this loop executes, the storage arrays should contain the + * cumulative sum from the beginning of sequence (including initial + * condition) up to and including the indexed warp and block. + */ + // TODO: parallel scan + for (int i = threadIdx.x; i < n_dims; i += blockDim.x) { + for (int t = 0; t < 32; t++) { + if (t == 0 && blockIdx.x == 0) { + // the reduction over warp 0 (including initial condition) is + // correct val for scan, so there's no work to do + continue; + } + + int cur_idx = i + t * n_dims + blockIdx.x * 33 * n_dims; + int prev_idx = i + (t - 1) * n_dims + blockIdx.x * 33 * n_dims; + h_storage[cur_idx] = decay_storage[cur_idx] * h_storage[prev_idx] + + h_storage[cur_idx]; + decay_storage[cur_idx] *= decay_storage[prev_idx]; + } + } + + __syncthreads(); + + int2 start_stop = + compute_warp_start_stop(blockIdx.x, lane, gridDim.x, n_steps); + int warp_start = start_stop.x; + int warp_stop = start_stop.y; + + /* + * Scan within warps. + * This loop writes to the output array. Each warp reads in it's initial + * state (either from the "initial_state" or the storage arrays) and then + * writes to output for indices warp_start up to warp_stop. + */ + for (int i = warp; i < n_dims; i += CEIL_DIV(blockDim.x, 32)) { + scalar_t h = static_cast(0.0); + if (blockIdx.x == 0 && lane == 0) { + if (initial_state != NULL) { + h = initial_state[i]; + } + } else { + h = h_storage[i + (lane - 1) * n_dims + blockIdx.x * 33 * n_dims]; + } + + for (int t = warp_start; t < warp_stop; t++) { + h = decays[i * n_steps + t] * h + impulses[i * n_steps + t]; + out[i * n_steps + t] = h; + } + } +} + +/* + * This is the main method for the prefix sum kernels. + * decays, impulses, out: + * each a n_dims x n_steps column major matrix located on GPU + * initial_state: + * array of size n_dims located on GPU + */ +template +void compute_linear_recurrence(const scalar_t *decays, const scalar_t *impulses, + const scalar_t *initial_state, scalar_t *out, + int n_dims, int n_steps) { + // we want at least 32 elements per block, but no reason to run + // with more than the maximum number of concurrent blocks + // NOTE: 128 is decided empirically. + int n_blocks = min(CEIL_DIV(n_steps, 32), 128); + + // TODO: make user pass in working memory? This allows integration + // with CNMeM (used by Theano) + int reduction_mem_sz = 2 * n_blocks * 33 * n_dims * sizeof(float); + scalar_t *d_reduction_mem; + gpuErrChk(cudaMalloc(&d_reduction_mem, reduction_mem_sz)); + scalar_t *d_decay_storage = &d_reduction_mem[0 * n_blocks * 33 * n_dims]; + scalar_t *d_h_storage = &d_reduction_mem[1 * n_blocks * 33 * n_dims]; + + // TODO: run kernels on non-default stream? + reduction_kernel<<>>(decays, impulses, initial_state, + d_decay_storage, d_h_storage, n_dims, + n_steps); + + block_scan_kernel<<>>(d_decay_storage, d_h_storage, n_dims, + n_blocks); + + warp_scan_kernel<<>>(decays, impulses, initial_state, out, + d_decay_storage, d_h_storage, n_dims, + n_steps); + + gpuErrChk(cudaFree(d_reduction_mem)); +} + +at::Tensor scan_cuda_wrapper(const at::Tensor &input, const at::Tensor &weights, + const at::Tensor &initials) { + TORCH_CHECK(input.is_floating_point() || input.is_complex(), + "Input must be floating point or complex"); + TORCH_CHECK(initials.scalar_type() == input.scalar_type(), + "Initials must have the same scalar type as input"); + TORCH_CHECK(weights.scalar_type() == input.scalar_type(), + "Weights must have the same scalar type as input"); + + auto input_contiguous = input.contiguous(); + auto weights_contiguous = weights.contiguous(); + auto output = at::empty_like(input_contiguous); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( + input.scalar_type(), "compute_linear_recurrence", [&] { + compute_linear_recurrence( + weights_contiguous.const_data_ptr(), + input_contiguous.const_data_ptr(), + initials.const_data_ptr(), + output.mutable_data_ptr(), input_contiguous.size(0), + input_contiguous.size(1)); + }); + return output.contiguous(); +} + +TORCH_LIBRARY_IMPL(torchlpc, CUDA, m) { m.impl("scan", &scan_cuda_wrapper); } diff --git a/torchlpc/csrc/scan_cpu.cpp b/torchlpc/csrc/scan_cpu.cpp index 8463341..6b47cad 100644 --- a/torchlpc/csrc/scan_cpu.cpp +++ b/torchlpc/csrc/scan_cpu.cpp @@ -1,3 +1,4 @@ +#include #include #include @@ -5,6 +6,23 @@ #include #include +extern "C" { +/* Creates a dummy empty _C module that can be imported from Python. + The import from Python will load the .so associated with this extension + built from this file, so that all the TORCH_LIBRARY calls below are run.*/ +PyObject *PyInit__C(void) { + static struct PyModuleDef module_def = { + PyModuleDef_HEAD_INIT, + "_C", /* name of module */ + NULL, /* module documentation, may be NULL */ + -1, /* size of per-interpreter state of the module, + or -1 if the module keeps state in global variables. */ + NULL, /* methods */ + }; + return PyModule_Create(&module_def); +} +} + template void scan_cpu(const at::Tensor &input, const at::Tensor &weights, const at::Tensor &initials, const at::Tensor &output) { @@ -34,10 +52,11 @@ void scan_cpu(const at::Tensor &input, const at::Tensor &weights, std::pair buffer[total_size]; - const scalar_t *input_ptr = input_contiguous.data_ptr(); - const scalar_t *initials_ptr = initials_contiguous.data_ptr(); - const scalar_t *weights_ptr = weights_contiguous.data_ptr(); - scalar_t *output_ptr = output.data_ptr(); + const scalar_t *input_ptr = input_contiguous.const_data_ptr(); + const scalar_t *initials_ptr = + initials_contiguous.const_data_ptr(); + const scalar_t *weights_ptr = weights_contiguous.const_data_ptr(); + scalar_t *output_ptr = output.mutable_data_ptr(); std::transform(weights_ptr, weights_ptr + total_size, input_ptr, buffer, [](const scalar_t &a, const scalar_t &b) { @@ -84,8 +103,8 @@ void lpc_cpu_core(const torch::Tensor &a, const torch::Tensor &padded_out) { auto a_contiguous = a.contiguous(); - const scalar_t *a_ptr = a_contiguous.data_ptr(); - scalar_t *out_ptr = padded_out.data_ptr(); + const scalar_t *a_ptr = a_contiguous.const_data_ptr(); + scalar_t *out_ptr = padded_out.mutable_data_ptr(); at::parallel_for(0, B, 1, [&](int64_t start, int64_t end) { for (auto b = start; b < end; b++) { @@ -142,11 +161,11 @@ at::Tensor lpc_cpu(const at::Tensor &x, const at::Tensor &a, } TORCH_LIBRARY(torchlpc, m) { - m.def("torchlpc::scan_cpu(Tensor a, Tensor b, Tensor c) -> Tensor"); - m.def("torchlpc::lpc_cpu(Tensor a, Tensor b, Tensor c) -> Tensor"); + m.def("torchlpc::scan(Tensor a, Tensor b, Tensor c) -> Tensor"); + m.def("torchlpc::lpc(Tensor a, Tensor b, Tensor c) -> Tensor"); } TORCH_LIBRARY_IMPL(torchlpc, CPU, m) { - m.impl("scan_cpu", &scan_cpu_wrapper); - m.impl("lpc_cpu", &lpc_cpu); + m.impl("scan", &scan_cpu_wrapper); + m.impl("lpc", &lpc_cpu); } diff --git a/torchlpc/recurrence.py b/torchlpc/recurrence.py index 05b9fd5..bcb5334 100644 --- a/torchlpc/recurrence.py +++ b/torchlpc/recurrence.py @@ -9,16 +9,16 @@ from . import EXTENSION_LOADED -class Recurrence(Function): - @staticmethod - def forward( - decay: torch.Tensor, - impulse: torch.Tensor, - initial_state: torch.Tensor, - ) -> torch.Tensor: - n_dims, n_steps = decay.shape - if decay.is_cuda: - if n_dims * WARPSIZE < n_steps: +def _cuda_recurrence( + impulse: torch.Tensor, decay: torch.Tensor, initial_state: torch.Tensor +) -> torch.Tensor: + n_dims, n_steps = decay.shape + if n_dims * WARPSIZE < n_steps: + if EXTENSION_LOADED: + runner = torch.ops.torchlpc.scan + else: + + def runner(impulse, decay, initial_state): out = torch.empty_like(impulse) compute_linear_recurrence( cuda.as_cuda_array(decay.detach()), @@ -28,21 +28,45 @@ def forward( n_dims, n_steps, ) - else: - out = lpc_cuda(impulse, -decay.unsqueeze(2), initial_state.unsqueeze(1)) + return out + + else: + runner = lambda impulse, decay, initial_state: lpc_cuda( + impulse, -decay.unsqueeze(2), initial_state.unsqueeze(1) + ) + return runner(impulse, decay, initial_state) + + +def _cpu_recurrence( + impulse: torch.Tensor, decay: torch.Tensor, initial_state: torch.Tensor +) -> torch.Tensor: + num_threads = torch.get_num_threads() + n_dims, _ = decay.shape + # This is just a rough estimation of the computational cost + if EXTENSION_LOADED and min(n_dims, num_threads) < num_threads / 3: + runner = torch.ops.torchlpc.scan + else: + runner = lambda impulse, decay, initial_state: torch.from_numpy( + lpc_np( + impulse.detach().numpy(), + -decay.unsqueeze(2).detach().numpy(), + initial_state.unsqueeze(1).detach().numpy(), + ) + ) + return runner(impulse, decay, initial_state) + + +class Recurrence(Function): + @staticmethod + def forward( + decay: torch.Tensor, + impulse: torch.Tensor, + initial_state: torch.Tensor, + ) -> torch.Tensor: + if decay.is_cuda: + out = _cuda_recurrence(impulse, decay, initial_state) else: - num_threads = torch.get_num_threads() - # This is just a rough estimation of the computational cost - if EXTENSION_LOADED and min(n_dims, num_threads) < num_threads / 3: - out = torch.ops.torchlpc.scan_cpu(impulse, decay, initial_state) - else: - out = torch.from_numpy( - lpc_np( - impulse.detach().numpy(), - -decay.unsqueeze(2).detach().numpy(), - initial_state.unsqueeze(1).detach().numpy(), - ) - ) + out = _cpu_recurrence(impulse, decay, initial_state) return out @staticmethod