diff --git a/nerfstudio/fields/nerfacto_field.py b/nerfstudio/fields/nerfacto_field.py index f06073c8d6..543955d745 100644 --- a/nerfstudio/fields/nerfacto_field.py +++ b/nerfstudio/fields/nerfacto_field.py @@ -110,7 +110,8 @@ def __init__( self.spatial_distortion = spatial_distortion self.num_images = num_images self.appearance_embedding_dim = appearance_embedding_dim - self.embedding_appearance = Embedding(self.num_images, self.appearance_embedding_dim) + if self.appearance_embedding_dim > 0: + self.embedding_appearance = Embedding(self.num_images, self.appearance_embedding_dim) self.use_average_appearance_embedding = use_average_appearance_embedding self.use_transient_embedding = use_transient_embedding self.use_semantics = use_semantics @@ -239,17 +240,19 @@ def get_outputs( outputs_shape = ray_samples.frustums.directions.shape[:-1] # appearance - if self.training: - embedded_appearance = self.embedding_appearance(camera_indices) - else: - if self.use_average_appearance_embedding: - embedded_appearance = torch.ones( - (*directions.shape[:-1], self.appearance_embedding_dim), device=directions.device - ) * self.embedding_appearance.mean(dim=0) + embedded_appearance = None + if self.appearance_embedding_dim > 0: + if self.training: + embedded_appearance = self.embedding_appearance(camera_indices) else: - embedded_appearance = torch.zeros( - (*directions.shape[:-1], self.appearance_embedding_dim), device=directions.device - ) + if self.use_average_appearance_embedding: + embedded_appearance = torch.ones( + (*directions.shape[:-1], self.appearance_embedding_dim), device=directions.device + ) * self.embedding_appearance.mean(dim=0) + else: + embedded_appearance = torch.zeros( + (*directions.shape[:-1], self.appearance_embedding_dim), device=directions.device + ) # transients if self.use_transient_embedding and self.training: @@ -289,8 +292,10 @@ def get_outputs( [ d, density_embedding.view(-1, self.geo_feat_dim), - embedded_appearance.view(-1, self.appearance_embedding_dim), - ], + ] + + ( + [embedded_appearance.view(-1, self.appearance_embedding_dim)] if embedded_appearance is not None else [] + ), dim=-1, ) rgb = self.mlp_head(h).view(*outputs_shape, -1).to(directions) diff --git a/nerfstudio/fields/nerfw_field.py b/nerfstudio/fields/nerfw_field.py index 4f385a3710..9b47de881f 100644 --- a/nerfstudio/fields/nerfw_field.py +++ b/nerfstudio/fields/nerfw_field.py @@ -75,7 +75,8 @@ def __init__( self.appearance_embedding_dim = appearance_embedding_dim self.transient_embedding_dim = transient_embedding_dim - self.embedding_appearance = Embedding(self.num_images, self.appearance_embedding_dim) + if self.appearance_embedding_dim > 0: + self.embedding_appearance = Embedding(self.num_images, self.appearance_embedding_dim) self.embedding_transient = Embedding(self.num_images, self.transient_embedding_dim) self.mlp_base = MLP( @@ -96,7 +97,7 @@ def __init__( self.mlp_head = MLP( in_dim=self.mlp_base.get_out_dim() + self.direction_encoding.get_out_dim() - + self.embedding_appearance.get_out_dim(), + + (self.embedding_appearance.get_out_dim() if self.appearance_embedding_dim > 0 else 0), num_layers=head_mlp_num_layers, layer_width=head_mlp_layer_width, out_activation=nn.ReLU(), @@ -134,9 +135,11 @@ def get_outputs( if ray_samples.camera_indices is None: raise AttributeError("Camera indices are not provided.") camera_indices = ray_samples.camera_indices.squeeze().to(ray_samples.frustums.origins.device) - embedded_appearance = self.embedding_appearance(camera_indices) - mlp_in = torch.cat([density_embedding, encoded_dir, embedded_appearance], dim=-1) # type: ignore - mlp_head_out = self.mlp_head(mlp_in) + mlp_in = [density_embedding, encoded_dir] + if self.appearance_embedding_dim > 0: + embedded_appearance = self.embedding_appearance(camera_indices) + mlp_in.append(embedded_appearance) + mlp_head_out = self.mlp_head(torch.cat(mlp_in, dim=-1)) outputs[self.field_head_rgb.field_head_name] = self.field_head_rgb(mlp_head_out) # static rgb embedded_transient = self.embedding_transient(camera_indices) transient_mlp_in = torch.cat([density_embedding, embedded_transient], dim=-1) # type: ignore diff --git a/nerfstudio/models/nerfacto.py b/nerfstudio/models/nerfacto.py index 47fc42e31f..bfccfd8797 100644 --- a/nerfstudio/models/nerfacto.py +++ b/nerfstudio/models/nerfacto.py @@ -106,6 +106,8 @@ class NerfactoModelConfig(ModelConfig): """Predicted normal loss multiplier.""" use_proposal_weight_anneal: bool = True """Whether to use proposal weight annealing.""" + use_appearance_embedding: bool = True + """Whether to use an appearance embedding.""" use_average_appearance_embedding: bool = True """Whether to use average appearance embedding or zeros for inference.""" proposal_weights_anneal_slope: float = 10.0 @@ -148,6 +150,8 @@ def populate_modules(self): else: scene_contraction = SceneContraction(order=float("inf")) + appearance_embedding_dim = self.config.appearance_embed_dim if self.config.use_appearance_embedding else 0 + # Fields self.field = NerfactoField( self.scene_box.aabb, @@ -163,7 +167,7 @@ def populate_modules(self): num_images=self.num_train_data, use_pred_normals=self.config.predict_normals, use_average_appearance_embedding=self.config.use_average_appearance_embedding, - appearance_embedding_dim=self.config.appearance_embed_dim, + appearance_embedding_dim=appearance_embedding_dim, average_init_density=self.config.average_init_density, implementation=self.config.implementation, )