Skip to content

Commit

Permalink
Merge branch 'nerfstudio-project:main' into convert-cameras-json
Browse files Browse the repository at this point in the history
  • Loading branch information
VasuAgrawal authored Jan 11, 2024
2 parents e96f6db + 3b6ca91 commit aa2bc6b
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 74 deletions.
15 changes: 5 additions & 10 deletions nerfstudio/configs/experiment_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,7 @@

import yaml

from nerfstudio.configs.base_config import (
InstantiateConfig,
LoggingConfig,
MachineConfig,
ViewerConfig,
)
from nerfstudio.configs.base_config import InstantiateConfig, LoggingConfig, MachineConfig, ViewerConfig
from nerfstudio.configs.config_utils import to_immutable_dict
from nerfstudio.engine.optimizers import OptimizerConfig
from nerfstudio.engine.schedulers import SchedulerConfig
Expand All @@ -51,13 +46,13 @@ class ExperimentConfig(InstantiateConfig):
"""Project name."""
timestamp: str = "{timestamp}"
"""Experiment timestamp."""
machine: MachineConfig = field(default_factory=lambda: MachineConfig())
machine: MachineConfig = field(default_factory=MachineConfig)
"""Machine configuration"""
logging: LoggingConfig = field(default_factory=lambda: LoggingConfig())
logging: LoggingConfig = field(default_factory=LoggingConfig)
"""Logging configuration"""
viewer: ViewerConfig = field(default_factory=lambda: ViewerConfig())
viewer: ViewerConfig = field(default_factory=ViewerConfig)
"""Viewer configuration"""
pipeline: VanillaPipelineConfig = field(default_factory=lambda: VanillaPipelineConfig())
pipeline: VanillaPipelineConfig = field(default_factory=VanillaPipelineConfig)
"""Pipeline configuration"""
optimizers: Dict[str, Any] = to_immutable_dict(
{
Expand Down
4 changes: 2 additions & 2 deletions nerfstudio/data/datamanagers/base_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ class VanillaDataManagerConfig(DataManagerConfig):

_target: Type = field(default_factory=lambda: VanillaDataManager)
"""Target class to instantiate."""
dataparser: AnnotatedDataParserUnion = field(default_factory=lambda: BlenderDataParserConfig())
dataparser: AnnotatedDataParserUnion = field(default_factory=BlenderDataParserConfig)
"""Specifies the dataparser used to unpack the data."""
train_num_rays_per_batch: int = 1024
"""Number of rays per batch to use per training iteration."""
Expand Down Expand Up @@ -344,7 +344,7 @@ class VanillaDataManagerConfig(DataManagerConfig):
"""Size of patch to sample from. If > 1, patch-based sampling will be used."""
camera_optimizer: Optional[CameraOptimizerConfig] = field(default=None)
"""Deprecated, has been moved to the model config."""
pixel_sampler: PixelSamplerConfig = field(default_factory=lambda: PixelSamplerConfig())
pixel_sampler: PixelSamplerConfig = field(default_factory=PixelSamplerConfig)
"""Specifies the pixel sampler used to sample pixels from images."""

def __post_init__(self):
Expand Down
30 changes: 18 additions & 12 deletions nerfstudio/data/datamanagers/full_images_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from __future__ import annotations

import random
from copy import deepcopy
from dataclasses import dataclass, field
from functools import cached_property
from pathlib import Path
Expand All @@ -30,7 +31,6 @@
import cv2
import numpy as np
import torch
from copy import deepcopy
from torch.nn import Parameter
from tqdm import tqdm

Expand All @@ -47,7 +47,7 @@
@dataclass
class FullImageDatamanagerConfig(DataManagerConfig):
_target: Type = field(default_factory=lambda: FullImageDatamanager)
dataparser: AnnotatedDataParserUnion = NerfstudioDataParserConfig()
dataparser: AnnotatedDataParserUnion = field(default_factory=NerfstudioDataParserConfig)
camera_res_scale_factor: float = 1.0
"""The scale factor for scaling spatial data such as images, mask, semantics
along with relevant information about camera intrinsics
Expand Down Expand Up @@ -133,7 +133,6 @@ def cache_images(self, cache_images_option):
continue
distortion_params = camera.distortion_params.numpy()
image = data["image"].numpy()

if camera.camera_type.item() == CameraType.PERSPECTIVE.value:
distortion_params = np.array(
[
Expand All @@ -147,13 +146,15 @@ def cache_images(self, cache_images_option):
0,
]
)
newK, roi = cv2.getOptimalNewCameraMatrix(K, distortion_params, (image.shape[1], image.shape[0]), 0)
image = cv2.undistort(image, K, distortion_params, None, newK) # type: ignore
if np.any(distortion_params):
newK, roi = cv2.getOptimalNewCameraMatrix(K, distortion_params, (image.shape[1], image.shape[0]), 0)
image = cv2.undistort(image, K, distortion_params, None, newK) # type: ignore
else:
newK = K
roi = 0, 0, image.shape[1], image.shape[0]
# crop the image and update the intrinsics accordingly
x, y, w, h = roi
image = image[y : y + h, x : x + w]
if "mask" in data:
data["mask"] = data["mask"][y : y + h, x : x + w]
if "depth_image" in data:
data["depth_image"] = data["depth_image"][y : y + h, x : x + w]
# update the width, height
Expand All @@ -162,7 +163,8 @@ def cache_images(self, cache_images_option):
if "mask" in data:
mask = data["mask"].numpy()
mask = mask.astype(np.uint8) * 255
mask = cv2.undistort(mask, K, distortion_params, None, newK) # type: ignore
if np.any(distortion_params):
mask = cv2.undistort(mask, K, distortion_params, None, newK) # type: ignore
mask = mask[y : y + h, x : x + w]
data["mask"] = torch.from_numpy(mask).bool()
K = newK
Expand Down Expand Up @@ -206,7 +208,6 @@ def cache_images(self, cache_images_option):
continue
distortion_params = camera.distortion_params.numpy()
image = data["image"].numpy()

if camera.camera_type.item() == CameraType.PERSPECTIVE.value:
distortion_params = np.array(
[
Expand All @@ -220,8 +221,12 @@ def cache_images(self, cache_images_option):
0,
]
)
newK, roi = cv2.getOptimalNewCameraMatrix(K, distortion_params, (image.shape[1], image.shape[0]), 0)
image = cv2.undistort(image, K, distortion_params, None, newK) # type: ignore
if np.any(distortion_params):
newK, roi = cv2.getOptimalNewCameraMatrix(K, distortion_params, (image.shape[1], image.shape[0]), 0)
image = cv2.undistort(image, K, distortion_params, None, newK) # type: ignore
else:
newK = K
roi = 0, 0, image.shape[1], image.shape[0]
# crop the image and update the intrinsics accordingly
x, y, w, h = roi
image = image[y : y + h, x : x + w]
Expand All @@ -231,7 +236,8 @@ def cache_images(self, cache_images_option):
if "mask" in data:
mask = data["mask"].numpy()
mask = mask.astype(np.uint8) * 255
mask = cv2.undistort(mask, K, distortion_params, None, newK) # type: ignore
if np.any(distortion_params):
mask = cv2.undistort(mask, K, distortion_params, None, newK) # type: ignore
mask = mask[y : y + h, x : x + w]
data["mask"] = torch.from_numpy(mask).bool()
K = newK
Expand Down
16 changes: 3 additions & 13 deletions nerfstudio/models/base_surface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,9 @@
from nerfstudio.fields.nerfacto_field import NerfactoField
from nerfstudio.fields.sdf_field import SDFFieldConfig
from nerfstudio.fields.vanilla_nerf_field import NeRFField
from nerfstudio.model_components.losses import (
L1Loss,
MSELoss,
ScaleAndShiftInvariantLoss,
monosdf_normal_loss,
)
from nerfstudio.model_components.losses import L1Loss, MSELoss, ScaleAndShiftInvariantLoss, monosdf_normal_loss
from nerfstudio.model_components.ray_samplers import LinearDisparitySampler
from nerfstudio.model_components.renderers import (
AccumulationRenderer,
DepthRenderer,
RGBRenderer,
SemanticRenderer,
)
from nerfstudio.model_components.renderers import AccumulationRenderer, DepthRenderer, RGBRenderer, SemanticRenderer
from nerfstudio.model_components.scene_colliders import AABBBoxCollider, NearFarCollider
from nerfstudio.models.base_model import Model, ModelConfig
from nerfstudio.utils import colormaps
Expand Down Expand Up @@ -79,7 +69,7 @@ class SurfaceModelConfig(ModelConfig):
"""Monocular normal consistency loss multiplier."""
mono_depth_loss_mult: float = 0.0
"""Monocular depth consistency loss multiplier."""
sdf_field: SDFFieldConfig = field(default_factory=lambda: SDFFieldConfig())
sdf_field: SDFFieldConfig = field(default_factory=SDFFieldConfig)
"""Config for SDF Field"""
background_model: Literal["grid", "mlp", "none"] = "mlp"
"""background models"""
Expand Down
30 changes: 14 additions & 16 deletions nerfstudio/models/gaussian_splatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,35 +19,33 @@

from __future__ import annotations

import math
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Type, Union
from nerfstudio.data.scene_box import OrientedBox

import numpy as np
import torch
import torchvision.transforms.functional as TF
from gsplat._torch_impl import quat_to_rotmat
from gsplat.compute_cumulative_intersects import compute_cumulative_intersects
from gsplat.project_gaussians import ProjectGaussians
from gsplat.rasterize import RasterizeGaussians
from gsplat.sh import SphericalHarmonics, num_sh_bases
from pytorch_msssim import SSIM
from sklearn.neighbors import NearestNeighbors
from torch.nn import Parameter
from torchmetrics.image import PeakSignalNoiseRatio
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
import torchvision.transforms.functional as TF

from nerfstudio.cameras.camera_optimizers import CameraOptimizer, CameraOptimizerConfig
from nerfstudio.cameras.cameras import Cameras
from gsplat._torch_impl import quat_to_rotmat
from nerfstudio.data.scene_box import OrientedBox
from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes, TrainingCallbackLocation
from nerfstudio.engine.optimizers import Optimizers
from nerfstudio.models.base_model import Model, ModelConfig
import math
import numpy as np
from sklearn.neighbors import NearestNeighbors
from nerfstudio.cameras.camera_optimizers import CameraOptimizer, CameraOptimizerConfig

from gsplat.rasterize import RasterizeGaussians
from gsplat.project_gaussians import ProjectGaussians
from gsplat.sh import SphericalHarmonics, num_sh_bases

from gsplat.compute_cumulative_intersects import compute_cumulative_intersects
from pytorch_msssim import SSIM

# need following import for background color override
from nerfstudio.model_components import renderers
from nerfstudio.models.base_model import Model, ModelConfig
from nerfstudio.utils.rich_utils import CONSOLE


Expand Down Expand Up @@ -149,7 +147,7 @@ class GaussianSplattingModelConfig(ModelConfig):
"""stop splitting at this step"""
sh_degree: int = 3
"""maximum degree of spherical harmonics to use"""
camera_optimizer: CameraOptimizerConfig = CameraOptimizerConfig(mode="off")
camera_optimizer: CameraOptimizerConfig = field(default_factory=CameraOptimizerConfig)
"""camera optimizer config"""
use_scale_regularization: bool = False
"""If enabled, a scale regularization introduced in PhysGauss (https://xpandora.github.io/PhysGaussian/) is used for reducing huge spikey gaussians."""
Expand Down
16 changes: 6 additions & 10 deletions nerfstudio/pipelines/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,10 @@
from torch.nn import Parameter
from torch.nn.parallel import DistributedDataParallel as DDP

from nerfstudio.configs import base_config as cfg
from nerfstudio.data.datamanagers.base_datamanager import (
DataManager,
DataManagerConfig,
VanillaDataManager,
)
from nerfstudio.data.datamanagers.parallel_datamanager import ParallelDataManager
from nerfstudio.configs.base_config import InstantiateConfig
from nerfstudio.data.datamanagers.base_datamanager import DataManager, DataManagerConfig, VanillaDataManager
from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanager
from nerfstudio.data.datamanagers.parallel_datamanager import ParallelDataManager
from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes
from nerfstudio.models.base_model import Model, ModelConfig
from nerfstudio.utils import profiler
Expand Down Expand Up @@ -213,14 +209,14 @@ def get_param_groups(self) -> Dict[str, List[Parameter]]:


@dataclass
class VanillaPipelineConfig(cfg.InstantiateConfig):
class VanillaPipelineConfig(InstantiateConfig):
"""Configuration for pipeline instantiation"""

_target: Type = field(default_factory=lambda: VanillaPipeline)
"""target class to instantiate"""
datamanager: DataManagerConfig = field(default_factory=lambda: DataManagerConfig())
datamanager: DataManagerConfig = field(default_factory=DataManagerConfig)
"""specifies the datamanager config"""
model: ModelConfig = field(default_factory=lambda: ModelConfig())
model: ModelConfig = field(default_factory=ModelConfig)
"""specifies the model config"""


Expand Down
26 changes: 21 additions & 5 deletions nerfstudio/viewer/render_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,21 @@

import contextlib
import threading
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple, get_args

import numpy as np
import torch
import torch.nn.functional as F
from nerfstudio.cameras.cameras import Cameras
from nerfstudio.model_components.renderers import background_color_override_context
from nerfstudio.models.gaussian_splatting import GaussianSplattingModel
from nerfstudio.utils import colormaps, writer
from nerfstudio.utils.writer import GLOBAL_BUFFER, EventName, TimeWriter
from nerfstudio.viewer_legacy.server import viewer_utils
from nerfstudio.viewer.utils import CameraState, get_camera
from nerfstudio.viewer_legacy.server import viewer_utils

from viser import ClientHandle

if TYPE_CHECKING:
Expand Down Expand Up @@ -166,11 +169,21 @@ def _render_img(self, camera_state: CameraState):
num_rays = (camera.height * camera.width).item()
if self.viewer.control_panel.layer_depth:
if isinstance(self.viewer.get_model(), GaussianSplattingModel):
# TODO: sending depth at high resolution lags the network a lot, figure out how to do this more efficiently
# outputs["gl_z_buf_depth"] = outputs["depth"]
pass
# Gaussians render much faster than we can send depth images, so we do some downsampling.
assert len(outputs["depth"].shape) == 3
assert outputs["depth"].shape[-1] == 1

desired_depth_pixels = {"low_move": 128, "low_static": 128, "high": 512}[self.state] ** 2
current_depth_pixels = outputs["depth"].shape[0] * outputs["depth"].shape[1]
scale = min(desired_depth_pixels / current_depth_pixels, 1.0)

outputs["gl_z_buf_depth"] = F.interpolate(
outputs["depth"].squeeze(dim=-1)[None, None, ...],
size=(int(outputs["depth"].shape[0] * scale), int(outputs["depth"].shape[1] * scale)),
mode="bilinear",
)[0, 0, :, :, None]
else:
# convert to z_depth if depth compositing is enabled
# Convert to z_depth if depth compositing is enabled.
R = camera.camera_to_worlds[0, 0:3, 0:3].T
camera_ray_bundle = camera.generate_rays(camera_indices=0, obb_box=obb)
pts = camera_ray_bundle.directions * outputs["depth"]
Expand All @@ -186,6 +199,9 @@ def _render_img(self, camera_state: CameraState):
def run(self):
"""Main loop for the render thread"""
while self.running:
if not self.viewer.ready:
time.sleep(0.1)
continue
if not self.render_trigger.wait(0.2):
# if we haven't received a trigger in a while, send a static action
self.action(RenderAction(action="static", camera_state=self.viewer.get_camera_state(self.client)))
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dependencies = [
"msgpack_numpy>=0.4.8",
"nerfacc==0.5.2",
"open3d>=0.16.0",
"opencv-python==4.6.0.66",
"opencv-python==4.8.0.76",
"Pillow>=9.3.0",
"plotly>=5.7.0",
"protobuf<=3.20.3,!=3.20.0",
Expand Down
9 changes: 4 additions & 5 deletions tests/plugins/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
"""
import os
import sys
from dataclasses import dataclass
from dataclasses import dataclass, field

from nerfstudio.engine.trainer import TrainerConfig
from nerfstudio.pipelines.base_pipeline import VanillaPipelineConfig
from nerfstudio.plugins import registry
from nerfstudio.plugins import registry, registry_dataparser
from nerfstudio.plugins.registry_dataparser import DataParserConfig, DataParserSpecification, discover_dataparsers
from nerfstudio.plugins.types import MethodSpecification
from nerfstudio.plugins import registry_dataparser
from nerfstudio.plugins.registry_dataparser import DataParserSpecification, discover_dataparsers, DataParserConfig

if sys.version_info < (3, 10):
import importlib_metadata
Expand Down Expand Up @@ -100,7 +99,7 @@ def test_discover_methods_from_environment_variable_instance():

@dataclass
class TestDataparserConfigClass(DataParserSpecification):
config: DataParserConfig = DataParserConfig()
config: DataParserConfig = field(default_factory=DataParserConfig)
description: str = "Test description"


Expand Down

0 comments on commit aa2bc6b

Please sign in to comment.