Skip to content

Commit

Permalink
Add 'inference' mode to pipeline (nerfstudio-project#956)
Browse files Browse the repository at this point in the history
* cleanup test mode

* cleanup eval loading

* change default

* cleanup

* cleanup

* comment
  • Loading branch information
nikmo33 authored Nov 14, 2022
1 parent d459864 commit a6458fc
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 38 deletions.
13 changes: 6 additions & 7 deletions nerfstudio/data/datamanagers/base_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from abc import abstractmethod
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Type, Union
from typing import Dict, List, Literal, Optional, Tuple, Type, Union

import torch
import tyro
Expand Down Expand Up @@ -136,9 +136,9 @@ def __init__(self):
super().__init__()
self.train_count = 0
self.eval_count = 0
if self.train_dataset:
if self.train_dataset and self.test_mode != "inference":
self.setup_train()
if self.eval_dataset:
if self.eval_dataset and self.test_mode != "inference":
self.setup_eval()

def forward(self):
Expand Down Expand Up @@ -288,7 +288,7 @@ def __init__(
self,
config: VanillaDataManagerConfig,
device: Union[torch.device, str] = "cpu",
test_mode: bool = False,
test_mode: Literal["test", "val", "inference"] = "val",
world_size: int = 1,
local_rank: int = 0,
**kwargs, # pylint: disable=unused-argument
Expand All @@ -299,6 +299,7 @@ def __init__(
self.local_rank = local_rank
self.sampler = None
self.test_mode = test_mode
self.test_split = "test" if test_mode in ["test", "inference"] else "val"

self.train_dataset = self.create_train_dataset()
self.eval_dataset = self.create_eval_dataset()
Expand All @@ -310,9 +311,7 @@ def create_train_dataset(self) -> InputDataset:

def create_eval_dataset(self) -> InputDataset:
"""Sets up the data loaders for evaluation"""
return InputDataset(
self.config.dataparser.setup().get_dataparser_outputs(split="val" if not self.test_mode else "test")
)
return InputDataset(self.config.dataparser.setup().get_dataparser_outputs(split=self.test_split))

def setup_train(self):
"""Sets up the data loaders for training"""
Expand Down
4 changes: 1 addition & 3 deletions nerfstudio/data/datamanagers/semantic_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,4 @@ def create_train_dataset(self) -> SemanticDataset:
return SemanticDataset(self.config.dataparser.setup().get_dataparser_outputs(split="train"))

def create_eval_dataset(self) -> SemanticDataset:
return SemanticDataset(
self.config.dataparser.setup().get_dataparser_outputs(split="val" if not self.test_mode else "test")
)
return SemanticDataset(self.config.dataparser.setup().get_dataparser_outputs(split=self.test_split))
9 changes: 6 additions & 3 deletions nerfstudio/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import functools
import os
import time
from typing import Dict, List, Tuple
from typing import Dict, List, Literal, Tuple

import torch
from rich.console import Console
Expand Down Expand Up @@ -102,11 +102,14 @@ def __init__(self, config: cfg.Config, local_rank: int = 0, world_size: int = 1)
writer.put_config(name="config", config_dict=dataclasses.asdict(config), step=0)
profiler.setup_profiler(config.logging)

def setup(self, test_mode=False):
def setup(self, test_mode: Literal["test", "val", "inference"] = "val"):
"""Setup the Trainer by calling other setup functions.
Args:
test_mode: Whether to setup for testing. Defaults to False.
test_mode:
'val': loads train/val datasets into memory
'test': loads train/test datset into memory
'inference': does not load any dataset into memory
"""
self.pipeline = self.config.pipeline.setup(
device=self.device, test_mode=test_mode, world_size=self.world_size, local_rank=self.local_rank
Expand Down
18 changes: 14 additions & 4 deletions nerfstudio/pipelines/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from abc import abstractmethod
from dataclasses import dataclass, field
from time import time
from typing import Any, Dict, List, Optional, Type, Union, cast
from typing import Any, Dict, List, Literal, Optional, Type, Union, cast

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -77,7 +77,10 @@ class should be 1:1 with a pipeline that can act as a standardized interface and
Args:
config: configuration to instantiate pipeline
device: location to place model and data
test_mode: if True, loads test datset. if False, loads train/eval datasets
test_mode:
'train': loads train/eval datasets into memory
'test': loads train/test datset into memory
'inference': does not load any dataset into memory
world_size: total number of machines available
local_rank: rank of current machine
Expand Down Expand Up @@ -191,7 +194,10 @@ class VanillaPipeline(Pipeline):
config: configuration to instantiate pipeline
device: location to place model and data
test_mode: if True, loads test datset. if False, loads train/eval datasets
test_mode:
'val': loads train/val datasets into memory
'test': loads train/test datset into memory
'inference': does not load any dataset into memory
world_size: total number of machines available
local_rank: rank of current machine
Expand All @@ -204,12 +210,13 @@ def __init__(
self,
config: VanillaPipelineConfig,
device: str,
test_mode: bool = False,
test_mode: Literal["test", "val", "inference"] = "val",
world_size: int = 1,
local_rank: int = 0,
):
super().__init__()
self.config = config
self.test_mode = test_mode
self.datamanager: VanillaDataManager = config.datamanager.setup(
device=device, test_mode=test_mode, world_size=world_size, local_rank=local_rank
)
Expand Down Expand Up @@ -351,6 +358,9 @@ def load_pipeline(self, loaded_state: Dict[str, Any]) -> None:
loaded_state: pre-trained model state dict
"""
state = {key.replace("module.", ""): value for key, value in loaded_state.items()}
if self.test_mode == "inference":
state.pop("datamanager.train_ray_generator.image_coords", None)
state.pop("datamanager.eval_ray_generator.image_coords", None)
self.load_state_dict(state) # type: ignore

def get_training_callbacks(
Expand Down
4 changes: 2 additions & 2 deletions nerfstudio/pipelines/dynamic_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""

from dataclasses import dataclass, field
from typing import Type
from typing import Literal, Type

import torch

Expand Down Expand Up @@ -49,7 +49,7 @@ def __init__(
self,
config: DynamicBatchPipelineConfig,
device: str,
test_mode: bool = False,
test_mode: Literal["test", "val", "inference"] = "val",
world_size: int = 1,
local_rank: int = 0,
):
Expand Down
16 changes: 13 additions & 3 deletions nerfstudio/utils/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import os
import sys
from pathlib import Path
from typing import Optional, Tuple
from typing import Literal, Optional, Tuple

import torch
import yaml
Expand Down Expand Up @@ -63,11 +63,21 @@ def eval_load_checkpoint(config: cfg.TrainerConfig, pipeline: Pipeline) -> Path:
return load_path


def eval_setup(config_path: Path, eval_num_rays_per_chunk: Optional[int] = None) -> Tuple[cfg.Config, Pipeline, Path]:
def eval_setup(
config_path: Path,
eval_num_rays_per_chunk: Optional[int] = None,
test_mode: Literal["test", "val", "inference"] = "test",
) -> Tuple[cfg.Config, Pipeline, Path]:
"""Shared setup for loading a saved pipeline for evaluation.
Args:
config_path: Path to config YAML file.
eval_num_rays_per_chunk: Number of rays per forward pass
test_mode:
'val': loads train/val datasets into memory
'test': loads train/test datset into memory
'inference': does not load any dataset into memory
Returns:
Loaded config, pipeline module, and corresponding checkpoint.
Expand All @@ -86,7 +96,7 @@ def eval_setup(config_path: Path, eval_num_rays_per_chunk: Optional[int] = None)

# setup pipeline (which includes the DataManager)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipeline = config.pipeline.setup(device=device, test_mode=True)
pipeline = config.pipeline.setup(device=device, test_mode=test_mode)
assert isinstance(pipeline, Pipeline)
pipeline.eval()

Expand Down
38 changes: 22 additions & 16 deletions scripts/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
"""
from __future__ import annotations

import dataclasses
import json
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
from typing import List, Optional

import mediapy as media
import numpy as np
import torch
import tyro
from rich.console import Console
Expand Down Expand Up @@ -38,7 +39,7 @@ def _render_trajectory_video(
pipeline: Pipeline,
cameras: Cameras,
output_filename: Path,
rendered_output_name: str,
rendered_output_names: List[str],
rendered_resolution_scaling_factor: float = 1.0,
seconds: float = 5.0,
output_format: Literal["images", "video"] = "video",
Expand All @@ -49,7 +50,7 @@ def _render_trajectory_video(
pipeline: Pipeline to evaluate with.
cameras: Cameras to render.
output_filename: Name of the output file.
rendered_output_name: Name of the renderer output to use.
rendered_output_names: List of outputs to visualise.
rendered_resolution_scaling_factor: Scaling factor to apply to the camera image resolution.
seconds: Length of output video.
output_format: How to save output data.
Expand All @@ -73,16 +74,20 @@ def _render_trajectory_video(
camera_ray_bundle = cameras.generate_rays(camera_indices=camera_idx).to(pipeline.device)
with torch.no_grad():
outputs = pipeline.model.get_outputs_for_camera_ray_bundle(camera_ray_bundle)
if rendered_output_name not in outputs:
CONSOLE.rule("Error", style="red")
CONSOLE.print(f"Could not find {rendered_output_name} in the model outputs", justify="center")
CONSOLE.print(f"Please set --rendered_output_name to one of: {outputs.keys()}", justify="center")
sys.exit(1)
image = outputs[rendered_output_name].cpu().numpy()
render_image = []
for rendered_output_name in rendered_output_names:
if rendered_output_name not in outputs:
CONSOLE.rule("Error", style="red")
CONSOLE.print(f"Could not find {rendered_output_name} in the model outputs", justify="center")
CONSOLE.print(f"Please set --rendered_output_name to one of: {outputs.keys()}", justify="center")
sys.exit(1)
output_image = outputs[rendered_output_name].cpu().numpy()
render_image.append(output_image)
render_image = np.concatenate(render_image, axis=1)
if output_format == "images":
media.write_image(output_image_dir / f"{camera_idx:05d}.png", image)
media.write_image(output_image_dir / f"{camera_idx:05d}.png", render_image)
else:
images.append(image)
images.append(render_image)

if output_format == "video":
fps = len(images) / seconds
Expand All @@ -94,14 +99,14 @@ def _render_trajectory_video(
CONSOLE.print(f"[green]Saved video to {output_filename}", justify="center")


@dataclasses.dataclass
@dataclass
class RenderTrajectory:
"""Load a checkpoint, render a trajectory, and save to a video file."""

# Path to config YAML file.
load_config: Path
# Name of the renderer output to use. rgb, depth, etc.
rendered_output_name: str = "rgb"
# Name of the renderer outputs to use. rgb, depth, etc. concatenates them along y axis
rendered_output_names: List[str] = field(default_factory=lambda: ["rgb"])
# Trajectory to render.
traj: Literal["spiral", "filename"] = "spiral"
# Scaling factor to apply to the camera image resolution.
Expand All @@ -122,6 +127,7 @@ def main(self) -> None:
_, pipeline, _ = eval_setup(
self.load_config,
eval_num_rays_per_chunk=self.eval_num_rays_per_chunk,
test_mode="test" if self.traj == "spiral" else "inference",
)

install_checks.check_ffmpeg_installed()
Expand All @@ -145,7 +151,7 @@ def main(self) -> None:
pipeline,
camera_path,
output_filename=self.output_path,
rendered_output_name=self.rendered_output_name,
rendered_output_names=self.rendered_output_names,
rendered_resolution_scaling_factor=1.0 / self.downscale_factor,
seconds=seconds,
output_format=self.output_format,
Expand Down

0 comments on commit a6458fc

Please sign in to comment.