66#include < utility>
77#include < vector>
88
9- extern " C" {
10- /* Creates a dummy empty _C module that can be imported from Python.
11- The import from Python will load the .so associated with this extension
12- built from this file, so that all the TORCH_LIBRARY calls below are run.*/
13- PyObject *PyInit__C (void ) {
14- static struct PyModuleDef module_def = {
15- PyModuleDef_HEAD_INIT,
16- " _C" , /* name of module */
17- NULL , /* module documentation, may be NULL */
18- -1 , /* size of per-interpreter state of the module,
19- or -1 if the module keeps state in global variables. */
20- NULL , /* methods */
21- };
22- return PyModule_Create (&module_def);
23- }
9+ extern " C"
10+ {
11+ /* Creates a dummy empty _C module that can be imported from Python.
12+ The import from Python will load the .so associated with this extension
13+ built from this file, so that all the TORCH_LIBRARY calls below are run.*/
14+ PyObject *PyInit__C (void )
15+ {
16+ static struct PyModuleDef module_def = {
17+ PyModuleDef_HEAD_INIT,
18+ " _C" , /* name of module */
19+ NULL , /* module documentation, may be NULL */
20+ -1 , /* size of per-interpreter state of the module,
21+ or -1 if the module keeps state in global variables. */
22+ NULL , /* methods */
23+ };
24+ return PyModule_Create (&module_def);
25+ }
2426}
2527
2628template <typename scalar_t >
2729void scan_cpu (const at::Tensor &input, const at::Tensor &weights,
28- const at::Tensor &initials, const at::Tensor &output) {
30+ const at::Tensor &initials, const at::Tensor &output)
31+ {
2932 TORCH_CHECK (input.dim () == 2 , " Input must be 2D" );
3033 TORCH_CHECK (initials.dim () == 1 , " Initials must be 1D" );
3134 TORCH_CHECK (weights.sizes () == input.sizes (),
@@ -50,39 +53,33 @@ void scan_cpu(const at::Tensor &input, const at::Tensor &weights,
5053 auto T = input.size (1 );
5154 auto total_size = input.numel ();
5255
53- std::pair<scalar_t , scalar_t > buffer[total_size];
54-
5556 const scalar_t *input_ptr = input_contiguous.const_data_ptr <scalar_t >();
5657 const scalar_t *initials_ptr =
5758 initials_contiguous.const_data_ptr <scalar_t >();
5859 const scalar_t *weights_ptr = weights_contiguous.const_data_ptr <scalar_t >();
5960 scalar_t *output_ptr = output.mutable_data_ptr <scalar_t >();
6061
61- std::transform (weights_ptr, weights_ptr + total_size, input_ptr, buffer,
62- [](const scalar_t &a, const scalar_t &b) {
63- return std::make_pair (a, b);
64- });
65-
66- at::parallel_for (0 , n_batch, 1 , [&](int64_t start, int64_t end) {
67- for (auto b = start; b < end; b++) {
68- std::inclusive_scan (
69- buffer + b * T, buffer + (b + 1 ) * T, buffer + b * T,
70- [](const std::pair<scalar_t , scalar_t > &a,
71- const std::pair<scalar_t , scalar_t > &b) {
72- return std::make_pair (a.first * b.first ,
73- a.second * b.first + b.second );
74- },
75- std::make_pair ((scalar_t )1.0 , initials_ptr[b]));
76- }
77- });
78-
79- std::transform (
80- buffer, buffer + total_size, output_ptr,
81- [](const std::pair<scalar_t , scalar_t > &a) { return a.second ; });
62+ at::parallel_for (0 , n_batch, 1 , [&](int64_t start, int64_t end)
63+ {
64+ for (auto b = start; b < end; b++)
65+ {
66+ auto initial = initials_ptr[b];
67+ auto weights_offset = weights_ptr + b * T;
68+ auto input_offset = input_ptr + b * T;
69+ auto output_offset = output_ptr + b * T;
70+ for (int64_t t = 0 ; t < T; t++)
71+ {
72+ auto w = weights_offset[t];
73+ auto x = input_offset[t];
74+ initial = initial * w + x;
75+ output_offset[t] = initial;
76+ }
77+ }; });
8278}
8379
8480template <typename scalar_t >
85- void lpc_cpu_core (const torch::Tensor &a, const torch::Tensor &padded_out) {
81+ void lpc_cpu_core (const torch::Tensor &a, const torch::Tensor &padded_out)
82+ {
8683 // Ensure input dimensions are correct
8784 TORCH_CHECK (a.dim () == 3 , " a must be 3-dimensional" );
8885 TORCH_CHECK (padded_out.dim () == 2 , " out must be 2-dimensional" );
@@ -106,24 +103,27 @@ void lpc_cpu_core(const torch::Tensor &a, const torch::Tensor &padded_out) {
106103 const scalar_t *a_ptr = a_contiguous.const_data_ptr <scalar_t >();
107104 scalar_t *out_ptr = padded_out.mutable_data_ptr <scalar_t >();
108105
109- at::parallel_for (0 , B, 1 , [&](int64_t start, int64_t end) {
110- for (auto b = start; b < end; b++) {
111- auto out_offset = b * (T + order) + order;
112- auto a_offset = b * T * order;
113- for (int64_t t = 0 ; t < T; t++) {
114- scalar_t y = out_ptr[out_offset + t];
115- for (int64_t i = 0 ; i < order; i++) {
116- y -= a_ptr[a_offset + t * order + i] *
117- out_ptr[out_offset + t - i - 1 ];
106+ at::parallel_for (0 , B, 1 , [&](int64_t start, int64_t end)
107+ {
108+ for (auto b = start; b < end; b++)
109+ {
110+ auto out_offset = out_ptr + b * (T + order) + order;
111+ auto a_offset = a_ptr + b * T * order;
112+ for (int64_t t = 0 ; t < T; t++)
113+ {
114+ scalar_t y = out_offset[t];
115+ for (int64_t i = 0 ; i < order; i++)
116+ {
117+ y -= a_offset[t * order + i] * out_offset [t - i - 1 ];
118118 }
119- out_ptr[ out_offset + t] = y;
119+ out_offset[ t] = y;
120120 }
121- }
122- });
121+ }; });
123122}
124123
125124at::Tensor scan_cpu_wrapper (const at::Tensor &input, const at::Tensor &weights,
126- const at::Tensor &initials) {
125+ const at::Tensor &initials)
126+ {
127127 TORCH_CHECK (input.is_floating_point () || input.is_complex (),
128128 " Input must be floating point or complex" );
129129 TORCH_CHECK (initials.scalar_type () == input.scalar_type (),
@@ -135,12 +135,14 @@ at::Tensor scan_cpu_wrapper(const at::Tensor &input, const at::Tensor &weights,
135135
136136 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES (
137137 input.scalar_type (), " scan_cpu" ,
138- [&] { scan_cpu<scalar_t >(input, weights, initials, output); });
138+ [&]
139+ { scan_cpu<scalar_t >(input, weights, initials, output); });
139140 return output;
140141}
141142
142143at::Tensor lpc_cpu (const at::Tensor &x, const at::Tensor &a,
143- const at::Tensor &zi) {
144+ const at::Tensor &zi)
145+ {
144146 TORCH_CHECK (x.is_floating_point () || x.is_complex (),
145147 " Input must be floating point or complex" );
146148 TORCH_CHECK (a.scalar_type () == x.scalar_type (),
@@ -156,16 +158,19 @@ at::Tensor lpc_cpu(const at::Tensor &x, const at::Tensor &a,
156158 auto out = at::cat ({zi.flip (1 ), x}, 1 ).contiguous ();
157159
158160 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES (
159- x.scalar_type (), " lpc_cpu" , [&] { lpc_cpu_core<scalar_t >(a, out); });
161+ x.scalar_type (), " lpc_cpu" , [&]
162+ { lpc_cpu_core<scalar_t >(a, out); });
160163 return out.slice (1 , zi.size (1 ), out.size (1 )).contiguous ();
161164}
162165
163- TORCH_LIBRARY (torchlpc, m) {
166+ TORCH_LIBRARY (torchlpc, m)
167+ {
164168 m.def (" torchlpc::scan(Tensor a, Tensor b, Tensor c) -> Tensor" );
165169 m.def (" torchlpc::lpc(Tensor a, Tensor b, Tensor c) -> Tensor" );
166170}
167171
168- TORCH_LIBRARY_IMPL (torchlpc, CPU, m) {
172+ TORCH_LIBRARY_IMPL (torchlpc, CPU, m)
173+ {
169174 m.impl (" scan" , &scan_cpu_wrapper);
170175 m.impl (" lpc" , &lpc_cpu);
171176}
0 commit comments