11import torch
22from torchstain .torch .utils import cov , percentile
3+
34"""
45Implementation of the multi-target normalizer from the paper: https://arxiv.org/pdf/2406.02077
56"""
6- class MultiMacenkoNormalizer :
7- def __init__ (self , norm_mode = ' avg-post' ):
7+ class TorchMultiMacenkoNormalizer :
8+ def __init__ (self , norm_mode = " avg-post" ):
89 self .norm_mode = norm_mode
910 self .HERef = torch .tensor ([[0.5626 , 0.2159 ],
1011 [0.7201 , 0.8012 ],
1112 [0.4062 , 0.5581 ]])
1213 self .maxCRef = torch .tensor ([1.9705 , 1.0308 ])
13- self .updated_lstsq = hasattr (torch .linalg , ' lstsq' )
14+ self .updated_lstsq = hasattr (torch .linalg , " lstsq" )
1415
1516 def __convert_rgb2od (self , I , Io , beta ):
1617 I = I .permute (1 , 2 , 0 )
@@ -48,7 +49,8 @@ def __find_concentration(self, OD, HE):
4849 def __compute_matrices_single (self , I , Io , alpha , beta ):
4950 OD , ODhat = self .__convert_rgb2od (I , Io = Io , beta = beta )
5051
51- _ , eigvecs = torch .symeig (cov (ODhat .T ), eigenvectors = True )
52+ # _, eigvecs = torch.symeig(cov(ODhat.T), eigenvectors=True)
53+ _ , eigvecs = torch .linalg .eigh (cov (ODhat .T ), UPLO = 'U' )
5254 eigvecs = eigvecs [:, [1 , 2 ]]
5355
5456 HE = self .__find_HE (ODhat , eigvecs , alpha )
@@ -59,15 +61,15 @@ def __compute_matrices_single(self, I, Io, alpha, beta):
5961 return HE , C , maxC
6062
6163 def fit (self , Is , Io = 240 , alpha = 1 , beta = 0.15 ):
62- if self .norm_mode == ' avg-post' :
64+ if self .norm_mode == " avg-post" :
6365 HEs , _ , maxCs = zip (* (
6466 self .__compute_matrices_single (I , Io , alpha , beta )
6567 for I in Is
6668 ))
6769
6870 self .HERef = torch .stack (HEs ).mean (dim = 0 )
6971 self .maxCRef = torch .stack (maxCs ).mean (dim = 0 )
70- elif self .norm_mode == ' concat' :
72+ elif self .norm_mode == " concat" :
7173 ODs , ODhats = zip (* (
7274 self .__convert_rgb2od (I , Io , beta )
7375 for I in Is
@@ -83,7 +85,7 @@ def fit(self, Is, Io=240, alpha=1, beta=0.15):
8385 maxCs = torch .stack ([percentile (C [0 , :], 99 ), percentile (C [1 , :], 99 )])
8486 self .HERef = HE
8587 self .maxCRef = maxCs
86- elif self .norm_mode == ' avg-pre' :
88+ elif self .norm_mode == " avg-pre" :
8789 ODs , ODhats = zip (* (
8890 self .__convert_rgb2od (I , Io , beta )
8991 for I in Is
@@ -100,7 +102,7 @@ def fit(self, Is, Io=240, alpha=1, beta=0.15):
100102 maxCs = torch .stack ([percentile (C [0 , :], 99 ), percentile (C [1 , :], 99 )])
101103 self .HERef = HE
102104 self .maxCRef = maxCs
103- elif self .norm_mode == ' fixed-single' or self .norm_mode == ' stochastic-single' :
105+ elif self .norm_mode == " fixed-single" or self .norm_mode == " stochastic-single" :
104106 # single img
105107 self .HERef , _ , self .maxCRef = self .__compute_matrices_single (Is [0 ], Io , alpha , beta )
106108 else :
@@ -127,4 +129,4 @@ def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True):
127129 E [E > 255 ] = 255
128130 E = E .T .reshape (h , w , c ).int ()
129131
130- return Inorm , H , E
132+ return Inorm , H , E
0 commit comments