Skip to content

Commit eae6bca

Browse files
committed
fixing CMTL
1 parent d44d799 commit eae6bca

File tree

2 files changed

+310
-9
lines changed

2 files changed

+310
-9
lines changed
Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
import numpy as np
2+
from .init_opts import init_opts
3+
from numpy import linalg as LA
4+
from tqdm import tqdm
5+
from tqdm import trange
6+
import sys
7+
import time
8+
from scipy.sparse import identity
9+
from scipy import linalg
10+
from scipy.sparse.linalg import spsolve
11+
from scipy.sparse import isspmatrix
12+
13+
class MTL_Cluster_Least_L21:
14+
def __init__(self, opts, k, rho1=10, rho2=0.2):
15+
self.opts = init_opts(opts)
16+
self.rho1 = rho1
17+
self.rho2 = rho2
18+
self.rho_L2 = 0
19+
self.k = k
20+
if hasattr(opts, 'rho_L2'):
21+
rho_L2 = opts.rho_L2
22+
23+
def fit(self, X, Y, **kwargs):
24+
"""
25+
X: np.array: t x n x d
26+
Y: np.array t x n x 1
27+
"""
28+
if 'rho' in kwargs.keys():
29+
print(kwargs)
30+
self.rho1 = kwargs['rho']
31+
X_new = []
32+
for i in range(len(X)):
33+
X_new.append(np.transpose(X[i]))
34+
X = X_new
35+
self.X = X
36+
self.Y = Y
37+
# transpose to size: t x d x n
38+
self.task_num = len(X)
39+
self.dimension, _ = X[0].shape
40+
self.eta = self.rho2/self.rho1
41+
self.c = self.rho1 * self.eta * (1+self.eta)
42+
funcVal = []
43+
44+
self.XY = [0]* self.task_num
45+
W0_prep = []
46+
for t in range(self.task_num):
47+
self.XY[i] = X[i] @ Y[i]
48+
W0_prep.append(self.XY[i].reshape((-1,1)))
49+
W0_prep = np.hstack(W0_prep)
50+
if hasattr(self.opts,'W0'):
51+
W0=self.opts.W0
52+
elif self.opts.init==2:
53+
W0 = np.zeros((self.dimension, self.task_num))
54+
elif self.opts.init == 0:
55+
W0 =W0_prep
56+
else:
57+
W0 = np.random.normal(0, 1, (self.dimension, self.task_num))
58+
59+
60+
M0 = np.array(identity(self.task_num)) * self.k / self.task_num
61+
# this flag checks if gradient descent only makes significant step
62+
63+
bFlag=0
64+
Wz= W0
65+
Wz_old = W0
66+
Mz = M0.toarray()
67+
Mz_old = M0.toarray()
68+
69+
t = 1
70+
t_old = 0
71+
72+
it = 0
73+
gamma = 1.0
74+
gamma_inc = 2
75+
76+
for it in trange(self.opts.maxIter, file=sys.stdout, desc='outer loop'):
77+
alpha = (t_old - 1)/t
78+
Ws = (1 + alpha) * Wz - alpha * Wz_old
79+
if(isspmatrix(Mz)):
80+
Mz = Mz.toarray()
81+
if(isspmatrix(Mz_old)):
82+
Mz_old = Mz_old.toarray()
83+
Ms = (1 + alpha) * Mz - alpha * Mz_old
84+
# compute function value and gradients of the search point
85+
gWs, gMs, Fs = self.gradVal_eval(Ws, Ms)
86+
87+
in_it = 0
88+
for in_it in trange(1000,file=sys.stdout, leave=False, unit_scale=True, desc='inner loop'):
89+
Wzp = Ws - gWs/gamma
90+
Mzp, Mzp_Pz, Mzp_DiagSigz = self.singular_projection (Ms - gMs/gamma, self.k)
91+
Fzp = self.funVal_eval(Wzp, Mzp_Pz, Mzp_DiagSigz)
92+
93+
delta_Wzs = Wzp - Ws
94+
delta_Mzs = Mzp - Ms
95+
96+
r_sum = (LA.norm(delta_Wzs)**2 + LA.norm(delta_Mzs)**2)/2
97+
Fzp_gamma = Fs + np.sum(delta_Wzs*gWs) + np.sum(delta_Mzs*gMs) + gamma * r_sum
98+
if (r_sum <=1e-20):
99+
bFlag=1 # this shows that, the gradient step makes little improvement
100+
break
101+
if (Fzp <= Fzp_gamma):
102+
break
103+
else:
104+
gamma = gamma * gamma_inc
105+
Wz_old = Wz
106+
Wz = Wzp
107+
Mz_old = Mz
108+
Mz = Mzp
109+
funcVal.append(Fzp)
110+
111+
if (bFlag):
112+
print('\n The program terminates as the gradient step changes the solution very small.')
113+
break
114+
if (self.opts.tFlag == 0):
115+
if it >= 2:
116+
if (abs(funcVal[-1] - funcVal[-2]) <= self.opts.tol):
117+
break
118+
119+
elif(self.opts.tFlag == 1):
120+
if it >= 2:
121+
if (abs(funcVal[-1] - funcVal[-2]) <= self.opts.tol * funcVal[-2]):
122+
break
123+
124+
elif(self.opts.tFlag == 2):
125+
if (funcVal[-1] <= self.opts.tol):
126+
break
127+
128+
elif(self.opts.tFlag == 3):
129+
if it >= self.opts.maxIter:
130+
break
131+
132+
t_old = t
133+
t = 0.5 * (1 + (1 + 4 * t ** 2) ** 0.5)
134+
135+
self.W = Wzp
136+
self.M = Mzp
137+
self.funcVal = funcVal
138+
139+
def singular_projection (self, Msp, k):
140+
"""[summary]
141+
142+
Args:
143+
Msp ([type]): [description]
144+
k ([type]): [description]
145+
146+
Returns:
147+
[type]: [description]
148+
"""
149+
# l2.1 norm projection.
150+
EValue, EVector = linalg.eig(Msp)
151+
Pz = np.real(EVector)
152+
diag_EValue = np.real(EValue)
153+
DiagSigz, _, _ = self.bsa_ihb(diag_EValue, np.ones(diag_EValue.shape), k, np.ones(diag_EValue.shape))
154+
Mzp = Pz @ np.diag(DiagSigz) @ Pz.T
155+
Mzp_Pz = Pz
156+
Mzp_DiagSigz = DiagSigz
157+
return Mzp, Mzp_Pz, Mzp_DiagSigz
158+
159+
def bsa_ihb(self, a, b, r, u):
160+
'''
161+
Singular Projection
162+
min 1/2*||x - a||_2^2
163+
s.t. b'*x = r, 0<= x <= u, b > 0
164+
'''
165+
break_flag = 0
166+
t_l = a/b
167+
t_u = (a - u)/b
168+
T = np.concatenate((t_l, t_u), axis=0)
169+
t_L = -np.Infinity
170+
t_U = np.Infinity
171+
g_tL = 0.
172+
g_tU = 0.
173+
174+
it = 0
175+
while(len(T)!=0):
176+
it +=1
177+
g_t = 0.
178+
t_hat = np.median(T)
179+
180+
U = t_hat < t_u
181+
M = (t_u <= t_hat) & (t_hat <= t_l)
182+
183+
if np.sum(U):
184+
g_t += np.sum(b[U]*u[U])
185+
if np.sum(M):
186+
g_t += np.sum(b[M]*(a[M]-t_hat*b[M]))
187+
if g_t > r:
188+
t_L = t_hat
189+
T = T[T>t_hat]
190+
g_tL = g_t
191+
elif g_t <r:
192+
t_U = t_hat
193+
T = T[T<t_hat]
194+
g_tU = g_t
195+
else:
196+
t_star = t_hat
197+
break_flag = 1
198+
break
199+
if not break_flag:
200+
eps = g_tU - g_tL
201+
t_star = t_L - (g_tL - r) * (t_U - t_L)/(eps)
202+
est = a-t_star * b
203+
if(np.isnan(est).any()):
204+
est[np.isnan(est)] = 0
205+
x_star = np.minimum(u, np.max(est, 0))
206+
return x_star, t_star, it
207+
208+
def gradVal_eval(self, W, M):
209+
IM = self.eta * identity(self.task_num)+M
210+
# could be sparse matrix to solve
211+
invEtaMWt = linalg.inv(IM) @ W.T
212+
if self.opts.pFlag:
213+
# grad_W = zeros(zeros(W));
214+
# # parfor i = 1:task_num
215+
# # grad_W (i, :) = X{i}*(X{i}' * W(:,i)-Y{i})
216+
pass
217+
else:
218+
grad_W = []
219+
for i in range(self.task_num):
220+
XWi = self.X[i].T @ W[:,i]
221+
XTXWi = self.X[i] @ XWi
222+
grad_W.append((XTXWi - self.XY[i]).reshape(-1,1))
223+
grad_W = np.hstack(grad_W)
224+
grad_W = grad_W + 2 * self.c * invEtaMWt.T
225+
W2 = W.T @ W
226+
grad_M = - self.c * W2@linalg.inv(IM)@linalg.inv(IM)
227+
228+
funcVal = 0
229+
if self.opts.pFlag:
230+
pass
231+
else:
232+
for i in range(self.task_num):
233+
funcVal = funcVal + 0.5 * LA.norm ((self.Y[i] - self.X[i].T @ W[:, i]), ord=2)**2
234+
funcVal = funcVal + self.c * np.trace( W @ invEtaMWt)
235+
return grad_W, grad_M, funcVal
236+
237+
def funVal_eval(self, W, M_Pz, M_DiagSigz):
238+
invIM = M_Pz @ (np.diag(1/(self.eta + np.array(M_DiagSigz)))) @ M_Pz.T
239+
invEtaMWt = invIM @ W.T
240+
funcVal = 0
241+
if self.opts.pFlag:
242+
# parfor i = 1: task_num
243+
# # funcVal = funcVal + 0.5 * norm (Y{i} - X{i}' * W(:, i))^2;
244+
# # end
245+
pass
246+
else:
247+
for i in range(self.task_num):
248+
funcVal = funcVal + 0.5 * LA.norm ((self.Y[i] - self.X[i].T @ W[:, i]), ord=2)**2
249+
funcVal = funcVal + self.c * np.trace(W @ invEtaMWt)
250+
return funcVal
251+
252+
def get_params(self, deep = False):
253+
return {'rho1':self.rho1, 'rho2':self.rho2,'opts':self.opts, 'k':self.k}
254+
255+
def analyse(self):
256+
# returns correlation matrix
257+
258+
# kmCMTL_OrderedModel = np.zeros(self.W.shape)
259+
# for i in range(self.k):
260+
# clusModel = self.W[:, i:self.task_num*self.k:self.k]
261+
# kmCMTL_OrderedModel[:, (i)*self.task_num: (i+1)* self.task_num] = clusModel
262+
# return 1-np.corrcoef(kmCMTL_OrderedModel)
263+
return self.W
264+
265+
266+
267+

Vampyr_MTL/functions/tests/test_C_Least_L21.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
from ...evaluations.utils import MTL_data_extract, MTL_data_split, opts
44
import numpy as np
55
import math
6+
from scipy import linalg
7+
import plotly.express as px
68

79
# iris data
810
X_train, X_test, Y_train, Y_test, df = get_data()
9-
opts = opts(1500,2)
11+
opts = opts(1500,0)
1012
opts.tol = 10**(-6)
1113

1214

@@ -51,17 +53,49 @@
5153
s= xw.shape
5254
xw = xw + np.random.randn(s[0]) * nois_var
5355
Y[i] = np.sign(xw)
54-
5556
class Test_CMTL_Least_classification(object):
57+
"""Pytest for Cluster Least Classification L21
58+
59+
Args:
60+
object ([type]): entry point for pytest
61+
"""
62+
63+
def test_bsa_ihb(self):
64+
""" Test for bsa_ihb function inside CMTL usage
65+
"""
66+
A = np.array([[1,2],[3,4]])
67+
EValue, EVector = linalg.eig(A)
68+
Pz = np.real(EVector)
69+
# diag_EValue = np.real(np.diagonal(Evalue))
70+
diag_EValue = np.real(EValue).reshape((-1,1))
71+
clf = MTL_Cluster_Least_L21(opts, 3)
72+
73+
x_star, t_star, it = clf.bsa_ihb(diag_EValue, np.ones(diag_EValue.shape), 3, np.ones(diag_EValue.shape))
74+
np.testing.assert_array_equal(x_star,
75+
np.array([[0],[0]]))
76+
assert np.isnan(t_star) == True
77+
assert it == 3
5678
def test_basic_mat(self):
5779
clf = MTL_Cluster_Least_L21(opts, 3)
5880
clf.fit(X, Y)
5981
corr = clf.analyse()
60-
print(corr)
61-
82+
# print(corr)
83+
fig = px.imshow(corr)
84+
fig2 = px.imshow(W)
85+
fig.update_layout(
86+
title={
87+
'text': "predict",
88+
})
89+
fig.show()
90+
fig2.update_layout(
91+
title={
92+
'text': "real",
93+
})
94+
fig2.show()
95+
6296

63-
def test_iris_accuracy(self):
64-
clf = MTL_Cluster_Least_L21(opts, 3)
65-
clf.fit(X_train, Y_train)
66-
corr = clf.analyse()
67-
print(corr)
97+
# def test_iris_accuracy(self):
98+
# clf = MTL_Cluster_Least_L21(opts, 3)
99+
# clf.fit(X_train, Y_train)
100+
# corr = clf.analyse()
101+
# print(corr)

0 commit comments

Comments
 (0)