Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 46 additions & 4 deletions nerfstudio/models/nerfacto.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,13 @@
scale_gradients_by_distance_squared,
)
from nerfstudio.model_components.ray_samplers import ProposalNetworkSampler, UniformSampler
from nerfstudio.model_components.renderers import AccumulationRenderer, DepthRenderer, NormalsRenderer, RGBRenderer
from nerfstudio.model_components.renderers import (
AccumulationRenderer,
DepthRenderer,
NormalsRenderer,
RGBRenderer,
UncertaintyRenderer,
)
from nerfstudio.model_components.scene_colliders import NearFarCollider
from nerfstudio.model_components.shaders import NormalsShader
from nerfstudio.models.base_model import Model, ModelConfig
Expand All @@ -63,8 +69,12 @@ class NerfactoModelConfig(ModelConfig):
"""Dimension of hidden layers"""
hidden_dim_color: int = 64
"""Dimension of hidden layers for color network"""
use_transient_embedding: bool = False
"""Whether to use an transient embedding."""
hidden_dim_transient: int = 64
"""Dimension of hidden layers for transient network"""
transient_embed_dim: int = 16
"""Dimension of the transient embedding."""
num_levels: int = 16
"""Number of levels of the hashmap for the base mlp."""
base_res: int = 16
Expand Down Expand Up @@ -162,7 +172,9 @@ def populate_modules(self):
features_per_level=self.config.features_per_level,
log2_hashmap_size=self.config.log2_hashmap_size,
hidden_dim_color=self.config.hidden_dim_color,
use_transient_embedding=self.config.use_transient_embedding,
hidden_dim_transient=self.config.hidden_dim_transient,
transient_embedding_dim=self.config.transient_embed_dim,
spatial_distortion=scene_contraction,
num_images=self.num_train_data,
use_pred_normals=self.config.predict_normals,
Expand Down Expand Up @@ -234,6 +246,7 @@ def update_schedule(step):
self.renderer_accumulation = AccumulationRenderer()
self.renderer_depth = DepthRenderer(method="median")
self.renderer_expected_depth = DepthRenderer(method="expected")
self.renderer_uncertainty = UncertaintyRenderer()
self.renderer_normals = NormalsRenderer()

# shaders
Expand Down Expand Up @@ -305,11 +318,25 @@ def get_outputs(self, ray_bundle: RayBundle):
if self.config.use_gradient_scaling:
field_outputs = scale_gradients_by_distance_squared(field_outputs, ray_samples)

weights = ray_samples.get_weights(field_outputs[FieldHeadNames.DENSITY])
if self.training and self.config.use_transient_embedding:
static_density = field_outputs[FieldHeadNames.DENSITY]
transient_density = field_outputs[FieldHeadNames.TRANSIENT_DENSITY]
weights_static = ray_samples.get_weights(static_density)
weights_transient = ray_samples.get_weights(transient_density)
weights = weights_static
rgb_static_component = self.renderer_rgb(rgb=field_outputs[FieldHeadNames.RGB], weights=weights_static)
rgb_transient_component = self.renderer_rgb(
rgb=field_outputs[FieldHeadNames.TRANSIENT_RGB], weights=weights_transient
)
rgb = rgb_static_component + rgb_transient_component
else:
weights_static = ray_samples.get_weights(field_outputs[FieldHeadNames.DENSITY])
weights = weights_static
rgb = self.renderer_rgb(rgb=field_outputs[FieldHeadNames.RGB], weights=weights)

weights_list.append(weights)
ray_samples_list.append(ray_samples)

rgb = self.renderer_rgb(rgb=field_outputs[FieldHeadNames.RGB], weights=weights)
with torch.no_grad():
depth = self.renderer_depth(weights=weights, ray_samples=ray_samples)
expected_depth = self.renderer_expected_depth(weights=weights, ray_samples=ray_samples)
Expand Down Expand Up @@ -345,6 +372,13 @@ def get_outputs(self, ray_bundle: RayBundle):

for i in range(self.config.num_proposal_iterations):
outputs[f"prop_depth_{i}"] = self.renderer_depth(weights=weights_list[i], ray_samples=ray_samples_list[i])

# transients
if self.training and self.config.use_transient_embedding:
uncertainty = self.renderer_uncertainty(field_outputs[FieldHeadNames.UNCERTAINTY], weights_transient)
outputs["uncertainty"] = uncertainty + 0.1 # NOTE(ethan): this is the uncertainty min
outputs["density_transient"] = field_outputs[FieldHeadNames.TRANSIENT_DENSITY]

return outputs

def get_metrics_dict(self, outputs, batch):
Expand All @@ -369,7 +403,15 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None):
gt_image=image,
)

loss_dict["rgb_loss"] = self.rgb_loss(gt_rgb, pred_rgb)
if self.training and self.config.use_transient_embedding:
# transient loss
betas = outputs["uncertainty"]
loss_dict["uncertainty_loss"] = 3 + torch.log(betas).mean()
loss_dict["density_loss"] = 0.01 * outputs["density_transient"].mean()
loss_dict["rgb_loss"] = (((gt_rgb - pred_rgb) ** 2).sum(-1) / (betas[..., 0] ** 2)).mean()
else:
loss_dict["rgb_loss"] = self.rgb_loss(gt_rgb, pred_rgb)

if self.training:
loss_dict["interlevel_loss"] = self.config.interlevel_loss_mult * interlevel_loss(
outputs["weights_list"], outputs["ray_samples_list"]
Expand Down
Loading