Skip to content

Commit eba72db

Browse files
oseiskarjh-surh
andauthored
Camera pose optimization for Splatfacto (#2891)
* Add pose optimization to Splatfacto * Disable Splatfacto pose optimization by default * Improve apply_to_camera for Gaussian Splatting pose optimization Do not chain modifications to camera_to_worlds to improve numerical stability and enable L2 rot/trans penalties. * Add separate mean and max rot/trans metrics to camera-opt * Tweak pose optimization hyperparameters Parameters used in the Gaussian Splatting on the Move paper v1 * Unit test fix for new cameara_optimizer training metrics * Adjust splatfacto-big camera pose optimization parameters Same parameters as in normal Splatfacto --------- Co-authored-by: jh-surh <[email protected]>
1 parent 2d9bbe5 commit eba72db

File tree

3 files changed

+54
-19
lines changed

3 files changed

+54
-19
lines changed

nerfstudio/cameras/camera_optimizers.py

+21-11
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from dataclasses import dataclass, field
2323
from typing import Literal, Optional, Type, Union
2424

25+
import numpy
2526
import torch
2627
import tyro
2728
from jaxtyping import Float, Int
@@ -151,15 +152,20 @@ def apply_to_raybundle(self, raybundle: RayBundle) -> None:
151152
raybundle.origins = raybundle.origins + correction_matrices[:, :3, 3]
152153
raybundle.directions = torch.bmm(correction_matrices[:, :3, :3], raybundle.directions[..., None]).squeeze()
153154

154-
def apply_to_camera(self, camera: Cameras) -> None:
155-
"""Apply the pose correction to the raybundle"""
156-
if self.config.mode != "off":
157-
assert camera.metadata is not None, "Must provide id of camera in its metadata"
158-
assert "cam_idx" in camera.metadata, "Must provide id of camera in its metadata"
159-
camera_idx = camera.metadata["cam_idx"]
160-
adj = self(torch.tensor([camera_idx], dtype=torch.long, device=camera.device)) # type: ignore
161-
adj = torch.cat([adj, torch.Tensor([0, 0, 0, 1])[None, None].to(adj)], dim=1)
162-
camera.camera_to_worlds = torch.bmm(camera.camera_to_worlds, adj)
155+
def apply_to_camera(self, camera: Cameras) -> torch.Tensor:
156+
"""Apply the pose correction to the world-to-camera matrix in a Camera object"""
157+
if self.config.mode == "off":
158+
return camera.camera_to_worlds
159+
160+
assert camera.metadata is not None, "Must provide id of camera in its metadata"
161+
if "cam_idx" not in camera.metadata:
162+
# Evalutaion cams?
163+
return camera.camera_to_worlds
164+
165+
camera_idx = camera.metadata["cam_idx"]
166+
adj = self(torch.tensor([camera_idx], dtype=torch.long, device=camera.device)) # type: ignore
167+
adj = torch.cat([adj, torch.Tensor([0, 0, 0, 1])[None, None].to(adj)], dim=1)
168+
return torch.bmm(camera.camera_to_worlds, adj)
163169

164170
def get_loss_dict(self, loss_dict: dict) -> None:
165171
"""Add regularization"""
@@ -176,8 +182,12 @@ def get_correction_matrices(self):
176182
def get_metrics_dict(self, metrics_dict: dict) -> None:
177183
"""Get camera optimizer metrics"""
178184
if self.config.mode != "off":
179-
metrics_dict["camera_opt_translation"] = self.pose_adjustment[:, :3].norm()
180-
metrics_dict["camera_opt_rotation"] = self.pose_adjustment[:, 3:].norm()
185+
trans = self.pose_adjustment[:, :3].detach().norm(dim=-1)
186+
rot = self.pose_adjustment[:, 3:].detach().norm(dim=-1)
187+
metrics_dict["camera_opt_translation_max"] = trans.max()
188+
metrics_dict["camera_opt_translation_mean"] = trans.mean()
189+
metrics_dict["camera_opt_rotation_mean"] = numpy.rad2deg(rot.mean().cpu())
190+
metrics_dict["camera_opt_rotation_max"] = numpy.rad2deg(rot.max().cpu())
181191

182192
def get_param_groups(self, param_groups: dict) -> None:
183193
"""Get camera optimizer parameters"""

nerfstudio/configs/method_configs.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -632,8 +632,10 @@
632632
},
633633
"quats": {"optimizer": AdamOptimizerConfig(lr=0.001, eps=1e-15), "scheduler": None},
634634
"camera_opt": {
635-
"optimizer": AdamOptimizerConfig(lr=1e-3, eps=1e-15),
636-
"scheduler": ExponentialDecaySchedulerConfig(lr_final=5e-5, max_steps=30000),
635+
"optimizer": AdamOptimizerConfig(lr=1e-4, eps=1e-15),
636+
"scheduler": ExponentialDecaySchedulerConfig(
637+
lr_final=5e-7, max_steps=30000, warmup_steps=1000, lr_pre_warmup=0
638+
),
637639
},
638640
},
639641
viewer=ViewerConfig(num_rays_per_chunk=1 << 15),
@@ -684,8 +686,10 @@
684686
},
685687
"quats": {"optimizer": AdamOptimizerConfig(lr=0.001, eps=1e-15), "scheduler": None},
686688
"camera_opt": {
687-
"optimizer": AdamOptimizerConfig(lr=1e-3, eps=1e-15),
688-
"scheduler": ExponentialDecaySchedulerConfig(lr_final=5e-5, max_steps=30000),
689+
"optimizer": AdamOptimizerConfig(lr=1e-4, eps=1e-15),
690+
"scheduler": ExponentialDecaySchedulerConfig(
691+
lr_final=5e-7, max_steps=30000, warmup_steps=1000, lr_pre_warmup=0
692+
),
689693
},
690694
},
691695
viewer=ViewerConfig(num_rays_per_chunk=1 << 15),

nerfstudio/models/splatfacto.py

+25-4
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from torch.nn import Parameter
3434
from typing_extensions import Literal
3535

36+
from nerfstudio.cameras.camera_optimizers import CameraOptimizer, CameraOptimizerConfig
3637
from nerfstudio.cameras.cameras import Cameras
3738
from nerfstudio.data.scene_box import OrientedBox
3839
from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes, TrainingCallbackLocation
@@ -146,6 +147,8 @@ class SplatfactoModelConfig(ModelConfig):
146147
However, PLY exported with antialiased rasterize mode is not compatible with classic mode. Thus many web viewers that
147148
were implemented for classic mode can not render antialiased mode PLY properly without modifications.
148149
"""
150+
camera_optimizer: CameraOptimizerConfig = field(default_factory=lambda: CameraOptimizerConfig(mode="off"))
151+
"""Config of the camera optimizer to use"""
149152

150153

151154
class SplatfactoModel(Model):
@@ -213,6 +216,10 @@ def populate_modules(self):
213216
}
214217
)
215218

219+
self.camera_optimizer: CameraOptimizer = self.config.camera_optimizer.setup(
220+
num_cameras=self.num_train_data, device="cpu"
221+
)
222+
216223
# metrics
217224
from torchmetrics.image import PeakSignalNoiseRatio
218225
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
@@ -609,6 +616,7 @@ def get_param_groups(self) -> Dict[str, List[Parameter]]:
609616
Mapping of different parameter groups
610617
"""
611618
gps = self.get_gaussian_param_groups()
619+
self.camera_optimizer.get_param_groups(param_groups=gps)
612620
return gps
613621

614622
def _get_downscale_factor(self):
@@ -648,6 +656,8 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
648656

649657
# get the background color
650658
if self.training:
659+
optimized_camera_to_world = self.camera_optimizer.apply_to_camera(camera)[0, ...]
660+
651661
if self.config.background_color == "random":
652662
background = torch.rand(3, device=self.device)
653663
elif self.config.background_color == "white":
@@ -657,6 +667,8 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
657667
else:
658668
background = self.background_color.to(self.device)
659669
else:
670+
optimized_camera_to_world = camera.camera_to_worlds[0, ...]
671+
660672
if renderers.BACKGROUND_COLOR_OVERRIDE is not None:
661673
background = renderers.BACKGROUND_COLOR_OVERRIDE.to(self.device)
662674
else:
@@ -674,8 +686,9 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
674686
camera_downscale = self._get_downscale_factor()
675687
camera.rescale_output_resolution(1 / camera_downscale)
676688
# shift the camera to center of scene looking at center
677-
R = camera.camera_to_worlds[0, :3, :3] # 3 x 3
678-
T = camera.camera_to_worlds[0, :3, 3:4] # 3 x 1
689+
R = optimized_camera_to_world[:3, :3] # 3 x 3
690+
T = optimized_camera_to_world[:3, 3:4] # 3 x 1
691+
679692
# flip the z and y axes to align with gsplat conventions
680693
R_edit = torch.diag(torch.tensor([1, -1, -1], device=self.device, dtype=R.dtype))
681694
R = R @ R_edit
@@ -738,7 +751,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
738751
self.xys.retain_grad()
739752

740753
if self.config.sh_degree > 0:
741-
viewdirs = means_crop.detach() - camera.camera_to_worlds.detach()[..., :3, 3] # (N, 3)
754+
viewdirs = means_crop.detach() - optimized_camera_to_world.detach()[:3, 3] # (N, 3)
742755
viewdirs = viewdirs / viewdirs.norm(dim=-1, keepdim=True)
743756
n = min(self.step // self.config.sh_degree_interval, self.config.sh_degree)
744757
rgbs = spherical_harmonics(n, viewdirs, colors_crop)
@@ -829,6 +842,8 @@ def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]:
829842
metrics_dict["psnr"] = self.psnr(predicted_rgb, gt_rgb)
830843

831844
metrics_dict["gaussian_count"] = self.num_points
845+
846+
self.camera_optimizer.get_metrics_dict(metrics_dict)
832847
return metrics_dict
833848

834849
def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Tensor]:
@@ -867,11 +882,17 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te
867882
else:
868883
scale_reg = torch.tensor(0.0).to(self.device)
869884

870-
return {
885+
loss_dict = {
871886
"main_loss": (1 - self.config.ssim_lambda) * Ll1 + self.config.ssim_lambda * simloss,
872887
"scale_reg": scale_reg,
873888
}
874889

890+
if self.training:
891+
# Add loss from camera optimizer
892+
self.camera_optimizer.get_loss_dict(loss_dict)
893+
894+
return loss_dict
895+
875896
@torch.no_grad()
876897
def get_outputs_for_camera(self, camera: Cameras, obb_box: Optional[OrientedBox] = None) -> Dict[str, torch.Tensor]:
877898
"""Takes in a camera, generates the raybundle, and computes the output of the model.

0 commit comments

Comments
 (0)