Skip to content

Commit 086bbb4

Browse files
committed
refactor: create dummy _C module for python loading
1 parent 6d4a2de commit 086bbb4

File tree

2 files changed

+41
-14
lines changed

2 files changed

+41
-14
lines changed

torchlpc/__init__.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,24 @@
33
from pathlib import Path
44
import warnings
55

6-
so_files = list(Path(__file__).parent.glob("_C*.so"))
7-
# assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
8-
if len(so_files) == 1:
9-
torch.ops.load_library(so_files[0])
6+
# so_files = list(Path(__file__).parent.glob("_C*.so"))
7+
# # assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
8+
# if len(so_files) == 1:
9+
# torch.ops.load_library(so_files[0])
10+
# EXTENSION_LOADED = True
11+
# elif len(so_files) > 1:
12+
# raise ValueError(f"Expected one _C*.so file, found {len(so_files)}")
13+
# else:
14+
# warnings.warn("No _C*.so file found. Custom extension not loaded.")
15+
# EXTENSION_LOADED = False
16+
17+
try:
18+
from . import _C
19+
1020
EXTENSION_LOADED = True
11-
elif len(so_files) > 1:
12-
raise ValueError(f"Expected one _C*.so file, found {len(so_files)}")
13-
else:
14-
warnings.warn("No _C*.so file found. Custom extension not loaded.")
21+
except ImportError:
1522
EXTENSION_LOADED = False
23+
warnings.warn("Custom extension not loaded. Falling back to Numba implementation.")
1624

1725
from .core import LPC
1826

torchlpc/csrc/scan_cpu.cpp

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,28 @@
1+
#include <Python.h>
12
#include <torch/script.h>
23
#include <torch/torch.h>
34

45
#include <algorithm>
56
#include <utility>
67
#include <vector>
78

9+
extern "C" {
10+
/* Creates a dummy empty _C module that can be imported from Python.
11+
The import from Python will load the .so associated with this extension
12+
built from this file, so that all the TORCH_LIBRARY calls below are run.*/
13+
PyObject *PyInit__C(void) {
14+
static struct PyModuleDef module_def = {
15+
PyModuleDef_HEAD_INIT,
16+
"_C", /* name of module */
17+
NULL, /* module documentation, may be NULL */
18+
-1, /* size of per-interpreter state of the module,
19+
or -1 if the module keeps state in global variables. */
20+
NULL, /* methods */
21+
};
22+
return PyModule_Create(&module_def);
23+
}
24+
}
25+
826
template <typename scalar_t>
927
void scan_cpu(const at::Tensor &input, const at::Tensor &weights,
1028
const at::Tensor &initials, const at::Tensor &output) {
@@ -34,10 +52,11 @@ void scan_cpu(const at::Tensor &input, const at::Tensor &weights,
3452

3553
std::pair<scalar_t, scalar_t> buffer[total_size];
3654

37-
const scalar_t *input_ptr = input_contiguous.data_ptr<scalar_t>();
38-
const scalar_t *initials_ptr = initials_contiguous.data_ptr<scalar_t>();
39-
const scalar_t *weights_ptr = weights_contiguous.data_ptr<scalar_t>();
40-
scalar_t *output_ptr = output.data_ptr<scalar_t>();
55+
const scalar_t *input_ptr = input_contiguous.const_data_ptr<scalar_t>();
56+
const scalar_t *initials_ptr =
57+
initials_contiguous.const_data_ptr<scalar_t>();
58+
const scalar_t *weights_ptr = weights_contiguous.const_data_ptr<scalar_t>();
59+
scalar_t *output_ptr = output.mutable_data_ptr<scalar_t>();
4160

4261
std::transform(weights_ptr, weights_ptr + total_size, input_ptr, buffer,
4362
[](const scalar_t &a, const scalar_t &b) {
@@ -84,8 +103,8 @@ void lpc_cpu_core(const torch::Tensor &a, const torch::Tensor &padded_out) {
84103

85104
auto a_contiguous = a.contiguous();
86105

87-
const scalar_t *a_ptr = a_contiguous.data_ptr<scalar_t>();
88-
scalar_t *out_ptr = padded_out.data_ptr<scalar_t>();
106+
const scalar_t *a_ptr = a_contiguous.const_data_ptr<scalar_t>();
107+
scalar_t *out_ptr = padded_out.mutable_data_ptrscalar_t>();
89108

90109
at::parallel_for(0, B, 1, [&](int64_t start, int64_t end) {
91110
for (auto b = start; b < end; b++) {

0 commit comments

Comments
 (0)