Skip to content

Commit

Permalink
Add support for sparse depth maps (#1373)
Browse files Browse the repository at this point in the history
Add support for sparse depth maps (#2)

* sparse depth maps

* pr comments

* unified loss

* minor

* lint
  • Loading branch information
yimingzhou1 authored Feb 8, 2023
1 parent 2ff1d3b commit 8c9c2ba
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
18 changes: 12 additions & 6 deletions nerfstudio/model_components/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
URF_SIGMA_SCALE_FACTOR = 3.0


class DephtLossType(Enum):
class DepthLossType(Enum):
"""Types of depth losses for depth supervision."""

DS_NERF = 1
Expand Down Expand Up @@ -225,8 +225,11 @@ def ds_nerf_depth_loss(
Returns:
Depth loss scalar.
"""
depth_mask = termination_depth > 0

loss = -torch.log(weights + EPS) * torch.exp(-((steps - termination_depth[:, None]) ** 2) / (2 * sigma)) * lengths
return torch.mean(loss.sum(-2))
loss = loss.sum(-2) * depth_mask
return torch.mean(loss)


def urban_radiance_field_depth_loss(
Expand All @@ -247,6 +250,8 @@ def urban_radiance_field_depth_loss(
Returns:
Depth loss scalar.
"""
depth_mask = termination_depth > 0

# Expected depth loss
expected_depth_loss = (termination_depth - predicted_depth) ** 2

Expand All @@ -262,7 +267,8 @@ def urban_radiance_field_depth_loss(
line_of_sight_loss_empty = (line_of_sight_loss_empty_mask * weights**2).sum(-2)
line_of_sight_loss = line_of_sight_loss_near + line_of_sight_loss_empty

return torch.mean(expected_depth_loss + line_of_sight_loss)
loss = (expected_depth_loss + line_of_sight_loss) * depth_mask
return torch.mean(loss)


def depth_loss(
Expand All @@ -273,7 +279,7 @@ def depth_loss(
sigma: TensorType[0],
directions_norm: TensorType[..., 1],
is_euclidean: bool,
depth_loss_type: DephtLossType,
depth_loss_type: DepthLossType,
) -> TensorType[0]:
"""Implementation of depth losses.
Expand All @@ -294,11 +300,11 @@ def depth_loss(
termination_depth = termination_depth * directions_norm
steps = (ray_samples.frustums.starts + ray_samples.frustums.ends) / 2

if depth_loss_type == DephtLossType.DS_NERF:
if depth_loss_type == DepthLossType.DS_NERF:
lengths = ray_samples.frustums.ends - ray_samples.frustums.starts
return ds_nerf_depth_loss(weights, termination_depth, steps, lengths, sigma)

if depth_loss_type == DephtLossType.URF:
if depth_loss_type == DepthLossType.URF:
return urban_radiance_field_depth_loss(weights, termination_depth, predicted_depth, steps, sigma)

raise NotImplementedError("Provided depth loss type not implemented.")
9 changes: 6 additions & 3 deletions nerfstudio/models/depth_nerfacto.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch

from nerfstudio.cameras.rays import RayBundle
from nerfstudio.model_components.losses import DephtLossType, depth_loss
from nerfstudio.model_components.losses import DepthLossType, depth_loss
from nerfstudio.models.nerfacto import NerfactoModel, NerfactoModelConfig
from nerfstudio.utils import colormaps

Expand All @@ -46,7 +46,7 @@ class DepthNerfactoModelConfig(NerfactoModelConfig):
"""Starting uncertainty around depth values in meters (defaults to 0.2m)."""
sigma_decay_rate: float = 0.99985
"""Rate of exponential decay."""
depth_loss_type: DephtLossType = DephtLossType.DS_NERF
depth_loss_type: DepthLossType = DepthLossType.DS_NERF
"""Depth loss type."""


Expand Down Expand Up @@ -118,7 +118,10 @@ def get_image_metrics_and_images(
far_plane=torch.max(ground_truth_depth),
)
images["depth"] = torch.cat([ground_truth_depth_colormap, predicted_depth_colormap], dim=1)
metrics["depth_mse"] = torch.nn.functional.mse_loss(outputs["depth"], ground_truth_depth)
depth_mask = ground_truth_depth > 0
metrics["depth_mse"] = torch.nn.functional.mse_loss(
outputs["depth"][depth_mask], ground_truth_depth[depth_mask]
)
return metrics, images

def _get_sigma(self):
Expand Down

0 comments on commit 8c9c2ba

Please sign in to comment.