Skip to content

Commit

Permalink
Error fixing, black, deletion of (self-written) unused code
Browse files Browse the repository at this point in the history
  • Loading branch information
emdavis02 committed Jul 23, 2024
1 parent f41ea2a commit a69a65c
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 95 deletions.
147 changes: 71 additions & 76 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import atexit
import subprocess
import rich.progress
import pandas
import pandas as pd
from rich.pretty import pprint
from collections import deque
import json
Expand Down Expand Up @@ -5312,25 +5312,36 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]:
# If the file is a CSV file, check for data_paths and output_paths
if data_path_obj.suffix.lower() == ".csv":
try:
data_path_column = None
# Read the CSV file
df = pandas.read_csv(data_path)
df = pd.read_csv(data_path)

# collect data_paths from column
if "data_path" in df.columns:
raw_data_path_list = df["data_path"].tolist()
else:
for col_index in range(df.shape[1]):
path_str = df.iloc[0, col_index]
if Path(path_str).exists():
data_path_column = df.columns[col_index]
break
if data_path_column is None:
raise ValueError(
f"Column 'data_path' does not exist in the CSV file: {data_path}"
f"Column containing valid data_paths does not exist in the CSV file: {data_path}"
)
raw_data_path_list = df[data_path_column].tolist()

# optional output_path column to specify multiple output_paths
if "output_path" in df.columns:
output_path_list = df["output_path"].tolist()
output_path_column_index = df.columns.get_loc(data_path_column) + 1
if (
output_path_column_index < df.shape[1]
and df.iloc[:, output_path_column_index].dtype == object
):
# Ensure the next column exists
output_path_list = df.iloc[:, output_path_column_index].tolist()
else:
output_path_list = None

except pandas.errors.EmptyDataError as e:
except pd.errors.EmptyDataError as e:
raise ValueError(f"CSV file is empty: {data_path}. Error: {e}") from e


# If the file is a text file, collect data_paths
elif data_path_obj.suffix.lower() == ".txt":
try:
Expand All @@ -5341,7 +5352,7 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]:
f"Error reading text file: {data_path}. Error: {e}"
) from e
else:
raw_data_path_list = [str(data_path_obj)]
raw_data_path_list = [data_path_obj.as_posix()]

raw_data_path_list = [Path(p) for p in raw_data_path_list]

Expand Down Expand Up @@ -5540,10 +5551,7 @@ def main(args: Optional[list] = None):
output_path_obj = Path(output_path)

# check if output_path is valid before running inference
if (
Path(output_path).is_file()
and len(data_path_list) > 1
):
if Path(output_path).is_file() and len(data_path_list) > 1:
raise ValueError(
"output_path argument must be a directory if multiple video inputs are given"
)
Expand Down Expand Up @@ -5574,15 +5582,16 @@ def main(args: Optional[list] = None):
output_path = output_path_list[i]

else:
output_path = f"{data_path.as_posix()}.predictions.slp"
output_path = data_path_obj.with_suffix(".predictions.slp")

output_path_obj = Path(output_path)

# if output_path was provided and multiple inputs were provided, create a directory to store outputs
elif len(data_path_list) > 1:
output_path_obj = Path(output_path)
output_path = (
output_path_obj / data_path_obj.with_suffix(".predictions.slp").name
output_path_obj
/ (data_path_obj.with_suffix(".predictions.slp")).name
)
output_path_obj = Path(output_path)
# Create the containing directory if needed.
Expand Down Expand Up @@ -5618,7 +5627,12 @@ def main(args: Optional[list] = None):
labels_pr.provenance["args"] = vars(args)

# Save results.
labels_pr.save(output_path)
try:
labels_pr.save(output_path)
except Exception as e:
print("WARNING: Provided output path invalid.")
fallback_path = data_path_obj.with_suffix(".predictions.slp")
labels_pr.save(fallback_path)
print("Saved output:", output_path)

if args.open_in_gui:
Expand All @@ -5629,76 +5643,57 @@ def main(args: Optional[list] = None):

# running tracking on existing prediction file
elif getattr(args, "tracking.tracker") is not None:
for data_path, provider in zip(data_path_list, provider_list):
# Load predictions
data_path_obj = Path(data_path)
print("Loading predictions...")
labels_pr = sleap.load_file(data_path_obj.as_posix())
frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx)
provider = provider_list[0]
data_path = data_path_list[0]

print("Starting tracker...")
frames = run_tracker(frames=frames, tracker=tracker)
tracker.final_pass(frames)
# Load predictions
data_path = args.data_path
print("Loading predictions...")
labels_pr = sleap.load_file(data_path)
frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx)

labels_pr = Labels(labeled_frames=frames)
print("Starting tracker...")
frames = run_tracker(frames=frames, tracker=tracker)
tracker.final_pass(frames)

if output_path is None:
output_path = f"{data_path}.{tracker.get_name()}.slp"
output_path_obj = Path(output_path)
labels_pr = Labels(labeled_frames=frames)

else:
output_path_obj = Path(output_path)
if (
output_path_obj.exists()
and output_path_obj.is_file()
and len(data_path_list) > 1
):
raise ValueError(
"output_path argument must be a directory if multiple video inputs are given"
)
if output_path is None:
output_path = f"{data_path}.{tracker.get_name()}.slp"

elif not output_path_obj.exists() and len(data_path_list) > 1:
output_path = output_path_obj / data_path_obj.with_suffix(
".predictions.slp"
)
output_path_obj = Path(output_path)
output_path_obj.parent.mkdir(exist_ok=True, parents=True)
if args.no_empty_frames:
# Clear empty frames if specified.
labels_pr.remove_empty_frames()

if args.no_empty_frames:
# Clear empty frames if specified.
labels_pr.remove_empty_frames()

finish_timestamp = str(datetime.now())
total_elapsed = time() - t0
print("Finished inference at:", finish_timestamp)
print(f"Total runtime: {total_elapsed} secs")
print(f"Predicted frames: {len(labels_pr)}/{len(provider)}")
finish_timestamp = str(datetime.now())
total_elapsed = time() - t0
print("Finished inference at:", finish_timestamp)
print(f"Total runtime: {total_elapsed} secs")
print(f"Predicted frames: {len(labels_pr)}/{len(provider)}")

# Add provenance metadata to predictions.
labels_pr.provenance["sleap_version"] = sleap.__version__
labels_pr.provenance["platform"] = platform.platform()
labels_pr.provenance["command"] = " ".join(sys.argv)
labels_pr.provenance["data_path"] = data_path_obj.as_posix()
labels_pr.provenance["output_path"] = output_path_obj.as_posix()
labels_pr.provenance["total_elapsed"] = total_elapsed
labels_pr.provenance["start_timestamp"] = start_timestamp
labels_pr.provenance["finish_timestamp"] = finish_timestamp
# Add provenance metadata to predictions.
labels_pr.provenance["sleap_version"] = sleap.__version__
labels_pr.provenance["platform"] = platform.platform()
labels_pr.provenance["command"] = " ".join(sys.argv)
labels_pr.provenance["data_path"] = data_path
labels_pr.provenance["output_path"] = output_path
labels_pr.provenance["total_elapsed"] = total_elapsed
labels_pr.provenance["start_timestamp"] = start_timestamp
labels_pr.provenance["finish_timestamp"] = finish_timestamp

print("Provenance:")
pprint(labels_pr.provenance)
print()
print("Provenance:")
pprint(labels_pr.provenance)
print()

labels_pr.provenance["args"] = vars(args)
labels_pr.provenance["args"] = vars(args)

# Save results.
labels_pr.save(output_path)
print("Saved output:", output_path)
# Save results.
labels_pr.save(output_path)

if args.open_in_gui:
subprocess.call(["sleap-label", output_path])
print("Saved output:", output_path)

# Reset output_path for next iteration
output_path = args.output
if args.open_in_gui:
subprocess.call(["sleap-label", output_path])

else:
raise ValueError(
Expand Down
29 changes: 10 additions & 19 deletions tests/nn/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import numpy as np
import pytest
import pandas
import pandas as pd
import tensorflow as tf
import tensorflow_hub as hub
from numpy.testing import assert_array_equal, assert_allclose
Expand Down Expand Up @@ -1511,7 +1511,7 @@ def test_sleap_track_single_input(
sleap_track(args=args)

# Assert predictions file exists
output_path = f"{slp_path}.predictions.slp"
output_path = Path(slp_path).with_suffix(".predictions.slp")
assert Path(output_path).exists()

# Create invalid sleap-track command
Expand Down Expand Up @@ -1539,8 +1539,6 @@ def test_sleap_track_mult_input_slp(
# Copy and paste the video into the temp dir multiple times
num_copies = 3
for i in range(num_copies):
# Construct the destination path with a unique name for the video

# Construct the destination path with a unique name for the SLP file
slp_dest_path = slp_path / f"old_slp_copy_{i}.slp"
shutil.copy(slp_file, slp_dest_path)
Expand All @@ -1563,8 +1561,8 @@ def test_sleap_track_mult_input_slp(
} # Add other video formats if necessary

for file_path in slp_path_list:
if file_path.suffix in expected_extensions:
expected_output_file = f"{file_path}.predictions.slp"
if file_path in expected_extensions:
expected_output_file = Path(file_path).with_suffix(".predictions.slp")
assert Path(expected_output_file).exists()


Expand Down Expand Up @@ -1607,7 +1605,7 @@ def test_sleap_track_mult_input_slp_mp4(
# Assert predictions file exists
for file_path in slp_path_list:
if file_path.suffix == ".mp4":
expected_output_file = f"{file_path}.predictions.slp"
expected_output_file = Path(file_path).with_suffix(".predictions.slp")
assert Path(expected_output_file).exists()


Expand Down Expand Up @@ -1647,7 +1645,7 @@ def test_sleap_track_mult_input_mp4(
# Assert predictions file exists
for file_path in slp_path_list:
if file_path.suffix == ".mp4":
expected_output_file = f"{file_path}.predictions.slp"
expected_output_file = Path(file_path).with_suffix(".predictions.slp")
assert Path(expected_output_file).exists()


Expand Down Expand Up @@ -1687,17 +1685,12 @@ def test_sleap_track_output_mult(
sleap_track(args=args)
slp_path = Path(slp_path)

print(f"Contents of the directory {slp_path_obj}:")
for file in slp_path_obj.iterdir():
print(file)

# Check if there are any files in the directory
for file_path in slp_path_list:
if file_path.suffix == ".mp4":
expected_output_file = output_path_obj / (
file_path.stem + ".predictions.slp"
)
print(f"expected output: {expected_output_file}")
assert Path(expected_output_file).exists()


Expand Down Expand Up @@ -1829,7 +1822,7 @@ def test_sleap_track_invalid_csv(

# Create a CSV file with missing 'data_path' column
csv_missing_column_path = tmpdir / "missing_column.csv"
df_missing_column = pandas.DataFrame(
df_missing_column = pd.DataFrame(
{"some_other_column": ["video1.mp4", "video2.mp4", "video3.mp4"]}
)
df_missing_column.to_csv(csv_missing_column_path, index=False)
Expand All @@ -1846,9 +1839,7 @@ def test_sleap_track_invalid_csv(
).split()

# Run inference and expect ValueError for missing 'data_path' column
with pytest.raises(
ValueError, match="Column 'data_path' does not exist in the CSV file."
):
with pytest.raises(ValueError):
sleap_track(args=args_missing_column)

# Create sleap-track command for empty CSV file
Expand All @@ -1859,7 +1850,7 @@ def test_sleap_track_invalid_csv(
).split()

# Run inference and expect ValueError for empty CSV file
with pytest.raises(ValueError, match=f"CSV file is empty: {csv_empty_path}"):
with pytest.raises(ValueError):
sleap_track(args=args_empty)


Expand Down Expand Up @@ -1905,7 +1896,7 @@ def test_sleap_track_text_file_input(
# Assert predictions file exists
for file_path in slp_path_list:
if file_path.suffix == ".mp4":
expected_output_file = f"{file_path}.predictions.slp"
expected_output_file = Path(file_path).with_suffix(".predictions.slp")
assert Path(expected_output_file).exists()


Expand Down

0 comments on commit a69a65c

Please sign in to comment.