|
6 | 6 | import pytest
|
7 | 7 | import torch
|
8 | 8 | import torchvision.io as io
|
9 |
| -from common_utils import assert_equal |
| 9 | +from common_utils import assert_equal, cpu_and_cuda |
10 | 10 | from torchvision import get_video_backend
|
11 | 11 |
|
12 | 12 |
|
@@ -255,22 +255,19 @@ def test_read_video_partially_corrupted_file(self):
|
255 | 255 | assert_equal(video, data)
|
256 | 256 |
|
257 | 257 | @pytest.mark.skipif(sys.platform == "win32", reason="temporarily disabled on Windows")
|
258 |
| - @pytest.mark.parametrize("device", ["cpu", "cuda"]) |
| 258 | + @pytest.mark.parametrize("device", cpu_and_cuda()) |
259 | 259 | def test_write_video_with_audio(self, device, tmpdir):
|
260 | 260 | f_name = os.path.join(VIDEO_DIR, "R6llTwEh07w.mp4")
|
261 | 261 | video_tensor, audio_tensor, info = io.read_video(f_name, pts_unit="sec")
|
262 | 262 |
|
263 |
| - video_tensor = video_tensor.to(device) |
264 |
| - audio_tensor = audio_tensor.to(device) |
265 |
| - |
266 | 263 | out_f_name = os.path.join(tmpdir, "testing.mp4")
|
267 | 264 | io.video.write_video(
|
268 | 265 | out_f_name,
|
269 |
| - video_tensor, |
| 266 | + video_tensor.to(device), |
270 | 267 | round(info["video_fps"]),
|
271 | 268 | video_codec="libx264rgb",
|
272 | 269 | options={"crf": "0"},
|
273 |
| - audio_array=audio_tensor, |
| 270 | + audio_array=audio_tensor.to(device), |
274 | 271 | audio_fps=info["audio_fps"],
|
275 | 272 | audio_codec="aac",
|
276 | 273 | )
|
|
0 commit comments