Skip to content

Commit 9df09e2

Browse files
Merge pull request #65 from EIDOSLAB/development
Fix MultiMacenkoNormalizer
2 parents 7b2620e + 7c2a95f commit 9df09e2

File tree

5 files changed

+56
-15
lines changed

5 files changed

+56
-15
lines changed

tests/test_torch.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,40 @@ def test_macenko_torch():
5757
# assess whether the normalized images are identical across backends
5858
np.testing.assert_almost_equal(result_numpy.flatten(), result_torch.flatten(), decimal=2, verbose=True)
5959

60+
def test_multitarget_macenko_torch():
61+
size = 1024
62+
curr_file_path = os.path.dirname(os.path.realpath(__file__))
63+
target = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/target.png")), cv2.COLOR_BGR2RGB), (size, size))
64+
to_transform = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/source.png")), cv2.COLOR_BGR2RGB), (size, size))
65+
66+
# setup preprocessing and preprocess image to be normalized
67+
T = transforms.Compose([
68+
transforms.ToTensor(),
69+
transforms.Lambda(lambda x: x * 255)
70+
])
71+
target = T(target)
72+
t_to_transform = T(to_transform)
73+
74+
# initialize normalizers for each backend and fit to target image
75+
single_normalizer = torchstain.normalizers.MacenkoNormalizer(backend="torch")
76+
single_normalizer.fit(target)
77+
78+
multi_normalizer = torchstain.normalizers.MultiMacenkoNormalizer(backend="torch", norm_mode="avg-post")
79+
multi_normalizer.fit([target, target, target])
80+
81+
82+
# transform
83+
result_single, _, _ = single_normalizer.normalize(I=t_to_transform, stains=True)
84+
result_multi, _, _ = multi_normalizer.normalize(I=t_to_transform, stains=True)
85+
86+
# convert to numpy and set dtype
87+
result_single = result_single.numpy().astype("float32") / 255.
88+
result_multi = result_multi.numpy().astype("float32") / 255.
89+
90+
# assess whether the normalized images are identical across backends
91+
np.testing.assert_almost_equal(result_single.flatten(), result_multi.flatten(), decimal=2, verbose=True)
92+
93+
6094
def test_reinhard_torch():
6195
size = 1024
6296
curr_file_path = os.path.dirname(os.path.realpath(__file__))
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .he_normalizer import HENormalizer
22
from .macenko import MacenkoNormalizer
3+
from .multitarget import MultiMacenkoNormalizer
34
from .reinhard import ReinhardNormalizer
Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1-
def MultiMacenkoNormalizer(backend='torch', **kwargs):
2-
if backend == 'torch':
3-
from torchstain.torch.normalizers.multitarget import MultiMacenkoNormalizer
4-
return MultiMacenkoNormalizer(**kwargs)
1+
def MultiMacenkoNormalizer(backend="torch", **kwargs):
2+
if backend == "numpy":
3+
raise NotImplementedError("MultiMacenkoNormalizer is not implemented for NumPy backend")
4+
elif backend == "torch":
5+
from torchstain.torch.normalizers import TorchMultiMacenkoNormalizer
6+
return TorchMultiMacenkoNormalizer(**kwargs)
7+
elif backend == "tensorflow":
8+
raise NotImplementedError("MultiMacenkoNormalizer is not implemented for TensorFlow backend")
59
else:
6-
raise Exception(f'Unsupported backend {backend}')
10+
raise Exception(f"Unsupported backend {backend}")
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from torchstain.torch.normalizers.macenko import TorchMacenkoNormalizer
2-
from torchstain.torch.normalizers.multitarget import MultiMacenkoNormalizer
2+
from torchstain.torch.normalizers.multitarget import TorchMultiMacenkoNormalizer
33
from torchstain.torch.normalizers.reinhard import TorchReinhardNormalizer

torchstain/torch/normalizers/multitarget.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
import torch
22
from torchstain.torch.utils import cov, percentile
3+
34
"""
45
Implementation of the multi-target normalizer from the paper: https://arxiv.org/pdf/2406.02077
56
"""
6-
class MultiMacenkoNormalizer:
7-
def __init__(self, norm_mode='avg-post'):
7+
class TorchMultiMacenkoNormalizer:
8+
def __init__(self, norm_mode="avg-post"):
89
self.norm_mode = norm_mode
910
self.HERef = torch.tensor([[0.5626, 0.2159],
1011
[0.7201, 0.8012],
1112
[0.4062, 0.5581]])
1213
self.maxCRef = torch.tensor([1.9705, 1.0308])
13-
self.updated_lstsq = hasattr(torch.linalg, 'lstsq')
14+
self.updated_lstsq = hasattr(torch.linalg, "lstsq")
1415

1516
def __convert_rgb2od(self, I, Io, beta):
1617
I = I.permute(1, 2, 0)
@@ -48,7 +49,8 @@ def __find_concentration(self, OD, HE):
4849
def __compute_matrices_single(self, I, Io, alpha, beta):
4950
OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta)
5051

51-
_, eigvecs = torch.symeig(cov(ODhat.T), eigenvectors=True)
52+
# _, eigvecs = torch.symeig(cov(ODhat.T), eigenvectors=True)
53+
_, eigvecs = torch.linalg.eigh(cov(ODhat.T), UPLO='U')
5254
eigvecs = eigvecs[:, [1, 2]]
5355

5456
HE = self.__find_HE(ODhat, eigvecs, alpha)
@@ -59,15 +61,15 @@ def __compute_matrices_single(self, I, Io, alpha, beta):
5961
return HE, C, maxC
6062

6163
def fit(self, Is, Io=240, alpha=1, beta=0.15):
62-
if self.norm_mode == 'avg-post':
64+
if self.norm_mode == "avg-post":
6365
HEs, _, maxCs = zip(*(
6466
self.__compute_matrices_single(I, Io, alpha, beta)
6567
for I in Is
6668
))
6769

6870
self.HERef = torch.stack(HEs).mean(dim=0)
6971
self.maxCRef = torch.stack(maxCs).mean(dim=0)
70-
elif self.norm_mode == 'concat':
72+
elif self.norm_mode == "concat":
7173
ODs, ODhats = zip(*(
7274
self.__convert_rgb2od(I, Io, beta)
7375
for I in Is
@@ -83,7 +85,7 @@ def fit(self, Is, Io=240, alpha=1, beta=0.15):
8385
maxCs = torch.stack([percentile(C[0, :], 99), percentile(C[1, :], 99)])
8486
self.HERef = HE
8587
self.maxCRef = maxCs
86-
elif self.norm_mode == 'avg-pre':
88+
elif self.norm_mode == "avg-pre":
8789
ODs, ODhats = zip(*(
8890
self.__convert_rgb2od(I, Io, beta)
8991
for I in Is
@@ -100,7 +102,7 @@ def fit(self, Is, Io=240, alpha=1, beta=0.15):
100102
maxCs = torch.stack([percentile(C[0, :], 99), percentile(C[1, :], 99)])
101103
self.HERef = HE
102104
self.maxCRef = maxCs
103-
elif self.norm_mode == 'fixed-single' or self.norm_mode == 'stochastic-single':
105+
elif self.norm_mode == "fixed-single" or self.norm_mode == "stochastic-single":
104106
# single img
105107
self.HERef, _, self.maxCRef = self.__compute_matrices_single(Is[0], Io, alpha, beta)
106108
else:
@@ -127,4 +129,4 @@ def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True):
127129
E[E > 255] = 255
128130
E = E.T.reshape(h, w, c).int()
129131

130-
return Inorm, H, E
132+
return Inorm, H, E

0 commit comments

Comments
 (0)