11import time
22
3- import pandas as pd
3+ import gtda .mapper as gm
4+ import kmapper as km
45import numpy as np
5-
6- from sklearn .decomposition import PCA
7- from sklearn .datasets import fetch_openml , load_digits
6+ import pandas as pd
87from sklearn .base import ClusterMixin
9-
10- from tdamapper . clustering import TrivialClustering
8+ from sklearn . datasets import fetch_openml , load_digits
9+ from sklearn . decomposition import PCA
1110
1211import tdamapper as tm
13- import gtda .mapper as gm
14- import kmapper as km
12+ from tdamapper .clustering import TrivialClustering
1513
1614
1715def _segment (cardinality , dimension , noise = 0.1 , start = None , end = None ):
@@ -40,17 +38,17 @@ def digits(k):
4038
4139
4240def mnist (k ):
43- X = _load_openml (' mnist_784' )
41+ X = _load_openml (" mnist_784" )
4442 return PCA (k ).fit_transform (X )
4543
4644
4745def cifar10 (k ):
48- X = _load_openml (' CIFAR_10' )
46+ X = _load_openml (" CIFAR_10" )
4947 return PCA (k ).fit_transform (X )
5048
5149
5250def fashion_mnist (k ):
53- X = _load_openml (' Fashion-MNIST' )
51+ X = _load_openml (" Fashion-MNIST" )
5452 return PCA (k ).fit_transform (X )
5553
5654
@@ -73,9 +71,7 @@ def run_gm(X, n, p):
7371 t0 = time .time ()
7472 pipe = gm .make_mapper_pipeline (
7573 filter_func = lambda x : x ,
76- cover = gm .CubicalCover (
77- n_intervals = n ,
78- overlap_frac = p ),
74+ cover = gm .CubicalCover (n_intervals = n , overlap_frac = p ),
7975 clusterer = TrivialEstimator (),
8076 )
8177 mapper_graph = pipe .fit_transform (X )
@@ -89,10 +85,10 @@ def run_tm(X, n, p):
8985 cover = tm .cover .CubicalCover (
9086 n_intervals = n ,
9187 overlap_frac = p ,
92- #leaf_capacity=1000,
93- #leaf_radius=1.0 / (2.0 - 2.0 * p),
94- #kind='hierarchical',
95- #pivoting='random',
88+ # leaf_capacity=1000,
89+ # leaf_radius=1.0 / (2.0 - 2.0 * p),
90+ # kind='hierarchical',
91+ # pivoting='random',
9692 ),
9793 clustering = TrivialEstimator (),
9894 ).fit_transform (X , X )
@@ -106,25 +102,24 @@ def run_km(X, n, p):
106102 graph = mapper .map (
107103 lens = X ,
108104 X = X ,
109- cover = km .Cover (
110- n_cubes = n ,
111- perc_overlap = p
112- ),
105+ cover = km .Cover (n_cubes = n , perc_overlap = p ),
113106 clusterer = TrivialEstimator (),
114107 )
115108 t1 = time .time ()
116109 return t1 - t0
117110
118111
119112def run_bench (benches , datasets , dimensions , overlaps , intervals ):
120- df_bench = pd .DataFrame ({
121- 'bench' : [],
122- 'dataset' : [],
123- 'p' : [],
124- 'n' : [],
125- 'k' : [],
126- 'time' : [],
127- })
113+ df_bench = pd .DataFrame (
114+ {
115+ "bench" : [],
116+ "dataset" : [],
117+ "p" : [],
118+ "n" : [],
119+ "k" : [],
120+ "time" : [],
121+ }
122+ )
128123 launch_time = int (time .time ())
129124 for bench_name , bench in benches :
130125 for dataset_name , dataset in datasets :
@@ -133,34 +128,33 @@ def run_bench(benches, datasets, dimensions, overlaps, intervals):
133128 for p in overlaps :
134129 for n in intervals :
135130 t = bench (X , n , p )
136- df_delta = pd .DataFrame ({
137- 'bench' : bench_name ,
138- 'dataset' : dataset_name ,
139- 'p' : p ,
140- 'n' : n ,
141- 'k' : k ,
142- 'time' : t ,
143- }, index = [0 ])
131+ df_delta = pd .DataFrame (
132+ {
133+ "bench" : bench_name ,
134+ "dataset" : dataset_name ,
135+ "p" : p ,
136+ "n" : n ,
137+ "k" : k ,
138+ "time" : t ,
139+ },
140+ index = [0 ],
141+ )
144142 print (df_delta )
145143 df_bench = pd .concat ([df_bench , df_delta ], ignore_index = True )
146- df_bench .to_csv (f' ./benchmark_{ launch_time } .csv' , index = False )
144+ df_bench .to_csv (f" ./benchmark_{ launch_time } .csv" , index = False )
147145
148146
149- if __name__ == ' __main__' :
150- run_tm (line (1 ), 1 , 0.5 ) # fist run to jit-compile numba decorated functions
147+ if __name__ == " __main__" :
148+ run_tm (line (1 ), 1 , 0.5 ) # fist run to jit-compile numba decorated functions
151149
152150 run_bench (
153- overlaps = [
154- 0.125 ,
155- 0.25 ,
156- 0.5
157- ],
151+ overlaps = [0.125 , 0.25 , 0.5 ],
158152 datasets = [
159- (' line' , line ),
160- (' digits' , digits ),
161- (' mnist' , mnist ),
162- (' cifar10' , cifar10 ),
163- (' fashion_mnist' , fashion_mnist ),
153+ (" line" , line ),
154+ (" digits" , digits ),
155+ (" mnist" , mnist ),
156+ (" cifar10" , cifar10 ),
157+ (" fashion_mnist" , fashion_mnist ),
164158 ],
165159 intervals = [
166160 10 ,
@@ -172,9 +166,9 @@ def run_bench(benches, datasets, dimensions, overlaps, intervals):
172166 4 ,
173167 5 ,
174168 ],
175- benches = [
176- (' tda-mapper' , run_tm ),
177- (' kepler-mapper' , run_km ),
178- (' giotto-tda' , run_gm ),
169+ benches = [
170+ (" tda-mapper" , run_tm ),
171+ (" kepler-mapper" , run_km ),
172+ (" giotto-tda" , run_gm ),
179173 ],
180- )
174+ )
0 commit comments