Skip to content

Commit 7ea6bdd

Browse files
jkulhanekArpegorPSGH
authored andcommitted
1 parent 2ed1e28 commit 7ea6bdd

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

nerfstudio/fields/nerfacto_field.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ def __init__(
112112
self.appearance_embedding_dim = appearance_embedding_dim
113113
if self.appearance_embedding_dim > 0:
114114
self.embedding_appearance = Embedding(self.num_images, self.appearance_embedding_dim)
115+
else:
116+
self.embedding_appearance = None
115117
self.use_average_appearance_embedding = use_average_appearance_embedding
116118
self.use_transient_embedding = use_transient_embedding
117119
self.use_semantics = use_semantics
@@ -241,7 +243,7 @@ def get_outputs(
241243

242244
# appearance
243245
embedded_appearance = None
244-
if self.appearance_embedding_dim > 0:
246+
if self.embedding_appearance is not None:
245247
if self.training:
246248
embedded_appearance = self.embedding_appearance(camera_indices)
247249
else:

nerfstudio/fields/nerfw_field.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ def __init__(
7777

7878
if self.appearance_embedding_dim > 0:
7979
self.embedding_appearance = Embedding(self.num_images, self.appearance_embedding_dim)
80+
else:
81+
self.embedding_appearance = None
8082
self.embedding_transient = Embedding(self.num_images, self.transient_embedding_dim)
8183

8284
self.mlp_base = MLP(
@@ -97,7 +99,7 @@ def __init__(
9799
self.mlp_head = MLP(
98100
in_dim=self.mlp_base.get_out_dim()
99101
+ self.direction_encoding.get_out_dim()
100-
+ (self.embedding_appearance.get_out_dim() if self.appearance_embedding_dim > 0 else 0),
102+
+ (self.embedding_appearance.get_out_dim() if self.embedding_appearance is not None else 0),
101103
num_layers=head_mlp_num_layers,
102104
layer_width=head_mlp_layer_width,
103105
out_activation=nn.ReLU(),
@@ -136,7 +138,7 @@ def get_outputs(
136138
raise AttributeError("Camera indices are not provided.")
137139
camera_indices = ray_samples.camera_indices.squeeze().to(ray_samples.frustums.origins.device)
138140
mlp_in = [density_embedding, encoded_dir]
139-
if self.appearance_embedding_dim > 0:
141+
if self.embedding_appearance is not None:
140142
embedded_appearance = self.embedding_appearance(camera_indices)
141143
mlp_in.append(embedded_appearance)
142144
mlp_head_out = self.mlp_head(torch.cat(mlp_in, dim=-1))

0 commit comments

Comments
 (0)