Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow csv and text file support on sleap track #1875

Merged
merged 10 commits into from
Jul 31, 2024
84 changes: 58 additions & 26 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5285,8 +5285,8 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]:
args: Parsed CLI namespace.

Returns:
A tuple of `(provider, data_path)` with the data `Provider` and path to the data
that was specified in the args.
`(provider_list, data_path_list, output_path_list)` with the data `Provider`, path to the data
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved
that was specified in the args, and list out output paths if a csv file was inputed.
"""

# Figure out which input path to use.
Expand All @@ -5299,72 +5299,96 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]:
)

data_path_obj = Path(data_path)

# Set output_path_list to None as a default to return later
output_path_list = None
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved

# Check that input value is valid
if not data_path_obj.exists():
raise ValueError("Path to data_path does not exist")

elif data_path_obj.suffix.lower() == ".csv":
try:
# Read the CSV file
df = pandas.read_csv(data_path)
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved

# Check if the 'data_path' and 'output_path' columns exist
if "data_path" in df.columns:
raw_data_path_list = df["data_path"].tolist()
else:
print("Column 'data_path' does not exist in data_path csv file.")
if "output_path" in df.columns:
output_path_list = df["output_path"].tolist()
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved

except FileNotFoundError as e:
raise ValueError(f"CSV file not found: {data_path}") from e
except pandas.errors.EmptyDataError as e:
raise ValueError(f"CSV file is empty: {data_path}") from e
except pandas.errors.ParserError as e:
raise ValueError(f"Error parsing CSV file: {data_path}") from e

# Check for multiple video inputs
# Compile file(s) into a list for later itteration
if data_path_obj.is_dir():
data_path_list = []
elif data_path_obj.is_dir():
raw_data_path_list = []
for file_path in data_path_obj.iterdir():
if file_path.is_file():
data_path_list.append(Path(file_path))
raw_data_path_list.append(Path(file_path))

emdavis02 marked this conversation as resolved.
Show resolved Hide resolved
elif data_path_obj.is_file():
data_path_list = [data_path_obj]
raw_data_path_list = [data_path_obj]

# Provider list to accomodate multiple video inputs
output_provider_list = []
output_data_path_list = []
for file_path in data_path_list:
provider_list = []
data_path_list = []
for file_path in raw_data_path_list:
# Create a provider for each file
if file_path.as_posix().endswith(".slp") and len(data_path_list) > 1:
if file_path.as_posix().endswith(".slp") and len(raw_data_path_list) > 1:
print(f"slp file skipped: {file_path.as_posix()}")

elif file_path.as_posix().endswith(".slp"):
labels = sleap.load_file(file_path.as_posix())

if args.only_labeled_frames:
output_provider_list.append(
provider_list.append(
LabelsReader.from_user_labeled_frames(labels)
)
elif args.only_suggested_frames:
output_provider_list.append(
provider_list.append(
LabelsReader.from_unlabeled_suggestions(labels)
)
elif getattr(args, "video.index") != "":
output_provider_list.append(
provider_list.append(
VideoReader(
video=labels.videos[int(getattr(args, "video.index"))],
example_indices=frame_list(args.frames),
)
)
else:
output_provider_list.append(LabelsReader(labels))
provider_list.append(LabelsReader(labels))

output_data_path_list.append(file_path)
data_path_list.append(file_path)

else:
try:
video_kwargs = dict(
dataset=vars(args).get("video.dataset"),
input_format=vars(args).get("video.input_format"),
)
output_provider_list.append(
provider_list.append(
VideoReader.from_filepath(
filename=file_path.as_posix(),
example_indices=frame_list(args.frames),
**video_kwargs,
)
)
print(f"Video: {file_path.as_posix()}")
output_data_path_list.append(file_path)
data_path_list.append(file_path)
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved
# TODO: Clean this up.
except Exception:
print(f"Error reading file: {file_path.as_posix()}")

return output_provider_list, output_data_path_list
return provider_list, data_path_list, output_path_list


def _make_predictor_from_cli(args: argparse.Namespace) -> Predictor:
Expand Down Expand Up @@ -5496,10 +5520,12 @@ def main(args: Optional[list] = None):
print()

# Setup data loader.
provider_list, data_path_list = _make_provider_from_cli(args)

output_path = args.output
provider_list, data_path_list, output_path_list = _make_provider_from_cli(args)

# if output_path has not been extracted from a csv file yet
if output_path_list is None:
output_path = args.output
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved

# check if output_path is valid before running inference
if (
output_path is not None
Expand All @@ -5520,7 +5546,7 @@ def main(args: Optional[list] = None):
if args.models is not None:

# Run inference on all files inputed
for data_path, provider in zip(data_path_list, provider_list):
for i, (data_path, provider) in enumerate(zip(data_path_list, provider_list)):
# Setup models.
data_path_obj = Path(data_path)
predictor = _make_predictor_from_cli(args)
Expand All @@ -5531,11 +5557,17 @@ def main(args: Optional[list] = None):

# if output path was not provided, create an output path
if output_path is None:
output_path = f"{data_path.as_posix()}.predictions.slp"
output_path_obj = Path(output_path)
output_path = data_path + ".predictions.slp"
# if output path was not provided, create an output path
if output_path_list is not None:
output_path = output_path_list[i]

elif output_path is None:
output_path = f"{data_path.as_posix()}.predictions.slp"
output_path_obj = Path(output_path)
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved

else:
output_path_obj = Path(output_path)
else:
output_path_obj = Path(output_path)
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved

# if output_path was provided and multiple inputs were provided, create a directory to store outputs
if len(data_path_list) > 1:
Expand Down
Loading