Skip to content

Commit 98f314f

Browse files
committed
refactor: reorganize setup.py for building CUDA extensions
1 parent a4fd535 commit 98f314f

File tree

1 file changed

+54
-17
lines changed

1 file changed

+54
-17
lines changed

setup.py

Lines changed: 54 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
11
import setuptools
2+
import os
3+
import glob
24
import torch
3-
from torch.utils import cpp_extension
5+
from torch.utils.cpp_extension import (
6+
CppExtension,
7+
CUDAExtension,
8+
BuildExtension,
9+
CUDA_HOME,
10+
)
411

5-
NAME = "torchlpc"
12+
library_name = "torchlpc"
613
VERSION = "0.7.dev"
714
MAINTAINER = "Chin-Yun Yu"
815
EMAIL = "chin-yun.yu@qmul.ac.uk"
@@ -12,15 +19,51 @@
1219
long_description = fh.read()
1320

1421

22+
# if torch.__version__ >= "2.6.0":
23+
# py_limited_api = True
24+
# else:
25+
py_limited_api = False
26+
27+
28+
def get_extensions():
29+
use_cuda = torch.cuda.is_available() and CUDA_HOME is not None
30+
use_openmp = torch.backends.openmp.is_available()
31+
extension = CUDAExtension if use_cuda else CppExtension
32+
33+
extra_link_args = []
34+
extra_compile_args = {}
35+
if use_openmp:
36+
extra_compile_args["cxx"] = ["-fopenmp"]
37+
extra_link_args.append("-lgomp")
38+
39+
this_dir = os.path.dirname(os.path.curdir)
40+
extensions_dir = os.path.join(this_dir, library_name, "csrc")
41+
sources = list(glob.glob(os.path.join(extensions_dir, "*.cpp")))
42+
43+
extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
44+
cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "*.cu")))
45+
46+
if use_cuda:
47+
sources += cuda_sources
48+
49+
ext_modules = [
50+
extension(
51+
f"{library_name}._C",
52+
sources,
53+
extra_compile_args=extra_compile_args,
54+
extra_link_args=extra_link_args,
55+
py_limited_api=py_limited_api,
56+
)
57+
]
58+
59+
return ext_modules
60+
61+
1562
extra_link_args = []
1663
extra_compile_args = {}
17-
# check if openmp is available
18-
if torch.backends.openmp.is_available():
19-
extra_compile_args["cxx"] = ["-fopenmp"]
20-
extra_link_args.append("-lgomp")
2164

2265
setuptools.setup(
23-
name=NAME,
66+
name=library_name,
2467
version=VERSION,
2568
author=MAINTAINER,
2669
author_email=EMAIL,
@@ -32,16 +75,10 @@
3275
install_requires=["torch>=2.0", "numpy", "numba"],
3376
classifiers=[
3477
"Programming Language :: Python :: 3",
35-
"License :: OSI Approved :: MIT License",
3678
"Operating System :: OS Independent",
3779
],
38-
ext_modules=[
39-
cpp_extension.CppExtension(
40-
"torchlpc._C",
41-
["torchlpc/csrc/scan_cpu.cpp"],
42-
extra_compile_args=extra_compile_args,
43-
extra_link_args=extra_link_args,
44-
)
45-
],
46-
cmdclass={"build_ext": cpp_extension.BuildExtension},
80+
license="MIT",
81+
ext_modules=get_extensions(),
82+
cmdclass={"build_ext": BuildExtension},
83+
options={"bdist_wheel": {"py_limited_api": "cp39"}} if py_limited_api else {},
4784
)

0 commit comments

Comments
 (0)