Skip to content

Commit

Permalink
Save out dataparser transforms (#1105)
Browse files Browse the repository at this point in the history
* Save out dataparser transforms

* Make missing folder
  • Loading branch information
tancik authored Dec 11, 2022
1 parent 90447ab commit b41ce80
Show file tree
Hide file tree
Showing 11 changed files with 56 additions and 175 deletions.
12 changes: 7 additions & 5 deletions nerfstudio/cameras/camera_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def rotation_matrix(a: TensorType[3], b: TensorType[3]) -> TensorType[3, 3]:

def auto_orient_and_center_poses(
poses: TensorType["num_poses":..., 4, 4], method: Literal["pca", "up", "none"] = "up", center_poses: bool = True
) -> TensorType["num_poses":..., 3, 4]:
) -> Tuple[TensorType["num_poses":..., 3, 4], TensorType[4, 4]]:
"""Orients and centers the poses. We provide two methods for orientation: pca and up.
pca: Orient the poses so that the principal component of the points is aligned with the axes.
Expand All @@ -424,7 +424,7 @@ def auto_orient_and_center_poses(
center_poses: If True, the poses are centered around the origin.
Returns:
The oriented poses.
Tuple of the oriented poses and the transform matrix.
"""

translation = poses[..., :3, 3]
Expand Down Expand Up @@ -457,7 +457,9 @@ def auto_orient_and_center_poses(
transform = torch.cat([rotation, rotation @ -translation[..., None]], dim=-1)
oriented_poses = transform @ poses
elif method == "none":
oriented_poses = poses[:, :3]
oriented_poses[..., 3] -= translation
transform = torch.eye(4)
transform[:3, 3] = -translation
transform = transform[:3, :]
oriented_poses = transform @ poses

return oriented_poses
return oriented_poses, transform
7 changes: 4 additions & 3 deletions nerfstudio/data/datamanagers/base_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from nerfstudio.cameras.cameras import CameraType
from nerfstudio.cameras.rays import RayBundle
from nerfstudio.configs.base_config import InstantiateConfig
from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs
from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig
from nerfstudio.data.dataparsers.dnerf_dataparser import DNeRFDataParserConfig
from nerfstudio.data.dataparsers.friends_dataparser import FriendsDataParserConfig
Expand All @@ -46,7 +47,6 @@
from nerfstudio.data.dataparsers.phototourism_dataparser import (
PhototourismDataParserConfig,
)
from nerfstudio.data.dataparsers.record3d_dataparser import Record3DDataParserConfig
from nerfstudio.data.datasets.base_dataset import InputDataset
from nerfstudio.data.pixel_samplers import EquirectangularPixelSampler, PixelSampler
from nerfstudio.data.utils.dataloaders import (
Expand All @@ -69,7 +69,6 @@
"friends-data": FriendsDataParserConfig(),
"instant-ngp-data": InstantNGPDataParserConfig(),
"nuscenes-data": NuScenesDataParserConfig(),
"record3d-data": Record3DDataParserConfig(),
"dnerf-data": DNeRFDataParserConfig(),
"phototourism-data": PhototourismDataParserConfig(),
},
Expand Down Expand Up @@ -296,6 +295,7 @@ class VanillaDataManager(DataManager): # pylint: disable=abstract-method
config: VanillaDataManagerConfig
train_dataset: InputDataset
eval_dataset: InputDataset
train_dataparser_outputs: DataparserOutputs

def __init__(
self,
Expand All @@ -321,8 +321,9 @@ def __init__(

def create_train_dataset(self) -> InputDataset:
"""Sets up the data loaders for training"""
self.train_dataparser_outputs = self.dataparser.get_dataparser_outputs(split="train")
return InputDataset(
dataparser_outputs=self.dataparser.get_dataparser_outputs(split="train"),
dataparser_outputs=self.train_dataparser_outputs,
scale_factor=self.config.camera_res_scale_factor,
)

Expand Down
21 changes: 21 additions & 0 deletions nerfstudio/data/dataparsers/base_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import json
from abc import abstractmethod
from dataclasses import dataclass, field
from pathlib import Path
Expand Down Expand Up @@ -63,11 +64,31 @@ class DataparserOutputs:
"""Dictionary of any metadata that be required for the given experiment.
Will be processed by the InputDataset to create any additional tensors that may be required.
"""
dataparser_transform: TensorType[3, 4] = torch.eye(4)[:3, :]
"""Transform applied by the dataparser."""
dataparser_scale: float = 1.0
"""Scale applied by the dataparser."""

def as_dict(self) -> dict:
"""Returns the dataclass as a dictionary."""
return vars(self)

def save_dataparser_transform(self, path: Path):
"""Save dataparser transform to json file. Some dataparsers will apply a transform to the poses,
this method allows the transform to be saved so that it can be used in other applications.
Args:
path: path to save transform to
"""
data = {
"transform": self.dataparser_transform.tolist(),
"scale": self.dataparser_scale,
}
if not path.parent.exists():
path.parent.mkdir(parents=True)
with open(path, "w", encoding="UTF-8") as file:
json.dump(data, file, indent=4)


@dataclass
class DataParserConfig(cfg.InstantiateConfig):
Expand Down
1 change: 1 addition & 0 deletions nerfstudio/data/dataparsers/blender_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def _generate_dataparser_outputs(self, split="train"):
cameras=cameras,
alpha_color=alpha_color_tensor,
scene_box=scene_box,
dataparser_scale=self.scale_factor,
)

return dataparser_outputs
6 changes: 5 additions & 1 deletion nerfstudio/data/dataparsers/dnerf_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,11 @@ def _generate_dataparser_outputs(self, split="train"):
)

dataparser_outputs = DataparserOutputs(
image_filenames=image_filenames, cameras=cameras, alpha_color=alpha_color_tensor, scene_box=scene_box
image_filenames=image_filenames,
cameras=cameras,
alpha_color=alpha_color_tensor,
scene_box=scene_box,
dataparser_scale=self.scale_factor,
)

return dataparser_outputs
1 change: 1 addition & 0 deletions nerfstudio/data/dataparsers/friends_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,5 +146,6 @@ def _generate_dataparser_outputs(self, split="train"): # pylint: disable=unused
cameras=cameras,
scene_box=scene_box,
metadata={"semantics": semantics} if self.config.include_semantics else {},
dataparser_scale=scale,
)
return dataparser_outputs
1 change: 1 addition & 0 deletions nerfstudio/data/dataparsers/instant_ngp_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def _generate_dataparser_outputs(self, split="train"):
image_filenames=image_filenames,
cameras=cameras,
scene_box=scene_box,
dataparser_scale=self.config.scene_scale,
)

return dataparser_outputs
Expand Down
9 changes: 6 additions & 3 deletions nerfstudio/data/dataparsers/nerfstudio_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _generate_dataparser_outputs(self, split="train"):
orientation_method = self.config.orientation_method

poses = torch.from_numpy(np.array(poses).astype(np.float32))
poses = camera_utils.auto_orient_and_center_poses(
poses, transform_matrix = camera_utils.auto_orient_and_center_poses(
poses,
method=orientation_method,
center_poses=self.config.center_poses,
Expand All @@ -190,9 +190,10 @@ def _generate_dataparser_outputs(self, split="train"):
# Scale poses
scale_factor = 1.0
if self.config.auto_scale_poses:
scale_factor /= torch.max(torch.abs(poses[:, :3, 3]))
scale_factor /= float(torch.max(torch.abs(poses[:, :3, 3])))
scale_factor *= self.config.scale_factor

poses[:, :3, 3] *= scale_factor * self.config.scale_factor
poses[:, :3, 3] *= scale_factor

# Choose image_filenames and poses based on split, but after auto orient and scaling the poses.
image_filenames = [image_filenames[i] for i in indices]
Expand Down Expand Up @@ -252,6 +253,8 @@ def _generate_dataparser_outputs(self, split="train"):
cameras=cameras,
scene_box=scene_box,
mask_filenames=mask_filenames if len(mask_filenames) > 0 else None,
dataparser_scale=scale_factor,
dataparser_transform=transform_matrix,
)
return dataparser_outputs

Expand Down
9 changes: 6 additions & 3 deletions nerfstudio/data/dataparsers/phototourism_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,16 +137,17 @@ def _generate_dataparser_outputs(self, split="train"):
else:
raise ValueError(f"Unknown dataparser split {split}")

poses = camera_utils.auto_orient_and_center_poses(
poses, transform_matrix = camera_utils.auto_orient_and_center_poses(
poses, method=self.config.orientation_method, center_poses=self.config.center_poses
)

# Scale poses
scale_factor = 1.0
if self.config.auto_scale_poses:
scale_factor /= torch.max(torch.abs(poses[:, :3, 3]))
scale_factor /= float(torch.max(torch.abs(poses[:, :3, 3])))
scale_factor *= self.config.scale_factor

poses[:, :3, 3] *= scale_factor * self.config.scale_factor
poses[:, :3, 3] *= scale_factor

# in x,y,z order
# assumes that the scene is centered at the origin
Expand Down Expand Up @@ -175,6 +176,8 @@ def _generate_dataparser_outputs(self, split="train"):
image_filenames=image_filenames,
cameras=cameras,
scene_box=scene_box,
dataparser_scale=scale_factor,
dataparser_transform=transform_matrix,
)

return dataparser_outputs
160 changes: 0 additions & 160 deletions nerfstudio/data/dataparsers/record3d_dataparser.py

This file was deleted.

4 changes: 4 additions & 0 deletions nerfstudio/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ def train(self) -> None:
"""Train the model."""
assert self.pipeline.datamanager.train_dataset is not None, "Missing DatsetInputs"

self.pipeline.datamanager.train_dataparser_outputs.save_dataparser_transform(
self.base_dir / "dataparser_transforms.json"
)

self._init_viewer_state()
with TimeWriter(writer, EventName.TOTAL_TRAIN_TIME):
num_iterations = self.config.trainer.max_num_iterations
Expand Down

0 comments on commit b41ce80

Please sign in to comment.