Skip to content

Commit 89458c9

Browse files
committed
CMTL done
1 parent eae6bca commit 89458c9

File tree

2 files changed

+62
-17
lines changed

2 files changed

+62
-17
lines changed

Vampyr_MTL/functions/MTL_Cluster_Least_L21.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from scipy.sparse import isspmatrix
1212

1313
class MTL_Cluster_Least_L21:
14-
def __init__(self, opts, k, rho1=10, rho2=0.2):
14+
def __init__(self, opts, k, rho1=10, rho2=0.1):
1515
self.opts = init_opts(opts)
1616
self.rho1 = rho1
1717
self.rho2 = rho2
@@ -44,8 +44,8 @@ def fit(self, X, Y, **kwargs):
4444
self.XY = [0]* self.task_num
4545
W0_prep = []
4646
for t in range(self.task_num):
47-
self.XY[i] = X[i] @ Y[i]
48-
W0_prep.append(self.XY[i].reshape((-1,1)))
47+
self.XY[t] = X[t] @ Y[t]
48+
W0_prep.append(self.XY[t].reshape((-1,1)))
4949
W0_prep = np.hstack(W0_prep)
5050
if hasattr(self.opts,'W0'):
5151
W0=self.opts.W0
@@ -65,14 +65,14 @@ def fit(self, X, Y, **kwargs):
6565
Wz_old = W0
6666
Mz = M0.toarray()
6767
Mz_old = M0.toarray()
68-
68+
6969
t = 1
7070
t_old = 0
7171

7272
it = 0
7373
gamma = 1.0
7474
gamma_inc = 2
75-
75+
# for it in trange(2, file=sys.stdout, desc='outer loop'):
7676
for it in trange(self.opts.maxIter, file=sys.stdout, desc='outer loop'):
7777
alpha = (t_old - 1)/t
7878
Ws = (1 + alpha) * Wz - alpha * Wz_old
@@ -85,9 +85,13 @@ def fit(self, X, Y, **kwargs):
8585
gWs, gMs, Fs = self.gradVal_eval(Ws, Ms)
8686

8787
in_it = 0
88+
# for in_it in trange(2,file=sys.stdout, leave=False, unit_scale=True, desc='inner loop'):
8889
for in_it in trange(1000,file=sys.stdout, leave=False, unit_scale=True, desc='inner loop'):
8990
Wzp = Ws - gWs/gamma
9091
Mzp, Mzp_Pz, Mzp_DiagSigz = self.singular_projection (Ms - gMs/gamma, self.k)
92+
# print(Mzp)
93+
# print(Mzp_Pz)
94+
# print(Mzp_DiagSigz)
9195
Fzp = self.funVal_eval(Wzp, Mzp_Pz, Mzp_DiagSigz)
9296

9397
delta_Wzs = Wzp - Ws
@@ -147,7 +151,10 @@ def singular_projection (self, Msp, k):
147151
[type]: [description]
148152
"""
149153
# l2.1 norm projection.
150-
EValue, EVector = linalg.eig(Msp)
154+
EValue, EVector = LA.eig(Msp)
155+
idx = EValue.argsort()
156+
EValue = EValue[idx]
157+
EVector = EVector[:,idx]
151158
Pz = np.real(EVector)
152159
diag_EValue = np.real(EValue)
153160
DiagSigz, _, _ = self.bsa_ihb(diag_EValue, np.ones(diag_EValue.shape), k, np.ones(diag_EValue.shape))
@@ -224,7 +231,6 @@ def gradVal_eval(self, W, M):
224231
grad_W = grad_W + 2 * self.c * invEtaMWt.T
225232
W2 = W.T @ W
226233
grad_M = - self.c * W2@linalg.inv(IM)@linalg.inv(IM)
227-
228234
funcVal = 0
229235
if self.opts.pFlag:
230236
pass
@@ -252,15 +258,19 @@ def funVal_eval(self, W, M_Pz, M_DiagSigz):
252258
def get_params(self, deep = False):
253259
return {'rho1':self.rho1, 'rho2':self.rho2,'opts':self.opts, 'k':self.k}
254260

261+
def get_weights(self):
262+
return self.W
263+
255264
def analyse(self):
256265
# returns correlation matrix
257266

258-
# kmCMTL_OrderedModel = np.zeros(self.W.shape)
259-
# for i in range(self.k):
260-
# clusModel = self.W[:, i:self.task_num*self.k:self.k]
261-
# kmCMTL_OrderedModel[:, (i)*self.task_num: (i+1)* self.task_num] = clusModel
262-
# return 1-np.corrcoef(kmCMTL_OrderedModel)
263-
return self.W
267+
kmCMTL_OrderedModel = np.zeros(self.W.shape)
268+
clus_task_num = self.task_num//self.k
269+
for i in range(self.k):
270+
clusModel = self.W[:, i:self.task_num:self.k]
271+
kmCMTL_OrderedModel[:, (i)*clus_task_num: (i+1)* clus_task_num] = clusModel
272+
return 1-np.corrcoef(kmCMTL_OrderedModel)
273+
# return self.W
264274

265275

266276

Vampyr_MTL/functions/tests/test_C_Least_L21.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,24 +76,59 @@ def test_bsa_ihb(self):
7676
assert np.isnan(t_star) == True
7777
assert it == 3
7878
def test_basic_mat(self):
79-
clf = MTL_Cluster_Least_L21(opts, 3)
79+
clus_num = 2
80+
clus_task_num = 10
81+
task_num = clus_num * clus_task_num
82+
clf = MTL_Cluster_Least_L21(opts, 2)
8083
clf.fit(X, Y)
8184
corr = clf.analyse()
8285
# print(corr)
83-
fig = px.imshow(corr)
84-
fig2 = px.imshow(W)
86+
fig = px.imshow(corr, color_continuous_scale='Bluered_r')
8587
fig.update_layout(
8688
title={
8789
'text': "predict",
8890
})
8991
fig.show()
92+
OrderedTrueModel = np.zeros(W.shape)
93+
clus_task_num = task_num//clus_num
94+
for i in range(clus_num):
95+
clusModel = W[:, i:task_num:clus_num]
96+
OrderedTrueModel[:, (i)*clus_task_num: (i+1)* clus_task_num] = clusModel
97+
corr2 = 1-np.corrcoef(OrderedTrueModel)
98+
fig2 = px.imshow(corr2, color_continuous_scale='Bluered_r')
9099
fig2.update_layout(
91100
title={
92101
'text': "real",
93102
})
94103
fig2.show()
95104

96-
105+
def test_check_simplified(self):
106+
# generate cluster model
107+
cluster_weight = np.ones((dimension, clus_num))* clus_var
108+
W = np.tile(cluster_weight, (1, clus_task_num))
109+
cluster_index = np.tile(range(clus_num), (1, clus_task_num)).T
110+
111+
# generate task and intra-cluster variance
112+
W_it = np.zeros((dimension, task_num)) * task_var
113+
for i in range(task_num):
114+
bll = np.hstack(((W[:-1-comm_dim+1,i]==0).reshape(1,-1), np.zeros((1,comm_dim))==1))
115+
W_it[:,i][bll.flatten()]=0
116+
W = W+W_it
117+
118+
W = W + np.zeros((dimension, task_num))*nois_var
119+
120+
X = [0]*task_num
121+
Y = [0]*task_num
122+
for i in range(task_num):
123+
X[i] = np.ones((sample_size, dimension))
124+
xw = X[i] @ W[:,i]
125+
s= xw.shape
126+
xw = xw + np.ones((s[0])) * nois_var
127+
Y[i] = np.sign(xw)
128+
129+
clf = MTL_Cluster_Least_L21(opts, 2)
130+
clf.fit(X, Y)
131+
corr = clf.analyse()
97132
# def test_iris_accuracy(self):
98133
# clf = MTL_Cluster_Least_L21(opts, 3)
99134
# clf.fit(X_train, Y_train)

0 commit comments

Comments
 (0)