-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
Copy pathbase_datamanager.py
576 lines (490 loc) · 24.6 KB
/
base_datamanager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Datamanager.
"""
from __future__ import annotations
from abc import abstractmethod
from collections import defaultdict
from dataclasses import dataclass, field
from functools import cached_property
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
ForwardRef,
Generic,
List,
Literal,
Optional,
Tuple,
Type,
Union,
cast,
get_args,
get_origin,
)
import torch
import tyro
from torch import nn
from torch.nn import Parameter
from torch.utils.data.distributed import DistributedSampler
from typing_extensions import TypeVar
from nerfstudio.cameras.camera_optimizers import CameraOptimizerConfig
from nerfstudio.cameras.cameras import Cameras, CameraType
from nerfstudio.cameras.rays import RayBundle
from nerfstudio.configs.base_config import InstantiateConfig
from nerfstudio.configs.dataparser_configs import AnnotatedDataParserUnion
from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs
from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig
from nerfstudio.data.datasets.base_dataset import InputDataset
from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig
from nerfstudio.data.utils.dataloaders import CacheDataloader, FixedIndicesEvalDataloader, RandIndicesEvalDataloader
from nerfstudio.data.utils.nerfstudio_collate import nerfstudio_collate
from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes
from nerfstudio.model_components.ray_generators import RayGenerator
from nerfstudio.utils.misc import IterableWrapper, get_orig_class
from nerfstudio.utils.rich_utils import CONSOLE
def variable_res_collate(batch: List[Dict]) -> Dict:
"""Default collate function for the cached dataloader.
Args:
batch: Batch of samples from the dataset.
Returns:
Collated batch.
"""
images = []
imgdata_lists = defaultdict(list)
for data in batch:
image = data.pop("image")
images.append(image)
topop = []
for key, val in data.items():
if isinstance(val, torch.Tensor):
# if the value has same height and width as the image, assume that it should be collated accordingly.
if len(val.shape) >= 2 and val.shape[:2] == image.shape[:2]:
imgdata_lists[key].append(val)
topop.append(key)
# now that iteration is complete, the image data items can be removed from the batch
for key in topop:
del data[key]
new_batch = nerfstudio_collate(batch)
new_batch["image"] = images
new_batch.update(imgdata_lists)
return new_batch
@dataclass
class DataManagerConfig(InstantiateConfig):
"""Configuration for data manager instantiation; DataManager is in charge of keeping the train/eval dataparsers;
After instantiation, data manager holds both train/eval datasets and is in charge of returning unpacked
train/eval data at each iteration
"""
_target: Type = field(default_factory=lambda: DataManager)
"""Target class to instantiate."""
data: Optional[Path] = None
"""Source of data, may not be used by all models."""
masks_on_gpu: bool = False
"""Process masks on GPU for speed at the expense of memory, if True."""
images_on_gpu: bool = False
"""Process images on GPU for speed at the expense of memory, if True."""
class DataManager(nn.Module):
"""Generic data manager's abstract class
This version of the data manager is designed be a monolithic way to load data and latents,
especially since this may contain learnable parameters which need to be shared across the train
and test data managers. The idea is that we have setup methods for train and eval separately and
this can be a combined train/eval if you want.
Usage:
To get data, use the next_train and next_eval functions.
This data manager's next_train and next_eval methods will return 2 things:
1. 'rays': This will contain the rays or camera we are sampling, with latents and
conditionals attached (everything needed at inference)
2. A "batch" of auxiliary information: This will contain the mask, the ground truth
pixels, etc needed to actually train, score, etc the model
Rationale:
Because of this abstraction we've added, we can support more NeRF paradigms beyond the
vanilla nerf paradigm of single-scene, fixed-images, no-learnt-latents.
We can now support variable scenes, variable number of images, and arbitrary latents.
Train Methods:
setup_train: sets up for being used as train
iter_train: will be called on __iter__() for the train iterator
next_train: will be called on __next__() for the training iterator
get_train_iterable: utility that gets a clean pythonic iterator for your training data
Eval Methods:
setup_eval: sets up for being used as eval
iter_eval: will be called on __iter__() for the eval iterator
next_eval: will be called on __next__() for the eval iterator
get_eval_iterable: utility that gets a clean pythonic iterator for your eval data
Attributes:
train_count (int): the step number of our train iteration, needs to be incremented manually
eval_count (int): the step number of our eval iteration, needs to be incremented manually
train_dataset (Dataset): the dataset for the train dataset
eval_dataset (Dataset): the dataset for the eval dataset
includes_time (bool): whether the dataset includes time information
Additional attributes specific to each subclass are defined in the setup_train and setup_eval
functions.
"""
train_dataset: Optional[InputDataset] = None
eval_dataset: Optional[InputDataset] = None
train_sampler: Optional[DistributedSampler] = None
eval_sampler: Optional[DistributedSampler] = None
includes_time: bool = False
def __init__(self):
"""Constructor for the DataManager class.
Subclassed DataManagers will likely need to override this constructor.
If you aren't manually calling the setup_train and setup_eval functions from an overriden
constructor, that you call super().__init__() BEFORE you initialize any
nn.Modules or nn.Parameters, but AFTER you've already set all the attributes you need
for the setup functions."""
super().__init__()
self.train_count = 0
self.eval_count = 0
if self.train_dataset and self.test_mode != "inference":
self.setup_train()
if self.eval_dataset and self.test_mode != "inference":
self.setup_eval()
def forward(self):
"""Blank forward method
This is an nn.Module, and so requires a forward() method normally, although in our case
we do not need a forward() method"""
raise NotImplementedError
def iter_train(self):
"""The __iter__ function for the train iterator.
This only exists to assist the get_train_iterable function, since we need to pass
in an __iter__ function for our trivial iterable that we are making."""
self.train_count = 0
def iter_eval(self):
"""The __iter__ function for the eval iterator.
This only exists to assist the get_eval_iterable function, since we need to pass
in an __iter__ function for our trivial iterable that we are making."""
self.eval_count = 0
def get_train_iterable(self, length=-1) -> IterableWrapper:
"""Gets a trivial pythonic iterator that will use the iter_train and next_train functions
as __iter__ and __next__ methods respectively.
This basically is just a little utility if you want to do something like:
| for ray_bundle, batch in datamanager.get_train_iterable():
| <eval code here>
since the returned IterableWrapper is just an iterator with the __iter__ and __next__
methods (methods bound to our DataManager instance in this case) specified in the constructor.
"""
return IterableWrapper(self.iter_train, self.next_train, length)
def get_eval_iterable(self, length=-1) -> IterableWrapper:
"""Gets a trivial pythonic iterator that will use the iter_eval and next_eval functions
as __iter__ and __next__ methods respectively.
This basically is just a little utility if you want to do something like:
| for ray_bundle, batch in datamanager.get_eval_iterable():
| <eval code here>
since the returned IterableWrapper is just an iterator with the __iter__ and __next__
methods (methods bound to our DataManager instance in this case) specified in the constructor.
"""
return IterableWrapper(self.iter_eval, self.next_eval, length)
@abstractmethod
def setup_train(self):
"""Sets up the data manager for training.
Here you will define any subclass specific object attributes from the attribute"""
@abstractmethod
def setup_eval(self):
"""Sets up the data manager for evaluation"""
@abstractmethod
def next_train(self, step: int) -> Tuple[Union[RayBundle, Cameras], Dict]:
"""Returns the next batch of data from the train data manager.
Args:
step: the step number of the eval image to retrieve
Returns:
A tuple of the ray bundle for the image, and a dictionary of additional batch information
such as the groundtruth image.
"""
raise NotImplementedError
@abstractmethod
def next_eval(self, step: int) -> Tuple[Union[RayBundle, Cameras], Dict]:
"""Returns the next batch of data from the eval data manager.
Args:
step: the step number of the eval image to retrieve
Returns:
A tuple of the ray/camera for the image, and a dictionary of additional batch information
such as the groundtruth image.
"""
raise NotImplementedError
@abstractmethod
def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]:
"""Retrieve the next eval image.
Args:
step: the step number of the eval image to retrieve
Returns:
A tuple of the step number, the ray/camera for the image, and a dictionary of
additional batch information such as the groundtruth image.
"""
raise NotImplementedError
@abstractmethod
def get_train_rays_per_batch(self) -> int:
"""Returns the number of rays per batch for training."""
raise NotImplementedError
@abstractmethod
def get_eval_rays_per_batch(self) -> int:
"""Returns the number of rays per batch for evaluation."""
raise NotImplementedError
@abstractmethod
def get_datapath(self) -> Path:
"""Returns the path to the data. This is used to determine where to save camera paths."""
def get_training_callbacks(
self, training_callback_attributes: TrainingCallbackAttributes
) -> List[TrainingCallback]:
"""Returns a list of callbacks to be used during training."""
return []
@abstractmethod
def get_param_groups(self) -> Dict[str, List[Parameter]]:
"""Get the param groups for the data manager.
Returns:
A list of dictionaries containing the data manager's param groups.
"""
return {}
@dataclass
class VanillaDataManagerConfig(DataManagerConfig):
"""A basic data manager for a ray-based model"""
_target: Type = field(default_factory=lambda: VanillaDataManager)
"""Target class to instantiate."""
dataparser: AnnotatedDataParserUnion = field(default_factory=BlenderDataParserConfig)
"""Specifies the dataparser used to unpack the data."""
train_num_rays_per_batch: int = 1024
"""Number of rays per batch to use per training iteration."""
train_num_images_to_sample_from: int = -1
"""Number of images to sample during training iteration."""
train_num_times_to_repeat_images: int = -1
"""When not training on all images, number of iterations before picking new
images. If -1, never pick new images."""
eval_num_rays_per_batch: int = 1024
"""Number of rays per batch to use per eval iteration."""
eval_num_images_to_sample_from: int = -1
"""Number of images to sample during eval iteration."""
eval_num_times_to_repeat_images: int = -1
"""When not evaluating on all images, number of iterations before picking
new images. If -1, never pick new images."""
eval_image_indices: Optional[Tuple[int, ...]] = (0,)
"""Specifies the image indices to use during eval; if None, uses all."""
collate_fn: Callable[[Any], Any] = cast(Any, staticmethod(nerfstudio_collate))
"""Specifies the collate function to use for the train and eval dataloaders."""
camera_res_scale_factor: float = 1.0
"""The scale factor for scaling spatial data such as images, mask, semantics
along with relevant information about camera intrinsics
"""
patch_size: int = 1
"""Size of patch to sample from. If > 1, patch-based sampling will be used."""
# tyro.conf.Suppress prevents us from creating CLI arguments for this field.
camera_optimizer: tyro.conf.Suppress[Optional[CameraOptimizerConfig]] = field(default=None)
"""Deprecated, has been moved to the model config."""
pixel_sampler: PixelSamplerConfig = field(default_factory=PixelSamplerConfig)
"""Specifies the pixel sampler used to sample pixels from images."""
def __post_init__(self):
"""Warn user of camera optimizer change."""
if self.camera_optimizer is not None:
import warnings
CONSOLE.print(
"\nCameraOptimizerConfig has been moved from the DataManager to the Model.\n", style="bold yellow"
)
warnings.warn("above message coming from", FutureWarning, stacklevel=3)
TDataset = TypeVar("TDataset", bound=InputDataset, default=InputDataset)
class VanillaDataManager(DataManager, Generic[TDataset]):
"""Basic stored data manager implementation.
This is pretty much a port over from our old dataloading utilities, and is a little jank
under the hood. We may clean this up a little bit under the hood with more standard dataloading
components that can be strung together, but it can be just used as a black box for now since
only the constructor is likely to change in the future, or maybe passing in step number to the
next_train and next_eval functions.
Args:
config: the DataManagerConfig used to instantiate class
"""
config: VanillaDataManagerConfig
train_dataset: TDataset
eval_dataset: TDataset
train_dataparser_outputs: DataparserOutputs
train_pixel_sampler: Optional[PixelSampler] = None
eval_pixel_sampler: Optional[PixelSampler] = None
def __init__(
self,
config: VanillaDataManagerConfig,
device: Union[torch.device, str] = "cpu",
test_mode: Literal["test", "val", "inference"] = "val",
world_size: int = 1,
local_rank: int = 0,
**kwargs,
):
self.config = config
self.device = device
self.world_size = world_size
self.local_rank = local_rank
self.sampler = None
self.test_mode = test_mode
self.test_split = "test" if test_mode in ["test", "inference"] else "val"
self.dataparser_config = self.config.dataparser
if self.config.data is not None:
self.config.dataparser.data = Path(self.config.data)
else:
self.config.data = self.config.dataparser.data
self.dataparser = self.dataparser_config.setup()
if test_mode == "inference":
self.dataparser.downscale_factor = 1 # Avoid opening images
self.includes_time = self.dataparser.includes_time
self.train_dataparser_outputs: DataparserOutputs = self.dataparser.get_dataparser_outputs(split="train")
self.train_dataset = self.create_train_dataset()
self.eval_dataset = self.create_eval_dataset()
self.exclude_batch_keys_from_device = self.train_dataset.exclude_batch_keys_from_device
if self.config.masks_on_gpu is True and "mask" in self.exclude_batch_keys_from_device:
self.exclude_batch_keys_from_device.remove("mask")
if self.config.images_on_gpu is True and "image" in self.exclude_batch_keys_from_device:
self.exclude_batch_keys_from_device.remove("image")
if self.train_dataparser_outputs is not None:
cameras = self.train_dataparser_outputs.cameras
if len(cameras) > 1:
for i in range(1, len(cameras)):
if cameras[0].width != cameras[i].width or cameras[0].height != cameras[i].height:
CONSOLE.print("Variable resolution, using variable_res_collate")
self.config.collate_fn = variable_res_collate
break
super().__init__()
@cached_property
def dataset_type(self) -> Type[TDataset]:
"""Returns the dataset type passed as the generic argument"""
default: Type[TDataset] = cast(TDataset, TDataset.__default__) # type: ignore
orig_class: Type[VanillaDataManager] = get_orig_class(self, default=None) # type: ignore
if type(self) is VanillaDataManager and orig_class is None:
return default
if orig_class is not None and get_origin(orig_class) is VanillaDataManager:
return get_args(orig_class)[0]
# For inherited classes, we need to find the correct type to instantiate
for base in getattr(self, "__orig_bases__", []):
if get_origin(base) is VanillaDataManager:
for value in get_args(base):
if isinstance(value, ForwardRef):
if value.__forward_evaluated__:
value = value.__forward_value__
elif value.__forward_module__ is None:
value.__forward_module__ = type(self).__module__
value = getattr(value, "_evaluate")(None, None, set())
assert isinstance(value, type)
if issubclass(value, InputDataset):
return cast(Type[TDataset], value)
return default
def create_train_dataset(self) -> TDataset:
"""Sets up the data loaders for training"""
return self.dataset_type(
dataparser_outputs=self.train_dataparser_outputs,
scale_factor=self.config.camera_res_scale_factor,
)
def create_eval_dataset(self) -> TDataset:
"""Sets up the data loaders for evaluation"""
return self.dataset_type(
dataparser_outputs=self.dataparser.get_dataparser_outputs(split=self.test_split),
scale_factor=self.config.camera_res_scale_factor,
)
def _get_pixel_sampler(self, dataset: TDataset, num_rays_per_batch: int) -> PixelSampler:
"""Infer pixel sampler to use."""
if self.config.patch_size > 1 and type(self.config.pixel_sampler) is PixelSamplerConfig:
return PatchPixelSamplerConfig().setup(
patch_size=self.config.patch_size, num_rays_per_batch=num_rays_per_batch
)
is_equirectangular = (dataset.cameras.camera_type == CameraType.EQUIRECTANGULAR.value).all()
if is_equirectangular.any():
CONSOLE.print("[bold yellow]Warning: Some cameras are equirectangular, but using default pixel sampler.")
fisheye_crop_radius = None
if dataset.cameras.metadata is not None:
fisheye_crop_radius = dataset.cameras.metadata.get("fisheye_crop_radius")
return self.config.pixel_sampler.setup(
is_equirectangular=is_equirectangular,
num_rays_per_batch=num_rays_per_batch,
fisheye_crop_radius=fisheye_crop_radius,
)
def setup_train(self):
"""Sets up the data loaders for training"""
assert self.train_dataset is not None
CONSOLE.print("Setting up training dataset...")
self.train_image_dataloader = CacheDataloader(
self.train_dataset,
num_images_to_sample_from=self.config.train_num_images_to_sample_from,
num_times_to_repeat_images=self.config.train_num_times_to_repeat_images,
device=self.device,
num_workers=self.world_size * 4,
pin_memory=True,
collate_fn=self.config.collate_fn,
exclude_batch_keys_from_device=self.exclude_batch_keys_from_device,
)
self.iter_train_image_dataloader = iter(self.train_image_dataloader)
self.train_pixel_sampler = self._get_pixel_sampler(self.train_dataset, self.config.train_num_rays_per_batch)
self.train_ray_generator = RayGenerator(self.train_dataset.cameras.to(self.device))
def setup_eval(self):
"""Sets up the data loader for evaluation"""
assert self.eval_dataset is not None
CONSOLE.print("Setting up evaluation dataset...")
self.eval_image_dataloader = CacheDataloader(
self.eval_dataset,
num_images_to_sample_from=self.config.eval_num_images_to_sample_from,
num_times_to_repeat_images=self.config.eval_num_times_to_repeat_images,
device=self.device,
num_workers=self.world_size * 4,
pin_memory=True,
collate_fn=self.config.collate_fn,
exclude_batch_keys_from_device=self.exclude_batch_keys_from_device,
)
self.iter_eval_image_dataloader = iter(self.eval_image_dataloader)
self.eval_pixel_sampler = self._get_pixel_sampler(self.eval_dataset, self.config.eval_num_rays_per_batch)
self.eval_ray_generator = RayGenerator(self.eval_dataset.cameras.to(self.device))
# for loading full images
self.fixed_indices_eval_dataloader = FixedIndicesEvalDataloader(
input_dataset=self.eval_dataset,
device=self.device,
num_workers=self.world_size * 4,
)
self.eval_dataloader = RandIndicesEvalDataloader(
input_dataset=self.eval_dataset,
device=self.device,
num_workers=self.world_size * 4,
)
def next_train(self, step: int) -> Tuple[RayBundle, Dict]:
"""Returns the next batch of data from the train dataloader."""
self.train_count += 1
image_batch = next(self.iter_train_image_dataloader)
assert self.train_pixel_sampler is not None
assert isinstance(image_batch, dict)
batch = self.train_pixel_sampler.sample(image_batch)
ray_indices = batch["indices"]
ray_bundle = self.train_ray_generator(ray_indices)
return ray_bundle, batch
def next_eval(self, step: int) -> Tuple[RayBundle, Dict]:
"""Returns the next batch of data from the eval dataloader."""
self.eval_count += 1
image_batch = next(self.iter_eval_image_dataloader)
assert self.eval_pixel_sampler is not None
assert isinstance(image_batch, dict)
batch = self.eval_pixel_sampler.sample(image_batch)
ray_indices = batch["indices"]
ray_bundle = self.eval_ray_generator(ray_indices)
return ray_bundle, batch
def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]:
for camera, batch in self.eval_dataloader:
assert camera.shape[0] == 1
return camera, batch
raise ValueError("No more eval images")
def get_train_rays_per_batch(self) -> int:
if self.train_pixel_sampler is not None:
return self.train_pixel_sampler.num_rays_per_batch
return self.config.train_num_rays_per_batch
def get_eval_rays_per_batch(self) -> int:
if self.eval_pixel_sampler is not None:
return self.eval_pixel_sampler.num_rays_per_batch
return self.config.eval_num_rays_per_batch
def get_datapath(self) -> Path:
return self.config.dataparser.data
def get_param_groups(self) -> Dict[str, List[Parameter]]:
"""Get the param groups for the data manager.
Returns:
A list of dictionaries containing the data manager's param groups.
"""
return {}