diff --git a/tests/entrypoints/openai_api/test_video_server.py b/tests/entrypoints/openai_api/test_video_server.py index 1cc5616657d..a29f4493c28 100644 --- a/tests/entrypoints/openai_api/test_video_server.py +++ b/tests/entrypoints/openai_api/test_video_server.py @@ -29,6 +29,7 @@ from vllm_omni.entrypoints.openai.serving_video import OmniOpenAIServingVideo from vllm_omni.entrypoints.openai.storage import LocalStorageManager from vllm_omni.entrypoints.openai.stores import AsyncDictStore, TaskRegistry +from vllm_omni.inputs.data import OmniDiffusionSamplingParams pytestmark = [pytest.mark.core_model, pytest.mark.cpu] @@ -60,6 +61,7 @@ def custom_output(self): class FakeAsyncOmni: def __init__(self): self.stage_configs = [SimpleNamespace(stage_type="diffusion")] + self.default_sampling_params_list = [OmniDiffusionSamplingParams()] self.captured_prompt = None self.captured_sampling_params_list = None @@ -442,6 +444,87 @@ def test_frame_interpolation_params_pass_to_diffusion_sampling_params(test_clien assert captured.frame_interpolation_model_path == "local-rife" +def test_default_sampling_params_apply_to_video_requests(test_client, mocker: MockerFixture): + mocker.patch( + "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes", + return_value=b"fake-video", + ) + engine = test_client.app.state.openai_serving_video._engine_client + engine.default_sampling_params_list = [ + OmniDiffusionSamplingParams( + num_inference_steps=4, + guidance_scale=7.5, + generator_device="cpu", + enable_frame_interpolation=True, + frame_interpolation_exp=2, + frame_interpolation_scale=0.5, + frame_interpolation_model_path="default-rife", + ) + ] + + response = test_client.post( + "/v1/videos", + data={ + "prompt": "default param pass-through", + }, + ) + + assert response.status_code == 200 + video_id = response.json()["id"] + _wait_for_status(test_client, video_id, VideoGenerationStatus.COMPLETED.value) + + captured = engine.captured_sampling_params_list[0] + assert captured.num_inference_steps == 4 + assert captured.guidance_scale == 7.5 + assert captured.generator_device == "cpu" + assert captured.enable_frame_interpolation is True + assert captured.frame_interpolation_exp == 2 + assert captured.frame_interpolation_scale == 0.5 + assert captured.frame_interpolation_model_path == "default-rife" + + +def test_request_params_override_default_video_sampling_params(test_client, mocker: MockerFixture): + mocker.patch( + "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes", + return_value=b"fake-video", + ) + engine = test_client.app.state.openai_serving_video._engine_client + engine.default_sampling_params_list = [ + OmniDiffusionSamplingParams( + num_inference_steps=4, + guidance_scale=7.5, + enable_frame_interpolation=True, + frame_interpolation_exp=2, + frame_interpolation_scale=0.5, + frame_interpolation_model_path="default-rife", + ) + ] + + response = test_client.post( + "/v1/videos", + data={ + "prompt": "explicit override", + "num_inference_steps": "8", + "enable_frame_interpolation": "false", + "frame_interpolation_exp": "1", + "frame_interpolation_scale": "1.0", + "frame_interpolation_model_path": "custom-rife", + }, + ) + + assert response.status_code == 200 + video_id = response.json()["id"] + _wait_for_status(test_client, video_id, VideoGenerationStatus.COMPLETED.value) + + captured = engine.captured_sampling_params_list[0] + assert captured.num_inference_steps == 8 + assert captured.guidance_scale == 7.5 + assert captured.enable_frame_interpolation is False + assert captured.frame_interpolation_exp == 1 + assert captured.frame_interpolation_scale == 1.0 + assert captured.frame_interpolation_model_path == "custom-rife" + + def test_worker_fps_multiplier_is_applied_to_async_encoding(test_client, mocker: MockerFixture): fps_values = [] engine = test_client.app.state.openai_serving_video._engine_client @@ -1149,6 +1232,39 @@ def test_sync_frame_interpolation_params_pass_to_sampling_params(test_client, mo assert kwargs["fps"] == 8 +def test_sync_default_sampling_params_apply_to_video_requests(test_client, mocker: MockerFixture): + _mock_encode_video_bytes(mocker) + engine = test_client.app.state.openai_serving_video._engine_client + engine.default_sampling_params_list = [ + OmniDiffusionSamplingParams( + num_inference_steps=4, + guidance_scale=7.5, + enable_frame_interpolation=True, + frame_interpolation_exp=2, + frame_interpolation_scale=0.5, + frame_interpolation_model_path="default-rife", + ) + ] + + response = test_client.post( + "/v1/videos/sync", + data={ + "prompt": "sync default param pass-through", + "fps": "8", + }, + ) + + assert response.status_code == 200 + engine = test_client.app.state.openai_serving_video._engine_client + captured = engine.captured_sampling_params_list[0] + assert captured.num_inference_steps == 4 + assert captured.guidance_scale == 7.5 + assert captured.enable_frame_interpolation is True + assert captured.frame_interpolation_exp == 2 + assert captured.frame_interpolation_scale == 0.5 + assert captured.frame_interpolation_model_path == "default-rife" + + def test_worker_fps_multiplier_is_applied_to_sync_encoding(test_client, mocker: MockerFixture): engine = test_client.app.state.openai_serving_video._engine_client fps_values = [] diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 4454f5bda10..4d348a0d890 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -2361,9 +2361,9 @@ async def _parse_video_form( true_cfg_scale: float | None = Form(default=None), seed: int | None = Form(default=None), negative_prompt: str | None = Form(default=None), - enable_frame_interpolation: bool = Form(default=False), - frame_interpolation_exp: int = Form(default=1, ge=1), - frame_interpolation_scale: float = Form(default=1.0, gt=0.0), + enable_frame_interpolation: bool | None = Form(default=None), + frame_interpolation_exp: int | None = Form(default=None, ge=1), + frame_interpolation_scale: float | None = Form(default=None, gt=0.0), frame_interpolation_model_path: str | None = Form(default=None), lora: str | None = Form(default=None), extra_params: str | None = Form(default=None), diff --git a/vllm_omni/entrypoints/openai/serving_video.py b/vllm_omni/entrypoints/openai/serving_video.py index 741295c7c25..a4be330eb47 100644 --- a/vllm_omni/entrypoints/openai/serving_video.py +++ b/vllm_omni/entrypoints/openai/serving_video.py @@ -3,6 +3,7 @@ from __future__ import annotations +import copy import time from dataclasses import dataclass from http import HTTPStatus @@ -95,7 +96,7 @@ async def _run_and_extract( if request.negative_prompt is not None: prompt["negative_prompt"] = request.negative_prompt - gen_params = OmniDiffusionSamplingParams() + gen_params = self._resolve_default_sampling_params() input_image = None if reference_image is None else reference_image.data vp = request.resolve_video_params() @@ -113,22 +114,27 @@ async def _run_and_extract( if vp.fps is not None: gen_params.fps = vp.fps gen_params.frame_rate = float(vp.fps) - gen_params.enable_frame_interpolation = request.enable_frame_interpolation - gen_params.frame_interpolation_exp = request.frame_interpolation_exp - gen_params.frame_interpolation_scale = request.frame_interpolation_scale - gen_params.frame_interpolation_model_path = request.frame_interpolation_model_path - - if request.num_inference_steps is not None: + provided_fields = request.model_fields_set + if "enable_frame_interpolation" in provided_fields: + gen_params.enable_frame_interpolation = request.enable_frame_interpolation + if "frame_interpolation_exp" in provided_fields: + gen_params.frame_interpolation_exp = request.frame_interpolation_exp + if "frame_interpolation_scale" in provided_fields: + gen_params.frame_interpolation_scale = request.frame_interpolation_scale + if "frame_interpolation_model_path" in provided_fields: + gen_params.frame_interpolation_model_path = request.frame_interpolation_model_path + + if "num_inference_steps" in provided_fields and request.num_inference_steps is not None: gen_params.num_inference_steps = request.num_inference_steps - if request.guidance_scale is not None: + if "guidance_scale" in provided_fields and request.guidance_scale is not None: gen_params.guidance_scale = request.guidance_scale - if request.guidance_scale_2 is not None: + if "guidance_scale_2" in provided_fields and request.guidance_scale_2 is not None: gen_params.guidance_scale_2 = request.guidance_scale_2 - if request.true_cfg_scale is not None: + if "true_cfg_scale" in provided_fields and request.true_cfg_scale is not None: gen_params.true_cfg_scale = request.true_cfg_scale - if request.seed is not None: + if "seed" in provided_fields and request.seed is not None: gen_params.seed = request.seed - if request.boundary_ratio is not None: + if "boundary_ratio" in provided_fields and request.boundary_ratio is not None: gen_params.boundary_ratio = request.boundary_ratio logger.info( @@ -136,7 +142,7 @@ async def _run_and_extract( request.boundary_ratio, gen_params.boundary_ratio, ) - if request.flow_shift is not None: + if "flow_shift" in provided_fields and request.flow_shift is not None: gen_params.extra_args["flow_shift"] = request.flow_shift # Apply model-specific extra parameters @@ -263,6 +269,17 @@ def _resolve_video_fps_multiplier(result: Any) -> int: return int(multiplier) return 1 + def _resolve_default_sampling_params(self) -> OmniDiffusionSamplingParams: + default_sampling_params_list = getattr(self._engine_client, "default_sampling_params_list", None) + if default_sampling_params_list: + for params in default_sampling_params_list: + if isinstance(params, OmniDiffusionSamplingParams): + # Requests mutate sampling params in-place, including + # nested dict fields like extra_args. Deep-copy the stage + # defaults so one request cannot leak state into another. + return copy.deepcopy(params) + return OmniDiffusionSamplingParams() + @staticmethod def _apply_lora(lora_body: Any, gen_params: OmniDiffusionSamplingParams) -> None: try: