Skip to content

Commit

Permalink
Fix #2843 - safer typing
Browse files Browse the repository at this point in the history
  • Loading branch information
jkulhanek committed Jan 30, 2024
1 parent 4779f81 commit 4d3c475
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
4 changes: 3 additions & 1 deletion nerfstudio/fields/nerfacto_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ def __init__(
self.appearance_embedding_dim = appearance_embedding_dim
if self.appearance_embedding_dim > 0:
self.embedding_appearance = Embedding(self.num_images, self.appearance_embedding_dim)
else:
self.embedding_appearance = None
self.use_average_appearance_embedding = use_average_appearance_embedding
self.use_transient_embedding = use_transient_embedding
self.use_semantics = use_semantics
Expand Down Expand Up @@ -241,7 +243,7 @@ def get_outputs(

# appearance
embedded_appearance = None
if self.appearance_embedding_dim > 0:
if self.embedding_appearance is not None:
if self.training:
embedded_appearance = self.embedding_appearance(camera_indices)
else:
Expand Down
4 changes: 3 additions & 1 deletion nerfstudio/fields/nerfw_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def __init__(

if self.appearance_embedding_dim > 0:
self.embedding_appearance = Embedding(self.num_images, self.appearance_embedding_dim)
else:
self.embedding_appearance = None
self.embedding_transient = Embedding(self.num_images, self.transient_embedding_dim)

self.mlp_base = MLP(
Expand Down Expand Up @@ -136,7 +138,7 @@ def get_outputs(
raise AttributeError("Camera indices are not provided.")
camera_indices = ray_samples.camera_indices.squeeze().to(ray_samples.frustums.origins.device)
mlp_in = [density_embedding, encoded_dir]
if self.appearance_embedding_dim > 0:
if self.embedding_appearance is not None:
embedded_appearance = self.embedding_appearance(camera_indices)
mlp_in.append(embedded_appearance)
mlp_head_out = self.mlp_head(torch.cat(mlp_in, dim=-1))
Expand Down

0 comments on commit 4d3c475

Please sign in to comment.