Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save out dataparser transforms #1105

Merged
merged 2 commits into from
Dec 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions nerfstudio/cameras/camera_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def rotation_matrix(a: TensorType[3], b: TensorType[3]) -> TensorType[3, 3]:

def auto_orient_and_center_poses(
poses: TensorType["num_poses":..., 4, 4], method: Literal["pca", "up", "none"] = "up", center_poses: bool = True
) -> TensorType["num_poses":..., 3, 4]:
) -> Tuple[TensorType["num_poses":..., 3, 4], TensorType[4, 4]]:
"""Orients and centers the poses. We provide two methods for orientation: pca and up.
pca: Orient the poses so that the principal component of the points is aligned with the axes.
Expand All @@ -424,7 +424,7 @@ def auto_orient_and_center_poses(
center_poses: If True, the poses are centered around the origin.
Returns:
The oriented poses.
Tuple of the oriented poses and the transform matrix.
"""

translation = poses[..., :3, 3]
Expand Down Expand Up @@ -457,7 +457,9 @@ def auto_orient_and_center_poses(
transform = torch.cat([rotation, rotation @ -translation[..., None]], dim=-1)
oriented_poses = transform @ poses
elif method == "none":
oriented_poses = poses[:, :3]
oriented_poses[..., 3] -= translation
transform = torch.eye(4)
transform[:3, 3] = -translation
transform = transform[:3, :]
oriented_poses = transform @ poses

return oriented_poses
return oriented_poses, transform
7 changes: 4 additions & 3 deletions nerfstudio/data/datamanagers/base_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from nerfstudio.cameras.cameras import CameraType
from nerfstudio.cameras.rays import RayBundle
from nerfstudio.configs.base_config import InstantiateConfig
from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs
from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig
from nerfstudio.data.dataparsers.dnerf_dataparser import DNeRFDataParserConfig
from nerfstudio.data.dataparsers.friends_dataparser import FriendsDataParserConfig
Expand All @@ -46,7 +47,6 @@
from nerfstudio.data.dataparsers.phototourism_dataparser import (
PhototourismDataParserConfig,
)
from nerfstudio.data.dataparsers.record3d_dataparser import Record3DDataParserConfig
from nerfstudio.data.datasets.base_dataset import InputDataset
from nerfstudio.data.pixel_samplers import EquirectangularPixelSampler, PixelSampler
from nerfstudio.data.utils.dataloaders import (
Expand All @@ -69,7 +69,6 @@
"friends-data": FriendsDataParserConfig(),
"instant-ngp-data": InstantNGPDataParserConfig(),
"nuscenes-data": NuScenesDataParserConfig(),
"record3d-data": Record3DDataParserConfig(),
"dnerf-data": DNeRFDataParserConfig(),
"phototourism-data": PhototourismDataParserConfig(),
},
Expand Down Expand Up @@ -296,6 +295,7 @@ class VanillaDataManager(DataManager): # pylint: disable=abstract-method
config: VanillaDataManagerConfig
train_dataset: InputDataset
eval_dataset: InputDataset
train_dataparser_outputs: DataparserOutputs

def __init__(
self,
Expand All @@ -321,8 +321,9 @@ def __init__(

def create_train_dataset(self) -> InputDataset:
"""Sets up the data loaders for training"""
self.train_dataparser_outputs = self.dataparser.get_dataparser_outputs(split="train")
return InputDataset(
dataparser_outputs=self.dataparser.get_dataparser_outputs(split="train"),
dataparser_outputs=self.train_dataparser_outputs,
scale_factor=self.config.camera_res_scale_factor,
)

Expand Down
21 changes: 21 additions & 0 deletions nerfstudio/data/dataparsers/base_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import json
from abc import abstractmethod
from dataclasses import dataclass, field
from pathlib import Path
Expand Down Expand Up @@ -63,11 +64,31 @@ class DataparserOutputs:
"""Dictionary of any metadata that be required for the given experiment.
Will be processed by the InputDataset to create any additional tensors that may be required.
"""
dataparser_transform: TensorType[3, 4] = torch.eye(4)[:3, :]
"""Transform applied by the dataparser."""
dataparser_scale: float = 1.0
"""Scale applied by the dataparser."""

def as_dict(self) -> dict:
"""Returns the dataclass as a dictionary."""
return vars(self)

def save_dataparser_transform(self, path: Path):
"""Save dataparser transform to json file. Some dataparsers will apply a transform to the poses,
this method allows the transform to be saved so that it can be used in other applications.
Args:
path: path to save transform to
"""
data = {
"transform": self.dataparser_transform.tolist(),
"scale": self.dataparser_scale,
}
if not path.parent.exists():
path.parent.mkdir(parents=True)
with open(path, "w", encoding="UTF-8") as file:
json.dump(data, file, indent=4)


@dataclass
class DataParserConfig(cfg.InstantiateConfig):
Expand Down
1 change: 1 addition & 0 deletions nerfstudio/data/dataparsers/blender_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def _generate_dataparser_outputs(self, split="train"):
cameras=cameras,
alpha_color=alpha_color_tensor,
scene_box=scene_box,
dataparser_scale=self.scale_factor,
)

return dataparser_outputs
6 changes: 5 additions & 1 deletion nerfstudio/data/dataparsers/dnerf_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,11 @@ def _generate_dataparser_outputs(self, split="train"):
)

dataparser_outputs = DataparserOutputs(
image_filenames=image_filenames, cameras=cameras, alpha_color=alpha_color_tensor, scene_box=scene_box
image_filenames=image_filenames,
cameras=cameras,
alpha_color=alpha_color_tensor,
scene_box=scene_box,
dataparser_scale=self.scale_factor,
)

return dataparser_outputs
1 change: 1 addition & 0 deletions nerfstudio/data/dataparsers/friends_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,5 +146,6 @@ def _generate_dataparser_outputs(self, split="train"): # pylint: disable=unused
cameras=cameras,
scene_box=scene_box,
metadata={"semantics": semantics} if self.config.include_semantics else {},
dataparser_scale=scale,
)
return dataparser_outputs
1 change: 1 addition & 0 deletions nerfstudio/data/dataparsers/instant_ngp_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def _generate_dataparser_outputs(self, split="train"):
image_filenames=image_filenames,
cameras=cameras,
scene_box=scene_box,
dataparser_scale=self.config.scene_scale,
)

return dataparser_outputs
Expand Down
9 changes: 6 additions & 3 deletions nerfstudio/data/dataparsers/nerfstudio_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _generate_dataparser_outputs(self, split="train"):
orientation_method = self.config.orientation_method

poses = torch.from_numpy(np.array(poses).astype(np.float32))
poses = camera_utils.auto_orient_and_center_poses(
poses, transform_matrix = camera_utils.auto_orient_and_center_poses(
poses,
method=orientation_method,
center_poses=self.config.center_poses,
Expand All @@ -190,9 +190,10 @@ def _generate_dataparser_outputs(self, split="train"):
# Scale poses
scale_factor = 1.0
if self.config.auto_scale_poses:
scale_factor /= torch.max(torch.abs(poses[:, :3, 3]))
scale_factor /= float(torch.max(torch.abs(poses[:, :3, 3])))
scale_factor *= self.config.scale_factor

poses[:, :3, 3] *= scale_factor * self.config.scale_factor
poses[:, :3, 3] *= scale_factor

# Choose image_filenames and poses based on split, but after auto orient and scaling the poses.
image_filenames = [image_filenames[i] for i in indices]
Expand Down Expand Up @@ -252,6 +253,8 @@ def _generate_dataparser_outputs(self, split="train"):
cameras=cameras,
scene_box=scene_box,
mask_filenames=mask_filenames if len(mask_filenames) > 0 else None,
dataparser_scale=scale_factor,
dataparser_transform=transform_matrix,
)
return dataparser_outputs

Expand Down
9 changes: 6 additions & 3 deletions nerfstudio/data/dataparsers/phototourism_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,16 +137,17 @@ def _generate_dataparser_outputs(self, split="train"):
else:
raise ValueError(f"Unknown dataparser split {split}")

poses = camera_utils.auto_orient_and_center_poses(
poses, transform_matrix = camera_utils.auto_orient_and_center_poses(
poses, method=self.config.orientation_method, center_poses=self.config.center_poses
)

# Scale poses
scale_factor = 1.0
if self.config.auto_scale_poses:
scale_factor /= torch.max(torch.abs(poses[:, :3, 3]))
scale_factor /= float(torch.max(torch.abs(poses[:, :3, 3])))
scale_factor *= self.config.scale_factor

poses[:, :3, 3] *= scale_factor * self.config.scale_factor
poses[:, :3, 3] *= scale_factor

# in x,y,z order
# assumes that the scene is centered at the origin
Expand Down Expand Up @@ -175,6 +176,8 @@ def _generate_dataparser_outputs(self, split="train"):
image_filenames=image_filenames,
cameras=cameras,
scene_box=scene_box,
dataparser_scale=scale_factor,
dataparser_transform=transform_matrix,
)

return dataparser_outputs
160 changes: 0 additions & 160 deletions nerfstudio/data/dataparsers/record3d_dataparser.py

This file was deleted.

4 changes: 4 additions & 0 deletions nerfstudio/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ def train(self) -> None:
"""Train the model."""
assert self.pipeline.datamanager.train_dataset is not None, "Missing DatsetInputs"

self.pipeline.datamanager.train_dataparser_outputs.save_dataparser_transform(
self.base_dir / "dataparser_transforms.json"
)

self._init_viewer_state()
with TimeWriter(writer, EventName.TOTAL_TRAIN_TIME):
num_iterations = self.config.trainer.max_num_iterations
Expand Down