Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow csv and text file support on sleap track #1875

Merged
merged 10 commits into from
Jul 31, 2024
124 changes: 82 additions & 42 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import atexit
import subprocess
import rich.progress
import pandas
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved
from rich.pretty import pprint
from collections import deque
import json
Expand Down Expand Up @@ -5285,8 +5286,8 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]:
args: Parsed CLI namespace.

Returns:
A tuple of `(provider, data_path)` with the data `Provider` and path to the data
that was specified in the args.
`(provider_list, data_path_list, output_path_list)` with the data `Provider`, path to the data
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved
that was specified in the args, and list out output paths if a csv file was inputed.
"""

# Figure out which input path to use.
Expand All @@ -5300,71 +5301,104 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]:

data_path_obj = Path(data_path)

# Set output_path_list to None as a default to return later
output_path_list = None
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved

# Check that input value is valid
if not data_path_obj.exists():
raise ValueError("Path to data_path does not exist")

# Check for multiple video inputs
# Compile file(s) into a list for later itteration
if data_path_obj.is_dir():
data_path_list = []
for file_path in data_path_obj.iterdir():
if file_path.is_file():
data_path_list.append(Path(file_path))
elif data_path_obj.is_file():
data_path_list = [data_path_obj]
# If the file is a CSV file, check for data_paths and output_paths
if data_path_obj.suffix.lower() == ".csv":
try:
# Read the CSV file
df = pandas.read_csv(data_path)

# collect data_paths from column
if "data_path" in df.columns:
raw_data_path_list = df["data_path"].tolist()
else:
raise ValueError(
"Column 'data_path' does not exist in the CSV file."
)
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved

# optional output_path column to specify multiple output_paths
if "output_path" in df.columns:
output_path_list = df["output_path"].tolist()
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved

except FileNotFoundError as e:
raise ValueError(f"CSV file not found: {data_path}") from e
except pandas.errors.EmptyDataError as e:
raise ValueError(f"CSV file is empty: {data_path}") from e
except pandas.errors.ParserError as e:
raise ValueError(f"Error parsing CSV file: {data_path}") from e
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved

# If the file is a text file, collect data_paths
elif data_path_obj.suffix.lower() == ".txt":
with open(data_path_obj, "r") as file:
raw_data_path_list = [line.strip() for line in file.readlines()]
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved

# Else, the file is a single data_path
else:
raw_data_path_list = [data_path_obj]
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved

raw_data_path_list = [Path(p) for p in raw_data_path_list]

# Check for multiple video inputs
# Compile file(s) into a list for later iteration
elif data_path_obj.is_dir():
raw_data_path_list = [
file_path for file_path in data_path_obj.iterdir() if file_path.is_file()
]
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved

# Provider list to accomodate multiple video inputs
output_provider_list = []
output_data_path_list = []
for file_path in data_path_list:
provider_list = []
data_path_list = []
for file_path in raw_data_path_list:
# Create a provider for each file
if file_path.as_posix().endswith(".slp") and len(data_path_list) > 1:
if file_path.as_posix().endswith(".slp") and len(raw_data_path_list) > 1:
print(f"slp file skipped: {file_path.as_posix()}")

elif file_path.as_posix().endswith(".slp"):
labels = sleap.load_file(file_path.as_posix())

if args.only_labeled_frames:
output_provider_list.append(
LabelsReader.from_user_labeled_frames(labels)
)
provider_list.append(LabelsReader.from_user_labeled_frames(labels))
elif args.only_suggested_frames:
output_provider_list.append(
LabelsReader.from_unlabeled_suggestions(labels)
)
provider_list.append(LabelsReader.from_unlabeled_suggestions(labels))
elif getattr(args, "video.index") != "":
output_provider_list.append(
provider_list.append(
VideoReader(
video=labels.videos[int(getattr(args, "video.index"))],
example_indices=frame_list(args.frames),
)
)
else:
output_provider_list.append(LabelsReader(labels))
provider_list.append(LabelsReader(labels))

output_data_path_list.append(file_path)
data_path_list.append(file_path)

else:
try:
video_kwargs = dict(
dataset=vars(args).get("video.dataset"),
input_format=vars(args).get("video.input_format"),
)
output_provider_list.append(
provider_list.append(
VideoReader.from_filepath(
filename=file_path.as_posix(),
example_indices=frame_list(args.frames),
**video_kwargs,
)
)
print(f"Video: {file_path.as_posix()}")
output_data_path_list.append(file_path)
data_path_list.append(file_path)
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved
# TODO: Clean this up.
except Exception:
print(f"Error reading file: {file_path.as_posix()}")

return output_provider_list, output_data_path_list
return provider_list, data_path_list, output_path_list


def _make_predictor_from_cli(args: argparse.Namespace) -> Predictor:
Expand Down Expand Up @@ -5496,9 +5530,13 @@ def main(args: Optional[list] = None):
print()

# Setup data loader.
provider_list, data_path_list = _make_provider_from_cli(args)
provider_list, data_path_list, output_path_list = _make_provider_from_cli(args)

output_path = None

output_path = args.output
# if output_path has not been extracted from a csv file yet
if output_path_list is None:
output_path = args.output
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved

# check if output_path is valid before running inference
if (
Expand All @@ -5520,7 +5558,7 @@ def main(args: Optional[list] = None):
if args.models is not None:

# Run inference on all files inputed
for data_path, provider in zip(data_path_list, provider_list):
for i, (data_path, provider) in enumerate(zip(data_path_list, provider_list)):
# Setup models.
data_path_obj = Path(data_path)
predictor = _make_predictor_from_cli(args)
Expand All @@ -5531,21 +5569,23 @@ def main(args: Optional[list] = None):

# if output path was not provided, create an output path
if output_path is None:
output_path = f"{data_path.as_posix()}.predictions.slp"
output_path_obj = Path(output_path)
# if output path was not provided, create an output path
if output_path_list:
output_path = output_path_list[i]

else:
output_path_obj = Path(output_path)
else:
output_path = f"{data_path.as_posix()}.predictions.slp"

# if output_path was provided and multiple inputs were provided, create a directory to store outputs
if len(data_path_list) > 1:
output_path = (
output_path_obj
/ data_path_obj.with_suffix(".predictions.slp").name
)
output_path_obj = Path(output_path)
# Create the containing directory if needed.
output_path_obj.parent.mkdir(exist_ok=True, parents=True)
output_path_obj = Path(output_path)

# if output_path was provided and multiple inputs were provided, create a directory to store outputs
if len(data_path_list) > 1:
output_path = (
output_path_obj / data_path_obj.with_suffix(".predictions.slp").name
)
output_path_obj = Path(output_path)
# Create the containing directory if needed.
output_path_obj.parent.mkdir(exist_ok=True, parents=True)

labels_pr.provenance["model_paths"] = predictor.model_paths
labels_pr.provenance["predictor"] = type(predictor).__name__
Expand Down
54 changes: 54 additions & 0 deletions tests/nn/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path
from typing import cast
import shutil
import csv

import numpy as np
import pytest
Expand Down Expand Up @@ -1747,6 +1748,59 @@ def test_sleap_track_invalid_input(
sleap_track(args=args)


def test_sleap_track_csv_input(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure all expected output files are checked.

The current implementation only checks for .TESTpredictions.slp files. Ensure that all expected output files are checked, regardless of the input file extension.

- for file_path in slp_path_list:
-     if file_path.suffix in expected_extensions:
-         expected_output_file = file_path.with_suffix(".TESTpredictions.slp")
-         assert Path(expected_output_file).exists()
+ for file_path in file_paths:
+     expected_output_file = file_path.with_suffix(".TESTpredictions.slp")
+     assert Path(expected_output_file).exists()
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def test_sleap_track_csv_input(
def test_sleap_track_csv_input(
for file_path in file_paths:
expected_output_file = file_path.with_suffix(".TESTpredictions.slp")
assert Path(expected_output_file).exists()

min_centroid_model_path: str,
min_centered_instance_model_path: str,
centered_pair_vid_path,
tmpdir,
):

# Create temporary directory with the structured video files
slp_path = Path(tmpdir.mkdir("mp4_directory"))

# Copy and paste the video into the temp dir multiple times
num_copies = 3
file_paths = []
for i in range(num_copies):
# Construct the destination path with a unique name
dest_path = slp_path / f"centered_pair_vid_copy_{i}.mp4"
shutil.copy(centered_pair_vid_path, dest_path)
file_paths.append(dest_path)

# Generate output paths for each data_path
output_paths = [
file_path.with_suffix(".TESTpredictions.slp") for file_path in file_paths
]

# Create a CSV file with the file paths
csv_file_path = slp_path / "file_paths.csv"
with open(csv_file_path, mode="w", newline="") as csv_file:
csv_writer = csv.writer(csv_file)
csv_writer.writerow(["data_path", "output_path"])
for data_path, output_path in zip(file_paths, output_paths):
csv_writer.writerow([data_path, output_path])

slp_path_obj = Path(slp_path)

# Create sleap-track command
args = (
f"{csv_file_path} --model {min_centroid_model_path} "
f"--tracking.tracker simple "
f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu"
).split()

slp_path_list = [file for file in slp_path_obj.iterdir() if file.is_file()]

# Run inference
sleap_track(args=args)

# Assert predictions file exists
for file_path in slp_path_list:
if file_path.suffix == ".mp4":
expected_output_file = file_path.with_suffix(".TESTpredictions.slp")
assert Path(expected_output_file).exists()
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved


def test_flow_tracker(centered_pair_predictions: Labels, tmpdir):
"""Test flow tracker instances are pruned."""
labels: Labels = centered_pair_predictions
Expand Down