Skip to content

Commit 0d746b1

Browse files
Mxbonnjkulhanek
andauthored
Fourier Feature encodings and polyhedron encodings (#2463)
* Fourier Feature encodings and polyhedron encodings Rework RFFEncoding to be a subclass of the more general Fourier Feature Encodings and introduce Polyhedron encodings as introduced in mipnerf360. * b_matrix -> basis * Add typing * fighting pyright * use scale argument * continue the fight * ignore * ignore em all * add docstring and rename generate_basis to generate_polyhedron_basis * Try to please pyright with assert * Immediately allocate tensor on correct device * private functions and docstrings update * doc fix continued --------- Co-authored-by: Jonáš Kulhánek <[email protected]>
1 parent c2f5e68 commit 0d746b1

File tree

2 files changed

+237
-22
lines changed

2 files changed

+237
-22
lines changed

nerfstudio/field_components/encodings.py

+78-20
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,13 @@
2727
from torch import Tensor, nn
2828

2929
from nerfstudio.field_components.base_field_component import FieldComponent
30-
from nerfstudio.utils.math import components_from_spherical_harmonics, expected_sin
30+
from nerfstudio.utils.external import TCNN_EXISTS, tcnn
31+
from nerfstudio.utils.math import (
32+
components_from_spherical_harmonics,
33+
expected_sin,
34+
generate_polyhedron_basis,
35+
)
3136
from nerfstudio.utils.printing import print_tcnn_speed_warning
32-
from nerfstudio.utils.external import tcnn, TCNN_EXISTS
3337

3438

3539
class Encoding(FieldComponent):
@@ -153,7 +157,7 @@ def pytorch_fwd(
153157
Output values will be between -1 and 1
154158
"""
155159
scaled_in_tensor = 2 * torch.pi * in_tensor # scale to [0, 2pi]
156-
freqs = 2 ** torch.linspace(self.min_freq, self.max_freq, self.num_frequencies).to(in_tensor.device)
160+
freqs = 2 ** torch.linspace(self.min_freq, self.max_freq, self.num_frequencies, device=in_tensor.device)
157161
scaled_inputs = scaled_in_tensor[..., None] * freqs # [..., "input_dim", "num_scales"]
158162
scaled_inputs = scaled_inputs.view(*scaled_inputs.shape[:-2], -1) # [..., "input_dim" * "num_scales"]
159163

@@ -178,34 +182,40 @@ def forward(
178182
return self.pytorch_fwd(in_tensor, covs)
179183

180184

181-
class RFFEncoding(Encoding):
182-
"""Random Fourier Feature encoding. Supports integrated encodings.
185+
class FFEncoding(Encoding):
186+
"""Fourier Feature encoding. Supports integrated encodings.
183187
184188
Args:
185189
in_dim: Input dimension of tensor
186-
num_frequencies: Number of encoding frequencies
187-
scale: Std of Gaussian to sample frequencies. Must be greater than zero
190+
basis: Basis matrix from which to construct the Fourier features.
191+
num_frequencies: Number of encoded frequencies per axis
192+
min_freq_exp: Minimum frequency exponent
193+
max_freq_exp: Maximum frequency exponent
188194
include_input: Append the input coordinate to the encoding
189195
"""
190196

191-
def __init__(self, in_dim: int, num_frequencies: int, scale: float, include_input: bool = False) -> None:
197+
def __init__(
198+
self,
199+
in_dim: int,
200+
basis: Float[Tensor, "M N"],
201+
num_frequencies: int,
202+
min_freq_exp: float,
203+
max_freq_exp: float,
204+
include_input: bool = False,
205+
) -> None:
192206
super().__init__(in_dim)
193-
194207
self.num_frequencies = num_frequencies
195-
if not scale > 0:
196-
raise ValueError("RFF encoding scale should be greater than zero")
197-
self.scale = scale
198-
if self.in_dim is None:
199-
raise ValueError("Input dimension has not been set")
200-
b_matrix = torch.normal(mean=0, std=self.scale, size=(self.in_dim, self.num_frequencies))
201-
self.register_buffer(name="b_matrix", tensor=b_matrix)
208+
self.min_freq = min_freq_exp
209+
self.max_freq = max_freq_exp
210+
self.register_buffer(name="b_matrix", tensor=basis)
202211
self.include_input = include_input
203212

204213
def get_out_dim(self) -> int:
205-
out_dim = self.num_frequencies * 2
214+
if self.in_dim is None:
215+
raise ValueError("Input dimension has not been set")
216+
assert isinstance(self.b_matrix, Tensor)
217+
out_dim = self.b_matrix.shape[1] * self.num_frequencies * 2
206218
if self.include_input:
207-
if self.in_dim is None:
208-
raise ValueError("Input dimension has not been set")
209219
out_dim += self.in_dim
210220
return out_dim
211221

@@ -214,7 +224,7 @@ def forward(
214224
in_tensor: Float[Tensor, "*bs input_dim"],
215225
covs: Optional[Float[Tensor, "*bs input_dim input_dim"]] = None,
216226
) -> Float[Tensor, "*bs output_dim"]:
217-
"""Calculates RFF encoding. If covariances are provided the encodings will be integrated as proposed
227+
"""Calculates FF encoding. If covariances are provided the encodings will be integrated as proposed
218228
in mip-NeRF.
219229
220230
Args:
@@ -226,11 +236,16 @@ def forward(
226236
"""
227237
scaled_in_tensor = 2 * torch.pi * in_tensor # scale to [0, 2pi]
228238
scaled_inputs = scaled_in_tensor @ self.b_matrix # [..., "num_frequencies"]
239+
freqs = 2 ** torch.linspace(self.min_freq, self.max_freq, self.num_frequencies, device=in_tensor.device)
240+
scaled_inputs = scaled_inputs[..., None] * freqs # [..., "input_dim", "num_scales"]
241+
scaled_inputs = scaled_inputs.view(*scaled_inputs.shape[:-2], -1) # [..., "input_dim" * "num_scales"]
229242

230243
if covs is None:
231244
encoded_inputs = torch.sin(torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1))
232245
else:
233246
input_var = torch.sum((covs @ self.b_matrix) * self.b_matrix, -2)
247+
input_var = input_var[..., :, None] * freqs[None, :] ** 2
248+
input_var = input_var.reshape((*input_var.shape[:-2], -1))
234249
encoded_inputs = expected_sin(
235250
torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1), torch.cat(2 * [input_var], dim=-1)
236251
)
@@ -241,6 +256,49 @@ def forward(
241256
return encoded_inputs
242257

243258

259+
class RFFEncoding(FFEncoding):
260+
"""Random Fourier Feature encoding. Supports integrated encodings.
261+
262+
Args:
263+
in_dim: Input dimension of tensor
264+
num_frequencies: Number of encoding frequencies
265+
scale: Std of Gaussian to sample frequencies. Must be greater than zero
266+
include_input: Append the input coordinate to the encoding
267+
"""
268+
269+
def __init__(self, in_dim: int, num_frequencies: int, scale: float, include_input: bool = False) -> None:
270+
if not scale > 0:
271+
raise ValueError("RFF encoding scale should be greater than zero")
272+
273+
b_matrix = torch.normal(mean=0, std=scale, size=(in_dim, num_frequencies))
274+
super().__init__(in_dim, b_matrix, 1, 0.0, 0.0, include_input)
275+
276+
277+
class PolyhedronFFEncoding(FFEncoding):
278+
"""Fourier Feature encoding using polyhedron basis as proposed by mip-NeRF360. Supports integrated encodings.
279+
280+
Args:
281+
num_frequencies: Number of encoded frequencies per axis
282+
min_freq_exp: Minimum frequency exponent
283+
max_freq_exp: Maximum frequency exponent
284+
basis_shape: Shape of polyhedron basis. Either "octahedron" or "icosahedron"
285+
basis_subdivisions: Number of times to tesselate the polyhedron.
286+
include_input: Append the input coordinate to the encoding
287+
"""
288+
289+
def __init__(
290+
self,
291+
num_frequencies: int,
292+
min_freq_exp: float,
293+
max_freq_exp: float,
294+
basis_shape: Literal["octahedron", "icosahedron"] = "octahedron",
295+
basis_subdivisions: int = 1,
296+
include_input: bool = False,
297+
) -> None:
298+
basis_t = generate_polyhedron_basis(basis_shape, basis_subdivisions).T
299+
super().__init__(3, basis_t, num_frequencies, min_freq_exp, max_freq_exp, include_input)
300+
301+
244302
class HashEncoding(Encoding):
245303
"""Hash encoding
246304

nerfstudio/utils/math.py

+159-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
""" Math Helper Functions """
1616

17+
import itertools
18+
import math
1719
from dataclasses import dataclass
1820
from typing import Literal, Tuple
1921

@@ -195,7 +197,6 @@ def expected_sin(x_means: torch.Tensor, x_vars: torch.Tensor) -> torch.Tensor:
195197
Returns:
196198
torch.Tensor: The expected value of sin.
197199
"""
198-
199200
return torch.exp(-0.5 * x_vars) * torch.sin(x_means)
200201

201202

@@ -360,4 +361,160 @@ def normalized_depth_scale_and_shift(
360361
shift[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid]
361362

362363
return scale, shift
363-
return scale, shift
364+
365+
366+
def columnwise_squared_l2_distance(
367+
x: Float[Tensor, "*M N"],
368+
y: Float[Tensor, "*M N"],
369+
) -> Float[Tensor, "N N"]:
370+
"""Compute the squared Euclidean distance between all pairs of columns.
371+
Adapted from https://github.com/google-research/multinerf/blob/5b4d4f64608ec8077222c52fdf814d40acc10bc1/internal/geopoly.py
372+
373+
Args:
374+
x: tensor of floats, with shape [M, N].
375+
y: tensor of floats, with shape [M, N].
376+
Returns:
377+
sq_dist: tensor of floats, with shape [N, N].
378+
"""
379+
# Use the fact that ||x - y||^2 == ||x||^2 + ||y||^2 - 2 x^T y.
380+
sq_norm_x = torch.sum(x**2, 0)
381+
sq_norm_y = torch.sum(y**2, 0)
382+
sq_dist = sq_norm_x[:, None] + sq_norm_y[None, :] - 2 * x.T @ y
383+
return sq_dist
384+
385+
386+
def _compute_tesselation_weights(v: int) -> Tensor:
387+
"""Tesselate the vertices of a triangle by a factor of `v`.
388+
Adapted from https://github.com/google-research/multinerf/blob/5b4d4f64608ec8077222c52fdf814d40acc10bc1/internal/geopoly.py
389+
390+
Args:
391+
v: int, the factor of the tesselation (v==1 is a no-op to the triangle).
392+
393+
Returns:
394+
weights: tesselated weights.
395+
"""
396+
if v < 1:
397+
raise ValueError(f"v {v} must be >= 1")
398+
int_weights = []
399+
for i in range(v + 1):
400+
for j in range(v + 1 - i):
401+
int_weights.append((i, j, v - (i + j)))
402+
int_weights = torch.FloatTensor(int_weights)
403+
weights = int_weights / v # Barycentric weights.
404+
return weights
405+
406+
407+
def _tesselate_geodesic(
408+
vertices: Float[Tensor, "N 3"], faces: Float[Tensor, "M 3"], v: int, eps: float = 1e-4
409+
) -> Tensor:
410+
"""Tesselate the vertices of a geodesic polyhedron.
411+
412+
Adapted from https://github.com/google-research/multinerf/blob/5b4d4f64608ec8077222c52fdf814d40acc10bc1/internal/geopoly.py
413+
414+
Args:
415+
vertices: tensor of floats, the vertex coordinates of the geodesic.
416+
faces: tensor of ints, the indices of the vertices of base_verts that
417+
constitute eachface of the polyhedra.
418+
v: int, the factor of the tesselation (v==1 is a no-op).
419+
eps: float, a small value used to determine if two vertices are the same.
420+
421+
Returns:
422+
verts: a tensor of floats, the coordinates of the tesselated vertices.
423+
"""
424+
tri_weights = _compute_tesselation_weights(v)
425+
426+
verts = []
427+
for face in faces:
428+
new_verts = torch.matmul(tri_weights, vertices[face, :])
429+
new_verts /= torch.sqrt(torch.sum(new_verts**2, 1, keepdim=True))
430+
verts.append(new_verts)
431+
verts = torch.concatenate(verts, 0)
432+
433+
sq_dist = columnwise_squared_l2_distance(verts.T, verts.T)
434+
assignment = torch.tensor([torch.min(torch.argwhere(d <= eps)) for d in sq_dist])
435+
unique = torch.unique(assignment)
436+
verts = verts[unique, :]
437+
return verts
438+
439+
440+
def generate_polyhedron_basis(
441+
basis_shape: Literal["icosahedron", "octahedron"],
442+
angular_tesselation: int,
443+
remove_symmetries: bool = True,
444+
eps: float = 1e-4,
445+
) -> Tensor:
446+
"""Generates a 3D basis by tesselating a geometric polyhedron.
447+
Basis is used to construct Fourier features for positional encoding.
448+
See Mip-Nerf360 paper: https://arxiv.org/abs/2111.12077
449+
Adapted from https://github.com/google-research/multinerf/blob/5b4d4f64608ec8077222c52fdf814d40acc10bc1/internal/geopoly.py
450+
451+
Args:
452+
base_shape: string, the name of the starting polyhedron, must be either
453+
'icosahedron' or 'octahedron'.
454+
angular_tesselation: int, the number of times to tesselate the polyhedron,
455+
must be >= 1 (a value of 1 is a no-op to the polyhedron).
456+
remove_symmetries: bool, if True then remove the symmetric basis columns,
457+
which is usually a good idea because otherwise projections onto the basis
458+
will have redundant negative copies of each other.
459+
eps: float, a small number used to determine symmetries.
460+
461+
Returns:
462+
basis: a matrix with shape [3, n].
463+
"""
464+
if basis_shape == "icosahedron":
465+
a = (math.sqrt(5) + 1) / 2
466+
verts = torch.FloatTensor(
467+
[
468+
(-1, 0, a),
469+
(1, 0, a),
470+
(-1, 0, -a),
471+
(1, 0, -a),
472+
(0, a, 1),
473+
(0, a, -1),
474+
(0, -a, 1),
475+
(0, -a, -1),
476+
(a, 1, 0),
477+
(-a, 1, 0),
478+
(a, -1, 0),
479+
(-a, -1, 0),
480+
]
481+
) / math.sqrt(a + 2)
482+
faces = torch.tensor(
483+
[
484+
(0, 4, 1),
485+
(0, 9, 4),
486+
(9, 5, 4),
487+
(4, 5, 8),
488+
(4, 8, 1),
489+
(8, 10, 1),
490+
(8, 3, 10),
491+
(5, 3, 8),
492+
(5, 2, 3),
493+
(2, 7, 3),
494+
(7, 10, 3),
495+
(7, 6, 10),
496+
(7, 11, 6),
497+
(11, 0, 6),
498+
(0, 1, 6),
499+
(6, 1, 10),
500+
(9, 0, 11),
501+
(9, 11, 2),
502+
(9, 2, 5),
503+
(7, 2, 11),
504+
]
505+
)
506+
verts = _tesselate_geodesic(verts, faces, angular_tesselation)
507+
elif basis_shape == "octahedron":
508+
verts = torch.FloatTensor([(0, 0, -1), (0, 0, 1), (0, -1, 0), (0, 1, 0), (-1, 0, 0), (1, 0, 0)])
509+
corners = torch.FloatTensor(list(itertools.product([-1, 1], repeat=3)))
510+
pairs = torch.argwhere(columnwise_squared_l2_distance(corners.T, verts.T) == 2)
511+
faces, _ = torch.sort(torch.reshape(pairs[:, 1], [3, -1]).T, 1)
512+
verts = _tesselate_geodesic(verts, faces, angular_tesselation)
513+
514+
if remove_symmetries:
515+
# Remove elements of `verts` that are reflections of each other.
516+
match = columnwise_squared_l2_distance(verts.T, -verts.T) < eps
517+
verts = verts[torch.any(torch.triu(match), 1), :]
518+
519+
basis = verts.flip(-1)
520+
return basis

0 commit comments

Comments
 (0)