diff --git a/nerfstudio/models/splatfacto.py b/nerfstudio/models/splatfacto.py index a88a306ced..66113fc7ed 100644 --- a/nerfstudio/models/splatfacto.py +++ b/nerfstudio/models/splatfacto.py @@ -46,6 +46,7 @@ from nerfstudio.model_components import renderers from nerfstudio.models.base_model import Model, ModelConfig from nerfstudio.utils.colors import get_color +from nerfstudio.utils.misc import torch_compile from nerfstudio.utils.rich_utils import CONSOLE @@ -99,7 +100,7 @@ def resize_image(image: torch.Tensor, d: int): return tf.conv2d(image.permute(2, 0, 1)[:, None, ...], weight, stride=d).squeeze(1).permute(1, 2, 0) -@torch.compile() +@torch_compile() def get_viewmat(optimized_camera_to_world): """ function that converts c2w to gsplat world2camera matrix, using compile for some speed