Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow csv and text file support on sleap track #1875

Merged
merged 10 commits into from
Jul 31, 2024
254 changes: 145 additions & 109 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import atexit
import subprocess
import rich.progress
import pandas as pd
from rich.pretty import pprint
from collections import deque
import json
Expand Down Expand Up @@ -5285,8 +5286,8 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]:
args: Parsed CLI namespace.

Returns:
A tuple of `(provider, data_path)` with the data `Provider` and path to the data
that was specified in the args.
`(provider_list, data_path_list, output_path_list)` with the data `Provider`, path to the data
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved
that was specified in the args, and list out output paths if a csv file was inputed.
"""

# Figure out which input path to use.
Expand All @@ -5300,71 +5301,115 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]:

data_path_obj = Path(data_path)

# Set output_path_list to None as a default to return later
output_path_list = None
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved

# Check that input value is valid
if not data_path_obj.exists():
raise ValueError("Path to data_path does not exist")

# Check for multiple video inputs
# Compile file(s) into a list for later itteration
if data_path_obj.is_dir():
data_path_list = []
for file_path in data_path_obj.iterdir():
if file_path.is_file():
data_path_list.append(Path(file_path))
elif data_path_obj.is_file():
data_path_list = [data_path_obj]
# If the file is a CSV file, check for data_paths and output_paths
if data_path_obj.suffix.lower() == ".csv":
try:
data_path_column = None
# Read the CSV file
df = pd.read_csv(data_path)

# collect data_paths from column
for col_index in range(df.shape[1]):
path_str = df.iloc[0, col_index]
if Path(path_str).exists():
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved
data_path_column = df.columns[col_index]
break
if data_path_column is None:
raise ValueError(
f"Column containing valid data_paths does not exist in the CSV file: {data_path}"
)
raw_data_path_list = df[data_path_column].tolist()

# optional output_path column to specify multiple output_paths
output_path_column_index = df.columns.get_loc(data_path_column) + 1
if (
output_path_column_index < df.shape[1]
and df.iloc[:, output_path_column_index].dtype == object
):
# Ensure the next column exists
output_path_list = df.iloc[:, output_path_column_index].tolist()
else:
output_path_list = None

except pd.errors.EmptyDataError as e:
raise ValueError(f"CSV file is empty: {data_path}. Error: {e}") from e

emdavis02 marked this conversation as resolved.
Show resolved Hide resolved
# If the file is a text file, collect data_paths
elif data_path_obj.suffix.lower() == ".txt":
try:
with open(data_path_obj, "r") as file:
raw_data_path_list = [line.strip() for line in file.readlines()]
except Exception as e:
raise ValueError(
f"Error reading text file: {data_path}. Error: {e}"
) from e
else:
raw_data_path_list = [data_path_obj.as_posix()]

raw_data_path_list = [Path(p) for p in raw_data_path_list]

# Check for multiple video inputs
# Compile file(s) into a list for later iteration
elif data_path_obj.is_dir():
raw_data_path_list = [
file_path for file_path in data_path_obj.iterdir() if file_path.is_file()
]
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved

# Provider list to accomodate multiple video inputs
output_provider_list = []
output_data_path_list = []
for file_path in data_path_list:
provider_list = []
data_path_list = []
for file_path in raw_data_path_list:
# Create a provider for each file
if file_path.as_posix().endswith(".slp") and len(data_path_list) > 1:
if file_path.as_posix().endswith(".slp") and len(raw_data_path_list) > 1:
print(f"slp file skipped: {file_path.as_posix()}")

elif file_path.as_posix().endswith(".slp"):
labels = sleap.load_file(file_path.as_posix())

if args.only_labeled_frames:
output_provider_list.append(
LabelsReader.from_user_labeled_frames(labels)
)
provider_list.append(LabelsReader.from_user_labeled_frames(labels))
elif args.only_suggested_frames:
output_provider_list.append(
LabelsReader.from_unlabeled_suggestions(labels)
)
provider_list.append(LabelsReader.from_unlabeled_suggestions(labels))
elif getattr(args, "video.index") != "":
output_provider_list.append(
provider_list.append(
VideoReader(
video=labels.videos[int(getattr(args, "video.index"))],
example_indices=frame_list(args.frames),
)
)
else:
output_provider_list.append(LabelsReader(labels))
provider_list.append(LabelsReader(labels))

output_data_path_list.append(file_path)
data_path_list.append(file_path)

else:
try:
video_kwargs = dict(
dataset=vars(args).get("video.dataset"),
input_format=vars(args).get("video.input_format"),
)
output_provider_list.append(
provider_list.append(
VideoReader.from_filepath(
filename=file_path.as_posix(),
example_indices=frame_list(args.frames),
**video_kwargs,
)
)
print(f"Video: {file_path.as_posix()}")
output_data_path_list.append(file_path)
data_path_list.append(file_path)
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved
# TODO: Clean this up.
except Exception:
print(f"Error reading file: {file_path.as_posix()}")

return output_provider_list, output_data_path_list
return provider_list, data_path_list, output_path_list


def _make_predictor_from_cli(args: argparse.Namespace) -> Predictor:
Expand Down Expand Up @@ -5496,19 +5541,20 @@ def main(args: Optional[list] = None):
print()

# Setup data loader.
provider_list, data_path_list = _make_provider_from_cli(args)
provider_list, data_path_list, output_path_list = _make_provider_from_cli(args)

output_path = args.output
output_path = None

# check if output_path is valid before running inference
if (
output_path is not None
and Path(output_path).is_file()
and len(data_path_list) > 1
):
raise ValueError(
"output_path argument must be a directory if multiple video inputs are given"
)
# if output_path has not been extracted from a csv file yet
if output_path_list is None and args.output is not None:
output_path = args.output
output_path_obj = Path(output_path)

# check if output_path is valid before running inference
if Path(output_path).is_file() and len(data_path_list) > 1:
raise ValueError(
"output_path argument must be a directory if multiple video inputs are given"
)

# Setup tracker.
tracker = _make_tracker_from_cli(args)
Expand All @@ -5520,7 +5566,7 @@ def main(args: Optional[list] = None):
if args.models is not None:

# Run inference on all files inputed
for data_path, provider in zip(data_path_list, provider_list):
for i, (data_path, provider) in enumerate(zip(data_path_list, provider_list)):
# Setup models.
data_path_obj = Path(data_path)
predictor = _make_predictor_from_cli(args)
Expand All @@ -5531,21 +5577,25 @@ def main(args: Optional[list] = None):

# if output path was not provided, create an output path
if output_path is None:
output_path = f"{data_path.as_posix()}.predictions.slp"
output_path_obj = Path(output_path)
# if output path was not provided, create an output path
if output_path_list:
output_path = output_path_list[i]

else:
output_path = data_path_obj.with_suffix(".predictions.slp")

else:
output_path_obj = Path(output_path)

# if output_path was provided and multiple inputs were provided, create a directory to store outputs
if len(data_path_list) > 1:
output_path = (
output_path_obj
/ data_path_obj.with_suffix(".predictions.slp").name
)
output_path_obj = Path(output_path)
# Create the containing directory if needed.
output_path_obj.parent.mkdir(exist_ok=True, parents=True)
# if output_path was provided and multiple inputs were provided, create a directory to store outputs
elif len(data_path_list) > 1:
output_path_obj = Path(output_path)
output_path = (
output_path_obj
/ (data_path_obj.with_suffix(".predictions.slp")).name
)
output_path_obj = Path(output_path)
# Create the containing directory if needed.
output_path_obj.parent.mkdir(exist_ok=True, parents=True)

labels_pr.provenance["model_paths"] = predictor.model_paths
labels_pr.provenance["predictor"] = type(predictor).__name__
Expand Down Expand Up @@ -5577,7 +5627,12 @@ def main(args: Optional[list] = None):
labels_pr.provenance["args"] = vars(args)

# Save results.
labels_pr.save(output_path)
try:
labels_pr.save(output_path)
except Exception:
print("WARNING: Provided output path invalid.")
fallback_path = data_path_obj.with_suffix(".predictions.slp")
labels_pr.save(fallback_path)
print("Saved output:", output_path)

if args.open_in_gui:
Expand All @@ -5588,76 +5643,57 @@ def main(args: Optional[list] = None):

# running tracking on existing prediction file
elif getattr(args, "tracking.tracker") is not None:
for data_path, provider in zip(data_path_list, provider_list):
# Load predictions
data_path_obj = Path(data_path)
print("Loading predictions...")
labels_pr = sleap.load_file(data_path_obj.as_posix())
frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx)
provider = provider_list[0]
data_path = data_path_list[0]

print("Starting tracker...")
frames = run_tracker(frames=frames, tracker=tracker)
tracker.final_pass(frames)
# Load predictions
data_path = args.data_path
print("Loading predictions...")
labels_pr = sleap.load_file(data_path)
frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx)

labels_pr = Labels(labeled_frames=frames)
print("Starting tracker...")
frames = run_tracker(frames=frames, tracker=tracker)
tracker.final_pass(frames)

if output_path is None:
output_path = f"{data_path}.{tracker.get_name()}.slp"
output_path_obj = Path(output_path)
labels_pr = Labels(labeled_frames=frames)

else:
output_path_obj = Path(output_path)
if (
output_path_obj.exists()
and output_path_obj.is_file()
and len(data_path_list) > 1
):
raise ValueError(
"output_path argument must be a directory if multiple video inputs are given"
)
if output_path is None:
output_path = f"{data_path}.{tracker.get_name()}.slp"

elif not output_path_obj.exists() and len(data_path_list) > 1:
output_path = output_path_obj / data_path_obj.with_suffix(
".predictions.slp"
)
output_path_obj = Path(output_path)
output_path_obj.parent.mkdir(exist_ok=True, parents=True)
if args.no_empty_frames:
# Clear empty frames if specified.
labels_pr.remove_empty_frames()

if args.no_empty_frames:
# Clear empty frames if specified.
labels_pr.remove_empty_frames()
finish_timestamp = str(datetime.now())
total_elapsed = time() - t0
print("Finished inference at:", finish_timestamp)
print(f"Total runtime: {total_elapsed} secs")
print(f"Predicted frames: {len(labels_pr)}/{len(provider)}")

finish_timestamp = str(datetime.now())
total_elapsed = time() - t0
print("Finished inference at:", finish_timestamp)
print(f"Total runtime: {total_elapsed} secs")
print(f"Predicted frames: {len(labels_pr)}/{len(provider)}")
# Add provenance metadata to predictions.
labels_pr.provenance["sleap_version"] = sleap.__version__
labels_pr.provenance["platform"] = platform.platform()
labels_pr.provenance["command"] = " ".join(sys.argv)
labels_pr.provenance["data_path"] = data_path
labels_pr.provenance["output_path"] = output_path
labels_pr.provenance["total_elapsed"] = total_elapsed
labels_pr.provenance["start_timestamp"] = start_timestamp
labels_pr.provenance["finish_timestamp"] = finish_timestamp

# Add provenance metadata to predictions.
labels_pr.provenance["sleap_version"] = sleap.__version__
labels_pr.provenance["platform"] = platform.platform()
labels_pr.provenance["command"] = " ".join(sys.argv)
labels_pr.provenance["data_path"] = data_path_obj.as_posix()
labels_pr.provenance["output_path"] = output_path_obj.as_posix()
labels_pr.provenance["total_elapsed"] = total_elapsed
labels_pr.provenance["start_timestamp"] = start_timestamp
labels_pr.provenance["finish_timestamp"] = finish_timestamp

print("Provenance:")
pprint(labels_pr.provenance)
print()
print("Provenance:")
pprint(labels_pr.provenance)
print()

labels_pr.provenance["args"] = vars(args)
labels_pr.provenance["args"] = vars(args)

# Save results.
labels_pr.save(output_path)
print("Saved output:", output_path)
# Save results.
labels_pr.save(output_path)

if args.open_in_gui:
subprocess.call(["sleap-label", output_path])
print("Saved output:", output_path)

# Reset output_path for next iteration
output_path = args.output
if args.open_in_gui:
subprocess.call(["sleap-label", output_path])

else:
raise ValueError(
Expand Down
Loading