Skip to content

Commit

Permalink
AABB normalized_positions to [-1, 1]
Browse files Browse the repository at this point in the history
Previously, SceneBox.get_normalized_positions() returned a value in 0,1.
This was fine for InstantNGP and Nerfacto, which use these values directly in an MLP.
For TensoRF though, this was an error because grid sample expects values in the range [-1,1].
Thus, I think only 1/4th of the encoding planes were actually being used, and half of the
encoding lines were.
  • Loading branch information
JulianKnodt committed Nov 21, 2022
1 parent d767cd5 commit 330195d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
7 changes: 7 additions & 0 deletions nerfstudio/field_components/encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,13 @@ def get_out_dim(self) -> int:
return self.num_components * 3

def forward(self, in_tensor: TensorType["bs":..., "input_dim"]) -> TensorType["bs":..., "output_dim"]:
"""Compute encoding for each position in in_positions
Args:
in_tensor: position inside bounds in range [-1,1],
Returns: Encoded position
"""
plane_coord = torch.stack([in_tensor[..., [0, 1]], in_tensor[..., [0, 2]], in_tensor[..., [1, 2]]]) # [3,...,2]
line_coord = torch.stack([in_tensor[..., 2], in_tensor[..., 1], in_tensor[..., 0]]) # [3, ...]
line_coord = torch.stack([torch.zeros_like(line_coord), line_coord], dim=-1) # [3, ...., 2]
Expand Down
2 changes: 2 additions & 0 deletions nerfstudio/fields/tensorf_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(

def get_density(self, ray_samples: RaySamples):
positions = SceneBox.get_normalized_positions(ray_samples.frustums.get_positions(), self.aabb)
positions = positions * 2 - 1
density = self.density_encoding(positions)
density_enc = torch.sum(density, dim=-1)[:, :, None]
relu = torch.nn.ReLU()
Expand All @@ -94,6 +95,7 @@ def get_density(self, ray_samples: RaySamples):
def get_outputs(self, ray_samples: RaySamples, density_embedding: Optional[TensorType] = None) -> TensorType:
d = ray_samples.frustums.directions
positions = SceneBox.get_normalized_positions(ray_samples.frustums.get_positions(), self.aabb)
positions = positions * 2 - 1
rgb_features = self.color_encoding(positions)
rgb_features = self.B(rgb_features)

Expand Down

0 comments on commit 330195d

Please sign in to comment.