27
27
from torch import Tensor , nn
28
28
29
29
from nerfstudio .field_components .base_field_component import FieldComponent
30
- from nerfstudio .utils .math import components_from_spherical_harmonics , expected_sin
30
+ from nerfstudio .utils .external import TCNN_EXISTS , tcnn
31
+ from nerfstudio .utils .math import (
32
+ components_from_spherical_harmonics ,
33
+ expected_sin ,
34
+ generate_polyhedron_basis ,
35
+ )
31
36
from nerfstudio .utils .printing import print_tcnn_speed_warning
32
- from nerfstudio .utils .external import tcnn , TCNN_EXISTS
33
37
34
38
35
39
class Encoding (FieldComponent ):
@@ -153,7 +157,7 @@ def pytorch_fwd(
153
157
Output values will be between -1 and 1
154
158
"""
155
159
scaled_in_tensor = 2 * torch .pi * in_tensor # scale to [0, 2pi]
156
- freqs = 2 ** torch .linspace (self .min_freq , self .max_freq , self .num_frequencies ). to ( in_tensor .device )
160
+ freqs = 2 ** torch .linspace (self .min_freq , self .max_freq , self .num_frequencies , device = in_tensor .device )
157
161
scaled_inputs = scaled_in_tensor [..., None ] * freqs # [..., "input_dim", "num_scales"]
158
162
scaled_inputs = scaled_inputs .view (* scaled_inputs .shape [:- 2 ], - 1 ) # [..., "input_dim" * "num_scales"]
159
163
@@ -178,34 +182,40 @@ def forward(
178
182
return self .pytorch_fwd (in_tensor , covs )
179
183
180
184
181
- class RFFEncoding (Encoding ):
182
- """Random Fourier Feature encoding. Supports integrated encodings.
185
+ class FFEncoding (Encoding ):
186
+ """Fourier Feature encoding. Supports integrated encodings.
183
187
184
188
Args:
185
189
in_dim: Input dimension of tensor
186
- num_frequencies: Number of encoding frequencies
187
- scale: Std of Gaussian to sample frequencies. Must be greater than zero
190
+ basis: Basis matrix from which to construct the Fourier features.
191
+ num_frequencies: Number of encoded frequencies per axis
192
+ min_freq_exp: Minimum frequency exponent
193
+ max_freq_exp: Maximum frequency exponent
188
194
include_input: Append the input coordinate to the encoding
189
195
"""
190
196
191
- def __init__ (self , in_dim : int , num_frequencies : int , scale : float , include_input : bool = False ) -> None :
197
+ def __init__ (
198
+ self ,
199
+ in_dim : int ,
200
+ basis : Float [Tensor , "M N" ],
201
+ num_frequencies : int ,
202
+ min_freq_exp : float ,
203
+ max_freq_exp : float ,
204
+ include_input : bool = False ,
205
+ ) -> None :
192
206
super ().__init__ (in_dim )
193
-
194
207
self .num_frequencies = num_frequencies
195
- if not scale > 0 :
196
- raise ValueError ("RFF encoding scale should be greater than zero" )
197
- self .scale = scale
198
- if self .in_dim is None :
199
- raise ValueError ("Input dimension has not been set" )
200
- b_matrix = torch .normal (mean = 0 , std = self .scale , size = (self .in_dim , self .num_frequencies ))
201
- self .register_buffer (name = "b_matrix" , tensor = b_matrix )
208
+ self .min_freq = min_freq_exp
209
+ self .max_freq = max_freq_exp
210
+ self .register_buffer (name = "b_matrix" , tensor = basis )
202
211
self .include_input = include_input
203
212
204
213
def get_out_dim (self ) -> int :
205
- out_dim = self .num_frequencies * 2
214
+ if self .in_dim is None :
215
+ raise ValueError ("Input dimension has not been set" )
216
+ assert isinstance (self .b_matrix , Tensor )
217
+ out_dim = self .b_matrix .shape [1 ] * self .num_frequencies * 2
206
218
if self .include_input :
207
- if self .in_dim is None :
208
- raise ValueError ("Input dimension has not been set" )
209
219
out_dim += self .in_dim
210
220
return out_dim
211
221
@@ -214,7 +224,7 @@ def forward(
214
224
in_tensor : Float [Tensor , "*bs input_dim" ],
215
225
covs : Optional [Float [Tensor , "*bs input_dim input_dim" ]] = None ,
216
226
) -> Float [Tensor , "*bs output_dim" ]:
217
- """Calculates RFF encoding. If covariances are provided the encodings will be integrated as proposed
227
+ """Calculates FF encoding. If covariances are provided the encodings will be integrated as proposed
218
228
in mip-NeRF.
219
229
220
230
Args:
@@ -226,11 +236,16 @@ def forward(
226
236
"""
227
237
scaled_in_tensor = 2 * torch .pi * in_tensor # scale to [0, 2pi]
228
238
scaled_inputs = scaled_in_tensor @ self .b_matrix # [..., "num_frequencies"]
239
+ freqs = 2 ** torch .linspace (self .min_freq , self .max_freq , self .num_frequencies , device = in_tensor .device )
240
+ scaled_inputs = scaled_inputs [..., None ] * freqs # [..., "input_dim", "num_scales"]
241
+ scaled_inputs = scaled_inputs .view (* scaled_inputs .shape [:- 2 ], - 1 ) # [..., "input_dim" * "num_scales"]
229
242
230
243
if covs is None :
231
244
encoded_inputs = torch .sin (torch .cat ([scaled_inputs , scaled_inputs + torch .pi / 2.0 ], dim = - 1 ))
232
245
else :
233
246
input_var = torch .sum ((covs @ self .b_matrix ) * self .b_matrix , - 2 )
247
+ input_var = input_var [..., :, None ] * freqs [None , :] ** 2
248
+ input_var = input_var .reshape ((* input_var .shape [:- 2 ], - 1 ))
234
249
encoded_inputs = expected_sin (
235
250
torch .cat ([scaled_inputs , scaled_inputs + torch .pi / 2.0 ], dim = - 1 ), torch .cat (2 * [input_var ], dim = - 1 )
236
251
)
@@ -241,6 +256,49 @@ def forward(
241
256
return encoded_inputs
242
257
243
258
259
+ class RFFEncoding (FFEncoding ):
260
+ """Random Fourier Feature encoding. Supports integrated encodings.
261
+
262
+ Args:
263
+ in_dim: Input dimension of tensor
264
+ num_frequencies: Number of encoding frequencies
265
+ scale: Std of Gaussian to sample frequencies. Must be greater than zero
266
+ include_input: Append the input coordinate to the encoding
267
+ """
268
+
269
+ def __init__ (self , in_dim : int , num_frequencies : int , scale : float , include_input : bool = False ) -> None :
270
+ if not scale > 0 :
271
+ raise ValueError ("RFF encoding scale should be greater than zero" )
272
+
273
+ b_matrix = torch .normal (mean = 0 , std = scale , size = (in_dim , num_frequencies ))
274
+ super ().__init__ (in_dim , b_matrix , 1 , 0.0 , 0.0 , include_input )
275
+
276
+
277
+ class PolyhedronFFEncoding (FFEncoding ):
278
+ """Fourier Feature encoding using polyhedron basis as proposed by mip-NeRF360. Supports integrated encodings.
279
+
280
+ Args:
281
+ num_frequencies: Number of encoded frequencies per axis
282
+ min_freq_exp: Minimum frequency exponent
283
+ max_freq_exp: Maximum frequency exponent
284
+ basis_shape: Shape of polyhedron basis. Either "octahedron" or "icosahedron"
285
+ basis_subdivisions: Number of times to tesselate the polyhedron.
286
+ include_input: Append the input coordinate to the encoding
287
+ """
288
+
289
+ def __init__ (
290
+ self ,
291
+ num_frequencies : int ,
292
+ min_freq_exp : float ,
293
+ max_freq_exp : float ,
294
+ basis_shape : Literal ["octahedron" , "icosahedron" ] = "octahedron" ,
295
+ basis_subdivisions : int = 1 ,
296
+ include_input : bool = False ,
297
+ ) -> None :
298
+ basis_t = generate_polyhedron_basis (basis_shape , basis_subdivisions ).T
299
+ super ().__init__ (3 , basis_t , num_frequencies , min_freq_exp , max_freq_exp , include_input )
300
+
301
+
244
302
class HashEncoding (Encoding ):
245
303
"""Hash encoding
246
304
0 commit comments