|
1 | 1 | import setuptools |
| 2 | +import os |
| 3 | +import glob |
2 | 4 | import torch |
3 | | -from torch.utils import cpp_extension |
| 5 | +from torch.utils.cpp_extension import ( |
| 6 | + CppExtension, |
| 7 | + CUDAExtension, |
| 8 | + BuildExtension, |
| 9 | + CUDA_HOME, |
| 10 | +) |
4 | 11 |
|
5 | | -NAME = "torchlpc" |
| 12 | +library_name = "torchlpc" |
6 | 13 | VERSION = "0.7.dev" |
7 | 14 | MAINTAINER = "Chin-Yun Yu" |
8 | 15 | EMAIL = "chin-yun.yu@qmul.ac.uk" |
|
12 | 19 | long_description = fh.read() |
13 | 20 |
|
14 | 21 |
|
| 22 | +# if torch.__version__ >= "2.6.0": |
| 23 | +# py_limited_api = True |
| 24 | +# else: |
| 25 | +py_limited_api = False |
| 26 | + |
| 27 | + |
| 28 | +def get_extensions(): |
| 29 | + use_cuda = torch.cuda.is_available() and CUDA_HOME is not None |
| 30 | + use_openmp = torch.backends.openmp.is_available() |
| 31 | + extension = CUDAExtension if use_cuda else CppExtension |
| 32 | + |
| 33 | + extra_link_args = [] |
| 34 | + extra_compile_args = {} |
| 35 | + if use_openmp: |
| 36 | + extra_compile_args["cxx"] = ["-fopenmp"] |
| 37 | + extra_link_args.append("-lgomp") |
| 38 | + |
| 39 | + this_dir = os.path.dirname(os.path.curdir) |
| 40 | + extensions_dir = os.path.join(this_dir, library_name, "csrc") |
| 41 | + sources = list(glob.glob(os.path.join(extensions_dir, "*.cpp"))) |
| 42 | + |
| 43 | + extensions_cuda_dir = os.path.join(extensions_dir, "cuda") |
| 44 | + cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "*.cu"))) |
| 45 | + |
| 46 | + if use_cuda: |
| 47 | + sources += cuda_sources |
| 48 | + |
| 49 | + ext_modules = [ |
| 50 | + extension( |
| 51 | + f"{library_name}._C", |
| 52 | + sources, |
| 53 | + extra_compile_args=extra_compile_args, |
| 54 | + extra_link_args=extra_link_args, |
| 55 | + py_limited_api=py_limited_api, |
| 56 | + ) |
| 57 | + ] |
| 58 | + |
| 59 | + return ext_modules |
| 60 | + |
| 61 | + |
15 | 62 | extra_link_args = [] |
16 | 63 | extra_compile_args = {} |
17 | | -# check if openmp is available |
18 | | -if torch.backends.openmp.is_available(): |
19 | | - extra_compile_args["cxx"] = ["-fopenmp"] |
20 | | - extra_link_args.append("-lgomp") |
21 | 64 |
|
22 | 65 | setuptools.setup( |
23 | | - name=NAME, |
| 66 | + name=library_name, |
24 | 67 | version=VERSION, |
25 | 68 | author=MAINTAINER, |
26 | 69 | author_email=EMAIL, |
|
32 | 75 | install_requires=["torch>=2.0", "numpy", "numba"], |
33 | 76 | classifiers=[ |
34 | 77 | "Programming Language :: Python :: 3", |
35 | | - "License :: OSI Approved :: MIT License", |
36 | 78 | "Operating System :: OS Independent", |
37 | 79 | ], |
38 | | - ext_modules=[ |
39 | | - cpp_extension.CppExtension( |
40 | | - "torchlpc._C", |
41 | | - ["torchlpc/csrc/scan_cpu.cpp"], |
42 | | - extra_compile_args=extra_compile_args, |
43 | | - extra_link_args=extra_link_args, |
44 | | - ) |
45 | | - ], |
46 | | - cmdclass={"build_ext": cpp_extension.BuildExtension}, |
| 80 | + license="MIT", |
| 81 | + ext_modules=get_extensions(), |
| 82 | + cmdclass={"build_ext": BuildExtension}, |
| 83 | + options={"bdist_wheel": {"py_limited_api": "cp39"}} if py_limited_api else {}, |
47 | 84 | ) |
0 commit comments