Skip to content

Commit 2d7119b

Browse files
Julien RousselJulien Roussel
authored andcommitted
doctest mended
1 parent 4b784dd commit 2d7119b

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

qolmat/imputations/softimpute.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,11 @@ class SoftImpute(BaseEstimator, TransformerMixin):
4545
>>> import numpy as np
4646
>>> from qolmat.imputations.softimpute import SoftImpute
4747
>>> D = np.array([[1, 2, np.nan, 4], [1, 5, 3, np.nan], [4, 2, 3, 2], [1, 1, 5, 4]])
48-
>>> M, A = SoftImpute(random_state=11).decompose(D)
48+
>>> Omega = ~np.isnan(D)
49+
>>> M, A = SoftImpute(random_state=11).decompose(D, Omega)
4950
>>> print(M + A)
50-
[[1. 2. 3.7242757 4. ]
51-
[1. 5. 3. 1.97846028]
51+
[[1. 2. 4.12611456 4. ]
52+
[1. 5. 3. 0.87217939]
5253
[4. 2. 3. 2. ]
5354
[1. 1. 5. 4. ]]
5455
"""
@@ -159,9 +160,11 @@ def decompose(self, X: NDArray, Omega: NDArray) -> Tuple[NDArray, NDArray]:
159160
ratio = SoftImpute._check_convergence(U_old, D_old, V_old, U, D, V)
160161
if self.verbose:
161162
print(f"Iteration {iter_}: ratio = {round(ratio, 4)}")
162-
if ratio < self.tolerance:
163-
print(f"Convergence reached at iteration {iter_} with ratio = {round(ratio, 4)}")
164-
break
163+
if ratio < self.tolerance:
164+
print(
165+
f"Convergence reached at iteration {iter_} with ratio = {round(ratio, 4)}"
166+
)
167+
break
165168

166169
Xstar = np.where(Omega, X - A @ B.T, 0) + A @ B.T
167170
M = Xstar @ V

0 commit comments

Comments
 (0)