Skip to content

Commit a93e9f8

Browse files
Julien RousselJulien Roussel
authored andcommitted
doctstring fixed
1 parent 22ffb90 commit a93e9f8

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

qolmat/imputations/softimpute.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ class SoftImpute(BaseEstimator, TransformerMixin):
4545
>>> import numpy as np
4646
>>> from qolmat.imputations.softimpute import SoftImpute
4747
>>> D = np.array([[1, 2, np.nan, 4], [1, 5, 3, np.nan], [4, 2, 3, 2], [1, 1, 5, 4]])
48-
>>> D = SoftImpute(random_state=11).fit_transform(D)
49-
>>> print(D)
48+
>>> M, A = SoftImpute(random_state=11).decompose(D)
49+
>>> print(M + A)
5050
[[1. 2. 3.7242757 4. ]
5151
[1. 5. 3. 1.97846028]
5252
[4. 2. 3. 2. ]
@@ -129,7 +129,7 @@ def decompose(self, X: NDArray, Omega: NDArray) -> Tuple[NDArray, NDArray]:
129129
A = U * D
130130
B = V * D
131131
M = A @ B.T
132-
cost_start = self.cost_function(X, M, A, Omega, tau)
132+
cost_start = SoftImpute.cost_function(X, M, A, Omega, tau)
133133
for iter_ in range(self.max_iterations):
134134
U_old = U
135135
V_old = V
@@ -156,7 +156,7 @@ def decompose(self, X: NDArray, Omega: NDArray) -> Tuple[NDArray, NDArray]:
156156
A = U * D
157157

158158
# Step 4 : Stopping upon convergence
159-
ratio = self._check_convergence(U_old, D_old, V_old, U, D, V)
159+
ratio = SoftImpute._check_convergence(U_old, D_old, V_old, U, D, V)
160160
if self.verbose:
161161
print(f"Iteration {iter_}: ratio = {round(ratio, 4)}")
162162
if ratio < self.tolerance:
@@ -171,7 +171,7 @@ def decompose(self, X: NDArray, Omega: NDArray) -> Tuple[NDArray, NDArray]:
171171

172172
A = np.where(Omega, X - M, 0)
173173

174-
cost_end = self.cost_function(X, M, A, Omega, tau)
174+
cost_end = SoftImpute.cost_function(X, M, A, Omega, tau)
175175
if self.verbose and (cost_end > cost_start + 1e-9):
176176
warnings.warn(
177177
f"Convergence failed: cost function increased from"
@@ -180,8 +180,8 @@ def decompose(self, X: NDArray, Omega: NDArray) -> Tuple[NDArray, NDArray]:
180180

181181
return M, A
182182

183+
@staticmethod
183184
def _check_convergence(
184-
self,
185185
U_old: NDArray,
186186
D_old: NDArray,
187187
V_old: NDArray,

0 commit comments

Comments
 (0)