Skip to content

Commit

Permalink
Cleanup splatfacto
Browse files Browse the repository at this point in the history
  • Loading branch information
jancoveeden committed Sep 26, 2024
1 parent 2805efe commit 2fba911
Showing 1 changed file with 1 addition and 75 deletions.
76 changes: 1 addition & 75 deletions nerfstudio/models/splatfacto.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,78 +802,4 @@ def get_image_metrics_and_images(

images_dict = {"img": combined_rgb}

return metrics_dict, images_dict

def get_rgbsigma(self, xyz, camera: Cameras, res):
"""
Args:
xyz: (1, N, 3) feature grid positions
camera: Input ray_bundle
res: (3) List of xyz dimensions e.g. [256, 256, 256]
Returns:
rgb_mean, density
Notes:
N = number of samples
"""
device = self.device

rgbs = torch.zeros((res[0] * res[1] * res[2], 3)).to(device) # e.g. size: [16777216, 3]
depths = torch.zeros((xyz.shape[1], 1), device=device) # e.g. size: [16777216, 1]

for i in range(xyz.shape[1]):
point = xyz[:, i, :] # (1, 3) tensor representing the 3D position

# 1. Calculate the Gaussian density
means = self.means # (N, 3) tensor of Gaussian means
scales = torch.exp(self.scales) # (N, 3) tensor of Gaussian scales
quats = self.quats # (N, 4) tensor of Gaussian quaternions
rotmats = quat_to_rotmat(quats) # (N, 3, 3) tensor of rotation matrices
features_dc_crop = self.features_dc # (N, 3)
features_rest_crop = self.features_rest # (N, 16 - 1, 3)

# 1.1 Calculate squared Mahalanobis distances & density
# diffs = point - means # (N, 3)
# diffs = torch.sum(diffs * torch.bmm(diffs[:, None, :], rotmats) / (scales ** 2), dim=-1) # (N,)

# # # 1.2 Calculate density at the given point
# densities = torch.exp(-0.5 * diffs) # (N,) tensor of Gaussian densities
# total_density = torch.sum(densities * torch.sigmoid(self.opacities)) # Weighted sum of densities

# from gsplat._torch_impl.project_gaussians_forward
# p_view, is_close = clip_near_plane(means3d, viewmat, clip_thresh)
# depths = p_view[..., 2]
# depths = torch.where(~mask, 0, depths)
colors_crop = torch.cat((features_dc_crop[:, None, :], features_rest_crop), dim=1) # [N, 1, 3] + [N, 15, 3]
#CONSOLE.print(f"[bold blue]colors_crop.shape: {colors_crop.shape}")

### Sampling the RGB values
if self.config.sh_degree > 0:
# viewdirs = means.to(device) - camera.camera_to_worlds[..., :3, 3].to(device) # (N, 3)
viewdirs = point.to(device) - camera.camera_to_worlds[0, ..., :3, 3].to(device) # (3,)
viewdirs = viewdirs / viewdirs.norm(dim=-1, keepdim=True) # (3,)
n = min(self.step // self.config.sh_degree_interval, self.config.sh_degree)
# rgb = spherical_harmonics(n, viewdirs[None, :], self.shs_0, self.shs_rest)[0] # (3,)
# rgb = spherical_harmonics(n, viewdirs, torch.cat([self.shs_0[:, None], self.shs_rest], dim=1))

# OR
#rgb = spherical_harmonics(n, viewdirs, colors_crop) # [384213, 3]
#rgb = torch.clamp(rgb + 0.5, min=0.0) # [384213, 3]

rgb = torch.sigmoid(colors_crop[:, 0, :])

#CONSOLE.print(f"[bold blue]rgb.shape: {rgb.shape}")
#CONSOLE.print(f"[bold blue]rgb: \n{rgb}")
#exit()
#else:
# Ignore
# rgb = self.colors[densities.argmax()] # Sample RGB from Gaussian with max density
# OR
# rgb = torch.sigmoid(colors_crop[:, 0, :])

rgbs[i, :] = rgb
#depths[i, 0] = total_density

CONSOLE.print(f"[bold blue]rgbs.shape: {rgbs.shape}")
exit()

return rgbs, depths
return metrics_dict, images_dict

0 comments on commit 2fba911

Please sign in to comment.