Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encoder + MLP combo #2063

Merged
merged 8 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 62 additions & 26 deletions nerfstudio/field_components/encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ def __init__(self, in_dim: int) -> None:
raise ValueError("Input dimension should be greater than zero")
super().__init__(in_dim=in_dim)

@classmethod
def get_tcnn_encoding_config(cls) -> dict:
"""Get the encoding configuration for tcnn if implemented"""
raise NotImplementedError("Encoding does not have a TCNN implementation")

@abstractmethod
def forward(self, in_tensor: Shaped[Tensor, "*bs input_dim"]) -> Shaped[Tensor, "*bs output_dim"]:
"""Call forward and returns and processed tensor
Expand Down Expand Up @@ -126,14 +131,20 @@ def __init__(
if implementation == "tcnn" and not TCNN_EXISTS:
print_tcnn_speed_warning("NeRFEncoding")
elif implementation == "tcnn":
encoding_config = {"otype": "Frequency", "n_frequencies": num_frequencies}
assert min_freq_exp == 0, "tcnn only supports min_freq_exp = 0"
assert max_freq_exp == num_frequencies - 1, "tcnn only supports max_freq_exp = num_frequencies - 1"
encoding_config = self.get_tcnn_encoding_config(num_frequencies=self.num_frequencies)
self.tcnn_encoding = tcnn.Encoding(
n_input_dims=in_dim,
encoding_config=encoding_config,
)

@classmethod
def get_tcnn_encoding_config(cls, num_frequencies) -> dict:
"""Get the encoding configuration for tcnn if implemented"""
encoding_config = {"otype": "Frequency", "n_frequencies": num_frequencies}
return encoding_config

def get_out_dim(self) -> int:
if self.in_dim is None:
raise ValueError("Input dimension has not been set")
Expand Down Expand Up @@ -327,48 +338,67 @@ def __init__(
) -> None:
super().__init__(in_dim=3)
self.num_levels = num_levels
self.min_res = min_res
self.features_per_level = features_per_level
self.hash_init_scale = hash_init_scale
self.log2_hashmap_size = log2_hashmap_size
self.hash_table_size = 2**log2_hashmap_size

levels = torch.arange(num_levels)
growth_factor = np.exp((np.log(max_res) - np.log(min_res)) / (num_levels - 1)) if num_levels > 1 else 1
self.scalings = torch.floor(min_res * growth_factor**levels)
self.growth_factor = np.exp((np.log(max_res) - np.log(min_res)) / (num_levels - 1)) if num_levels > 1 else 1
self.scalings = torch.floor(min_res * self.growth_factor**levels)

self.hash_offset = levels * self.hash_table_size

self.tcnn_encoding = None
self.hash_table = torch.empty(0)
if implementation == "tcnn" and not TCNN_EXISTS:
if implementation == "torch":
self.build_nn_modules()
elif implementation == "tcnn" and not TCNN_EXISTS:
print_tcnn_speed_warning("HashEncoding")
implementation = "torch"

if implementation == "tcnn":
encoding_config = {
"otype": "HashGrid",
"n_levels": self.num_levels,
"n_features_per_level": self.features_per_level,
"log2_hashmap_size": self.log2_hashmap_size,
"base_resolution": min_res,
"per_level_scale": growth_factor,
}
if interpolation is not None:
encoding_config["interpolation"] = interpolation

self.build_nn_modules()
elif implementation == "tcnn":
encoding_config = self.get_tcnn_encoding_config(
num_levels=self.num_levels,
features_per_level=self.features_per_level,
log2_hashmap_size=self.log2_hashmap_size,
min_res=self.min_res,
growth_factor=self.growth_factor,
interpolation=interpolation,
)
self.tcnn_encoding = tcnn.Encoding(
n_input_dims=3,
encoding_config=encoding_config,
)
elif implementation == "torch":
self.hash_table = torch.rand(size=(self.hash_table_size * num_levels, features_per_level)) * 2 - 1
self.hash_table *= hash_init_scale
self.hash_table = nn.Parameter(self.hash_table)

if self.tcnn_encoding is None:
assert (
interpolation is None or interpolation == "Linear"
), f"interpolation '{interpolation}' is not supported for torch encoding backend"

def build_nn_modules(self) -> None:
"""Initialize the torch version of the hash encoding."""
self.hash_table = torch.rand(size=(self.hash_table_size * self.num_levels, self.features_per_level)) * 2 - 1
self.hash_table *= self.hash_init_scale
self.hash_table = nn.Parameter(self.hash_table)

@classmethod
def get_tcnn_encoding_config(
cls, num_levels, features_per_level, log2_hashmap_size, min_res, growth_factor, interpolation=None
) -> dict:
"""Get the encoding configuration for tcnn if implemented"""
encoding_config = {
"otype": "HashGrid",
"n_levels": num_levels,
"n_features_per_level": features_per_level,
"log2_hashmap_size": log2_hashmap_size,
"base_resolution": min_res,
"per_level_scale": growth_factor,
}
if interpolation is not None:
encoding_config["interpolation"] = interpolation
return encoding_config

def get_out_dim(self) -> int:
return self.num_levels * self.features_per_level

Expand Down Expand Up @@ -745,15 +775,21 @@ def __init__(self, levels: int = 4, implementation: Literal["tcnn", "torch"] = "
if implementation == "tcnn" and not TCNN_EXISTS:
print_tcnn_speed_warning("SHEncoding")
elif implementation == "tcnn":
encoding_config = {
"otype": "SphericalHarmonics",
"degree": levels,
}
encoding_config = self.get_tcnn_encoding_config(levels=self.levels)
self.tcnn_encoding = tcnn.Encoding(
n_input_dims=3,
encoding_config=encoding_config,
)

@classmethod
def get_tcnn_encoding_config(cls, levels) -> dict:
"""Get the encoding configuration for tcnn if implemented"""
encoding_config = {
"otype": "SphericalHarmonics",
"degree": levels,
}
return encoding_config

def get_out_dim(self) -> int:
return self.levels**2

Expand Down
176 changes: 149 additions & 27 deletions nerfstudio/field_components/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
"""
from typing import Literal, Optional, Set, Tuple, Union

import numpy as np
import torch
from jaxtyping import Float
from torch import Tensor, nn

from nerfstudio.field_components.base_field_component import FieldComponent
from nerfstudio.utils.printing import print_tcnn_speed_warning
from nerfstudio.field_components.encodings import HashEncoding

from nerfstudio.utils.rich_utils import CONSOLE
from nerfstudio.utils.external import TCNN_EXISTS, tcnn
Expand Down Expand Up @@ -66,6 +68,7 @@ class MLP(FieldComponent):
out_dim: Output layer dimension. Uses layer_width if None.
activation: intermediate layer activation function.
out_activation: output activation function.
implementation: Implementation of hash encoding. Fallback to torch if tcnn not available.
"""

def __init__(
Expand Down Expand Up @@ -98,39 +101,47 @@ def __init__(
print_tcnn_speed_warning("MLP")
self.build_nn_modules()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed this in this PR too!

elif implementation == "tcnn":
activation_str = activation_to_tcnn_string(activation)
output_activation_str = activation_to_tcnn_string(out_activation)
if layer_width in [16, 32, 64, 128]:
network_config = {
"otype": "FullyFusedMLP",
"activation": activation_str,
"output_activation": output_activation_str,
"n_neurons": layer_width,
"n_hidden_layers": num_layers - 1,
}
else:
CONSOLE.line()
CONSOLE.print("[bold yellow]WARNING: Using slower TCNN CutlassMLP instead of TCNN FullyFusedMLP")
CONSOLE.print(
"[bold yellow]Use layer width of 16, 32, 64, or 128 to use the faster TCNN FullyFusedMLP."
)
CONSOLE.line()
network_config = {
"otype": "CutlassMLP",
"activation": activation_str,
"output_activation": output_activation_str,
"n_neurons": layer_width,
"n_hidden_layers": num_layers - 1,
}

network_config = self.get_tcnn_network_config(
activation=self.activation,
out_activation=self.out_activation,
layer_width=self.layer_width,
num_layers=self.num_layers,
)
self.tcnn_encoding = tcnn.Network(
n_input_dims=in_dim,
n_output_dims=out_dim,
n_output_dims=self.out_dim,
network_config=network_config,
)

@classmethod
def get_tcnn_network_config(cls, activation, out_activation, layer_width, num_layers) -> dict:
"""Get the network configuration for tcnn if implemented"""
activation_str = activation_to_tcnn_string(activation)
output_activation_str = activation_to_tcnn_string(out_activation)
if layer_width in [16, 32, 64, 128]:
network_config = {
"otype": "FullyFusedMLP",
"activation": activation_str,
"output_activation": output_activation_str,
"n_neurons": layer_width,
"n_hidden_layers": num_layers - 1,
}
else:
CONSOLE.line()
CONSOLE.print("[bold yellow]WARNING: Using slower TCNN CutlassMLP instead of TCNN FullyFusedMLP")
CONSOLE.print("[bold yellow]Use layer width of 16, 32, 64, or 128 to use the faster TCNN FullyFusedMLP.")
CONSOLE.line()
network_config = {
"otype": "CutlassMLP",
"activation": activation_str,
"output_activation": output_activation_str,
"n_neurons": layer_width,
"n_hidden_layers": num_layers - 1,
}
return network_config

def build_nn_modules(self) -> None:
"""Initialize multi-layer perceptron."""
"""Initialize the torch version of the multi-layer perceptron."""
layers = []
if self.num_layers == 1:
layers.append(nn.Linear(self.in_dim, self.out_dim))
Expand Down Expand Up @@ -171,3 +182,114 @@ def forward(self, in_tensor: Float[Tensor, "*bs in_dim"]) -> Float[Tensor, "*bs
if self.tcnn_encoding is not None:
return self.tcnn_encoding(in_tensor)
return self.pytorch_fwd(in_tensor)


class MLPWithHashEncoding(FieldComponent):
"""Multilayer perceptron with hash encoding

Args:
num_levels: Number of feature grids.
min_res: Resolution of smallest feature grid.
max_res: Resolution of largest feature grid.
log2_hashmap_size: Size of hash map is 2^log2_hashmap_size.
features_per_level: Number of features per level.
hash_init_scale: Value to initialize hash grid.
interpolation: Interpolation override for tcnn hashgrid. Not supported for torch unless linear.
num_layers: Number of network layers
layer_width: Width of each MLP layer
out_dim: Output layer dimension. Uses layer_width if None.
activation: intermediate layer activation function.
out_activation: output activation function.
implementation: Implementation of hash encoding. Fallback to torch if tcnn not available.
"""

def __init__(
self,
num_levels: int = 16,
min_res: int = 16,
max_res: int = 1024,
log2_hashmap_size: int = 19,
features_per_level: int = 2,
hash_init_scale: float = 0.001,
interpolation: Optional[Literal["Nearest", "Linear", "Smoothstep"]] = None,
num_layers: int = 2,
layer_width: int = 64,
out_dim: Optional[int] = None,
skip_connections: Optional[Tuple[int]] = None,
activation: Optional[nn.Module] = nn.ReLU(),
out_activation: Optional[nn.Module] = None,
implementation: Literal["tcnn", "torch"] = "torch",
) -> None:
super().__init__()
self.in_dim = 3

self.num_levels = num_levels
self.min_res = min_res
self.max_res = max_res
self.features_per_level = features_per_level
self.hash_init_scale = hash_init_scale
self.log2_hashmap_size = log2_hashmap_size
self.hash_table_size = 2**log2_hashmap_size

self.growth_factor = np.exp((np.log(max_res) - np.log(min_res)) / (num_levels - 1)) if num_levels > 1 else 1

self.out_dim = out_dim if out_dim is not None else layer_width
self.num_layers = num_layers
self.layer_width = layer_width
self.skip_connections = skip_connections
self._skip_connections: Set[int] = set(skip_connections) if skip_connections else set()
self.activation = activation
self.out_activation = out_activation
self.net = None

self.tcnn_encoding = None
if implementation == "torch":
self.build_nn_modules()
elif implementation == "tcnn" and not TCNN_EXISTS:
print_tcnn_speed_warning("MLPWithHashEncoding")
self.build_nn_modules()
elif implementation == "tcnn":
self.model = tcnn.NetworkWithInputEncoding(
n_input_dims=self.in_dim,
n_output_dims=self.out_dim,
encoding_config=HashEncoding.get_tcnn_encoding_config(
num_levels=self.num_levels,
features_per_level=self.features_per_level,
log2_hashmap_size=self.log2_hashmap_size,
min_res=self.min_res,
growth_factor=self.growth_factor,
interpolation=interpolation,
),
network_config=MLP.get_tcnn_network_config(
activation=self.activation,
out_activation=self.out_activation,
layer_width=self.layer_width,
num_layers=self.num_layers,
),
)

def build_nn_modules(self) -> None:
"""Initialize the torch version of the MLP with hash encoding."""
encoder = HashEncoding(
num_levels=self.num_levels,
min_res=self.min_res,
max_res=self.max_res,
log2_hashmap_size=self.log2_hashmap_size,
features_per_level=self.features_per_level,
hash_init_scale=self.hash_init_scale,
implementation="torch",
)
mlp = MLP(
in_dim=encoder.get_out_dim(),
num_layers=self.num_layers,
layer_width=self.layer_width,
out_dim=self.out_dim,
skip_connections=self.skip_connections,
activation=self.activation,
out_activation=self.out_activation,
implementation="torch",
)
self.model = torch.nn.Sequential(encoder, mlp)

def forward(self, in_tensor: Float[Tensor, "*bs in_dim"]) -> Float[Tensor, "*bs out_dim"]:
return self.model(in_tensor)
Loading
Loading