Skip to content

Commit

Permalink
Add average_init_density to improve robustness of nerfacto training (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
jb-ye authored Jan 28, 2024
1 parent fb1fee1 commit d8b517d
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 2 deletions.
3 changes: 3 additions & 0 deletions nerfstudio/configs/method_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
),
model=NerfactoModelConfig(
eval_num_rays_per_chunk=1 << 15,
average_init_density=0.01,
camera_optimizer=CameraOptimizerConfig(mode="SO3xR3"),
),
),
Expand Down Expand Up @@ -140,6 +141,7 @@
max_res=4096,
proposal_weights_anneal_max_num_iters=5000,
log2_hashmap_size=21,
average_init_density=0.01,
camera_optimizer=CameraOptimizerConfig(mode="SO3xR3"),
),
),
Expand Down Expand Up @@ -187,6 +189,7 @@
max_res=8192,
proposal_weights_anneal_max_num_iters=5000,
log2_hashmap_size=21,
average_init_density=0.01,
camera_optimizer=CameraOptimizerConfig(mode="SO3xR3"),
),
),
Expand Down
4 changes: 3 additions & 1 deletion nerfstudio/fields/density_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,14 @@ def __init__(
base_res: int = 16,
log2_hashmap_size: int = 18,
features_per_level: int = 2,
average_init_density: float = 1.0,
implementation: Literal["tcnn", "torch"] = "tcnn",
) -> None:
super().__init__()
self.register_buffer("aabb", aabb)
self.spatial_distortion = spatial_distortion
self.use_linear = use_linear
self.average_init_density = average_init_density

self.register_buffer("max_res", torch.tensor(max_res))
self.register_buffer("num_levels", torch.tensor(num_levels))
Expand Down Expand Up @@ -111,7 +113,7 @@ def get_density(self, ray_samples: RaySamples) -> Tuple[Tensor, None]:
# Rectifying the density with an exponential is much more stable than a ReLU or
# softplus, because it enables high post-activation (float32) density outputs
# from smaller internal (float16) parameters.
density = trunc_exp(density_before_activation)
density = self.average_init_density * trunc_exp(density_before_activation)
density = density * selector[..., None]
return density, None

Expand Down
4 changes: 3 additions & 1 deletion nerfstudio/fields/nerfacto_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(
use_pred_normals: bool = False,
use_average_appearance_embedding: bool = False,
spatial_distortion: Optional[SpatialDistortion] = None,
average_init_density: float = 1.0,
implementation: Literal["tcnn", "torch"] = "tcnn",
) -> None:
super().__init__()
Expand All @@ -116,6 +117,7 @@ def __init__(
self.use_pred_normals = use_pred_normals
self.pass_semantic_gradients = pass_semantic_gradients
self.base_res = base_res
self.average_init_density = average_init_density
self.step = 0

self.direction_encoding = SHEncoding(
Expand Down Expand Up @@ -218,7 +220,7 @@ def get_density(self, ray_samples: RaySamples) -> Tuple[Tensor, Tensor]:
# Rectifying the density with an exponential is much more stable than a ReLU or
# softplus, because it enables high post-activation (float32) density outputs
# from smaller internal (float16) parameters.
density = trunc_exp(density_before_activation.to(positions))
density = self.average_init_density * trunc_exp(density_before_activation.to(positions))
density = density * selector[..., None]
return density, base_mlp_out

Expand Down
5 changes: 5 additions & 0 deletions nerfstudio/models/nerfacto.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ class NerfactoModelConfig(ModelConfig):
"""Which implementation to use for the model."""
appearance_embed_dim: int = 32
"""Dimension of the appearance embedding."""
average_init_density: float = 1.0
"""Average initial density output from MLP. """
camera_optimizer: CameraOptimizerConfig = field(default_factory=lambda: CameraOptimizerConfig(mode="SO3xR3"))
"""Config of the camera optimizer to use"""

Expand Down Expand Up @@ -162,6 +164,7 @@ def populate_modules(self):
use_pred_normals=self.config.predict_normals,
use_average_appearance_embedding=self.config.use_average_appearance_embedding,
appearance_embedding_dim=self.config.appearance_embed_dim,
average_init_density=self.config.average_init_density,
implementation=self.config.implementation,
)

Expand All @@ -179,6 +182,7 @@ def populate_modules(self):
self.scene_box.aabb,
spatial_distortion=scene_contraction,
**prop_net_args,
average_init_density=self.config.average_init_density,
implementation=self.config.implementation,
)
self.proposal_networks.append(network)
Expand All @@ -190,6 +194,7 @@ def populate_modules(self):
self.scene_box.aabb,
spatial_distortion=scene_contraction,
**prop_net_args,
average_init_density=self.config.average_init_density,
implementation=self.config.implementation,
)
self.proposal_networks.append(network)
Expand Down

0 comments on commit d8b517d

Please sign in to comment.