diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index a76ccd7ce..8b3ae9eaa 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -5320,28 +5320,28 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]: raw_data_path_list = df["data_path"].tolist() else: raise ValueError( - "Column 'data_path' does not exist in the CSV file." + f"Column 'data_path' does not exist in the CSV file: {data_path}" ) # optional output_path column to specify multiple output_paths if "output_path" in df.columns: output_path_list = df["output_path"].tolist() - except FileNotFoundError as e: - raise ValueError(f"CSV file not found: {data_path}") from e except pandas.errors.EmptyDataError as e: - raise ValueError(f"CSV file is empty: {data_path}") from e - except pandas.errors.ParserError as e: - raise ValueError(f"Error parsing CSV file: {data_path}") from e + raise ValueError(f"CSV file is empty: {data_path}. Error: {e}") from e + # If the file is a text file, collect data_paths elif data_path_obj.suffix.lower() == ".txt": - with open(data_path_obj, "r") as file: - raw_data_path_list = [line.strip() for line in file.readlines()] - - # Else, the file is a single data_path + try: + with open(data_path_obj, "r") as file: + raw_data_path_list = [line.strip() for line in file.readlines()] + except Exception as e: + raise ValueError( + f"Error reading text file: {data_path}. Error: {e}" + ) from e else: - raw_data_path_list = [data_path_obj] + raw_data_path_list = [str(data_path_obj)] raw_data_path_list = [Path(p) for p in raw_data_path_list] @@ -5535,18 +5535,18 @@ def main(args: Optional[list] = None): output_path = None # if output_path has not been extracted from a csv file yet - if output_path_list is None: + if output_path_list is None and args.output is not None: output_path = args.output + output_path_obj = Path(output_path) - # check if output_path is valid before running inference - if ( - output_path is not None - and Path(output_path).is_file() - and len(data_path_list) > 1 - ): - raise ValueError( - "output_path argument must be a directory if multiple video inputs are given" - ) + # check if output_path is valid before running inference + if ( + Path(output_path).is_file() + and len(data_path_list) > 1 + ): + raise ValueError( + "output_path argument must be a directory if multiple video inputs are given" + ) # Setup tracker. tracker = _make_tracker_from_cli(args) @@ -5576,10 +5576,11 @@ def main(args: Optional[list] = None): else: output_path = f"{data_path.as_posix()}.predictions.slp" - output_path_obj = Path(output_path) + output_path_obj = Path(output_path) # if output_path was provided and multiple inputs were provided, create a directory to store outputs - if len(data_path_list) > 1: + elif len(data_path_list) > 1: + output_path_obj = Path(output_path) output_path = ( output_path_obj / data_path_obj.with_suffix(".predictions.slp").name ) diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index 0678b1eee..646049256 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -8,6 +8,7 @@ import numpy as np import pytest +import pandas import tensorflow as tf import tensorflow_hub as hub from numpy.testing import assert_array_equal, assert_allclose @@ -1686,12 +1687,17 @@ def test_sleap_track_output_mult( sleap_track(args=args) slp_path = Path(slp_path) + print(f"Contents of the directory {slp_path_obj}:") + for file in slp_path_obj.iterdir(): + print(file) + # Check if there are any files in the directory for file_path in slp_path_list: if file_path.suffix == ".mp4": expected_output_file = output_path_obj / ( file_path.stem + ".predictions.slp" ) + print(f"expected output: {expected_output_file}") assert Path(expected_output_file).exists() @@ -1747,6 +1753,20 @@ def test_sleap_track_invalid_input( with pytest.raises(ValueError): sleap_track(args=args) + # Test with a non-existent path + slp_path = "/path/to/nonexistent/file.mp4" + + # Create sleap-track command for non-existent path + args = ( + f"{slp_path} --model {min_centroid_model_path} " + f"--tracking.tracker simple " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + # Run inference and expect a ValueError for non-existent path + with pytest.raises(ValueError): + sleap_track(args=args) + def test_sleap_track_csv_input( min_centroid_model_path: str, @@ -1801,6 +1821,94 @@ def test_sleap_track_csv_input( assert Path(expected_output_file).exists() +def test_sleap_track_invalid_csv( + min_centroid_model_path: str, + min_centered_instance_model_path: str, + tmpdir, +): + + # Create a CSV file with missing 'data_path' column + csv_missing_column_path = tmpdir / "missing_column.csv" + df_missing_column = pandas.DataFrame( + {"some_other_column": ["video1.mp4", "video2.mp4", "video3.mp4"]} + ) + df_missing_column.to_csv(csv_missing_column_path, index=False) + + # Create an empty CSV file + csv_empty_path = tmpdir / "empty.csv" + open(csv_empty_path, "w").close() + + # Create sleap-track command for missing 'data_path' column + args_missing_column = ( + f"{csv_missing_column_path} --model {min_centroid_model_path} " + f"--tracking.tracker simple " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + # Run inference and expect ValueError for missing 'data_path' column + with pytest.raises( + ValueError, match="Column 'data_path' does not exist in the CSV file." + ): + sleap_track(args=args_missing_column) + + # Create sleap-track command for empty CSV file + args_empty = ( + f"{csv_empty_path} --model {min_centroid_model_path} " + f"--tracking.tracker simple " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + # Run inference and expect ValueError for empty CSV file + with pytest.raises(ValueError, match=f"CSV file is empty: {csv_empty_path}"): + sleap_track(args=args_empty) + + +def test_sleap_track_text_file_input( + min_centroid_model_path: str, + min_centered_instance_model_path: str, + centered_pair_vid_path, + tmpdir, +): + + # Create temporary directory with the structured video files + slp_path = Path(tmpdir.mkdir("mp4_directory")) + + # Copy and paste the video into the temp dir multiple times + num_copies = 3 + file_paths = [] + for i in range(num_copies): + # Construct the destination path with a unique name + dest_path = slp_path / f"centered_pair_vid_copy_{i}.mp4" + shutil.copy(centered_pair_vid_path, dest_path) + file_paths.append(dest_path) + + # Create a text file with the file paths + txt_file_path = slp_path / "file_paths.txt" + with open(txt_file_path, mode="w") as txt_file: + for file_path in file_paths: + txt_file.write(f"{file_path}\n") + + slp_path_obj = Path(slp_path) + + # Create sleap-track command + args = ( + f"{txt_file_path} --model {min_centroid_model_path} " + f"--tracking.tracker simple " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + slp_path_list = [file for file in slp_path_obj.iterdir() if file.is_file()] + + # Run inference + sleap_track(args=args) + + # Assert predictions file exists + for file_path in slp_path_list: + if file_path.suffix == ".mp4": + expected_output_file = f"{file_path}.predictions.slp" + assert Path(expected_output_file).exists() + + def test_flow_tracker(centered_pair_predictions: Labels, tmpdir): """Test flow tracker instances are pruned.""" labels: Labels = centered_pair_predictions