Skip to content

Commit c4cd235

Browse files
committed
Actually check the values of objective in tests and temporary fix for building without openmp
1 parent 5df6c53 commit c4cd235

File tree

2 files changed

+110
-4
lines changed

2 files changed

+110
-4
lines changed

build_extension.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import os
2+
import subprocess
3+
import warnings
14
from typing import Any, Dict
25

36
from setuptools import Extension
@@ -21,18 +24,67 @@
2124
include_dirs=["kmsr"],
2225
)
2326

24-
# Thank you https://github.com/dstein64/kmeans1d!
27+
28+
def check_openmp_support() -> bool:
29+
openmp_test_code = """
30+
#include <omp.h>
31+
#include <stdio.h>
32+
int main() {
33+
int nthreads;
34+
#pragma omp parallel
35+
{
36+
nthreads = omp_get_num_threads();
37+
}
38+
printf("Number of threads = %d\\n", nthreads);
39+
return 0;
40+
}
41+
"""
42+
43+
with open("test_openmp.c", "w") as f:
44+
f.write(openmp_test_code)
45+
46+
try:
47+
# Try to compile the code with OpenMP support
48+
result = subprocess.run(
49+
["gcc", "-fopenmp", "test_openmp.c", "-o", "test_openmp"],
50+
stdout=subprocess.PIPE,
51+
stderr=subprocess.PIPE,
52+
)
53+
if result.returncode != 0:
54+
return False
55+
56+
# Run the compiled program
57+
result = subprocess.run(
58+
["./test_openmp"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
59+
)
60+
if result.returncode == 0:
61+
return True
62+
else:
63+
return False
64+
finally:
65+
os.remove("test_openmp.c")
66+
if os.path.exists("test_openmp"):
67+
os.remove("test_openmp")
2568

2669

2770
class BuildExt(build_ext):
2871
"""A custom build extension for adding -stdlib arguments for clang++."""
2972

3073
def build_extensions(self) -> None:
74+
support = check_openmp_support()
75+
3176
# '-std=c++11' is added to `extra_compile_args` so the code can compile
3277
# with clang++. This works across compilers (ignored by MSVC).
3378
for extension in self.extensions:
3479
extension.extra_compile_args.append("-std=c++11")
35-
extension.extra_compile_args.append("-fopenmp")
80+
if support:
81+
extension.extra_compile_args.append("-fopenmp")
82+
extension.extra_link_args.append("-lomp")
83+
else:
84+
warnings.warn(
85+
"\x1b[31;20m OpenMP is not installed on this system. "
86+
"Please install it to have all the benefits from the program.\x1b[0m"
87+
)
3688

3789
try:
3890
build_ext.build_extensions(self)
@@ -43,7 +95,6 @@ def build_extensions(self) -> None:
4395
for extension in self.extensions:
4496
extension.extra_compile_args.append("-stdlib=libc++")
4597
extension.extra_link_args.append("-stdlib=libc++")
46-
extension.extra_link_args.append("-lomp")
4798
build_ext.build_extensions(self)
4899

49100

tests/test.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,54 @@ def check_cost(inertia: float, radii: np.ndarray) -> Any:
1212
return np.isclose(inertia, sum(radii))
1313

1414

15+
EXPECTED_VALUES = {
16+
2: { # Dim
17+
2: { # K
18+
"schmidt": (45.91, 45.93),
19+
"heuristic": (48.25, 48.27),
20+
"gonzales": (48.25, 48.27),
21+
"kmeans": (48.25, 48.27),
22+
},
23+
3: {
24+
"schmidt": (55.86, 56.17),
25+
"heuristic": (57.73, 57.75),
26+
"gonzales": (57.73, 57.75),
27+
"kmeans": (57.73, 57.75),
28+
},
29+
4: {
30+
"schmidt": (50.53, 50.95),
31+
"heuristic": (52.34, 52.36),
32+
"gonzales": (52.34, 52.36),
33+
"kmeans": (52.34, 52.36),
34+
},
35+
},
36+
3: {
37+
2: {
38+
"schmidt": (41.24, 42.84),
39+
"heuristic": (42.96, 42.98),
40+
"gonzales": (42.96, 44.81),
41+
"kmeans": (42.96, 44.81),
42+
},
43+
3: {
44+
"schmidt": (51.79, 51.81),
45+
"heuristic": (51.79, 51.81),
46+
"gonzales": (51.79, 60.65),
47+
"kmeans": (51.79, 73.38),
48+
},
49+
4: {
50+
"schmidt": (66.37, 66.39),
51+
"heuristic": (68.64, 68.66),
52+
"gonzales": (68.64, 68.66),
53+
"kmeans": (68.64, 68.66),
54+
},
55+
},
56+
}
57+
58+
1559
class TestKMSR(unittest.TestCase):
1660
def test_fit(self) -> None:
1761
random.seed(42)
62+
s = "{"
1863
for k in range(2, 5):
1964
for dim in range(2, 4):
2065
_, points = generate_clusters(
@@ -37,7 +82,17 @@ def test_fit(self) -> None:
3782
self.assertTrue(
3883
check_cost(kmsr.inertia_, kmsr.cluster_radii_)
3984
)
40-
# print(algo, sum(costs) / len(costs))
85+
assert (
86+
EXPECTED_VALUES[dim][k][algo][0]
87+
<= kmsr.inertia_
88+
<= EXPECTED_VALUES[dim][k][algo][1]
89+
), (
90+
f"{EXPECTED_VALUES[dim][k][algo][0]} "
91+
f"<= {kmsr.inertia_} "
92+
f"<= {EXPECTED_VALUES[dim][k][algo][1]}"
93+
)
94+
95+
print(s)
4196

4297

4398
if __name__ == "__main__":

0 commit comments

Comments
 (0)