Skip to content

Commit

Permalink
Allowing inference on multiple videos via sleap-track (#1784)
Browse files Browse the repository at this point in the history
* implementing proposed code changes from issue #1777

* comments

* configuring output_path to support multiple video inputs

* fixing errors from preexisting test cases

* Test case / code fixes

* extending test cases for mp4 folders

* test case for output directory

* black and code rabbit fixes

* code rabbit fixes

* as_posix errors resolved

* syntax error

* adding test data

* black

* output error resolved

* edited for push to dev branch

* black

* errors fixed, test cases implemented

* invalid output test and invalid input test

* deleting debugging statements

* deleting print statements

* black

* deleting unnecessary test case

* implemented tmpdir

* deleting extraneous file

* fixing broken test case

* fixing test_sleap_track_invalid_output

* removing support for multiple slp files

* implementing talmo's comments

* adding comments
  • Loading branch information
emdavis02 authored Jul 19, 2024
1 parent 14c21b4 commit 3e2bd25
Show file tree
Hide file tree
Showing 2 changed files with 486 additions and 87 deletions.
299 changes: 213 additions & 86 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5288,46 +5288,83 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]:
A tuple of `(provider, data_path)` with the data `Provider` and path to the data
that was specified in the args.
"""

# Figure out which input path to use.
labels_path = getattr(args, "labels", None)
if labels_path is not None:
data_path = labels_path
else:
data_path = args.data_path
data_path = args.data_path

if data_path is None or data_path == "":
raise ValueError(
"You must specify a path to a video or a labels dataset. "
"Run 'sleap-track -h' to see full command documentation."
)

if data_path.endswith(".slp"):
labels = sleap.load_file(data_path)

if args.only_labeled_frames:
provider = LabelsReader.from_user_labeled_frames(labels)
elif args.only_suggested_frames:
provider = LabelsReader.from_unlabeled_suggestions(labels)
elif getattr(args, "video.index") != "":
provider = VideoReader(
video=labels.videos[int(getattr(args, "video.index"))],
example_indices=frame_list(args.frames),
)
else:
provider = LabelsReader(labels)
data_path_obj = Path(data_path)

# Check that input value is valid
if not data_path_obj.exists():
raise ValueError("Path to data_path does not exist")

# Check for multiple video inputs
# Compile file(s) into a list for later itteration
if data_path_obj.is_dir():
data_path_list = []
for file_path in data_path_obj.iterdir():
if file_path.is_file():
data_path_list.append(Path(file_path))
elif data_path_obj.is_file():
data_path_list = [data_path_obj]

# Provider list to accomodate multiple video inputs
output_provider_list = []
output_data_path_list = []
for file_path in data_path_list:
# Create a provider for each file
if file_path.as_posix().endswith(".slp") and len(data_path_list) > 1:
print(f"slp file skipped: {file_path.as_posix()}")

elif file_path.as_posix().endswith(".slp"):
labels = sleap.load_file(file_path.as_posix())

if args.only_labeled_frames:
output_provider_list.append(
LabelsReader.from_user_labeled_frames(labels)
)
elif args.only_suggested_frames:
output_provider_list.append(
LabelsReader.from_unlabeled_suggestions(labels)
)
elif getattr(args, "video.index") != "":
output_provider_list.append(
VideoReader(
video=labels.videos[int(getattr(args, "video.index"))],
example_indices=frame_list(args.frames),
)
)
else:
output_provider_list.append(LabelsReader(labels))

else:
print(f"Video: {data_path}")
# TODO: Clean this up.
video_kwargs = dict(
dataset=vars(args).get("video.dataset"),
input_format=vars(args).get("video.input_format"),
)
provider = VideoReader.from_filepath(
filename=data_path, example_indices=frame_list(args.frames), **video_kwargs
)
output_data_path_list.append(file_path)

return provider, data_path
else:
try:
video_kwargs = dict(
dataset=vars(args).get("video.dataset"),
input_format=vars(args).get("video.input_format"),
)
output_provider_list.append(
VideoReader.from_filepath(
filename=file_path.as_posix(),
example_indices=frame_list(args.frames),
**video_kwargs,
)
)
print(f"Video: {file_path.as_posix()}")
output_data_path_list.append(file_path)
# TODO: Clean this up.
except Exception:
print(f"Error reading file: {file_path.as_posix()}")

return output_provider_list, output_data_path_list


def _make_predictor_from_cli(args: argparse.Namespace) -> Predictor:
Expand Down Expand Up @@ -5422,8 +5459,6 @@ def main(args: Optional[list] = None):
pprint(vars(args))
print()

output_path = args.output

# Setup devices.
if args.cpu or not sleap.nn.system.is_gpu_system():
sleap.nn.system.use_cpu_only()
Expand Down Expand Up @@ -5461,43 +5496,168 @@ def main(args: Optional[list] = None):
print()

# Setup data loader.
provider, data_path = _make_provider_from_cli(args)
provider_list, data_path_list = _make_provider_from_cli(args)

output_path = args.output

# check if output_path is valid before running inference
if (
output_path is not None
and Path(output_path).is_file()
and len(data_path_list) > 1
):
raise ValueError(
"output_path argument must be a directory if multiple video inputs are given"
)

# Setup tracker.
tracker = _make_tracker_from_cli(args)

if args.models is not None and "movenet" in args.models[0]:
args.models = args.models[0]

# Either run inference (and tracking) or just run tracking
# Either run inference (and tracking) or just run tracking (if using an existing prediction where inference has already been run)
if args.models is not None:
# Setup models.
predictor = _make_predictor_from_cli(args)
predictor.tracker = tracker

# Run inference!
labels_pr = predictor.predict(provider)
# Run inference on all files inputed
for data_path, provider in zip(data_path_list, provider_list):
# Setup models.
data_path_obj = Path(data_path)
predictor = _make_predictor_from_cli(args)
predictor.tracker = tracker

# Run inference!
labels_pr = predictor.predict(provider)

if output_path is None:
output_path = data_path + ".predictions.slp"
# if output path was not provided, create an output path
if output_path is None:
output_path = f"{data_path.as_posix()}.predictions.slp"
output_path_obj = Path(output_path)

labels_pr.provenance["model_paths"] = predictor.model_paths
labels_pr.provenance["predictor"] = type(predictor).__name__
else:
output_path_obj = Path(output_path)

# if output_path was provided and multiple inputs were provided, create a directory to store outputs
if len(data_path_list) > 1:
output_path = (
output_path_obj
/ data_path_obj.with_suffix(".predictions.slp").name
)
output_path_obj = Path(output_path)
# Create the containing directory if needed.
output_path_obj.parent.mkdir(exist_ok=True, parents=True)

labels_pr.provenance["model_paths"] = predictor.model_paths
labels_pr.provenance["predictor"] = type(predictor).__name__

if args.no_empty_frames:
# Clear empty frames if specified.
labels_pr.remove_empty_frames()

finish_timestamp = str(datetime.now())
total_elapsed = time() - t0
print("Finished inference at:", finish_timestamp)
print(f"Total runtime: {total_elapsed} secs")
print(f"Predicted frames: {len(labels_pr)}/{len(provider)}")

# Add provenance metadata to predictions.
labels_pr.provenance["sleap_version"] = sleap.__version__
labels_pr.provenance["platform"] = platform.platform()
labels_pr.provenance["command"] = " ".join(sys.argv)
labels_pr.provenance["data_path"] = data_path_obj.as_posix()
labels_pr.provenance["output_path"] = output_path_obj.as_posix()
labels_pr.provenance["total_elapsed"] = total_elapsed
labels_pr.provenance["start_timestamp"] = start_timestamp
labels_pr.provenance["finish_timestamp"] = finish_timestamp

print("Provenance:")
pprint(labels_pr.provenance)
print()

labels_pr.provenance["args"] = vars(args)

# Save results.
labels_pr.save(output_path)
print("Saved output:", output_path)

if args.open_in_gui:
subprocess.call(["sleap-label", output_path])

# Reset output_path for next iteration
output_path = args.output

# running tracking on existing prediction file
elif getattr(args, "tracking.tracker") is not None:
# Load predictions
print("Loading predictions...")
labels_pr = sleap.load_file(args.data_path)
frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx)
for data_path, provider in zip(data_path_list, provider_list):
# Load predictions
data_path_obj = Path(data_path)
print("Loading predictions...")
labels_pr = sleap.load_file(data_path_obj.as_posix())
frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx)

print("Starting tracker...")
frames = run_tracker(frames=frames, tracker=tracker)
tracker.final_pass(frames)

labels_pr = Labels(labeled_frames=frames)

if output_path is None:
output_path = f"{data_path}.{tracker.get_name()}.slp"
output_path_obj = Path(output_path)

else:
output_path_obj = Path(output_path)
if (
output_path_obj.exists()
and output_path_obj.is_file()
and len(data_path_list) > 1
):
raise ValueError(
"output_path argument must be a directory if multiple video inputs are given"
)

print("Starting tracker...")
frames = run_tracker(frames=frames, tracker=tracker)
tracker.final_pass(frames)
elif not output_path_obj.exists() and len(data_path_list) > 1:
output_path = output_path_obj / data_path_obj.with_suffix(
".predictions.slp"
)
output_path_obj = Path(output_path)
output_path_obj.parent.mkdir(exist_ok=True, parents=True)

if args.no_empty_frames:
# Clear empty frames if specified.
labels_pr.remove_empty_frames()

finish_timestamp = str(datetime.now())
total_elapsed = time() - t0
print("Finished inference at:", finish_timestamp)
print(f"Total runtime: {total_elapsed} secs")
print(f"Predicted frames: {len(labels_pr)}/{len(provider)}")

# Add provenance metadata to predictions.
labels_pr.provenance["sleap_version"] = sleap.__version__
labels_pr.provenance["platform"] = platform.platform()
labels_pr.provenance["command"] = " ".join(sys.argv)
labels_pr.provenance["data_path"] = data_path_obj.as_posix()
labels_pr.provenance["output_path"] = output_path_obj.as_posix()
labels_pr.provenance["total_elapsed"] = total_elapsed
labels_pr.provenance["start_timestamp"] = start_timestamp
labels_pr.provenance["finish_timestamp"] = finish_timestamp

print("Provenance:")
pprint(labels_pr.provenance)
print()

labels_pr.provenance["args"] = vars(args)

labels_pr = Labels(labeled_frames=frames)
# Save results.
labels_pr.save(output_path)
print("Saved output:", output_path)

if output_path is None:
output_path = f"{data_path}.{tracker.get_name()}.slp"
if args.open_in_gui:
subprocess.call(["sleap-label", output_path])

# Reset output_path for next iteration
output_path = args.output

else:
raise ValueError(
Expand All @@ -5506,36 +5666,3 @@ def main(args: Optional[list] = None):
"To retrack on predictions, must specify tracker. "
"Use \"sleap-track --tracking.tracker ...' to specify tracker to use."
)

if args.no_empty_frames:
# Clear empty frames if specified.
labels_pr.remove_empty_frames()

finish_timestamp = str(datetime.now())
total_elapsed = time() - t0
print("Finished inference at:", finish_timestamp)
print(f"Total runtime: {total_elapsed} secs")
print(f"Predicted frames: {len(labels_pr)}/{len(provider)}")

# Add provenance metadata to predictions.
labels_pr.provenance["sleap_version"] = sleap.__version__
labels_pr.provenance["platform"] = platform.platform()
labels_pr.provenance["command"] = " ".join(sys.argv)
labels_pr.provenance["data_path"] = data_path
labels_pr.provenance["output_path"] = output_path
labels_pr.provenance["total_elapsed"] = total_elapsed
labels_pr.provenance["start_timestamp"] = start_timestamp
labels_pr.provenance["finish_timestamp"] = finish_timestamp

print("Provenance:")
pprint(labels_pr.provenance)
print()

labels_pr.provenance["args"] = vars(args)

# Save results.
labels_pr.save(output_path)
print("Saved output:", output_path)

if args.open_in_gui:
subprocess.call(["sleap-label", output_path])
Loading

0 comments on commit 3e2bd25

Please sign in to comment.