|
| 1 | +#include <Python.h> |
1 | 2 | #include <torch/script.h> |
2 | 3 | #include <torch/torch.h> |
3 | 4 |
|
4 | 5 | #include <algorithm> |
5 | 6 | #include <utility> |
6 | 7 | #include <vector> |
7 | 8 |
|
| 9 | +extern "C" { |
| 10 | +/* Creates a dummy empty _C module that can be imported from Python. |
| 11 | + The import from Python will load the .so associated with this extension |
| 12 | + built from this file, so that all the TORCH_LIBRARY calls below are run.*/ |
| 13 | +PyObject *PyInit__C(void) { |
| 14 | + static struct PyModuleDef module_def = { |
| 15 | + PyModuleDef_HEAD_INIT, |
| 16 | + "_C", /* name of module */ |
| 17 | + NULL, /* module documentation, may be NULL */ |
| 18 | + -1, /* size of per-interpreter state of the module, |
| 19 | + or -1 if the module keeps state in global variables. */ |
| 20 | + NULL, /* methods */ |
| 21 | + }; |
| 22 | + return PyModule_Create(&module_def); |
| 23 | +} |
| 24 | +} |
| 25 | + |
8 | 26 | template <typename scalar_t> |
9 | 27 | void scan_cpu(const at::Tensor &input, const at::Tensor &weights, |
10 | 28 | const at::Tensor &initials, const at::Tensor &output) { |
@@ -34,10 +52,11 @@ void scan_cpu(const at::Tensor &input, const at::Tensor &weights, |
34 | 52 |
|
35 | 53 | std::pair<scalar_t, scalar_t> buffer[total_size]; |
36 | 54 |
|
37 | | - const scalar_t *input_ptr = input_contiguous.data_ptr<scalar_t>(); |
38 | | - const scalar_t *initials_ptr = initials_contiguous.data_ptr<scalar_t>(); |
39 | | - const scalar_t *weights_ptr = weights_contiguous.data_ptr<scalar_t>(); |
40 | | - scalar_t *output_ptr = output.data_ptr<scalar_t>(); |
| 55 | + const scalar_t *input_ptr = input_contiguous.const_data_ptr<scalar_t>(); |
| 56 | + const scalar_t *initials_ptr = |
| 57 | + initials_contiguous.const_data_ptr<scalar_t>(); |
| 58 | + const scalar_t *weights_ptr = weights_contiguous.const_data_ptr<scalar_t>(); |
| 59 | + scalar_t *output_ptr = output.mutable_data_ptr<scalar_t>(); |
41 | 60 |
|
42 | 61 | std::transform(weights_ptr, weights_ptr + total_size, input_ptr, buffer, |
43 | 62 | [](const scalar_t &a, const scalar_t &b) { |
@@ -84,8 +103,8 @@ void lpc_cpu_core(const torch::Tensor &a, const torch::Tensor &padded_out) { |
84 | 103 |
|
85 | 104 | auto a_contiguous = a.contiguous(); |
86 | 105 |
|
87 | | - const scalar_t *a_ptr = a_contiguous.data_ptr<scalar_t>(); |
88 | | - scalar_t *out_ptr = padded_out.data_ptr<scalar_t>(); |
| 106 | + const scalar_t *a_ptr = a_contiguous.const_data_ptr<scalar_t>(); |
| 107 | + scalar_t *out_ptr = padded_out.mutable_data_ptrscalar_t>(); |
89 | 108 |
|
90 | 109 | at::parallel_for(0, B, 1, [&](int64_t start, int64_t end) { |
91 | 110 | for (auto b = start; b < end; b++) { |
|
0 commit comments