diff --git a/nerfstudio/models/splatfacto.py b/nerfstudio/models/splatfacto.py index b23eabe002..6733db406b 100644 --- a/nerfstudio/models/splatfacto.py +++ b/nerfstudio/models/splatfacto.py @@ -155,6 +155,8 @@ class SplatfactoModelConfig(ModelConfig): """threshold of ratio of gaussian max to min scale before applying regularization loss from the PhysGaussian paper """ + output_depth_during_training: bool = False + """If True, output depth during training. Otherwise, only output depth during evaluation.""" class SplatfactoModel(Model): @@ -766,7 +768,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: alpha = alpha[..., None] rgb = torch.clamp(rgb, max=1.0) # type: ignore depth_im = None - if not self.training: + if self.config.output_depth_during_training or not self.training: depth_im = rasterize_gaussians( # type: ignore self.xys, depths,