Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions tests/entrypoints/openai_api/test_video_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from vllm_omni.entrypoints.openai.serving_video import OmniOpenAIServingVideo
from vllm_omni.entrypoints.openai.storage import LocalStorageManager
from vllm_omni.entrypoints.openai.stores import AsyncDictStore, TaskRegistry
from vllm_omni.inputs.data import OmniDiffusionSamplingParams

pytestmark = [pytest.mark.core_model, pytest.mark.cpu]

Expand Down Expand Up @@ -60,6 +61,7 @@ def custom_output(self):
class FakeAsyncOmni:
def __init__(self):
self.stage_configs = [SimpleNamespace(stage_type="diffusion")]
self.default_sampling_params_list = [OmniDiffusionSamplingParams()]
self.captured_prompt = None
self.captured_sampling_params_list = None

Expand Down Expand Up @@ -442,6 +444,87 @@ def test_frame_interpolation_params_pass_to_diffusion_sampling_params(test_clien
assert captured.frame_interpolation_model_path == "local-rife"


def test_default_sampling_params_apply_to_video_requests(test_client, mocker: MockerFixture):
mocker.patch(
"vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
return_value=b"fake-video",
)
engine = test_client.app.state.openai_serving_video._engine_client
engine.default_sampling_params_list = [
OmniDiffusionSamplingParams(
num_inference_steps=4,
guidance_scale=7.5,
generator_device="cpu",
enable_frame_interpolation=True,
frame_interpolation_exp=2,
frame_interpolation_scale=0.5,
frame_interpolation_model_path="default-rife",
)
]

response = test_client.post(
"/v1/videos",
data={
"prompt": "default param pass-through",
},
)

assert response.status_code == 200
video_id = response.json()["id"]
_wait_for_status(test_client, video_id, VideoGenerationStatus.COMPLETED.value)

captured = engine.captured_sampling_params_list[0]
assert captured.num_inference_steps == 4
assert captured.guidance_scale == 7.5
assert captured.generator_device == "cpu"
assert captured.enable_frame_interpolation is True
assert captured.frame_interpolation_exp == 2
assert captured.frame_interpolation_scale == 0.5
assert captured.frame_interpolation_model_path == "default-rife"


def test_request_params_override_default_video_sampling_params(test_client, mocker: MockerFixture):
mocker.patch(
"vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
return_value=b"fake-video",
)
engine = test_client.app.state.openai_serving_video._engine_client
engine.default_sampling_params_list = [
OmniDiffusionSamplingParams(
num_inference_steps=4,
guidance_scale=7.5,
enable_frame_interpolation=True,
frame_interpolation_exp=2,
frame_interpolation_scale=0.5,
frame_interpolation_model_path="default-rife",
)
]

response = test_client.post(
"/v1/videos",
data={
"prompt": "explicit override",
"num_inference_steps": "8",
"enable_frame_interpolation": "false",
"frame_interpolation_exp": "1",
"frame_interpolation_scale": "1.0",
"frame_interpolation_model_path": "custom-rife",
},
)

assert response.status_code == 200
video_id = response.json()["id"]
_wait_for_status(test_client, video_id, VideoGenerationStatus.COMPLETED.value)

captured = engine.captured_sampling_params_list[0]
assert captured.num_inference_steps == 8
assert captured.guidance_scale == 7.5
assert captured.enable_frame_interpolation is False
assert captured.frame_interpolation_exp == 1
assert captured.frame_interpolation_scale == 1.0
assert captured.frame_interpolation_model_path == "custom-rife"


def test_worker_fps_multiplier_is_applied_to_async_encoding(test_client, mocker: MockerFixture):
fps_values = []
engine = test_client.app.state.openai_serving_video._engine_client
Expand Down Expand Up @@ -1149,6 +1232,39 @@ def test_sync_frame_interpolation_params_pass_to_sampling_params(test_client, mo
assert kwargs["fps"] == 8


def test_sync_default_sampling_params_apply_to_video_requests(test_client, mocker: MockerFixture):
_mock_encode_video_bytes(mocker)
engine = test_client.app.state.openai_serving_video._engine_client
engine.default_sampling_params_list = [
OmniDiffusionSamplingParams(
num_inference_steps=4,
guidance_scale=7.5,
enable_frame_interpolation=True,
frame_interpolation_exp=2,
frame_interpolation_scale=0.5,
frame_interpolation_model_path="default-rife",
)
]

response = test_client.post(
"/v1/videos/sync",
data={
"prompt": "sync default param pass-through",
"fps": "8",
},
)

assert response.status_code == 200
engine = test_client.app.state.openai_serving_video._engine_client
captured = engine.captured_sampling_params_list[0]
assert captured.num_inference_steps == 4
assert captured.guidance_scale == 7.5
assert captured.enable_frame_interpolation is True
assert captured.frame_interpolation_exp == 2
assert captured.frame_interpolation_scale == 0.5
assert captured.frame_interpolation_model_path == "default-rife"


def test_worker_fps_multiplier_is_applied_to_sync_encoding(test_client, mocker: MockerFixture):
engine = test_client.app.state.openai_serving_video._engine_client
fps_values = []
Expand Down
6 changes: 3 additions & 3 deletions vllm_omni/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2361,9 +2361,9 @@ async def _parse_video_form(
true_cfg_scale: float | None = Form(default=None),
seed: int | None = Form(default=None),
negative_prompt: str | None = Form(default=None),
enable_frame_interpolation: bool = Form(default=False),
frame_interpolation_exp: int = Form(default=1, ge=1),
frame_interpolation_scale: float = Form(default=1.0, gt=0.0),
enable_frame_interpolation: bool | None = Form(default=None),
frame_interpolation_exp: int | None = Form(default=None, ge=1),
frame_interpolation_scale: float | None = Form(default=None, gt=0.0),
frame_interpolation_model_path: str | None = Form(default=None),
lora: str | None = Form(default=None),
extra_params: str | None = Form(default=None),
Expand Down
43 changes: 30 additions & 13 deletions vllm_omni/entrypoints/openai/serving_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from __future__ import annotations

import copy
import time
from dataclasses import dataclass
from http import HTTPStatus
Expand Down Expand Up @@ -95,7 +96,7 @@ async def _run_and_extract(
if request.negative_prompt is not None:
prompt["negative_prompt"] = request.negative_prompt

gen_params = OmniDiffusionSamplingParams()
gen_params = self._resolve_default_sampling_params()

input_image = None if reference_image is None else reference_image.data
vp = request.resolve_video_params()
Expand All @@ -113,30 +114,35 @@ async def _run_and_extract(
if vp.fps is not None:
gen_params.fps = vp.fps
gen_params.frame_rate = float(vp.fps)
gen_params.enable_frame_interpolation = request.enable_frame_interpolation
gen_params.frame_interpolation_exp = request.frame_interpolation_exp
gen_params.frame_interpolation_scale = request.frame_interpolation_scale
gen_params.frame_interpolation_model_path = request.frame_interpolation_model_path

if request.num_inference_steps is not None:
provided_fields = request.model_fields_set
if "enable_frame_interpolation" in provided_fields:
gen_params.enable_frame_interpolation = request.enable_frame_interpolation
if "frame_interpolation_exp" in provided_fields:
gen_params.frame_interpolation_exp = request.frame_interpolation_exp
if "frame_interpolation_scale" in provided_fields:
gen_params.frame_interpolation_scale = request.frame_interpolation_scale
if "frame_interpolation_model_path" in provided_fields:
gen_params.frame_interpolation_model_path = request.frame_interpolation_model_path

if "num_inference_steps" in provided_fields and request.num_inference_steps is not None:
gen_params.num_inference_steps = request.num_inference_steps
if request.guidance_scale is not None:
if "guidance_scale" in provided_fields and request.guidance_scale is not None:
gen_params.guidance_scale = request.guidance_scale
if request.guidance_scale_2 is not None:
if "guidance_scale_2" in provided_fields and request.guidance_scale_2 is not None:
gen_params.guidance_scale_2 = request.guidance_scale_2
if request.true_cfg_scale is not None:
if "true_cfg_scale" in provided_fields and request.true_cfg_scale is not None:
gen_params.true_cfg_scale = request.true_cfg_scale
if request.seed is not None:
if "seed" in provided_fields and request.seed is not None:
gen_params.seed = request.seed
if request.boundary_ratio is not None:
if "boundary_ratio" in provided_fields and request.boundary_ratio is not None:
gen_params.boundary_ratio = request.boundary_ratio

logger.info(
"Boundary ratio parse: request=%s gen_params=%s",
request.boundary_ratio,
gen_params.boundary_ratio,
)
if request.flow_shift is not None:
if "flow_shift" in provided_fields and request.flow_shift is not None:
gen_params.extra_args["flow_shift"] = request.flow_shift

# Apply model-specific extra parameters
Expand Down Expand Up @@ -263,6 +269,17 @@ def _resolve_video_fps_multiplier(result: Any) -> int:
return int(multiplier)
return 1

def _resolve_default_sampling_params(self) -> OmniDiffusionSamplingParams:
default_sampling_params_list = getattr(self._engine_client, "default_sampling_params_list", None)
if default_sampling_params_list:
for params in default_sampling_params_list:
if isinstance(params, OmniDiffusionSamplingParams):
# Requests mutate sampling params in-place, including
# nested dict fields like extra_args. Deep-copy the stage
# defaults so one request cannot leak state into another.
return copy.deepcopy(params)
return OmniDiffusionSamplingParams()

@staticmethod
def _apply_lora(lora_body: Any, gen_params: OmniDiffusionSamplingParams) -> None:
try:
Expand Down
Loading