From a12faff3d1e54204972eed2ac31ec57610acf02e Mon Sep 17 00:00:00 2001 From: machenmusik Date: Thu, 9 Feb 2023 18:32:51 -0500 Subject: [PATCH] fix instant-ngp while preserving robust for nerfacto --- nerfstudio/configs/method_configs.py | 8 ++-- .../data/datamanagers/base_datamanager.py | 4 +- .../data/datamanagers/patch_datamanager.py | 44 +++++++++++++++++++ 3 files changed, 49 insertions(+), 7 deletions(-) create mode 100644 nerfstudio/data/datamanagers/patch_datamanager.py diff --git a/nerfstudio/configs/method_configs.py b/nerfstudio/configs/method_configs.py index 6884e70540..d436cd15e8 100644 --- a/nerfstudio/configs/method_configs.py +++ b/nerfstudio/configs/method_configs.py @@ -27,6 +27,7 @@ from nerfstudio.configs.base_config import ViewerConfig from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManagerConfig from nerfstudio.data.datamanagers.depth_datamanager import DepthDataManagerConfig +from nerfstudio.data.datamanagers.patch_datamanager import PatchDataManagerConfig from nerfstudio.data.datamanagers.semantic_datamanager import SemanticDataManagerConfig from nerfstudio.data.datamanagers.variable_res_datamanager import ( VariableResDataManagerConfig, @@ -90,10 +91,7 @@ mode="SO3xR3", optimizer=AdamOptimizerConfig(lr=6e-4, eps=1e-8, weight_decay=1e-2) ), ), - model=NerfactoModelConfig( - eval_num_rays_per_chunk=1 << 15, - robust=False, - ), + model=NerfactoModelConfig(eval_num_rays_per_chunk=1 << 15), ), optimizers={ "proposal_networks": { @@ -116,7 +114,7 @@ max_num_iterations=30000, mixed_precision=True, pipeline=VanillaPipelineConfig( - datamanager=VanillaDataManagerConfig( + datamanager=PatchDataManagerConfig( dataparser=NerfstudioDataParserConfig(), train_num_rays_per_batch=17000, eval_num_rays_per_batch=4096, diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index e9dd0f3fe8..d032c4d483 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -50,7 +50,7 @@ PhototourismDataParserConfig, ) from nerfstudio.data.datasets.base_dataset import InputDataset -from nerfstudio.data.pixel_samplers import EquirectangularPixelSampler, PixelSampler,PatchPixelSampler +from nerfstudio.data.pixel_samplers import EquirectangularPixelSampler, PixelSampler from nerfstudio.data.utils.dataloaders import ( CacheDataloader, FixedIndicesEvalDataloader, @@ -351,7 +351,7 @@ def _get_pixel_sampler( # pylint: disable=no-self-use # Otherwise, use the default pixel sampler if is_equirectangular.any(): CONSOLE.print("[bold yellow]Warning: Some cameras are equirectangular, but using default pixel sampler.") - return PatchPixelSampler(*args, **kwargs,patch_size=16) + return PixelSampler(*args, **kwargs) def setup_train(self): """Sets up the data loaders for training""" diff --git a/nerfstudio/data/datamanagers/patch_datamanager.py b/nerfstudio/data/datamanagers/patch_datamanager.py new file mode 100644 index 0000000000..15923f0221 --- /dev/null +++ b/nerfstudio/data/datamanagers/patch_datamanager.py @@ -0,0 +1,44 @@ +# Copyright 2022 The Nerfstudio Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +#from abc import abstractmethod +from dataclasses import dataclass, field +from typing import Any, Type + +from nerfstudio.cameras.cameras import CameraType +from nerfstudio.data.datamanagers import base_datamanager +from nerfstudio.data.datasets.base_dataset import InputDataset +from nerfstudio.data.pixel_samplers import EquirectangularPixelSampler, PatchPixelSampler + +@dataclass +class PatchDataManagerConfig(base_datamanager.VanillaDataManagerConfig): + _target: Type = field(default_factory=lambda: PatchDataManager) + """Target class to instantiate.""" + +class PatchDataManager(base_datamanager.VanillaDataManager): # pylint: disable=abstract-method + + def _get_pixel_sampler( # pylint: disable=no-self-use + self, dataset: InputDataset, *args: Any, **kwargs: Any + ) -> PixelSampler: + """Infer pixel sampler to use.""" + # If all images are equirectangular, use equirectangular pixel sampler + is_equirectangular = dataset.cameras.camera_type == CameraType.EQUIRECTANGULAR.value + if is_equirectangular.all(): + return EquirectangularPixelSampler(*args, **kwargs) + # Otherwise, use the default pixel sampler + if is_equirectangular.any(): + CONSOLE.print("[bold yellow]Warning: Some cameras are equirectangular, but using default pixel sampler.") + return PatchPixelSampler(*args, **kwargs,patch_size=16)