Skip to content

Commit

Permalink
Allow use of black or white background in Nerfacto models (#1441)
Browse files Browse the repository at this point in the history
* Allow black/white background color in Nerfacto

* Use white background in Nerfacto Blender Benchmark

* Fixed unused import

* Updated index for license headers script
  • Loading branch information
iSach authored Feb 18, 2023
1 parent 2357332 commit 8e4ae6b
Show file tree
Hide file tree
Showing 17 changed files with 14 additions and 22 deletions.
3 changes: 3 additions & 0 deletions nerfstudio/.vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"python.analysis.typeCheckingMode": "basic"
}
7 changes: 4 additions & 3 deletions nerfstudio/model_components/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from typing_extensions import Literal

from nerfstudio.cameras.rays import RaySamples
from nerfstudio.utils import colors
from nerfstudio.utils.math import components_from_spherical_harmonics, safe_normalize

BACKGROUND_COLOR_OVERRIDE: Optional[TensorType[3]] = None
Expand Down Expand Up @@ -70,7 +71,7 @@ def combine_rgb(
cls,
rgb: TensorType["bs":..., "num_samples", 3],
weights: TensorType["bs":..., "num_samples", 1],
background_color: Union[Literal["random", "black", "last_sample"], TensorType[3]] = "random",
background_color: Union[Literal["random", "white", "black", "last_sample"], TensorType[3]] = "random",
ray_indices: Optional[TensorType["num_samples"]] = None,
num_rays: Optional[int] = None,
) -> TensorType["bs":..., 3]:
Expand Down Expand Up @@ -102,8 +103,8 @@ def combine_rgb(
background_color = rgb[..., -1, :]
if background_color == "random":
background_color = torch.rand_like(comp_rgb).to(rgb.device)
if background_color == "black":
background_color = torch.zeros_like(comp_rgb).to(rgb.device)
if isinstance(background_color, str) and background_color in colors.COLORS_DICT:
background_color = colors.COLORS_DICT[background_color].to(rgb.device)

assert isinstance(background_color, torch.Tensor)
comp_rgb = comp_rgb + background_color.to(weights.device) * (1.0 - accumulated_weight)
Expand Down
8 changes: 2 additions & 6 deletions nerfstudio/models/instant_ngp.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
RGBRenderer,
)
from nerfstudio.models.base_model import Model, ModelConfig
from nerfstudio.utils import colormaps, colors
from nerfstudio.utils import colormaps


@dataclass
Expand Down Expand Up @@ -123,11 +123,7 @@ def populate_modules(self):
)

# renderers
background_color = "random"
if self.config.background_color in ["white", "black"]:
background_color = colors.COLORS_DICT[self.config.background_color]

self.renderer_rgb = RGBRenderer(background_color=background_color)
self.renderer_rgb = RGBRenderer(background_color=self.config.background_color)
self.renderer_accumulation = AccumulationRenderer()
self.renderer_depth = DepthRenderer(method="expected")

Expand Down
2 changes: 1 addition & 1 deletion nerfstudio/models/nerfacto.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class NerfactoModelConfig(ModelConfig):
"""How far along the ray to start sampling."""
far_plane: float = 1000.0
"""How far along the ray to stop sampling."""
background_color: Literal["random", "last_sample"] = "last_sample"
background_color: Literal["random", "last_sample", "black", "white"] = "last_sample"
"""Whether to randomize the background color."""
num_levels: int = 16
"""Number of levels of the hashmap for the base mlp."""
Expand Down
2 changes: 1 addition & 1 deletion nerfstudio/models/nerfplayer_nerfacto.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class NerfplayerNerfactoModelConfig(NerfactoModelConfig):
"""How far along the ray to start sampling."""
far_plane: float = 1000.0
"""How far along the ray to stop sampling."""
background_color: Literal["random", "last_sample"] = "random"
background_color: Literal["random", "last_sample", "black", "white"] = "random"
"""Whether to randomize the background color. (Random is reported to be better on DyCheck.)"""
num_levels: int = 16
"""Hashing grid parameter."""
Expand Down
12 changes: 2 additions & 10 deletions nerfstudio/models/nerfplayer_ngp.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
)
from nerfstudio.models.base_model import Model
from nerfstudio.models.instant_ngp import InstantNGPModelConfig, NGPModel
from nerfstudio.utils import colors


@dataclass
Expand Down Expand Up @@ -128,13 +127,6 @@ def populate_modules(self):
) # need to update the density_fn later during forward (for input time)

# renderers
self.train_background_color = self.config.train_background_color
self.eval_background_color = self.config.eval_background_color
if self.config.train_background_color in ["white", "black"]:
self.train_background_color = colors.COLORS_DICT[self.config.train_background_color]
if self.config.eval_background_color in ["white", "black"]:
self.eval_background_color = colors.COLORS_DICT[self.config.eval_background_color]

self.renderer_rgb = RGBRenderer() # will update bgcolor later during forward
self.renderer_accumulation = AccumulationRenderer()
self.renderer_depth = DepthRenderer(method="expected")
Expand Down Expand Up @@ -175,9 +167,9 @@ def get_outputs(self, ray_bundle: RayBundle):

# update bgcolor in the renderer; usually random color for training and fixed color for inference
if self.training:
self.renderer_rgb.background_color = self.train_background_color
self.renderer_rgb.background_color = self.config.train_background_color
else:
self.renderer_rgb.background_color = self.eval_background_color
self.renderer_rgb.background_color = self.config.eval_background_color
rgb = self.renderer_rgb(
rgb=field_outputs[FieldHeadNames.RGB],
weights=weights,
Expand Down
Empty file modified scripts/benchmarking/launch_eval_blender.sh
100755 → 100644
Empty file.
2 changes: 1 addition & 1 deletion scripts/benchmarking/launch_train_blender.sh
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ fi
method_opts=()
if [ "$method_name" = "nerfacto" ]; then
# https://github.com/nerfstudio-project/nerfstudio/issues/806#issuecomment-1284327844
method_opts=(--pipeline.model.near-plane 2. --pipeline.model.far-plane 6. --pipeline.datamanager.camera-optimizer.mode off --pipeline.model.use-average-appearance-embedding False)
method_opts=(--pipeline.model.background-color white --pipeline.model.near-plane 2. --pipeline.model.far-plane 6. --pipeline.datamanager.camera-optimizer.mode off --pipeline.model.use-average-appearance-embedding False)
fi

shift $((OPTIND-1))
Expand Down
Empty file modified scripts/completions/install.py
100755 → 100644
Empty file.
Empty file modified scripts/docs/add_nb_tags.py
100755 → 100644
Empty file.
Empty file modified scripts/docs/build_docs.py
100755 → 100644
Empty file.
Empty file modified scripts/downloads/download_data.py
100755 → 100644
Empty file.
Empty file modified scripts/eval.py
100755 → 100644
Empty file.
Empty file modified scripts/github/run_actions.py
100755 → 100644
Empty file.
Empty file modified scripts/process_data.py
100755 → 100644
Empty file.
Empty file modified scripts/train.py
100755 → 100644
Empty file.
Empty file modified tests/data/lego_test/transforms_train.json
100755 → 100644
Empty file.

0 comments on commit 8e4ae6b

Please sign in to comment.