@@ -77,6 +77,8 @@ def __init__(
77
77
78
78
if self .appearance_embedding_dim > 0 :
79
79
self .embedding_appearance = Embedding (self .num_images , self .appearance_embedding_dim )
80
+ else :
81
+ self .embedding_appearance = None
80
82
self .embedding_transient = Embedding (self .num_images , self .transient_embedding_dim )
81
83
82
84
self .mlp_base = MLP (
@@ -97,7 +99,7 @@ def __init__(
97
99
self .mlp_head = MLP (
98
100
in_dim = self .mlp_base .get_out_dim ()
99
101
+ self .direction_encoding .get_out_dim ()
100
- + (self .embedding_appearance .get_out_dim () if self .appearance_embedding_dim > 0 else 0 ),
102
+ + (self .embedding_appearance .get_out_dim () if self .embedding_appearance is not None else 0 ),
101
103
num_layers = head_mlp_num_layers ,
102
104
layer_width = head_mlp_layer_width ,
103
105
out_activation = nn .ReLU (),
@@ -136,7 +138,7 @@ def get_outputs(
136
138
raise AttributeError ("Camera indices are not provided." )
137
139
camera_indices = ray_samples .camera_indices .squeeze ().to (ray_samples .frustums .origins .device )
138
140
mlp_in = [density_embedding , encoded_dir ]
139
- if self .appearance_embedding_dim > 0 :
141
+ if self .embedding_appearance is not None :
140
142
embedded_appearance = self .embedding_appearance (camera_indices )
141
143
mlp_in .append (embedded_appearance )
142
144
mlp_head_out = self .mlp_head (torch .cat (mlp_in , dim = - 1 ))
0 commit comments