Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add direct render+export and notifications #3220

Open
wants to merge 41 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
b48fa2f
change render button to run ns-render directly
ginazhouhuiwu Jun 4, 2024
2f5e5ca
nit
ginazhouhuiwu Jun 4, 2024
0933c82
nits
ginazhouhuiwu Jun 4, 2024
c273891
Merge branch 'main' into gina/export
ginazhouhuiwu Jun 10, 2024
c43e9e0
add direct render via viewer and notifications
ginazhouhuiwu Jun 13, 2024
7e92ca8
add direct export button and notifications
ginazhouhuiwu Jun 13, 2024
2f98a97
Merge branch 'main' into gina/export
ginazhouhuiwu Jun 13, 2024
44ae194
viser notification api changes
ginazhouhuiwu Jun 13, 2024
ec04421
export bug fixes
ginazhouhuiwu Jun 14, 2024
1aad57b
add (kind of?) error message for checkpoint isn't found during export
ginazhouhuiwu Jun 14, 2024
14839dd
message
ginazhouhuiwu Jun 14, 2024
df13472
add export warning message
ginazhouhuiwu Jun 15, 2024
6fbccf1
wip
ginazhouhuiwu Jun 16, 2024
006f5f3
update to match viser notification api
ginazhouhuiwu Jun 21, 2024
56e6139
add local render download
ginazhouhuiwu Jun 22, 2024
d39a729
wip for export local download
ginazhouhuiwu Jun 22, 2024
b6210fd
update icons
ginazhouhuiwu Jun 27, 2024
f86e690
Merge branch 'main' into gina/export
ginazhouhuiwu Aug 26, 2024
0ddabed
better rendering functionality and command palette
ginazhouhuiwu Aug 26, 2024
8a5bd0b
add button disable settings for render requirements and in-progress
ginazhouhuiwu Aug 27, 2024
8027a49
fix button disabling for renders
ginazhouhuiwu Aug 27, 2024
b5e9f8c
add render cancellation option
ginazhouhuiwu Aug 27, 2024
fef8461
disable button feature on exports and clean export panel
ginazhouhuiwu Aug 27, 2024
15a1fbb
Merge branch 'main' into gina/export
ginazhouhuiwu Aug 27, 2024
2423510
ruff format
ginazhouhuiwu Aug 27, 2024
f5f9fdb
ruff format
ginazhouhuiwu Aug 27, 2024
2ff7da4
ruff format
ginazhouhuiwu Aug 27, 2024
12c454a
RUFF FORMAT
ginazhouhuiwu Aug 27, 2024
e1ac0e5
some pyright type fixes
ginazhouhuiwu Aug 28, 2024
2308450
pyright
ginazhouhuiwu Aug 29, 2024
58a4c3a
pyright
ginazhouhuiwu Aug 29, 2024
8aa49c3
ruff format
ginazhouhuiwu Aug 29, 2024
6b4afed
ruff format exports
ginazhouhuiwu Aug 29, 2024
5f4ba8d
ruff again
ginazhouhuiwu Aug 29, 2024
61974c3
Merge branch 'main' into gina/export
ginazhouhuiwu Aug 30, 2024
c976bf6
Merge branch 'main' into gina/export
kerrj Sep 5, 2024
b838598
nit
ginazhouhuiwu Nov 5, 2024
c0e57ab
Merge branch 'gina/export' of https://github.com/nerfstudio-project/n…
ginazhouhuiwu Nov 5, 2024
2dc15df
Merge branch 'main' into gina/export
ginazhouhuiwu Nov 5, 2024
f51f355
error message and load checkpoint fixes
ginazhouhuiwu Nov 6, 2024
02a3f7a
ruff format import blocks
ginazhouhuiwu Nov 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 41 additions & 11 deletions nerfstudio/scripts/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class Exporter:
"""Path to the config YAML file."""
output_dir: Path
"""Path to the output directory."""
complete: bool = False
"""Set to True when export is finished."""


def validate_pipeline(normal_method: str, normal_output_name: str, pipeline: Pipeline) -> None:
Expand All @@ -74,7 +76,10 @@ def validate_pipeline(normal_method: str, normal_output_name: str, pipeline: Pip
pixel_area = torch.ones_like(origins[..., :1])
camera_indices = torch.zeros_like(origins[..., :1])
ray_bundle = RayBundle(
origins=origins, directions=directions, pixel_area=pixel_area, camera_indices=camera_indices
origins=origins,
directions=directions,
pixel_area=pixel_area,
camera_indices=camera_indices,
)
outputs = pipeline.model(ray_bundle)
if normal_output_name not in outputs:
Expand Down Expand Up @@ -135,7 +140,12 @@ def main(self) -> None:
# Increase the batchsize to speed up the evaluation.
assert isinstance(
pipeline.datamanager,
(VanillaDataManager, ParallelDataManager, FullImageDatamanager, RandomCamerasDataManager),
(
VanillaDataManager,
ParallelDataManager,
FullImageDatamanager,
RandomCamerasDataManager,
),
)
assert pipeline.datamanager.train_pixel_sampler is not None
pipeline.datamanager.train_pixel_sampler.num_rays_per_batch = self.num_rays_per_batch
Expand All @@ -153,7 +163,7 @@ def main(self) -> None:
estimate_normals=estimate_normals,
rgb_output_name=self.rgb_output_name,
depth_output_name=self.depth_output_name,
normal_output_name=self.normal_output_name if self.normal_method == "model_output" else None,
normal_output_name=(self.normal_output_name if self.normal_method == "model_output" else None),
crop_obb=crop_obb,
std_ratio=self.std_ratio,
)
Expand All @@ -180,6 +190,8 @@ def main(self) -> None:
print("\033[A\033[A")
CONSOLE.print("[bold green]:white_check_mark: Saving Point Cloud")

self.complete = True


@dataclass
class ExportTSDFMesh(Exporter):
Expand Down Expand Up @@ -241,18 +253,21 @@ def main(self) -> None:
if self.texture_method == "nerf":
# load the mesh from the tsdf export
mesh = get_mesh_from_filename(
str(self.output_dir / "tsdf_mesh.ply"), target_num_faces=self.target_num_faces
str(self.output_dir / "tsdf_mesh.ply"),
target_num_faces=self.target_num_faces,
)
CONSOLE.print("Texturing mesh with NeRF")
texture_utils.export_textured_mesh(
mesh,
pipeline,
self.output_dir,
px_per_uv_triangle=self.px_per_uv_triangle if self.unwrap_method == "custom" else None,
px_per_uv_triangle=(self.px_per_uv_triangle if self.unwrap_method == "custom" else None),
unwrap_method=self.unwrap_method,
num_pixels_per_side=self.num_pixels_per_side,
)

self.complete = True


@dataclass
class ExportPoissonMesh(Exporter):
Expand Down Expand Up @@ -316,7 +331,12 @@ def main(self) -> None:
# Increase the batchsize to speed up the evaluation.
assert isinstance(
pipeline.datamanager,
(VanillaDataManager, ParallelDataManager, FullImageDatamanager, RandomCamerasDataManager),
(
VanillaDataManager,
ParallelDataManager,
FullImageDatamanager,
RandomCamerasDataManager,
),
)
assert pipeline.datamanager.train_pixel_sampler is not None
pipeline.datamanager.train_pixel_sampler.num_rays_per_batch = self.num_rays_per_batch
Expand All @@ -336,7 +356,7 @@ def main(self) -> None:
estimate_normals=estimate_normals,
rgb_output_name=self.rgb_output_name,
depth_output_name=self.depth_output_name,
normal_output_name=self.normal_output_name if self.normal_method == "model_output" else None,
normal_output_name=(self.normal_output_name if self.normal_method == "model_output" else None),
crop_obb=crop_obb,
std_ratio=self.std_ratio,
)
Expand Down Expand Up @@ -366,18 +386,21 @@ def main(self) -> None:
if self.texture_method == "nerf":
# load the mesh from the poisson reconstruction
mesh = get_mesh_from_filename(
str(self.output_dir / "poisson_mesh.ply"), target_num_faces=self.target_num_faces
str(self.output_dir / "poisson_mesh.ply"),
target_num_faces=self.target_num_faces,
)
CONSOLE.print("Texturing mesh with NeRF")
texture_utils.export_textured_mesh(
mesh,
pipeline,
self.output_dir,
px_per_uv_triangle=self.px_per_uv_triangle if self.unwrap_method == "custom" else None,
px_per_uv_triangle=(self.px_per_uv_triangle if self.unwrap_method == "custom" else None),
unwrap_method=self.unwrap_method,
num_pixels_per_side=self.num_pixels_per_side,
)

self.complete = True


@dataclass
class ExportMarchingCubesMesh(Exporter):
Expand Down Expand Up @@ -439,11 +462,13 @@ def main(self) -> None:
mesh,
pipeline,
self.output_dir,
px_per_uv_triangle=self.px_per_uv_triangle if self.unwrap_method == "custom" else None,
px_per_uv_triangle=(self.px_per_uv_triangle if self.unwrap_method == "custom" else None),
unwrap_method=self.unwrap_method,
num_pixels_per_side=self.num_pixels_per_side,
)

self.complete = True


@dataclass
class ExportCameraPoses(Exporter):
Expand All @@ -460,7 +485,10 @@ def main(self) -> None:
assert isinstance(pipeline, VanillaPipeline)
train_frames, eval_frames = collect_camera_poses(pipeline)

for file_name, frames in [("transforms_train.json", train_frames), ("transforms_eval.json", eval_frames)]:
for file_name, frames in [
("transforms_train.json", train_frames),
("transforms_eval.json", eval_frames),
]:
if len(frames) == 0:
CONSOLE.print(f"[bold yellow]No frames found for {file_name}. Skipping.")
continue
Expand Down Expand Up @@ -629,6 +657,8 @@ def main(self) -> None:

ExportGaussianSplat.write_ply(str(filename), count, map_to_tensors)

self.complete = True


Commands = tyro.conf.FlagConversionOff[
Union[
Expand Down
80 changes: 66 additions & 14 deletions nerfstudio/scripts/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,10 @@ def _render_trajectory_video(
depth_near_plane: Optional[float] = None,
depth_far_plane: Optional[float] = None,
colormap_options: colormaps.ColormapOptions = colormaps.ColormapOptions(),
render_nearest_camera=False,
render_nearest_camera: bool = False,
check_occlusions: bool = False,
) -> None:
kill_flag: List[bool] = [False],
) -> bool:
"""Helper function to create a video of the spiral trajectory.

Args:
Expand Down Expand Up @@ -137,6 +138,9 @@ def _render_trajectory_video(

with progress:
for camera_idx in progress.track(range(cameras.size), description=""):
if kill_flag[0]:
return False

obb_box = None
if crop_data is not None:
obb_box = crop_data.obb
Expand Down Expand Up @@ -202,9 +206,13 @@ def _render_trajectory_video(
for rendered_output_name in rendered_output_names:
if rendered_output_name not in outputs:
CONSOLE.rule("Error", style="red")
CONSOLE.print(f"Could not find {rendered_output_name} in the model outputs", justify="center")
CONSOLE.print(
f"Please set --rendered_output_name to one of: {outputs.keys()}", justify="center"
f"Could not find {rendered_output_name} in the model outputs",
justify="center",
)
CONSOLE.print(
f"Please set --rendered_output_name to one of: {outputs.keys()}",
justify="center",
)
sys.exit(1)
output_image = outputs[rendered_output_name]
Expand Down Expand Up @@ -256,10 +264,17 @@ def _render_trajectory_video(
render_image = np.concatenate(render_image, axis=1)
if output_format == "images":
if image_format == "png":
media.write_image(output_image_dir / f"{camera_idx:05d}.png", render_image, fmt="png")
media.write_image(
output_image_dir / f"{camera_idx:05d}.png",
render_image,
fmt="png",
)
if image_format == "jpeg":
media.write_image(
output_image_dir / f"{camera_idx:05d}.jpg", render_image, fmt="jpeg", quality=jpeg_quality
output_image_dir / f"{camera_idx:05d}.jpg",
render_image,
fmt="jpeg",
quality=jpeg_quality,
)
if output_format == "video":
if writer is None:
Expand Down Expand Up @@ -287,7 +302,15 @@ def _render_trajectory_video(
table.add_row("Video", str(output_filename))
else:
table.add_row("Images", str(output_image_dir))
CONSOLE.print(Panel(table, title="[bold][green]:tada: Render Complete :tada:[/bold]", expand=False))
CONSOLE.print(
Panel(
table,
title="[bold][green]:tada: Render Complete :tada:[/bold]",
expand=False,
)
)

return True


def insert_spherical_metadata_into_file(
Expand Down Expand Up @@ -432,6 +455,11 @@ class BaseRender:
"""If true, checks line-of-sight occlusions when computing camera distance and rejects cameras not visible to each other"""
camera_idx: Optional[int] = None
"""Index of the training camera to render."""
kill_flag: List[bool] = field(default_factory=lambda: [False])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not expose this in the config? it feels like it should be an invisible class self._kill_flag variable, then the kill() function sets that to true.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still a hack but one option is:

# `tyro.conf.Suppress` hides this field from the generated `ns-render` CLI.
_kill_flag: tyro.conf.Suppress[List[bool]] = field(default_factory=lambda: [False])

"""Stop execution of render if set to True."""

def kill(self) -> None:
self.kill_flag[0] = True


@dataclass
Expand All @@ -442,6 +470,8 @@ class RenderCameraPath(BaseRender):
"""Filename of the camera path to render."""
output_format: Literal["images", "video"] = "video"
"""How to save output data."""
complete: bool = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same with this, it seems like complete should be a @property that gets managed internally

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not really following the connection to @property

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can hack around this too by suppressing? I think as a field it does need to still exist for the progress(notification) update.

"""Set to True when render is finished."""

def main(self) -> None:
"""Main function."""
Expand Down Expand Up @@ -485,7 +515,7 @@ def main(self) -> None:
if self.camera_idx is not None:
camera_path.metadata = {"cam_idx": self.camera_idx}

_render_trajectory_video(
self.complete = _render_trajectory_video(
pipeline,
camera_path,
output_filename=self.output_path,
Expand All @@ -501,6 +531,7 @@ def main(self) -> None:
colormap_options=self.colormap_options,
render_nearest_camera=self.render_nearest_camera,
check_occlusions=self.check_occlusions,
kill_flag=self.kill_flag,
)

if (
Expand Down Expand Up @@ -536,6 +567,7 @@ def main(self) -> None:
colormap_options=self.colormap_options,
render_nearest_camera=self.render_nearest_camera,
check_occlusions=self.check_occlusions,
kill_flag=self.kill_flag,
)

self.output_path = Path(str(left_eye_path.parent)[:-5] + ".mp4")
Expand Down Expand Up @@ -639,6 +671,7 @@ def main(self) -> None:
colormap_options=self.colormap_options,
render_nearest_camera=self.render_nearest_camera,
check_occlusions=self.check_occlusions,
kill_flag=self.kill_flag,
)


Expand Down Expand Up @@ -694,6 +727,7 @@ def main(self) -> None:
colormap_options=self.colormap_options,
render_nearest_camera=self.render_nearest_camera,
check_occlusions=self.check_occlusions,
kill_flag=self.kill_flag,
)


Expand Down Expand Up @@ -731,7 +765,10 @@ def main(self):

def update_config(config: TrainerConfig) -> TrainerConfig:
data_manager_config = config.pipeline.datamanager
assert isinstance(data_manager_config, (VanillaDataManagerConfig, FullImageDatamanagerConfig))
assert isinstance(
data_manager_config,
(VanillaDataManagerConfig, FullImageDatamanagerConfig),
)
data_manager_config.eval_num_images_to_sample_from = -1
data_manager_config.eval_num_times_to_repeat_images = -1
if isinstance(data_manager_config, VanillaDataManagerConfig):
Expand All @@ -741,7 +778,11 @@ def update_config(config: TrainerConfig) -> TrainerConfig:
data_manager_config.data = self.data
if self.downscale_factor is not None:
assert hasattr(data_manager_config.dataparser, "downscale_factor")
setattr(data_manager_config.dataparser, "downscale_factor", self.downscale_factor)
setattr(
data_manager_config.dataparser,
"downscale_factor",
self.downscale_factor,
)
return config

config, pipeline, _, _ = eval_setup(
Expand Down Expand Up @@ -806,10 +847,12 @@ def update_config(config: TrainerConfig) -> TrainerConfig:
if rendered_output_name not in all_outputs:
CONSOLE.rule("Error", style="red")
CONSOLE.print(
f"Could not find {rendered_output_name} in the model outputs", justify="center"
f"Could not find {rendered_output_name} in the model outputs",
justify="center",
)
CONSOLE.print(
f"Please set --rendered-output-name to one of: {all_outputs}", justify="center"
f"Please set --rendered-output-name to one of: {all_outputs}",
justify="center",
)
sys.exit(1)

Expand Down Expand Up @@ -876,7 +919,10 @@ def update_config(config: TrainerConfig) -> TrainerConfig:
media.write_image(output_path.with_suffix(".png"), output_image, fmt="png")
elif self.image_format == "jpeg":
media.write_image(
output_path.with_suffix(".jpg"), output_image, fmt="jpeg", quality=self.jpeg_quality
output_path.with_suffix(".jpg"),
output_image,
fmt="jpeg",
quality=self.jpeg_quality,
)
else:
raise ValueError(f"Unknown image format {self.image_format}")
Expand All @@ -889,7 +935,13 @@ def update_config(config: TrainerConfig) -> TrainerConfig:
)
for split in self.split.split("+"):
table.add_row(f"Outputs {split}", str(self.output_path / split))
CONSOLE.print(Panel(table, title="[bold][green]:tada: Render on split {} Complete :tada:[/bold]", expand=False))
CONSOLE.print(
Panel(
table,
title="[bold][green]:tada: Render on split {} Complete :tada:[/bold]",
expand=False,
)
)


Commands = tyro.conf.FlagConversionOff[
Expand Down
Loading
Loading