Skip to content

Commit c75a0ef

Browse files
committed
Reverted �iew_ back to save_ and changed new training checkbox to Keep visualization images after training.
1 parent ad06cf5 commit c75a0ef

29 files changed

+91
-91
lines changed

docs/guides/cli.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ optional arguments:
3636

3737
```none
3838
usage: sleap-train [-h] [--video-paths VIDEO_PATHS] [--val_labels VAL_LABELS]
39-
[--test_labels TEST_LABELS] [--tensorboard] [--view_viz]
40-
[--delete_viz] [--zmq] [--run_name RUN_NAME] [--prefix PREFIX]
39+
[--test_labels TEST_LABELS] [--tensorboard] [--save_viz]
40+
[--keep_viz] [--zmq] [--run_name RUN_NAME] [--prefix PREFIX]
4141
[--suffix SUFFIX]
4242
training_job_path [labels_path]
4343
@@ -65,11 +65,11 @@ optional arguments:
6565
to resume training from.
6666
--tensorboard Enable TensorBoard logging to the run path if not
6767
already specified in the training job config.
68-
--view_viz Enable saving of prediction visualizations to the run
68+
--save_viz Enable saving of prediction visualizations to the run
6969
folder if not already specified in the training job
7070
config.
71-
--delete_viz Delete prediction visualizations in the run
72-
folder after training if view_viz is enabled.
71+
--keep_viz Keep prediction visualization images in the run
72+
folder after training when save_viz is enabled.
7373
--zmq Enable ZMQ logging (for GUI) if not already specified
7474
in the training job config.
7575
--run_name RUN_NAME Run name to use when saving file, overrides other run

docs/notebooks/Training_and_inference_on_an_example_dataset.ipynb

+4-4
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@
190190
" \"test_labels\": null,\n",
191191
" \"base_checkpoint\": null,\n",
192192
" \"tensorboard\": false,\n",
193-
" \"view_viz\": false,\n",
193+
" \"save_viz\": false,\n",
194194
" \"zmq\": false,\n",
195195
" \"run_name\": \"courtship.centroid\",\n",
196196
" \"prefix\": \"\",\n",
@@ -335,7 +335,7 @@
335335
" \"runs_folder\": \"models\",\n",
336336
" \"tags\": [],\n",
337337
" \"save_visualizations\": true,\n",
338-
" \"delete_viz_images\": true,\n",
338+
" \"keep_viz_images\": true,\n",
339339
" \"zip_outputs\": false,\n",
340340
" \"log_to_csv\": true,\n",
341341
" \"checkpointing\": {\n",
@@ -581,7 +581,7 @@
581581
" \"test_labels\": null,\n",
582582
" \"base_checkpoint\": null,\n",
583583
" \"tensorboard\": false,\n",
584-
" \"view_viz\": false,\n",
584+
" \"save_viz\": false,\n",
585585
" \"zmq\": false,\n",
586586
" \"run_name\": \"courtship.topdown_confmaps\",\n",
587587
" \"prefix\": \"\",\n",
@@ -727,7 +727,7 @@
727727
" \"runs_folder\": \"models\",\n",
728728
" \"tags\": [],\n",
729729
" \"save_visualizations\": true,\n",
730-
" \"delete_viz_images\": true,\n",
730+
" \"keep_viz_images\": true,\n",
731731
" \"zip_outputs\": false,\n",
732732
" \"log_to_csv\": true,\n",
733733
" \"checkpointing\": {\n",

sleap/config/pipeline_form.yaml

+4-4
Original file line numberDiff line numberDiff line change
@@ -281,15 +281,15 @@ training:
281281
name: outputs.tensorboard.write_logs
282282
type: bool
283283

284-
- name: _view_viz
284+
- name: _save_viz
285285
label: Visualize Predictions During Training
286286
type: bool
287287
default: true
288288

289-
- name: _delete_viz
290-
label: Delete Prediction Visualizations After Training
289+
- name: _keep_viz
290+
label: Keep Prediction Visualization Images After Training
291291
type: bool
292-
default: true
292+
default: false
293293

294294
- name: _predict_frames
295295
label: Predict On

sleap/gui/learning/runners.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -546,8 +546,8 @@ def run_learning_pipeline(
546546
547547
"""
548548

549-
view_viz = inference_params.get("_view_viz", False)
550-
delete_viz = inference_params.get("_delete_viz", False)
549+
save_viz = inference_params.get("_save_viz", False)
550+
keep_viz = inference_params.get("_keep_viz", False)
551551

552552
if "movenet" in inference_params["_pipeline"]:
553553
trained_job_paths = [inference_params["_pipeline"]]
@@ -560,8 +560,8 @@ def run_learning_pipeline(
560560
config_info_list=config_info_list,
561561
inference_params=inference_params,
562562
gui=True,
563-
view_viz=view_viz,
564-
delete_viz=delete_viz,
563+
save_viz=save_viz,
564+
keep_viz=keep_viz,
565565
)
566566

567567
# Check that all the models were trained
@@ -589,8 +589,8 @@ def run_gui_training(
589589
config_info_list: List[ConfigFileInfo],
590590
inference_params: Dict[str, Any],
591591
gui: bool = True,
592-
view_viz: bool = False,
593-
delete_viz: bool = True,
592+
save_viz: bool = False,
593+
keep_viz: bool = False,
594594
) -> Dict[Text, Text]:
595595
"""
596596
Runs training for each training job.
@@ -599,8 +599,8 @@ def run_gui_training(
599599
labels: Labels object from which we'll get training data.
600600
config_info_list: List of ConfigFileInfo with configs for training.
601601
gui: Whether to show gui windows and process gui events.
602-
view_viz: Whether to save visualizations from training.
603-
delete_viz: Whether to delete prediction visualizations after training.
602+
save_viz: Whether to save visualizations from training.
603+
keep_viz: Whether to keep prediction visualization images after training.
604604
605605
Returns:
606606
Dictionary, keys are head name, values are path to trained config.
@@ -667,7 +667,7 @@ def run_gui_training(
667667
win.reset(what=str(model_type), config=job)
668668
win.setWindowTitle(f"Training Model - {str(model_type)}")
669669
win.set_message(f"Preparing to run training...")
670-
if view_viz:
670+
if save_viz:
671671
viz_window = QtImageDirectoryWidget.make_training_vizualizer(
672672
job.outputs.run_path
673673
)
@@ -689,8 +689,8 @@ def waiting():
689689
labels_filename=labels_filename,
690690
video_paths=video_path_list,
691691
waiting_callback=waiting,
692-
view_viz=view_viz,
693-
delete_viz=delete_viz,
692+
save_viz=save_viz,
693+
keep_viz=keep_viz,
694694
)
695695

696696
if ret == "success":
@@ -832,8 +832,8 @@ def train_subprocess(
832832
inference_params: Dict[str, Any],
833833
video_paths: Optional[List[Text]] = None,
834834
waiting_callback: Optional[Callable] = None,
835-
view_viz: bool = False,
836-
delete_viz: bool = True,
835+
save_viz: bool = False,
836+
keep_viz: bool = False,
837837
):
838838
"""Runs training inside subprocess."""
839839
run_path = job_config.outputs.run_path
@@ -860,10 +860,10 @@ def train_subprocess(
860860
str(inference_params["publish_port"]),
861861
]
862862

863-
if view_viz:
864-
cli_args.append("--view_viz")
865-
if delete_viz:
866-
cli_args.append("--delete_viz")
863+
if save_viz:
864+
cli_args.append("--save_viz")
865+
if keep_viz:
866+
cli_args.append("--keep_viz")
867867

868868
# Use cli arg since cli ignores setting in config
869869
if job_config.outputs.tensorboard.write_logs:

sleap/nn/config/outputs.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,11 @@ class OutputsConfig:
148148
tags: A list of strings to use as "tags" that can be used to organize multiple
149149
runs. These are not used for anything during training or inference, so they
150150
can be used to store arbitrary user-specified metadata.
151-
view_visualizations: If True, will render and save visualizations of the model
151+
save_visualizations: If True, will render and save visualizations of the model
152152
predictions as PNGs to "{run_folder}/viz/{split}.{epoch:04d}.png", where the
153153
split is one of "train", "validation", "test".
154-
delete_viz_images: If True, delete the saved visualizations after training
155-
completes. This is useful to reduce the model folder size if you do not need
154+
keep_viz_images: If True, keep the saved visualization images after training
155+
completes. This is useful unchecked to reduce the model folder size if you do not need
156156
to keep the visualization images.
157157
zip_outputs: If True, compress the run folder to a zip file. This will be named
158158
"{run_folder}.zip".
@@ -169,8 +169,8 @@ class OutputsConfig:
169169
run_name_suffix: Optional[Text] = None
170170
runs_folder: Text = "models"
171171
tags: List[Text] = attr.ib(factory=list)
172-
view_visualizations: bool = True
173-
delete_viz_images: bool = True
172+
save_visualizations: bool = True
173+
keep_viz_images: bool = False
174174
zip_outputs: bool = False
175175
log_to_csv: bool = True
176176
checkpointing: CheckpointingConfig = attr.ib(factory=CheckpointingConfig)

sleap/nn/training.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@ def setup_visualization(
515515
"This probably means Qt is running headless."
516516
)
517517

518-
if config.view_visualizations and config.save_outputs:
518+
if config.save_visualizations and config.save_outputs:
519519
callbacks.append(
520520
MatplotlibSaver(
521521
save_folder=os.path.join(run_path, "viz"), plot_fn=viz_fn, prefix=name
@@ -945,8 +945,8 @@ def train(self):
945945
# Run post-training actions.
946946
if self.config.outputs.save_outputs:
947947
if (
948-
self.config.outputs.view_visualizations
949-
and self.config.outputs.delete_viz_images
948+
self.config.outputs.save_visualizations
949+
and not self.config.outputs.keep_viz_images
950950
):
951951
self.cleanup()
952952

@@ -997,7 +997,7 @@ def cleanup(self):
997997

998998
def package(self):
999999
"""Package model folder into a zip file for portability."""
1000-
if self.config.outputs.delete_viz_images:
1000+
if not self.config.outputs.keep_viz_images:
10011001
self.cleanup()
10021002
logger.info(f"Packaging results to: {self.run_path}.zip")
10031003
shutil.make_archive(
@@ -1857,19 +1857,19 @@ def create_trainer_using_cli(args: Optional[List] = None):
18571857
),
18581858
)
18591859
parser.add_argument(
1860-
"--view_viz",
1860+
"--save_viz",
18611861
action="store_true",
18621862
help=(
18631863
"Enable saving of prediction visualizations to the run folder if not "
18641864
"already specified in the training job config."
18651865
),
18661866
)
18671867
parser.add_argument(
1868-
"--delete_viz",
1868+
"--keep_viz",
18691869
action="store_true",
18701870
help=(
1871-
"Delete prediction visualizations in the run folder after training if "
1872-
"view_viz is enabled."
1871+
"Keep prediction visualization images in the run folder after training when "
1872+
"save_viz is enabled."
18731873
),
18741874
)
18751875
parser.add_argument(
@@ -1956,8 +1956,8 @@ def create_trainer_using_cli(args: Optional[List] = None):
19561956
job_config.outputs.run_name_prefix = args.prefix
19571957
if args.suffix != "":
19581958
job_config.outputs.run_name_suffix = args.suffix
1959-
job_config.outputs.view_visualizations |= args.view_viz
1960-
job_config.outputs.delete_viz_images |= args.delete_viz
1959+
job_config.outputs.save_visualizations |= args.save_viz
1960+
job_config.outputs.keep_viz_images |= args.keep_viz
19611961
if args.labels_path == "":
19621962
args.labels_path = None
19631963
args.video_paths = args.video_paths.split(",")

sleap/training_profiles/baseline.centroid.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@
115115
"run_name_suffix": null,
116116
"runs_folder": "models",
117117
"tags": [],
118-
"view_visualizations": true,
119-
"delete_viz_images": true,
118+
"save_visualizations": true,
119+
"keep_viz_images": false,
120120
"log_to_csv": true,
121121
"checkpointing": {
122122
"initial_model": false,

sleap/training_profiles/baseline_large_rf.bottomup.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@
124124
"run_name_suffix": null,
125125
"runs_folder": "models",
126126
"tags": [],
127-
"view_visualizations": true,
128-
"delete_viz_images": true,
127+
"save_visualizations": true,
128+
"keep_viz_images": false,
129129
"log_to_csv": true,
130130
"checkpointing": {
131131
"initial_model": false,

sleap/training_profiles/baseline_large_rf.single.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@
115115
"run_name_suffix": null,
116116
"runs_folder": "models",
117117
"tags": [],
118-
"view_visualizations": true,
119-
"delete_viz_images": true,
118+
"save_visualizations": true,
119+
"keep_viz_images": false,
120120
"log_to_csv": true,
121121
"checkpointing": {
122122
"initial_model": false,

sleap/training_profiles/baseline_large_rf.topdown.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@
116116
"run_name_suffix": null,
117117
"runs_folder": "models",
118118
"tags": [],
119-
"view_visualizations": true,
120-
"delete_viz_images": true,
119+
"save_visualizations": true,
120+
"keep_viz_images": false,
121121
"log_to_csv": true,
122122
"checkpointing": {
123123
"initial_model": false,

sleap/training_profiles/baseline_medium_rf.bottomup.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@
124124
"run_name_suffix": null,
125125
"runs_folder": "models",
126126
"tags": [],
127-
"view_visualizations": true,
128-
"delete_viz_images": true,
127+
"save_visualizations": true,
128+
"keep_viz_images": false,
129129
"log_to_csv": true,
130130
"checkpointing": {
131131
"initial_model": false,

sleap/training_profiles/baseline_medium_rf.single.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@
115115
"run_name_suffix": null,
116116
"runs_folder": "models",
117117
"tags": [],
118-
"view_visualizations": true,
119-
"delete_viz_images": true,
118+
"save_visualizations": true,
119+
"keep_viz_images": false,
120120
"log_to_csv": true,
121121
"checkpointing": {
122122
"initial_model": false,

sleap/training_profiles/baseline_medium_rf.topdown.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@
116116
"run_name_suffix": null,
117117
"runs_folder": "models",
118118
"tags": [],
119-
"view_visualizations": true,
120-
"delete_viz_images": true,
119+
"save_visualizations": true,
120+
"keep_viz_images": false,
121121
"log_to_csv": true,
122122
"checkpointing": {
123123
"initial_model": false,

sleap/training_profiles/pretrained.bottomup.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@
121121
"run_name_suffix": null,
122122
"runs_folder": "models",
123123
"tags": [],
124-
"view_visualizations": true,
125-
"delete_viz_images": true,
124+
"save_visualizations": true,
125+
"keep_viz_images": false,
126126
"log_to_csv": true,
127127
"checkpointing": {
128128
"initial_model": false,

sleap/training_profiles/pretrained.centroid.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@
112112
"run_name_suffix": null,
113113
"runs_folder": "models",
114114
"tags": [],
115-
"view_visualizations": true,
116-
"delete_viz_images": true,
115+
"save_visualizations": true,
116+
"keep_viz_images": false,
117117
"log_to_csv": true,
118118
"checkpointing": {
119119
"initial_model": false,

sleap/training_profiles/pretrained.single.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@
112112
"run_name_suffix": null,
113113
"runs_folder": "models",
114114
"tags": [],
115-
"view_visualizations": true,
116-
"delete_viz_images": true,
115+
"save_visualizations": true,
116+
"keep_viz_images": false,
117117
"log_to_csv": true,
118118
"checkpointing": {
119119
"initial_model": false,

sleap/training_profiles/pretrained.topdown.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@
113113
"run_name_suffix": null,
114114
"runs_folder": "models",
115115
"tags": [],
116-
"view_visualizations": true,
117-
"delete_viz_images": true,
116+
"save_visualizations": true,
117+
"keep_viz_images": false,
118118
"log_to_csv": true,
119119
"checkpointing": {
120120
"initial_model": false,

tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/initial_config.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@
127127
"run_name_suffix": null,
128128
"runs_folder": "models",
129129
"tags": [],
130-
"view_visualizations": false,
131-
"delete_viz_images": false,
130+
"save_visualizations": false,
131+
"keep_viz_images": true,
132132
"log_to_csv": true,
133133
"checkpointing": {
134134
"initial_model": false,

0 commit comments

Comments
 (0)