Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions benchmarks/diffusion/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,8 @@ async def async_request_v1_videos(
video_bytes = await content_response.read()
output.response_body = video_bytes
output.success = True
if "stage_durations" in poll_json:
Comment thread
david6666666 marked this conversation as resolved.
output.stage_durations = poll_json["stage_durations"] or {}
if "peak_memory_mb" in poll_json:
output.peak_memory_mb = poll_json["peak_memory_mb"]
elif "peak_memory_mb" in resp_json:
Expand Down
56 changes: 55 additions & 1 deletion tests/entrypoints/openai_api/test_video_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@


class MockVideoResult:
def __init__(self, videos, audios=None, sample_rate=None):
def __init__(self, videos, audios=None, sample_rate=None, stage_durations=None, peak_memory_mb=0.0):
self.multimodal_output = {"video": videos}
if audios is not None:
self.multimodal_output["audio"] = audios
if sample_rate is not None:
self.multimodal_output["audio_sample_rate"] = sample_rate
self.stage_durations = stage_durations or {}
self.peak_memory_mb = peak_memory_mb


class FakeAsyncOmni:
Expand Down Expand Up @@ -371,6 +373,33 @@ async def _generate(prompt, request_id, sampling_params_list):
assert audio_sample_rates == [16000]


def test_video_job_persists_profiler_metadata(test_client, mocker: MockerFixture):
engine = test_client.app.state.openai_serving_video._engine_client

async def _generate(prompt, request_id, sampling_params_list):
engine.captured_prompt = prompt
engine.captured_sampling_params_list = sampling_params_list
yield MockVideoResult(
[object()],
stage_durations={"diffuse": 2.5, "vae.decode": 0.3},
peak_memory_mb=4096.5,
)

engine.generate = _generate
mocker.patch(
"vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
return_value="Zg==",
)

response = test_client.post("/v1/videos", data={"prompt": "profile me"})
assert response.status_code == 200
video_id = response.json()["id"]
completed = _wait_for_status(test_client, video_id, VideoGenerationStatus.COMPLETED.value)

assert completed["stage_durations"] == {"diffuse": 2.5, "vae.decode": 0.3}
assert completed["peak_memory_mb"] == 4096.5


def test_missing_handler_returns_503():
app = FastAPI()
app.include_router(router)
Expand Down Expand Up @@ -770,6 +799,31 @@ def test_sync_t2v_returns_video_bytes(test_client, mocker: MockerFixture):
assert response.headers["x-request-id"].startswith("video_sync-")
assert response.headers["x-model"] == "Wan-AI/Wan2.2-T2V-A14B-Diffusers"
assert float(response.headers["x-inference-time-s"]) >= 0
assert json.loads(response.headers["x-stage-durations"]) == {}
assert float(response.headers["x-peak-memory-mb"]) == 0.0


def test_sync_t2v_returns_profiler_headers(test_client, mocker: MockerFixture):
engine = test_client.app.state.openai_serving_video._engine_client

async def _generate(prompt, request_id, sampling_params_list):
engine.captured_prompt = prompt
engine.captured_sampling_params_list = sampling_params_list
yield MockVideoResult(
[object()],
stage_durations={"diffuse": 1.75},
peak_memory_mb=1234.25,
)

engine.generate = _generate
_mock_encode_video_bytes(mocker, b"profiled-video")

response = test_client.post("/v1/videos/sync", data={"prompt": "sync profile"})

assert response.status_code == 200
assert response.content == b"profiled-video"
assert json.loads(response.headers["x-stage-durations"]) == {"diffuse": 1.75}
assert float(response.headers["x-peak-memory-mb"]) == pytest.approx(1234.25, rel=0, abs=1e-3)


def test_sync_i2v_returns_video_bytes(test_client, mocker: MockerFixture):
Expand Down
21 changes: 21 additions & 0 deletions tests/entrypoints/test_async_omni_diffusion_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,24 @@ def test_serve_cli_accepts_ulysses_mode():
assert args.ulysses_mode == "advanced_uaa"
assert parallel_config.ulysses_degree == 4
assert parallel_config.ulysses_mode == "advanced_uaa"


def test_serve_cli_accepts_diffusion_pipeline_profiler_flag():
"""Ensure diffusion serve CLI exposes the profiler switch."""
parser = FlexibleArgumentParser()
subparsers = parser.add_subparsers(dest="command")
OmniServeCommand().subparser_init(subparsers)

args = parser.parse_args(
[
"serve",
"Wan-AI/Wan2.2-T2V-A14B-Diffusers",
"--omni",
"--enable-diffusion-pipeline-profiler",
]
)

stage_cfg = _create_default_diffusion_stage_cfg(args)[0]

assert args.enable_diffusion_pipeline_profiler is True
assert stage_cfg["engine_args"]["enable_diffusion_pipeline_profiler"] is True
6 changes: 5 additions & 1 deletion vllm_omni/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2008,6 +2008,8 @@ async def _run_video_generation_job(
"file_name": file_name,
"completed_at": int(time.time()),
"inference_time_s": time.perf_counter() - started_at,
"stage_durations": response.stage_durations,
"peak_memory_mb": response.peak_memory_mb,
},
)
except Exception as exc:
Expand Down Expand Up @@ -2181,7 +2183,7 @@ async def create_video_sync(
request_id = f"video_sync-{random_uuid()}"
started_at = time.perf_counter()
try:
video_bytes = await asyncio.wait_for(
video_bytes, stage_durations, peak_memory_mb = await asyncio.wait_for(
handler.generate_video_bytes(request, request_id, reference_image=reference_image),
timeout=VIDEO_SYNC_TIMEOUT_S,
)
Expand All @@ -2207,6 +2209,8 @@ async def create_video_sync(
"X-Request-Id": request_id,
"X-Model": effective_model_name,
"X-Inference-Time-S": f"{inference_time_s:.3f}",
"X-Stage-Durations": json.dumps(stage_durations, separators=(",", ":")),
"X-Peak-Memory-MB": f"{peak_memory_mb:.3f}",
},
)

Expand Down
16 changes: 16 additions & 0 deletions vllm_omni/entrypoints/openai/protocol/videos.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,14 @@ class VideoGenerationResponse(BaseModel):

created: int = Field(..., description="Unix timestamp of when the generation completed")
data: list[VideoData] = Field(..., description="Array of generated videos")
stage_durations: dict[str, float] = Field(
default_factory=dict,
description="Profiler stage durations reported by the diffusion pipeline.",
)
peak_memory_mb: float = Field(
default=0.0,
description="Peak device memory usage in MB reported by the diffusion pipeline.",
)


class VideoError(BaseModel):
Expand Down Expand Up @@ -250,6 +258,14 @@ class VideoResponse(BaseModel):
description="Filename of the saved output video files for this job.",
)
inference_time_s: float | None = Field(default=None, description="End-to-end inference time in seconds.")
stage_durations: dict[str, float] = Field(
default_factory=dict,
description="Profiler stage durations reported by the diffusion pipeline.",
)
peak_memory_mb: float = Field(
default=0.0,
description="Peak device memory usage in MB reported by the diffusion pipeline.",
)

@property
def file_extension(self) -> str:
Expand Down
87 changes: 59 additions & 28 deletions vllm_omni/entrypoints/openai/serving_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,18 @@ class ReferenceImage:
data: Image.Image


@dataclass
class VideoGenerationArtifacts:
"""Normalized outputs and profiler metadata extracted from one request."""

videos: list[Any]
audios: list[Any | None]
audio_sample_rate: int
output_fps: int
stage_durations: dict[str, float]
peak_memory_mb: float


class OmniOpenAIServingVideo:
"""OpenAI-style video generation handler for omni diffusion models."""

Expand Down Expand Up @@ -77,12 +89,8 @@ async def _run_and_extract(
reference_id: str,
*,
reference_image: ReferenceImage | None = None,
) -> tuple[list[Any], list[Any | None], int, int]:
"""Run the generation pipeline and extract video/audio outputs.

Returns:
Tuple of (videos, audios, audio_sample_rate, output_fps).
"""
) -> VideoGenerationArtifacts:
"""Run the generation pipeline and extract video/audio/profiler outputs."""
prompt: OmniTextPrompt = OmniTextPrompt(prompt=request.prompt)
if request.negative_prompt is not None:
prompt["negative_prompt"] = request.negative_prompt
Expand Down Expand Up @@ -153,7 +161,14 @@ async def _run_and_extract(
audios = self._extract_audio_outputs(result, expected_count=len(videos))
audio_sample_rate = self._resolve_audio_sample_rate(result)
output_fps = vp.fps or self._resolve_fps(result) or 24
return videos, audios, audio_sample_rate, output_fps
return VideoGenerationArtifacts(
videos=videos,
audios=audios,
audio_sample_rate=audio_sample_rate,
output_fps=output_fps,
stage_durations=self._extract_stage_durations(result),
peak_memory_mb=self._extract_peak_memory_mb(result),
)

async def generate_videos(
self,
Expand All @@ -162,54 +177,57 @@ async def generate_videos(
*,
reference_image: ReferenceImage | None = None,
) -> VideoGenerationResponse:
videos, audios, audio_sample_rate, output_fps = await self._run_and_extract(
request, reference_id, reference_image=reference_image
)
artifacts = await self._run_and_extract(request, reference_id, reference_image=reference_image)
_t_encode_start = time.perf_counter()
video_data = [
VideoData(
b64_json=(
encode_video_base64(video, fps=output_fps)
if audios[idx] is None
encode_video_base64(video, fps=artifacts.output_fps)
if artifacts.audios[idx] is None
else encode_video_base64(
video,
fps=output_fps,
audio=audios[idx],
audio_sample_rate=audio_sample_rate,
fps=artifacts.output_fps,
audio=artifacts.audios[idx],
audio_sample_rate=artifacts.audio_sample_rate,
)
)
)
for idx, video in enumerate(videos)
for idx, video in enumerate(artifacts.videos)
]
_t_encode_ms = (time.perf_counter() - _t_encode_start) * 1000
logger.info("Video response encoding (MP4+base64): %.2f ms", _t_encode_ms)
return VideoGenerationResponse(created=int(time.time()), data=video_data)
return VideoGenerationResponse(
created=int(time.time()),
data=video_data,
stage_durations=artifacts.stage_durations,
peak_memory_mb=artifacts.peak_memory_mb,
)

async def generate_video_bytes(
self,
request: VideoGenerationRequest,
reference_id: str,
*,
reference_image: ReferenceImage | None = None,
) -> bytes:
) -> tuple[bytes, dict[str, float], float]:
"""Generate a video and return raw MP4 bytes, bypassing base64 encoding."""
videos, audios, audio_sample_rate, output_fps = await self._run_and_extract(
request, reference_id, reference_image=reference_image
)
if len(videos) > 1:
artifacts = await self._run_and_extract(request, reference_id, reference_image=reference_image)
if len(artifacts.videos) > 1:
logger.warning(
"Video request %s generated %d outputs; returning only the first.", reference_id, len(videos)
"Video request %s generated %d outputs; returning only the first.",
reference_id,
len(artifacts.videos),
)
audio = audios[0]
audio = artifacts.audios[0]
_t_encode_start = time.perf_counter()
video_bytes = _encode_video_bytes(
videos[0],
fps=output_fps,
**({"audio": audio, "audio_sample_rate": audio_sample_rate} if audio is not None else {}),
artifacts.videos[0],
fps=artifacts.output_fps,
**({"audio": audio, "audio_sample_rate": artifacts.audio_sample_rate} if audio is not None else {}),
)
_t_encode_ms = (time.perf_counter() - _t_encode_start) * 1000
logger.info("Video response encoding (MP4 bytes): %.2f ms", _t_encode_ms)
return video_bytes
return video_bytes, artifacts.stage_durations, artifacts.peak_memory_mb

@staticmethod
def _apply_lora(lora_body: Any, gen_params: OmniDiffusionSamplingParams) -> None:
Expand Down Expand Up @@ -483,3 +501,16 @@ def _coerce_audio_sample_rate(value: Any) -> int | None:
return None

return sample_rate if sample_rate > 0 else None

@staticmethod
def _extract_stage_durations(result: Any) -> dict[str, float]:
stage_durations = getattr(result, "stage_durations", None)
return stage_durations if isinstance(stage_durations, dict) else {}

@staticmethod
def _extract_peak_memory_mb(result: Any) -> float:
peak_memory_mb = getattr(result, "peak_memory_mb", 0.0)
try:
return float(peak_memory_mb or 0.0)
except (TypeError, ValueError):
return 0.0
Loading