Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Keep visualizations checkbox to training GUI #1824

Merged
merged 15 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions docs/guides/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ optional arguments:

```none
usage: sleap-train [-h] [--video-paths VIDEO_PATHS] [--val_labels VAL_LABELS]
[--test_labels TEST_LABELS] [--tensorboard] [--save_viz]
[--zmq] [--run_name RUN_NAME] [--prefix PREFIX]
[--test_labels TEST_LABELS] [--tensorboard] [--save_viz]
[--keep_viz] [--zmq] [--run_name RUN_NAME] [--prefix PREFIX]
hajin-park marked this conversation as resolved.
Show resolved Hide resolved
[--suffix SUFFIX]
training_job_path [labels_path]

Expand Down Expand Up @@ -68,6 +68,8 @@ optional arguments:
--save_viz Enable saving of prediction visualizations to the run
folder if not already specified in the training job
config.
--keep_viz Keep prediction visualization images in the run
folder after training if --save_viz is enabled.
--zmq Enable ZMQ logging (for GUI) if not already specified
in the training job config.
--run_name RUN_NAME Run name to use when saving file, overrides other run
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@
" \"runs_folder\": \"models\",\n",
" \"tags\": [],\n",
" \"save_visualizations\": true,\n",
" \"delete_viz_images\": true,\n",
" \"keep_viz_images\": true,\n",
" \"zip_outputs\": false,\n",
" \"log_to_csv\": true,\n",
" \"checkpointing\": {\n",
Expand Down Expand Up @@ -727,7 +727,7 @@
" \"runs_folder\": \"models\",\n",
" \"tags\": [],\n",
" \"save_visualizations\": true,\n",
" \"delete_viz_images\": true,\n",
" \"keep_viz_images\": true,\n",
" \"zip_outputs\": false,\n",
" \"log_to_csv\": true,\n",
" \"checkpointing\": {\n",
Expand Down
5 changes: 5 additions & 0 deletions sleap/config/pipeline_form.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,11 @@ training:
type: bool
default: true

- name: _keep_viz
label: Keep Prediction Visualization Images After Training
type: bool
default: false

- name: _predict_frames
label: Predict On
type: list
Expand Down
17 changes: 14 additions & 3 deletions sleap/gui/learning/runners.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Run training/inference in background process via CLI."""

hajin-park marked this conversation as resolved.
Show resolved Hide resolved
import abc
import attr
import os
Expand Down Expand Up @@ -500,9 +501,11 @@ def write_pipeline_files(
"data_path": os.path.basename(data_path),
"models": [Path(p).as_posix() for p in new_cfg_filenames],
"output_path": prediction_output_path,
"type": "labels"
if type(item_for_inference) == DatasetItemForInference
else "video",
"type": (
"labels"
if type(item_for_inference) == DatasetItemForInference
else "video"
),
"only_suggested_frames": only_suggested_frames,
"tracking": tracking_args,
}
Expand Down Expand Up @@ -544,6 +547,7 @@ def run_learning_pipeline(
"""

save_viz = inference_params.get("_save_viz", False)
keep_viz = inference_params.get("_keep_viz", False)

if "movenet" in inference_params["_pipeline"]:
trained_job_paths = [inference_params["_pipeline"]]
Expand All @@ -557,6 +561,7 @@ def run_learning_pipeline(
inference_params=inference_params,
gui=True,
save_viz=save_viz,
keep_viz=keep_viz,
)

# Check that all the models were trained
Expand Down Expand Up @@ -585,6 +590,7 @@ def run_gui_training(
inference_params: Dict[str, Any],
gui: bool = True,
save_viz: bool = False,
keep_viz: bool = False,
) -> Dict[Text, Text]:
hajin-park marked this conversation as resolved.
Show resolved Hide resolved
"""
Runs training for each training job.
Expand All @@ -594,6 +600,7 @@ def run_gui_training(
config_info_list: List of ConfigFileInfo with configs for training.
gui: Whether to show gui windows and process gui events.
save_viz: Whether to save visualizations from training.
keep_viz: Whether to keep prediction visualization images after training.

Returns:
Dictionary, keys are head name, values are path to trained config.
Expand Down Expand Up @@ -683,6 +690,7 @@ def waiting():
video_paths=video_path_list,
waiting_callback=waiting,
save_viz=save_viz,
keep_viz=keep_viz,
)

if ret == "success":
Expand Down Expand Up @@ -825,6 +833,7 @@ def train_subprocess(
video_paths: Optional[List[Text]] = None,
waiting_callback: Optional[Callable] = None,
save_viz: bool = False,
keep_viz: bool = False,
):
"""Runs training inside subprocess."""
run_path = job_config.outputs.run_path
Expand Down Expand Up @@ -853,6 +862,8 @@ def train_subprocess(

if save_viz:
cli_args.append("--save_viz")
if keep_viz:
cli_args.append("--keep_viz")

# Use cli arg since cli ignores setting in config
if job_config.outputs.tensorboard.write_logs:
Expand Down
6 changes: 3 additions & 3 deletions sleap/nn/config/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ class OutputsConfig:
save_visualizations: If True, will render and save visualizations of the model
predictions as PNGs to "{run_folder}/viz/{split}.{epoch:04d}.png", where the
split is one of "train", "validation", "test".
delete_viz_images: If True, delete the saved visualizations after training
completes. This is useful to reduce the model folder size if you do not need
keep_viz_images: If True, keep the saved visualization images after training
completes. This is useful unchecked to reduce the model folder size if you do not need
to keep the visualization images.
zip_outputs: If True, compress the run folder to a zip file. This will be named
"{run_folder}.zip".
Expand All @@ -170,7 +170,7 @@ class OutputsConfig:
runs_folder: Text = "models"
tags: List[Text] = attr.ib(factory=list)
save_visualizations: bool = True
delete_viz_images: bool = True
keep_viz_images: bool = False
hajin-park marked this conversation as resolved.
Show resolved Hide resolved
zip_outputs: bool = False
log_to_csv: bool = True
checkpointing: CheckpointingConfig = attr.ib(factory=CheckpointingConfig)
Expand Down
13 changes: 11 additions & 2 deletions sleap/nn/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,7 @@ def train(self):
if self.config.outputs.save_outputs:
if (
self.config.outputs.save_visualizations
and self.config.outputs.delete_viz_images
and not self.config.outputs.keep_viz_images
):
self.cleanup()

Expand Down Expand Up @@ -997,7 +997,7 @@ def cleanup(self):

def package(self):
"""Package model folder into a zip file for portability."""
if self.config.outputs.delete_viz_images:
if not self.config.outputs.keep_viz_images:
self.cleanup()
logger.info(f"Packaging results to: {self.run_path}.zip")
shutil.make_archive(
Expand Down Expand Up @@ -1864,6 +1864,14 @@ def create_trainer_using_cli(args: Optional[List] = None):
"already specified in the training job config."
),
)
parser.add_argument(
"--keep_viz",
action="store_true",
help=(
"Keep prediction visualization images in the run folder after training when "
"--save_viz is enabled."
),
)
parser.add_argument(
"--zmq",
action="store_true",
Expand Down Expand Up @@ -1949,6 +1957,7 @@ def create_trainer_using_cli(args: Optional[List] = None):
if args.suffix != "":
job_config.outputs.run_name_suffix = args.suffix
job_config.outputs.save_visualizations |= args.save_viz
job_config.outputs.keep_viz_images = args.keep_viz
if args.labels_path == "":
args.labels_path = None
args.video_paths = args.video_paths.split(",")
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/baseline.centroid.json
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/baseline_large_rf.bottomup.json
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/baseline_large_rf.single.json
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/baseline_large_rf.topdown.json
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/baseline_medium_rf.bottomup.json
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/baseline_medium_rf.single.json
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/baseline_medium_rf.topdown.json
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/pretrained.bottomup.json
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/pretrained.centroid.json
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/pretrained.single.json
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/pretrained.topdown.json
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": false,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": false,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": false,
"delete_viz_images": true,
"keep_viz_images": false,
"zip_outputs": false,
"log_to_csv": true,
"checkpointing": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": false,
"delete_viz_images": true,
"keep_viz_images": false,
"zip_outputs": false,
"log_to_csv": true,
"checkpointing": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": false,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": false,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": false,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": false,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": false,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": false,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
""
],
"save_visualizations": false,
"keep_viz_images": true,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@
""
],
"save_visualizations": false,
"keep_viz_images": true,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 0 additions & 1 deletion tests/gui/test_dialogs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Module to test the dialogs of the GUI (contained in sleap/gui/dialogs)."""


import os
from pathlib import Path
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unused import pathlib.Path.

The Path import from pathlib is unused and can be safely removed.

- from pathlib import Path
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from pathlib import Path
Tools
Ruff

4-4: pathlib.Path imported but unused

Remove unused import: pathlib.Path

(F401)


Expand Down
Loading