1+ from .test_data import get_data
2+ from ..MTL_Cluster_Least_L21 import MTL_Cluster_Least_L21
3+ from ...evaluations .utils import MTL_data_extract , MTL_data_split , opts
4+ import numpy as np
5+ import math
6+
7+ # iris data
8+ X_train , X_test , Y_train , Y_test , df = get_data ()
9+ opts = opts (1500 ,2 )
10+ opts .tol = 10 ** (- 6 )
11+
12+
13+ # customized data
14+ clus_var = 900
15+ task_var = 16
16+ nois_var = 150
17+
18+ clus_num = 2
19+ clus_task_num = 10
20+ task_num = clus_num * clus_task_num
21+ sample_size = 100
22+ dimension = 20
23+ comm_dim = 2
24+ clus_dim = math .floor ((dimension - comm_dim )/ 2 )
25+
26+ # generate cluster model
27+ cluster_weight = np .random .randn (dimension , clus_num )* clus_var
28+ for i in range (clus_num ):
29+ bll = np .random .permutation (range (dimension - clus_num ))<= clus_dim
30+ blc = np .array ([False ]* clus_num )
31+ bll = np .hstack ((bll , blc ))
32+ cluster_weight [:,i ][bll ]= 0
33+ cluster_weight [- 1 - comm_dim :, :]= 0
34+ W = np .tile (cluster_weight , (1 , clus_task_num ))
35+ cluster_index = np .tile (range (clus_num ), (1 , clus_task_num )).T
36+
37+ # generate task and intra-cluster variance
38+ W_it = np .random .randn (dimension , task_num ) * task_var
39+ for i in range (task_num ):
40+ bll = np .hstack (((W [:- 1 - comm_dim + 1 ,i ]== 0 ).reshape (1 ,- 1 ), np .zeros ((1 ,comm_dim ))== 1 ))
41+ W_it [:,i ][bll .flatten ()]= 0
42+ W = W + W_it
43+
44+ W = W + np .random .randn (dimension , task_num )* nois_var
45+
46+ X = [0 ]* task_num
47+ Y = [0 ]* task_num
48+ for i in range (task_num ):
49+ X [i ] = np .random .randn (sample_size , dimension )
50+ xw = X [i ] @ W [:,i ]
51+ s = xw .shape
52+ xw = xw + np .random .randn (s [0 ]) * nois_var
53+ Y [i ] = np .sign (xw )
54+
55+ class Test_CMTL_Least_classification (object ):
56+ def test_basic_mat (self ):
57+ clf = MTL_Cluster_Least_L21 (opts , 3 )
58+ clf .fit (X , Y )
59+ corr = clf .analyse ()
60+ print (corr )
61+
62+
63+ def test_iris_accuracy (self ):
64+ clf = MTL_Cluster_Least_L21 (opts , 3 )
65+ clf .fit (X_train , Y_train )
66+ corr = clf .analyse ()
67+ print (corr )
0 commit comments