33
33
from torch .nn import Parameter
34
34
from typing_extensions import Literal
35
35
36
+ from nerfstudio .cameras .camera_optimizers import CameraOptimizer , CameraOptimizerConfig
36
37
from nerfstudio .cameras .cameras import Cameras
37
38
from nerfstudio .data .scene_box import OrientedBox
38
39
from nerfstudio .engine .callbacks import TrainingCallback , TrainingCallbackAttributes , TrainingCallbackLocation
@@ -146,6 +147,8 @@ class SplatfactoModelConfig(ModelConfig):
146
147
However, PLY exported with antialiased rasterize mode is not compatible with classic mode. Thus many web viewers that
147
148
were implemented for classic mode can not render antialiased mode PLY properly without modifications.
148
149
"""
150
+ camera_optimizer : CameraOptimizerConfig = field (default_factory = lambda : CameraOptimizerConfig (mode = "off" ))
151
+ """Config of the camera optimizer to use"""
149
152
150
153
151
154
class SplatfactoModel (Model ):
@@ -213,6 +216,10 @@ def populate_modules(self):
213
216
}
214
217
)
215
218
219
+ self .camera_optimizer : CameraOptimizer = self .config .camera_optimizer .setup (
220
+ num_cameras = self .num_train_data , device = "cpu"
221
+ )
222
+
216
223
# metrics
217
224
from torchmetrics .image import PeakSignalNoiseRatio
218
225
from torchmetrics .image .lpip import LearnedPerceptualImagePatchSimilarity
@@ -609,6 +616,7 @@ def get_param_groups(self) -> Dict[str, List[Parameter]]:
609
616
Mapping of different parameter groups
610
617
"""
611
618
gps = self .get_gaussian_param_groups ()
619
+ self .camera_optimizer .get_param_groups (param_groups = gps )
612
620
return gps
613
621
614
622
def _get_downscale_factor (self ):
@@ -648,6 +656,8 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
648
656
649
657
# get the background color
650
658
if self .training :
659
+ optimized_camera_to_world = self .camera_optimizer .apply_to_camera (camera )[0 , ...]
660
+
651
661
if self .config .background_color == "random" :
652
662
background = torch .rand (3 , device = self .device )
653
663
elif self .config .background_color == "white" :
@@ -657,6 +667,8 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
657
667
else :
658
668
background = self .background_color .to (self .device )
659
669
else :
670
+ optimized_camera_to_world = camera .camera_to_worlds [0 , ...]
671
+
660
672
if renderers .BACKGROUND_COLOR_OVERRIDE is not None :
661
673
background = renderers .BACKGROUND_COLOR_OVERRIDE .to (self .device )
662
674
else :
@@ -674,8 +686,9 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
674
686
camera_downscale = self ._get_downscale_factor ()
675
687
camera .rescale_output_resolution (1 / camera_downscale )
676
688
# shift the camera to center of scene looking at center
677
- R = camera .camera_to_worlds [0 , :3 , :3 ] # 3 x 3
678
- T = camera .camera_to_worlds [0 , :3 , 3 :4 ] # 3 x 1
689
+ R = optimized_camera_to_world [:3 , :3 ] # 3 x 3
690
+ T = optimized_camera_to_world [:3 , 3 :4 ] # 3 x 1
691
+
679
692
# flip the z and y axes to align with gsplat conventions
680
693
R_edit = torch .diag (torch .tensor ([1 , - 1 , - 1 ], device = self .device , dtype = R .dtype ))
681
694
R = R @ R_edit
@@ -738,7 +751,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
738
751
self .xys .retain_grad ()
739
752
740
753
if self .config .sh_degree > 0 :
741
- viewdirs = means_crop .detach () - camera . camera_to_worlds . detach ()[..., :3 , 3 ] # (N, 3)
754
+ viewdirs = means_crop .detach () - optimized_camera_to_world . detach ()[:3 , 3 ] # (N, 3)
742
755
viewdirs = viewdirs / viewdirs .norm (dim = - 1 , keepdim = True )
743
756
n = min (self .step // self .config .sh_degree_interval , self .config .sh_degree )
744
757
rgbs = spherical_harmonics (n , viewdirs , colors_crop )
@@ -829,6 +842,8 @@ def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]:
829
842
metrics_dict ["psnr" ] = self .psnr (predicted_rgb , gt_rgb )
830
843
831
844
metrics_dict ["gaussian_count" ] = self .num_points
845
+
846
+ self .camera_optimizer .get_metrics_dict (metrics_dict )
832
847
return metrics_dict
833
848
834
849
def get_loss_dict (self , outputs , batch , metrics_dict = None ) -> Dict [str , torch .Tensor ]:
@@ -867,11 +882,17 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te
867
882
else :
868
883
scale_reg = torch .tensor (0.0 ).to (self .device )
869
884
870
- return {
885
+ loss_dict = {
871
886
"main_loss" : (1 - self .config .ssim_lambda ) * Ll1 + self .config .ssim_lambda * simloss ,
872
887
"scale_reg" : scale_reg ,
873
888
}
874
889
890
+ if self .training :
891
+ # Add loss from camera optimizer
892
+ self .camera_optimizer .get_loss_dict (loss_dict )
893
+
894
+ return loss_dict
895
+
875
896
@torch .no_grad ()
876
897
def get_outputs_for_camera (self , camera : Cameras , obb_box : Optional [OrientedBox ] = None ) -> Dict [str , torch .Tensor ]:
877
898
"""Takes in a camera, generates the raybundle, and computes the output of the model.
0 commit comments