Skip to content

Commit

Permalink
Merge branch 'main' into ethan/adding_normals
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanweber committed Nov 6, 2022
2 parents 62b20bd + 1a24f3e commit 89a2c94
Show file tree
Hide file tree
Showing 19 changed files with 229 additions and 197 deletions.
2 changes: 0 additions & 2 deletions docs/developer_guides/pipelines/fields.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ class FieldHeadNames(Enum):
TRANSIENT_RGB = "transient_rgb"
TRANSIENT_DENSITY = "transient_density"
SEMANTICS = "semantics"
SEMANTICS_STUFF = "semantics_stuff"
SEMANTICS_THING = "semantics_thing"
```

```{button-link} https://github.com/nerfstudio-project/nerfstudio/blob/master/nerfstudio/field_components/field_heads.py
Expand Down
2 changes: 1 addition & 1 deletion docs/reference/api/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Data Parsers
Datasets
----------------

.. automodule:: nerfstudio.data.utils.datasets
.. automodule:: nerfstudio.data.datasets
:members:
:show-inheritance:

Expand Down
5 changes: 3 additions & 2 deletions nerfstudio/configs/method_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
TrainerConfig,
ViewerConfig,
)
from nerfstudio.data.datamanagers import VanillaDataManagerConfig
from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManagerConfig
from nerfstudio.data.datamanagers.semantic_datamanager import SemanticDataManagerConfig
from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig
from nerfstudio.data.dataparsers.friends_dataparser import FriendsDataParserConfig
from nerfstudio.data.dataparsers.nerfstudio_dataparser import NerfstudioDataParserConfig
Expand Down Expand Up @@ -127,7 +128,7 @@
method_name="semantic-nerfw",
trainer=TrainerConfig(steps_per_eval_batch=500, steps_per_save=2000, mixed_precision=True),
pipeline=VanillaPipelineConfig(
datamanager=VanillaDataManagerConfig(
datamanager=SemanticDataManagerConfig(
dataparser=FriendsDataParserConfig(), train_num_rays_per_batch=4096, eval_num_rays_per_batch=8192
),
model=SemanticNerfWModelConfig(eval_num_rays_per_chunk=1 << 16),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@
)
from nerfstudio.data.dataparsers.nerfstudio_dataparser import NerfstudioDataParserConfig
from nerfstudio.data.dataparsers.record3d_dataparser import Record3DDataParserConfig
from nerfstudio.data.datasets.base_dataset import InputDataset
from nerfstudio.data.pixel_samplers import PixelSampler
from nerfstudio.data.utils.dataloaders import (
CacheDataloader,
FixedIndicesEvalDataloader,
RandIndicesEvalDataloader,
)
from nerfstudio.data.utils.datasets import InputDataset
from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes
from nerfstudio.model_components.ray_generators import RayGenerator
from nerfstudio.utils.misc import IterableWrapper
Expand Down Expand Up @@ -296,13 +296,22 @@ def __init__(
self.world_size = world_size
self.local_rank = local_rank
self.sampler = None
self.test_mode = test_mode

self.train_dataset = InputDataset(config.dataparser.setup().get_dataparser_outputs(split="train"))
self.eval_dataset = InputDataset(
config.dataparser.setup().get_dataparser_outputs(split="val" if not test_mode else "test")
)
self.train_dataset = self.create_train_dataset()
self.eval_dataset = self.create_eval_dataset()
super().__init__()

def create_train_dataset(self) -> InputDataset:
"""Sets up the data loaders for training"""
return InputDataset(self.config.dataparser.setup().get_dataparser_outputs(split="train"))

def create_eval_dataset(self) -> InputDataset:
"""Sets up the data loaders for evaluation"""
return InputDataset(
self.config.dataparser.setup().get_dataparser_outputs(split="val" if not self.test_mode else "test")
)

def setup_train(self):
"""Sets up the data loaders for training"""
assert self.train_dataset is not None
Expand Down
45 changes: 45 additions & 0 deletions nerfstudio/data/datamanagers/semantic_datamanager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2022 The Nerfstudio Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from typing import Type

from nerfstudio.data.datamanagers.base_datamanager import (
VanillaDataManager,
VanillaDataManagerConfig,
)
from nerfstudio.data.datasets.semantic_dataset import SemanticDataset


@dataclass
class SemanticDataManagerConfig(VanillaDataManagerConfig):
"""A semantic datamanager - required to use with .setup()"""

_target: Type = field(default_factory=lambda: SemanticDataManager)


class SemanticDataManager(VanillaDataManager): # pylint: disable=abstract-method
"""Data manager implementation for data that also requires processing semantic data.
Args:
config: the DataManagerConfig used to instantiate class
"""

def create_train_dataset(self) -> SemanticDataset:
return SemanticDataset(self.config.dataparser.setup().get_dataparser_outputs(split="train"))

def create_eval_dataset(self) -> SemanticDataset:
return SemanticDataset(
self.config.dataparser.setup().get_dataparser_outputs(split="val" if not self.test_mode else "test")
)
35 changes: 13 additions & 22 deletions nerfstudio/data/dataparsers/base_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,14 @@
class Semantics:
"""Dataclass for semantic labels."""

stuff_filenames: List[Path]
"""filenames to load "stuff"/background data"""
stuff_classes: List[str]
"""class labels for "stuff" data"""
stuff_colors: torch.Tensor
"""color mapping for "stuff" classes"""
thing_filenames: List[Path]
"""filenames to load "thing"/foreground data"""
thing_classes: List[str]
"""class labels for "thing" data"""
thing_colors: torch.Tensor
"""color mapping for "thing" classes"""
filenames: List[Path]
"""filenames to load semantic data"""
classes: List[str]
"""class labels for data"""
colors: torch.Tensor
"""color mapping for classes"""
mask_classes: List[str] = field(default_factory=lambda: [])
"""classes to mask out from training for all modalities"""


@dataclass
Expand All @@ -61,16 +57,11 @@ class DataparserOutputs:
"""Color of dataset background."""
scene_box: SceneBox = SceneBox()
"""Scene box of dataset. Used to bound the scene or provide the scene scale depending on model."""
semantics: Optional[Semantics] = None
"""Semantics information."""
times: Optional[TensorType[1]] = None
"""Time in range [0,1] for when each image was taken."""
additional_inputs: Dict[str, Any] = to_immutable_dict({})
"""Dictionary of additional dataset information (e.g. semantics/point clouds/masks).
{input_name:
... {"func": function to process additional dataparser outputs,
... "kwargs": dictionary of data to pass into "func"}
}
mask_filenames: Optional[List[Path]] = None
"""Filenames for any masks that are required"""
metadata: Dict[str, Any] = to_immutable_dict({})
"""Dictionary of any metadata that be required for the given experiment.
Will be processed by the InputDataset to create any additional tensors that may be required.
"""

def as_dict(self) -> dict:
Expand Down
56 changes: 5 additions & 51 deletions nerfstudio/data/dataparsers/friends_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
from pathlib import Path
from typing import Type

import numpy as np
import torch
from PIL import Image
from rich.console import Console

from nerfstudio.cameras.cameras import Cameras, CameraType
Expand All @@ -37,31 +35,6 @@
CONSOLE = Console()


def get_semantics_and_masks(image_idx: int, semantics: Semantics):
"""function to process additional semantics and mask information
Args:
image_idx: specific image index to work with
semantics: semantics data
"""
# handle mask
person_index = semantics.thing_classes.index("person")
thing_image_filename = semantics.thing_filenames[image_idx]
pil_image = Image.open(thing_image_filename)
thing_semantics = torch.from_numpy(np.array(pil_image, dtype="int32"))[..., None]
mask = (thing_semantics != person_index).to(torch.float32) # 1 where valid
# handle semantics
# stuff
stuff_image_filename = semantics.stuff_filenames[image_idx]
pil_image = Image.open(stuff_image_filename)
stuff_semantics = torch.from_numpy(np.array(pil_image, dtype="int32"))[..., None]
# thing
thing_image_filename = semantics.thing_filenames[image_idx]
pil_image = Image.open(thing_image_filename)
thing_semantics = torch.from_numpy(np.array(pil_image, dtype="int32"))[..., None]
return {"mask": mask, "semantics_stuff": stuff_semantics, "semantics_thing": thing_semantics}


@dataclass
class FriendsDataParserConfig(DataParserConfig):
"""Friends dataset parser config"""
Expand Down Expand Up @@ -142,37 +115,19 @@ def _generate_dataparser_outputs(self, split="train"): # pylint: disable=unused
camera_to_worlds[..., 3] *= scale # cameras

# --- semantics ---
semantics = None
if self.config.include_semantics:
thing_filenames = [
filenames = [
Path(
str(image_filename)
.replace(f"/{images_folder}/", f"/{segmentations_folder}/thing/")
.replace(".jpg", ".png")
)
for image_filename in image_filenames
]
stuff_filenames = [
Path(
str(image_filename)
.replace(f"/{images_folder}/", f"/{segmentations_folder}/stuff/")
.replace(".jpg", ".png")
)
for image_filename in image_filenames
]
panoptic_classes = load_from_json(self.config.data / "panoptic_classes.json")
stuff_classes = panoptic_classes["stuff"]
stuff_colors = torch.tensor(panoptic_classes["stuff_colors"], dtype=torch.float32) / 255.0
thing_classes = panoptic_classes["thing"]
thing_colors = torch.tensor(panoptic_classes["thing_colors"], dtype=torch.float32) / 255.0
semantics = Semantics(
stuff_classes=stuff_classes,
stuff_colors=stuff_colors,
stuff_filenames=stuff_filenames,
thing_classes=thing_classes,
thing_colors=thing_colors,
thing_filenames=thing_filenames,
)
classes = panoptic_classes["thing"]
colors = torch.tensor(panoptic_classes["thing_colors"], dtype=torch.float32) / 255.0
semantics = Semantics(filenames=filenames, classes=classes, colors=colors, mask_classes=["person"])

assert torch.all(cx[0] == cx), "Not all cameras have the same cx. Our Cameras class does not support this."
assert torch.all(cy[0] == cy), "Not all cameras have the same cy. Our Cameras class does not support this."
Expand All @@ -191,7 +146,6 @@ def _generate_dataparser_outputs(self, split="train"): # pylint: disable=unused
image_filenames=image_filenames,
cameras=cameras,
scene_box=scene_box,
additional_inputs={"semantics": {"func": get_semantics_and_masks, "kwargs": {"semantics": semantics}}},
semantics=semantics,
metadata={"semantics": semantics} if self.config.include_semantics else {},
)
return dataparser_outputs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torchtyping import TensorType

from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs
from nerfstudio.data.utils.data_utils import get_image_mask_tensor_from_path


class InputDataset(Dataset):
Expand All @@ -39,6 +40,7 @@ class InputDataset(Dataset):
def __init__(self, dataparser_outputs: DataparserOutputs):
super().__init__()
self.dataparser_outputs = dataparser_outputs
self.has_masks = self.dataparser_outputs.mask_filenames is not None

def __len__(self):
return len(self.dataparser_outputs.image_filenames)
Expand Down Expand Up @@ -71,7 +73,7 @@ def get_image(self, image_idx: int) -> TensorType["image_height", "image_width",
image = image[:, :, :3]
return image

def get_data(self, image_idx) -> Dict:
def get_data(self, image_idx: int) -> Dict:
"""Returns the ImageDataset data as a dictionary.
Args:
Expand All @@ -80,13 +82,23 @@ def get_data(self, image_idx) -> Dict:
image = self.get_image(image_idx)
data = {"image_idx": image_idx}
data["image"] = image
for _, data_func_dict in self.dataparser_outputs.additional_inputs.items():
assert "func" in data_func_dict, "Missing function to process data: specify `func` in `additional_inputs`"
func = data_func_dict["func"]
assert "kwargs" in data_func_dict, "No data to process: specify `kwargs` in `additional_inputs`"
data.update(func(image_idx, **data_func_dict["kwargs"]))
if self.has_masks:
mask_filepath = self.dataparser_outputs.mask_filenames[image_idx]
data["mask"] = get_image_mask_tensor_from_path(filepath=mask_filepath)
metadata = self.get_metadata(data)
data.update(metadata)
return data

def __getitem__(self, image_idx):
# pylint: disable=no-self-use
def get_metadata(self, data: Dict) -> Dict:
"""Method that can be used to process any additional metadata that may be part of the model inputs.
Args:
image_idx: The image index in the dataset.
"""
del data
return {}

def __getitem__(self, image_idx: int) -> Dict:
data = self.get_data(image_idx)
return data
49 changes: 49 additions & 0 deletions nerfstudio/data/datasets/semantic_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2022 The Nerfstudio Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict

import torch

from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs, Semantics
from nerfstudio.data.datasets.base_dataset import InputDataset
from nerfstudio.data.utils.data_utils import get_semantics_and_mask_tensors_from_path


class SemanticDataset(InputDataset):
"""Dataset that returns images and semantics and masks.
Args:
dataparser_outputs: description of where and how to read input images.
"""

def __init__(self, dataparser_outputs: DataparserOutputs):
super().__init__(dataparser_outputs)
assert "semantics" in dataparser_outputs.metadata.keys() and isinstance(
dataparser_outputs.metadata["semantics"], Semantics
)
self.semantics = dataparser_outputs.metadata["semantics"]
self.mask_indices = torch.tensor(
[self.semantics.classes.index(mask_class) for mask_class in self.semantics.mask_classes]
).view(1, 1, -1)

def get_metadata(self, data: Dict) -> Dict:
# handle mask
filepath = self.semantics.filenames[data["image_idx"]]
semantic_label, mask = get_semantics_and_mask_tensors_from_path(
filepath=filepath, mask_indices=self.mask_indices
)
if "mask" in data.keys():
mask = mask & data["mask"]
return {"mask": mask, "semantics": semantic_label}
Loading

0 comments on commit 89a2c94

Please sign in to comment.