diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index af8ef2c6c..7f9e91ec9 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -5288,12 +5288,9 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]: A tuple of `(provider, data_path)` with the data `Provider` and path to the data that was specified in the args. """ + # Figure out which input path to use. - labels_path = getattr(args, "labels", None) - if labels_path is not None: - data_path = labels_path - else: - data_path = args.data_path + data_path = args.data_path if data_path is None or data_path == "": raise ValueError( @@ -5301,33 +5298,73 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]: "Run 'sleap-track -h' to see full command documentation." ) - if data_path.endswith(".slp"): - labels = sleap.load_file(data_path) - - if args.only_labeled_frames: - provider = LabelsReader.from_user_labeled_frames(labels) - elif args.only_suggested_frames: - provider = LabelsReader.from_unlabeled_suggestions(labels) - elif getattr(args, "video.index") != "": - provider = VideoReader( - video=labels.videos[int(getattr(args, "video.index"))], - example_indices=frame_list(args.frames), - ) - else: - provider = LabelsReader(labels) + data_path_obj = Path(data_path) + + # Check that input value is valid + if not data_path_obj.exists(): + raise ValueError("Path to data_path does not exist") + + # Check for multiple video inputs + # Compile file(s) into a list for later itteration + if data_path_obj.is_dir(): + data_path_list = [] + for file_path in data_path_obj.iterdir(): + if file_path.is_file(): + data_path_list.append(Path(file_path)) + elif data_path_obj.is_file(): + data_path_list = [data_path_obj] + + # Provider list to accomodate multiple video inputs + output_provider_list = [] + output_data_path_list = [] + for file_path in data_path_list: + # Create a provider for each file + if file_path.as_posix().endswith(".slp") and len(data_path_list) > 1: + print(f"slp file skipped: {file_path.as_posix()}") + + elif file_path.as_posix().endswith(".slp"): + labels = sleap.load_file(file_path.as_posix()) + + if args.only_labeled_frames: + output_provider_list.append( + LabelsReader.from_user_labeled_frames(labels) + ) + elif args.only_suggested_frames: + output_provider_list.append( + LabelsReader.from_unlabeled_suggestions(labels) + ) + elif getattr(args, "video.index") != "": + output_provider_list.append( + VideoReader( + video=labels.videos[int(getattr(args, "video.index"))], + example_indices=frame_list(args.frames), + ) + ) + else: + output_provider_list.append(LabelsReader(labels)) - else: - print(f"Video: {data_path}") - # TODO: Clean this up. - video_kwargs = dict( - dataset=vars(args).get("video.dataset"), - input_format=vars(args).get("video.input_format"), - ) - provider = VideoReader.from_filepath( - filename=data_path, example_indices=frame_list(args.frames), **video_kwargs - ) + output_data_path_list.append(file_path) - return provider, data_path + else: + try: + video_kwargs = dict( + dataset=vars(args).get("video.dataset"), + input_format=vars(args).get("video.input_format"), + ) + output_provider_list.append( + VideoReader.from_filepath( + filename=file_path.as_posix(), + example_indices=frame_list(args.frames), + **video_kwargs, + ) + ) + print(f"Video: {file_path.as_posix()}") + output_data_path_list.append(file_path) + # TODO: Clean this up. + except Exception: + print(f"Error reading file: {file_path.as_posix()}") + + return output_provider_list, output_data_path_list def _make_predictor_from_cli(args: argparse.Namespace) -> Predictor: @@ -5422,8 +5459,6 @@ def main(args: Optional[list] = None): pprint(vars(args)) print() - output_path = args.output - # Setup devices. if args.cpu or not sleap.nn.system.is_gpu_system(): sleap.nn.system.use_cpu_only() @@ -5461,7 +5496,19 @@ def main(args: Optional[list] = None): print() # Setup data loader. - provider, data_path = _make_provider_from_cli(args) + provider_list, data_path_list = _make_provider_from_cli(args) + + output_path = args.output + + # check if output_path is valid before running inference + if ( + output_path is not None + and Path(output_path).is_file() + and len(data_path_list) > 1 + ): + raise ValueError( + "output_path argument must be a directory if multiple video inputs are given" + ) # Setup tracker. tracker = _make_tracker_from_cli(args) @@ -5469,35 +5516,148 @@ def main(args: Optional[list] = None): if args.models is not None and "movenet" in args.models[0]: args.models = args.models[0] - # Either run inference (and tracking) or just run tracking + # Either run inference (and tracking) or just run tracking (if using an existing prediction where inference has already been run) if args.models is not None: - # Setup models. - predictor = _make_predictor_from_cli(args) - predictor.tracker = tracker - # Run inference! - labels_pr = predictor.predict(provider) + # Run inference on all files inputed + for data_path, provider in zip(data_path_list, provider_list): + # Setup models. + data_path_obj = Path(data_path) + predictor = _make_predictor_from_cli(args) + predictor.tracker = tracker + + # Run inference! + labels_pr = predictor.predict(provider) - if output_path is None: - output_path = data_path + ".predictions.slp" + # if output path was not provided, create an output path + if output_path is None: + output_path = f"{data_path.as_posix()}.predictions.slp" + output_path_obj = Path(output_path) - labels_pr.provenance["model_paths"] = predictor.model_paths - labels_pr.provenance["predictor"] = type(predictor).__name__ + else: + output_path_obj = Path(output_path) + # if output_path was provided and multiple inputs were provided, create a directory to store outputs + if len(data_path_list) > 1: + output_path = ( + output_path_obj + / data_path_obj.with_suffix(".predictions.slp").name + ) + output_path_obj = Path(output_path) + # Create the containing directory if needed. + output_path_obj.parent.mkdir(exist_ok=True, parents=True) + + labels_pr.provenance["model_paths"] = predictor.model_paths + labels_pr.provenance["predictor"] = type(predictor).__name__ + + if args.no_empty_frames: + # Clear empty frames if specified. + labels_pr.remove_empty_frames() + + finish_timestamp = str(datetime.now()) + total_elapsed = time() - t0 + print("Finished inference at:", finish_timestamp) + print(f"Total runtime: {total_elapsed} secs") + print(f"Predicted frames: {len(labels_pr)}/{len(provider)}") + + # Add provenance metadata to predictions. + labels_pr.provenance["sleap_version"] = sleap.__version__ + labels_pr.provenance["platform"] = platform.platform() + labels_pr.provenance["command"] = " ".join(sys.argv) + labels_pr.provenance["data_path"] = data_path_obj.as_posix() + labels_pr.provenance["output_path"] = output_path_obj.as_posix() + labels_pr.provenance["total_elapsed"] = total_elapsed + labels_pr.provenance["start_timestamp"] = start_timestamp + labels_pr.provenance["finish_timestamp"] = finish_timestamp + + print("Provenance:") + pprint(labels_pr.provenance) + print() + + labels_pr.provenance["args"] = vars(args) + + # Save results. + labels_pr.save(output_path) + print("Saved output:", output_path) + + if args.open_in_gui: + subprocess.call(["sleap-label", output_path]) + + # Reset output_path for next iteration + output_path = args.output + + # running tracking on existing prediction file elif getattr(args, "tracking.tracker") is not None: - # Load predictions - print("Loading predictions...") - labels_pr = sleap.load_file(args.data_path) - frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx) + for data_path, provider in zip(data_path_list, provider_list): + # Load predictions + data_path_obj = Path(data_path) + print("Loading predictions...") + labels_pr = sleap.load_file(data_path_obj.as_posix()) + frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx) + + print("Starting tracker...") + frames = run_tracker(frames=frames, tracker=tracker) + tracker.final_pass(frames) + + labels_pr = Labels(labeled_frames=frames) + + if output_path is None: + output_path = f"{data_path}.{tracker.get_name()}.slp" + output_path_obj = Path(output_path) + + else: + output_path_obj = Path(output_path) + if ( + output_path_obj.exists() + and output_path_obj.is_file() + and len(data_path_list) > 1 + ): + raise ValueError( + "output_path argument must be a directory if multiple video inputs are given" + ) - print("Starting tracker...") - frames = run_tracker(frames=frames, tracker=tracker) - tracker.final_pass(frames) + elif not output_path_obj.exists() and len(data_path_list) > 1: + output_path = output_path_obj / data_path_obj.with_suffix( + ".predictions.slp" + ) + output_path_obj = Path(output_path) + output_path_obj.parent.mkdir(exist_ok=True, parents=True) + + if args.no_empty_frames: + # Clear empty frames if specified. + labels_pr.remove_empty_frames() + + finish_timestamp = str(datetime.now()) + total_elapsed = time() - t0 + print("Finished inference at:", finish_timestamp) + print(f"Total runtime: {total_elapsed} secs") + print(f"Predicted frames: {len(labels_pr)}/{len(provider)}") + + # Add provenance metadata to predictions. + labels_pr.provenance["sleap_version"] = sleap.__version__ + labels_pr.provenance["platform"] = platform.platform() + labels_pr.provenance["command"] = " ".join(sys.argv) + labels_pr.provenance["data_path"] = data_path_obj.as_posix() + labels_pr.provenance["output_path"] = output_path_obj.as_posix() + labels_pr.provenance["total_elapsed"] = total_elapsed + labels_pr.provenance["start_timestamp"] = start_timestamp + labels_pr.provenance["finish_timestamp"] = finish_timestamp + + print("Provenance:") + pprint(labels_pr.provenance) + print() + + labels_pr.provenance["args"] = vars(args) - labels_pr = Labels(labeled_frames=frames) + # Save results. + labels_pr.save(output_path) + print("Saved output:", output_path) - if output_path is None: - output_path = f"{data_path}.{tracker.get_name()}.slp" + if args.open_in_gui: + subprocess.call(["sleap-label", output_path]) + + # Reset output_path for next iteration + output_path = args.output else: raise ValueError( @@ -5506,36 +5666,3 @@ def main(args: Optional[list] = None): "To retrack on predictions, must specify tracker. " "Use \"sleap-track --tracking.tracker ...' to specify tracker to use." ) - - if args.no_empty_frames: - # Clear empty frames if specified. - labels_pr.remove_empty_frames() - - finish_timestamp = str(datetime.now()) - total_elapsed = time() - t0 - print("Finished inference at:", finish_timestamp) - print(f"Total runtime: {total_elapsed} secs") - print(f"Predicted frames: {len(labels_pr)}/{len(provider)}") - - # Add provenance metadata to predictions. - labels_pr.provenance["sleap_version"] = sleap.__version__ - labels_pr.provenance["platform"] = platform.platform() - labels_pr.provenance["command"] = " ".join(sys.argv) - labels_pr.provenance["data_path"] = data_path - labels_pr.provenance["output_path"] = output_path - labels_pr.provenance["total_elapsed"] = total_elapsed - labels_pr.provenance["start_timestamp"] = start_timestamp - labels_pr.provenance["finish_timestamp"] = finish_timestamp - - print("Provenance:") - pprint(labels_pr.provenance) - print() - - labels_pr.provenance["args"] = vars(args) - - # Save results. - labels_pr.save(output_path) - print("Saved output:", output_path) - - if args.open_in_gui: - subprocess.call(["sleap-label", output_path]) diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index 1b0f88c7c..f99f136ab 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -3,6 +3,7 @@ import zipfile from pathlib import Path from typing import cast +import shutil import numpy as np import pytest @@ -1447,7 +1448,49 @@ def test_make_predictor_from_cli( assert predictor.max_instances == 5 -def test_sleap_track( +def test_make_predictor_from_cli_mult_input( + centered_pair_predictions: Labels, + min_centroid_model_path: str, + min_centered_instance_model_path: str, + min_bottomup_model_path: str, + tmpdir, +): + slp_path = tmpdir.mkdir("slp_directory") + + slp_file = slp_path / "old_slp.slp" + Labels.save(centered_pair_predictions, slp_file) + + # Copy and paste the video into the temp dir multiple times + num_copies = 3 + for i in range(num_copies): + # Construct the destination path with a unique name for the video + + # Construct the destination path with a unique name for the SLP file + slp_dest_path = slp_path / f"old_slp_copy_{i}.slp" + shutil.copy(slp_file, slp_dest_path) + + # Create sleap-track command + model_args = [ + f"--model {min_centroid_model_path} --model {min_centered_instance_model_path}", + f"--model {min_bottomup_model_path}", + ] + for model_arg in model_args: + args = ( + f"{slp_path} {model_arg} --video.index 0 --frames 1-3 " + "--cpu --max_instances 5" + ).split() + parser = _make_cli_parser() + args, _ = parser.parse_known_args(args=args) + + # Create predictor + predictor = _make_predictor_from_cli(args=args) + if isinstance(predictor, TopDownPredictor): + assert predictor.inference_model.centroid_crop.max_instances == 5 + elif isinstance(predictor, BottomUpPredictor): + assert predictor.max_instances == 5 + + +def test_sleap_track_single_input( centered_pair_predictions: Labels, min_centroid_model_path: str, min_centered_instance_model_path: str, @@ -1475,6 +1518,235 @@ def test_sleap_track( sleap_track(args=args) +@pytest.mark.parametrize("tracking", ["simple", "flow", "None"]) +def test_sleap_track_mult_input_slp( + min_centroid_model_path: str, + min_centered_instance_model_path: str, + tmpdir, + centered_pair_predictions: Labels, + tracking, +): + # Create temporary directory with the structured video files + slp_path = tmpdir.mkdir("slp_directory") + + slp_file = slp_path / "old_slp.slp" + Labels.save(centered_pair_predictions, slp_file) + + slp_path_obj = Path(slp_path) + + # Copy and paste the video into the temp dir multiple times + num_copies = 3 + for i in range(num_copies): + # Construct the destination path with a unique name for the video + + # Construct the destination path with a unique name for the SLP file + slp_dest_path = slp_path / f"old_slp_copy_{i}.slp" + shutil.copy(slp_file, slp_dest_path) + + # Create sleap-track command + args = ( + f"{slp_path} --model {min_centroid_model_path} " + f"--tracking.tracker {tracking} " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + slp_path_list = [file for file in slp_path_obj.iterdir() if file.is_file()] + + # Run inference + sleap_track(args=args) + + # Assert predictions file exists + expected_extensions = { + ".mp4", + } # Add other video formats if necessary + + for file_path in slp_path_list: + if file_path.suffix in expected_extensions: + expected_output_file = f"{file_path}.predictions.slp" + assert Path(expected_output_file).exists() + + +@pytest.mark.parametrize("tracking", ["simple", "flow", "None"]) +def test_sleap_track_mult_input_slp_mp4( + min_centroid_model_path: str, + min_centered_instance_model_path: str, + centered_pair_vid_path, + tracking, + tmpdir, + centered_pair_predictions: Labels, +): + # Create temporary directory with the structured video files + slp_path = tmpdir.mkdir("slp_mp4_directory") + + slp_file = slp_path / "old_slp.slp" + Labels.save(centered_pair_predictions, slp_file) + + # Copy and paste the video into temp dir multiple times + num_copies = 3 + for i in range(num_copies): + # Construct the destination path with a unique name + dest_path = slp_path / f"centered_pair_vid_copy_{i}.mp4" + shutil.copy(centered_pair_vid_path, dest_path) + + slp_path_obj = Path(slp_path) + + # Create sleap-track command + args = ( + f"{slp_path} --model {min_centroid_model_path} " + f"--tracking.tracker {tracking} " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + slp_path_list = [file for file in slp_path_obj.iterdir() if file.is_file()] + + # Run inference + sleap_track(args=args) + + # Assert predictions file exists + for file_path in slp_path_list: + if file_path.suffix == ".mp4": + expected_output_file = f"{file_path}.predictions.slp" + assert Path(expected_output_file).exists() + + +@pytest.mark.parametrize("tracking", ["simple", "flow", "None"]) +def test_sleap_track_mult_input_mp4( + min_centroid_model_path: str, + min_centered_instance_model_path: str, + centered_pair_vid_path, + tracking, + tmpdir, +): + + # Create temporary directory with the structured video files + slp_path = tmpdir.mkdir("mp4_directory") + + # Copy and paste the video into the temp dir multiple times + num_copies = 3 + for i in range(num_copies): + # Construct the destination path with a unique name + dest_path = slp_path / f"centered_pair_vid_copy_{i}.mp4" + shutil.copy(centered_pair_vid_path, dest_path) + + slp_path_obj = Path(slp_path) + + # Create sleap-track command + args = ( + f"{slp_path} --model {min_centroid_model_path} " + f"--tracking.tracker {tracking} " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + slp_path_list = [file for file in slp_path_obj.iterdir() if file.is_file()] + + # Run inference + sleap_track(args=args) + + # Assert predictions file exists + for file_path in slp_path_list: + if file_path.suffix == ".mp4": + expected_output_file = f"{file_path}.predictions.slp" + assert Path(expected_output_file).exists() + + +def test_sleap_track_output_mult( + min_centroid_model_path: str, + min_centered_instance_model_path: str, + centered_pair_vid_path, + tmpdir, +): + + output_path = tmpdir.mkdir("output_directory") + output_path_obj = Path(output_path) + + # Create temporary directory with the structured video files + slp_path = tmpdir.mkdir("mp4_directory") + + # Copy and paste the video into the temp dir multiple times + num_copies = 3 + for i in range(num_copies): + # Construct the destination path with a unique name + dest_path = slp_path / f"centered_pair_vid_copy_{i}.mp4" + shutil.copy(centered_pair_vid_path, dest_path) + + slp_path_obj = Path(slp_path) + + # Create sleap-track command + args = ( + f"{slp_path} --model {min_centroid_model_path} " + f"--tracking.tracker simple " + f"-o {output_path} " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + slp_path_list = [file for file in slp_path_obj.iterdir() if file.is_file()] + + # Run inference + sleap_track(args=args) + slp_path = Path(slp_path) + + # Check if there are any files in the directory + for file_path in slp_path_list: + if file_path.suffix == ".mp4": + expected_output_file = output_path_obj / ( + file_path.stem + ".predictions.slp" + ) + assert Path(expected_output_file).exists() + + +def test_sleap_track_invalid_output( + min_centroid_model_path: str, + min_centered_instance_model_path: str, + centered_pair_vid_path, + centered_pair_predictions: Labels, + tmpdir, +): + + output_path = Path(tmpdir, "output_file.slp").as_posix() + Labels.save(centered_pair_predictions, output_path) + + # Create temporary directory with the structured video files + slp_path = tmpdir.mkdir("mp4_directory") + + # Copy and paste the video into the temp dir multiple times + num_copies = 3 + for i in range(num_copies): + # Construct the destination path with a unique name + dest_path = slp_path / f"centered_pair_vid_copy_{i}.mp4" + shutil.copy(centered_pair_vid_path, dest_path) + + # Create sleap-track command + args = ( + f"{slp_path} --model {min_centroid_model_path} " + f"--tracking.tracker simple " + f"-o {output_path} " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + # Run inference + with pytest.raises(ValueError): + sleap_track(args=args) + + +def test_sleap_track_invalid_input( + min_centroid_model_path: str, + min_centered_instance_model_path: str, +): + + slp_path = "" + + # Create sleap-track command + args = ( + f"{slp_path} --model {min_centroid_model_path} " + f"--tracking.tracker simple " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + # Run inference + with pytest.raises(ValueError): + sleap_track(args=args) + + def test_flow_tracker(centered_pair_predictions: Labels, tmpdir): """Test flow tracker instances are pruned.""" labels: Labels = centered_pair_predictions