diff --git a/nerfstudio/utils/colormaps.py b/nerfstudio/utils/colormaps.py index d772d40b4e..0f1fc40b76 100644 --- a/nerfstudio/utils/colormaps.py +++ b/nerfstudio/utils/colormaps.py @@ -171,35 +171,46 @@ def apply_boolean_colormap( return colored_image -def apply_pca_colormap(image: Float[Tensor, "*bs dim"]) -> Float[Tensor, "*bs rgb=3"]: +def apply_pca_colormap( + image: Float[Tensor, "*bs dim"], pca_mat: Optional[Float[Tensor, "dim rgb=3"]] = None, ignore_zeros=True +) -> Float[Tensor, "*bs rgb=3"]: """Convert feature image to 3-channel RGB via PCA. The first three principle components are used for the color channels, with outlier rejection per-channel Args: image: image of arbitrary vectors + pca_mat: an optional argument of the PCA matrix, shape (dim, 3) + ignore_zeros: whether to ignore zero values in the input image (they won't affect the PCA computation) Returns: Tensor: Colored image """ original_shape = image.shape image = image.view(-1, image.shape[-1]) - _, _, v = torch.pca_lowrank(image) - image = torch.matmul(image, v[..., :3]) - d = torch.abs(image - torch.median(image, dim=0).values) + if ignore_zeros: + valids = (image.abs().amax(dim=-1)) > 0 + else: + valids = torch.ones(image.shape[0], dtype=torch.bool) + + if pca_mat is None: + _, _, pca_mat = torch.pca_lowrank(image[valids, :], q=3, niter=20) + assert pca_mat is not None + image = torch.matmul(image, pca_mat[..., :3]) + d = torch.abs(image[valids, :] - torch.median(image[valids, :], dim=0).values) mdev = torch.median(d, dim=0).values s = d / mdev - m = 3.0 # this is a hyperparam controlling how many std dev outside for outliers - rins = image[s[:, 0] < m, 0] - gins = image[s[:, 1] < m, 1] - bins = image[s[:, 2] < m, 2] - - image[:, 0] -= rins.min() - image[:, 1] -= gins.min() - image[:, 2] -= bins.min() - - image[:, 0] /= rins.max() - rins.min() - image[:, 1] /= gins.max() - gins.min() - image[:, 2] /= bins.max() - bins.min() + m = 2.0 # this is a hyperparam controlling how many std dev outside for outliers + rins = image[valids, :][s[:, 0] < m, 0] + gins = image[valids, :][s[:, 1] < m, 1] + bins = image[valids, :][s[:, 2] < m, 2] + + image[valids, 0] -= rins.min() + image[valids, 1] -= gins.min() + image[valids, 2] -= bins.min() + + image[valids, 0] /= rins.max() - rins.min() + image[valids, 1] /= gins.max() - gins.min() + image[valids, 2] /= bins.max() - bins.min() image = torch.clamp(image, 0, 1) image_long = (image * 255).long()