Skip to content

[FEAT] support multi-stage deployment#2396

Merged
hsliuustc0106 merged 98 commits into
vllm-project:mainfrom
ZhengWG:support-stage-scale-out
May 6, 2026
Merged

[FEAT] support multi-stage deployment#2396
hsliuustc0106 merged 98 commits into
vllm-project:mainfrom
ZhengWG:support-stage-scale-out

Conversation

@ZhengWG
Copy link
Copy Markdown
Contributor

@ZhengWG ZhengWG commented Apr 1, 2026

Purpose

Implement Multi-Instance Stage Deployment for vLLM-Omni, enabling horizontal scaling of individual logical stages (e.g., talker) with multiple replicas.

Problem: Each YAML stage_id maps 1:1 to a single StageEngineCoreClient. The same logical role (e.g., talker) cannot be horizontally scaled — compute and throughput cannot be independently scaled per stage.

**Solution:**Introduce a StagePool layer between the Orchestrator and individual engine clients. Each logical stage owns a pool of replicas; the Orchestrator operates exclusively on StageReplica handles, never on flat indices. Replica selection (round-robin + per-request affinity + CFG companion binding) is encapsulated inside the pool. Related to #2634. cc @yinpeiqi @fake0fan

Main Design:

  1. StagePool Abstraction (vllm_omni/engine/stage_pool.py)
  • select_replica: Three-phase resolution — (1) req_state cache hit → same replica as before; (2) affinity_from cross-request binding (CFG companion → parent); (3) round-robin.
  • admit: Atomically couples select_replica + output_processor.add_request, ensuring the processor that receives raw outputs is always the one that registered the request. This fixes a silent output loss bug when num_replicas > 1 on stage 0.
  1. Orchestrator — Handle-based, No Flat Indices
    All orchestrator methods (_orchestration_loop, _route_output, _forward_to_next_stage, _poll_stage_raw, _process_stage_outputs, _build_stage_metrics, _handle_abort, _handle_collective_rpc, _shutdown_stages, etc.) operate on StageReplica handles directly. No flat-index resolution anywhere.

Usage:

TP=2 with 2 replicas:

    runtime:
      num_replicas: 2
      devices: "1,2,3,4"   # r0→"1,2", r1→"3,4"
    engine_args:
      tensor_parallel_size: 2

Backward compatible: When num_replicas is omitted (defaults to 1), behavior is identical to the original code. No existing YAML configs need modification. More details in vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk_multi_replicas.yaml.

Scope / Known Limitations:

  1. Diffusion stages: Multi-replica not yet supported (single replica enforced).
  2. Remote (headless) multi-replica: In single_stage_mode, remote stages are pinned to 1 replica. Extending to N replicas requires OmniMasterServer protocol changes ((stage_id, replica_index) addressing). See TODO in _initialize_stages.
  3. Dynamic scaling: Replica count is static at startup; no runtime add/remove.

Test Plan

Unit tests — tests/engine/test_orchestrator.py (9 cases):

  • Existing single-replica cases adapted to StagePool interface: two-stage LLM pipeline, single-stage diffusion, LLM→diffusion, async-chunk prewarm, shutdown, abort
  • New multi-replica cases: RR distribution across 2 replicas with forward to stage-1, abort broadcasts to all replicas, shutdown covers all replicas

Test Result

================================================================================================================ test session starts ================================================================================================================
platform linux -- Python 3.10.16, pytest-9.0.3, pluggy-1.5.0 -- /opt/conda/bin/python
cachedir: .pytest_cache
rootdir: /home/nas/pengyu.zwg/vllm-omni-dev
configfile: pyproject.toml
plugins: asyncio-1.3.0, anyio-4.13.0
asyncio: mode=auto, debug=False, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function
collected 9 items

tests/engine/test_orchestrator.py::test_run_two_stage_llm PASSED                                                                                                                                                                              [ 11%]
tests/engine/test_orchestrator.py::test_run_single_stage_diffusion PASSED                                                                                                                                                                     [ 22%]
tests/engine/test_orchestrator.py::test_run_llm_to_diffusion PASSED                                                                                                                                                                           [ 33%]
tests/engine/test_orchestrator.py::test_run_async_chunk PASSED                                                                                                                                                                                [ 44%]
tests/engine/test_orchestrator.py::test_run_shutdown PASSED                                                                                                                                                                                   [ 55%]
tests/engine/test_orchestrator.py::test_run_abort PASSED                                                                                                                                                                                      [ 66%]
tests/engine/test_orchestrator.py::test_multi_replica_round_robin_distribution PASSED                                                                                                                                                         [ 77%]
tests/engine/test_orchestrator.py::test_multi_replica_abort_broadcasts_to_all_replicas PASSED                                                                                                                                                 [ 88%]
tests/engine/test_orchestrator.py::test_multi_replica_shutdown_all_replicas PASSED

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please provide the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please paste the results comparison before and after, or the e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user-facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

@ZhengWG ZhengWG requested a review from hsliuustc0106 as a code owner April 1, 2026 03:05
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 660e3c64d8

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread vllm_omni/engine/orchestrator.py Outdated
if 0 <= lid < self.num_logical_stages:
stage_ids.extend(self.logical_stage_to_clients[lid])
else:
stage_ids.append(lid) # keep invalid id for error reporting
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Reject invalid logical stage IDs before RPC expansion

The API comment here states that stage_ids are logical stage IDs, but the fallback path appends an out-of-range logical ID directly into the client target list. In multi-replica deployments (num_clients > num_logical_stages), an invalid logical ID (for example lid=2 when only logical stages 0-1 exist) can still pass the later stage_id < self.num_clients check and execute the RPC on an unintended client, so control operations like sleep/wake_up/profile may affect the wrong stage instead of returning an invalid-stage error.

Useful? React with 👍 / 👎.

@ZhengWG ZhengWG force-pushed the support-stage-scale-out branch 2 times, most recently from ee7e60f to bf789f9 Compare April 2, 2026 08:39
ZhengWG added 6 commits April 2, 2026 16:47
Signed-off-by: ZhengWG <zwg0606@gmail.com>
Made-with: Cursor
Signed-off-by: ZhengWG <zwg0606@gmail.com>
Made-with: Cursor
Signed-off-by: ZhengWG <zwg0606@gmail.com>
Made-with: Cursor
Signed-off-by: ZhengWG <zwg0606@gmail.com>
Made-with: Cursor
Signed-off-by: ZhengWG <zwg0606@gmail.com>
@ZhengWG ZhengWG force-pushed the support-stage-scale-out branch from bf789f9 to 4d9b160 Compare April 2, 2026 08:52
@yinpeiqi
Copy link
Copy Markdown
Contributor

yinpeiqi commented Apr 2, 2026

Thanks for the PR! This feature is very critical.
We just working on the DP deployment for different stages, and I have several comments to it.

  1. We may better to use the name DP rather than replica.
  2. The logic of DP routing and orchestration should be more clear dispatch in different class, rather than all in the orchestrator. Especially the orchestrator is the entry for both Omni model and Diffusion models, thus it should better don;t include complex logic.
  3. Several replica workers chould share the same output processor (or the same output coroutine)
  4. We may follow the designs of vllm DP (https://docs.vllm.ai/en/latest/serving/data_parallel_deployment/).
  5. We can check whether we can reuse the current OmniCoordinator, and check when to place our DP load balancer.
  6. For the usage, we can (i) write cuda device "1,2,3,4" in the yaml file and split it internal; (ii) write talker1: "1,2" and talker2: "3,4" in the yaml file. We can discuss it later about how to set it.

cc list: @fake0fan @chickeyton @wuhang2014 @Gaohan123 @tzhouam

@ZhengWG
Copy link
Copy Markdown
Contributor Author

ZhengWG commented Apr 2, 2026

Thanks for the PR! This feature is very critical. We just working on the DP deployment for different stages, and I have several comments to it.

  1. We may better to use the name DP rather than replica.
  2. The logic of DP routing and orchestration should be more clear dispatch in different class, rather than all in the orchestrator. Especially the orchestrator is the entry for both Omni model and Diffusion models, thus it should better don;t include complex logic.
  3. Several replica workers chould share the same output processor (or the same output coroutine)
  4. We may follow the designs of vllm DP (https://docs.vllm.ai/en/latest/serving/data_parallel_deployment/).
  5. We can check whether we can reuse the current OmniCoordinator, and check when to place our DP load balancer.
  6. For the usage, we can (i) write cuda device "1,2,3,4" in the yaml file and split it internal; (ii) write talker1: "1,2" and talker2: "3,4" in the yaml file. We can discuss it later about how to set it.

cc list: @fake0fan @chickeyton @wuhang2014 @Gaohan123 @tzhouam

Thanks for the feedback! The suggestions on DP naming, routing separation, and aligning with vLLM DP design all make sense — we're thinking along similar lines.
You mentioned you're also working on DP deployment for different stages — that's great, sounds like we have a lot of overlap. We've done an initial implementation (orchestration framework + per-stage process isolation) and ran some benchmarks on Qwen3-Omni. Happy to share our RFC and benchmark data if useful.
Would you be open to a quick sync to align on the approach and see how we can collaborate?

@yinpeiqi
Copy link
Copy Markdown
Contributor

yinpeiqi commented Apr 2, 2026

Yeah great! We can have a sync for how to collaborate then!

Signed-off-by: ZhengWG <zwg0606@gmail.com>
Comment thread vllm_omni/engine/async_omni_engine.py Outdated
self.stage_clients = flat_clients
self.output_processors = flat_output_processors
self.stage_vllm_configs = flat_vllm_configs
self.logical_stage_to_clients = logical_stage_to_clients
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I suggest to add a lightweight StageEnginePool / StagePool class, serve as a list of EngineCoreClient from the same stage. The output processors also could be place in the StagePool.

After init, the AsyncOmniEngine would holds:

self.stage_pool_list = [pool1, pool2] # List[StagePool]

The routing logic could be put inside the stage_pool. For example, we may call functions like self.stage_pool_list[stage_id].select to choose a logical engine core client.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vllm_configs could also be placed inside the StagePool

Comment thread vllm_omni/engine/async_omni_engine.py Outdated
all_output_processors[stage_id] = stage_output_procs
all_vllm_configs[stage_id] = stage_vllm_cfgs
# Use first replica for finalize_initialized_stages
logical_stage_clients_for_finalize[stage_id] = stage_clients_list[0]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The init function becomes more complex. We may launch the stages and then register to the StagePool.

Comment thread vllm_omni/engine/async_omni_engine.py Outdated
if any(getattr(stage_client, "is_comprehension", False) for stage_client in flat_clients):
supported_tasks.add("generate")
if any(metadata.get("final_output_type") == "audio" for metadata in stage_metadata):
if any(metadata.get("final_output_type") == "audio" for metadata in logical_stage_metadata):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrap them into the stage pool?

Comment thread vllm_omni/engine/orchestrator.py Outdated
for logical_id, client_indices in enumerate(self.logical_stage_to_clients):
for ri, ci in enumerate(client_indices):
self._client_to_logical[ci] = logical_id
self._client_to_replica[ci] = ri
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we parse List[StagePool] into here, do we still need these data?

Comment thread vllm_omni/engine/orchestrator.py Outdated
# Multi-replica: maps logical_stage_id -> client_index chosen for this
# request. Ensures the same request always hits the same replica within
# a given logical stage (KV / intermediate-state affinity).
chosen_client_index: dict[int, int] = field(default_factory=dict)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use the Router here to record the chosen client index, use a map inside the router. This could be consider later.

Comment thread vllm_omni/engine/orchestrator.py Outdated
raise
except Exception:
for stage_id in range(self.num_logical_stages):
for replica_index in range(len(self.logical_stage_to_clients[stage_id])):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am afraid here would be a potential bottleneck of the orchestrator thread under larger scale. But we can just kept it for now.

Comment thread vllm_omni/engine/orchestrator.py Outdated

async def _poll_stage_raw(self, stage_id: int) -> EngineCoreOutputs | None:
"""Pull raw EngineCoreOutputs from a stage client without processing.
async def _poll_stage_raw(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found all these functions have both a stage_id and a replica_index, what about offload these functions to StagePool? If so, the orchestrator would become lighter. For example:

class StagePool:
    async def _poll_stage_raw(..., replica_index):
          ...

# in orchestrator:
    pool = self.stage_pool_list[stage_id]
    pool.poll_stage_raw(replica_id)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, rename all replica_index to replica_id, to align with the naming of the parameters.

Comment thread vllm_omni/engine/orchestrator.py Outdated
return outputs

async def _process_stage_outputs(self, stage_id: int, raw_outputs: EngineCoreOutputs) -> list[RequestOutput]:
async def _process_stage_outputs(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These functions could also consider move to StagePool? We can further discuss about it.

Comment thread vllm_omni/engine/orchestrator.py Outdated
companion_state.stage_submit_ts[0] = _time.time()
self.request_states[companion_id] = companion_state

# Use same replica as the parent for affinity, or choose one
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we support CFG in this pr?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CFG is partially supported in this PR.

The companion flow is implemented, and it works for the common setup with single-replica stage-0. For multi-replica stage-0, there is still a processor/client alignment risk. We’ll add follow-up hardening + tests for that case.

chosen_client_index: dict[int, int] = field(default_factory=dict)


class Orchestrator:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add tests for multi-stage scenarios

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I will add it.

@ZhengWG
Copy link
Copy Markdown
Contributor Author

ZhengWG commented Apr 13, 2026

@yinpeiqi I understand that your main concern is having a standalone StagePool design. I’ll take a deeper look and start with a baseline StagePool implementation.

ZhengWG added 2 commits April 13, 2026 14:51
Signed-off-by: ZhengWG <zwg0606@gmail.com>
Signed-off-by: ZhengWG <zwg0606@gmail.com>
ZhengWG added 3 commits April 14, 2026 11:49
Signed-off-by: ZhengWG <zwg0606@gmail.com>
Signed-off-by: ZhengWG <zwg0606@gmail.com>
Signed-off-by: ZhengWG <zwg0606@gmail.com>
@Gaohan123 Gaohan123 added this to the v0.20.0 milestone Apr 15, 2026
ZhengWG added 5 commits April 16, 2026 11:11
Signed-off-by: ZhengWG <zwg0606@gmail.com>
Signed-off-by: ZhengWG <zwg0606@gmail.com>
Signed-off-by: ZhengWG <zwg0606@gmail.com>
Signed-off-by: ZhengWG <zwg0606@gmail.com>
Signed-off-by: ZhengWG <zwg0606@gmail.com>
@ZhengWG ZhengWG force-pushed the support-stage-scale-out branch from 56a8be6 to 658415e Compare April 16, 2026 09:24
Signed-off-by: ZhengWG <zwg0606@gmail.com>
@ZhengWG
Copy link
Copy Markdown
Contributor Author

ZhengWG commented May 5, 2026

Here is Multi-Instance Stage Scale-Out Benchmark Summary

1. Benchmark Parameters

  • Model: Qwen3-Omni-MoE (3-stage pipeline: thinker → talker → code2wav)
  • Dataset: Random, 128 prompts, input_len=1024, output_len=512
  • Concurrency sweep: 8, 16, 24, 32
  • Command: vllm bench serve --omni --backend openai-chat-omni --dataset-name random --num-prompts 128 --random-input-len 1024 --output-len 512 --ignore-eos

I think we should still add benchmark data for 2gpu_base and 3gpu_single, especially at 32 concurrency for a more complete comparison.

The multi-replica setup clearly shows strong throughput gains at higher concurrency, but TTFT and TPOT also degrade quite a bit. It would be helpful to clarify whether this latency regression is acceptable, or to explicitly present it as a throughput/latency tradeoff.

Also, although we scaled out talker and code2wav, the current summary does not show stage-level metrics for those stages, so it is hard to directly attribute the gains to them. My understanding is that the reported TTFT and TPOT may still be dominated by the thinker stage. If so, these metrics alone may not fully demonstrate the benefit of multi-replica on the scaled stages.

Overall, the direction looks good, but adding the missing baseline points and clarifying the latency/stage-level impact would strengthen the conclusion.

Updated Benchmark Results

Re-ran the full benchmark on the latest branch with the same parameters as before:

  • Command: vllm bench serve --omni --backend openai-chat-omni --dataset-name random --num-prompts 128 --random-input-len 1024 --output-len 512 --ignore-eos
  • Hardware: 8× NVIDIA H20-96G
  • Concurrency sweep: 8, 16, 24, 32

Complete Results (2gpu_base c=24,32 and 3gpu_single c=32 added)

Config c req/s tok/s TTFT med (ms) TPOT med (ms) Audio Throughput
2gpu_base 8 0.189 68.2 123 12.6 23.8
2gpu_base 16 0.229 81.7 138 14.9 29.1
2gpu_base 24 0.197 73.7 138 15.3 27.8
2gpu_base 32 0.241 89.9 144 16.7 32.4
3gpu_single 8 0.207 73.5 131 14.7 26.8
3gpu_single 16 0.248 94.1 139 14.3 33.8
3gpu_single 24 0.249 91.6 150 16.0 33.2
3gpu_single 32 0.282 106.4 162 18.3 37.9
3gpu_replica2 8 0.230 84.7 131 12.5 31.0
3gpu_replica2 16 0.351 126.8 157 15.8 44.9
3gpu_replica2 24 0.392 145.8 183 17.1 51.6
3gpu_replica2 32 0.414 156.2 201 21.4 51.2

Peak per-GPU efficiency: 3gpu_replica2 reaches 0.138 req/s/GPU (+14% vs 2gpu_base, +47% vs 3gpu_single).

On the Latency/Throughput Tradeoff

The TTFT/TPOT degradation at higher concurrency is not specific to multi-replica — it applies to all configurations equally. A fairer comparison is at equivalent QPS:

  • To reach ~0.25 req/s: 3gpu_single needs c=24 (TTFT=150ms), while 3gpu_replica2 only needs c=8 (TTFT=131ms)
  • To reach ~0.23 req/s: 2gpu_base needs c=16 (TTFT=138ms), while 3gpu_replica2 only needs c=8 (TTFT=131ms)

At the same throughput level, multi-replica achieves equal or lower latency — this is a Pareto improvement, not a tradeoff.

On Stage-Level Attribution

Agreed that TTFT/TPOT are dominated by the thinker stage (1 replica in all configs). Two signals directly attribute the gains to talker/code2wav:

1) TTFT is nearly identical across configs at the same concurrency (e.g., c=8: 123/131/131ms), confirming the thinker bottleneck is unchanged and multi-replica adds no overhead.

2) Audio throughput directly measures talker+code2wav capacity (thinker does not produce audio):

c 3gpu_single 3gpu_replica2 improvement
8 26.8 31.0 +16%
16 33.8 44.9 +33%
24 33.2 51.6 +55%
32 37.9 51.2 +35%

Additionally, 3gpu_single plateaus at c=16–24 (0.248→0.249 req/s), showing its talker/code2wav pipeline is saturated. Multi-replica breaks through this ceiling to 0.414 req/s at c=32.

Per-step profiling confirms framework overhead is <0.3ms/step (<0.5% of compute), and the gains come from parallel talker/code2wav execution across replicas.

@fake0fan @yinpeiqi @hsliuustc0106

@Gaohan123 Gaohan123 added omni-test label to trigger buildkite omni model test in nightly CI diffusion-x2iat-test label to trigger buildkite x2image + x2audio + x2text series of diffusion models test in nightly CI labels May 5, 2026
@hsliuustc0106 hsliuustc0106 removed the diffusion-x2iat-test label to trigger buildkite x2image + x2audio + x2text series of diffusion models test in nightly CI label May 5, 2026
@Gaohan123 Gaohan123 removed ready label to trigger buildkite CI omni-test label to trigger buildkite omni model test in nightly CI labels May 6, 2026
@Gaohan123 Gaohan123 added diffusion-x2iat-test label to trigger buildkite x2image + x2audio + x2text series of diffusion models test in nightly CI labels May 6, 2026
@hsliuustc0106 hsliuustc0106 removed the diffusion-x2iat-test label to trigger buildkite x2image + x2audio + x2text series of diffusion models test in nightly CI label May 6, 2026
A replica consumes the devices required by that diffusion stage. For single-device diffusion pipelines, set `runtime.devices` to one device per replica:

```yaml
stages:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we we deprecate these yamls in v0.22.0, please remember to rm them in the follow-up PRs

Signed-off-by: ZhengWG <zwg0606@gmail.com>
@fake0fan
Copy link
Copy Markdown
Contributor

fake0fan commented May 6, 2026

Given the issues that still exist in the Bagel model itself for the current version of the code, I think we should leave out the diffusion part for now. You can go ahead and remove the corresponding CI , YAML , README content, and any other related materials.

@hsliuustc0106 hsliuustc0106 added ready label to trigger buildkite CI labels May 6, 2026
Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, please rm the redundant docs&tests in a follow-up PR @fake0fan @ZhengWG

@hsliuustc0106 hsliuustc0106 merged commit 1e5f288 into vllm-project:main May 6, 2026
7 of 8 checks passed
clodaghwalsh17 pushed a commit to clodaghwalsh17/nm-vllm-omni-ent that referenced this pull request May 12, 2026
Signed-off-by: ZhengWG <zwg0606@gmail.com>
Signed-off-by: Zheng Wengang <zwg0606@gmail.com>
Signed-off-by: Peiqi Yin <60515999+yinpeiqi@users.noreply.github.com>
Signed-off-by: yinpe <11810305@mail.sustech.edu.cn>
Signed-off-by: yinpeiqi <yinpeiqi809@gmail.com>
Co-authored-by: Peiqi Yin <60515999+yinpeiqi@users.noreply.github.com>
Co-authored-by: yinpe <11810305@mail.sustech.edu.cn>
Co-authored-by: yinpeiqi <yinpeiqi809@gmail.com>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
Co-authored-by: Gao Han <hgaoaf@connect.ust.hk>
Co-authored-by: Chenguang Zheng <645327136@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

high priority high priority issue, needs to be done asap ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants