|
8 | 8 | import json
|
9 | 9 | import logging
|
10 | 10 | import os
|
| 11 | +import sys |
11 | 12 | from dataclasses import dataclass
|
12 | 13 | from pathlib import Path
|
13 | 14 | from typing import Literal, Tuple, Union
|
@@ -48,6 +49,14 @@ def _load_checkpoint(config: cfg.TrainerConfig, pipeline: Pipeline) -> Path:
|
48 | 49 | if config.load_step is None:
|
49 | 50 | console.print("Loading latest checkpoint from load_dir")
|
50 | 51 | # NOTE: this is specific to the checkpoint name format
|
| 52 | + if not os.path.exists(config.load_dir): |
| 53 | + console.rule("Error", style="red") |
| 54 | + console.print(f"No checkpoint directory found at {config.load_dir}, ", justify="center") |
| 55 | + console.print( |
| 56 | + "Please make sure the checkpoint exists, they should be generated periodically during training", |
| 57 | + justify="center", |
| 58 | + ) |
| 59 | + sys.exit(1) |
51 | 60 | load_step = sorted(int(x[x.find("-") + 1 : x.find(".")]) for x in os.listdir(config.load_dir))[-1]
|
52 | 61 | else:
|
53 | 62 | load_step = config.load_step
|
@@ -93,6 +102,11 @@ def _render_trajectory_video(
|
93 | 102 | camera_ray_bundle = cameras.generate_rays(camera_indices=camera_idx).to(pipeline.device)
|
94 | 103 | with torch.no_grad():
|
95 | 104 | outputs = pipeline.model.get_outputs_for_camera_ray_bundle(camera_ray_bundle)
|
| 105 | + if rendered_output_name not in outputs: |
| 106 | + console.rule("Error", style="red") |
| 107 | + console.print(f"Could not find {rendered_output_name} in the model outputs", justify="center") |
| 108 | + console.print(f"Please set --rendered_output_name to one of: {outputs.keys()}", justify="center") |
| 109 | + sys.exit(1) |
96 | 110 | image = outputs[rendered_output_name].cpu().numpy()
|
97 | 111 | images.append(image)
|
98 | 112 |
|
|
0 commit comments