Skip to content

Commit

Permalink
increased code coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
emdavis02 committed Jul 22, 2024
1 parent 4e873b8 commit f41ea2a
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 23 deletions.
47 changes: 24 additions & 23 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5320,28 +5320,28 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]:
raw_data_path_list = df["data_path"].tolist()
else:
raise ValueError(
"Column 'data_path' does not exist in the CSV file."
f"Column 'data_path' does not exist in the CSV file: {data_path}"
)

# optional output_path column to specify multiple output_paths
if "output_path" in df.columns:
output_path_list = df["output_path"].tolist()

except FileNotFoundError as e:
raise ValueError(f"CSV file not found: {data_path}") from e
except pandas.errors.EmptyDataError as e:
raise ValueError(f"CSV file is empty: {data_path}") from e
except pandas.errors.ParserError as e:
raise ValueError(f"Error parsing CSV file: {data_path}") from e
raise ValueError(f"CSV file is empty: {data_path}. Error: {e}") from e


# If the file is a text file, collect data_paths
elif data_path_obj.suffix.lower() == ".txt":
with open(data_path_obj, "r") as file:
raw_data_path_list = [line.strip() for line in file.readlines()]

# Else, the file is a single data_path
try:
with open(data_path_obj, "r") as file:
raw_data_path_list = [line.strip() for line in file.readlines()]
except Exception as e:
raise ValueError(
f"Error reading text file: {data_path}. Error: {e}"
) from e
else:
raw_data_path_list = [data_path_obj]
raw_data_path_list = [str(data_path_obj)]

raw_data_path_list = [Path(p) for p in raw_data_path_list]

Expand Down Expand Up @@ -5535,18 +5535,18 @@ def main(args: Optional[list] = None):
output_path = None

# if output_path has not been extracted from a csv file yet
if output_path_list is None:
if output_path_list is None and args.output is not None:
output_path = args.output
output_path_obj = Path(output_path)

# check if output_path is valid before running inference
if (
output_path is not None
and Path(output_path).is_file()
and len(data_path_list) > 1
):
raise ValueError(
"output_path argument must be a directory if multiple video inputs are given"
)
# check if output_path is valid before running inference
if (
Path(output_path).is_file()
and len(data_path_list) > 1
):
raise ValueError(
"output_path argument must be a directory if multiple video inputs are given"
)

# Setup tracker.
tracker = _make_tracker_from_cli(args)
Expand Down Expand Up @@ -5576,10 +5576,11 @@ def main(args: Optional[list] = None):
else:
output_path = f"{data_path.as_posix()}.predictions.slp"

output_path_obj = Path(output_path)
output_path_obj = Path(output_path)

# if output_path was provided and multiple inputs were provided, create a directory to store outputs
if len(data_path_list) > 1:
elif len(data_path_list) > 1:
output_path_obj = Path(output_path)
output_path = (
output_path_obj / data_path_obj.with_suffix(".predictions.slp").name
)
Expand Down
108 changes: 108 additions & 0 deletions tests/nn/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
import pytest
import pandas
import tensorflow as tf
import tensorflow_hub as hub
from numpy.testing import assert_array_equal, assert_allclose
Expand Down Expand Up @@ -1686,12 +1687,17 @@ def test_sleap_track_output_mult(
sleap_track(args=args)
slp_path = Path(slp_path)

print(f"Contents of the directory {slp_path_obj}:")
for file in slp_path_obj.iterdir():
print(file)

# Check if there are any files in the directory
for file_path in slp_path_list:
if file_path.suffix == ".mp4":
expected_output_file = output_path_obj / (
file_path.stem + ".predictions.slp"
)
print(f"expected output: {expected_output_file}")
assert Path(expected_output_file).exists()


Expand Down Expand Up @@ -1747,6 +1753,20 @@ def test_sleap_track_invalid_input(
with pytest.raises(ValueError):
sleap_track(args=args)

# Test with a non-existent path
slp_path = "/path/to/nonexistent/file.mp4"

# Create sleap-track command for non-existent path
args = (
f"{slp_path} --model {min_centroid_model_path} "
f"--tracking.tracker simple "
f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu"
).split()

# Run inference and expect a ValueError for non-existent path
with pytest.raises(ValueError):
sleap_track(args=args)


def test_sleap_track_csv_input(
min_centroid_model_path: str,
Expand Down Expand Up @@ -1801,6 +1821,94 @@ def test_sleap_track_csv_input(
assert Path(expected_output_file).exists()


def test_sleap_track_invalid_csv(
min_centroid_model_path: str,
min_centered_instance_model_path: str,
tmpdir,
):

# Create a CSV file with missing 'data_path' column
csv_missing_column_path = tmpdir / "missing_column.csv"
df_missing_column = pandas.DataFrame(
{"some_other_column": ["video1.mp4", "video2.mp4", "video3.mp4"]}
)
df_missing_column.to_csv(csv_missing_column_path, index=False)

# Create an empty CSV file
csv_empty_path = tmpdir / "empty.csv"
open(csv_empty_path, "w").close()

# Create sleap-track command for missing 'data_path' column
args_missing_column = (
f"{csv_missing_column_path} --model {min_centroid_model_path} "
f"--tracking.tracker simple "
f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu"
).split()

# Run inference and expect ValueError for missing 'data_path' column
with pytest.raises(
ValueError, match="Column 'data_path' does not exist in the CSV file."
):
sleap_track(args=args_missing_column)

# Create sleap-track command for empty CSV file
args_empty = (
f"{csv_empty_path} --model {min_centroid_model_path} "
f"--tracking.tracker simple "
f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu"
).split()

# Run inference and expect ValueError for empty CSV file
with pytest.raises(ValueError, match=f"CSV file is empty: {csv_empty_path}"):
sleap_track(args=args_empty)


def test_sleap_track_text_file_input(
min_centroid_model_path: str,
min_centered_instance_model_path: str,
centered_pair_vid_path,
tmpdir,
):

# Create temporary directory with the structured video files
slp_path = Path(tmpdir.mkdir("mp4_directory"))

# Copy and paste the video into the temp dir multiple times
num_copies = 3
file_paths = []
for i in range(num_copies):
# Construct the destination path with a unique name
dest_path = slp_path / f"centered_pair_vid_copy_{i}.mp4"
shutil.copy(centered_pair_vid_path, dest_path)
file_paths.append(dest_path)

# Create a text file with the file paths
txt_file_path = slp_path / "file_paths.txt"
with open(txt_file_path, mode="w") as txt_file:
for file_path in file_paths:
txt_file.write(f"{file_path}\n")

slp_path_obj = Path(slp_path)

# Create sleap-track command
args = (
f"{txt_file_path} --model {min_centroid_model_path} "
f"--tracking.tracker simple "
f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu"
).split()

slp_path_list = [file for file in slp_path_obj.iterdir() if file.is_file()]

# Run inference
sleap_track(args=args)

# Assert predictions file exists
for file_path in slp_path_list:
if file_path.suffix == ".mp4":
expected_output_file = f"{file_path}.predictions.slp"
assert Path(expected_output_file).exists()


def test_flow_tracker(centered_pair_predictions: Labels, tmpdir):
"""Test flow tracker instances are pruned."""
labels: Labels = centered_pair_predictions
Expand Down

0 comments on commit f41ea2a

Please sign in to comment.