@@ -45,8 +45,8 @@ class SoftImpute(BaseEstimator, TransformerMixin):
4545 >>> import numpy as np
4646 >>> from qolmat.imputations.softimpute import SoftImpute
4747 >>> D = np.array([[1, 2, np.nan, 4], [1, 5, 3, np.nan], [4, 2, 3, 2], [1, 1, 5, 4]])
48- >>> D = SoftImpute(random_state=11).fit_transform (D)
49- >>> print(D )
48+ >>> M, A = SoftImpute(random_state=11).decompose (D)
49+ >>> print(M + A )
5050 [[1. 2. 3.7242757 4. ]
5151 [1. 5. 3. 1.97846028]
5252 [4. 2. 3. 2. ]
@@ -129,7 +129,7 @@ def decompose(self, X: NDArray, Omega: NDArray) -> Tuple[NDArray, NDArray]:
129129 A = U * D
130130 B = V * D
131131 M = A @ B .T
132- cost_start = self .cost_function (X , M , A , Omega , tau )
132+ cost_start = SoftImpute .cost_function (X , M , A , Omega , tau )
133133 for iter_ in range (self .max_iterations ):
134134 U_old = U
135135 V_old = V
@@ -156,7 +156,7 @@ def decompose(self, X: NDArray, Omega: NDArray) -> Tuple[NDArray, NDArray]:
156156 A = U * D
157157
158158 # Step 4 : Stopping upon convergence
159- ratio = self ._check_convergence (U_old , D_old , V_old , U , D , V )
159+ ratio = SoftImpute ._check_convergence (U_old , D_old , V_old , U , D , V )
160160 if self .verbose :
161161 print (f"Iteration { iter_ } : ratio = { round (ratio , 4 )} " )
162162 if ratio < self .tolerance :
@@ -171,7 +171,7 @@ def decompose(self, X: NDArray, Omega: NDArray) -> Tuple[NDArray, NDArray]:
171171
172172 A = np .where (Omega , X - M , 0 )
173173
174- cost_end = self .cost_function (X , M , A , Omega , tau )
174+ cost_end = SoftImpute .cost_function (X , M , A , Omega , tau )
175175 if self .verbose and (cost_end > cost_start + 1e-9 ):
176176 warnings .warn (
177177 f"Convergence failed: cost function increased from"
@@ -180,8 +180,8 @@ def decompose(self, X: NDArray, Omega: NDArray) -> Tuple[NDArray, NDArray]:
180180
181181 return M , A
182182
183+ @staticmethod
183184 def _check_convergence (
184- self ,
185185 U_old : NDArray ,
186186 D_old : NDArray ,
187187 V_old : NDArray ,
0 commit comments