Skip to content

Commit

Permalink
add time interpolation support for ns-render interpolate (#3481)
Browse files Browse the repository at this point in the history
add time interpolation support
  • Loading branch information
Tavish9 authored Nov 29, 2024
1 parent 4d73c4e commit e8bf472
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 11 deletions.
6 changes: 5 additions & 1 deletion nerfstudio/cameras/camera_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ def get_interpolated_camera_path(cameras: Cameras, steps: int, order_poses: bool
"""
Ks = cameras.get_intrinsics_matrices()
poses = cameras.camera_to_worlds
poses, Ks = get_interpolated_poses_many(poses, Ks, steps_per_transition=steps, order_poses=order_poses)
times = cameras.times
poses, Ks, times = get_interpolated_poses_many(
poses, Ks, times, steps_per_transition=steps, order_poses=order_poses
)

cameras = Cameras(
fx=Ks[:, 0, 0],
Expand All @@ -48,6 +51,7 @@ def get_interpolated_camera_path(cameras: Cameras, steps: int, order_poses: bool
cy=Ks[0, 1, 2],
camera_type=cameras.camera_type[0],
camera_to_worlds=poses,
times=times,
)
return cameras

Expand Down
52 changes: 42 additions & 10 deletions nerfstudio/cameras/camera_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,48 +206,74 @@ def get_interpolated_k(
return Ks


def get_ordered_poses_and_k(
def get_interpolated_time(
time_a: Float[Tensor, "1"], time_b: Float[Tensor, "1"], steps: int = 10
) -> List[Float[Tensor, "1"]]:
"""
Returns interpolated time between two camera poses with specified number of steps.
Args:
time_a: camera time 1
time_b: camera time 2
steps: number of steps the interpolated pose path should contain
"""
times: List[Float[Tensor, "1"]] = []
ts = np.linspace(0, 1, steps)
for t in ts:
new_t = time_a * (1.0 - t) + time_b * t
times.append(new_t)
return times


def get_ordered_poses_and_k_and_time(
poses: Float[Tensor, "num_poses 3 4"],
Ks: Float[Tensor, "num_poses 3 3"],
) -> Tuple[Float[Tensor, "num_poses 3 4"], Float[Tensor, "num_poses 3 3"]]:
times: Optional[Float[Tensor, "num_poses 1"]] = None,
) -> Tuple[Float[Tensor, "num_poses 3 4"], Float[Tensor, "num_poses 3 3"], Optional[Float[Tensor, "num_poses 1"]]]:
"""
Returns ordered poses and intrinsics by euclidian distance between poses.
Args:
poses: list of camera poses
Ks: list of camera intrinsics
times: list of camera times
Returns:
tuple of ordered poses and intrinsics
tuple of ordered poses, intrinsics and times
"""

poses_num = len(poses)

ordered_poses = torch.unsqueeze(poses[0], 0)
ordered_ks = torch.unsqueeze(Ks[0], 0)
ordered_times = torch.unsqueeze(times[0], 0) if times is not None else None

# remove the first pose from poses
poses = poses[1:]
Ks = Ks[1:]
times = times[1:] if times is not None else None

for _ in range(poses_num - 1):
distances = torch.norm(ordered_poses[-1][:, 3] - poses[:, :, 3], dim=1)
idx = torch.argmin(distances)
ordered_poses = torch.cat((ordered_poses, torch.unsqueeze(poses[idx], 0)), dim=0)
ordered_ks = torch.cat((ordered_ks, torch.unsqueeze(Ks[idx], 0)), dim=0)
ordered_times = torch.cat((ordered_times, torch.unsqueeze(times[idx], 0)), dim=0) if times is not None else None # type: ignore
poses = torch.cat((poses[0:idx], poses[idx + 1 :]), dim=0)
Ks = torch.cat((Ks[0:idx], Ks[idx + 1 :]), dim=0)
times = torch.cat((times[0:idx], times[idx + 1 :]), dim=0) if times is not None else None

return ordered_poses, ordered_ks
return ordered_poses, ordered_ks, ordered_times


def get_interpolated_poses_many(
poses: Float[Tensor, "num_poses 3 4"],
Ks: Float[Tensor, "num_poses 3 3"],
times: Optional[Float[Tensor, "num_poses 1"]] = None,
steps_per_transition: int = 10,
order_poses: bool = False,
) -> Tuple[Float[Tensor, "num_poses 3 4"], Float[Tensor, "num_poses 3 3"]]:
) -> Tuple[Float[Tensor, "num_poses 3 4"], Float[Tensor, "num_poses 3 3"], Optional[Float[Tensor, "num_poses 1"]]]:
"""Return interpolated poses for many camera poses.
Args:
Expand All @@ -261,21 +287,27 @@ def get_interpolated_poses_many(
"""
traj = []
k_interp = []
time_interp = [] if times is not None else None

if order_poses:
poses, Ks = get_ordered_poses_and_k(poses, Ks)
poses, Ks, times = get_ordered_poses_and_k_and_time(poses, Ks, times)

for idx in range(poses.shape[0] - 1):
pose_a = poses[idx].cpu().numpy()
pose_b = poses[idx + 1].cpu().numpy()
poses_ab = get_interpolated_poses(pose_a, pose_b, steps=steps_per_transition)
traj += poses_ab
traj += get_interpolated_poses(pose_a, pose_b, steps=steps_per_transition)
k_interp += get_interpolated_k(Ks[idx], Ks[idx + 1], steps=steps_per_transition)
if times is not None:
time_interp += get_interpolated_time(times[idx], times[idx + 1], steps=steps_per_transition) # type: ignore

traj = np.stack(traj, axis=0)
k_interp = torch.stack(k_interp, dim=0)

return torch.tensor(traj, dtype=torch.float32), torch.tensor(k_interp, dtype=torch.float32)
time_interp = torch.stack(time_interp, dim=0) if time_interp is not None else None
return (
torch.tensor(traj, dtype=torch.float32),
torch.tensor(k_interp, dtype=torch.float32),
torch.tensor(time_interp, dtype=torch.float32) if time_interp is not None else None,
)


def normalize(x: torch.Tensor) -> Float[Tensor, "*batch"]:
Expand Down

0 comments on commit e8bf472

Please sign in to comment.