Skip to content

Commit

Permalink
Merge branch 'main' into feat/improve-starting
Browse files Browse the repository at this point in the history
  • Loading branch information
jb-ye authored Dec 2, 2024
2 parents 93f8843 + e8bf472 commit 848c6a7
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 12 deletions.
6 changes: 5 additions & 1 deletion nerfstudio/cameras/camera_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ def get_interpolated_camera_path(cameras: Cameras, steps: int, order_poses: bool
"""
Ks = cameras.get_intrinsics_matrices()
poses = cameras.camera_to_worlds
poses, Ks = get_interpolated_poses_many(poses, Ks, steps_per_transition=steps, order_poses=order_poses)
times = cameras.times
poses, Ks, times = get_interpolated_poses_many(
poses, Ks, times, steps_per_transition=steps, order_poses=order_poses
)

cameras = Cameras(
fx=Ks[:, 0, 0],
Expand All @@ -48,6 +51,7 @@ def get_interpolated_camera_path(cameras: Cameras, steps: int, order_poses: bool
cy=Ks[0, 1, 2],
camera_type=cameras.camera_type[0],
camera_to_worlds=poses,
times=times,
)
return cameras

Expand Down
52 changes: 42 additions & 10 deletions nerfstudio/cameras/camera_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,48 +206,74 @@ def get_interpolated_k(
return Ks


def get_ordered_poses_and_k(
def get_interpolated_time(
time_a: Float[Tensor, "1"], time_b: Float[Tensor, "1"], steps: int = 10
) -> List[Float[Tensor, "1"]]:
"""
Returns interpolated time between two camera poses with specified number of steps.
Args:
time_a: camera time 1
time_b: camera time 2
steps: number of steps the interpolated pose path should contain
"""
times: List[Float[Tensor, "1"]] = []
ts = np.linspace(0, 1, steps)
for t in ts:
new_t = time_a * (1.0 - t) + time_b * t
times.append(new_t)
return times


def get_ordered_poses_and_k_and_time(
poses: Float[Tensor, "num_poses 3 4"],
Ks: Float[Tensor, "num_poses 3 3"],
) -> Tuple[Float[Tensor, "num_poses 3 4"], Float[Tensor, "num_poses 3 3"]]:
times: Optional[Float[Tensor, "num_poses 1"]] = None,
) -> Tuple[Float[Tensor, "num_poses 3 4"], Float[Tensor, "num_poses 3 3"], Optional[Float[Tensor, "num_poses 1"]]]:
"""
Returns ordered poses and intrinsics by euclidian distance between poses.
Args:
poses: list of camera poses
Ks: list of camera intrinsics
times: list of camera times
Returns:
tuple of ordered poses and intrinsics
tuple of ordered poses, intrinsics and times
"""

poses_num = len(poses)

ordered_poses = torch.unsqueeze(poses[0], 0)
ordered_ks = torch.unsqueeze(Ks[0], 0)
ordered_times = torch.unsqueeze(times[0], 0) if times is not None else None

# remove the first pose from poses
poses = poses[1:]
Ks = Ks[1:]
times = times[1:] if times is not None else None

for _ in range(poses_num - 1):
distances = torch.norm(ordered_poses[-1][:, 3] - poses[:, :, 3], dim=1)
idx = torch.argmin(distances)
ordered_poses = torch.cat((ordered_poses, torch.unsqueeze(poses[idx], 0)), dim=0)
ordered_ks = torch.cat((ordered_ks, torch.unsqueeze(Ks[idx], 0)), dim=0)
ordered_times = torch.cat((ordered_times, torch.unsqueeze(times[idx], 0)), dim=0) if times is not None else None # type: ignore
poses = torch.cat((poses[0:idx], poses[idx + 1 :]), dim=0)
Ks = torch.cat((Ks[0:idx], Ks[idx + 1 :]), dim=0)
times = torch.cat((times[0:idx], times[idx + 1 :]), dim=0) if times is not None else None

return ordered_poses, ordered_ks
return ordered_poses, ordered_ks, ordered_times


def get_interpolated_poses_many(
poses: Float[Tensor, "num_poses 3 4"],
Ks: Float[Tensor, "num_poses 3 3"],
times: Optional[Float[Tensor, "num_poses 1"]] = None,
steps_per_transition: int = 10,
order_poses: bool = False,
) -> Tuple[Float[Tensor, "num_poses 3 4"], Float[Tensor, "num_poses 3 3"]]:
) -> Tuple[Float[Tensor, "num_poses 3 4"], Float[Tensor, "num_poses 3 3"], Optional[Float[Tensor, "num_poses 1"]]]:
"""Return interpolated poses for many camera poses.
Args:
Expand All @@ -261,21 +287,27 @@ def get_interpolated_poses_many(
"""
traj = []
k_interp = []
time_interp = [] if times is not None else None

if order_poses:
poses, Ks = get_ordered_poses_and_k(poses, Ks)
poses, Ks, times = get_ordered_poses_and_k_and_time(poses, Ks, times)

for idx in range(poses.shape[0] - 1):
pose_a = poses[idx].cpu().numpy()
pose_b = poses[idx + 1].cpu().numpy()
poses_ab = get_interpolated_poses(pose_a, pose_b, steps=steps_per_transition)
traj += poses_ab
traj += get_interpolated_poses(pose_a, pose_b, steps=steps_per_transition)
k_interp += get_interpolated_k(Ks[idx], Ks[idx + 1], steps=steps_per_transition)
if times is not None:
time_interp += get_interpolated_time(times[idx], times[idx + 1], steps=steps_per_transition) # type: ignore

traj = np.stack(traj, axis=0)
k_interp = torch.stack(k_interp, dim=0)

return torch.tensor(traj, dtype=torch.float32), torch.tensor(k_interp, dtype=torch.float32)
time_interp = torch.stack(time_interp, dim=0) if time_interp is not None else None
return (
torch.tensor(traj, dtype=torch.float32),
torch.tensor(k_interp, dtype=torch.float32),
torch.tensor(time_interp, dtype=torch.float32) if time_interp is not None else None,
)


def normalize(x: torch.Tensor) -> Float[Tensor, "*batch"]:
Expand Down
2 changes: 1 addition & 1 deletion nerfstudio/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def train(self) -> None:

self._init_viewer_state()
with TimeWriter(writer, EventName.TOTAL_TRAIN_TIME):
num_iterations = self.config.max_num_iterations
num_iterations = self.config.max_num_iterations - self._start_step
step = 0
self.stop_training = False
for step in range(self._start_step, self._start_step + num_iterations):
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ lint.ignore = [
"PLW0603", # Globa statement updates are discouraged.
"PLW2901", # For loop variable overwritten.
"PLR1730", # Replace if statement with min/max
"PLC0206", # Extracting value from dictionary without calling `.items()`
]

[tool.ruff.lint.isort]
Expand Down

0 comments on commit 848c6a7

Please sign in to comment.