diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 8b3ae9eaa..28319f792 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -33,7 +33,7 @@ import atexit import subprocess import rich.progress -import pandas +import pandas as pd from rich.pretty import pprint from collections import deque import json @@ -5312,25 +5312,36 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]: # If the file is a CSV file, check for data_paths and output_paths if data_path_obj.suffix.lower() == ".csv": try: + data_path_column = None # Read the CSV file - df = pandas.read_csv(data_path) + df = pd.read_csv(data_path) # collect data_paths from column - if "data_path" in df.columns: - raw_data_path_list = df["data_path"].tolist() - else: + for col_index in range(df.shape[1]): + path_str = df.iloc[0, col_index] + if Path(path_str).exists(): + data_path_column = df.columns[col_index] + break + if data_path_column is None: raise ValueError( - f"Column 'data_path' does not exist in the CSV file: {data_path}" + f"Column containing valid data_paths does not exist in the CSV file: {data_path}" ) + raw_data_path_list = df[data_path_column].tolist() # optional output_path column to specify multiple output_paths - if "output_path" in df.columns: - output_path_list = df["output_path"].tolist() + output_path_column_index = df.columns.get_loc(data_path_column) + 1 + if ( + output_path_column_index < df.shape[1] + and df.iloc[:, output_path_column_index].dtype == object + ): + # Ensure the next column exists + output_path_list = df.iloc[:, output_path_column_index].tolist() + else: + output_path_list = None - except pandas.errors.EmptyDataError as e: + except pd.errors.EmptyDataError as e: raise ValueError(f"CSV file is empty: {data_path}. Error: {e}") from e - # If the file is a text file, collect data_paths elif data_path_obj.suffix.lower() == ".txt": try: @@ -5341,7 +5352,7 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]: f"Error reading text file: {data_path}. Error: {e}" ) from e else: - raw_data_path_list = [str(data_path_obj)] + raw_data_path_list = [data_path_obj.as_posix()] raw_data_path_list = [Path(p) for p in raw_data_path_list] @@ -5540,10 +5551,7 @@ def main(args: Optional[list] = None): output_path_obj = Path(output_path) # check if output_path is valid before running inference - if ( - Path(output_path).is_file() - and len(data_path_list) > 1 - ): + if Path(output_path).is_file() and len(data_path_list) > 1: raise ValueError( "output_path argument must be a directory if multiple video inputs are given" ) @@ -5574,7 +5582,7 @@ def main(args: Optional[list] = None): output_path = output_path_list[i] else: - output_path = f"{data_path.as_posix()}.predictions.slp" + output_path = data_path_obj.with_suffix(".predictions.slp") output_path_obj = Path(output_path) @@ -5582,7 +5590,8 @@ def main(args: Optional[list] = None): elif len(data_path_list) > 1: output_path_obj = Path(output_path) output_path = ( - output_path_obj / data_path_obj.with_suffix(".predictions.slp").name + output_path_obj + / (data_path_obj.with_suffix(".predictions.slp")).name ) output_path_obj = Path(output_path) # Create the containing directory if needed. @@ -5618,7 +5627,12 @@ def main(args: Optional[list] = None): labels_pr.provenance["args"] = vars(args) # Save results. - labels_pr.save(output_path) + try: + labels_pr.save(output_path) + except Exception as e: + print("WARNING: Provided output path invalid.") + fallback_path = data_path_obj.with_suffix(".predictions.slp") + labels_pr.save(fallback_path) print("Saved output:", output_path) if args.open_in_gui: @@ -5629,76 +5643,57 @@ def main(args: Optional[list] = None): # running tracking on existing prediction file elif getattr(args, "tracking.tracker") is not None: - for data_path, provider in zip(data_path_list, provider_list): - # Load predictions - data_path_obj = Path(data_path) - print("Loading predictions...") - labels_pr = sleap.load_file(data_path_obj.as_posix()) - frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx) + provider = provider_list[0] + data_path = data_path_list[0] - print("Starting tracker...") - frames = run_tracker(frames=frames, tracker=tracker) - tracker.final_pass(frames) + # Load predictions + data_path = args.data_path + print("Loading predictions...") + labels_pr = sleap.load_file(data_path) + frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx) - labels_pr = Labels(labeled_frames=frames) + print("Starting tracker...") + frames = run_tracker(frames=frames, tracker=tracker) + tracker.final_pass(frames) - if output_path is None: - output_path = f"{data_path}.{tracker.get_name()}.slp" - output_path_obj = Path(output_path) + labels_pr = Labels(labeled_frames=frames) - else: - output_path_obj = Path(output_path) - if ( - output_path_obj.exists() - and output_path_obj.is_file() - and len(data_path_list) > 1 - ): - raise ValueError( - "output_path argument must be a directory if multiple video inputs are given" - ) + if output_path is None: + output_path = f"{data_path}.{tracker.get_name()}.slp" - elif not output_path_obj.exists() and len(data_path_list) > 1: - output_path = output_path_obj / data_path_obj.with_suffix( - ".predictions.slp" - ) - output_path_obj = Path(output_path) - output_path_obj.parent.mkdir(exist_ok=True, parents=True) + if args.no_empty_frames: + # Clear empty frames if specified. + labels_pr.remove_empty_frames() - if args.no_empty_frames: - # Clear empty frames if specified. - labels_pr.remove_empty_frames() - - finish_timestamp = str(datetime.now()) - total_elapsed = time() - t0 - print("Finished inference at:", finish_timestamp) - print(f"Total runtime: {total_elapsed} secs") - print(f"Predicted frames: {len(labels_pr)}/{len(provider)}") + finish_timestamp = str(datetime.now()) + total_elapsed = time() - t0 + print("Finished inference at:", finish_timestamp) + print(f"Total runtime: {total_elapsed} secs") + print(f"Predicted frames: {len(labels_pr)}/{len(provider)}") - # Add provenance metadata to predictions. - labels_pr.provenance["sleap_version"] = sleap.__version__ - labels_pr.provenance["platform"] = platform.platform() - labels_pr.provenance["command"] = " ".join(sys.argv) - labels_pr.provenance["data_path"] = data_path_obj.as_posix() - labels_pr.provenance["output_path"] = output_path_obj.as_posix() - labels_pr.provenance["total_elapsed"] = total_elapsed - labels_pr.provenance["start_timestamp"] = start_timestamp - labels_pr.provenance["finish_timestamp"] = finish_timestamp + # Add provenance metadata to predictions. + labels_pr.provenance["sleap_version"] = sleap.__version__ + labels_pr.provenance["platform"] = platform.platform() + labels_pr.provenance["command"] = " ".join(sys.argv) + labels_pr.provenance["data_path"] = data_path + labels_pr.provenance["output_path"] = output_path + labels_pr.provenance["total_elapsed"] = total_elapsed + labels_pr.provenance["start_timestamp"] = start_timestamp + labels_pr.provenance["finish_timestamp"] = finish_timestamp - print("Provenance:") - pprint(labels_pr.provenance) - print() + print("Provenance:") + pprint(labels_pr.provenance) + print() - labels_pr.provenance["args"] = vars(args) + labels_pr.provenance["args"] = vars(args) - # Save results. - labels_pr.save(output_path) - print("Saved output:", output_path) + # Save results. + labels_pr.save(output_path) - if args.open_in_gui: - subprocess.call(["sleap-label", output_path]) + print("Saved output:", output_path) - # Reset output_path for next iteration - output_path = args.output + if args.open_in_gui: + subprocess.call(["sleap-label", output_path]) else: raise ValueError( diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index 646049256..cdd56da09 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -8,7 +8,7 @@ import numpy as np import pytest -import pandas +import pandas as pd import tensorflow as tf import tensorflow_hub as hub from numpy.testing import assert_array_equal, assert_allclose @@ -1511,7 +1511,7 @@ def test_sleap_track_single_input( sleap_track(args=args) # Assert predictions file exists - output_path = f"{slp_path}.predictions.slp" + output_path = Path(slp_path).with_suffix(".predictions.slp") assert Path(output_path).exists() # Create invalid sleap-track command @@ -1539,8 +1539,6 @@ def test_sleap_track_mult_input_slp( # Copy and paste the video into the temp dir multiple times num_copies = 3 for i in range(num_copies): - # Construct the destination path with a unique name for the video - # Construct the destination path with a unique name for the SLP file slp_dest_path = slp_path / f"old_slp_copy_{i}.slp" shutil.copy(slp_file, slp_dest_path) @@ -1563,8 +1561,8 @@ def test_sleap_track_mult_input_slp( } # Add other video formats if necessary for file_path in slp_path_list: - if file_path.suffix in expected_extensions: - expected_output_file = f"{file_path}.predictions.slp" + if file_path in expected_extensions: + expected_output_file = Path(file_path).with_suffix(".predictions.slp") assert Path(expected_output_file).exists() @@ -1607,7 +1605,7 @@ def test_sleap_track_mult_input_slp_mp4( # Assert predictions file exists for file_path in slp_path_list: if file_path.suffix == ".mp4": - expected_output_file = f"{file_path}.predictions.slp" + expected_output_file = Path(file_path).with_suffix(".predictions.slp") assert Path(expected_output_file).exists() @@ -1647,7 +1645,7 @@ def test_sleap_track_mult_input_mp4( # Assert predictions file exists for file_path in slp_path_list: if file_path.suffix == ".mp4": - expected_output_file = f"{file_path}.predictions.slp" + expected_output_file = Path(file_path).with_suffix(".predictions.slp") assert Path(expected_output_file).exists() @@ -1687,17 +1685,12 @@ def test_sleap_track_output_mult( sleap_track(args=args) slp_path = Path(slp_path) - print(f"Contents of the directory {slp_path_obj}:") - for file in slp_path_obj.iterdir(): - print(file) - # Check if there are any files in the directory for file_path in slp_path_list: if file_path.suffix == ".mp4": expected_output_file = output_path_obj / ( file_path.stem + ".predictions.slp" ) - print(f"expected output: {expected_output_file}") assert Path(expected_output_file).exists() @@ -1829,7 +1822,7 @@ def test_sleap_track_invalid_csv( # Create a CSV file with missing 'data_path' column csv_missing_column_path = tmpdir / "missing_column.csv" - df_missing_column = pandas.DataFrame( + df_missing_column = pd.DataFrame( {"some_other_column": ["video1.mp4", "video2.mp4", "video3.mp4"]} ) df_missing_column.to_csv(csv_missing_column_path, index=False) @@ -1846,9 +1839,7 @@ def test_sleap_track_invalid_csv( ).split() # Run inference and expect ValueError for missing 'data_path' column - with pytest.raises( - ValueError, match="Column 'data_path' does not exist in the CSV file." - ): + with pytest.raises(ValueError): sleap_track(args=args_missing_column) # Create sleap-track command for empty CSV file @@ -1859,7 +1850,7 @@ def test_sleap_track_invalid_csv( ).split() # Run inference and expect ValueError for empty CSV file - with pytest.raises(ValueError, match=f"CSV file is empty: {csv_empty_path}"): + with pytest.raises(ValueError): sleap_track(args=args_empty) @@ -1905,7 +1896,7 @@ def test_sleap_track_text_file_input( # Assert predictions file exists for file_path in slp_path_list: if file_path.suffix == ".mp4": - expected_output_file = f"{file_path}.predictions.slp" + expected_output_file = Path(file_path).with_suffix(".predictions.slp") assert Path(expected_output_file).exists()