Skip to content

Commit

Permalink
Only load+undistort splatfacto images when they're used (nerfstudio-p…
Browse files Browse the repository at this point in the history
…roject#3043)

* Only load splatfacto dataset when it's used

* ruff

* typo fix

* Fix nerfstudio-project#2817; don't undistort for `ns-viewer --viewer.max-num-display-images 0`

* ruff
  • Loading branch information
brentyi authored and Michael-Spleenlab committed Apr 26, 2024
1 parent d6384c8 commit 0ebca65
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 70 deletions.
116 changes: 47 additions & 69 deletions nerfstudio/data/datamanagers/full_images_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import torch
from rich.progress import track
from torch.nn import Parameter
from typing_extensions import assert_never

from nerfstudio.cameras.camera_utils import fisheye624_project, fisheye624_unproject_helper
from nerfstudio.cameras.cameras import Cameras, CameraType
Expand Down Expand Up @@ -115,7 +116,6 @@ def __init__(
style="bold yellow",
)
self.config.cache_images = "cpu"
self.cached_train, self.cached_eval = self.cache_images(self.config.cache_images)
self.exclude_batch_keys_from_device = self.train_dataset.exclude_batch_keys_from_device
if self.config.masks_on_gpu is True:
self.exclude_batch_keys_from_device.remove("mask")
Expand All @@ -129,37 +129,34 @@ def __init__(

super().__init__()

def cache_images(self, cache_images_option):
cached_train = []
cached_eval = []

def process_train_data(idx):
# cv2.undistort the images / cameras
data = self.train_dataset.get_data(idx, image_type=self.config.cache_images_type)
camera = self.train_dataset.cameras[idx].reshape(())
K = camera.get_intrinsics_matrices().numpy()
if camera.distortion_params is None:
return data
distortion_params = camera.distortion_params.numpy()
image = data["image"].numpy()

K, image, mask = _undistort_image(camera, distortion_params, data, image, K)
data["image"] = torch.from_numpy(image)
if mask is not None:
data["mask"] = mask
@cached_property
def cached_train(self) -> List[Dict[str, torch.Tensor]]:
"""Get the training images. Will load and undistort the images the
first time this (cached) property is accessed."""
return self._load_images("train", cache_images_device=self.config.cache_images)

self.train_dataset.cameras.fx[idx] = float(K[0, 0])
self.train_dataset.cameras.fy[idx] = float(K[1, 1])
self.train_dataset.cameras.cx[idx] = float(K[0, 2])
self.train_dataset.cameras.cy[idx] = float(K[1, 2])
self.train_dataset.cameras.width[idx] = image.shape[1]
self.train_dataset.cameras.height[idx] = image.shape[0]
return data
@cached_property
def cached_eval(self) -> List[Dict[str, torch.Tensor]]:
"""Get the eval images. Will load and undistort the images the
first time this (cached) property is accessed."""
return self._load_images("eval", cache_images_device=self.config.cache_images)

def _load_images(
self, split: Literal["train", "eval"], cache_images_device: Literal["cpu", "gpu"]
) -> List[Dict[str, torch.Tensor]]:
undistorted_images: List[Dict[str, torch.Tensor]] = []

# Which dataset?
if split == "train":
dataset = self.train_dataset
elif split == "eval":
dataset = self.eval_dataset
else:
assert_never(split)

def process_eval_data(idx):
# cv2.undistort the images / cameras
data = self.eval_dataset.get_data(idx, image_type=self.config.cache_images_type)
camera = self.eval_dataset.cameras[idx].reshape(())
def undistort_idx(idx: int) -> Dict[str, torch.Tensor]:
data = dataset.get_data(idx, image_type=self.config.cache_images_type)
camera = dataset.cameras[idx].reshape(())
K = camera.get_intrinsics_matrices().numpy()
if camera.distortion_params is None:
return data
Expand All @@ -171,62 +168,43 @@ def process_eval_data(idx):
if mask is not None:
data["mask"] = mask

self.eval_dataset.cameras.fx[idx] = float(K[0, 0])
self.eval_dataset.cameras.fy[idx] = float(K[1, 1])
self.eval_dataset.cameras.cx[idx] = float(K[0, 2])
self.eval_dataset.cameras.cy[idx] = float(K[1, 2])
self.eval_dataset.cameras.width[idx] = image.shape[1]
self.eval_dataset.cameras.height[idx] = image.shape[0]
dataset.cameras.fx[idx] = float(K[0, 0])
dataset.cameras.fy[idx] = float(K[1, 1])
dataset.cameras.cx[idx] = float(K[0, 2])
dataset.cameras.cy[idx] = float(K[1, 2])
dataset.cameras.width[idx] = image.shape[1]
dataset.cameras.height[idx] = image.shape[0]
return data

CONSOLE.log("Caching / undistorting train images")
CONSOLE.log(f"Caching / undistorting {split} images")
with ThreadPoolExecutor(max_workers=2) as executor:
cached_train = list(
undistorted_images = list(
track(
executor.map(
process_train_data,
range(len(self.train_dataset)),
undistort_idx,
range(len(dataset)),
),
description="Caching / undistorting train images",
description=f"Caching / undistorting {split} images",
transient=True,
total=len(self.train_dataset),
total=len(dataset),
)
)

CONSOLE.log("Caching / undistorting eval images")
with ThreadPoolExecutor(max_workers=2) as executor:
cached_eval = list(
track(
executor.map(
process_eval_data,
range(len(self.eval_dataset)),
),
description="Caching / undistorting eval images",
transient=True,
total=len(self.eval_dataset),
)
)

if cache_images_option == "gpu":
for cache in cached_train:
# Move to device.
if cache_images_device == "gpu":
for cache in undistorted_images:
cache["image"] = cache["image"].to(self.device)
if "mask" in cache:
cache["mask"] = cache["mask"].to(self.device)
for cache in cached_eval:
cache["image"] = cache["image"].to(self.device)
if "mask" in cache:
cache["mask"] = cache["mask"].to(self.device)
else:
for cache in cached_train:
cache["image"] = cache["image"].pin_memory()
if "mask" in cache:
cache["mask"] = cache["mask"].pin_memory()
for cache in cached_eval:
elif cache_images_device == "cpu":
for cache in undistorted_images:
cache["image"] = cache["image"].pin_memory()
if "mask" in cache:
cache["mask"] = cache["mask"].pin_memory()
else:
assert_never(cache_images_device)

return cached_train, cached_eval
return undistorted_images

def create_train_dataset(self) -> TDataset:
"""Sets up the data loaders for training"""
Expand Down
2 changes: 1 addition & 1 deletion nerfstudio/viewer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def add(ret: List[Tuple[str, Any]], ts: str, v: Any):
return []
ret = []
# get a list of the properties of the object, sorted by whether things are instances of type_check
obj_props = [(k, getattr(obj, k)) for k in dir(obj)]
obj_props = [(k, v) for k, v in vars(obj).items()]
for k, v in obj_props:
if k[0] == "_":
continue
Expand Down

0 comments on commit 0ebca65

Please sign in to comment.