Skip to content

Commit

Permalink
Update splatfacto.py (nerfstudio-project#2804)
Browse files Browse the repository at this point in the history
Shape error: The shape of batch['mask'] is [H, W, 1], not[H, W]
  • Loading branch information
Harr7y authored and ArpegorPSGH committed Jun 22, 2024
1 parent 387f355 commit 0837031
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions nerfstudio/models/splatfacto.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,8 +827,9 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te
# Set masked part of both ground-truth and rendered image to black.
# This is a little bit sketchy for the SSIM loss.
if "mask" in batch:
assert batch["mask"].shape == gt_img.shape[:2] == pred_img.shape[:2]
mask = batch["mask"][..., None].to(self.device)
# batch["mask"] : [H, W, 1]
assert batch["mask"].shape[:2] == gt_img.shape[:2] == pred_img.shape[:2]
mask = batch["mask"].to(self.device)
gt_img = gt_img * mask
pred_img = pred_img * mask

Expand Down

0 comments on commit 0837031

Please sign in to comment.