Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions tests/entrypoints/openai_api/test_video_api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_frame_interpolator_runs_actual_torch_tensor_path(monkeypatch):
assert torch.isfinite(output_video).all()


def test_frame_interpolator_prefers_input_tensor_device(monkeypatch):
def test_frame_interpolator_uses_platform_device_when_tensor_is_cpu(monkeypatch):
chosen_devices = []
model = rife_interpolator.Model().eval()

Expand All @@ -83,10 +83,11 @@ def _fake_ensure_model_loaded(*, preferred_device=None):
interpolator = rife_interpolator.FrameInterpolator()
monkeypatch.setattr(interpolator, "_ensure_model_loaded", _fake_ensure_model_loaded)
monkeypatch.setattr(model.flownet, "to", lambda device: model.flownet)
monkeypatch.setattr(rife_interpolator, "_select_torch_device", lambda: torch.device("cuda"))

video = torch.zeros(1, 3, 2, 32, 32)
output_video, multiplier = interpolator.interpolate_tensor(video, exp=1, scale=1.0)

assert chosen_devices == [video.device]
assert chosen_devices == [torch.device("cuda")]
assert multiplier == 2
assert output_video.shape == (1, 3, 3, 32, 32)
9 changes: 6 additions & 3 deletions vllm_omni/diffusion/postprocess/rife_interpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,9 +412,12 @@ def interpolate_tensor(
return restore_layout(video), 1

video, restore_range = _normalize_video_tensor_range(video)
# Prefer the decoded video's current device so CPU-offloaded requests do
# not move the tensor back to GPU just for interpolation.
model = self._ensure_model_loaded(preferred_device=video.device)
# A CPU tensor may be transport/offload state rather than an execution
# choice, so only trust it when it is already on an accelerator.
preferred_device = video.device
if preferred_device.type == "cpu":
preferred_device = _select_torch_device()
model = self._ensure_model_loaded(preferred_device=preferred_device)
video = video.to(model.device())
intermediates_per_pair = 2**exp // 2

Expand Down
Loading