Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add per frame intrinsic support to nerfstudio data #1049

Merged
merged 2 commits into from
Nov 29, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 77 additions & 17 deletions nerfstudio/data/dataparsers/nerfstudio_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,64 @@ def _generate_dataparser_outputs(self, split="train"):
poses = []
num_skipped_image_filenames = 0

fx_fixed = "fl_x" in meta
fy_fixed = "fl_y" in meta
cx_fixed = "cx" in meta
cy_fixed = "cy" in meta
height_fixed = "h" in meta
width_fixed = "w" in meta
distort_fixed = False
for distort_key in ["k1", "k2", "k3", "p1", "p2"]:
if distort_key in meta:
distort_fixed = True
break
fx = []
fy = []
cx = []
cy = []
height = []
width = []
distort = []

for frame in meta["frames"]:
filepath = PurePath(frame["file_path"])
fname = self._get_fname(filepath)
if not fname.exists():
num_skipped_image_filenames += 1
else:
image_filenames.append(fname)
poses.append(np.array(frame["transform_matrix"]))
continue

if not fx_fixed:
assert "fl_x" in frame, "fx not specified in frame"
fx.append(float(frame["fl_x"]))
if not fy_fixed:
assert "fl_y" in frame, "fy not specified in frame"
fy.append(float(frame["fl_y"]))
if not cx_fixed:
assert "cx" in frame, "cx not specified in frame"
cx.append(float(frame["cx"]))
if not cy_fixed:
assert "cy" in frame, "cy not specified in frame"
cy.append(float(frame["cy"]))
if not height_fixed:
assert "h" in frame, "height not specified in frame"
height.append(int(frame["h"]))
if not width_fixed:
assert "w" in frame, "width not specified in frame"
width.append(int(frame["w"]))
if not distort_fixed:
distort.append(
camera_utils.get_distortion_params(
k1=float(meta["k1"]) if "k1" in meta else 0.0,
k2=float(meta["k2"]) if "k2" in meta else 0.0,
k3=float(meta["k3"]) if "k3" in meta else 0.0,
k4=float(meta["k4"]) if "k4" in meta else 0.0,
p1=float(meta["p1"]) if "p1" in meta else 0.0,
p2=float(meta["p2"]) if "p2" in meta else 0.0,
)
)

image_filenames.append(fname)
poses.append(np.array(frame["transform_matrix"]))
if "mask_path" in frame:
mask_filepath = PurePath(frame["mask_path"])
mask_fname = self._get_fname(mask_filepath, downsample_folder_prefix="masks_")
Expand Down Expand Up @@ -162,23 +212,33 @@ def _generate_dataparser_outputs(self, split="train"):
else:
camera_type = CameraType.PERSPECTIVE

distortion_params = camera_utils.get_distortion_params(
k1=float(meta["k1"]) if "k1" in meta else 0.0,
k2=float(meta["k2"]) if "k2" in meta else 0.0,
k3=float(meta["k3"]) if "k3" in meta else 0.0,
k4=float(meta["k4"]) if "k4" in meta else 0.0,
p1=float(meta["p1"]) if "p1" in meta else 0.0,
p2=float(meta["p2"]) if "p2" in meta else 0.0,
)
idx_tensor = torch.tensor(indices)
fx = float(meta["fl_x"]) if fx_fixed else torch.tensor(fx, dtype=torch.float32)[idx_tensor]
fy = float(meta["fl_y"]) if fy_fixed else torch.tensor(fy, dtype=torch.float32)[idx_tensor]
cx = float(meta["cx"]) if cx_fixed else torch.tensor(cx, dtype=torch.float32)[idx_tensor]
cy = float(meta["cy"]) if cy_fixed else torch.tensor(cy, dtype=torch.float32)[idx_tensor]
height = int(meta["h"]) if height_fixed else torch.tensor(height, dtype=torch.int32)[idx_tensor]
width = int(meta["w"]) if width_fixed else torch.tensor(width, dtype=torch.int32)[idx_tensor]
if distort_fixed:
distortion_params = camera_utils.get_distortion_params(
k1=float(meta["k1"]) if "k1" in meta else 0.0,
k2=float(meta["k2"]) if "k2" in meta else 0.0,
k3=float(meta["k3"]) if "k3" in meta else 0.0,
k4=float(meta["k4"]) if "k4" in meta else 0.0,
p1=float(meta["p1"]) if "p1" in meta else 0.0,
p2=float(meta["p2"]) if "p2" in meta else 0.0,
)
else:
distortion_params = torch.stack(distort, dim=0)[idx_tensor]

cameras = Cameras(
fx=float(meta["fl_x"]),
fy=float(meta["fl_y"]),
cx=float(meta["cx"]),
cy=float(meta["cy"]),
fx=fx,
fy=fy,
cx=cx,
cy=cy,
distortion_params=distortion_params,
height=int(meta["h"]),
width=int(meta["w"]),
height=height,
width=width,
camera_to_worlds=poses[:, :3, :4],
camera_type=camera_type,
)
Expand Down