Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Computing and predicting normals #905

Merged
merged 28 commits into from
Nov 20, 2022
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
7a06e40
staring normals PR
ethanweber Nov 4, 2022
c767eff
code runs but visualization looks bad
ethanweber Nov 6, 2022
94b3124
Camera and world coordinate conventions (#889)
ethanweber Nov 5, 2022
f2d3525
Start viewer with more accurate speed calculation (#899)
tancik Nov 5, 2022
46a661f
Minor doc updates (#903)
tancik Nov 6, 2022
99053a1
fix conflicts
ethanweber Nov 6, 2022
528807b
code runs but visualization looks bad
ethanweber Nov 6, 2022
3c8d425
normals implemented and vis fixed. still debugging quality
ethanweber Nov 6, 2022
0be1fc5
minor name change
ethanweber Nov 6, 2022
62b20bd
comment
ethanweber Nov 6, 2022
89a2c94
Merge branch 'main' into ethan/adding_normals
ethanweber Nov 6, 2022
2b53638
fix black
ethanweber Nov 6, 2022
8a5e434
lint
ethanweber Nov 6, 2022
20e9734
predicting normals
ethanweber Nov 7, 2022
be83a2e
adding detach
ethanweber Nov 7, 2022
b7667f7
predicted normals
ethanweber Nov 8, 2022
4360411
Merge branch 'main' of https://github.com/plenoptix/nerfactory into e…
terrancewang Nov 10, 2022
a9d8ab0
improved normals prediction performance
terrancewang Nov 17, 2022
97601a8
Merge branch 'main' of https://github.com/plenoptix/nerfactory into e…
terrancewang Nov 17, 2022
71e0b46
cleaning code
terrancewang Nov 17, 2022
c2fb388
cleaning code
terrancewang Nov 19, 2022
a09e924
cleaning code
terrancewang Nov 20, 2022
7cacc47
adding type to function param
terrancewang Nov 20, 2022
c5bd54c
fixing tensorf forward func
terrancewang Nov 20, 2022
1970493
reverting tensorf changes
terrancewang Nov 20, 2022
432dcfc
fix style with tensorf field
terrancewang Nov 20, 2022
55e0ef5
adding normals check to tensorffield
terrancewang Nov 20, 2022
b2ff175
Merge branch 'main' into ethan/adding_normals
terrancewang Nov 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions nerfstudio/field_components/field_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from enum import Enum
from typing import Callable, Optional, Union

import torch
from torch import nn
from torchtyping import TensorType

Expand All @@ -30,12 +31,12 @@ class FieldHeadNames(Enum):
RGB = "rgb"
SH = "sh"
DENSITY = "density"
NORMALS = "normals"
PRED_NORMALS = "pred_normals"
UNCERTAINTY = "uncertainty"
TRANSIENT_RGB = "transient_rgb"
TRANSIENT_DENSITY = "transient_density"
SEMANTICS = "semantics"
SEMANTICS_STUFF = "semantics_stuff"
SEMANTICS_THING = "semantics_thing"


class FieldHead(FieldComponent):
Expand Down Expand Up @@ -171,7 +172,7 @@ def __init__(self, in_dim: Optional[int] = None, activation: Optional[nn.Module]


class SemanticFieldHead(FieldHead):
"""Semantic stuff output
"""Semantic output

Args:
num_classes: Number of semantic classes
Expand All @@ -181,3 +182,21 @@ class SemanticFieldHead(FieldHead):

def __init__(self, num_classes: int, in_dim: Optional[int] = None) -> None:
super().__init__(in_dim=in_dim, out_dim=num_classes, field_head_name=FieldHeadNames.SEMANTICS, activation=None)


class PredNormalsFieldHead(FieldHead):
"""Predicted normals output.

Args:
in_dim: input dimension. If not defined in constructor, it must be set later.
activation: output head activation
"""

def __init__(self, in_dim: Optional[int] = None, activation: Optional[nn.Module] = nn.Tanh()) -> None:
super().__init__(in_dim=in_dim, out_dim=3, field_head_name=FieldHeadNames.PRED_NORMALS, activation=activation)

def forward(self, in_tensor: TensorType["bs":..., "in_dim"]) -> TensorType["bs":..., "out_dim"]:
"""Needed to normalize the output into valid normals."""
out_tensor = super().forward(in_tensor)
out_tensor = torch.nn.functional.normalize(out_tensor, dim=-1)
return out_tensor
38 changes: 35 additions & 3 deletions nerfstudio/fields/base_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@
class Field(nn.Module):
"""Base class for fields."""

def __init__(self) -> None:
super().__init__()
self._sample_locations = None
self._density_before_activation = None

def density_fn(self, positions: TensorType["bs":..., 3]) -> TensorType["bs":..., 1]:
"""Returns only the density. Used primarily with the density grid.

Expand Down Expand Up @@ -57,6 +62,24 @@ def get_density(self, ray_samples: RaySamples) -> Tuple[TensorType[..., 1], Tens
ray_samples: Samples locations to compute density.
"""

def get_normals(self) -> TensorType[..., 3]:
"""Computes and returns a tensor of normals.

Args:
density: Tensor of densities.
"""
assert self._sample_locations is not None, "Sample locations must be set before calling get_normals."
assert self._density_before_activation is not None, "Density must be set before calling get_normals."
assert (
self._sample_locations.shape[:-1] == self._density_before_activation.shape[:-1]
), "Sample locations and density must have the same shape besides the last dimension."

self._density_before_activation.backward(
gradient=torch.ones_like(self._density_before_activation), inputs=self._sample_locations, retain_graph=True
)
normals = -torch.nn.functional.normalize(self._sample_locations.grad, dim=-1)
return normals

@abstractmethod
def get_outputs(
self, ray_samples: RaySamples, density_embedding: Optional[TensorType] = None
Expand All @@ -68,14 +91,23 @@ def get_outputs(
density_embedding: Density embeddings to condition on.
"""

def forward(self, ray_samples: RaySamples):
def forward(self, ray_samples: RaySamples, compute_normals: bool = False):
"""Evaluates the field at points along the ray.

Args:
ray_samples: Samples to evaluate field on.
"""
density, density_embedding = self.get_density(ray_samples)
field_outputs = self.get_outputs(ray_samples, density_embedding=density_embedding)
if compute_normals:
with torch.enable_grad():
density, density_embedding = self.get_density(ray_samples)
else:
density, density_embedding = self.get_density(ray_samples)

field_outputs = self.get_outputs(ray_samples, density_embedding=density_embedding)
field_outputs[FieldHeadNames.DENSITY] = density # type: ignore

if compute_normals:
with torch.enable_grad():
normals = self.get_normals()
field_outputs[FieldHeadNames.NORMALS] = normals # type: ignore
return field_outputs
50 changes: 45 additions & 5 deletions nerfstudio/fields/nerfacto_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
DensityFieldHead,
FieldHead,
FieldHeadNames,
PredNormalsFieldHead,
RGBFieldHead,
SemanticFieldHead,
TransientDensityFieldHead,
Expand Down Expand Up @@ -95,6 +96,7 @@ def __init__(
use_transient_embedding: bool = False,
use_semantics: bool = False,
num_semantic_classes: int = 100,
use_pred_normals: bool = False,
use_average_appearance_embedding: bool = False,
spatial_distortion: Optional[SpatialDistortion] = None,
) -> None:
Expand All @@ -110,6 +112,7 @@ def __init__(
self.use_average_appearance_embedding = use_average_appearance_embedding
self.use_transient_embedding = use_transient_embedding
self.use_semantics = use_semantics
self.use_pred_normals = use_pred_normals

num_levels = 16
max_res = 1024
Expand All @@ -126,6 +129,11 @@ def __init__(
},
)

self.position_encoding = tcnn.Encoding(
n_input_dims=3,
encoding_config={"otype": "Frequency", "n_frequencies": 2},
)

self.mlp_base = tcnn.NetworkWithInputEncoding(
n_input_dims=3,
n_output_dims=1 + self.geo_feat_dim,
Expand Down Expand Up @@ -182,6 +190,21 @@ def __init__(
in_dim=self.mlp_semantics.n_output_dims, num_classes=num_semantic_classes
)

# predicted normals
if self.use_pred_normals:
self.mlp_pred_normals = tcnn.Network(
n_input_dims=self.geo_feat_dim + self.position_encoding.n_output_dims,
n_output_dims=hidden_dim_transient,
network_config={
"otype": "FullyFusedMLP",
"activation": "ReLU",
"output_activation": "None",
"n_neurons": 64,
"n_hidden_layers": 2,
},
)
self.field_head_pred_normals = PredNormalsFieldHead(in_dim=self.mlp_pred_normals.n_output_dims)

self.mlp_head = tcnn.Network(
n_input_dims=self.direction_encoding.n_output_dims + self.geo_feat_dim + self.appearance_embedding_dim,
n_output_dims=3,
Expand All @@ -202,9 +225,13 @@ def get_density(self, ray_samples: RaySamples):
positions = (positions + 2.0) / 4.0
else:
positions = SceneBox.get_normalized_positions(ray_samples.frustums.get_positions(), self.aabb)
self._sample_locations = positions
if not self._sample_locations.requires_grad:
self._sample_locations.requires_grad = True
positions_flat = positions.view(-1, 3)
h = self.mlp_base(positions_flat).view(*ray_samples.frustums.shape, -1)
density_before_activation, base_mlp_out = torch.split(h, [1, self.geo_feat_dim], dim=-1)
self._density_before_activation = density_before_activation

# Rectifying the density with an exponential is much more stable than a ReLU or
# softplus, because it enables high post-activation (float32) density outputs
Expand All @@ -222,6 +249,8 @@ def get_outputs(self, ray_samples: RaySamples, density_embedding: Optional[Tenso
directions_flat = directions.view(-1, 3)
d = self.direction_encoding(directions_flat)

outputs_shape = ray_samples.frustums.directions.shape[:-1]

# appearance
if self.training:
embedded_appearance = self.embedding_appearance(camera_indices)
Expand All @@ -245,7 +274,7 @@ def get_outputs(self, ray_samples: RaySamples, density_embedding: Optional[Tenso
],
dim=-1,
)
x = self.mlp_transient(transient_input).view(*ray_samples.frustums.directions.shape[:-1], -1).to(directions)
x = self.mlp_transient(transient_input).view(*outputs_shape, -1).to(directions)
outputs[FieldHeadNames.UNCERTAINTY] = self.field_head_transient_uncertainty(x)
outputs[FieldHeadNames.TRANSIENT_RGB] = self.field_head_transient_rgb(x)
outputs[FieldHeadNames.TRANSIENT_DENSITY] = self.field_head_transient_density(x)
Expand All @@ -259,10 +288,19 @@ def get_outputs(self, ray_samples: RaySamples, density_embedding: Optional[Tenso
],
dim=-1,
)
# print(semantics_input)
x = self.mlp_semantics(semantics_input).view(*ray_samples.frustums.directions.shape[:-1], -1).to(directions)
x = self.mlp_semantics(semantics_input).view(*outputs_shape, -1).to(directions)
outputs[FieldHeadNames.SEMANTICS] = self.field_head_semantics(x)

# predicted normals
if self.use_pred_normals:
positions = ray_samples.frustums.get_positions()

positions_flat = self.position_encoding(positions.view(-1, 3))
pred_normals_inp = torch.cat([positions_flat, density_embedding.view(-1, self.geo_feat_dim)], dim=-1)

x = self.mlp_pred_normals(pred_normals_inp).view(*outputs_shape, -1).to(directions)
outputs[FieldHeadNames.PRED_NORMALS] = self.field_head_pred_normals(x)

h = torch.cat(
[
d,
Expand All @@ -271,7 +309,7 @@ def get_outputs(self, ray_samples: RaySamples, density_embedding: Optional[Tenso
],
dim=-1,
)
rgb = self.mlp_head(h).view(*ray_samples.frustums.directions.shape[:-1], -1).to(directions)
rgb = self.mlp_head(h).view(*outputs_shape, -1).to(directions)
outputs.update({FieldHeadNames.RGB: rgb})

return outputs
Expand Down Expand Up @@ -342,14 +380,16 @@ def get_outputs(
self, ray_samples: RaySamples, density_embedding: Optional[TensorType] = None
) -> Dict[FieldHeadNames, TensorType]:

outputs_shape = ray_samples.frustums.directions.shape[:-1]

if ray_samples.camera_indices is None:
raise AttributeError("Camera indices are not provided.")
camera_indices = ray_samples.camera_indices.squeeze()
if self.training:
embedded_appearance = self.embedding_appearance(camera_indices)
else:
embedded_appearance = torch.zeros(
(*ray_samples.frustums.directions.shape[:-1], self.appearance_embedding_dim),
(*outputs_shape, self.appearance_embedding_dim),
device=ray_samples.frustums.directions.device,
)

Expand Down
6 changes: 5 additions & 1 deletion nerfstudio/fields/tensorf_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,11 @@ def get_outputs(self, ray_samples: RaySamples, density_embedding: Optional[Tenso
return rgb

def forward(
self, ray_samples: RaySamples, mask: Optional[TensorType] = None, bg_color: Optional[TensorType] = None
self,
ray_samples: RaySamples,
compute_normals: bool = False,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or raise on error if compute_normals is True.

mask: Optional[TensorType] = None,
bg_color: Optional[TensorType] = None,
):
if mask is not None and bg_color is not None:
base_density = torch.zeros(ray_samples.shape)[:, :, None].to(mask.device)
Expand Down
24 changes: 24 additions & 0 deletions nerfstudio/model_components/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,27 @@ def nerfstudio_distortion_loss(
loss = loss + 1 / 3.0 * torch.sum(weights**2 * (ends - starts), dim=-2)

return loss


def orientation_loss(
weights: TensorType["bs":..., "num_samples", 1],
normals: TensorType["bs":..., "num_samples", 3],
viewdirs: TensorType["bs":..., 3],
):
"""Orientation loss proposed in Ref-NeRF.
Loss that encourages that all visible normals are facing towards the camera.
"""
w = weights
n = normals
v = viewdirs
n_dot_v = (n * v[..., None, :]).sum(axis=-1)
return (w[..., 0] * torch.fmin(torch.zeros_like(n_dot_v), n_dot_v) ** 2).sum(dim=-1)


def pred_normal_loss(
weights: TensorType["bs":..., "num_samples", 1],
normals: TensorType["bs":..., "num_samples", 3],
pred_normals: TensorType["bs":..., "num_samples", 3],
):
"""Loss between normals calculated from density and normals from prediction network."""
return (weights[..., 0] * (1.0 - torch.sum(normals * pred_normals, dim=-1))).sum(dim=-1)
16 changes: 15 additions & 1 deletion nerfstudio/model_components/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,20 @@ def forward(
semantics: TensorType["bs":..., "num_samples", "num_classes"],
weights: TensorType["bs":..., "num_samples", 1],
) -> TensorType["bs":..., "num_classes"]:
"""_summary_"""
"""Calculate semantics along the ray."""
sem = torch.sum(weights * semantics, dim=-2)
return sem


class NormalsRenderer(nn.Module):
"""Calculate normals along the ray."""

@classmethod
def forward(
cls,
normals: TensorType["bs":..., "num_samples", 3],
weights: TensorType["bs":..., "num_samples", 1],
) -> TensorType["bs":..., 3]:
"""Calculate normals along the ray."""
n = torch.sum(weights * normals, dim=-2)
return n
Loading