Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update splatfacto with gsplat 1.0 #3200

Merged
merged 20 commits into from
Jun 8, 2024
Merged
2 changes: 1 addition & 1 deletion docs/nerfology/methods/splat.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ We provide a few additional variants:
| `splatfacto-big` | More Gaussians, Higher Quality | ~12GB | Slower |


A full evalaution of Nerfstudio's implementation of Gaussian Splatting against the original Inria method can be found [here](https://docs.gsplat.studio/tests/eval.html).
A full evalaution of Nerfstudio's implementation of Gaussian Splatting against the original Inria method can be found [here](https://docs.gsplat.studio/main/tests/eval.html).

#### Quality and Regularization
The default settings provided maintain a balance between speed, quality, and splat file size, but if you care more about quality than training speed or size, you can decrease the alpha cull threshold
Expand Down
12 changes: 8 additions & 4 deletions nerfstudio/data/datamanagers/full_images_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class FullImageDatamanagerConfig(DataManagerConfig):
new images. If -1, never pick new images."""
eval_image_indices: Optional[Tuple[int, ...]] = (0,)
"""Specifies the image indices to use during eval; if None, uses all."""
cache_images: Literal["cpu", "gpu"] = "cpu"
cache_images: Literal["cpu", "gpu"] = "gpu"
"""Whether to cache images in memory. If "cpu", caches on cpu. If "gpu", caches on device."""
cache_images_type: Literal["uint8", "float32"] = "float32"
"""The image type returned from manager, caching images in uint8 saves memory"""
Expand Down Expand Up @@ -247,11 +247,15 @@ def undistort_idx(idx: int) -> Dict[str, torch.Tensor]:
cache["image"] = cache["image"].to(self.device)
if "mask" in cache:
cache["mask"] = cache["mask"].to(self.device)
if "depth" in cache:
cache["depth"] = cache["depth"].to(self.device)
self.train_cameras = self.train_dataset.cameras.to(self.device)
elif cache_images_device == "cpu":
for cache in undistorted_images:
cache["image"] = cache["image"].pin_memory()
if "mask" in cache:
cache["mask"] = cache["mask"].pin_memory()
self.train_cameras = self.train_dataset.cameras
else:
assert_never(cache_images_device)

Expand Down Expand Up @@ -340,11 +344,11 @@ def next_train(self, step: int) -> Tuple[Cameras, Dict]:
if len(self.train_unseen_cameras) == 0:
self.train_unseen_cameras = self.sample_train_cameras()

data = deepcopy(self.cached_train[image_idx])
data = self.cached_train[image_idx]
data["image"] = data["image"].to(self.device)

assert len(self.train_dataset.cameras.shape) == 1, "Assumes single batch dimension"
camera = self.train_dataset.cameras[image_idx : image_idx + 1].to(self.device)
assert len(self.train_cameras.shape) == 1, "Assumes single batch dimension"
camera = self.train_cameras[image_idx : image_idx + 1].to(self.device)
if camera.metadata is None:
camera.metadata = {}
camera.metadata["cam_idx"] = image_idx
Expand Down
220 changes: 89 additions & 131 deletions nerfstudio/models/splatfacto.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@

import numpy as np
import torch
from gsplat._torch_impl import quat_to_rotmat
from gsplat.project_gaussians import project_gaussians
from gsplat.rasterize import rasterize_gaussians
from gsplat.sh import num_sh_bases, spherical_harmonics
from gsplat.cuda_legacy._torch_impl import quat_to_rotmat

try:
from gsplat.rendering import rasterization
except ImportError:
print("Please install gsplat>=1.0.0")
from gsplat.cuda_legacy._wrapper import num_sh_bases
from pytorch_msssim import SSIM
from torch.nn import Parameter
from typing_extensions import Literal
Expand Down Expand Up @@ -96,6 +99,25 @@ def resize_image(image: torch.Tensor, d: int):
return tf.conv2d(image.permute(2, 0, 1)[:, None, ...], weight, stride=d).squeeze(1).permute(1, 2, 0)


@torch.compile()
def get_viewmat(optimized_camera_to_world):
"""
function that converts c2w to gsplat world2camera matrix, using compile for some speed
"""
R = optimized_camera_to_world[:, :3, :3] # 3 x 3
T = optimized_camera_to_world[:, :3, 3:4] # 3 x 1
# flip the z and y axes to align with gsplat conventions
R = R * torch.tensor([[[1, -1, -1]]], device=R.device, dtype=R.dtype)
# analytic matrix inverse to get world2camera matrix
R_inv = R.transpose(1, 2)
T_inv = -torch.bmm(R_inv, T)
viewmat = torch.zeros(R.shape[0], 4, 4, device=R.device, dtype=R.dtype)
viewmat[:, 3, 3] = 1.0 # homogenous
viewmat[:, :3, :3] = R_inv
viewmat[:, :3, 3:4] = T_inv
return viewmat


@dataclass
class SplatfactoModelConfig(ModelConfig):
"""Splatfacto Model Config, nerfstudio's implementation of Gaussian Splatting"""
Expand Down Expand Up @@ -127,8 +149,6 @@ class SplatfactoModelConfig(ModelConfig):
"""number of samples to split gaussians into"""
sh_degree_interval: int = 1000
"""every n intervals turn on another sh degree"""
cull_screen_size: float = 0.15
"""if a gaussian is more than this percent of screen space, cull it"""
split_screen_size: float = 0.05
"""if a gaussian is more than this percent of screen space, split it"""
stop_screen_size_at: int = 4000
Expand Down Expand Up @@ -191,7 +211,6 @@ def populate_modules(self):
else:
means = torch.nn.Parameter((torch.rand((self.config.num_random, 3)) - 0.5) * self.config.random_scale)
self.xys_grad_norm = None
self.max_2Dsize = None
distances, _ = self.k_nearest_sklearn(means.data, 3)
distances = torch.from_numpy(distances)
# find the average of the three nearest neighbors for each point and use that as the scale
Expand Down Expand Up @@ -395,25 +414,14 @@ def after_train(self, step: int):
with torch.no_grad():
# keep track of a moving average of grad norms
visible_mask = (self.radii > 0).flatten()
assert self.xys.absgrad is not None # type: ignore
grads = self.xys.absgrad.detach().norm(dim=-1) # type: ignore
grads = self.xys.absgrad[0][visible_mask].norm(dim=-1) # type: ignore
# print(f"grad norm min {grads.min().item()} max {grads.max().item()} mean {grads.mean().item()} size {grads.shape}")
if self.xys_grad_norm is None:
self.xys_grad_norm = grads
self.vis_counts = torch.ones_like(self.xys_grad_norm)
else:
assert self.vis_counts is not None
self.vis_counts[visible_mask] = self.vis_counts[visible_mask] + 1
self.xys_grad_norm[visible_mask] = grads[visible_mask] + self.xys_grad_norm[visible_mask]

# update the max screen size, as a ratio of number of pixels
if self.max_2Dsize is None:
self.max_2Dsize = torch.zeros_like(self.radii, dtype=torch.float32)
newradii = self.radii.detach()[visible_mask]
self.max_2Dsize[visible_mask] = torch.maximum(
self.max_2Dsize[visible_mask],
newradii / float(max(self.last_size[0], self.last_size[1])),
)
self.xys_grad_norm = torch.zeros(self.num_points, device=self.device, dtype=torch.float32)
self.vis_counts = torch.ones(self.num_points, device=self.device, dtype=torch.float32)
assert self.vis_counts is not None
self.vis_counts[visible_mask] += 1
self.xys_grad_norm[visible_mask] += grads

def set_crop(self, crop_box: Optional[OrientedBox]):
self.crop_box = crop_box
Expand All @@ -438,12 +446,10 @@ def refinement_after(self, optimizers: Optimizers, step):
)
if do_densification:
# then we densify
assert self.xys_grad_norm is not None and self.vis_counts is not None and self.max_2Dsize is not None
assert self.xys_grad_norm is not None and self.vis_counts is not None
avg_grad_norm = (self.xys_grad_norm / self.vis_counts) * 0.5 * max(self.last_size[0], self.last_size[1])
high_grads = (avg_grad_norm > self.config.densify_grad_thresh).squeeze()
splits = (self.scales.exp().max(dim=-1).values > self.config.densify_size_thresh).squeeze()
if self.step < self.config.stop_screen_size_at:
splits |= (self.max_2Dsize > self.config.split_screen_size).squeeze()
splits &= high_grads
nsamps = self.config.n_split_samples
split_params = self.split_gaussians(splits, nsamps)
Expand All @@ -456,16 +462,6 @@ def refinement_after(self, optimizers: Optimizers, step):
torch.cat([param.detach(), split_params[name], dup_params[name]], dim=0)
)

# append zeros to the max_2Dsize tensor
self.max_2Dsize = torch.cat(
[
self.max_2Dsize,
torch.zeros_like(split_params["scales"][:, 0]),
torch.zeros_like(dup_params["scales"][:, 0]),
],
dim=0,
)

split_idcs = torch.where(splits)[0]
self.dup_in_all_optim(optimizers, split_idcs, nsamps)

Expand Down Expand Up @@ -510,7 +506,6 @@ def refinement_after(self, optimizers: Optimizers, step):

self.xys_grad_norm = None
self.vis_counts = None
self.max_2Dsize = None

def cull_gaussians(self, extra_cull_mask: Optional[torch.Tensor] = None):
"""
Expand All @@ -527,10 +522,6 @@ def cull_gaussians(self, extra_cull_mask: Optional[torch.Tensor] = None):
if self.step > self.config.refine_every * self.config.reset_alpha_every:
# cull huge ones
toobigs = (torch.exp(self.scales).max(dim=-1).values > self.config.cull_scale_thresh).squeeze()
if self.step < self.config.stop_screen_size_at:
# cull big screen space
assert self.max_2Dsize is not None
toobigs = toobigs | (self.max_2Dsize > self.config.cull_screen_size).squeeze()
culls = culls | toobigs
toobigs_count = torch.sum(toobigs).item()
for name, param in self.gauss_params.items():
Expand Down Expand Up @@ -670,12 +661,14 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
if not isinstance(camera, Cameras):
print("Called get_outputs with not a camera")
return {}
assert camera.shape[0] == 1, "Only one camera at a time"

optimized_camera_to_world = self.camera_optimizer.apply_to_camera(camera)[0, ...]

# get the background color
if self.training:
assert camera.shape[0] == 1, "Only one camera at a time"
optimized_camera_to_world = self.camera_optimizer.apply_to_camera(camera)

if self.config.background_color == "random":
background = torch.rand(3, device=self.device)
elif self.config.background_color == "white":
Expand All @@ -685,6 +678,8 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
else:
background = self.background_color.to(self.device)
else:
optimized_camera_to_world = camera.camera_to_worlds

if renderers.BACKGROUND_COLOR_OVERRIDE is not None:
background = renderers.BACKGROUND_COLOR_OVERRIDE.to(self.device)
else:
Expand All @@ -696,25 +691,9 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
return self.get_empty_outputs(int(camera.width.item()), int(camera.height.item()), background)
else:
crop_ids = None
camera_downscale = self._get_downscale_factor()
camera.rescale_output_resolution(1 / camera_downscale)
# shift the camera to center of scene looking at center
R = optimized_camera_to_world[:3, :3] # 3 x 3
T = optimized_camera_to_world[:3, 3:4] # 3 x 1

# flip the z and y axes to align with gsplat conventions
R_edit = torch.diag(torch.tensor([1, -1, -1], device=self.device, dtype=R.dtype))
R = R @ R_edit
# analytic matrix inverse to get world2camera matrix
R_inv = R.T
T_inv = -R_inv @ T
viewmat = torch.eye(4, device=R.device, dtype=R.dtype)
viewmat[:3, :3] = R_inv
viewmat[:3, 3:4] = T_inv
# calculate the FOV of the camera given fx and fy, width and height
cx = camera.cx.item()
cy = camera.cy.item()
W, H = int(camera.width.item()), int(camera.height.item())
camera_scale_fac = 1.0 / self._get_downscale_factor()
viewmat = get_viewmat(optimized_camera_to_world)
W, H = int(camera.width[0] * camera_scale_fac), int(camera.height[0] * camera_scale_fac)
self.last_size = (H, W)

if crop_ids is not None:
Expand All @@ -734,79 +713,58 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:

colors_crop = torch.cat((features_dc_crop[:, None, :], features_rest_crop), dim=1)
BLOCK_WIDTH = 16 # this controls the tile size of rasterization, 16 is a good default
self.xys, depths, self.radii, conics, comp, num_tiles_hit, cov3d = project_gaussians( # type: ignore
means_crop,
torch.exp(scales_crop),
1,
quats_crop / quats_crop.norm(dim=-1, keepdim=True),
viewmat.squeeze()[:3, :],
camera.fx.item(),
camera.fy.item(),
cx,
cy,
H,
W,
BLOCK_WIDTH,
) # type: ignore

# rescale the camera back to original dimensions before returning
camera.rescale_output_resolution(camera_downscale)

if (self.radii).sum() == 0:
return self.get_empty_outputs(W, H, background)

if self.config.sh_degree > 0:
viewdirs = means_crop.detach() - optimized_camera_to_world.detach()[:3, 3] # (N, 3)
n = min(self.step // self.config.sh_degree_interval, self.config.sh_degree)
rgbs = spherical_harmonics(n, viewdirs, colors_crop) # input unnormalized viewdirs
rgbs = torch.clamp(rgbs + 0.5, min=0.0) # type: ignore
else:
rgbs = torch.sigmoid(colors_crop[:, 0, :])

assert (num_tiles_hit > 0).any() # type: ignore

K = camera.get_intrinsics_matrices().cuda()
K[:, :2, :] *= camera_scale_fac
# apply the compensation of screen space blurring to gaussians
if self.config.rasterize_mode == "antialiased":
opacities = torch.sigmoid(opacities_crop) * comp[:, None]
elif self.config.rasterize_mode == "classic":
opacities = torch.sigmoid(opacities_crop)
else:
if self.config.rasterize_mode not in ["antialiased", "classic"]:
raise ValueError("Unknown rasterize_mode: %s", self.config.rasterize_mode)

rgb, alpha = rasterize_gaussians( # type: ignore
self.xys,
depths,
self.radii,
conics,
num_tiles_hit, # type: ignore
rgbs,
opacities,
H,
W,
BLOCK_WIDTH,
background=background,
return_alpha=True,
) # type: ignore
alpha = alpha[..., None]
rgb = torch.clamp(rgb, max=1.0) # type: ignore
depth_im = None
if self.config.output_depth_during_training or not self.training:
depth_im = rasterize_gaussians( # type: ignore
self.xys,
depths,
self.radii,
conics,
num_tiles_hit, # type: ignore
depths[:, None].repeat(1, 3),
opacities,
H,
W,
BLOCK_WIDTH,
background=torch.zeros(3, device=self.device),
)[..., 0:1] # type: ignore
depth_im = torch.where(alpha > 0, depth_im / alpha, depth_im.detach().max())

return {"rgb": rgb, "depth": depth_im, "accumulation": alpha, "background": background} # type: ignore
render_mode = "RGB+ED"
else:
render_mode = "RGB"

if self.config.sh_degree > 0:
sh_degree_to_use = min(self.step // self.config.sh_degree_interval, self.config.sh_degree)
else:
sh_degree_to_use = None

render, alpha, info = rasterization(
means=means_crop,
quats=quats_crop / quats_crop.norm(dim=-1, keepdim=True),
scales=torch.exp(scales_crop),
opacities=torch.sigmoid(opacities_crop).squeeze(-1),
colors=colors_crop,
viewmats=viewmat, # [1, 4, 4]
Ks=K, # [1, 3, 3]
width=W,
height=H,
tile_size=BLOCK_WIDTH,
packed=False,
near_plane=0.01,
far_plane=1e10,
render_mode=render_mode,
sh_degree=sh_degree_to_use,
sparse_grad=False,
absgrad=True,
rasterize_mode=self.config.rasterize_mode,
# set some threshold to disregrad small gaussians for faster rendering.
# radius_clip=3.0,
)
if self.training and info["means2d"].requires_grad:
info["means2d"].retain_grad()
self.xys = info["means2d"] # [1, N, 2]
self.radii = info["radii"][0] # [N]

alpha = alpha[:, ...]
rgb = render[:, ..., :3] + (1 - alpha) * background
rgb = torch.clamp(rgb, 0.0, 1.0)
if render_mode == "RGB+ED":
depth_im = render[:, ..., 3:4]
depth_im = torch.where(alpha > 0, depth_im, depth_im.detach().max()).squeeze(0)
else:
depth_im = None
return {"rgb": rgb.squeeze(0), "depth": depth_im, "accumulation": alpha.squeeze(0), "background": background} # type: ignore

def get_gt_img(self, image: torch.Tensor):
"""Compute groundtruth image with iteration dependent downscale factor for evaluation purpose
Expand Down
3 changes: 3 additions & 0 deletions nerfstudio/scripts/process_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ def main(self) -> None:
zip_ref.extractall(self.output_dir)
extracted_folder = zip_ref.namelist()[0].split("/")[0]
self.data = self.output_dir / extracted_folder
if not (self.data / "keyframes").exists():
# new versions of polycam data have a different structure, strip the last dir off
self.data = self.output_dir

if (self.data / "keyframes" / "corrected_images").exists() and not self.use_uncorrected_images:
polycam_image_dir = self.data / "keyframes" / "corrected_images"
Expand Down
Loading
Loading