Skip to content

Commit

Permalink
Merge branch 'develop' into andrew/more-cases-for-copying-instance
Browse files Browse the repository at this point in the history
  • Loading branch information
7174Andy authored Jul 16, 2024
2 parents 09f2f82 + 14c21b4 commit 30918d2
Show file tree
Hide file tree
Showing 34 changed files with 159 additions and 46 deletions.
6 changes: 4 additions & 2 deletions docs/guides/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ optional arguments:

```none
usage: sleap-train [-h] [--video-paths VIDEO_PATHS] [--val_labels VAL_LABELS]
[--test_labels TEST_LABELS] [--tensorboard] [--save_viz]
[--zmq] [--run_name RUN_NAME] [--prefix PREFIX]
[--test_labels TEST_LABELS] [--tensorboard] [--save_viz]
[--keep_viz] [--zmq] [--run_name RUN_NAME] [--prefix PREFIX]
[--suffix SUFFIX]
training_job_path [labels_path]
Expand Down Expand Up @@ -68,6 +68,8 @@ optional arguments:
--save_viz Enable saving of prediction visualizations to the run
folder if not already specified in the training job
config.
--keep_viz Keep prediction visualization images in the run
folder after training if --save_viz is enabled.
--zmq Enable ZMQ logging (for GUI) if not already specified
in the training job config.
--run_name RUN_NAME Run name to use when saving file, overrides other run
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@
" \"runs_folder\": \"models\",\n",
" \"tags\": [],\n",
" \"save_visualizations\": true,\n",
" \"delete_viz_images\": true,\n",
" \"keep_viz_images\": true,\n",
" \"zip_outputs\": false,\n",
" \"log_to_csv\": true,\n",
" \"checkpointing\": {\n",
Expand Down Expand Up @@ -727,7 +727,7 @@
" \"runs_folder\": \"models\",\n",
" \"tags\": [],\n",
" \"save_visualizations\": true,\n",
" \"delete_viz_images\": true,\n",
" \"keep_viz_images\": true,\n",
" \"zip_outputs\": false,\n",
" \"log_to_csv\": true,\n",
" \"checkpointing\": {\n",
Expand Down
5 changes: 5 additions & 0 deletions sleap/config/pipeline_form.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,11 @@ training:
type: bool
default: true

- name: _keep_viz
label: Keep Prediction Visualization Images After Training
type: bool
default: false

- name: _predict_frames
label: Predict On
type: list
Expand Down
25 changes: 24 additions & 1 deletion sleap/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
from logging import getLogger
from pathlib import Path
from typing import Callable, List, Optional, Tuple
import sys
import subprocess

from qtpy import QtCore, QtGui
from qtpy.QtCore import QEvent, Qt
Expand Down Expand Up @@ -84,7 +86,7 @@
from sleap.io.video import available_video_exts
from sleap.prefs import prefs
from sleap.skeleton import Skeleton
from sleap.util import parse_uri_path
from sleap.util import parse_uri_path, get_config_file


logger = getLogger(__name__)
Expand Down Expand Up @@ -515,6 +517,13 @@ def add_submenu_choices(menu, title, options, key):
fileMenu, "reset prefs", "Reset preferences to defaults...", self.resetPrefs
)

add_menu_item(
fileMenu,
"open preference directory",
"Open Preferences Directory...",
self.openPrefs,
)

fileMenu.addSeparator()
add_menu_item(fileMenu, "close", "Quit", self.close)

Expand Down Expand Up @@ -1330,6 +1339,20 @@ def resetPrefs(self):
)
msg.exec_()

def openPrefs(self):
"""Open preference file directory"""
pref_path = get_config_file("preferences.yaml")
# Make sure the pref_path is a directory rather than a file
if pref_path.is_file():
pref_path = pref_path.parent
# Open the file explorer at the folder containing the preferences.yaml file
if sys.platform == "win32":
subprocess.Popen(["explorer", str(pref_path)])
elif sys.platform == "darwin":
subprocess.Popen(["open", str(pref_path)])
else:
subprocess.Popen(["xdg-open", str(pref_path)])

def _update_track_menu(self):
"""Updates track menu options."""
self.track_menu.clear()
Expand Down
17 changes: 14 additions & 3 deletions sleap/gui/learning/runners.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Run training/inference in background process via CLI."""

import abc
import attr
import os
Expand Down Expand Up @@ -500,9 +501,11 @@ def write_pipeline_files(
"data_path": os.path.basename(data_path),
"models": [Path(p).as_posix() for p in new_cfg_filenames],
"output_path": prediction_output_path,
"type": "labels"
if type(item_for_inference) == DatasetItemForInference
else "video",
"type": (
"labels"
if type(item_for_inference) == DatasetItemForInference
else "video"
),
"only_suggested_frames": only_suggested_frames,
"tracking": tracking_args,
}
Expand Down Expand Up @@ -544,6 +547,7 @@ def run_learning_pipeline(
"""

save_viz = inference_params.get("_save_viz", False)
keep_viz = inference_params.get("_keep_viz", False)

if "movenet" in inference_params["_pipeline"]:
trained_job_paths = [inference_params["_pipeline"]]
Expand All @@ -557,6 +561,7 @@ def run_learning_pipeline(
inference_params=inference_params,
gui=True,
save_viz=save_viz,
keep_viz=keep_viz,
)

# Check that all the models were trained
Expand Down Expand Up @@ -585,6 +590,7 @@ def run_gui_training(
inference_params: Dict[str, Any],
gui: bool = True,
save_viz: bool = False,
keep_viz: bool = False,
) -> Dict[Text, Text]:
"""
Runs training for each training job.
Expand All @@ -594,6 +600,7 @@ def run_gui_training(
config_info_list: List of ConfigFileInfo with configs for training.
gui: Whether to show gui windows and process gui events.
save_viz: Whether to save visualizations from training.
keep_viz: Whether to keep prediction visualization images after training.
Returns:
Dictionary, keys are head name, values are path to trained config.
Expand Down Expand Up @@ -683,6 +690,7 @@ def waiting():
video_paths=video_path_list,
waiting_callback=waiting,
save_viz=save_viz,
keep_viz=keep_viz,
)

if ret == "success":
Expand Down Expand Up @@ -825,6 +833,7 @@ def train_subprocess(
video_paths: Optional[List[Text]] = None,
waiting_callback: Optional[Callable] = None,
save_viz: bool = False,
keep_viz: bool = False,
):
"""Runs training inside subprocess."""
run_path = job_config.outputs.run_path
Expand Down Expand Up @@ -853,6 +862,8 @@ def train_subprocess(

if save_viz:
cli_args.append("--save_viz")
if keep_viz:
cli_args.append("--keep_viz")

# Use cli arg since cli ignores setting in config
if job_config.outputs.tensorboard.write_logs:
Expand Down
19 changes: 11 additions & 8 deletions sleap/gui/overlays/tracks.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
"""Track trail and track list overlays."""

from typing import Dict, Iterable, List, Optional, Tuple

import attr
from qtpy import QtCore, QtGui

from sleap.gui.overlays.base import BaseOverlay
from sleap.gui.widgets.video import QtTextWithBackground
from sleap.instance import Track
from sleap.io.dataset import Labels
from sleap.io.video import Video
from sleap.prefs import prefs
from sleap.gui.widgets.video import QtTextWithBackground

import attr

from typing import Iterable, List, Optional, Dict

from qtpy import QtCore, QtGui


@attr.s(auto_attribs=True)
Expand Down Expand Up @@ -58,7 +57,9 @@ def get_shade_options(cls):

return {"Dark": 0.6, "Normal": 1.0, "Light": 1.25}

def get_track_trails(self, frame_selection: Iterable["LabeledFrame"]):
def get_track_trails(
self, frame_selection: Iterable["LabeledFrame"]
) -> Optional[Dict[Track, List[List[Tuple[float, float]]]]]:
"""Get data needed to draw track trail.
Args:
Expand Down Expand Up @@ -154,6 +155,8 @@ def add_to_scene(self, video: Video, frame_idx: int):
frame_selection = self.get_frame_selection(video, frame_idx)

all_track_trails = self.get_track_trails(frame_selection)
if all_track_trails is None:
return

for track, trails in all_track_trails.items():
trail_color = tuple(
Expand Down
6 changes: 3 additions & 3 deletions sleap/nn/config/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ class OutputsConfig:
save_visualizations: If True, will render and save visualizations of the model
predictions as PNGs to "{run_folder}/viz/{split}.{epoch:04d}.png", where the
split is one of "train", "validation", "test".
delete_viz_images: If True, delete the saved visualizations after training
completes. This is useful to reduce the model folder size if you do not need
keep_viz_images: If True, keep the saved visualization images after training
completes. This is useful unchecked to reduce the model folder size if you do not need
to keep the visualization images.
zip_outputs: If True, compress the run folder to a zip file. This will be named
"{run_folder}.zip".
Expand All @@ -170,7 +170,7 @@ class OutputsConfig:
runs_folder: Text = "models"
tags: List[Text] = attr.ib(factory=list)
save_visualizations: bool = True
delete_viz_images: bool = True
keep_viz_images: bool = False
zip_outputs: bool = False
log_to_csv: bool = True
checkpointing: CheckpointingConfig = attr.ib(factory=CheckpointingConfig)
Expand Down
13 changes: 11 additions & 2 deletions sleap/nn/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,7 @@ def train(self):
if self.config.outputs.save_outputs:
if (
self.config.outputs.save_visualizations
and self.config.outputs.delete_viz_images
and not self.config.outputs.keep_viz_images
):
self.cleanup()

Expand Down Expand Up @@ -997,7 +997,7 @@ def cleanup(self):

def package(self):
"""Package model folder into a zip file for portability."""
if self.config.outputs.delete_viz_images:
if not self.config.outputs.keep_viz_images:
self.cleanup()
logger.info(f"Packaging results to: {self.run_path}.zip")
shutil.make_archive(
Expand Down Expand Up @@ -1864,6 +1864,14 @@ def create_trainer_using_cli(args: Optional[List] = None):
"already specified in the training job config."
),
)
parser.add_argument(
"--keep_viz",
action="store_true",
help=(
"Keep prediction visualization images in the run folder after training when "
"--save_viz is enabled."
),
)
parser.add_argument(
"--zmq",
action="store_true",
Expand Down Expand Up @@ -1949,6 +1957,7 @@ def create_trainer_using_cli(args: Optional[List] = None):
if args.suffix != "":
job_config.outputs.run_name_suffix = args.suffix
job_config.outputs.save_visualizations |= args.save_viz
job_config.outputs.keep_viz_images = args.keep_viz
if args.labels_path == "":
args.labels_path = None
args.video_paths = args.video_paths.split(",")
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/baseline.centroid.json
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/baseline_large_rf.bottomup.json
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/baseline_large_rf.single.json
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/baseline_large_rf.topdown.json
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/baseline_medium_rf.bottomup.json
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/baseline_medium_rf.single.json
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/baseline_medium_rf.topdown.json
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/pretrained.bottomup.json
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/pretrained.centroid.json
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/pretrained.single.json
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/pretrained.topdown.json
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
Loading

0 comments on commit 30918d2

Please sign in to comment.