Skip to content

Commit d44d799

Browse files
committed
add CMTL least L21, doesnot work yet, set up test already
1 parent 0ddaf96 commit d44d799

File tree

3 files changed

+91
-15
lines changed

3 files changed

+91
-15
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from .test_data import get_data
2+
from ..MTL_Cluster_Least_L21 import MTL_Cluster_Least_L21
3+
from ...evaluations.utils import MTL_data_extract, MTL_data_split, opts
4+
import numpy as np
5+
import math
6+
7+
# iris data
8+
X_train, X_test, Y_train, Y_test, df = get_data()
9+
opts = opts(1500,2)
10+
opts.tol = 10**(-6)
11+
12+
13+
# customized data
14+
clus_var = 900
15+
task_var = 16
16+
nois_var = 150
17+
18+
clus_num = 2
19+
clus_task_num = 10
20+
task_num = clus_num * clus_task_num
21+
sample_size = 100
22+
dimension = 20
23+
comm_dim = 2
24+
clus_dim = math.floor((dimension - comm_dim)/2)
25+
26+
# generate cluster model
27+
cluster_weight = np.random.randn(dimension, clus_num)* clus_var
28+
for i in range(clus_num):
29+
bll = np.random.permutation(range(dimension-clus_num))<=clus_dim
30+
blc = np.array([False]*clus_num)
31+
bll = np.hstack((bll, blc))
32+
cluster_weight[:,i][bll]=0
33+
cluster_weight[-1-comm_dim:, :]=0
34+
W = np.tile(cluster_weight, (1, clus_task_num))
35+
cluster_index = np.tile(range(clus_num), (1, clus_task_num)).T
36+
37+
# generate task and intra-cluster variance
38+
W_it = np.random.randn(dimension, task_num) * task_var
39+
for i in range(task_num):
40+
bll = np.hstack(((W[:-1-comm_dim+1,i]==0).reshape(1,-1), np.zeros((1,comm_dim))==1))
41+
W_it[:,i][bll.flatten()]=0
42+
W = W+W_it
43+
44+
W = W + np.random.randn(dimension, task_num)*nois_var
45+
46+
X = [0]*task_num
47+
Y = [0]*task_num
48+
for i in range(task_num):
49+
X[i] = np.random.randn(sample_size, dimension)
50+
xw = X[i] @ W[:,i]
51+
s= xw.shape
52+
xw = xw + np.random.randn(s[0]) * nois_var
53+
Y[i] = np.sign(xw)
54+
55+
class Test_CMTL_Least_classification(object):
56+
def test_basic_mat(self):
57+
clf = MTL_Cluster_Least_L21(opts, 3)
58+
clf.fit(X, Y)
59+
corr = clf.analyse()
60+
print(corr)
61+
62+
63+
def test_iris_accuracy(self):
64+
clf = MTL_Cluster_Least_L21(opts, 3)
65+
clf.fit(X_train, Y_train)
66+
corr = clf.analyse()
67+
print(corr)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import numpy as np
2+
import pandas as pd
3+
from sklearn import datasets
4+
from ...evaluations.utils import MTL_data_extract, MTL_data_split, opts
5+
6+
7+
def get_data():
8+
iris = datasets.load_iris()
9+
df = pd.DataFrame(data= np.c_[iris['data'], iris['target']],
10+
columns= iris['feature_names'] + ['target'])
11+
df['cat1']=0
12+
df['cat2']=0
13+
df['target'] = df['target'].astype(int)
14+
df.loc[df['petal width (cm)']<=0.8, 'cat1'] = 0
15+
df.loc[(df['petal width (cm)']>0.8) & (df['petal width (cm)']<=1.6), 'cat1'] = 1
16+
df.loc[(df['petal width (cm)']>1.6) & (df['petal width (cm)']<=2.4), 'cat1'] = 2
17+
df.loc[df['petal length (cm)']<=2.3, 'cat2'] = 0
18+
df.loc[(df['petal length (cm)']>2.3) & (df['petal length (cm)']<=4.6), 'cat2'] = 1
19+
df.loc[(df['petal length (cm)']>4.6) & (df['petal length (cm)']<=6.9), 'cat2'] = 2
20+
X_i, Y_i = MTL_data_extract(df, 'cat2', 'target')
21+
X_train, X_test, Y_train, Y_test = MTL_data_split(X_i, Y_i, test_size=0.4)
22+
return X_train, X_test, Y_train, Y_test, df

Vampyr_MTL/functions/tests/test_softmax_L21_hinge.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pandas as pd
44
from sklearn import datasets
55
from ...evaluations.utils import MTL_data_extract, MTL_data_split, opts
6+
from .test_data import get_data
67

78
# class opts:
89
# def __init__(self, maxIter, init):
@@ -12,21 +13,7 @@
1213

1314
opts = opts(1000,2)
1415

15-
## iris data
16-
iris = datasets.load_iris()
17-
df = pd.DataFrame(data= np.c_[iris['data'], iris['target']],
18-
columns= iris['feature_names'] + ['target'])
19-
df['cat1']=0
20-
df['cat2']=0
21-
df['target'] = df['target'].astype(int)
22-
df.loc[df['petal width (cm)']<=0.8, 'cat1'] = 0
23-
df.loc[(df['petal width (cm)']>0.8) & (df['petal width (cm)']<=1.6), 'cat1'] = 1
24-
df.loc[(df['petal width (cm)']>1.6) & (df['petal width (cm)']<=2.4), 'cat1'] = 2
25-
df.loc[df['petal length (cm)']<=2.3, 'cat2'] = 0
26-
df.loc[(df['petal length (cm)']>2.3) & (df['petal length (cm)']<=4.6), 'cat2'] = 1
27-
df.loc[(df['petal length (cm)']>4.6) & (df['petal length (cm)']<=6.9), 'cat2'] = 2
28-
X_i, Y_i = MTL_data_extract(df, 'cat2', 'target')
29-
X_train, X_test, Y_train, Y_test = MTL_data_split(X_i, Y_i, test_size=0.4)
16+
X_train, X_test, Y_train, Y_test, df = get_data()
3017

3118
df2 = df.copy()
3219
df2.loc[df['target']==0, 'target'] = 'flower1'

0 commit comments

Comments
 (0)