|
12 | 12 |
|
13 | 13 | VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos") |
14 | 14 |
|
15 | | -test_videos = [ |
16 | | - "RATRACE_wave_f_nm_np1_fr_goo_37.avi", |
17 | | - "TrumanShow_wave_f_nm_np1_fr_med_26.avi", |
18 | | - "v_SoccerJuggling_g23_c01.avi", |
19 | | - "v_SoccerJuggling_g24_c01.avi", |
20 | | - "R6llTwEh07w.mp4", |
21 | | - "SOX5yA1l24A.mp4", |
22 | | - "WUzgd7C1pWA.mp4", |
23 | | -] |
24 | | - |
25 | 15 |
|
26 | 16 | @pytest.mark.skipif(_HAS_VIDEO_DECODER is False, reason="Didn't compile with support for gpu decoder") |
27 | 17 | class TestVideoGPUDecoder: |
28 | 18 | @pytest.mark.skipif(av is None, reason="PyAV unavailable") |
29 | | - def test_frame_reading(self): |
30 | | - for test_video in test_videos: |
31 | | - full_path = os.path.join(VIDEO_DIR, test_video) |
32 | | - decoder = VideoReader(full_path, device="cuda:0") |
33 | | - with av.open(full_path) as container: |
34 | | - for av_frame in container.decode(container.streams.video[0]): |
35 | | - av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray()) |
36 | | - vision_frames = next(decoder)["data"] |
37 | | - mean_delta = torch.mean(torch.abs(av_frames.float() - vision_frames.cpu().float())) |
38 | | - assert mean_delta < 0.75 |
| 19 | + @pytest.mark.parametrize( |
| 20 | + "video_file", |
| 21 | + [ |
| 22 | + "RATRACE_wave_f_nm_np1_fr_goo_37.avi", |
| 23 | + "TrumanShow_wave_f_nm_np1_fr_med_26.avi", |
| 24 | + "v_SoccerJuggling_g23_c01.avi", |
| 25 | + "v_SoccerJuggling_g24_c01.avi", |
| 26 | + "R6llTwEh07w.mp4", |
| 27 | + "SOX5yA1l24A.mp4", |
| 28 | + "WUzgd7C1pWA.mp4", |
| 29 | + ], |
| 30 | + ) |
| 31 | + def test_frame_reading(self, video_file): |
| 32 | + full_path = os.path.join(VIDEO_DIR, video_file) |
| 33 | + decoder = VideoReader(full_path, device="cuda:0") |
| 34 | + with av.open(full_path) as container: |
| 35 | + for av_frame in container.decode(container.streams.video[0]): |
| 36 | + av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray()) |
| 37 | + vision_frames = next(decoder)["data"] |
| 38 | + mean_delta = torch.mean(torch.abs(av_frames.float() - vision_frames.cpu().float())) |
| 39 | + assert mean_delta < 0.75 |
39 | 40 |
|
40 | 41 | @pytest.mark.skipif(av is None, reason="PyAV unavailable") |
41 | 42 | @pytest.mark.parametrize("keyframes", [True, False]) |
@@ -65,16 +66,27 @@ def test_seek_reading(self, keyframes, full_path, duration): |
65 | 66 | assert mean_delta < 0.75 |
66 | 67 |
|
67 | 68 | @pytest.mark.skipif(av is None, reason="PyAV unavailable") |
68 | | - def test_metadata(self): |
69 | | - for test_video in test_videos: |
70 | | - full_path = os.path.join(VIDEO_DIR, test_video) |
71 | | - decoder = VideoReader(full_path, device="cuda:0") |
72 | | - video_metadata = decoder.get_metadata()["video"] |
73 | | - with av.open(full_path) as container: |
74 | | - video = container.streams.video[0] |
75 | | - av_duration = float(video.duration * video.time_base) |
76 | | - assert math.isclose(video_metadata["duration"], av_duration, rel_tol=1e-2) |
77 | | - assert math.isclose(video_metadata["fps"], video.base_rate, rel_tol=1e-2) |
| 69 | + @pytest.mark.parametrize( |
| 70 | + "video_file", |
| 71 | + [ |
| 72 | + "RATRACE_wave_f_nm_np1_fr_goo_37.avi", |
| 73 | + "TrumanShow_wave_f_nm_np1_fr_med_26.avi", |
| 74 | + "v_SoccerJuggling_g23_c01.avi", |
| 75 | + "v_SoccerJuggling_g24_c01.avi", |
| 76 | + "R6llTwEh07w.mp4", |
| 77 | + "SOX5yA1l24A.mp4", |
| 78 | + "WUzgd7C1pWA.mp4", |
| 79 | + ], |
| 80 | + ) |
| 81 | + def test_metadata(self, video_file): |
| 82 | + full_path = os.path.join(VIDEO_DIR, video_file) |
| 83 | + decoder = VideoReader(full_path, device="cuda:0") |
| 84 | + video_metadata = decoder.get_metadata()["video"] |
| 85 | + with av.open(full_path) as container: |
| 86 | + video = container.streams.video[0] |
| 87 | + av_duration = float(video.duration * video.time_base) |
| 88 | + assert math.isclose(video_metadata["duration"], av_duration, rel_tol=1e-2) |
| 89 | + assert math.isclose(video_metadata["fps"], video.base_rate, rel_tol=1e-2) |
78 | 90 |
|
79 | 91 |
|
80 | 92 | if __name__ == "__main__": |
|
0 commit comments