Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Depth supervision for nerfacto #1173

Merged
merged 21 commits into from
Jan 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/quickstart/data_conventions.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,7 @@ For a transform matrix, the first 3 columns are the +X, +Y, and +Z defining the
]
}
```

**Depth images**

To train with depth supervision, you can also provide a `depth_file_path` for each frame in your `transforms.json` and use one of the methods that support additional depth losses (e.g., depth-nerfacto). The depths are assumed to be 16-bit or 32-bit and to be in millimeters to remain consistent with [Polyform](https://github.com/PolyCam/polyform). You can adjust this scaling factor using the `depth_unit_scale_factor` parameter in `NerfstudioDataParserConfig`. Note that by default, we resize the depth images to match the shape of the RGB images.
15 changes: 15 additions & 0 deletions nerfstudio/cameras/camera_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,21 @@ def normalize(x) -> TensorType[...]:
return x / torch.linalg.norm(x)


def normalize_with_norm(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Normalize tensor along axis and return normalized value with norms.

Args:
x: tensor to normalize.
dim: axis along which to normalize.

Returns:
Tuple of normalized tensor and corresponding norm.
"""

norm = torch.maximum(torch.linalg.vector_norm(x, dim=dim, keepdims=True), torch.tensor([_EPS]).to(x))
return x / norm, norm


def viewmatrix(lookat, up, pos) -> TensorType[...]:
"""Returns a camera transformation matrix.

Expand Down
4 changes: 2 additions & 2 deletions nerfstudio/cameras/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import cv2
import torch
import torchvision
from torch.nn.functional import normalize
from torchtyping import TensorType

import nerfstudio.utils.poses as pose_utils
Expand Down Expand Up @@ -666,7 +665,7 @@ def _generate_rays_from_coords(
directions_stack = torch.sum(
directions_stack[..., None, :] * rotation, dim=-1
) # (..., 1, 3) * (..., 3, 3) -> (..., 3)
directions_stack = normalize(directions_stack, dim=-1)
directions_stack, directions_norm = camera_utils.normalize_with_norm(directions_stack, -1)
assert directions_stack.shape == (3,) + num_rays_shape + (3,)

origins = c2w[..., :3, 3] # (..., 3)
Expand All @@ -691,6 +690,7 @@ def _generate_rays_from_coords(
pixel_area=pixel_area,
camera_indices=camera_indices,
times=times,
metadata={"directions_norm": directions_norm[0].detach()},
)

def to_json(
Expand Down
34 changes: 34 additions & 0 deletions nerfstudio/configs/method_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from nerfstudio.cameras.camera_optimizers import CameraOptimizerConfig
from nerfstudio.configs.base_config import ViewerConfig
from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManagerConfig
from nerfstudio.data.datamanagers.depth_datamanager import DepthDataManagerConfig
from nerfstudio.data.datamanagers.semantic_datamanager import SemanticDataManagerConfig
from nerfstudio.data.datamanagers.variable_res_datamanager import (
VariableResDataManagerConfig,
Expand All @@ -44,6 +45,7 @@
from nerfstudio.engine.schedulers import SchedulerConfig
from nerfstudio.engine.trainer import TrainerConfig
from nerfstudio.field_components.temporal_distortions import TemporalDistortionKind
from nerfstudio.models.depth_nerfacto import DepthNerfactoModelConfig
from nerfstudio.models.instant_ngp import InstantNGPModelConfig
from nerfstudio.models.mipnerf import MipNerfModel
from nerfstudio.models.nerfacto import NerfactoModelConfig
Expand All @@ -56,6 +58,7 @@
method_configs: Dict[str, TrainerConfig] = {}
descriptions = {
"nerfacto": "Recommended real-time model tuned for real captures. This model will be continually updated.",
"depth-nerfacto": "Nerfacto with depth supervision.",
"instant-ngp": "Implementation of Instant-NGP. Recommended real-time model for unbounded scenes.",
"instant-ngp-bounded": "Implementation of Instant-NGP. Recommended for bounded real and synthetic scenes",
"mipnerf": "High quality model for bounded scenes. (slow)",
Expand Down Expand Up @@ -97,6 +100,37 @@
vis="viewer",
)

method_configs["depth-nerfacto"] = TrainerConfig(
method_name="depth-nerfacto",
steps_per_eval_batch=500,
steps_per_save=2000,
max_num_iterations=30000,
mixed_precision=True,
pipeline=VanillaPipelineConfig(
datamanager=DepthDataManagerConfig(
dataparser=NerfstudioDataParserConfig(),
train_num_rays_per_batch=4096,
eval_num_rays_per_batch=4096,
camera_optimizer=CameraOptimizerConfig(
mode="SO3xR3", optimizer=AdamOptimizerConfig(lr=6e-4, eps=1e-8, weight_decay=1e-2)
),
),
model=DepthNerfactoModelConfig(eval_num_rays_per_chunk=1 << 15),
),
optimizers={
"proposal_networks": {
"optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
"scheduler": None,
},
"fields": {
"optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
"scheduler": None,
},
},
viewer=ViewerConfig(num_rays_per_chunk=1 << 15),
vis="viewer",
)

method_configs["instant-ngp"] = TrainerConfig(
method_name="instant-ngp",
steps_per_eval_batch=500,
Expand Down
48 changes: 48 additions & 0 deletions nerfstudio/data/datamanagers/depth_datamanager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2022 The Nerfstudio Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Depth datamanager.
"""

from dataclasses import dataclass, field
from typing import Type

from nerfstudio.data.datamanagers import base_datamanager
from nerfstudio.data.datasets.depth_dataset import DepthDataset


@dataclass
class DepthDataManagerConfig(base_datamanager.VanillaDataManagerConfig):
"""A depth datamanager - required to use with .setup()"""

_target: Type = field(default_factory=lambda: DepthDataManager)


class DepthDataManager(base_datamanager.VanillaDataManager): # pylint: disable=abstract-method
"""Data manager implementation for data that also requires processing depth data.
Args:
config: the DataManagerConfig used to instantiate class
"""

def create_train_dataset(self) -> DepthDataset:
self.train_dataparser_outputs = self.dataparser.get_dataparser_outputs(split="train")
return DepthDataset(
dataparser_outputs=self.train_dataparser_outputs,
)

def create_eval_dataset(self) -> DepthDataset:
return DepthDataset(
dataparser_outputs=self.dataparser.get_dataparser_outputs(split=self.test_split),
)
20 changes: 20 additions & 0 deletions nerfstudio/data/dataparsers/nerfstudio_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ class NerfstudioDataParserConfig(DataParserConfig):
"""Whether to automatically scale the poses to fit in +/- 1 bounding box."""
train_split_percentage: float = 0.9
"""The percent of images to use for training. The remaining images are for eval."""
depth_unit_scale_factor: float = 1e-3
"""Scales the depth values to meters. Default value is 0.001 for a millimiter to meter conversion."""


@dataclass
Expand All @@ -83,6 +85,7 @@ def _generate_dataparser_outputs(self, split="train"):

image_filenames = []
mask_filenames = []
depth_filenames = []
poses = []
num_skipped_image_filenames = 0

Expand Down Expand Up @@ -152,6 +155,12 @@ def _generate_dataparser_outputs(self, split="train"):
downsample_folder_prefix="masks_",
)
mask_filenames.append(mask_fname)

if "depth_file_path" in frame:
depth_filepath = PurePath(frame["depth_file_path"])
depth_fname = self._get_fname(depth_filepath, data_dir, downsample_folder_prefix="depths_")
depth_filenames.append(depth_fname)

if num_skipped_image_filenames >= 0:
CONSOLE.log(f"Skipping {num_skipped_image_filenames} files in dataset split {split}.")
assert (
Expand All @@ -166,6 +175,12 @@ def _generate_dataparser_outputs(self, split="train"):
Different number of image and mask filenames.
You should check that mask_path is specified for every frame (or zero frames) in transforms.json.
"""
assert len(depth_filenames) == 0 or (
len(depth_filenames) == len(image_filenames)
), """
Different number of image and depth filenames.
You should check that depth_file_path is specified for every frame (or zero frames) in transforms.json.
"""

# filter image_filenames and poses based on train/eval split percentage
num_images = len(image_filenames)
Expand Down Expand Up @@ -208,6 +223,7 @@ def _generate_dataparser_outputs(self, split="train"):
# Choose image_filenames and poses based on split, but after auto orient and scaling the poses.
image_filenames = [image_filenames[i] for i in indices]
mask_filenames = [mask_filenames[i] for i in indices] if len(mask_filenames) > 0 else []
depth_filenames = [depth_filenames[i] for i in indices] if len(depth_filenames) > 0 else []
poses = poses[indices]

# in x,y,z order
Expand Down Expand Up @@ -265,6 +281,10 @@ def _generate_dataparser_outputs(self, split="train"):
mask_filenames=mask_filenames if len(mask_filenames) > 0 else None,
dataparser_scale=scale_factor,
dataparser_transform=transform_matrix,
metadata={
"depth_filenames": depth_filenames if len(depth_filenames) > 0 else None,
"depth_unit_scale_factor": self.config.depth_unit_scale_factor,
},
)
return dataparser_outputs

Expand Down
54 changes: 54 additions & 0 deletions nerfstudio/data/datasets/depth_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2022 The Nerfstudio Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Depth dataset.
"""

from typing import Dict

from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs
from nerfstudio.data.datasets.base_dataset import InputDataset
from nerfstudio.data.utils.data_utils import get_depth_image_from_path


class DepthDataset(InputDataset):
"""Dataset that returns images and depths.

Args:
dataparser_outputs: description of where and how to read input images.
scale_factor: The scaling factor for the dataparser outputs.
"""

def __init__(self, dataparser_outputs: DataparserOutputs, scale_factor: float = 1.0):
super().__init__(dataparser_outputs, scale_factor)
assert (
"depth_filenames" in dataparser_outputs.metadata.keys()
and dataparser_outputs.metadata["depth_filenames"] is not None
)
self.depth_filenames = self.metadata["depth_filenames"]
self.depth_unit_scale_factor = self.metadata["depth_unit_scale_factor"]

def get_metadata(self, data: Dict) -> Dict:
filepath = self.depth_filenames[data["image_idx"]]
height = int(self._dataparser_outputs.cameras.height[data["image_idx"]])
width = int(self._dataparser_outputs.cameras.width[data["image_idx"]])

# Scale depth images to meter units and also by scaling applied to cameras
scale_factor = self.depth_unit_scale_factor * self._dataparser_outputs.dataparser_scale
depth_image = get_depth_image_from_path(
filepath=filepath, height=height, width=width, scale_factor=scale_factor
)

return {"depth_image": depth_image}
27 changes: 27 additions & 0 deletions nerfstudio/data/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pathlib import Path
from typing import List, Tuple, Union

import cv2
import numpy as np
import torch
from PIL import Image
Expand Down Expand Up @@ -51,3 +52,29 @@ def get_semantics_and_mask_tensors_from_path(
semantics = torch.from_numpy(np.array(pil_image, dtype="int64"))[..., None]
mask = torch.sum(semantics == mask_indices, dim=-1, keepdim=True) == 0
return semantics, mask


def get_depth_image_from_path(
filepath: Path,
height: int,
width: int,
scale_factor: float,
interpolation: int = cv2.INTER_NEAREST,
) -> torch.Tensor:
"""Loads, rescales and resizes depth images.

Assumes filepath points to a 16-bit or 32-bit depth image.
Args:
filepath: Path to depth image.
height: Target depth image height.
width: Target depth image width.
scale_factor: Factor by which to scale depth image.
interpolation: Depth value interpolation for resizing.

Returns:
Depth image torch tensor with shape [width, height, 1].
"""
image = cv2.imread(str(filepath.absolute()), cv2.IMREAD_ANYDEPTH)
image = image.astype(np.float64) * scale_factor
image = cv2.resize(image, (width, height), interpolation=interpolation)
return torch.from_numpy(image[:, :, np.newaxis])
Loading