Skip to content

Commit

Permalink
Depth supervision for Nerfacto (nerfstudio-project#1173)
Browse files Browse the repository at this point in the history
* Implement depth supervision

* Fixes

* Change normalize dtype

* refactor

* Change depth weight

* Small refactor

* Change docstring

* Removu unused import

* Add assumption to depth comment

* Implement URF loss

* Remove default depth loss type param

* URF fixes

* Fix depth loss scale

* Visualize gt depths

* Fixes & documentation

* Documentation and formatting

* Fix formatting problems

* Fix linter problems

* Fix tests

* Fix tests
  • Loading branch information
mpmisko authored Jan 2, 2023
1 parent c313804 commit 76a40d1
Show file tree
Hide file tree
Showing 12 changed files with 460 additions and 4 deletions.
4 changes: 4 additions & 0 deletions docs/quickstart/data_conventions.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,7 @@ For a transform matrix, the first 3 columns are the +X, +Y, and +Z defining the
]
}
```

**Depth images**

To train with depth supervision, you can also provide a `depth_file_path` for each frame in your `transforms.json` and use one of the methods that support additional depth losses (e.g., depth-nerfacto). The depths are assumed to be 16-bit or 32-bit and to be in millimeters to remain consistent with [Polyform](https://github.com/PolyCam/polyform). You can adjust this scaling factor using the `depth_unit_scale_factor` parameter in `NerfstudioDataParserConfig`. Note that by default, we resize the depth images to match the shape of the RGB images.
15 changes: 15 additions & 0 deletions nerfstudio/cameras/camera_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,21 @@ def normalize(x) -> TensorType[...]:
return x / torch.linalg.norm(x)


def normalize_with_norm(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Normalize tensor along axis and return normalized value with norms.
Args:
x: tensor to normalize.
dim: axis along which to normalize.
Returns:
Tuple of normalized tensor and corresponding norm.
"""

norm = torch.maximum(torch.linalg.vector_norm(x, dim=dim, keepdims=True), torch.tensor([_EPS]).to(x))
return x / norm, norm


def viewmatrix(lookat, up, pos) -> TensorType[...]:
"""Returns a camera transformation matrix.
Expand Down
4 changes: 2 additions & 2 deletions nerfstudio/cameras/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import cv2
import torch
import torchvision
from torch.nn.functional import normalize
from torchtyping import TensorType

import nerfstudio.utils.poses as pose_utils
Expand Down Expand Up @@ -666,7 +665,7 @@ def _generate_rays_from_coords(
directions_stack = torch.sum(
directions_stack[..., None, :] * rotation, dim=-1
) # (..., 1, 3) * (..., 3, 3) -> (..., 3)
directions_stack = normalize(directions_stack, dim=-1)
directions_stack, directions_norm = camera_utils.normalize_with_norm(directions_stack, -1)
assert directions_stack.shape == (3,) + num_rays_shape + (3,)

origins = c2w[..., :3, 3] # (..., 3)
Expand All @@ -691,6 +690,7 @@ def _generate_rays_from_coords(
pixel_area=pixel_area,
camera_indices=camera_indices,
times=times,
metadata={"directions_norm": directions_norm[0].detach()},
)

def to_json(
Expand Down
34 changes: 34 additions & 0 deletions nerfstudio/configs/method_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from nerfstudio.cameras.camera_optimizers import CameraOptimizerConfig
from nerfstudio.configs.base_config import ViewerConfig
from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManagerConfig
from nerfstudio.data.datamanagers.depth_datamanager import DepthDataManagerConfig
from nerfstudio.data.datamanagers.semantic_datamanager import SemanticDataManagerConfig
from nerfstudio.data.datamanagers.variable_res_datamanager import (
VariableResDataManagerConfig,
Expand All @@ -44,6 +45,7 @@
from nerfstudio.engine.schedulers import SchedulerConfig
from nerfstudio.engine.trainer import TrainerConfig
from nerfstudio.field_components.temporal_distortions import TemporalDistortionKind
from nerfstudio.models.depth_nerfacto import DepthNerfactoModelConfig
from nerfstudio.models.instant_ngp import InstantNGPModelConfig
from nerfstudio.models.mipnerf import MipNerfModel
from nerfstudio.models.nerfacto import NerfactoModelConfig
Expand All @@ -56,6 +58,7 @@
method_configs: Dict[str, TrainerConfig] = {}
descriptions = {
"nerfacto": "Recommended real-time model tuned for real captures. This model will be continually updated.",
"depth-nerfacto": "Nerfacto with depth supervision.",
"instant-ngp": "Implementation of Instant-NGP. Recommended real-time model for unbounded scenes.",
"instant-ngp-bounded": "Implementation of Instant-NGP. Recommended for bounded real and synthetic scenes",
"mipnerf": "High quality model for bounded scenes. (slow)",
Expand Down Expand Up @@ -97,6 +100,37 @@
vis="viewer",
)

method_configs["depth-nerfacto"] = TrainerConfig(
method_name="depth-nerfacto",
steps_per_eval_batch=500,
steps_per_save=2000,
max_num_iterations=30000,
mixed_precision=True,
pipeline=VanillaPipelineConfig(
datamanager=DepthDataManagerConfig(
dataparser=NerfstudioDataParserConfig(),
train_num_rays_per_batch=4096,
eval_num_rays_per_batch=4096,
camera_optimizer=CameraOptimizerConfig(
mode="SO3xR3", optimizer=AdamOptimizerConfig(lr=6e-4, eps=1e-8, weight_decay=1e-2)
),
),
model=DepthNerfactoModelConfig(eval_num_rays_per_chunk=1 << 15),
),
optimizers={
"proposal_networks": {
"optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
"scheduler": None,
},
"fields": {
"optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
"scheduler": None,
},
},
viewer=ViewerConfig(num_rays_per_chunk=1 << 15),
vis="viewer",
)

method_configs["instant-ngp"] = TrainerConfig(
method_name="instant-ngp",
steps_per_eval_batch=500,
Expand Down
48 changes: 48 additions & 0 deletions nerfstudio/data/datamanagers/depth_datamanager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2022 The Nerfstudio Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Depth datamanager.
"""

from dataclasses import dataclass, field
from typing import Type

from nerfstudio.data.datamanagers import base_datamanager
from nerfstudio.data.datasets.depth_dataset import DepthDataset


@dataclass
class DepthDataManagerConfig(base_datamanager.VanillaDataManagerConfig):
"""A depth datamanager - required to use with .setup()"""

_target: Type = field(default_factory=lambda: DepthDataManager)


class DepthDataManager(base_datamanager.VanillaDataManager): # pylint: disable=abstract-method
"""Data manager implementation for data that also requires processing depth data.
Args:
config: the DataManagerConfig used to instantiate class
"""

def create_train_dataset(self) -> DepthDataset:
self.train_dataparser_outputs = self.dataparser.get_dataparser_outputs(split="train")
return DepthDataset(
dataparser_outputs=self.train_dataparser_outputs,
)

def create_eval_dataset(self) -> DepthDataset:
return DepthDataset(
dataparser_outputs=self.dataparser.get_dataparser_outputs(split=self.test_split),
)
20 changes: 20 additions & 0 deletions nerfstudio/data/dataparsers/nerfstudio_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ class NerfstudioDataParserConfig(DataParserConfig):
"""Whether to automatically scale the poses to fit in +/- 1 bounding box."""
train_split_percentage: float = 0.9
"""The percent of images to use for training. The remaining images are for eval."""
depth_unit_scale_factor: float = 1e-3
"""Scales the depth values to meters. Default value is 0.001 for a millimiter to meter conversion."""


@dataclass
Expand All @@ -83,6 +85,7 @@ def _generate_dataparser_outputs(self, split="train"):

image_filenames = []
mask_filenames = []
depth_filenames = []
poses = []
num_skipped_image_filenames = 0

Expand Down Expand Up @@ -152,6 +155,12 @@ def _generate_dataparser_outputs(self, split="train"):
downsample_folder_prefix="masks_",
)
mask_filenames.append(mask_fname)

if "depth_file_path" in frame:
depth_filepath = PurePath(frame["depth_file_path"])
depth_fname = self._get_fname(depth_filepath, data_dir, downsample_folder_prefix="depths_")
depth_filenames.append(depth_fname)

if num_skipped_image_filenames >= 0:
CONSOLE.log(f"Skipping {num_skipped_image_filenames} files in dataset split {split}.")
assert (
Expand All @@ -166,6 +175,12 @@ def _generate_dataparser_outputs(self, split="train"):
Different number of image and mask filenames.
You should check that mask_path is specified for every frame (or zero frames) in transforms.json.
"""
assert len(depth_filenames) == 0 or (
len(depth_filenames) == len(image_filenames)
), """
Different number of image and depth filenames.
You should check that depth_file_path is specified for every frame (or zero frames) in transforms.json.
"""

# filter image_filenames and poses based on train/eval split percentage
num_images = len(image_filenames)
Expand Down Expand Up @@ -208,6 +223,7 @@ def _generate_dataparser_outputs(self, split="train"):
# Choose image_filenames and poses based on split, but after auto orient and scaling the poses.
image_filenames = [image_filenames[i] for i in indices]
mask_filenames = [mask_filenames[i] for i in indices] if len(mask_filenames) > 0 else []
depth_filenames = [depth_filenames[i] for i in indices] if len(depth_filenames) > 0 else []
poses = poses[indices]

# in x,y,z order
Expand Down Expand Up @@ -265,6 +281,10 @@ def _generate_dataparser_outputs(self, split="train"):
mask_filenames=mask_filenames if len(mask_filenames) > 0 else None,
dataparser_scale=scale_factor,
dataparser_transform=transform_matrix,
metadata={
"depth_filenames": depth_filenames if len(depth_filenames) > 0 else None,
"depth_unit_scale_factor": self.config.depth_unit_scale_factor,
},
)
return dataparser_outputs

Expand Down
54 changes: 54 additions & 0 deletions nerfstudio/data/datasets/depth_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2022 The Nerfstudio Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Depth dataset.
"""

from typing import Dict

from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs
from nerfstudio.data.datasets.base_dataset import InputDataset
from nerfstudio.data.utils.data_utils import get_depth_image_from_path


class DepthDataset(InputDataset):
"""Dataset that returns images and depths.
Args:
dataparser_outputs: description of where and how to read input images.
scale_factor: The scaling factor for the dataparser outputs.
"""

def __init__(self, dataparser_outputs: DataparserOutputs, scale_factor: float = 1.0):
super().__init__(dataparser_outputs, scale_factor)
assert (
"depth_filenames" in dataparser_outputs.metadata.keys()
and dataparser_outputs.metadata["depth_filenames"] is not None
)
self.depth_filenames = self.metadata["depth_filenames"]
self.depth_unit_scale_factor = self.metadata["depth_unit_scale_factor"]

def get_metadata(self, data: Dict) -> Dict:
filepath = self.depth_filenames[data["image_idx"]]
height = int(self._dataparser_outputs.cameras.height[data["image_idx"]])
width = int(self._dataparser_outputs.cameras.width[data["image_idx"]])

# Scale depth images to meter units and also by scaling applied to cameras
scale_factor = self.depth_unit_scale_factor * self._dataparser_outputs.dataparser_scale
depth_image = get_depth_image_from_path(
filepath=filepath, height=height, width=width, scale_factor=scale_factor
)

return {"depth_image": depth_image}
27 changes: 27 additions & 0 deletions nerfstudio/data/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pathlib import Path
from typing import List, Tuple, Union

import cv2
import numpy as np
import torch
from PIL import Image
Expand Down Expand Up @@ -51,3 +52,29 @@ def get_semantics_and_mask_tensors_from_path(
semantics = torch.from_numpy(np.array(pil_image, dtype="int64"))[..., None]
mask = torch.sum(semantics == mask_indices, dim=-1, keepdim=True) == 0
return semantics, mask


def get_depth_image_from_path(
filepath: Path,
height: int,
width: int,
scale_factor: float,
interpolation: int = cv2.INTER_NEAREST,
) -> torch.Tensor:
"""Loads, rescales and resizes depth images.
Assumes filepath points to a 16-bit or 32-bit depth image.
Args:
filepath: Path to depth image.
height: Target depth image height.
width: Target depth image width.
scale_factor: Factor by which to scale depth image.
interpolation: Depth value interpolation for resizing.
Returns:
Depth image torch tensor with shape [width, height, 1].
"""
image = cv2.imread(str(filepath.absolute()), cv2.IMREAD_ANYDEPTH)
image = image.astype(np.float64) * scale_factor
image = cv2.resize(image, (width, height), interpolation=interpolation)
return torch.from_numpy(image[:, :, np.newaxis])
Loading

0 comments on commit 76a40d1

Please sign in to comment.