diff --git a/tests/multimodal/test_video.py b/tests/multimodal/test_video.py index e82883ece338..7c024052a439 100644 --- a/tests/multimodal/test_video.py +++ b/tests/multimodal/test_video.py @@ -13,7 +13,7 @@ VideoLoader, ) -from .utils import create_video_from_image +from .utils import create_long_gop_video, create_video_from_image pytestmark = pytest.mark.cpu_test @@ -364,6 +364,49 @@ def test_pyav_dynamic_backend_loads_frames( assert metadata["video_backend"] == "pyav_dynamic" +def test_pyav_backend_returns_target_frames_not_keyframes(): + """Regression test: PyAV must decode forward past the seek keyframe. + + container.seek() snaps backward to the nearest keyframe. With a long GOP + (here: one keyframe at frame 0), a decoder that does not advance forward + to the target PTS collapses every sampled slot onto the keyframe. This + test encodes a per-frame marker on the green channel and verifies the + returned frames are distinct, ordered, and match the requested indices. + """ + num_frames = 50 + num_sampled = 4 + height, width = 64, 64 + + video_bytes = create_long_gop_video( + num_frames=num_frames, width=width, height=height + ) + + loader = VIDEO_LOADER_REGISTRY.load("opencv") + frames, metadata = loader.load_bytes( + video_bytes, num_frames=num_sampled, backend="pyav" + ) + assert frames.shape == (num_sampled, height, width, 3) + + requested = list(metadata["frames_indices"]) + assert len(requested) == num_sampled + + actual = [int(f[height // 2, width // 2, 1]) for f in frames] + + assert len(set(actual)) == num_sampled, ( + f"PyAV returned only {len(set(actual))} distinct frames for " + f"{num_sampled} requested indices: markers={actual}, " + f"requested={requested}. Keyframe-snap regression." + ) + + assert actual == sorted(actual), f"Returned frames out of order: markers={actual}" + + for marker, want_idx in zip(actual, requested): + assert abs(marker - want_idx) <= 10, ( + f"Frame mismatch: requested index {want_idx}, " + f"got marker {marker} (tolerance ±10)" + ) + + @pytest.mark.parametrize( "loader_key, kwargs, expected_num_frames", [ diff --git a/tests/multimodal/utils.py b/tests/multimodal/utils.py index 485bde939f69..32f3ec0e4233 100644 --- a/tests/multimodal/utils.py +++ b/tests/multimodal/utils.py @@ -66,6 +66,43 @@ def create_video_from_image( return video_path +def create_long_gop_video( + num_frames: int = 50, + fps: int = 30, + width: int = 64, + height: int = 64, +) -> bytes: + """Encode an H.264 clip with one keyframe and green-channel = frame index. + + The marker lets a test recover which frame the decoder actually returned, + independent of any metadata label. + """ + import io + + import av + + buf = io.BytesIO() + with av.open(buf, mode="w", format="mp4") as container: + stream = container.add_stream("h264", rate=fps) + stream.width = width + stream.height = height + stream.pix_fmt = "yuv420p" + stream.codec_context.gop_size = num_frames + stream.codec_context.max_b_frames = 0 + stream.codec_context.options = { + "x264-params": (f"scenecut=0:keyint={num_frames}:min-keyint={num_frames}") + } + for i in range(num_frames): + img = np.zeros((height, width, 3), dtype=np.uint8) + img[:, :, 1] = i + frame = av.VideoFrame.from_ndarray(img, format="rgb24") + for packet in stream.encode(frame): + container.mux(packet) + for packet in stream.encode(): + container.mux(packet) + return buf.getvalue() + + def cosine_similarity(A: npt.NDArray, B: npt.NDArray, axis: int = -1) -> npt.NDArray: """Compute cosine similarity between two vectors.""" return np.sum(A * B, axis=axis) / ( diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index 5b118af8fc53..697156a5b4dc 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -390,7 +390,7 @@ def decode_frames( fps: float, duration: float, ) -> tuple[npt.NDArray, list[int]]: - """Decode target frames via per-frame seek + keyframe decode.""" + """Decode target frames via per-frame seek + forward decode to PTS.""" stream = container.streams.video[0] # SLICE parallelizes within a single frame without the # one-frame-per-thread latency penalty of FRAME threading. @@ -402,14 +402,28 @@ def decode_frames( frame_interval = 1.0 / fps if fps > 0 else 0.1 max_ts = max(0.0, duration - frame_interval) if duration > 0 else float("inf") + decoder = None + last_pts = None for idx in frame_indices: ts = min(idx / fps, max_ts) if fps > 0 else 0.0 pts = int(ts / time_base) - container.seek(pts, stream=stream) - frame = next(container.decode(video=0), None) - if frame is not None: - frames_list.append(frame.to_ndarray(format="rgb24")) + # seek() snaps backward to a keyframe; reuse the running decoder + # while targets advance monotonically to avoid re-decoding the + # GOP prefix once per requested frame. + if decoder is None or last_pts is None or pts <= last_pts: + container.seek(pts, stream=stream) + decoder = container.decode(video=0) + chosen = None + for frame in decoder: + if frame.pts is not None and frame.pts >= pts: + chosen = frame + last_pts = frame.pts + break + if chosen is not None: + frames_list.append(chosen.to_ndarray(format="rgb24")) valid_indices.append(idx) + else: + decoder = None if not frames_list: return np.empty((0,), dtype=np.uint8), valid_indices