-
Notifications
You must be signed in to change notification settings - Fork 5
feat: parallel scan extension for CPU #17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
1f392ba
draft: scan extension on cpu
yoyolicoris d77846a
include cpp extension in setup.py
yoyolicoris 37b3340
fix: remove extra arg
yoyolicoris e110d14
fix: compile errors
yoyolicoris 4611ce2
use dev versioning
yoyolicoris 5e28b30
load library file when being imported
yoyolicoris 5b69d2a
test equivalence to numba version
yoyolicoris 7b54fb1
refactor: return tensor instead of void function
yoyolicoris 1e7ea16
refactor: rename RecurrenceCUDA to Recurrence to cover CPU device
yoyolicoris 314c5a2
refactor: update functions to use Recurrence for CPU and CUDA devices
yoyolicoris 1472acf
refactor: remove contiguous check besides output tensor
yoyolicoris 5e7a914
refactor: add warning for missing _C*.so file and check extension loa…
yoyolicoris 6cde45b
ci: add workflow step to build CPP extension and copy shared objects
yoyolicoris 32590c9
apply suggestions and remove comments
yoyolicoris c4517b6
refactor: apply google style format
yoyolicoris File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| import torch | ||
| import torch.nn.functional as F | ||
| import pytest | ||
| from torchlpc.core import lpc_np | ||
|
|
||
|
|
||
| from .test_grad import create_test_inputs | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "samples", | ||
| [64, 4097], | ||
| ) | ||
| @pytest.mark.parametrize( | ||
| "cmplx", | ||
| [True, False], | ||
| ) | ||
| def test_scan_cpu_equiv(samples: int, cmplx: bool): | ||
| batch_size = 4 | ||
| x = torch.randn( | ||
| batch_size, samples, dtype=torch.float32 if not cmplx else torch.complex64 | ||
| ) | ||
| A = torch.rand_like(x) * 1.8 - 0.9 | ||
| zi = torch.randn(batch_size, dtype=x.dtype) | ||
|
|
||
| numba_y = torch.from_numpy( | ||
| lpc_np( | ||
| x.cpu().numpy(), | ||
| -A.cpu().unsqueeze(2).numpy(), | ||
| zi.cpu().unsqueeze(1).numpy(), | ||
| ) | ||
| ) | ||
| ext_y = torch.ops.torchlpc.scan_cpu(x, A, zi) | ||
|
|
||
| assert torch.allclose(numba_y, ext_y) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,86 @@ | ||
| #include <torch/script.h> | ||
| #include <torch/torch.h> | ||
|
|
||
| #include <algorithm> | ||
| #include <utility> | ||
| #include <vector> | ||
|
|
||
| template <typename scalar_t> | ||
| void scan_cpu(const at::Tensor &input, const at::Tensor &weights, | ||
| const at::Tensor &initials, const at::Tensor &output) { | ||
| TORCH_CHECK(input.dim() == 2, "Input must be 2D"); | ||
| TORCH_CHECK(initials.dim() == 1, "Initials must be 1D"); | ||
| TORCH_CHECK(weights.sizes() == input.sizes(), | ||
| "Weights must have the same size as input"); | ||
| TORCH_CHECK(output.sizes() == input.sizes(), | ||
| "Output must have the same size as input"); | ||
| TORCH_CHECK(initials.size(0) == input.size(0), | ||
| "The first dimension of initials must be the same as the first " | ||
| "dimension of input"); | ||
| TORCH_INTERNAL_ASSERT(input.device().is_cpu(), "Input must be on CPU"); | ||
| TORCH_INTERNAL_ASSERT(initials.device().is_cpu(), | ||
| "Initials must be on CPU"); | ||
| TORCH_INTERNAL_ASSERT(weights.device().is_cpu(), "Weights must be on CPU"); | ||
| TORCH_INTERNAL_ASSERT(output.device().is_cpu(), "Output must be on CPU"); | ||
| TORCH_INTERNAL_ASSERT(output.is_contiguous(), "Output must be contiguous"); | ||
|
|
||
| auto input_contiguous = input.contiguous(); | ||
yoyolicoris marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| auto weights_contiguous = weights.contiguous(); | ||
| auto initials_contiguous = initials.contiguous(); | ||
yoyolicoris marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| auto n_batch = input.size(0); | ||
| auto T = input.size(1); | ||
| auto total_size = input.numel(); | ||
|
|
||
| std::pair<scalar_t, scalar_t> buffer[total_size]; | ||
|
|
||
| const scalar_t *input_ptr = input_contiguous.data_ptr<scalar_t>(); | ||
| const scalar_t *initials_ptr = initials_contiguous.data_ptr<scalar_t>(); | ||
| const scalar_t *weights_ptr = weights_contiguous.data_ptr<scalar_t>(); | ||
| scalar_t *output_ptr = output.data_ptr<scalar_t>(); | ||
|
|
||
| std::transform(weights_ptr, weights_ptr + total_size, input_ptr, buffer, | ||
| [](const scalar_t &a, const scalar_t &b) { | ||
| return std::make_pair(a, b); | ||
| }); | ||
|
|
||
| at::parallel_for(0, n_batch, 1, [&](int64_t start, int64_t end) { | ||
| for (auto b = start; b < end; b++) { | ||
| std::inclusive_scan( | ||
| buffer + b * T, buffer + (b + 1) * T, buffer + b * T, | ||
| [](const std::pair<scalar_t, scalar_t> &a, | ||
| const std::pair<scalar_t, scalar_t> &b) { | ||
| return std::make_pair(a.first * b.first, | ||
| a.second * b.first + b.second); | ||
| }, | ||
| std::make_pair((scalar_t)1.0, initials_ptr[b])); | ||
| } | ||
| }); | ||
|
|
||
| std::transform( | ||
| buffer, buffer + total_size, output_ptr, | ||
| [](const std::pair<scalar_t, scalar_t> &a) { return a.second; }); | ||
| } | ||
|
|
||
| at::Tensor scan_cpu_wrapper(const at::Tensor &input, const at::Tensor &weights, | ||
| const at::Tensor &initials) { | ||
| TORCH_CHECK(input.is_floating_point() || input.is_complex(), | ||
| "Input must be floating point or complex"); | ||
| TORCH_CHECK(initials.scalar_type() == input.scalar_type(), | ||
| "Initials must have the same scalar type as input"); | ||
| TORCH_CHECK(weights.scalar_type() == input.scalar_type(), | ||
| "Weights must have the same scalar type as input"); | ||
|
|
||
| auto output = at::empty_like(input); | ||
|
|
||
| AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( | ||
| input.scalar_type(), "scan_cpu", | ||
| [&] { scan_cpu<scalar_t>(input, weights, initials, output); }); | ||
| return output; | ||
| } | ||
|
|
||
| TORCH_LIBRARY(torchlpc, m) { | ||
| m.def("torchlpc::scan_cpu(Tensor a, Tensor b, Tensor c) -> Tensor"); | ||
| } | ||
|
|
||
| TORCH_LIBRARY_IMPL(torchlpc, CPU, m) { m.impl("scan_cpu", &scan_cpu_wrapper); } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.