Skip to content

Commit bb844c3

Browse files
authored
Merge pull request #212 from lucasimi/develop
Develop
2 parents f5b51fe + 9df87bf commit bb844c3

37 files changed

+990
-794
lines changed

app/requirements.txt

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
streamlit>=1.40.0
2-
numpy>=1.25.2
3-
scikit-learn>=1.5.0
4-
umap-learn>=0.5.7
5-
pandas>=2.1.0
6-
tda-mapper==0.9.0
7-
plotly < 6.0.0
1+
streamlit>=1.40.0,<2.0.0
2+
numpy>=1.25.2,<2.0.0
3+
scikit-learn>=1.5.0,<1.6.0
4+
umap-learn>=0.5.7,<0.6.0
5+
pandas>=2.1.0,<3.0.0
6+
tda-mapper>=0.9.0,<0.10.0
7+
plotly>=6.0.0,<7.0.0

benchmarks/benchmark.py

Lines changed: 50 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
import time
22

3-
import pandas as pd
3+
import gtda.mapper as gm
4+
import kmapper as km
45
import numpy as np
5-
6-
from sklearn.decomposition import PCA
7-
from sklearn.datasets import fetch_openml, load_digits
6+
import pandas as pd
87
from sklearn.base import ClusterMixin
9-
10-
from tdamapper.clustering import TrivialClustering
8+
from sklearn.datasets import fetch_openml, load_digits
9+
from sklearn.decomposition import PCA
1110

1211
import tdamapper as tm
13-
import gtda.mapper as gm
14-
import kmapper as km
12+
from tdamapper.clustering import TrivialClustering
1513

1614

1715
def _segment(cardinality, dimension, noise=0.1, start=None, end=None):
@@ -40,17 +38,17 @@ def digits(k):
4038

4139

4240
def mnist(k):
43-
X = _load_openml('mnist_784')
41+
X = _load_openml("mnist_784")
4442
return PCA(k).fit_transform(X)
4543

4644

4745
def cifar10(k):
48-
X = _load_openml('CIFAR_10')
46+
X = _load_openml("CIFAR_10")
4947
return PCA(k).fit_transform(X)
5048

5149

5250
def fashion_mnist(k):
53-
X = _load_openml('Fashion-MNIST')
51+
X = _load_openml("Fashion-MNIST")
5452
return PCA(k).fit_transform(X)
5553

5654

@@ -73,9 +71,7 @@ def run_gm(X, n, p):
7371
t0 = time.time()
7472
pipe = gm.make_mapper_pipeline(
7573
filter_func=lambda x: x,
76-
cover=gm.CubicalCover(
77-
n_intervals=n,
78-
overlap_frac=p),
74+
cover=gm.CubicalCover(n_intervals=n, overlap_frac=p),
7975
clusterer=TrivialEstimator(),
8076
)
8177
mapper_graph = pipe.fit_transform(X)
@@ -89,10 +85,10 @@ def run_tm(X, n, p):
8985
cover=tm.cover.CubicalCover(
9086
n_intervals=n,
9187
overlap_frac=p,
92-
#leaf_capacity=1000,
93-
#leaf_radius=1.0 / (2.0 - 2.0 * p),
94-
#kind='hierarchical',
95-
#pivoting='random',
88+
# leaf_capacity=1000,
89+
# leaf_radius=1.0 / (2.0 - 2.0 * p),
90+
# kind='hierarchical',
91+
# pivoting='random',
9692
),
9793
clustering=TrivialEstimator(),
9894
).fit_transform(X, X)
@@ -106,25 +102,24 @@ def run_km(X, n, p):
106102
graph = mapper.map(
107103
lens=X,
108104
X=X,
109-
cover=km.Cover(
110-
n_cubes=n,
111-
perc_overlap=p
112-
),
105+
cover=km.Cover(n_cubes=n, perc_overlap=p),
113106
clusterer=TrivialEstimator(),
114107
)
115108
t1 = time.time()
116109
return t1 - t0
117110

118111

119112
def run_bench(benches, datasets, dimensions, overlaps, intervals):
120-
df_bench = pd.DataFrame({
121-
'bench': [],
122-
'dataset': [],
123-
'p': [],
124-
'n': [],
125-
'k': [],
126-
'time': [],
127-
})
113+
df_bench = pd.DataFrame(
114+
{
115+
"bench": [],
116+
"dataset": [],
117+
"p": [],
118+
"n": [],
119+
"k": [],
120+
"time": [],
121+
}
122+
)
128123
launch_time = int(time.time())
129124
for bench_name, bench in benches:
130125
for dataset_name, dataset in datasets:
@@ -133,34 +128,33 @@ def run_bench(benches, datasets, dimensions, overlaps, intervals):
133128
for p in overlaps:
134129
for n in intervals:
135130
t = bench(X, n, p)
136-
df_delta = pd.DataFrame({
137-
'bench': bench_name,
138-
'dataset': dataset_name,
139-
'p': p,
140-
'n': n,
141-
'k': k,
142-
'time': t,
143-
}, index=[0])
131+
df_delta = pd.DataFrame(
132+
{
133+
"bench": bench_name,
134+
"dataset": dataset_name,
135+
"p": p,
136+
"n": n,
137+
"k": k,
138+
"time": t,
139+
},
140+
index=[0],
141+
)
144142
print(df_delta)
145143
df_bench = pd.concat([df_bench, df_delta], ignore_index=True)
146-
df_bench.to_csv(f'./benchmark_{launch_time}.csv', index=False)
144+
df_bench.to_csv(f"./benchmark_{launch_time}.csv", index=False)
147145

148146

149-
if __name__ == '__main__':
150-
run_tm(line(1), 1, 0.5) # fist run to jit-compile numba decorated functions
147+
if __name__ == "__main__":
148+
run_tm(line(1), 1, 0.5) # fist run to jit-compile numba decorated functions
151149

152150
run_bench(
153-
overlaps=[
154-
0.125,
155-
0.25,
156-
0.5
157-
],
151+
overlaps=[0.125, 0.25, 0.5],
158152
datasets=[
159-
('line', line),
160-
('digits', digits),
161-
('mnist', mnist),
162-
('cifar10', cifar10),
163-
('fashion_mnist', fashion_mnist),
153+
("line", line),
154+
("digits", digits),
155+
("mnist", mnist),
156+
("cifar10", cifar10),
157+
("fashion_mnist", fashion_mnist),
164158
],
165159
intervals=[
166160
10,
@@ -172,9 +166,9 @@ def run_bench(benches, datasets, dimensions, overlaps, intervals):
172166
4,
173167
5,
174168
],
175-
benches = [
176-
('tda-mapper', run_tm),
177-
('kepler-mapper', run_km),
178-
('giotto-tda', run_gm),
169+
benches=[
170+
("tda-mapper", run_tm),
171+
("kepler-mapper", run_km),
172+
("giotto-tda", run_gm),
179173
],
180-
)
174+
)

0 commit comments

Comments
 (0)