diff --git a/nerfstudio/models/splatfacto.py b/nerfstudio/models/splatfacto.py index a6881d4244..f934b12639 100644 --- a/nerfstudio/models/splatfacto.py +++ b/nerfstudio/models/splatfacto.py @@ -25,7 +25,6 @@ import numpy as np import torch -from gsplat.cuda_legacy._torch_impl import quat_to_rotmat try: from gsplat.rendering import rasterization @@ -47,6 +46,26 @@ from nerfstudio.utils.rich_utils import CONSOLE +def quat_to_rotmat(quat): + assert quat.shape[-1] == 4, quat.shape + w, x, y, z = torch.unbind(quat, dim=-1) + mat = torch.stack( + [ + 1 - 2 * (y**2 + z**2), + 2 * (x * y - w * z), + 2 * (x * z + w * y), + 2 * (x * y + w * z), + 1 - 2 * (x**2 + z**2), + 2 * (y * z - w * x), + 2 * (x * z - w * y), + 2 * (y * z + w * x), + 1 - 2 * (x**2 + y**2), + ], + dim=-1, + ) + return mat.reshape(quat.shape[:-1] + (3, 3)) + + def random_quat_tensor(N): """ Defines a random quaternion tensor of shape (N, 4)