-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add direct render+export and notifications #3220
base: main
Are you sure you want to change the base?
Changes from 35 commits
b48fa2f
2f5e5ca
0933c82
c273891
c43e9e0
7e92ca8
2f98a97
44ae194
ec04421
1aad57b
14839dd
df13472
6fbccf1
006f5f3
56e6139
d39a729
b6210fd
f86e690
0ddabed
8a5bd0b
8027a49
b5e9f8c
fef8461
15a1fbb
2423510
f5f9fdb
2ff7da4
12c454a
e1ac0e5
2308450
58a4c3a
8aa49c3
6b4afed
5f4ba8d
61974c3
c976bf6
b838598
c0e57ab
2dc15df
f51f355
02a3f7a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -75,9 +75,10 @@ def _render_trajectory_video( | |
depth_near_plane: Optional[float] = None, | ||
depth_far_plane: Optional[float] = None, | ||
colormap_options: colormaps.ColormapOptions = colormaps.ColormapOptions(), | ||
render_nearest_camera=False, | ||
render_nearest_camera: bool = False, | ||
check_occlusions: bool = False, | ||
) -> None: | ||
kill_flag: List[bool] = [False], | ||
) -> bool: | ||
"""Helper function to create a video of the spiral trajectory. | ||
|
||
Args: | ||
|
@@ -137,6 +138,9 @@ def _render_trajectory_video( | |
|
||
with progress: | ||
for camera_idx in progress.track(range(cameras.size), description=""): | ||
if kill_flag[0]: | ||
return False | ||
|
||
obb_box = None | ||
if crop_data is not None: | ||
obb_box = crop_data.obb | ||
|
@@ -202,9 +206,13 @@ def _render_trajectory_video( | |
for rendered_output_name in rendered_output_names: | ||
if rendered_output_name not in outputs: | ||
CONSOLE.rule("Error", style="red") | ||
CONSOLE.print(f"Could not find {rendered_output_name} in the model outputs", justify="center") | ||
CONSOLE.print( | ||
f"Please set --rendered_output_name to one of: {outputs.keys()}", justify="center" | ||
f"Could not find {rendered_output_name} in the model outputs", | ||
justify="center", | ||
) | ||
CONSOLE.print( | ||
f"Please set --rendered_output_name to one of: {outputs.keys()}", | ||
justify="center", | ||
) | ||
sys.exit(1) | ||
output_image = outputs[rendered_output_name] | ||
|
@@ -256,10 +264,17 @@ def _render_trajectory_video( | |
render_image = np.concatenate(render_image, axis=1) | ||
if output_format == "images": | ||
if image_format == "png": | ||
media.write_image(output_image_dir / f"{camera_idx:05d}.png", render_image, fmt="png") | ||
media.write_image( | ||
output_image_dir / f"{camera_idx:05d}.png", | ||
render_image, | ||
fmt="png", | ||
) | ||
if image_format == "jpeg": | ||
media.write_image( | ||
output_image_dir / f"{camera_idx:05d}.jpg", render_image, fmt="jpeg", quality=jpeg_quality | ||
output_image_dir / f"{camera_idx:05d}.jpg", | ||
render_image, | ||
fmt="jpeg", | ||
quality=jpeg_quality, | ||
) | ||
if output_format == "video": | ||
if writer is None: | ||
|
@@ -287,7 +302,15 @@ def _render_trajectory_video( | |
table.add_row("Video", str(output_filename)) | ||
else: | ||
table.add_row("Images", str(output_image_dir)) | ||
CONSOLE.print(Panel(table, title="[bold][green]:tada: Render Complete :tada:[/bold]", expand=False)) | ||
CONSOLE.print( | ||
Panel( | ||
table, | ||
title="[bold][green]:tada: Render Complete :tada:[/bold]", | ||
expand=False, | ||
) | ||
) | ||
|
||
return True | ||
|
||
|
||
def insert_spherical_metadata_into_file( | ||
|
@@ -432,6 +455,11 @@ class BaseRender: | |
"""If true, checks line-of-sight occlusions when computing camera distance and rejects cameras not visible to each other""" | ||
camera_idx: Optional[int] = None | ||
"""Index of the training camera to render.""" | ||
kill_flag: List[bool] = field(default_factory=lambda: [False]) | ||
"""Stop execution of render if set to True.""" | ||
|
||
def kill(self) -> None: | ||
self.kill_flag[0] = True | ||
|
||
|
||
@dataclass | ||
|
@@ -442,6 +470,8 @@ class RenderCameraPath(BaseRender): | |
"""Filename of the camera path to render.""" | ||
output_format: Literal["images", "video"] = "video" | ||
"""How to save output data.""" | ||
complete: bool = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same with this, it seems like complete should be a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not really following the connection to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can hack around this too by suppressing? I think as a field it does need to still exist for the progress(notification) update. |
||
"""Set to True when render is finished.""" | ||
|
||
def main(self) -> None: | ||
"""Main function.""" | ||
|
@@ -485,7 +515,7 @@ def main(self) -> None: | |
if self.camera_idx is not None: | ||
camera_path.metadata = {"cam_idx": self.camera_idx} | ||
|
||
_render_trajectory_video( | ||
self.complete = _render_trajectory_video( | ||
pipeline, | ||
camera_path, | ||
output_filename=self.output_path, | ||
|
@@ -501,6 +531,7 @@ def main(self) -> None: | |
colormap_options=self.colormap_options, | ||
render_nearest_camera=self.render_nearest_camera, | ||
check_occlusions=self.check_occlusions, | ||
kill_flag=self.kill_flag, | ||
) | ||
|
||
if ( | ||
|
@@ -536,6 +567,7 @@ def main(self) -> None: | |
colormap_options=self.colormap_options, | ||
render_nearest_camera=self.render_nearest_camera, | ||
check_occlusions=self.check_occlusions, | ||
kill_flag=self.kill_flag, | ||
) | ||
|
||
self.output_path = Path(str(left_eye_path.parent)[:-5] + ".mp4") | ||
|
@@ -639,6 +671,7 @@ def main(self) -> None: | |
colormap_options=self.colormap_options, | ||
render_nearest_camera=self.render_nearest_camera, | ||
check_occlusions=self.check_occlusions, | ||
kill_flag=self.kill_flag, | ||
) | ||
|
||
|
||
|
@@ -694,6 +727,7 @@ def main(self) -> None: | |
colormap_options=self.colormap_options, | ||
render_nearest_camera=self.render_nearest_camera, | ||
check_occlusions=self.check_occlusions, | ||
kill_flag=self.kill_flag, | ||
) | ||
|
||
|
||
|
@@ -731,7 +765,10 @@ def main(self): | |
|
||
def update_config(config: TrainerConfig) -> TrainerConfig: | ||
data_manager_config = config.pipeline.datamanager | ||
assert isinstance(data_manager_config, (VanillaDataManagerConfig, FullImageDatamanagerConfig)) | ||
assert isinstance( | ||
data_manager_config, | ||
(VanillaDataManagerConfig, FullImageDatamanagerConfig), | ||
) | ||
data_manager_config.eval_num_images_to_sample_from = -1 | ||
data_manager_config.eval_num_times_to_repeat_images = -1 | ||
if isinstance(data_manager_config, VanillaDataManagerConfig): | ||
|
@@ -741,7 +778,11 @@ def update_config(config: TrainerConfig) -> TrainerConfig: | |
data_manager_config.data = self.data | ||
if self.downscale_factor is not None: | ||
assert hasattr(data_manager_config.dataparser, "downscale_factor") | ||
setattr(data_manager_config.dataparser, "downscale_factor", self.downscale_factor) | ||
setattr( | ||
data_manager_config.dataparser, | ||
"downscale_factor", | ||
self.downscale_factor, | ||
) | ||
return config | ||
|
||
config, pipeline, _, _ = eval_setup( | ||
|
@@ -806,10 +847,12 @@ def update_config(config: TrainerConfig) -> TrainerConfig: | |
if rendered_output_name not in all_outputs: | ||
CONSOLE.rule("Error", style="red") | ||
CONSOLE.print( | ||
f"Could not find {rendered_output_name} in the model outputs", justify="center" | ||
f"Could not find {rendered_output_name} in the model outputs", | ||
justify="center", | ||
) | ||
CONSOLE.print( | ||
f"Please set --rendered-output-name to one of: {all_outputs}", justify="center" | ||
f"Please set --rendered-output-name to one of: {all_outputs}", | ||
justify="center", | ||
) | ||
sys.exit(1) | ||
|
||
|
@@ -876,7 +919,10 @@ def update_config(config: TrainerConfig) -> TrainerConfig: | |
media.write_image(output_path.with_suffix(".png"), output_image, fmt="png") | ||
elif self.image_format == "jpeg": | ||
media.write_image( | ||
output_path.with_suffix(".jpg"), output_image, fmt="jpeg", quality=self.jpeg_quality | ||
output_path.with_suffix(".jpg"), | ||
output_image, | ||
fmt="jpeg", | ||
quality=self.jpeg_quality, | ||
) | ||
else: | ||
raise ValueError(f"Unknown image format {self.image_format}") | ||
|
@@ -889,7 +935,13 @@ def update_config(config: TrainerConfig) -> TrainerConfig: | |
) | ||
for split in self.split.split("+"): | ||
table.add_row(f"Outputs {split}", str(self.output_path / split)) | ||
CONSOLE.print(Panel(table, title="[bold][green]:tada: Render on split {} Complete :tada:[/bold]", expand=False)) | ||
CONSOLE.print( | ||
Panel( | ||
table, | ||
title="[bold][green]:tada: Render on split {} Complete :tada:[/bold]", | ||
expand=False, | ||
) | ||
) | ||
|
||
|
||
Commands = tyro.conf.FlagConversionOff[ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we not expose this in the config? it feels like it should be an invisible class
self._kill_flag
variable, then the kill() function sets that to true.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still a hack but one option is: