Skip to content

Commit 04331ac

Browse files
authored
Merge pull request #68 from wouterzwerink/torch-cov
Use torch.cov instead of custom function
2 parents d439a8b + 9408b23 commit 04331ac

File tree

6 files changed

+8
-27
lines changed

6 files changed

+8
-27
lines changed

tests/test_torch.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,6 @@
1111
def setup_function(fn):
1212
print("torch version:", torch.__version__, "torchvision version:", torchvision.__version__)
1313

14-
15-
def test_cov():
16-
x = np.random.randn(10, 10)
17-
cov_np = np.cov(x)
18-
cov_t = torchstain.torch.utils.cov(torch.tensor(x))
19-
20-
np.testing.assert_almost_equal(cov_np, cov_t.numpy())
21-
22-
2314
def test_percentile():
2415
x = np.random.randn(10, 10)
2516
p = 20

torchstain/torch/augmentors/macenko.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
from torchstain.base.augmentors.he_augmentor import HEAugmentor
3-
from torchstain.torch.utils import cov, percentile
3+
from torchstain.torch.utils import percentile
44

55
"""
66
Source code ported from: https://github.com/schaugf/HEnorm_python
@@ -66,7 +66,7 @@ def __compute_matrices(self, I, Io, alpha, beta):
6666
OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta)
6767

6868
# compute eigenvectors
69-
_, eigvecs = torch.linalg.eigh(cov(ODhat.T))
69+
_, eigvecs = torch.linalg.eigh(torch.cov(ODhat.T))
7070
eigvecs = eigvecs[:, [1, 2]]
7171

7272
HE = self.__find_HE(ODhat, eigvecs, alpha)

torchstain/torch/normalizers/macenko.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
from torchstain.base.normalizers.he_normalizer import HENormalizer
3-
from torchstain.torch.utils import cov, percentile
3+
from torchstain.torch.utils import percentile
44

55
"""
66
Source code ported from: https://github.com/schaugf/HEnorm_python
@@ -61,7 +61,7 @@ def __compute_matrices(self, I, Io, alpha, beta):
6161
OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta)
6262

6363
# compute eigenvectors
64-
_, eigvecs = torch.linalg.eigh(cov(ODhat.T))
64+
_, eigvecs = torch.linalg.eigh(torch.cov(ODhat.T))
6565
eigvecs = eigvecs[:, [1, 2]]
6666

6767
HE = self.__find_HE(ODhat, eigvecs, alpha)

torchstain/torch/normalizers/multitarget.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from torchstain.torch.utils import cov, percentile
2+
from torchstain.torch.utils import percentile
33

44
"""
55
Implementation of the multi-target normalizer from the paper: https://arxiv.org/pdf/2406.02077
@@ -50,7 +50,7 @@ def __compute_matrices_single(self, I, Io, alpha, beta):
5050
OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta)
5151

5252
# _, eigvecs = torch.symeig(cov(ODhat.T), eigenvectors=True)
53-
_, eigvecs = torch.linalg.eigh(cov(ODhat.T), UPLO='U')
53+
_, eigvecs = torch.linalg.eigh(torch.cov(ODhat.T), UPLO='U')
5454
eigvecs = eigvecs[:, [1, 2]]
5555

5656
HE = self.__find_HE(ODhat, eigvecs, alpha)
@@ -77,7 +77,7 @@ def fit(self, Is, Io=240, alpha=1, beta=0.15):
7777
OD = torch.cat(ODs, dim=0)
7878
ODhat = torch.cat(ODhats, dim=0)
7979

80-
eigvecs = torch.symeig(cov(ODhat.T), eigenvectors=True)[1][:, [1, 2]]
80+
eigvecs = torch.symeig(torch.cov(ODhat.T), eigenvectors=True)[1][:, [1, 2]]
8181

8282
HE = self.__find_HE(ODhat, eigvecs, alpha)
8383

@@ -91,7 +91,7 @@ def fit(self, Is, Io=240, alpha=1, beta=0.15):
9191
for I in Is
9292
))
9393

94-
eigvecs = torch.stack([torch.symeig(cov(ODhat.T), eigenvectors=True)[1][:, [1, 2]] for ODhat in ODhats]).mean(dim=0)
94+
eigvecs = torch.stack([torch.symeig(torch.cov(ODhat.T), eigenvectors=True)[1][:, [1, 2]] for ODhat in ODhats]).mean(dim=0)
9595

9696
OD = torch.cat(ODs, dim=0)
9797
ODhat = torch.cat(ODhats, dim=0)

torchstain/torch/utils/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from torchstain.torch.utils.cov import cov
21
from torchstain.torch.utils.percentile import percentile
32
from torchstain.torch.utils.stats import *
43
from torchstain.torch.utils.split import *

torchstain/torch/utils/cov.py

Lines changed: 0 additions & 9 deletions
This file was deleted.

0 commit comments

Comments
 (0)