Skip to content

Commit d3ad226

Browse files
authored
Allow csv and text file support on sleap track (#1875)
* initial changes * csv support and test case * increased code coverage * Error fixing, black, deletion of (self-written) unused code * final edits * black * documentation changes * documentation changes
1 parent 28c34e2 commit d3ad226

File tree

3 files changed

+328
-125
lines changed

3 files changed

+328
-125
lines changed

docs/guides/cli.md

+5-2
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,10 @@ usage: sleap-track [-h] [-m MODELS] [--frames FRAMES] [--only-labeled-frames] [-
138138
[data_path]
139139
140140
positional arguments:
141-
data_path Path to data to predict on. This can be a labels (.slp) file or any supported video format.
141+
data_path Path to data to predict on. This can be one of the following: A .slp file containing labeled data; A folder containing multiple
142+
video files in supported formats; An individual video file in a supported format; A CSV file with a column of video file paths.
143+
If more than one column is provided in the CSV file, the first will be used for the input data paths and the next column will be
144+
used as the output paths; A text file with a path to a video file on each line
142145
143146
optional arguments:
144147
-h, --help show this help message and exit
@@ -153,7 +156,7 @@ optional arguments:
153156
Only run inference on unlabeled suggested frames when running on labels dataset. This is useful for generating predictions for
154157
initialization during labeling.
155158
-o OUTPUT, --output OUTPUT
156-
The output filename to use for the predicted data. If not provided, defaults to '[data_path].predictions.slp'.
159+
The output filename or directory path to use for the predicted data. If not provided, defaults to '[data_path].predictions.slp'.
157160
--no-empty-frames Clear any empty frames that did not have any detected instances before saving to output.
158161
--verbosity {none,rich,json}
159162
Verbosity of inference progress reporting. 'none' does not output anything during inference, 'rich' displays an updating

sleap/nn/inference.py

+147-109
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import atexit
3434
import subprocess
3535
import rich.progress
36+
import pandas as pd
3637
from rich.pretty import pprint
3738
from collections import deque
3839
import json
@@ -5285,8 +5286,10 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]:
52855286
args: Parsed CLI namespace.
52865287
52875288
Returns:
5288-
A tuple of `(provider, data_path)` with the data `Provider` and path to the data
5289-
that was specified in the args.
5289+
`(provider_list, data_path_list, output_path_list)` where `provider_list` contains the data providers,
5290+
`data_path_list` contains the paths to the specified data, and the `output_path_list` contains the list
5291+
of output paths if a CSV file with a column of output paths was provided; otherwise, `output_path_list`
5292+
defaults to None
52905293
"""
52915294

52925295
# Figure out which input path to use.
@@ -5300,71 +5303,115 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]:
53005303

53015304
data_path_obj = Path(data_path)
53025305

5306+
# Set output_path_list to None as a default to return later
5307+
output_path_list = None
5308+
53035309
# Check that input value is valid
53045310
if not data_path_obj.exists():
53055311
raise ValueError("Path to data_path does not exist")
53065312

5307-
# Check for multiple video inputs
5308-
# Compile file(s) into a list for later itteration
5309-
if data_path_obj.is_dir():
5310-
data_path_list = []
5311-
for file_path in data_path_obj.iterdir():
5312-
if file_path.is_file():
5313-
data_path_list.append(Path(file_path))
53145313
elif data_path_obj.is_file():
5315-
data_path_list = [data_path_obj]
5314+
# If the file is a CSV file, check for data_paths and output_paths
5315+
if data_path_obj.suffix.lower() == ".csv":
5316+
try:
5317+
data_path_column = None
5318+
# Read the CSV file
5319+
df = pd.read_csv(data_path)
5320+
5321+
# collect data_paths from column
5322+
for col_index in range(df.shape[1]):
5323+
path_str = df.iloc[0, col_index]
5324+
if Path(path_str).exists():
5325+
data_path_column = df.columns[col_index]
5326+
break
5327+
if data_path_column is None:
5328+
raise ValueError(
5329+
f"Column containing valid data_paths does not exist in the CSV file: {data_path}"
5330+
)
5331+
raw_data_path_list = df[data_path_column].tolist()
5332+
5333+
# optional output_path column to specify multiple output_paths
5334+
output_path_column_index = df.columns.get_loc(data_path_column) + 1
5335+
if (
5336+
output_path_column_index < df.shape[1]
5337+
and df.iloc[:, output_path_column_index].dtype == object
5338+
):
5339+
# Ensure the next column exists
5340+
output_path_list = df.iloc[:, output_path_column_index].tolist()
5341+
else:
5342+
output_path_list = None
5343+
5344+
except pd.errors.EmptyDataError as e:
5345+
raise ValueError(f"CSV file is empty: {data_path}. Error: {e}") from e
5346+
5347+
# If the file is a text file, collect data_paths
5348+
elif data_path_obj.suffix.lower() == ".txt":
5349+
try:
5350+
with open(data_path_obj, "r") as file:
5351+
raw_data_path_list = [line.strip() for line in file.readlines()]
5352+
except Exception as e:
5353+
raise ValueError(
5354+
f"Error reading text file: {data_path}. Error: {e}"
5355+
) from e
5356+
else:
5357+
raw_data_path_list = [data_path_obj.as_posix()]
5358+
5359+
raw_data_path_list = [Path(p) for p in raw_data_path_list]
5360+
5361+
# Check for multiple video inputs
5362+
# Compile file(s) into a list for later iteration
5363+
elif data_path_obj.is_dir():
5364+
raw_data_path_list = [
5365+
file_path for file_path in data_path_obj.iterdir() if file_path.is_file()
5366+
]
53165367

53175368
# Provider list to accomodate multiple video inputs
5318-
output_provider_list = []
5319-
output_data_path_list = []
5320-
for file_path in data_path_list:
5369+
provider_list = []
5370+
data_path_list = []
5371+
for file_path in raw_data_path_list:
53215372
# Create a provider for each file
5322-
if file_path.as_posix().endswith(".slp") and len(data_path_list) > 1:
5373+
if file_path.as_posix().endswith(".slp") and len(raw_data_path_list) > 1:
53235374
print(f"slp file skipped: {file_path.as_posix()}")
53245375

53255376
elif file_path.as_posix().endswith(".slp"):
53265377
labels = sleap.load_file(file_path.as_posix())
53275378

53285379
if args.only_labeled_frames:
5329-
output_provider_list.append(
5330-
LabelsReader.from_user_labeled_frames(labels)
5331-
)
5380+
provider_list.append(LabelsReader.from_user_labeled_frames(labels))
53325381
elif args.only_suggested_frames:
5333-
output_provider_list.append(
5334-
LabelsReader.from_unlabeled_suggestions(labels)
5335-
)
5382+
provider_list.append(LabelsReader.from_unlabeled_suggestions(labels))
53365383
elif getattr(args, "video.index") != "":
5337-
output_provider_list.append(
5384+
provider_list.append(
53385385
VideoReader(
53395386
video=labels.videos[int(getattr(args, "video.index"))],
53405387
example_indices=frame_list(args.frames),
53415388
)
53425389
)
53435390
else:
5344-
output_provider_list.append(LabelsReader(labels))
5391+
provider_list.append(LabelsReader(labels))
53455392

5346-
output_data_path_list.append(file_path)
5393+
data_path_list.append(file_path)
53475394

53485395
else:
53495396
try:
53505397
video_kwargs = dict(
53515398
dataset=vars(args).get("video.dataset"),
53525399
input_format=vars(args).get("video.input_format"),
53535400
)
5354-
output_provider_list.append(
5401+
provider_list.append(
53555402
VideoReader.from_filepath(
53565403
filename=file_path.as_posix(),
53575404
example_indices=frame_list(args.frames),
53585405
**video_kwargs,
53595406
)
53605407
)
53615408
print(f"Video: {file_path.as_posix()}")
5362-
output_data_path_list.append(file_path)
5409+
data_path_list.append(file_path)
53635410
# TODO: Clean this up.
53645411
except Exception:
53655412
print(f"Error reading file: {file_path.as_posix()}")
53665413

5367-
return output_provider_list, output_data_path_list
5414+
return provider_list, data_path_list, output_path_list
53685415

53695416

53705417
def _make_predictor_from_cli(args: argparse.Namespace) -> Predictor:
@@ -5496,19 +5543,20 @@ def main(args: Optional[list] = None):
54965543
print()
54975544

54985545
# Setup data loader.
5499-
provider_list, data_path_list = _make_provider_from_cli(args)
5546+
provider_list, data_path_list, output_path_list = _make_provider_from_cli(args)
55005547

5501-
output_path = args.output
5548+
output_path = None
55025549

5503-
# check if output_path is valid before running inference
5504-
if (
5505-
output_path is not None
5506-
and Path(output_path).is_file()
5507-
and len(data_path_list) > 1
5508-
):
5509-
raise ValueError(
5510-
"output_path argument must be a directory if multiple video inputs are given"
5511-
)
5550+
# if output_path has not been extracted from a csv file yet
5551+
if output_path_list is None and args.output is not None:
5552+
output_path = args.output
5553+
output_path_obj = Path(output_path)
5554+
5555+
# check if output_path is valid before running inference
5556+
if Path(output_path).is_file() and len(data_path_list) > 1:
5557+
raise ValueError(
5558+
"output_path argument must be a directory if multiple video inputs are given"
5559+
)
55125560

55135561
# Setup tracker.
55145562
tracker = _make_tracker_from_cli(args)
@@ -5520,7 +5568,7 @@ def main(args: Optional[list] = None):
55205568
if args.models is not None:
55215569

55225570
# Run inference on all files inputed
5523-
for data_path, provider in zip(data_path_list, provider_list):
5571+
for i, (data_path, provider) in enumerate(zip(data_path_list, provider_list)):
55245572
# Setup models.
55255573
data_path_obj = Path(data_path)
55265574
predictor = _make_predictor_from_cli(args)
@@ -5531,21 +5579,25 @@ def main(args: Optional[list] = None):
55315579

55325580
# if output path was not provided, create an output path
55335581
if output_path is None:
5534-
output_path = f"{data_path.as_posix()}.predictions.slp"
5535-
output_path_obj = Path(output_path)
5582+
# if output path was not provided, create an output path
5583+
if output_path_list:
5584+
output_path = output_path_list[i]
5585+
5586+
else:
5587+
output_path = data_path_obj.with_suffix(".predictions.slp")
55365588

5537-
else:
55385589
output_path_obj = Path(output_path)
55395590

5540-
# if output_path was provided and multiple inputs were provided, create a directory to store outputs
5541-
if len(data_path_list) > 1:
5542-
output_path = (
5543-
output_path_obj
5544-
/ data_path_obj.with_suffix(".predictions.slp").name
5545-
)
5546-
output_path_obj = Path(output_path)
5547-
# Create the containing directory if needed.
5548-
output_path_obj.parent.mkdir(exist_ok=True, parents=True)
5591+
# if output_path was provided and multiple inputs were provided, create a directory to store outputs
5592+
elif len(data_path_list) > 1:
5593+
output_path_obj = Path(output_path)
5594+
output_path = (
5595+
output_path_obj
5596+
/ (data_path_obj.with_suffix(".predictions.slp")).name
5597+
)
5598+
output_path_obj = Path(output_path)
5599+
# Create the containing directory if needed.
5600+
output_path_obj.parent.mkdir(exist_ok=True, parents=True)
55495601

55505602
labels_pr.provenance["model_paths"] = predictor.model_paths
55515603
labels_pr.provenance["predictor"] = type(predictor).__name__
@@ -5577,7 +5629,12 @@ def main(args: Optional[list] = None):
55775629
labels_pr.provenance["args"] = vars(args)
55785630

55795631
# Save results.
5580-
labels_pr.save(output_path)
5632+
try:
5633+
labels_pr.save(output_path)
5634+
except Exception:
5635+
print("WARNING: Provided output path invalid.")
5636+
fallback_path = data_path_obj.with_suffix(".predictions.slp")
5637+
labels_pr.save(fallback_path)
55815638
print("Saved output:", output_path)
55825639

55835640
if args.open_in_gui:
@@ -5588,76 +5645,57 @@ def main(args: Optional[list] = None):
55885645

55895646
# running tracking on existing prediction file
55905647
elif getattr(args, "tracking.tracker") is not None:
5591-
for data_path, provider in zip(data_path_list, provider_list):
5592-
# Load predictions
5593-
data_path_obj = Path(data_path)
5594-
print("Loading predictions...")
5595-
labels_pr = sleap.load_file(data_path_obj.as_posix())
5596-
frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx)
5648+
provider = provider_list[0]
5649+
data_path = data_path_list[0]
55975650

5598-
print("Starting tracker...")
5599-
frames = run_tracker(frames=frames, tracker=tracker)
5600-
tracker.final_pass(frames)
5651+
# Load predictions
5652+
data_path = args.data_path
5653+
print("Loading predictions...")
5654+
labels_pr = sleap.load_file(data_path)
5655+
frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx)
56015656

5602-
labels_pr = Labels(labeled_frames=frames)
5657+
print("Starting tracker...")
5658+
frames = run_tracker(frames=frames, tracker=tracker)
5659+
tracker.final_pass(frames)
56035660

5604-
if output_path is None:
5605-
output_path = f"{data_path}.{tracker.get_name()}.slp"
5606-
output_path_obj = Path(output_path)
5661+
labels_pr = Labels(labeled_frames=frames)
56075662

5608-
else:
5609-
output_path_obj = Path(output_path)
5610-
if (
5611-
output_path_obj.exists()
5612-
and output_path_obj.is_file()
5613-
and len(data_path_list) > 1
5614-
):
5615-
raise ValueError(
5616-
"output_path argument must be a directory if multiple video inputs are given"
5617-
)
5663+
if output_path is None:
5664+
output_path = f"{data_path}.{tracker.get_name()}.slp"
56185665

5619-
elif not output_path_obj.exists() and len(data_path_list) > 1:
5620-
output_path = output_path_obj / data_path_obj.with_suffix(
5621-
".predictions.slp"
5622-
)
5623-
output_path_obj = Path(output_path)
5624-
output_path_obj.parent.mkdir(exist_ok=True, parents=True)
5666+
if args.no_empty_frames:
5667+
# Clear empty frames if specified.
5668+
labels_pr.remove_empty_frames()
56255669

5626-
if args.no_empty_frames:
5627-
# Clear empty frames if specified.
5628-
labels_pr.remove_empty_frames()
5670+
finish_timestamp = str(datetime.now())
5671+
total_elapsed = time() - t0
5672+
print("Finished inference at:", finish_timestamp)
5673+
print(f"Total runtime: {total_elapsed} secs")
5674+
print(f"Predicted frames: {len(labels_pr)}/{len(provider)}")
56295675

5630-
finish_timestamp = str(datetime.now())
5631-
total_elapsed = time() - t0
5632-
print("Finished inference at:", finish_timestamp)
5633-
print(f"Total runtime: {total_elapsed} secs")
5634-
print(f"Predicted frames: {len(labels_pr)}/{len(provider)}")
5676+
# Add provenance metadata to predictions.
5677+
labels_pr.provenance["sleap_version"] = sleap.__version__
5678+
labels_pr.provenance["platform"] = platform.platform()
5679+
labels_pr.provenance["command"] = " ".join(sys.argv)
5680+
labels_pr.provenance["data_path"] = data_path
5681+
labels_pr.provenance["output_path"] = output_path
5682+
labels_pr.provenance["total_elapsed"] = total_elapsed
5683+
labels_pr.provenance["start_timestamp"] = start_timestamp
5684+
labels_pr.provenance["finish_timestamp"] = finish_timestamp
56355685

5636-
# Add provenance metadata to predictions.
5637-
labels_pr.provenance["sleap_version"] = sleap.__version__
5638-
labels_pr.provenance["platform"] = platform.platform()
5639-
labels_pr.provenance["command"] = " ".join(sys.argv)
5640-
labels_pr.provenance["data_path"] = data_path_obj.as_posix()
5641-
labels_pr.provenance["output_path"] = output_path_obj.as_posix()
5642-
labels_pr.provenance["total_elapsed"] = total_elapsed
5643-
labels_pr.provenance["start_timestamp"] = start_timestamp
5644-
labels_pr.provenance["finish_timestamp"] = finish_timestamp
5645-
5646-
print("Provenance:")
5647-
pprint(labels_pr.provenance)
5648-
print()
5686+
print("Provenance:")
5687+
pprint(labels_pr.provenance)
5688+
print()
56495689

5650-
labels_pr.provenance["args"] = vars(args)
5690+
labels_pr.provenance["args"] = vars(args)
56515691

5652-
# Save results.
5653-
labels_pr.save(output_path)
5654-
print("Saved output:", output_path)
5692+
# Save results.
5693+
labels_pr.save(output_path)
56555694

5656-
if args.open_in_gui:
5657-
subprocess.call(["sleap-label", output_path])
5695+
print("Saved output:", output_path)
56585696

5659-
# Reset output_path for next iteration
5660-
output_path = args.output
5697+
if args.open_in_gui:
5698+
subprocess.call(["sleap-label", output_path])
56615699

56625700
else:
56635701
raise ValueError(

0 commit comments

Comments
 (0)