Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion tests/multimodal/test_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
VideoLoader,
)

from .utils import create_video_from_image
from .utils import create_long_gop_video, create_video_from_image

pytestmark = pytest.mark.cpu_test

Expand Down Expand Up @@ -364,6 +364,49 @@ def test_pyav_dynamic_backend_loads_frames(
assert metadata["video_backend"] == "pyav_dynamic"


def test_pyav_backend_returns_target_frames_not_keyframes():
"""Regression test: PyAV must decode forward past the seek keyframe.

container.seek() snaps backward to the nearest keyframe. With a long GOP
(here: one keyframe at frame 0), a decoder that does not advance forward
to the target PTS collapses every sampled slot onto the keyframe. This
test encodes a per-frame marker on the green channel and verifies the
returned frames are distinct, ordered, and match the requested indices.
"""
num_frames = 50
num_sampled = 4
height, width = 64, 64

video_bytes = create_long_gop_video(
num_frames=num_frames, width=width, height=height
)

loader = VIDEO_LOADER_REGISTRY.load("opencv")
frames, metadata = loader.load_bytes(
video_bytes, num_frames=num_sampled, backend="pyav"
)
assert frames.shape == (num_sampled, height, width, 3)

requested = list(metadata["frames_indices"])
assert len(requested) == num_sampled

actual = [int(f[height // 2, width // 2, 1]) for f in frames]

assert len(set(actual)) == num_sampled, (
f"PyAV returned only {len(set(actual))} distinct frames for "
f"{num_sampled} requested indices: markers={actual}, "
f"requested={requested}. Keyframe-snap regression."
)

assert actual == sorted(actual), f"Returned frames out of order: markers={actual}"

for marker, want_idx in zip(actual, requested):
assert abs(marker - want_idx) <= 10, (
f"Frame mismatch: requested index {want_idx}, "
f"got marker {marker} (tolerance ±10)"
)


@pytest.mark.parametrize(
"loader_key, kwargs, expected_num_frames",
[
Expand Down
37 changes: 37 additions & 0 deletions tests/multimodal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,43 @@ def create_video_from_image(
return video_path


def create_long_gop_video(
num_frames: int = 50,
fps: int = 30,
width: int = 64,
height: int = 64,
) -> bytes:
"""Encode an H.264 clip with one keyframe and green-channel = frame index.

The marker lets a test recover which frame the decoder actually returned,
independent of any metadata label.
"""
import io

import av

buf = io.BytesIO()
with av.open(buf, mode="w", format="mp4") as container:
stream = container.add_stream("h264", rate=fps)
stream.width = width
stream.height = height
stream.pix_fmt = "yuv420p"
stream.codec_context.gop_size = num_frames
stream.codec_context.max_b_frames = 0
stream.codec_context.options = {
"x264-params": (f"scenecut=0:keyint={num_frames}:min-keyint={num_frames}")
}
for i in range(num_frames):
img = np.zeros((height, width, 3), dtype=np.uint8)
img[:, :, 1] = i
frame = av.VideoFrame.from_ndarray(img, format="rgb24")
for packet in stream.encode(frame):
container.mux(packet)
for packet in stream.encode():
container.mux(packet)
return buf.getvalue()


def cosine_similarity(A: npt.NDArray, B: npt.NDArray, axis: int = -1) -> npt.NDArray:
"""Compute cosine similarity between two vectors."""
return np.sum(A * B, axis=axis) / (
Expand Down
24 changes: 19 additions & 5 deletions vllm/multimodal/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def decode_frames(
fps: float,
duration: float,
) -> tuple[npt.NDArray, list[int]]:
"""Decode target frames via per-frame seek + keyframe decode."""
"""Decode target frames via per-frame seek + forward decode to PTS."""
stream = container.streams.video[0]
# SLICE parallelizes within a single frame without the
# one-frame-per-thread latency penalty of FRAME threading.
Expand All @@ -402,14 +402,28 @@ def decode_frames(
frame_interval = 1.0 / fps if fps > 0 else 0.1
max_ts = max(0.0, duration - frame_interval) if duration > 0 else float("inf")

decoder = None
last_pts = None
for idx in frame_indices:
ts = min(idx / fps, max_ts) if fps > 0 else 0.0
pts = int(ts / time_base)
container.seek(pts, stream=stream)
frame = next(container.decode(video=0), None)
if frame is not None:
frames_list.append(frame.to_ndarray(format="rgb24"))
# seek() snaps backward to a keyframe; reuse the running decoder
# while targets advance monotonically to avoid re-decoding the
# GOP prefix once per requested frame.
if decoder is None or last_pts is None or pts <= last_pts:
container.seek(pts, stream=stream)
decoder = container.decode(video=0)
chosen = None
for frame in decoder:
if frame.pts is not None and frame.pts >= pts:
chosen = frame
last_pts = frame.pts
break
if chosen is not None:
frames_list.append(chosen.to_ndarray(format="rgb24"))
valid_indices.append(idx)
else:
decoder = None

if not frames_list:
return np.empty((0,), dtype=np.uint8), valid_indices
Expand Down
Loading