Skip to content

[Bugfix] out-of-bounds error for routed experts capture#37118

Open
HollowMan6 wants to merge 1 commit intovllm-project:mainfrom
HollowMan6:router_replay
Open

[Bugfix] out-of-bounds error for routed experts capture#37118
HollowMan6 wants to merge 1 commit intovllm-project:mainfrom
HollowMan6:router_replay

Conversation

@HollowMan6
Copy link
Copy Markdown
Contributor

@HollowMan6 HollowMan6 commented Mar 15, 2026

Purpose

This PR fixes an out-of-bounds error in routed expert capture when
enable_return_routed_experts=True is used with hybrid KV cache groups.

Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 927, in worker_busy_loop
    output = func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 124, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 756, in sample_tokens
    return self.model_runner.sample_tokens(grammar_output)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 124, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 4060, in sample_tokens
    capturer.save_captured_experts(indices=self.slot_mapping)  # noqa
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/routed_experts_capturer.py", line 217, in save_captured_experts
    self._host_buffer_view[indices, :, :] = data
    ~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^
IndexError: index 699312 is out of bounds for axis 0 with size 699312

The routed-experts side buffer was sized with:

(num_blocks // num_groups) * min_block_size

on both the worker and scheduler sides.

That formula is only a coarse aggregate/token-capacity estimate for hybrid KV
cache layouts. However, routed expert capture/readback indexes the buffer with
the selected attention group's actual slot_mapping, whose address space is
based on that attention KV group directly.

As a result, in hybrid/padded KV-cache configurations, the routed-experts
buffer can be smaller than the valid range of slot_mapping, which leads to
out-of-bounds writes/reads.

Use the routed-experts attention group's full KV address space to size the
buffer consistently on both sides:

kv_cache_config.num_blocks * attn_group.kv_cache_spec.block_size

Routed expert capture is indexed by the attention group's slot_mapping, so
the auxiliary buffer must match the full addressable range of that mapping.
Sizing it from the specific attention group keeps the writer and reader aligned

  • fixes crashes when enable_return_routed_experts=True
  • no behavior change when routed expert return is disabled
  • no change to model weights, routing logic, or sampling semantics

Test Plan

End to end tests

Test Result

Now error is gone


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

This PR fixes an out-of-bounds error in routed expert capture when
  `enable_return_routed_experts=True` is used with hybrid KV cache groups.

```logs
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 927, in worker_busy_loop
    output = func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 124, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 756, in sample_tokens
    return self.model_runner.sample_tokens(grammar_output)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 124, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 4060, in sample_tokens
    capturer.save_captured_experts(indices=self.slot_mapping)  # noqa
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/routed_experts_capturer.py", line 217, in save_captured_experts
    self._host_buffer_view[indices, :, :] = data
    ~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^
IndexError: index 699312 is out of bounds for axis 0 with size 699312
```

The routed-experts side buffer was sized with:

```python
(num_blocks // num_groups) * min_block_size

on both the worker and scheduler sides.

That formula is only a coarse aggregate/token-capacity estimate for hybrid KV
cache layouts. However, routed expert capture/readback indexes the buffer with
the selected attention group's actual slot_mapping, whose address space is
based on that attention KV group directly.

As a result, in hybrid/padded KV-cache configurations, the routed-experts
buffer can be smaller than the valid range of slot_mapping, which leads to
out-of-bounds writes/reads.

Use the routed-experts attention group's full KV address space to size the
buffer consistently on both sides:

kv_cache_config.num_blocks * attn_group.kv_cache_spec.block_size

Routed expert capture is indexed by the attention group's slot_mapping, so
the auxiliary buffer must match the full addressable range of that mapping.
Sizing it from the specific attention group keeps the writer and reader aligned
- fixes crashes when enable_return_routed_experts=True
- no behavior change when routed expert return is disabled
- no change to model weights, routing logic, or sampling semantics

Signed-off-by: Hollow Man <hollowman@opensuse.org>
Copilot AI review requested due to automatic review settings March 15, 2026 18:29
@mergify mergify bot added v1 bug Something isn't working labels Mar 15, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses an out-of-bounds error that occurs during routed expert capture with hybrid KV cache groups. The fix, which adjusts the calculation of the routed-experts buffer size to align with the slot_mapping address space, is sound. However, the logic for this calculation has been duplicated in both the scheduler and the GPU model runner. I've added comments to highlight this duplication and recommend refactoring it into a shared utility to enhance maintainability and prevent future inconsistencies.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes a crash when enable_return_routed_experts=True is used with hybrid (multi-group / padded) KV-cache layouts by sizing the routed-experts side buffer to match the attention KV group’s full slot-mapping address space.

Changes:

  • Compute max_num_kv_tokens as kv_cache_config.num_blocks * attn_group.kv_cache_spec.block_size (instead of a coarse num_blocks // num_groups * min_block_size estimate) on the GPU worker side.
  • Apply the same sizing logic on the scheduler side to keep writer/reader buffer views aligned.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
vllm/v1/worker/gpu_model_runner.py Sizes routed-experts capture buffer to the selected attention KV group’s full addressable slot range.
vllm/v1/core/sched/scheduler.py Sizes routed-experts readback shared-memory view consistently with the attention KV group’s slot-mapping range.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants