|
| 1 | +import numpy as np |
| 2 | +from .init_opts import init_opts |
| 3 | +from numpy import linalg as LA |
| 4 | +from tqdm import tqdm |
| 5 | +from tqdm import trange |
| 6 | +import sys |
| 7 | +import time |
| 8 | +from scipy.sparse import identity |
| 9 | +from scipy import linalg |
| 10 | +from scipy.sparse.linalg import spsolve |
| 11 | +from scipy.sparse import isspmatrix |
| 12 | + |
| 13 | +class MTL_Cluster_Least_L21: |
| 14 | + def __init__(self, opts, k, rho1=10, rho2=0.2): |
| 15 | + self.opts = init_opts(opts) |
| 16 | + self.rho1 = rho1 |
| 17 | + self.rho2 = rho2 |
| 18 | + self.rho_L2 = 0 |
| 19 | + self.k = k |
| 20 | + if hasattr(opts, 'rho_L2'): |
| 21 | + rho_L2 = opts.rho_L2 |
| 22 | + |
| 23 | + def fit(self, X, Y, **kwargs): |
| 24 | + """ |
| 25 | + X: np.array: t x n x d |
| 26 | + Y: np.array t x n x 1 |
| 27 | + """ |
| 28 | + if 'rho' in kwargs.keys(): |
| 29 | + print(kwargs) |
| 30 | + self.rho1 = kwargs['rho'] |
| 31 | + X_new = [] |
| 32 | + for i in range(len(X)): |
| 33 | + X_new.append(np.transpose(X[i])) |
| 34 | + X = X_new |
| 35 | + self.X = X |
| 36 | + self.Y = Y |
| 37 | + # transpose to size: t x d x n |
| 38 | + self.task_num = len(X) |
| 39 | + self.dimension, _ = X[0].shape |
| 40 | + self.eta = self.rho2/self.rho1 |
| 41 | + self.c = self.rho1 * self.eta * (1+self.eta) |
| 42 | + funcVal = [] |
| 43 | + |
| 44 | + self.XY = [0]* self.task_num |
| 45 | + W0_prep = [] |
| 46 | + for t in range(self.task_num): |
| 47 | + self.XY[i] = X[i] @ Y[i] |
| 48 | + W0_prep.append(self.XY[i].reshape((-1,1))) |
| 49 | + W0_prep = np.hstack(W0_prep) |
| 50 | + if hasattr(self.opts,'W0'): |
| 51 | + W0=self.opts.W0 |
| 52 | + elif self.opts.init==2: |
| 53 | + W0 = np.zeros((self.dimension, self.task_num)) |
| 54 | + elif self.opts.init == 0: |
| 55 | + W0 =W0_prep |
| 56 | + else: |
| 57 | + W0 = np.random.normal(0, 1, (self.dimension, self.task_num)) |
| 58 | + |
| 59 | + |
| 60 | + M0 = np.array(identity(self.task_num)) * self.k / self.task_num |
| 61 | + # this flag checks if gradient descent only makes significant step |
| 62 | + |
| 63 | + bFlag=0 |
| 64 | + Wz= W0 |
| 65 | + Wz_old = W0 |
| 66 | + Mz = M0.toarray() |
| 67 | + Mz_old = M0.toarray() |
| 68 | + |
| 69 | + t = 1 |
| 70 | + t_old = 0 |
| 71 | + |
| 72 | + it = 0 |
| 73 | + gamma = 1.0 |
| 74 | + gamma_inc = 2 |
| 75 | + |
| 76 | + for it in trange(self.opts.maxIter, file=sys.stdout, desc='outer loop'): |
| 77 | + alpha = (t_old - 1)/t |
| 78 | + Ws = (1 + alpha) * Wz - alpha * Wz_old |
| 79 | + if(isspmatrix(Mz)): |
| 80 | + Mz = Mz.toarray() |
| 81 | + if(isspmatrix(Mz_old)): |
| 82 | + Mz_old = Mz_old.toarray() |
| 83 | + Ms = (1 + alpha) * Mz - alpha * Mz_old |
| 84 | + # compute function value and gradients of the search point |
| 85 | + gWs, gMs, Fs = self.gradVal_eval(Ws, Ms) |
| 86 | + |
| 87 | + in_it = 0 |
| 88 | + for in_it in trange(1000,file=sys.stdout, leave=False, unit_scale=True, desc='inner loop'): |
| 89 | + Wzp = Ws - gWs/gamma |
| 90 | + Mzp, Mzp_Pz, Mzp_DiagSigz = self.singular_projection (Ms - gMs/gamma, self.k) |
| 91 | + Fzp = self.funVal_eval(Wzp, Mzp_Pz, Mzp_DiagSigz) |
| 92 | + |
| 93 | + delta_Wzs = Wzp - Ws |
| 94 | + delta_Mzs = Mzp - Ms |
| 95 | + |
| 96 | + r_sum = (LA.norm(delta_Wzs)**2 + LA.norm(delta_Mzs)**2)/2 |
| 97 | + Fzp_gamma = Fs + np.sum(delta_Wzs*gWs) + np.sum(delta_Mzs*gMs) + gamma * r_sum |
| 98 | + if (r_sum <=1e-20): |
| 99 | + bFlag=1 # this shows that, the gradient step makes little improvement |
| 100 | + break |
| 101 | + if (Fzp <= Fzp_gamma): |
| 102 | + break |
| 103 | + else: |
| 104 | + gamma = gamma * gamma_inc |
| 105 | + Wz_old = Wz |
| 106 | + Wz = Wzp |
| 107 | + Mz_old = Mz |
| 108 | + Mz = Mzp |
| 109 | + funcVal.append(Fzp) |
| 110 | + |
| 111 | + if (bFlag): |
| 112 | + print('\n The program terminates as the gradient step changes the solution very small.') |
| 113 | + break |
| 114 | + if (self.opts.tFlag == 0): |
| 115 | + if it >= 2: |
| 116 | + if (abs(funcVal[-1] - funcVal[-2]) <= self.opts.tol): |
| 117 | + break |
| 118 | + |
| 119 | + elif(self.opts.tFlag == 1): |
| 120 | + if it >= 2: |
| 121 | + if (abs(funcVal[-1] - funcVal[-2]) <= self.opts.tol * funcVal[-2]): |
| 122 | + break |
| 123 | + |
| 124 | + elif(self.opts.tFlag == 2): |
| 125 | + if (funcVal[-1] <= self.opts.tol): |
| 126 | + break |
| 127 | + |
| 128 | + elif(self.opts.tFlag == 3): |
| 129 | + if it >= self.opts.maxIter: |
| 130 | + break |
| 131 | + |
| 132 | + t_old = t |
| 133 | + t = 0.5 * (1 + (1 + 4 * t ** 2) ** 0.5) |
| 134 | + |
| 135 | + self.W = Wzp |
| 136 | + self.M = Mzp |
| 137 | + self.funcVal = funcVal |
| 138 | + |
| 139 | + def singular_projection (self, Msp, k): |
| 140 | + """[summary] |
| 141 | +
|
| 142 | + Args: |
| 143 | + Msp ([type]): [description] |
| 144 | + k ([type]): [description] |
| 145 | +
|
| 146 | + Returns: |
| 147 | + [type]: [description] |
| 148 | + """ |
| 149 | + # l2.1 norm projection. |
| 150 | + EValue, EVector = linalg.eig(Msp) |
| 151 | + Pz = np.real(EVector) |
| 152 | + diag_EValue = np.real(EValue) |
| 153 | + DiagSigz, _, _ = self.bsa_ihb(diag_EValue, np.ones(diag_EValue.shape), k, np.ones(diag_EValue.shape)) |
| 154 | + Mzp = Pz @ np.diag(DiagSigz) @ Pz.T |
| 155 | + Mzp_Pz = Pz |
| 156 | + Mzp_DiagSigz = DiagSigz |
| 157 | + return Mzp, Mzp_Pz, Mzp_DiagSigz |
| 158 | + |
| 159 | + def bsa_ihb(self, a, b, r, u): |
| 160 | + ''' |
| 161 | + Singular Projection |
| 162 | + min 1/2*||x - a||_2^2 |
| 163 | + s.t. b'*x = r, 0<= x <= u, b > 0 |
| 164 | + ''' |
| 165 | + break_flag = 0 |
| 166 | + t_l = a/b |
| 167 | + t_u = (a - u)/b |
| 168 | + T = np.concatenate((t_l, t_u), axis=0) |
| 169 | + t_L = -np.Infinity |
| 170 | + t_U = np.Infinity |
| 171 | + g_tL = 0. |
| 172 | + g_tU = 0. |
| 173 | + |
| 174 | + it = 0 |
| 175 | + while(len(T)!=0): |
| 176 | + it +=1 |
| 177 | + g_t = 0. |
| 178 | + t_hat = np.median(T) |
| 179 | + |
| 180 | + U = t_hat < t_u |
| 181 | + M = (t_u <= t_hat) & (t_hat <= t_l) |
| 182 | + |
| 183 | + if np.sum(U): |
| 184 | + g_t += np.sum(b[U]*u[U]) |
| 185 | + if np.sum(M): |
| 186 | + g_t += np.sum(b[M]*(a[M]-t_hat*b[M])) |
| 187 | + if g_t > r: |
| 188 | + t_L = t_hat |
| 189 | + T = T[T>t_hat] |
| 190 | + g_tL = g_t |
| 191 | + elif g_t <r: |
| 192 | + t_U = t_hat |
| 193 | + T = T[T<t_hat] |
| 194 | + g_tU = g_t |
| 195 | + else: |
| 196 | + t_star = t_hat |
| 197 | + break_flag = 1 |
| 198 | + break |
| 199 | + if not break_flag: |
| 200 | + eps = g_tU - g_tL |
| 201 | + t_star = t_L - (g_tL - r) * (t_U - t_L)/(eps) |
| 202 | + est = a-t_star * b |
| 203 | + if(np.isnan(est).any()): |
| 204 | + est[np.isnan(est)] = 0 |
| 205 | + x_star = np.minimum(u, np.max(est, 0)) |
| 206 | + return x_star, t_star, it |
| 207 | + |
| 208 | + def gradVal_eval(self, W, M): |
| 209 | + IM = self.eta * identity(self.task_num)+M |
| 210 | + # could be sparse matrix to solve |
| 211 | + invEtaMWt = linalg.inv(IM) @ W.T |
| 212 | + if self.opts.pFlag: |
| 213 | + # grad_W = zeros(zeros(W)); |
| 214 | + # # parfor i = 1:task_num |
| 215 | + # # grad_W (i, :) = X{i}*(X{i}' * W(:,i)-Y{i}) |
| 216 | + pass |
| 217 | + else: |
| 218 | + grad_W = [] |
| 219 | + for i in range(self.task_num): |
| 220 | + XWi = self.X[i].T @ W[:,i] |
| 221 | + XTXWi = self.X[i] @ XWi |
| 222 | + grad_W.append((XTXWi - self.XY[i]).reshape(-1,1)) |
| 223 | + grad_W = np.hstack(grad_W) |
| 224 | + grad_W = grad_W + 2 * self.c * invEtaMWt.T |
| 225 | + W2 = W.T @ W |
| 226 | + grad_M = - self.c * W2@linalg.inv(IM)@linalg.inv(IM) |
| 227 | + |
| 228 | + funcVal = 0 |
| 229 | + if self.opts.pFlag: |
| 230 | + pass |
| 231 | + else: |
| 232 | + for i in range(self.task_num): |
| 233 | + funcVal = funcVal + 0.5 * LA.norm ((self.Y[i] - self.X[i].T @ W[:, i]), ord=2)**2 |
| 234 | + funcVal = funcVal + self.c * np.trace( W @ invEtaMWt) |
| 235 | + return grad_W, grad_M, funcVal |
| 236 | + |
| 237 | + def funVal_eval(self, W, M_Pz, M_DiagSigz): |
| 238 | + invIM = M_Pz @ (np.diag(1/(self.eta + np.array(M_DiagSigz)))) @ M_Pz.T |
| 239 | + invEtaMWt = invIM @ W.T |
| 240 | + funcVal = 0 |
| 241 | + if self.opts.pFlag: |
| 242 | + # parfor i = 1: task_num |
| 243 | + # # funcVal = funcVal + 0.5 * norm (Y{i} - X{i}' * W(:, i))^2; |
| 244 | + # # end |
| 245 | + pass |
| 246 | + else: |
| 247 | + for i in range(self.task_num): |
| 248 | + funcVal = funcVal + 0.5 * LA.norm ((self.Y[i] - self.X[i].T @ W[:, i]), ord=2)**2 |
| 249 | + funcVal = funcVal + self.c * np.trace(W @ invEtaMWt) |
| 250 | + return funcVal |
| 251 | + |
| 252 | + def get_params(self, deep = False): |
| 253 | + return {'rho1':self.rho1, 'rho2':self.rho2,'opts':self.opts, 'k':self.k} |
| 254 | + |
| 255 | + def analyse(self): |
| 256 | + # returns correlation matrix |
| 257 | + |
| 258 | + # kmCMTL_OrderedModel = np.zeros(self.W.shape) |
| 259 | + # for i in range(self.k): |
| 260 | + # clusModel = self.W[:, i:self.task_num*self.k:self.k] |
| 261 | + # kmCMTL_OrderedModel[:, (i)*self.task_num: (i+1)* self.task_num] = clusModel |
| 262 | + # return 1-np.corrcoef(kmCMTL_OrderedModel) |
| 263 | + return self.W |
| 264 | + |
| 265 | + |
| 266 | + |
| 267 | + |
0 commit comments