diff --git a/torchstain/base/normalizers/macenko.py b/torchstain/base/normalizers/macenko.py index 2c0c988..9923702 100644 --- a/torchstain/base/normalizers/macenko.py +++ b/torchstain/base/normalizers/macenko.py @@ -1,12 +1,15 @@ -def MacenkoNormalizer(backend='torch'): - if backend == 'numpy': +def MacenkoNormalizer(backend="torch", device="cpu"): + if backend == "numpy": from torchstain.numpy.normalizers import NumpyMacenkoNormalizer + return NumpyMacenkoNormalizer() elif backend == "torch": from torchstain.torch.normalizers.macenko import TorchMacenkoNormalizer - return TorchMacenkoNormalizer() + + return TorchMacenkoNormalizer(device=device) elif backend == "tensorflow": from torchstain.tf.normalizers.macenko import TensorFlowMacenkoNormalizer + return TensorFlowMacenkoNormalizer() else: - raise Exception(f'Unknown backend {backend}') + raise Exception(f"Unknown backend {backend}") diff --git a/torchstain/base/normalizers/multitarget.py b/torchstain/base/normalizers/multitarget.py index 495569b..e60fcda 100644 --- a/torchstain/base/normalizers/multitarget.py +++ b/torchstain/base/normalizers/multitarget.py @@ -1,10 +1,15 @@ def MultiMacenkoNormalizer(backend="torch", **kwargs): if backend == "numpy": - raise NotImplementedError("MultiMacenkoNormalizer is not implemented for NumPy backend") + raise NotImplementedError( + "MultiMacenkoNormalizer is not implemented for NumPy backend" + ) elif backend == "torch": from torchstain.torch.normalizers import TorchMultiMacenkoNormalizer + return TorchMultiMacenkoNormalizer(**kwargs) elif backend == "tensorflow": - raise NotImplementedError("MultiMacenkoNormalizer is not implemented for TensorFlow backend") + raise NotImplementedError( + "MultiMacenkoNormalizer is not implemented for TensorFlow backend" + ) else: raise Exception(f"Unsupported backend {backend}") diff --git a/torchstain/torch/normalizers/macenko.py b/torchstain/torch/normalizers/macenko.py index 569fd06..4bd74b7 100644 --- a/torchstain/torch/normalizers/macenko.py +++ b/torchstain/torch/normalizers/macenko.py @@ -6,18 +6,22 @@ Source code ported from: https://github.com/schaugf/HEnorm_python Original implementation: https://github.com/mitkovetta/staining-normalization """ + + class TorchMacenkoNormalizer(HENormalizer): - def __init__(self): + def __init__(self, device="cpu"): super().__init__() - self.HERef = torch.tensor([[0.5626, 0.2159], - [0.7201, 0.8012], - [0.4062, 0.5581]]) - self.maxCRef = torch.tensor([1.9705, 1.0308]) + self.device = device + + self.HERef = torch.tensor( + [[0.5626, 0.2159], [0.7201, 0.8012], [0.4062, 0.5581]], device=self.device + ) + self.maxCRef = torch.tensor([1.9705, 1.0308], device=self.device) # Avoid using deprecated torch.lstsq (since 1.9.0) - self.updated_lstsq = hasattr(torch.linalg, 'lstsq') - + self.updated_lstsq = hasattr(torch.linalg, "lstsq") + def __convert_rgb2od(self, I, Io, beta): I = I.permute(1, 2, 0) @@ -38,12 +42,20 @@ def __find_HE(self, ODhat, eigvecs, alpha): minPhi = percentile(phi, alpha) maxPhi = percentile(phi, 100 - alpha) - vMin = torch.matmul(eigvecs, torch.stack((torch.cos(minPhi), torch.sin(minPhi)))).unsqueeze(1) - vMax = torch.matmul(eigvecs, torch.stack((torch.cos(maxPhi), torch.sin(maxPhi)))).unsqueeze(1) + vMin = torch.matmul( + eigvecs, torch.stack((torch.cos(minPhi), torch.sin(minPhi))) + ).unsqueeze(1) + vMax = torch.matmul( + eigvecs, torch.stack((torch.cos(maxPhi), torch.sin(maxPhi))) + ).unsqueeze(1) # a heuristic to make the vector corresponding to hematoxylin first and the # one corresponding to eosin second - HE = torch.where(vMin[0] > vMax[0], torch.cat((vMin, vMax), dim=1), torch.cat((vMax, vMin), dim=1)) + HE = torch.where( + vMin[0] > vMax[0], + torch.cat((vMin, vMax), dim=1), + torch.cat((vMax, vMin), dim=1), + ) return HE @@ -54,14 +66,16 @@ def __find_concentration(self, OD, HE): # determine concentrations of the individual stains if not self.updated_lstsq: return torch.lstsq(Y, HE)[0][:2] - - return torch.linalg.lstsq(HE, Y)[0] + + return torch.linalg.pinv(HE) @ OD.T + # this fails for large Y dimension HE.shape torch.Size([3, 2]) Y.shape torch.Size([3, 1048576]) + # torch.linalg.lstsq(HE, Y)[0] def __compute_matrices(self, I, Io, alpha, beta): OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta) # compute eigenvectors - _, eigvecs = torch.linalg.eigh(cov(ODhat.T)) + _, eigvecs = torch.linalg.eigh(cov(ODhat.T)) eigvecs = eigvecs[:, [1, 2]] HE = self.__find_HE(ODhat, eigvecs, alpha) @@ -72,13 +86,16 @@ def __compute_matrices(self, I, Io, alpha, beta): return HE, C, maxC def fit(self, I, Io=240, alpha=1, beta=0.15): + + I = I.to(self.device) + HE, _, maxC = self.__compute_matrices(I, Io, alpha, beta) self.HERef = HE self.maxCRef = maxC def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True): - ''' Normalize staining appearence of H&E stained images + """Normalize staining appearence of H&E stained images Example use: see example.py @@ -98,7 +115,10 @@ def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True): Reference: A method for normalizing histology slides for quantitative analysis. M. Macenko et al., ISBI 2009 - ''' + """ + + I = I.to(self.device) + c, h, w = I.shape HE, C, maxC = self.__compute_matrices(I, Io, alpha, beta) @@ -114,11 +134,21 @@ def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True): H, E = None, None if stains: - H = torch.mul(Io, torch.exp(torch.matmul(-self.HERef[:, 0].unsqueeze(-1), C[0, :].unsqueeze(0)))) + H = torch.mul( + Io, + torch.exp( + torch.matmul(-self.HERef[:, 0].unsqueeze(-1), C[0, :].unsqueeze(0)) + ), + ) H[H > 255] = 255 H = H.T.reshape(h, w, c).int() - E = torch.mul(Io, torch.exp(torch.matmul(-self.HERef[:, 1].unsqueeze(-1), C[1, :].unsqueeze(0)))) + E = torch.mul( + Io, + torch.exp( + torch.matmul(-self.HERef[:, 1].unsqueeze(-1), C[1, :].unsqueeze(0)) + ), + ) E[E > 255] = 255 E = E.T.reshape(h, w, c).int() diff --git a/torchstain/torch/normalizers/multitarget.py b/torchstain/torch/normalizers/multitarget.py index d1396b4..533f273 100644 --- a/torchstain/torch/normalizers/multitarget.py +++ b/torchstain/torch/normalizers/multitarget.py @@ -4,18 +4,22 @@ """ Implementation of the multi-target normalizer from the paper: https://arxiv.org/pdf/2406.02077 """ + + class TorchMultiMacenkoNormalizer: - def __init__(self, norm_mode="avg-post"): + def __init__(self, norm_mode="avg-post", device="cpu"): + self.device = device + self.norm_mode = norm_mode - self.HERef = torch.tensor([[0.5626, 0.2159], - [0.7201, 0.8012], - [0.4062, 0.5581]]) - self.maxCRef = torch.tensor([1.9705, 1.0308]) + self.HERef = torch.tensor( + [[0.5626, 0.2159], [0.7201, 0.8012], [0.4062, 0.5581]], device=self.device + ) + self.maxCRef = torch.tensor([1.9705, 1.0308], device=self.device) self.updated_lstsq = hasattr(torch.linalg, "lstsq") - + def __convert_rgb2od(self, I, Io, beta): I = I.permute(1, 2, 0) - OD = -torch.log((I.reshape((-1, I.shape[-1])).float() + 1)/Io) + OD = -torch.log((I.reshape((-1, I.shape[-1])).float() + 1) / Io) ODhat = OD[~torch.any(OD < beta, dim=1)] return OD, ODhat @@ -29,10 +33,18 @@ def __find_phi_bounds(self, ODhat, eigvecs, alpha): return minPhi, maxPhi def __find_HE_from_bounds(self, eigvecs, minPhi, maxPhi): - vMin = torch.matmul(eigvecs, torch.stack((torch.cos(minPhi), torch.sin(minPhi)))).unsqueeze(1) - vMax = torch.matmul(eigvecs, torch.stack((torch.cos(maxPhi), torch.sin(maxPhi)))).unsqueeze(1) - - HE = torch.where(vMin[0] > vMax[0], torch.cat((vMin, vMax), dim=1), torch.cat((vMax, vMin), dim=1)) + vMin = torch.matmul( + eigvecs, torch.stack((torch.cos(minPhi), torch.sin(minPhi))) + ).unsqueeze(1) + vMax = torch.matmul( + eigvecs, torch.stack((torch.cos(maxPhi), torch.sin(maxPhi))) + ).unsqueeze(1) + + HE = torch.where( + vMin[0] > vMax[0], + torch.cat((vMin, vMax), dim=1), + torch.cat((vMax, vMin), dim=1), + ) return HE @@ -44,13 +56,17 @@ def __find_concentration(self, OD, HE): Y = OD.T if not self.updated_lstsq: return torch.lstsq(Y, HE)[0][:2] - return torch.linalg.lstsq(HE, Y)[0] + return torch.linalg.pinv(HE) @ OD.T + + # this fails for large Y dimension HE.shape torch.Size([3, 2]) Y.shape torch.Size([3, 1048576]) + # return torch.linalg.lstsq(HE, Y)[0] + def __compute_matrices_single(self, I, Io, alpha, beta): OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta) # _, eigvecs = torch.symeig(cov(ODhat.T), eigenvectors=True) - _, eigvecs = torch.linalg.eigh(cov(ODhat.T), UPLO='U') + _, eigvecs = torch.linalg.eigh(cov(ODhat.T), UPLO="U") eigvecs = eigvecs[:, [1, 2]] HE = self.__find_HE(ODhat, eigvecs, alpha) @@ -61,42 +77,43 @@ def __compute_matrices_single(self, I, Io, alpha, beta): return HE, C, maxC def fit(self, Is, Io=240, alpha=1, beta=0.15): + + Is = [I.to(self.device) for I in Is] + if self.norm_mode == "avg-post": - HEs, _, maxCs = zip(*( - self.__compute_matrices_single(I, Io, alpha, beta) - for I in Is - )) + HEs, _, maxCs = zip( + *(self.__compute_matrices_single(I, Io, alpha, beta) for I in Is) + ) self.HERef = torch.stack(HEs).mean(dim=0) self.maxCRef = torch.stack(maxCs).mean(dim=0) elif self.norm_mode == "concat": - ODs, ODhats = zip(*( - self.__convert_rgb2od(I, Io, beta) - for I in Is - )) + ODs, ODhats = zip(*(self.__convert_rgb2od(I, Io, beta) for I in Is)) OD = torch.cat(ODs, dim=0) ODhat = torch.cat(ODhats, dim=0) - eigvecs = torch.symeig(cov(ODhat.T), eigenvectors=True)[1][:, [1, 2]] + eigvecs = torch.symeig(cov(ODhat.T), eigenvectors=True)[1][:, [1, 2]] - HE = self.__find_HE(ODhat, eigvecs, alpha) + HE = self.__find_HE(ODhat, eigvecs, alpha) C = self.__find_concentration(OD, HE) maxCs = torch.stack([percentile(C[0, :], 99), percentile(C[1, :], 99)]) self.HERef = HE self.maxCRef = maxCs elif self.norm_mode == "avg-pre": - ODs, ODhats = zip(*( - self.__convert_rgb2od(I, Io, beta) - for I in Is - )) - - eigvecs = torch.stack([torch.symeig(cov(ODhat.T), eigenvectors=True)[1][:, [1, 2]] for ODhat in ODhats]).mean(dim=0) + ODs, ODhats = zip(*(self.__convert_rgb2od(I, Io, beta) for I in Is)) + + eigvecs = torch.stack( + [ + torch.symeig(cov(ODhat.T), eigenvectors=True)[1][:, [1, 2]] + for ODhat in ODhats + ] + ).mean(dim=0) OD = torch.cat(ODs, dim=0) ODhat = torch.cat(ODhats, dim=0) - - HE = self.__find_HE(ODhat, eigvecs, alpha) + + HE = self.__find_HE(ODhat, eigvecs, alpha) C = self.__find_concentration(OD, HE) maxCs = torch.stack([percentile(C[0, :], 99), percentile(C[1, :], 99)]) @@ -104,11 +121,16 @@ def fit(self, Is, Io=240, alpha=1, beta=0.15): self.maxCRef = maxCs elif self.norm_mode == "fixed-single" or self.norm_mode == "stochastic-single": # single img - self.HERef, _, self.maxCRef = self.__compute_matrices_single(Is[0], Io, alpha, beta) + self.HERef, _, self.maxCRef = self.__compute_matrices_single( + Is[0], Io, alpha, beta + ) else: raise "Unknown norm mode" def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True): + + I = I.to(self.device) + c, h, w = I.shape HE, C, maxC = self.__compute_matrices_single(I, Io, alpha, beta) @@ -121,11 +143,21 @@ def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True): H, E = None, None if stains: - H = torch.mul(Io, torch.exp(torch.matmul(-self.HERef[:, 0].unsqueeze(-1), C[0, :].unsqueeze(0)))) + H = torch.mul( + Io, + torch.exp( + torch.matmul(-self.HERef[:, 0].unsqueeze(-1), C[0, :].unsqueeze(0)) + ), + ) H[H > 255] = 255 H = H.T.reshape(h, w, c).int() - E = torch.mul(Io, torch.exp(torch.matmul(-self.HERef[:, 1].unsqueeze(-1), C[1, :].unsqueeze(0)))) + E = torch.mul( + Io, + torch.exp( + torch.matmul(-self.HERef[:, 1].unsqueeze(-1), C[1, :].unsqueeze(0)) + ), + ) E[E > 255] = 255 E = E.T.reshape(h, w, c).int()