Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions torchstain/base/normalizers/macenko.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
def MacenkoNormalizer(backend='torch'):
if backend == 'numpy':
def MacenkoNormalizer(backend="torch", device="cpu"):
if backend == "numpy":
from torchstain.numpy.normalizers import NumpyMacenkoNormalizer

return NumpyMacenkoNormalizer()
elif backend == "torch":
from torchstain.torch.normalizers.macenko import TorchMacenkoNormalizer
return TorchMacenkoNormalizer()

return TorchMacenkoNormalizer(device=device)
elif backend == "tensorflow":
from torchstain.tf.normalizers.macenko import TensorFlowMacenkoNormalizer

return TensorFlowMacenkoNormalizer()
else:
raise Exception(f'Unknown backend {backend}')
raise Exception(f"Unknown backend {backend}")
9 changes: 7 additions & 2 deletions torchstain/base/normalizers/multitarget.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
def MultiMacenkoNormalizer(backend="torch", **kwargs):
if backend == "numpy":
raise NotImplementedError("MultiMacenkoNormalizer is not implemented for NumPy backend")
raise NotImplementedError(
"MultiMacenkoNormalizer is not implemented for NumPy backend"
)
elif backend == "torch":
from torchstain.torch.normalizers import TorchMultiMacenkoNormalizer

return TorchMultiMacenkoNormalizer(**kwargs)
elif backend == "tensorflow":
raise NotImplementedError("MultiMacenkoNormalizer is not implemented for TensorFlow backend")
raise NotImplementedError(
"MultiMacenkoNormalizer is not implemented for TensorFlow backend"
)
else:
raise Exception(f"Unsupported backend {backend}")
64 changes: 47 additions & 17 deletions torchstain/torch/normalizers/macenko.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,22 @@
Source code ported from: https://github.com/schaugf/HEnorm_python
Original implementation: https://github.com/mitkovetta/staining-normalization
"""


class TorchMacenkoNormalizer(HENormalizer):
def __init__(self):
def __init__(self, device="cpu"):
super().__init__()

self.HERef = torch.tensor([[0.5626, 0.2159],
[0.7201, 0.8012],
[0.4062, 0.5581]])
self.maxCRef = torch.tensor([1.9705, 1.0308])
self.device = device

self.HERef = torch.tensor(
[[0.5626, 0.2159], [0.7201, 0.8012], [0.4062, 0.5581]], device=self.device
)
self.maxCRef = torch.tensor([1.9705, 1.0308], device=self.device)

# Avoid using deprecated torch.lstsq (since 1.9.0)
self.updated_lstsq = hasattr(torch.linalg, 'lstsq')
self.updated_lstsq = hasattr(torch.linalg, "lstsq")

def __convert_rgb2od(self, I, Io, beta):
I = I.permute(1, 2, 0)

Expand All @@ -38,12 +42,20 @@ def __find_HE(self, ODhat, eigvecs, alpha):
minPhi = percentile(phi, alpha)
maxPhi = percentile(phi, 100 - alpha)

vMin = torch.matmul(eigvecs, torch.stack((torch.cos(minPhi), torch.sin(minPhi)))).unsqueeze(1)
vMax = torch.matmul(eigvecs, torch.stack((torch.cos(maxPhi), torch.sin(maxPhi)))).unsqueeze(1)
vMin = torch.matmul(
eigvecs, torch.stack((torch.cos(minPhi), torch.sin(minPhi)))
).unsqueeze(1)
vMax = torch.matmul(
eigvecs, torch.stack((torch.cos(maxPhi), torch.sin(maxPhi)))
).unsqueeze(1)

# a heuristic to make the vector corresponding to hematoxylin first and the
# one corresponding to eosin second
HE = torch.where(vMin[0] > vMax[0], torch.cat((vMin, vMax), dim=1), torch.cat((vMax, vMin), dim=1))
HE = torch.where(
vMin[0] > vMax[0],
torch.cat((vMin, vMax), dim=1),
torch.cat((vMax, vMin), dim=1),
)

return HE

Expand All @@ -54,14 +66,16 @@ def __find_concentration(self, OD, HE):
# determine concentrations of the individual stains
if not self.updated_lstsq:
return torch.lstsq(Y, HE)[0][:2]

return torch.linalg.lstsq(HE, Y)[0]

return torch.linalg.pinv(HE) @ OD.T
# this fails for large Y dimension HE.shape torch.Size([3, 2]) Y.shape torch.Size([3, 1048576])
# torch.linalg.lstsq(HE, Y)[0]

def __compute_matrices(self, I, Io, alpha, beta):
OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta)

# compute eigenvectors
_, eigvecs = torch.linalg.eigh(cov(ODhat.T))
_, eigvecs = torch.linalg.eigh(cov(ODhat.T))
eigvecs = eigvecs[:, [1, 2]]

HE = self.__find_HE(ODhat, eigvecs, alpha)
Expand All @@ -72,13 +86,16 @@ def __compute_matrices(self, I, Io, alpha, beta):
return HE, C, maxC

def fit(self, I, Io=240, alpha=1, beta=0.15):

I = I.to(self.device)

HE, _, maxC = self.__compute_matrices(I, Io, alpha, beta)

self.HERef = HE
self.maxCRef = maxC

def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True):
''' Normalize staining appearence of H&E stained images
"""Normalize staining appearence of H&E stained images

Example use:
see example.py
Expand All @@ -98,7 +115,10 @@ def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True):
Reference:
A method for normalizing histology slides for quantitative analysis. M.
Macenko et al., ISBI 2009
'''
"""

I = I.to(self.device)

c, h, w = I.shape

HE, C, maxC = self.__compute_matrices(I, Io, alpha, beta)
Expand All @@ -114,11 +134,21 @@ def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True):
H, E = None, None

if stains:
H = torch.mul(Io, torch.exp(torch.matmul(-self.HERef[:, 0].unsqueeze(-1), C[0, :].unsqueeze(0))))
H = torch.mul(
Io,
torch.exp(
torch.matmul(-self.HERef[:, 0].unsqueeze(-1), C[0, :].unsqueeze(0))
),
)
H[H > 255] = 255
H = H.T.reshape(h, w, c).int()

E = torch.mul(Io, torch.exp(torch.matmul(-self.HERef[:, 1].unsqueeze(-1), C[1, :].unsqueeze(0))))
E = torch.mul(
Io,
torch.exp(
torch.matmul(-self.HERef[:, 1].unsqueeze(-1), C[1, :].unsqueeze(0))
),
)
E[E > 255] = 255
E = E.T.reshape(h, w, c).int()

Expand Down
100 changes: 66 additions & 34 deletions torchstain/torch/normalizers/multitarget.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,22 @@
"""
Implementation of the multi-target normalizer from the paper: https://arxiv.org/pdf/2406.02077
"""


class TorchMultiMacenkoNormalizer:
def __init__(self, norm_mode="avg-post"):
def __init__(self, norm_mode="avg-post", device="cpu"):
self.device = device

self.norm_mode = norm_mode
self.HERef = torch.tensor([[0.5626, 0.2159],
[0.7201, 0.8012],
[0.4062, 0.5581]])
self.maxCRef = torch.tensor([1.9705, 1.0308])
self.HERef = torch.tensor(
[[0.5626, 0.2159], [0.7201, 0.8012], [0.4062, 0.5581]], device=self.device
)
self.maxCRef = torch.tensor([1.9705, 1.0308], device=self.device)
self.updated_lstsq = hasattr(torch.linalg, "lstsq")

def __convert_rgb2od(self, I, Io, beta):
I = I.permute(1, 2, 0)
OD = -torch.log((I.reshape((-1, I.shape[-1])).float() + 1)/Io)
OD = -torch.log((I.reshape((-1, I.shape[-1])).float() + 1) / Io)
ODhat = OD[~torch.any(OD < beta, dim=1)]
return OD, ODhat

Expand All @@ -29,10 +33,18 @@ def __find_phi_bounds(self, ODhat, eigvecs, alpha):
return minPhi, maxPhi

def __find_HE_from_bounds(self, eigvecs, minPhi, maxPhi):
vMin = torch.matmul(eigvecs, torch.stack((torch.cos(minPhi), torch.sin(minPhi)))).unsqueeze(1)
vMax = torch.matmul(eigvecs, torch.stack((torch.cos(maxPhi), torch.sin(maxPhi)))).unsqueeze(1)

HE = torch.where(vMin[0] > vMax[0], torch.cat((vMin, vMax), dim=1), torch.cat((vMax, vMin), dim=1))
vMin = torch.matmul(
eigvecs, torch.stack((torch.cos(minPhi), torch.sin(minPhi)))
).unsqueeze(1)
vMax = torch.matmul(
eigvecs, torch.stack((torch.cos(maxPhi), torch.sin(maxPhi)))
).unsqueeze(1)

HE = torch.where(
vMin[0] > vMax[0],
torch.cat((vMin, vMax), dim=1),
torch.cat((vMax, vMin), dim=1),
)

return HE

Expand All @@ -44,13 +56,17 @@ def __find_concentration(self, OD, HE):
Y = OD.T
if not self.updated_lstsq:
return torch.lstsq(Y, HE)[0][:2]
return torch.linalg.lstsq(HE, Y)[0]
return torch.linalg.pinv(HE) @ OD.T

# this fails for large Y dimension HE.shape torch.Size([3, 2]) Y.shape torch.Size([3, 1048576])
# return torch.linalg.lstsq(HE, Y)[0]


def __compute_matrices_single(self, I, Io, alpha, beta):
OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta)

# _, eigvecs = torch.symeig(cov(ODhat.T), eigenvectors=True)
_, eigvecs = torch.linalg.eigh(cov(ODhat.T), UPLO='U')
_, eigvecs = torch.linalg.eigh(cov(ODhat.T), UPLO="U")
eigvecs = eigvecs[:, [1, 2]]

HE = self.__find_HE(ODhat, eigvecs, alpha)
Expand All @@ -61,54 +77,60 @@ def __compute_matrices_single(self, I, Io, alpha, beta):
return HE, C, maxC

def fit(self, Is, Io=240, alpha=1, beta=0.15):

Is = [I.to(self.device) for I in Is]

if self.norm_mode == "avg-post":
HEs, _, maxCs = zip(*(
self.__compute_matrices_single(I, Io, alpha, beta)
for I in Is
))
HEs, _, maxCs = zip(
*(self.__compute_matrices_single(I, Io, alpha, beta) for I in Is)
)

self.HERef = torch.stack(HEs).mean(dim=0)
self.maxCRef = torch.stack(maxCs).mean(dim=0)
elif self.norm_mode == "concat":
ODs, ODhats = zip(*(
self.__convert_rgb2od(I, Io, beta)
for I in Is
))
ODs, ODhats = zip(*(self.__convert_rgb2od(I, Io, beta) for I in Is))
OD = torch.cat(ODs, dim=0)
ODhat = torch.cat(ODhats, dim=0)

eigvecs = torch.symeig(cov(ODhat.T), eigenvectors=True)[1][:, [1, 2]]
eigvecs = torch.symeig(cov(ODhat.T), eigenvectors=True)[1][:, [1, 2]]

HE = self.__find_HE(ODhat, eigvecs, alpha)
HE = self.__find_HE(ODhat, eigvecs, alpha)

C = self.__find_concentration(OD, HE)
maxCs = torch.stack([percentile(C[0, :], 99), percentile(C[1, :], 99)])
self.HERef = HE
self.maxCRef = maxCs
elif self.norm_mode == "avg-pre":
ODs, ODhats = zip(*(
self.__convert_rgb2od(I, Io, beta)
for I in Is
))

eigvecs = torch.stack([torch.symeig(cov(ODhat.T), eigenvectors=True)[1][:, [1, 2]] for ODhat in ODhats]).mean(dim=0)
ODs, ODhats = zip(*(self.__convert_rgb2od(I, Io, beta) for I in Is))

eigvecs = torch.stack(
[
torch.symeig(cov(ODhat.T), eigenvectors=True)[1][:, [1, 2]]
for ODhat in ODhats
]
).mean(dim=0)

OD = torch.cat(ODs, dim=0)
ODhat = torch.cat(ODhats, dim=0)
HE = self.__find_HE(ODhat, eigvecs, alpha)

HE = self.__find_HE(ODhat, eigvecs, alpha)

C = self.__find_concentration(OD, HE)
maxCs = torch.stack([percentile(C[0, :], 99), percentile(C[1, :], 99)])
self.HERef = HE
self.maxCRef = maxCs
elif self.norm_mode == "fixed-single" or self.norm_mode == "stochastic-single":
# single img
self.HERef, _, self.maxCRef = self.__compute_matrices_single(Is[0], Io, alpha, beta)
self.HERef, _, self.maxCRef = self.__compute_matrices_single(
Is[0], Io, alpha, beta
)
else:
raise "Unknown norm mode"

def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True):

I = I.to(self.device)

c, h, w = I.shape

HE, C, maxC = self.__compute_matrices_single(I, Io, alpha, beta)
Expand All @@ -121,11 +143,21 @@ def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True):
H, E = None, None

if stains:
H = torch.mul(Io, torch.exp(torch.matmul(-self.HERef[:, 0].unsqueeze(-1), C[0, :].unsqueeze(0))))
H = torch.mul(
Io,
torch.exp(
torch.matmul(-self.HERef[:, 0].unsqueeze(-1), C[0, :].unsqueeze(0))
),
)
H[H > 255] = 255
H = H.T.reshape(h, w, c).int()

E = torch.mul(Io, torch.exp(torch.matmul(-self.HERef[:, 1].unsqueeze(-1), C[1, :].unsqueeze(0))))
E = torch.mul(
Io,
torch.exp(
torch.matmul(-self.HERef[:, 1].unsqueeze(-1), C[1, :].unsqueeze(0))
),
)
E[E > 255] = 255
E = E.T.reshape(h, w, c).int()

Expand Down