Skip to content

Commit

Permalink
fix instant-ngp while preserving robust for nerfacto
Browse files Browse the repository at this point in the history
  • Loading branch information
machenmusik committed Feb 9, 2023
1 parent 6417164 commit a12faff
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 7 deletions.
8 changes: 3 additions & 5 deletions nerfstudio/configs/method_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from nerfstudio.configs.base_config import ViewerConfig
from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManagerConfig
from nerfstudio.data.datamanagers.depth_datamanager import DepthDataManagerConfig
from nerfstudio.data.datamanagers.patch_datamanager import PatchDataManagerConfig
from nerfstudio.data.datamanagers.semantic_datamanager import SemanticDataManagerConfig
from nerfstudio.data.datamanagers.variable_res_datamanager import (
VariableResDataManagerConfig,
Expand Down Expand Up @@ -90,10 +91,7 @@
mode="SO3xR3", optimizer=AdamOptimizerConfig(lr=6e-4, eps=1e-8, weight_decay=1e-2)
),
),
model=NerfactoModelConfig(
eval_num_rays_per_chunk=1 << 15,
robust=False,
),
model=NerfactoModelConfig(eval_num_rays_per_chunk=1 << 15),
),
optimizers={
"proposal_networks": {
Expand All @@ -116,7 +114,7 @@
max_num_iterations=30000,
mixed_precision=True,
pipeline=VanillaPipelineConfig(
datamanager=VanillaDataManagerConfig(
datamanager=PatchDataManagerConfig(
dataparser=NerfstudioDataParserConfig(),
train_num_rays_per_batch=17000,
eval_num_rays_per_batch=4096,
Expand Down
4 changes: 2 additions & 2 deletions nerfstudio/data/datamanagers/base_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
PhototourismDataParserConfig,
)
from nerfstudio.data.datasets.base_dataset import InputDataset
from nerfstudio.data.pixel_samplers import EquirectangularPixelSampler, PixelSampler,PatchPixelSampler
from nerfstudio.data.pixel_samplers import EquirectangularPixelSampler, PixelSampler
from nerfstudio.data.utils.dataloaders import (
CacheDataloader,
FixedIndicesEvalDataloader,
Expand Down Expand Up @@ -351,7 +351,7 @@ def _get_pixel_sampler( # pylint: disable=no-self-use
# Otherwise, use the default pixel sampler
if is_equirectangular.any():
CONSOLE.print("[bold yellow]Warning: Some cameras are equirectangular, but using default pixel sampler.")
return PatchPixelSampler(*args, **kwargs,patch_size=16)
return PixelSampler(*args, **kwargs)

def setup_train(self):
"""Sets up the data loaders for training"""
Expand Down
44 changes: 44 additions & 0 deletions nerfstudio/data/datamanagers/patch_datamanager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2022 The Nerfstudio Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

#from abc import abstractmethod
from dataclasses import dataclass, field
from typing import Any, Type

from nerfstudio.cameras.cameras import CameraType
from nerfstudio.data.datamanagers import base_datamanager
from nerfstudio.data.datasets.base_dataset import InputDataset
from nerfstudio.data.pixel_samplers import EquirectangularPixelSampler, PatchPixelSampler

@dataclass
class PatchDataManagerConfig(base_datamanager.VanillaDataManagerConfig):
_target: Type = field(default_factory=lambda: PatchDataManager)
"""Target class to instantiate."""

class PatchDataManager(base_datamanager.VanillaDataManager): # pylint: disable=abstract-method

def _get_pixel_sampler( # pylint: disable=no-self-use
self, dataset: InputDataset, *args: Any, **kwargs: Any
) -> PixelSampler:
"""Infer pixel sampler to use."""
# If all images are equirectangular, use equirectangular pixel sampler
is_equirectangular = dataset.cameras.camera_type == CameraType.EQUIRECTANGULAR.value
if is_equirectangular.all():
return EquirectangularPixelSampler(*args, **kwargs)
# Otherwise, use the default pixel sampler
if is_equirectangular.any():
CONSOLE.print("[bold yellow]Warning: Some cameras are equirectangular, but using default pixel sampler.")
return PatchPixelSampler(*args, **kwargs,patch_size=16)

0 comments on commit a12faff

Please sign in to comment.