Skip to content

Commit 31eb3fb

Browse files
committed
final edits
1 parent a69a65c commit 31eb3fb

File tree

2 files changed

+21
-13
lines changed

2 files changed

+21
-13
lines changed

sleap/nn/inference.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5629,7 +5629,7 @@ def main(args: Optional[list] = None):
56295629
# Save results.
56305630
try:
56315631
labels_pr.save(output_path)
5632-
except Exception as e:
5632+
except Exception:
56335633
print("WARNING: Provided output path invalid.")
56345634
fallback_path = data_path_obj.with_suffix(".predictions.slp")
56355635
labels_pr.save(fallback_path)

tests/nn/test_inference.py

+20-12
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import tensorflow as tf
1313
import tensorflow_hub as hub
1414
from numpy.testing import assert_array_equal, assert_allclose
15+
from sleap.io.video import available_video_exts
1516

1617
import sleap
1718
from sleap.gui.learning import runners
@@ -1556,12 +1557,10 @@ def test_sleap_track_mult_input_slp(
15561557
sleap_track(args=args)
15571558

15581559
# Assert predictions file exists
1559-
expected_extensions = {
1560-
".mp4",
1561-
} # Add other video formats if necessary
1560+
expected_extensions = available_video_exts()
15621561

15631562
for file_path in slp_path_list:
1564-
if file_path in expected_extensions:
1563+
if file_path.suffix in expected_extensions:
15651564
expected_output_file = Path(file_path).with_suffix(".predictions.slp")
15661565
assert Path(expected_output_file).exists()
15671566

@@ -1602,9 +1601,10 @@ def test_sleap_track_mult_input_slp_mp4(
16021601
# Run inference
16031602
sleap_track(args=args)
16041603

1605-
# Assert predictions file exists
1604+
expected_extensions = available_video_exts()
1605+
16061606
for file_path in slp_path_list:
1607-
if file_path.suffix == ".mp4":
1607+
if file_path.suffix in expected_extensions:
16081608
expected_output_file = Path(file_path).with_suffix(".predictions.slp")
16091609
assert Path(expected_output_file).exists()
16101610

@@ -1643,8 +1643,10 @@ def test_sleap_track_mult_input_mp4(
16431643
sleap_track(args=args)
16441644

16451645
# Assert predictions file exists
1646+
expected_extensions = available_video_exts()
1647+
16461648
for file_path in slp_path_list:
1647-
if file_path.suffix == ".mp4":
1649+
if file_path.suffix in expected_extensions:
16481650
expected_output_file = Path(file_path).with_suffix(".predictions.slp")
16491651
assert Path(expected_output_file).exists()
16501652

@@ -1686,8 +1688,10 @@ def test_sleap_track_output_mult(
16861688
slp_path = Path(slp_path)
16871689

16881690
# Check if there are any files in the directory
1691+
expected_extensions = available_video_exts()
1692+
16891693
for file_path in slp_path_list:
1690-
if file_path.suffix == ".mp4":
1694+
if file_path.suffix in expected_extensions:
16911695
expected_output_file = output_path_obj / (
16921696
file_path.stem + ".predictions.slp"
16931697
)
@@ -1808,8 +1812,10 @@ def test_sleap_track_csv_input(
18081812
sleap_track(args=args)
18091813

18101814
# Assert predictions file exists
1815+
expected_extensions = available_video_exts()
1816+
18111817
for file_path in slp_path_list:
1812-
if file_path.suffix == ".mp4":
1818+
if file_path.suffix in expected_extensions:
18131819
expected_output_file = file_path.with_suffix(".TESTpredictions.slp")
18141820
assert Path(expected_output_file).exists()
18151821

@@ -1839,7 +1845,7 @@ def test_sleap_track_invalid_csv(
18391845
).split()
18401846

18411847
# Run inference and expect ValueError for missing 'data_path' column
1842-
with pytest.raises(ValueError):
1848+
with pytest.raises(ValueError, match=f"Column containing valid data_paths does not exist in the CSV file: {csv_missing_column_path}"):
18431849
sleap_track(args=args_missing_column)
18441850

18451851
# Create sleap-track command for empty CSV file
@@ -1850,7 +1856,7 @@ def test_sleap_track_invalid_csv(
18501856
).split()
18511857

18521858
# Run inference and expect ValueError for empty CSV file
1853-
with pytest.raises(ValueError):
1859+
with pytest.raises(ValueError, match = f"CSV file is empty: {csv_empty_path}"):
18541860
sleap_track(args=args_empty)
18551861

18561862

@@ -1894,8 +1900,10 @@ def test_sleap_track_text_file_input(
18941900
sleap_track(args=args)
18951901

18961902
# Assert predictions file exists
1903+
expected_extensions = available_video_exts()
1904+
18971905
for file_path in slp_path_list:
1898-
if file_path.suffix == ".mp4":
1906+
if file_path.suffix in expected_extensions:
18991907
expected_output_file = Path(file_path).with_suffix(".predictions.slp")
19001908
assert Path(expected_output_file).exists()
19011909

0 commit comments

Comments
 (0)