Skip to content

Commit 30cafe8

Browse files
Move util functions out of splatfacto
Nothing else currently uses some of the SH utils, but it might make sense to get them out of splatfacto. I also moved the k nearest neighbors to utils since it doesn't depend on the model class.
1 parent e8bf472 commit 30cafe8

File tree

6 files changed

+186
-160
lines changed

6 files changed

+186
-160
lines changed

nerfstudio/field_components/encodings.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@
2828

2929
from nerfstudio.field_components.base_field_component import FieldComponent
3030
from nerfstudio.utils.external import TCNN_EXISTS, tcnn
31-
from nerfstudio.utils.math import components_from_spherical_harmonics, expected_sin, generate_polyhedron_basis
31+
from nerfstudio.utils.math import expected_sin, generate_polyhedron_basis
3232
from nerfstudio.utils.printing import print_tcnn_speed_warning
33+
from nerfstudio.utils.spherical_harmonics import MAX_SH_DEGREE, components_from_spherical_harmonics
3334

3435

3536
class Encoding(FieldComponent):
@@ -762,8 +763,10 @@ class SHEncoding(Encoding):
762763
def __init__(self, levels: int = 4, implementation: Literal["tcnn", "torch"] = "torch") -> None:
763764
super().__init__(in_dim=3)
764765

765-
if levels <= 0 or levels > 4:
766-
raise ValueError(f"Spherical harmonic encoding only supports 1 to 4 levels, requested {levels}")
766+
if levels <= 0 or levels > MAX_SH_DEGREE:
767+
raise ValueError(
768+
f"Spherical harmonic encoding only supports 1 to {MAX_SH_DEGREE} levels, requested {levels}"
769+
)
767770

768771
self.levels = levels
769772

nerfstudio/model_components/renderers.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@
3838

3939
from nerfstudio.cameras.rays import RaySamples
4040
from nerfstudio.utils import colors
41-
from nerfstudio.utils.math import components_from_spherical_harmonics, safe_normalize
41+
from nerfstudio.utils.math import safe_normalize
42+
from nerfstudio.utils.spherical_harmonics import components_from_spherical_harmonics
4243

4344
BackgroundColor = Union[Literal["random", "last_sample", "black", "white"], Float[Tensor, "3"], Float[Tensor, "*bs 3"]]
4445
BACKGROUND_COLOR_OVERRIDE: Optional[Float[Tensor, "3"]] = None

nerfstudio/models/splatfacto.py

+3-86
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,9 @@
1919

2020
from __future__ import annotations
2121

22-
import math
2322
from dataclasses import dataclass, field
2423
from typing import Dict, List, Literal, Optional, Tuple, Type, Union
2524

26-
import numpy as np
2725
import torch
2826
from gsplat.strategy import DefaultStrategy
2927

@@ -42,70 +40,10 @@
4240
from nerfstudio.model_components.lib_bilagrid import BilateralGrid, color_correct, slice, total_variation_loss
4341
from nerfstudio.models.base_model import Model, ModelConfig
4442
from nerfstudio.utils.colors import get_color
43+
from nerfstudio.utils.math import k_nearest_sklearn, random_quat_tensor
4544
from nerfstudio.utils.misc import torch_compile
4645
from nerfstudio.utils.rich_utils import CONSOLE
47-
48-
49-
def num_sh_bases(degree: int) -> int:
50-
"""
51-
Returns the number of spherical harmonic bases for a given degree.
52-
"""
53-
assert degree <= 4, "We don't support degree greater than 4."
54-
return (degree + 1) ** 2
55-
56-
57-
def quat_to_rotmat(quat):
58-
assert quat.shape[-1] == 4, quat.shape
59-
w, x, y, z = torch.unbind(quat, dim=-1)
60-
mat = torch.stack(
61-
[
62-
1 - 2 * (y**2 + z**2),
63-
2 * (x * y - w * z),
64-
2 * (x * z + w * y),
65-
2 * (x * y + w * z),
66-
1 - 2 * (x**2 + z**2),
67-
2 * (y * z - w * x),
68-
2 * (x * z - w * y),
69-
2 * (y * z + w * x),
70-
1 - 2 * (x**2 + y**2),
71-
],
72-
dim=-1,
73-
)
74-
return mat.reshape(quat.shape[:-1] + (3, 3))
75-
76-
77-
def random_quat_tensor(N):
78-
"""
79-
Defines a random quaternion tensor of shape (N, 4)
80-
"""
81-
u = torch.rand(N)
82-
v = torch.rand(N)
83-
w = torch.rand(N)
84-
return torch.stack(
85-
[
86-
torch.sqrt(1 - u) * torch.sin(2 * math.pi * v),
87-
torch.sqrt(1 - u) * torch.cos(2 * math.pi * v),
88-
torch.sqrt(u) * torch.sin(2 * math.pi * w),
89-
torch.sqrt(u) * torch.cos(2 * math.pi * w),
90-
],
91-
dim=-1,
92-
)
93-
94-
95-
def RGB2SH(rgb):
96-
"""
97-
Converts from RGB values [0,1] to the 0th spherical harmonic coefficient
98-
"""
99-
C0 = 0.28209479177387814
100-
return (rgb - 0.5) / C0
101-
102-
103-
def SH2RGB(sh):
104-
"""
105-
Converts from the 0th spherical harmonic coefficient to RGB values [0,1]
106-
"""
107-
C0 = 0.28209479177387814
108-
return sh * C0 + 0.5
46+
from nerfstudio.utils.spherical_harmonics import RGB2SH, SH2RGB, num_sh_bases
10947

11048

11149
def resize_image(image: torch.Tensor, d: int):
@@ -243,8 +181,7 @@ def populate_modules(self):
243181
means = torch.nn.Parameter(self.seed_points[0]) # (Location, Color)
244182
else:
245183
means = torch.nn.Parameter((torch.rand((self.config.num_random, 3)) - 0.5) * self.config.random_scale)
246-
distances, _ = self.k_nearest_sklearn(means.data, 3)
247-
distances = torch.from_numpy(distances)
184+
distances, _ = k_nearest_sklearn(means.data, 3)
248185
# find the average of the three nearest neighbors for each point and use that as the scale
249186
avg_dist = distances.mean(dim=-1, keepdim=True)
250187
scales = torch.nn.Parameter(torch.log(avg_dist.repeat(1, 3)))
@@ -392,26 +329,6 @@ def load_state_dict(self, dict, **kwargs): # type: ignore
392329
self.gauss_params[name] = torch.nn.Parameter(torch.zeros(new_shape, device=self.device))
393330
super().load_state_dict(dict, **kwargs)
394331

395-
def k_nearest_sklearn(self, x: torch.Tensor, k: int):
396-
"""
397-
Find k-nearest neighbors using sklearn's NearestNeighbors.
398-
x: The data tensor of shape [num_samples, num_features]
399-
k: The number of neighbors to retrieve
400-
"""
401-
# Convert tensor to numpy array
402-
x_np = x.cpu().numpy()
403-
404-
# Build the nearest neighbors model
405-
from sklearn.neighbors import NearestNeighbors
406-
407-
nn_model = NearestNeighbors(n_neighbors=k + 1, algorithm="auto", metric="euclidean").fit(x_np)
408-
409-
# Find the k-nearest neighbors
410-
distances, indices = nn_model.kneighbors(x_np)
411-
412-
# Exclude the point itself from the result and return
413-
return distances[:, 1:].astype(np.float32), indices[:, 1:].astype(np.float32)
414-
415332
def set_crop(self, crop_box: Optional[OrientedBox]):
416333
self.crop_box = crop_box
417334

nerfstudio/utils/math.py

+63-69
Original file line numberDiff line numberDiff line change
@@ -20,78 +20,12 @@
2020
from typing import Literal, Tuple
2121

2222
import torch
23-
from jaxtyping import Bool, Float
23+
from jaxtyping import Bool, Float, Int
2424
from torch import Tensor
2525

2626
from nerfstudio.data.scene_box import OrientedBox
2727

2828

29-
def components_from_spherical_harmonics(
30-
levels: int, directions: Float[Tensor, "*batch 3"]
31-
) -> Float[Tensor, "*batch components"]:
32-
"""
33-
Returns value for each component of spherical harmonics.
34-
35-
Args:
36-
levels: Number of spherical harmonic levels to compute.
37-
directions: Spherical harmonic coefficients
38-
"""
39-
num_components = levels**2
40-
components = torch.zeros((*directions.shape[:-1], num_components), device=directions.device)
41-
42-
assert 1 <= levels <= 5, f"SH levels must be in [1,4], got {levels}"
43-
assert directions.shape[-1] == 3, f"Direction input should have three dimensions. Got {directions.shape[-1]}"
44-
45-
x = directions[..., 0]
46-
y = directions[..., 1]
47-
z = directions[..., 2]
48-
49-
xx = x**2
50-
yy = y**2
51-
zz = z**2
52-
53-
# l0
54-
components[..., 0] = 0.28209479177387814
55-
56-
# l1
57-
if levels > 1:
58-
components[..., 1] = 0.4886025119029199 * y
59-
components[..., 2] = 0.4886025119029199 * z
60-
components[..., 3] = 0.4886025119029199 * x
61-
62-
# l2
63-
if levels > 2:
64-
components[..., 4] = 1.0925484305920792 * x * y
65-
components[..., 5] = 1.0925484305920792 * y * z
66-
components[..., 6] = 0.9461746957575601 * zz - 0.31539156525251999
67-
components[..., 7] = 1.0925484305920792 * x * z
68-
components[..., 8] = 0.5462742152960396 * (xx - yy)
69-
70-
# l3
71-
if levels > 3:
72-
components[..., 9] = 0.5900435899266435 * y * (3 * xx - yy)
73-
components[..., 10] = 2.890611442640554 * x * y * z
74-
components[..., 11] = 0.4570457994644658 * y * (5 * zz - 1)
75-
components[..., 12] = 0.3731763325901154 * z * (5 * zz - 3)
76-
components[..., 13] = 0.4570457994644658 * x * (5 * zz - 1)
77-
components[..., 14] = 1.445305721320277 * z * (xx - yy)
78-
components[..., 15] = 0.5900435899266435 * x * (xx - 3 * yy)
79-
80-
# l4
81-
if levels > 4:
82-
components[..., 16] = 2.5033429417967046 * x * y * (xx - yy)
83-
components[..., 17] = 1.7701307697799304 * y * z * (3 * xx - yy)
84-
components[..., 18] = 0.9461746957575601 * x * y * (7 * zz - 1)
85-
components[..., 19] = 0.6690465435572892 * y * z * (7 * zz - 3)
86-
components[..., 20] = 0.10578554691520431 * (35 * zz * zz - 30 * zz + 3)
87-
components[..., 21] = 0.6690465435572892 * x * z * (7 * zz - 3)
88-
components[..., 22] = 0.47308734787878004 * (xx - yy) * (7 * zz - 1)
89-
components[..., 23] = 1.7701307697799304 * x * z * (xx - 3 * yy)
90-
components[..., 24] = 0.6258357354491761 * (xx * (xx - 3 * yy) - yy * (3 * xx - yy))
91-
92-
return components
93-
94-
9529
@dataclass
9630
class Gaussians:
9731
"""Stores Gaussians
@@ -323,7 +257,9 @@ def masked_reduction(
323257

324258

325259
def normalized_depth_scale_and_shift(
326-
prediction: Float[Tensor, "1 32 mult"], target: Float[Tensor, "1 32 mult"], mask: Bool[Tensor, "1 32 mult"]
260+
prediction: Float[Tensor, "1 32 mult"],
261+
target: Float[Tensor, "1 32 mult"],
262+
mask: Bool[Tensor, "1 32 mult"],
327263
):
328264
"""
329265
More info here: https://arxiv.org/pdf/2206.00665.pdf supplementary section A2 Depth Consistency Loss
@@ -405,7 +341,10 @@ def _compute_tesselation_weights(v: int) -> Tensor:
405341

406342

407343
def _tesselate_geodesic(
408-
vertices: Float[Tensor, "N 3"], faces: Float[Tensor, "M 3"], v: int, eps: float = 1e-4
344+
vertices: Float[Tensor, "N 3"],
345+
faces: Float[Tensor, "M 3"],
346+
v: int,
347+
eps: float = 1e-4,
409348
) -> Tensor:
410349
"""Tesselate the vertices of a geodesic polyhedron.
411350
@@ -518,3 +457,58 @@ def generate_polyhedron_basis(
518457

519458
basis = verts.flip(-1)
520459
return basis
460+
461+
462+
def random_quat_tensor(N: int) -> Float[Tensor, "*batch 4"]:
463+
"""
464+
Defines a random quaternion tensor.
465+
466+
Args:
467+
N: Number of quaternions to generate
468+
469+
Returns:
470+
a random quaternion tensor of shape (N, 4)
471+
472+
"""
473+
u = torch.rand(N)
474+
v = torch.rand(N)
475+
w = torch.rand(N)
476+
return torch.stack(
477+
[
478+
torch.sqrt(1 - u) * torch.sin(2 * math.pi * v),
479+
torch.sqrt(1 - u) * torch.cos(2 * math.pi * v),
480+
torch.sqrt(u) * torch.sin(2 * math.pi * w),
481+
torch.sqrt(u) * torch.cos(2 * math.pi * w),
482+
],
483+
dim=-1,
484+
)
485+
486+
487+
def k_nearest_sklearn(
488+
x: torch.Tensor, k: int, metric: str = "euclidean"
489+
) -> Tuple[Float[Tensor, "*batch k"], Int[Tensor, "*batch k"]]:
490+
"""
491+
Find k-nearest neighbors using sklearn's NearestNeighbors.
492+
493+
Args:
494+
x: input tensor
495+
k: number of neighbors to find
496+
metric: metric to use for distance computation
497+
498+
Returns:
499+
distances: distances to the k-nearest neighbors
500+
indices: indices of the k-nearest neighbors
501+
"""
502+
# Convert tensor to numpy array
503+
x_np = x.cpu().numpy()
504+
505+
# Build the nearest neighbors model
506+
from sklearn.neighbors import NearestNeighbors
507+
508+
nn_model = NearestNeighbors(n_neighbors=k + 1, algorithm="auto", metric=metric).fit(x_np)
509+
510+
# Find the k-nearest neighbors
511+
distances, indices = nn_model.kneighbors(x_np)
512+
513+
# Exclude the point itself from the result and return
514+
return torch.tensor(distances[:, 1:], dtype=torch.float32), torch.tensor(indices[:, 1:], dtype=torch.int64)

0 commit comments

Comments
 (0)