diff --git a/tests/test_torch.py b/tests/test_torch.py index 68ea0c6..82f238c 100644 --- a/tests/test_torch.py +++ b/tests/test_torch.py @@ -11,13 +11,6 @@ def setup_function(fn): print("torch version:", torch.__version__, "torchvision version:", torchvision.__version__) -def test_cov(): - x = np.random.randn(10, 10) - cov_np = np.cov(x) - cov_t = torchstain.torch.utils.cov(torch.tensor(x)) - - np.testing.assert_almost_equal(cov_np, cov_t.numpy()) - def test_percentile(): x = np.random.randn(10, 10) p = 20 diff --git a/torchstain/torch/augmentors/macenko.py b/torchstain/torch/augmentors/macenko.py index 0def8e4..d34eaef 100644 --- a/torchstain/torch/augmentors/macenko.py +++ b/torchstain/torch/augmentors/macenko.py @@ -1,6 +1,6 @@ import torch from torchstain.base.augmentors.he_augmentor import HEAugmentor -from torchstain.torch.utils import cov, percentile +from torchstain.torch.utils import percentile """ Source code ported from: https://github.com/schaugf/HEnorm_python @@ -66,7 +66,7 @@ def __compute_matrices(self, I, Io, alpha, beta): OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta) # compute eigenvectors - _, eigvecs = torch.linalg.eigh(cov(ODhat.T)) + _, eigvecs = torch.linalg.eigh(torch.cov(ODhat.T)) eigvecs = eigvecs[:, [1, 2]] HE = self.__find_HE(ODhat, eigvecs, alpha) diff --git a/torchstain/torch/normalizers/macenko.py b/torchstain/torch/normalizers/macenko.py index 569fd06..d672cc6 100644 --- a/torchstain/torch/normalizers/macenko.py +++ b/torchstain/torch/normalizers/macenko.py @@ -1,6 +1,6 @@ import torch from torchstain.base.normalizers.he_normalizer import HENormalizer -from torchstain.torch.utils import cov, percentile +from torchstain.torch.utils import percentile """ Source code ported from: https://github.com/schaugf/HEnorm_python @@ -61,7 +61,7 @@ def __compute_matrices(self, I, Io, alpha, beta): OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta) # compute eigenvectors - _, eigvecs = torch.linalg.eigh(cov(ODhat.T)) + _, eigvecs = torch.linalg.eigh(torch.cov(ODhat.T)) eigvecs = eigvecs[:, [1, 2]] HE = self.__find_HE(ODhat, eigvecs, alpha) diff --git a/torchstain/torch/normalizers/multitarget.py b/torchstain/torch/normalizers/multitarget.py index d1396b4..a93c12e 100644 --- a/torchstain/torch/normalizers/multitarget.py +++ b/torchstain/torch/normalizers/multitarget.py @@ -1,5 +1,5 @@ import torch -from torchstain.torch.utils import cov, percentile +from torchstain.torch.utils import percentile """ Implementation of the multi-target normalizer from the paper: https://arxiv.org/pdf/2406.02077 @@ -50,7 +50,7 @@ def __compute_matrices_single(self, I, Io, alpha, beta): OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta) # _, eigvecs = torch.symeig(cov(ODhat.T), eigenvectors=True) - _, eigvecs = torch.linalg.eigh(cov(ODhat.T), UPLO='U') + _, eigvecs = torch.linalg.eigh(torch.cov(ODhat.T), UPLO='U') eigvecs = eigvecs[:, [1, 2]] HE = self.__find_HE(ODhat, eigvecs, alpha) @@ -77,7 +77,7 @@ def fit(self, Is, Io=240, alpha=1, beta=0.15): OD = torch.cat(ODs, dim=0) ODhat = torch.cat(ODhats, dim=0) - eigvecs = torch.symeig(cov(ODhat.T), eigenvectors=True)[1][:, [1, 2]] + eigvecs = torch.symeig(torch.cov(ODhat.T), eigenvectors=True)[1][:, [1, 2]] HE = self.__find_HE(ODhat, eigvecs, alpha) @@ -91,7 +91,7 @@ def fit(self, Is, Io=240, alpha=1, beta=0.15): for I in Is )) - eigvecs = torch.stack([torch.symeig(cov(ODhat.T), eigenvectors=True)[1][:, [1, 2]] for ODhat in ODhats]).mean(dim=0) + eigvecs = torch.stack([torch.symeig(torch.cov(ODhat.T), eigenvectors=True)[1][:, [1, 2]] for ODhat in ODhats]).mean(dim=0) OD = torch.cat(ODs, dim=0) ODhat = torch.cat(ODhats, dim=0) diff --git a/torchstain/torch/utils/__init__.py b/torchstain/torch/utils/__init__.py index 4acea5a..48843ca 100644 --- a/torchstain/torch/utils/__init__.py +++ b/torchstain/torch/utils/__init__.py @@ -1,4 +1,3 @@ -from torchstain.torch.utils.cov import cov from torchstain.torch.utils.percentile import percentile from torchstain.torch.utils.stats import * from torchstain.torch.utils.split import * diff --git a/torchstain/torch/utils/cov.py b/torchstain/torch/utils/cov.py deleted file mode 100644 index 5a2d3c1..0000000 --- a/torchstain/torch/utils/cov.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch - -def cov(x): - """ - https://en.wikipedia.org/wiki/Covariance_matrix - """ - E_x = x.mean(dim=1) - x = x - E_x[:, None] - return torch.mm(x, x.T) / (x.size(1) - 1)