1111from scipy .sparse import isspmatrix
1212
1313class MTL_Cluster_Least_L21 :
14- def __init__ (self , opts , k , rho1 = 10 , rho2 = 0.2 ):
14+ def __init__ (self , opts , k , rho1 = 10 , rho2 = 0.1 ):
1515 self .opts = init_opts (opts )
1616 self .rho1 = rho1
1717 self .rho2 = rho2
@@ -44,8 +44,8 @@ def fit(self, X, Y, **kwargs):
4444 self .XY = [0 ]* self .task_num
4545 W0_prep = []
4646 for t in range (self .task_num ):
47- self .XY [i ] = X [i ] @ Y [i ]
48- W0_prep .append (self .XY [i ].reshape ((- 1 ,1 )))
47+ self .XY [t ] = X [t ] @ Y [t ]
48+ W0_prep .append (self .XY [t ].reshape ((- 1 ,1 )))
4949 W0_prep = np .hstack (W0_prep )
5050 if hasattr (self .opts ,'W0' ):
5151 W0 = self .opts .W0
@@ -65,14 +65,14 @@ def fit(self, X, Y, **kwargs):
6565 Wz_old = W0
6666 Mz = M0 .toarray ()
6767 Mz_old = M0 .toarray ()
68-
68+
6969 t = 1
7070 t_old = 0
7171
7272 it = 0
7373 gamma = 1.0
7474 gamma_inc = 2
75-
75+ # for it in trange(2, file=sys.stdout, desc='outer loop'):
7676 for it in trange (self .opts .maxIter , file = sys .stdout , desc = 'outer loop' ):
7777 alpha = (t_old - 1 )/ t
7878 Ws = (1 + alpha ) * Wz - alpha * Wz_old
@@ -85,9 +85,13 @@ def fit(self, X, Y, **kwargs):
8585 gWs , gMs , Fs = self .gradVal_eval (Ws , Ms )
8686
8787 in_it = 0
88+ # for in_it in trange(2,file=sys.stdout, leave=False, unit_scale=True, desc='inner loop'):
8889 for in_it in trange (1000 ,file = sys .stdout , leave = False , unit_scale = True , desc = 'inner loop' ):
8990 Wzp = Ws - gWs / gamma
9091 Mzp , Mzp_Pz , Mzp_DiagSigz = self .singular_projection (Ms - gMs / gamma , self .k )
92+ # print(Mzp)
93+ # print(Mzp_Pz)
94+ # print(Mzp_DiagSigz)
9195 Fzp = self .funVal_eval (Wzp , Mzp_Pz , Mzp_DiagSigz )
9296
9397 delta_Wzs = Wzp - Ws
@@ -147,7 +151,10 @@ def singular_projection (self, Msp, k):
147151 [type]: [description]
148152 """
149153 # l2.1 norm projection.
150- EValue , EVector = linalg .eig (Msp )
154+ EValue , EVector = LA .eig (Msp )
155+ idx = EValue .argsort ()
156+ EValue = EValue [idx ]
157+ EVector = EVector [:,idx ]
151158 Pz = np .real (EVector )
152159 diag_EValue = np .real (EValue )
153160 DiagSigz , _ , _ = self .bsa_ihb (diag_EValue , np .ones (diag_EValue .shape ), k , np .ones (diag_EValue .shape ))
@@ -224,7 +231,6 @@ def gradVal_eval(self, W, M):
224231 grad_W = grad_W + 2 * self .c * invEtaMWt .T
225232 W2 = W .T @ W
226233 grad_M = - self .c * W2 @linalg .inv (IM )@linalg .inv (IM )
227-
228234 funcVal = 0
229235 if self .opts .pFlag :
230236 pass
@@ -252,15 +258,19 @@ def funVal_eval(self, W, M_Pz, M_DiagSigz):
252258 def get_params (self , deep = False ):
253259 return {'rho1' :self .rho1 , 'rho2' :self .rho2 ,'opts' :self .opts , 'k' :self .k }
254260
261+ def get_weights (self ):
262+ return self .W
263+
255264 def analyse (self ):
256265 # returns correlation matrix
257266
258- # kmCMTL_OrderedModel = np.zeros(self.W.shape)
259- # for i in range(self.k):
260- # clusModel = self.W[:, i:self.task_num*self.k:self.k]
261- # kmCMTL_OrderedModel[:, (i)*self.task_num: (i+1)* self.task_num] = clusModel
262- # return 1-np.corrcoef(kmCMTL_OrderedModel)
263- return self .W
267+ kmCMTL_OrderedModel = np .zeros (self .W .shape )
268+ clus_task_num = self .task_num // self .k
269+ for i in range (self .k ):
270+ clusModel = self .W [:, i :self .task_num :self .k ]
271+ kmCMTL_OrderedModel [:, (i )* clus_task_num : (i + 1 )* clus_task_num ] = clusModel
272+ return 1 - np .corrcoef (kmCMTL_OrderedModel )
273+ # return self.W
264274
265275
266276
0 commit comments