Skip to content

Commit 15174b4

Browse files
committed
Added tests for learn module
1 parent 3255371 commit 15174b4

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

tests/test_unit_learn.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
2+
import unittest
3+
4+
import numpy as np
5+
import networkx as nx
6+
7+
from sklearn.utils.estimator_checks import check_estimator
8+
from sklearn.cluster import DBSCAN
9+
10+
from tdamapper.learn import (
11+
MapperAlgorithm,
12+
MapperClustering,
13+
)
14+
from tdamapper.core import TrivialClustering, TrivialCover
15+
from tdamapper.cover import BallCover
16+
17+
18+
def euclidean(x, y):
19+
return np.linalg.norm(x - y)
20+
21+
22+
def dataset(dim=10, num=1000):
23+
return [np.random.rand(dim) for _ in range(num)]
24+
25+
26+
class TestMapper(unittest.TestCase):
27+
28+
def run_tests(self, estimator):
29+
for est, check in check_estimator(estimator, generate_only=True):
30+
check(est)
31+
32+
def test_mapper_learn(self):
33+
data = dataset()
34+
mp = MapperAlgorithm(TrivialCover(), TrivialClustering())
35+
g = mp.fit_transform(data, data)
36+
self.assertEqual(1, len(g))
37+
self.assertEqual([], list(g.neighbors(0)))
38+
ccs = list(nx.connected_components(g))
39+
self.assertEqual(1, len(ccs))
40+
41+
def test_mapper_learn_est(self):
42+
est = MapperAlgorithm()
43+
self.run_tests(est)
44+
45+
def test_mapper_clustering_trivial(self):
46+
est = MapperClustering()
47+
self.run_tests(est)
48+
49+
def test_mapper_clustering_ball(self):
50+
est = MapperClustering(cover=BallCover(metric=euclidean))
51+
self.run_tests(est)

0 commit comments

Comments
 (0)