Skip to content

Commit

Permalink
sdfstudio dataparser/dataset (nerfstudio-project#1381)
Browse files Browse the repository at this point in the history
* add downloads for sdfstudio datasets

* add sdfstudio parser to work with sdfstudio data

* fix licensing and doc strings

* fix linter errors

* move depths/normals to metadata, create sdfdataset

* remove extra scenebox parameters

* more linting errors fix

* fix wrong annotation

* add in missing depth/normal files

* minor fixes

---------
  • Loading branch information
pablovela5620 authored Feb 12, 2023
1 parent 71eacd5 commit a37c73f
Show file tree
Hide file tree
Showing 4 changed files with 327 additions and 0 deletions.
2 changes: 2 additions & 0 deletions nerfstudio/data/datamanagers/base_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from nerfstudio.data.dataparsers.phototourism_dataparser import (
PhototourismDataParserConfig,
)
from nerfstudio.data.dataparsers.sdfstudio_dataparser import SDFStudioDataParserConfig
from nerfstudio.data.datasets.base_dataset import InputDataset
from nerfstudio.data.pixel_samplers import EquirectangularPixelSampler, PixelSampler
from nerfstudio.data.utils.dataloaders import (
Expand All @@ -75,6 +76,7 @@
"dnerf-data": DNeRFDataParserConfig(),
"phototourism-data": PhototourismDataParserConfig(),
"dycheck-data": DycheckDataParserConfig(),
"sdfstudio-data": SDFStudioDataParserConfig(),
},
prefix_names=False, # Omit prefixes in subcommands themselves.
)
Expand Down
158 changes: 158 additions & 0 deletions nerfstudio/data/dataparsers/sdfstudio_dataparser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright 2022 The Nerfstudio Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Datapaser for sdfstudio formatted data"""

from __future__ import annotations

from dataclasses import dataclass, field
from pathlib import Path
from typing import Type

import torch
from rich.console import Console

from nerfstudio.cameras import camera_utils
from nerfstudio.cameras.cameras import Cameras, CameraType
from nerfstudio.data.dataparsers.base_dataparser import (
DataParser,
DataParserConfig,
DataparserOutputs,
)
from nerfstudio.data.scene_box import SceneBox
from nerfstudio.utils.io import load_from_json

CONSOLE = Console()


@dataclass
class SDFStudioDataParserConfig(DataParserConfig):
"""Scene dataset parser config"""

_target: Type = field(default_factory=lambda: SDFStudio)
"""target class to instantiate"""
data: Path = Path("data/DTU/scan65")
"""Directory specifying location of data."""
include_mono_prior: bool = False
"""whether or not to load monocular depth and normal """
include_foreground_mask: bool = False
"""whether or not to load foreground mask"""
downscale_factor: int = 1
scene_scale: float = 2.0
"""
Sets the bounding cube to have edge length of this size.
The longest dimension of the Friends axis-aligned bbox will be scaled to this value.
"""
skip_every_for_val_split: int = 1
"""sub sampling validation images"""
auto_orient: bool = False


@dataclass
class SDFStudio(DataParser):
"""SDFStudio Dataset"""

config: SDFStudioDataParserConfig

def _generate_dataparser_outputs(self, split="train"): # pylint: disable=unused-argument,too-many-statements
# load meta data
meta = load_from_json(self.config.data / "meta_data.json")

indices = list(range(len(meta["frames"])))
# subsample to avoid out-of-memory for validation set
if split != "train" and self.config.skip_every_for_val_split >= 1:
indices = indices[:: self.config.skip_every_for_val_split]

image_filenames = []
depth_filenames = []
normal_filenames = []
transform = None
fx = []
fy = []
cx = []
cy = []
camera_to_worlds = []
for i, frame in enumerate(meta["frames"]):
if i not in indices:
continue

image_filename = self.config.data / frame["rgb_path"]
depth_filename = self.config.data / frame["mono_depth_path"]
normal_filename = self.config.data / frame["mono_normal_path"]

intrinsics = torch.tensor(frame["intrinsics"])
camtoworld = torch.tensor(frame["camtoworld"])

# append data
image_filenames.append(image_filename)
depth_filenames.append(depth_filename)
normal_filenames.append(normal_filename)
fx.append(intrinsics[0, 0])
fy.append(intrinsics[1, 1])
cx.append(intrinsics[0, 2])
cy.append(intrinsics[1, 2])
camera_to_worlds.append(camtoworld)

fx = torch.stack(fx)
fy = torch.stack(fy)
cx = torch.stack(cx)
cy = torch.stack(cy)
camera_to_worlds = torch.stack(camera_to_worlds)

# Convert from COLMAP's/OPENCV's camera coordinate system to nerfstudio
camera_to_worlds[:, 0:3, 1:3] *= -1

if self.config.auto_orient:
camera_to_worlds, transform = camera_utils.auto_orient_and_center_poses(
camera_to_worlds,
method="up",
center_poses=False,
)

# scene box from meta data
meta_scene_box = meta["scene_box"]
aabb = torch.tensor(meta_scene_box["aabb"], dtype=torch.float32)
scene_box = SceneBox(
aabb=aabb,
)

height, width = meta["height"], meta["width"]
cameras = Cameras(
fx=fx,
fy=fy,
cx=cx,
cy=cy,
height=height,
width=width,
camera_to_worlds=camera_to_worlds[:, :3, :4],
camera_type=CameraType.PERSPECTIVE,
)

# TODO supports downsample
# cameras.rescale_output_resolution(scaling_factor=1.0 / self.config.downscale_factor)

assert meta["has_mono_prior"] == self.config.include_mono_prior, f"no mono prior in {self.config.data}"

dataparser_outputs = DataparserOutputs(
image_filenames=image_filenames,
cameras=cameras,
scene_box=scene_box,
metadata={
"depth_filenames": depth_filenames if len(depth_filenames) > 0 else None,
"normal_filenames": normal_filenames if len(normal_filenames) > 0 else None,
"transform": transform,
"camera_to_worlds": camera_to_worlds if len(camera_to_worlds) > 0 else None,
"include_mono_prior": self.config.include_mono_prior,
},
)
return dataparser_outputs
97 changes: 97 additions & 0 deletions nerfstudio/data/datasets/sdf_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 2022 The Nerfstudio Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
SDFStudio dataset.
"""

from pathlib import Path
from typing import Dict

import numpy as np
import torch

from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs
from nerfstudio.data.datasets.base_dataset import InputDataset


class SDFDataset(InputDataset):
"""Dataset that returns images and depths.
Args:
dataparser_outputs: description of where and how to read input images.
scale_factor: The scaling factor for the dataparser outputs.
"""

def __init__(self, dataparser_outputs: DataparserOutputs, scale_factor: float = 1.0):
super().__init__(dataparser_outputs, scale_factor)

# can be none if monoprior not included
self.depth_filenames = self.metadata["depth_filenames"]
self.normal_filenames = self.metadata["normal_filenames"]
self.camera_to_worlds = self.metadata["camera_to_worlds"]
# can be none if auto orient not enabled in dataparser
self.transform = self.metadata["normal_filenames"]
self.include_mono_prior = self.metadata["include_mono_prior"]

def get_metadata(self, data: Dict) -> Dict:
# TODO supports foreground_masks
metadata = {}
if self.include_mono_prior:
depth_filepath = self.depth_filenames[data["image_idx"]]
normal_filepath = self.normal_filenames[data["image_idx"]]
camtoworld = self.camera_to_worlds[data["image_idx"]]

# Scale depth images to meter units and also by scaling applied to cameras
depth_image, normal_image = self.get_depths_and_normals(
depth_filepath=depth_filepath, normal_filename=normal_filepath, camtoworld=camtoworld
)
metadata["depth_image"] = depth_image
metadata["normal_image"] = normal_image

return metadata

def get_depths_and_normals(self, depth_filepath: Path, normal_filename: Path, camtoworld: np.ndarray):
"""function to process additional depths and normal information
Args:
depth_filepath: path to depth file
normal_filename: path to normal file
camtoworld: camera to world transformation matrix
"""

# load mono depth
depth = np.load(depth_filepath)
depth = torch.from_numpy(depth).float()

# load mono normal
normal = np.load(normal_filename)

# transform normal to world coordinate system
normal = normal * 2.0 - 1.0 # omnidata output is normalized so we convert it back to normal here
normal = torch.from_numpy(normal).float()

rot = camtoworld[:3, :3]

normal_map = normal.reshape(3, -1)
normal_map = torch.nn.functional.normalize(normal_map, p=2, dim=0)

normal_map = rot @ normal_map
normal = normal_map.permute(1, 0).reshape(*normal.shape[1:], 3)

if self.transform is not None:
h, w, _ = normal.shape
normal = self.transform[:3, :3] @ normal.reshape(-1, 3).permute(1, 0)
normal = normal.permute(1, 0).reshape(h, w, 3)

return depth, normal
70 changes: 70 additions & 0 deletions scripts/downloads/download_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,83 @@ def download(self, save_dir: Path):
os.remove(download_path)


# credit to https://autonomousvision.github.io/sdfstudio/
# pylint: disable=line-too-long
sdfstudio_downloads = {
"sdfstudio-demo-data": "https://s3.eu-central-1.amazonaws.com/avg-projects/monosdf/data/sdfstudio-demo-data.tar",
"dtu": "https://s3.eu-central-1.amazonaws.com/avg-projects/monosdf/data/DTU.tar",
"replica": "https://s3.eu-central-1.amazonaws.com/avg-projects/monosdf/data/Replica.tar",
"scannet": "https://s3.eu-central-1.amazonaws.com/avg-projects/monosdf/data/scannet.tar",
"tanks-and-temple": "https://s3.eu-central-1.amazonaws.com/avg-projects/monosdf/data/tnt_advanced.tar",
"tanks-and-temple-highres": "https://s3.eu-central-1.amazonaws.com/avg-projects/monosdf/data/highresTNT.tar",
"heritage": "https://s3.eu-central-1.amazonaws.com/avg-projects/monosdf/data/Heritage-Recon.tar",
"neural-rgbd-data": "http://kaldir.vc.in.tum.de/neural_rgbd/neural_rgbd_data.zip",
"all": None,
}

SDFstudioCaptureName = tyro.extras.literal_type_from_choices(sdfstudio_downloads.keys())


@dataclass
class SDFstudioDemoDownload(DatasetDownload):
"""Download the sdfstudio dataset."""

dataset_name: SDFstudioCaptureName = "sdfstudio-demo-data"

def download(self, save_dir: Path):
"""Download the D-NeRF dataset (https://github.com/albertpumarola/D-NeRF)."""
# TODO: give this code the same structure as download_nerfstudio

if self.dataset_name == "all":
for dataset_name in sdfstudio_downloads:
if dataset_name != "all":
SDFstudioDemoDownload(dataset_name=dataset_name).download(save_dir)
return

assert (
self.dataset_name in sdfstudio_downloads
), f"Capture name {self.dataset_name} not found in {sdfstudio_downloads.keys()}"

url = sdfstudio_downloads[self.dataset_name]

target_path = str(save_dir / self.dataset_name)
os.makedirs(target_path, exist_ok=True)

file_format = url[-4:]

download_path = Path(f"{target_path}{file_format}")
tmp_path = str(save_dir / ".temp")
shutil.rmtree(tmp_path, ignore_errors=True)
os.makedirs(tmp_path, exist_ok=True)

os.system(f"curl -L {url} > {download_path}")
if file_format == ".tar":
with tarfile.open(download_path, "r") as tar_ref:
tar_ref.extractall(str(tmp_path))
elif file_format == ".zip":
with zipfile.ZipFile(download_path, "r") as zip_ref:
zip_ref.extractall(str(target_path))
return
else:
raise NotImplementedError

inner_folders = os.listdir(tmp_path)
assert len(inner_folders) == 1, "There is more than one folder inside this zip file."
folder = os.path.join(tmp_path, inner_folders[0])
shutil.rmtree(target_path)
shutil.move(folder, target_path)
shutil.rmtree(tmp_path)
os.remove(download_path)


Commands = Union[
Annotated[BlenderDownload, tyro.conf.subcommand(name="blender")],
Annotated[FriendsDownload, tyro.conf.subcommand(name="friends")],
Annotated[NerfstudioDownload, tyro.conf.subcommand(name="nerfstudio")],
Annotated[Record3dDownload, tyro.conf.subcommand(name="record3d")],
Annotated[DNerfDownload, tyro.conf.subcommand(name="dnerf")],
Annotated[PhototourismDownload, tyro.conf.subcommand(name="phototourism")],
Annotated[SDFstudioDemoDownload, tyro.conf.subcommand(name="sdfstudio")],
]


Expand Down

0 comments on commit a37c73f

Please sign in to comment.