Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updates to apply_pca_colormap #3086

Merged
merged 2 commits into from
Apr 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 27 additions & 16 deletions nerfstudio/utils/colormaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,35 +171,46 @@ def apply_boolean_colormap(
return colored_image


def apply_pca_colormap(image: Float[Tensor, "*bs dim"]) -> Float[Tensor, "*bs rgb=3"]:
def apply_pca_colormap(
image: Float[Tensor, "*bs dim"], pca_mat: Optional[Float[Tensor, "dim rgb=3"]] = None, ignore_zeros=True
) -> Float[Tensor, "*bs rgb=3"]:
"""Convert feature image to 3-channel RGB via PCA. The first three principle
components are used for the color channels, with outlier rejection per-channel

Args:
image: image of arbitrary vectors
pca_mat: an optional argument of the PCA matrix, shape (dim, 3)
ignore_zeros: whether to ignore zero values in the input image (they won't affect the PCA computation)

Returns:
Tensor: Colored image
"""
original_shape = image.shape
image = image.view(-1, image.shape[-1])
_, _, v = torch.pca_lowrank(image)
image = torch.matmul(image, v[..., :3])
d = torch.abs(image - torch.median(image, dim=0).values)
if ignore_zeros:
valids = (image.abs().amax(dim=-1)) > 0
else:
valids = torch.ones(image.shape[0], dtype=torch.bool)

if pca_mat is None:
_, _, pca_mat = torch.pca_lowrank(image[valids, :], q=3, niter=20)
assert pca_mat is not None
image = torch.matmul(image, pca_mat[..., :3])
d = torch.abs(image[valids, :] - torch.median(image[valids, :], dim=0).values)
mdev = torch.median(d, dim=0).values
s = d / mdev
m = 3.0 # this is a hyperparam controlling how many std dev outside for outliers
rins = image[s[:, 0] < m, 0]
gins = image[s[:, 1] < m, 1]
bins = image[s[:, 2] < m, 2]

image[:, 0] -= rins.min()
image[:, 1] -= gins.min()
image[:, 2] -= bins.min()

image[:, 0] /= rins.max() - rins.min()
image[:, 1] /= gins.max() - gins.min()
image[:, 2] /= bins.max() - bins.min()
m = 2.0 # this is a hyperparam controlling how many std dev outside for outliers
rins = image[valids, :][s[:, 0] < m, 0]
gins = image[valids, :][s[:, 1] < m, 1]
bins = image[valids, :][s[:, 2] < m, 2]

image[valids, 0] -= rins.min()
image[valids, 1] -= gins.min()
image[valids, 2] -= bins.min()

image[valids, 0] /= rins.max() - rins.min()
image[valids, 1] /= gins.max() - gins.min()
image[valids, 2] /= bins.max() - bins.min()

image = torch.clamp(image, 0, 1)
image_long = (image * 255).long()
Expand Down
Loading