Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update exporter.py to export sh_degree 0 case #3371 #3374

Merged
merged 12 commits into from
Aug 28, 2024
5 changes: 4 additions & 1 deletion nerfstudio/models/splatfacto.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,10 @@ def colors(self):

@property
def shs_0(self):
return self.features_dc
if self.config.sh_degree > 0:
return self.features_dc
else:
return RGB2SH(torch.sigmoid(self.features_dc))

@property
def shs_rest(self):
Expand Down
33 changes: 22 additions & 11 deletions nerfstudio/scripts/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,9 @@ class ExportGaussianSplat(Exporter):
"""Rotation of the oriented bounding box. Expressed as RPY Euler angles in radians"""
obb_scale: Optional[Tuple[float, float, float]] = None
"""Scale of the oriented bounding box along each axis."""
ply_color_mode: Literal["sh_coeffs", "rgb"] = "sh_coeffs"
"""If "rgb", export colors as red/green/blue fields. Otherwise, export colors as
spherical harmonics coefficients."""

@staticmethod
def write_ply(
Expand All @@ -504,7 +507,7 @@ def write_ply(
"""

# Ensure count matches the length of all tensors
if not all(len(tensor) == count for tensor in map_to_tensors.values()):
if not all(tensor.size == count for tensor in map_to_tensors.values()):
raise ValueError("Count does not match the length of all tensors")

# Type check for numpy arrays of type float or uint8 and non-empty
Expand Down Expand Up @@ -552,7 +555,6 @@ def main(self) -> None:

filename = self.output_dir / "splat.ply"

count = 0
map_to_tensors = OrderedDict()

with torch.no_grad():
Expand All @@ -566,19 +568,28 @@ def main(self) -> None:
map_to_tensors["ny"] = np.zeros(n, dtype=np.float32)
map_to_tensors["nz"] = np.zeros(n, dtype=np.float32)

if model.config.sh_degree > 0:
if self.ply_color_mode == "rgb":
colors = torch.clamp(model.colors.clone(), 0.0, 1.0).data.cpu().numpy()
colors = (colors * 255).astype(np.uint8)
map_to_tensors["red"] = colors[:, 0]
map_to_tensors["green"] = colors[:, 1]
map_to_tensors["blue"] = colors[:, 2]
elif self.ply_color_mode == "sh_coeffs":
shs_0 = model.shs_0.contiguous().cpu().numpy()
for i in range(shs_0.shape[1]):
map_to_tensors[f"f_dc_{i}"] = shs_0[:, i, None]

# transpose(1, 2) was needed to match the sh order in Inria version
shs_rest = model.shs_rest.transpose(1, 2).contiguous().cpu().numpy()
shs_rest = shs_rest.reshape((n, -1))
for i in range(shs_rest.shape[-1]):
map_to_tensors[f"f_rest_{i}"] = shs_rest[:, i, None]
else:
colors = torch.clamp(model.colors.clone(), 0.0, 1.0).data.cpu().numpy()
map_to_tensors["colors"] = (colors * 255).astype(np.uint8)
if model.config.sh_degree > 0:
if self.ply_color_mode == "rgb":
CONSOLE.print(
"Warning: model has higher level of spherical harmonics, ignoring them and only export rgb."
)
elif self.ply_color_mode == "sh_coeffs":
# transpose(1, 2) was needed to match the sh order in Inria version
shs_rest = model.shs_rest.transpose(1, 2).contiguous().cpu().numpy()
shs_rest = shs_rest.reshape((n, -1))
for i in range(shs_rest.shape[-1]):
map_to_tensors[f"f_rest_{i}"] = shs_rest[:, i, None]

map_to_tensors["opacity"] = model.opacities.data.cpu().numpy()

Expand Down
Loading