Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions vllm_omni/diffusion/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,49 @@ class DiffusionOutput:
peak_memory_mb: float = 0.0


Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dataclass here is a no-op — you define a custom __init__, so dataclass won't generate one, and there are no class-level field annotations for it to act on. Just drop the decorator.

Suggested change
class OmniRequestError(RuntimeError):

@dataclass
class OmniRequestError(RuntimeError):
def __init__(
self,
message: str,
*,
status_code: int = 500,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

status_code is a http related field, it should be defined in http related component, not in diffusion.

request_id: str | None = None,
stage_id: int | None = None,
error_type: str | None = None,
detail: dict[str, Any] | None = None,
):
super().__init__(message)
self.status_code = status_code
self.request_id = request_id
self.stage_id = stage_id
self.error_type = error_type or self.__class__.__name__
self.detail = detail or {}


def normalize_omni_error(
exc: Exception,
*,
status_code: int = 500,
request_id: str | None = None,
stage_id: int | None = None,
) -> OmniRequestError:
if isinstance(exc, OmniRequestError):
if exc.request_id is None:
exc.request_id = request_id
if exc.stage_id is None:
exc.stage_id = stage_id
return exc

return OmniRequestError(
str(exc),
status_code=status_code,
request_id=request_id,
stage_id=stage_id,
error_type=type(exc).__name__,
)


class AttentionBackendEnum(enum.Enum):
FA = enum.auto()
SLIDING_TILE_ATTN = enum.auto()
Expand Down
78 changes: 51 additions & 27 deletions vllm_omni/diffusion/diffusion_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
from vllm.logger import init_logger

from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig, OmniRequestError, normalize_omni_error
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: line too long. Split the import.

Suggested change
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig, OmniRequestError, normalize_omni_error
from vllm_omni.diffusion.data import (
DiffusionOutput,
OmniDiffusionConfig,
OmniRequestError,
normalize_omni_error,
)

from vllm_omni.diffusion.executor.abstract import DiffusionExecutor
from vllm_omni.diffusion.registry import (
DiffusionModelRegistry,
Expand Down Expand Up @@ -99,7 +99,12 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]:
exec_total_time = time.perf_counter() - exec_start_time

if output.error:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the commented-out code. Same for the other # raise RuntimeError(...) / # output = ... / # error_msg = ... / # self.scheduler.pop_request_state(...) lines throughout the PR.

raise Exception(f"{output.error}")
# raise Exception(f"{output.error}")
raise OmniRequestError(
output.error,
status_code=500,
error_type="DiffusionExecutionError",
)
logger.info("Generation completed successfully.")

if output.output is None:
Expand Down Expand Up @@ -280,33 +285,52 @@ def add_req_and_wait_for_response(self, request: OmniDiffusionRequest) -> Diffus
target_sched_req_id = self.scheduler.add_request(request)

# keep scheduling and executing until the target request is finished
while True:
sched_output = self.scheduler.schedule()
if sched_output.is_empty:
if not self.scheduler.has_requests():
raise RuntimeError("Diffusion scheduler has no runnable requests.")
continue

# NOTE: add_req_and_wait_for_response() is synchronous, and
# the scheduler currently enforces _max_batch_size = 1 (see
# vllm_omni/diffusion/sched/base_scheduler.py), so we directly
# take the single scheduled request here.
sched_req_id = sched_output.scheduled_req_ids[0]
req = sched_output.scheduled_new_reqs[0].req
try:
while True:
sched_output = self.scheduler.schedule()
if sched_output.is_empty:
if not self.scheduler.has_requests():
# raise RuntimeError("Diffusion scheduler has no runnable requests.")
raise OmniRequestError(
"Diffusion scheduler has no runnable requests.",
status_code=500,
error_type="SchedulerError",
)
continue

# NOTE: add_req_and_wait_for_response() is synchronous, and
# the scheduler currently enforces _max_batch_size = 1 (see
# vllm_omni/diffusion/sched/base_scheduler.py), so we directly
# take the single scheduled request here.
sched_req_id = sched_output.scheduled_req_ids[0]
req = sched_output.scheduled_new_reqs[0].req
try:
output = self.executor.add_req(req)
except Exception as exc:
logger.error(
"Execution failed for diffusion request %s",
sched_req_id,
exc_info=True,
)
# output = DiffusionOutput(error=str(exc))
raise normalize_omni_error(exc) from exc
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Update scheduler state before rethrowing executor failures

In DiffusionEngine.add_req_and_wait_for_response, this rethrow happens before self.scheduler.update_from_output(...) runs, so the failed scheduled request is never transitioned out of the scheduler’s running set. The finally block only calls pop_request_state, which does not clear _running; after one executor error (for example OOM), later requests can spin indefinitely because schedule() sees capacity as full, returns empty work, and has_requests() remains true forever. The failure path needs to mark the scheduled request finished (or otherwise remove it from _running) before propagating the exception.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreeing with the bot comment above — this raise skips scheduler.update_from_output(...), so the request stays in the scheduler's running set. On a single-slot scheduler (_max_batch_size=1), this permanently blocks all future requests. You need to call scheduler.abort_request(sched_req_id) (or equivalent) before re-raising.


finished_req_ids = self.scheduler.update_from_output(sched_output, output)

if output.error:
raise OmniRequestError(
output.error,
status_code=500,
error_type="DiffusionExecutionError",
)
if target_sched_req_id in finished_req_ids:
# self.scheduler.pop_request_state(target_sched_req_id)
return output
finally:
try:
output = self.executor.add_req(req)
except Exception as exc:
logger.error(
"Execution failed for diffusion request %s",
sched_req_id,
exc_info=True,
)
output = DiffusionOutput(error=str(exc))

finished_req_ids = self.scheduler.update_from_output(sched_output, output)
if target_sched_req_id in finished_req_ids:
self.scheduler.pop_request_state(target_sched_req_id)
return output
except Exception:
logger.debug("Request state already removed: %s", target_sched_req_id, exc_info=True)

def profile(self, is_start: bool = True, profile_prefix: str | None = None) -> None:
"""Start or stop torch profiling on all diffusion workers.
Expand Down
16 changes: 15 additions & 1 deletion vllm_omni/diffusion/stage_diffusion_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from vllm.logger import init_logger

from vllm_omni.diffusion.data import normalize_omni_error
from vllm_omni.engine.stage_init_utils import StageMetadata
from vllm_omni.entrypoints.async_omni_diffusion import AsyncOmniDiffusion
from vllm_omni.outputs import OmniRequestOutput
Expand Down Expand Up @@ -75,11 +76,24 @@ async def _run(
result = await self._engine.generate(prompt, sampling_params, request_id)
await self._output_queue.put(result)
except Exception as e:
err = normalize_omni_error(e, request_id=request_id, stage_id=self.stage_id)
logger.exception(
"[StageDiffusionClient] Stage-%s req=%s failed: %s",
self.stage_id,
request_id,
e,
err,
)
await self._output_queue.put(
{
"type": "error",
"request_id": request_id,
"stage_id": self.stage_id,
"status_code": err.status_code,
"error": str(err),
"error_type": err.error_type,
"detail": err.detail,
"finished": True,
}
)
finally:
self._tasks.pop(request_id, None)
Expand Down
23 changes: 23 additions & 0 deletions vllm_omni/engine/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,29 @@ async def _orchestration_loop(self) -> None:
output = stage_client.get_diffusion_output_async()
if output is not None:
idle = False

if isinstance(output, dict) and output.get("type") == "error":
req_id = output.get("request_id")
req_state = self.request_states.get(req_id)

await self.output_async_queue.put(
{
"type": "error",
"request_id": req_id,
"stage_id": stage_id,
"status_code": output.get("status_code", 500),
"error": output.get("error", "Unknown diffusion stage error"),
"error_type": output.get("error_type"),
"detail": output.get("detail") or {},
"finished": True,
"stage_submit_ts": req_state.stage_submit_ts.get(stage_id) if req_state else None,
}
)

self._cleanup_companion_state(req_id)
self.request_states.pop(req_id, None)
continue

req_state = self.request_states.get(output.request_id)
if req_state is not None:
stage_metrics = self._build_stage_metrics(stage_id, output.request_id, [output], req_state)
Expand Down
34 changes: 26 additions & 8 deletions vllm_omni/entrypoints/async_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from vllm.tasks import SupportedTask
from vllm.v1.engine.exceptions import EngineDeadError

from vllm_omni.diffusion.data import OmniRequestError
from vllm_omni.entrypoints.client_request_state import ClientRequestState
from vllm_omni.entrypoints.omni_base import OmniBase
from vllm_omni.metrics.stats import OrchestratorAggregator as OrchestratorMetrics
Expand Down Expand Up @@ -290,14 +291,15 @@ async def _process_orchestrator_results(
stage_id = result.get("stage_id", 0)

# Check for errors
if "error" in result:
logger.error(
"[AsyncOmni] Orchestrator error for req=%s stage-%s: %s",
request_id,
stage_id,
result["error"],
if isinstance(result, dict) and result.get("type") == "error":
raise OmniRequestError(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Clean up request state before propagating stage errors

Raising OmniRequestError directly here exits generate() through the exception path, but AsyncOmni.generate only calls _log_summary_and_cleanup(request_id) after the normal completion path, so failed requests remain in self.request_states. With the new diffusion error messages, repeated stage failures (e.g., OOMs) will accumulate stale per-request queues/metrics in memory and can eventually degrade long-running servers.

Useful? React with 👍 / 👎.

result.get("error", "Unknown orchestrator error"),
status_code=result.get("status_code", 500),
request_id=result.get("request_id", request_id),
stage_id=result.get("stage_id", stage_id),
error_type=result.get("error_type"),
detail=result.get("detail") or {},
)
raise RuntimeError(result)

# Process the result (constructs OmniRequestOutput)
output_to_yield = self._process_single_result(
Expand Down Expand Up @@ -344,6 +346,13 @@ async def _final_output_loop():
await asyncio.sleep(_FINAL_OUTPUT_IDLE_SLEEP_S)
continue

if isinstance(msg, dict) and msg.get("type") == "error":
req_id = msg.get("request_id")
req_state = self.request_states.get(req_id)
if req_state is not None:
await req_state.queue.put(msg)
continue

should_continue, _, stage_id, req_state = self._handle_output_message(msg)
if should_continue:
continue
Expand All @@ -358,7 +367,16 @@ async def _final_output_loop():
except Exception as e:
logger.exception("[AsyncOmni] final_output_loop failed.")
for req_state in list(self.request_states.values()):
error_msg = {"request_id": req_state.request_id, "error": str(e)}
# error_msg = {"request_id": req_state.request_id, "error": str(e)}
error_msg = {
"type": "error",
"request_id": req_state.request_id,
"status_code": 500,
"error": str(e),
"error_type": type(e).__name__,
"detail": {},
"finished": True,
}
await req_state.queue.put(error_msg)
self.final_output_task = None

Expand Down
8 changes: 5 additions & 3 deletions vllm_omni/entrypoints/async_omni_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from vllm.logger import init_logger
from vllm.transformers_utils.config import get_hf_file_to_dict

from vllm_omni.diffusion.data import OmniDiffusionConfig, TransformerConfig
from vllm_omni.diffusion.data import OmniDiffusionConfig, TransformerConfig, normalize_omni_error
from vllm_omni.diffusion.diffusion_engine import DiffusionEngine
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType
Expand Down Expand Up @@ -237,7 +237,8 @@ async def _generate_batch(
)
except Exception as e:
logger.error("Batch generation failed for request %s: %s", request_id, e)
raise RuntimeError(f"Diffusion batch generation failed: {e}") from e
# raise RuntimeError(f"Diffusion batch generation failed: {e}") from e
raise normalize_omni_error(e, request_id=request_id) from e

# Combine all per-prompt results into a single OmniRequestOutput
all_images = []
Expand Down Expand Up @@ -310,7 +311,8 @@ async def generate(
result = result[0]
except Exception as e:
logger.error("Generation failed for request %s: %s", request_id, e)
raise RuntimeError(f"Diffusion generation failed: {e}") from e
# raise RuntimeError(f"Diffusion generation failed: {e}") from e
raise normalize_omni_error(e, request_id=request_id) from e

if not result.request_id:
result.request_id = request_id
Expand Down
10 changes: 9 additions & 1 deletion vllm_omni/entrypoints/omni_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.logger import init_logger
from vllm.v1.engine.exceptions import EngineDeadError

from vllm_omni.diffusion.data import OmniRequestError
from vllm_omni.engine.async_omni_engine import AsyncOmniEngine
from vllm_omni.entrypoints.client_request_state import ClientRequestState
from vllm_omni.entrypoints.utils import get_final_stage_id_for_e2e
Expand Down Expand Up @@ -206,7 +207,14 @@ def _handle_output_message(
return True, None, None, None

if msg_type == "error":
raise RuntimeError(msg.get("error", "Orchestrator returned an error message"))
raise OmniRequestError(
msg.get("error", "Orchestrator returned an error message"),
status_code=msg.get("status_code", 500),
request_id=msg.get("request_id"),
stage_id=msg.get("stage_id"),
error_type=msg.get("error_type"),
detail=msg.get("detail") or {},
)

if msg_type != "output":
logger.warning("[%s] got unexpected msg type: %s", self.__class__.__name__, msg_type)
Expand Down
25 changes: 19 additions & 6 deletions vllm_omni/entrypoints/openai/serving_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.engine.protocol import EngineClient
from vllm.logger import init_logger

from vllm_omni.diffusion.data import OmniRequestError
from vllm_omni.entrypoints.async_omni import AsyncOmni
from vllm_omni.entrypoints.openai.protocol.videos import (
VideoData,
Expand Down Expand Up @@ -214,12 +215,24 @@ async def _run_generation(
sampling_params_list: list[OmniSamplingParams] = [gen_params for _ in stage_configs]

result = None
async for output in engine_client.generate(
prompt=prompt,
request_id=request_id,
sampling_params_list=sampling_params_list,
):
result = output
try:
async for output in engine_client.generate(
prompt=prompt,
request_id=request_id,
sampling_params_list=sampling_params_list,
):
result = output
except OmniRequestError as e:
raise HTTPException(
status_code=e.status_code,
detail={
"message": str(e),
"request_id": e.request_id,
"stage_id": e.stage_id,
"error_type": e.error_type,
"detail": e.detail,
},
) from e

if result is None:
raise HTTPException(
Expand Down
Loading