-
Notifications
You must be signed in to change notification settings - Fork 1k
[Bugfix]fix wan2.2 RuntimeError no response to user #2390
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -675,6 +675,49 @@ class DiffusionOutput: | |
| peak_memory_mb: float = 0.0 | ||
|
|
||
|
|
||
| @dataclass | ||
| class OmniRequestError(RuntimeError): | ||
| def __init__( | ||
| self, | ||
| message: str, | ||
| *, | ||
| status_code: int = 500, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| request_id: str | None = None, | ||
| stage_id: int | None = None, | ||
| error_type: str | None = None, | ||
| detail: dict[str, Any] | None = None, | ||
| ): | ||
| super().__init__(message) | ||
| self.status_code = status_code | ||
| self.request_id = request_id | ||
| self.stage_id = stage_id | ||
| self.error_type = error_type or self.__class__.__name__ | ||
| self.detail = detail or {} | ||
|
|
||
|
|
||
| def normalize_omni_error( | ||
| exc: Exception, | ||
| *, | ||
| status_code: int = 500, | ||
| request_id: str | None = None, | ||
| stage_id: int | None = None, | ||
| ) -> OmniRequestError: | ||
| if isinstance(exc, OmniRequestError): | ||
| if exc.request_id is None: | ||
| exc.request_id = request_id | ||
| if exc.stage_id is None: | ||
| exc.stage_id = stage_id | ||
| return exc | ||
|
|
||
| return OmniRequestError( | ||
| str(exc), | ||
| status_code=status_code, | ||
| request_id=request_id, | ||
| stage_id=stage_id, | ||
| error_type=type(exc).__name__, | ||
| ) | ||
|
|
||
|
|
||
| class AttentionBackendEnum(enum.Enum): | ||
| FA = enum.auto() | ||
| SLIDING_TILE_ATTN = enum.auto() | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -11,7 +11,7 @@ | |||||||||||||||
| import torch | ||||||||||||||||
| from vllm.logger import init_logger | ||||||||||||||||
|
|
||||||||||||||||
| from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig | ||||||||||||||||
| from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig, OmniRequestError, normalize_omni_error | ||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: line too long. Split the import.
Suggested change
|
||||||||||||||||
| from vllm_omni.diffusion.executor.abstract import DiffusionExecutor | ||||||||||||||||
| from vllm_omni.diffusion.registry import ( | ||||||||||||||||
| DiffusionModelRegistry, | ||||||||||||||||
|
|
@@ -99,7 +99,12 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: | |||||||||||||||
| exec_total_time = time.perf_counter() - exec_start_time | ||||||||||||||||
|
|
||||||||||||||||
| if output.error: | ||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove the commented-out code. Same for the other |
||||||||||||||||
| raise Exception(f"{output.error}") | ||||||||||||||||
| # raise Exception(f"{output.error}") | ||||||||||||||||
| raise OmniRequestError( | ||||||||||||||||
| output.error, | ||||||||||||||||
| status_code=500, | ||||||||||||||||
| error_type="DiffusionExecutionError", | ||||||||||||||||
| ) | ||||||||||||||||
| logger.info("Generation completed successfully.") | ||||||||||||||||
|
|
||||||||||||||||
| if output.output is None: | ||||||||||||||||
|
|
@@ -280,33 +285,52 @@ def add_req_and_wait_for_response(self, request: OmniDiffusionRequest) -> Diffus | |||||||||||||||
| target_sched_req_id = self.scheduler.add_request(request) | ||||||||||||||||
|
|
||||||||||||||||
| # keep scheduling and executing until the target request is finished | ||||||||||||||||
| while True: | ||||||||||||||||
| sched_output = self.scheduler.schedule() | ||||||||||||||||
| if sched_output.is_empty: | ||||||||||||||||
| if not self.scheduler.has_requests(): | ||||||||||||||||
| raise RuntimeError("Diffusion scheduler has no runnable requests.") | ||||||||||||||||
| continue | ||||||||||||||||
|
|
||||||||||||||||
| # NOTE: add_req_and_wait_for_response() is synchronous, and | ||||||||||||||||
| # the scheduler currently enforces _max_batch_size = 1 (see | ||||||||||||||||
| # vllm_omni/diffusion/sched/base_scheduler.py), so we directly | ||||||||||||||||
| # take the single scheduled request here. | ||||||||||||||||
| sched_req_id = sched_output.scheduled_req_ids[0] | ||||||||||||||||
| req = sched_output.scheduled_new_reqs[0].req | ||||||||||||||||
| try: | ||||||||||||||||
| while True: | ||||||||||||||||
| sched_output = self.scheduler.schedule() | ||||||||||||||||
| if sched_output.is_empty: | ||||||||||||||||
| if not self.scheduler.has_requests(): | ||||||||||||||||
| # raise RuntimeError("Diffusion scheduler has no runnable requests.") | ||||||||||||||||
| raise OmniRequestError( | ||||||||||||||||
| "Diffusion scheduler has no runnable requests.", | ||||||||||||||||
| status_code=500, | ||||||||||||||||
| error_type="SchedulerError", | ||||||||||||||||
| ) | ||||||||||||||||
| continue | ||||||||||||||||
|
|
||||||||||||||||
| # NOTE: add_req_and_wait_for_response() is synchronous, and | ||||||||||||||||
| # the scheduler currently enforces _max_batch_size = 1 (see | ||||||||||||||||
| # vllm_omni/diffusion/sched/base_scheduler.py), so we directly | ||||||||||||||||
| # take the single scheduled request here. | ||||||||||||||||
| sched_req_id = sched_output.scheduled_req_ids[0] | ||||||||||||||||
| req = sched_output.scheduled_new_reqs[0].req | ||||||||||||||||
| try: | ||||||||||||||||
| output = self.executor.add_req(req) | ||||||||||||||||
| except Exception as exc: | ||||||||||||||||
| logger.error( | ||||||||||||||||
| "Execution failed for diffusion request %s", | ||||||||||||||||
| sched_req_id, | ||||||||||||||||
| exc_info=True, | ||||||||||||||||
| ) | ||||||||||||||||
| # output = DiffusionOutput(error=str(exc)) | ||||||||||||||||
| raise normalize_omni_error(exc) from exc | ||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In Useful? React with 👍 / 👎.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreeing with the bot comment above — this |
||||||||||||||||
|
|
||||||||||||||||
| finished_req_ids = self.scheduler.update_from_output(sched_output, output) | ||||||||||||||||
|
|
||||||||||||||||
| if output.error: | ||||||||||||||||
| raise OmniRequestError( | ||||||||||||||||
| output.error, | ||||||||||||||||
| status_code=500, | ||||||||||||||||
| error_type="DiffusionExecutionError", | ||||||||||||||||
| ) | ||||||||||||||||
| if target_sched_req_id in finished_req_ids: | ||||||||||||||||
| # self.scheduler.pop_request_state(target_sched_req_id) | ||||||||||||||||
| return output | ||||||||||||||||
| finally: | ||||||||||||||||
| try: | ||||||||||||||||
| output = self.executor.add_req(req) | ||||||||||||||||
| except Exception as exc: | ||||||||||||||||
| logger.error( | ||||||||||||||||
| "Execution failed for diffusion request %s", | ||||||||||||||||
| sched_req_id, | ||||||||||||||||
| exc_info=True, | ||||||||||||||||
| ) | ||||||||||||||||
| output = DiffusionOutput(error=str(exc)) | ||||||||||||||||
|
|
||||||||||||||||
| finished_req_ids = self.scheduler.update_from_output(sched_output, output) | ||||||||||||||||
| if target_sched_req_id in finished_req_ids: | ||||||||||||||||
| self.scheduler.pop_request_state(target_sched_req_id) | ||||||||||||||||
| return output | ||||||||||||||||
| except Exception: | ||||||||||||||||
| logger.debug("Request state already removed: %s", target_sched_req_id, exc_info=True) | ||||||||||||||||
|
|
||||||||||||||||
| def profile(self, is_start: bool = True, profile_prefix: str | None = None) -> None: | ||||||||||||||||
| """Start or stop torch profiling on all diffusion workers. | ||||||||||||||||
|
|
||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,7 @@ | |
| from vllm.tasks import SupportedTask | ||
| from vllm.v1.engine.exceptions import EngineDeadError | ||
|
|
||
| from vllm_omni.diffusion.data import OmniRequestError | ||
| from vllm_omni.entrypoints.client_request_state import ClientRequestState | ||
| from vllm_omni.entrypoints.omni_base import OmniBase | ||
| from vllm_omni.metrics.stats import OrchestratorAggregator as OrchestratorMetrics | ||
|
|
@@ -290,14 +291,15 @@ async def _process_orchestrator_results( | |
| stage_id = result.get("stage_id", 0) | ||
|
|
||
| # Check for errors | ||
| if "error" in result: | ||
| logger.error( | ||
| "[AsyncOmni] Orchestrator error for req=%s stage-%s: %s", | ||
| request_id, | ||
| stage_id, | ||
| result["error"], | ||
| if isinstance(result, dict) and result.get("type") == "error": | ||
| raise OmniRequestError( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Raising Useful? React with 👍 / 👎. |
||
| result.get("error", "Unknown orchestrator error"), | ||
| status_code=result.get("status_code", 500), | ||
| request_id=result.get("request_id", request_id), | ||
| stage_id=result.get("stage_id", stage_id), | ||
| error_type=result.get("error_type"), | ||
| detail=result.get("detail") or {}, | ||
| ) | ||
| raise RuntimeError(result) | ||
|
|
||
| # Process the result (constructs OmniRequestOutput) | ||
| output_to_yield = self._process_single_result( | ||
|
|
@@ -344,6 +346,13 @@ async def _final_output_loop(): | |
| await asyncio.sleep(_FINAL_OUTPUT_IDLE_SLEEP_S) | ||
| continue | ||
|
|
||
| if isinstance(msg, dict) and msg.get("type") == "error": | ||
| req_id = msg.get("request_id") | ||
| req_state = self.request_states.get(req_id) | ||
| if req_state is not None: | ||
| await req_state.queue.put(msg) | ||
| continue | ||
|
|
||
| should_continue, _, stage_id, req_state = self._handle_output_message(msg) | ||
| if should_continue: | ||
| continue | ||
|
|
@@ -358,7 +367,16 @@ async def _final_output_loop(): | |
| except Exception as e: | ||
| logger.exception("[AsyncOmni] final_output_loop failed.") | ||
| for req_state in list(self.request_states.values()): | ||
| error_msg = {"request_id": req_state.request_id, "error": str(e)} | ||
| # error_msg = {"request_id": req_state.request_id, "error": str(e)} | ||
| error_msg = { | ||
| "type": "error", | ||
| "request_id": req_state.request_id, | ||
| "status_code": 500, | ||
| "error": str(e), | ||
| "error_type": type(e).__name__, | ||
| "detail": {}, | ||
| "finished": True, | ||
| } | ||
| await req_state.queue.put(error_msg) | ||
| self.final_output_task = None | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dataclasshere is a no-op — you define a custom__init__, sodataclasswon't generate one, and there are no class-level field annotations for it to act on. Just drop the decorator.