Skip to content

[Refactor] MLP weight prefetch to consistency with MoE Model's prefetching in terms of code and usage#6442

Merged
wangxiyuan merged 7 commits intovllm-project:mainfrom
leo-pony:weight_prefetch_p1
Feb 4, 2026
Merged

[Refactor] MLP weight prefetch to consistency with MoE Model's prefetching in terms of code and usage#6442
wangxiyuan merged 7 commits intovllm-project:mainfrom
leo-pony:weight_prefetch_p1

Conversation

@leo-pony
Copy link
Copy Markdown
Collaborator

@leo-pony leo-pony commented Jan 30, 2026

What this PR does / why we need it?

Refactor MLP weight prefetch to consistency with MoE Model's prefetching in terms of code and usage.
Environments VLLM_ASCEND_ENABLE_PREFETCH_MLP, VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE and VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE is removed, usage as following:

--additional-config '{"weight_prefetch_config": { "enabled": true, "prefetch_ratio": {"mlp": { "gate_up": 1.0, "down": 1.0} }}}'

Does this PR introduce any user-facing change?

How was this patch tested?

@github-actions
Copy link
Copy Markdown
Contributor

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@leo-pony leo-pony marked this pull request as draft January 30, 2026 12:56
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request successfully refactors the MLP weight prefetch mechanism, moving from environment variables to a more unified additional-config approach, consistent with the MoE Model's prefetching. This change streamlines configuration and aligns the prefetching logic across different model types. The documentation and test cases have been updated to reflect this new configuration method. However, there are a few critical issues identified in the new prefetching logic and some test cases that need to be addressed to ensure correctness and maintainability.

Comment thread vllm_ascend/ops/weight_prefetch.py Outdated
elif prefetch_layer_name == self.MLP_DOWN:
self._maybe_prefetch_mlp_down_weight_preprocess(x_dependency, forward_context)
else:
raise ValueError(f"Unsupported prefetch weight name: {prefetch_weight_name}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a typo in the ValueError message. The variable prefetch_weight_name is used, but it is not defined in this scope. It should be prefetch_layer_name.

Suggested change
raise ValueError(f"Unsupported prefetch weight name: {prefetch_weight_name}")
raise ValueError(f"Unsupported prefetch weight name: {prefetch_layer_name}")

raise ValueError("curr_layer_prefix must been specified when prefetching mlp gate_up_proj weight")

# start point of gate_up_proj weight prefetch
if curr_layer_prefix.split('.')[-2] == "self_attn":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The condition curr_layer_prefix.split('.')[-2] == "self_attn" is used to determine if MLP gate_up_proj weight prefetching should occur. However, "self_attn" refers to the attention mechanism, not the MLP layer. This logic seems incorrect for MLP prefetching and could lead to the prefetching being triggered at the wrong time or not at all for MLP layers. The condition should accurately identify MLP layers.


@pytest.mark.parametrize("model", QWEN_DENSE_MODELS)
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_PREFETCH_MLP": "1"})
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The test test_qwen3_dense_prefetch_mlp_weight_tp2 is intended to test MLP weight prefetching. However, it is patching the environment variable VLLM_ASCEND_ENABLE_FLASHCOMM1, which is related to FlashComm optimization, not MLP prefetching. Since MLP prefetching is now configured via additional_config (as correctly done in line 240 of this test), this environment variable patch is misleading and potentially incorrect, as it might not be enabling the intended feature for this specific test, or it's patching an unrelated feature. This could lead to false positives or incorrect test coverage.

Suggested change
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
@pytest.mark.parametrize("model", QWEN_DENSE_MODELS)

SUPPORTED_MODULES = ["attn", "mlp", "moe"]
MOE_PREFETCH_TOKEN_THRESHOLD = 96

MAX_PREFETCH_WEIGHT_SIZE = 18 * 1024 * 1024
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The constant MAX_PREFETCH_WEIGHT_SIZE is defined with a magic number 18 * 1024 * 1024. While this value was previously a default in environment variables, it's good practice to define such critical configuration values with a descriptive name and potentially make it configurable if it's a tuning parameter. Consider adding a comment explaining the origin or purpose of this specific size, or making it configurable through WeightPrefetchConfig if it's meant to be dynamic.

Comment thread vllm_ascend/ops/weight_prefetch.py Outdated
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get(
"mlp", {}) or {'gate_up': 1.0, 'down': 1.0})

print(f'mlp prefetch config: {self.mlp} self.is_moe:{self.is_moe} ==============================================================')
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This print statement appears to be for debugging purposes. Debug prints should be removed from production code to avoid unnecessary console output and potential performance overhead.

Comment thread vllm_ascend/ops/weight_prefetch.py Outdated
weight_size = weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get("gate_up", 0)
if weight_size > MAX_PREFETCH_WEIGHT_SIZE:
weight_size = MAX_PREFETCH_WEIGHT_SIZE
print(f'mlp prefetch gate_up current layer prefix:{curr_layer_prefix}, weight size: {weight_size} ==============================================================')
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This print statement appears to be for debugging purposes. Debug prints should be removed from production code to avoid unnecessary console output and potential performance overhead.

Comment thread vllm_ascend/ops/weight_prefetch.py Outdated
max_weight_size=int(weight_size))
forward_context.prefetch_mlp_down_proj = True
forward_context.layer_idx += 1
print(f'mlp prefetch down layer idx:{layer_idx}, layer_idx for next forward:{forward_context.layer_idx}, weight size: {weight_size} ==============================================================')
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This print statement appears to be for debugging purposes. Debug prints should be removed from production code to avoid unnecessary console output and potential performance overhead.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Feb 2, 2026

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Signed-off-by: leo-pony <nengjunma@outlook.com>
@leo-pony leo-pony marked this pull request as ready for review February 2, 2026 08:33
Signed-off-by: leo-pony <nengjunma@outlook.com>
Signed-off-by: leo-pony <nengjunma@outlook.com>
Signed-off-by: leo-pony <nengjunma@outlook.com>
Signed-off-by: leo-pony <nengjunma@outlook.com>
Signed-off-by: leo-pony <nengjunma@outlook.com>
Signed-off-by: leo-pony <nengjunma@outlook.com>
@leo-pony leo-pony added ready read for review ready-for-test start test by label for PR labels Feb 3, 2026
Comment thread docs/source/tutorials/Qwen3-Dense.md Outdated
Comment thread docs/source/tutorials/Qwen3-Dense.md
Comment thread docs/source/tutorials/Qwen3-Dense.md
class AscendSiluAndMul310(AscendSiluAndMul):
def forward(self, x: torch.Tensor) -> torch.Tensor:
torch.ops.vllm.maybe_prefetch_mlp_down_proj(x)
weight_prefetch_method = get_weight_prefetch_method()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we shoult drop support for 310p first.

It is important to emphasize that, since we use vector computations to hide the weight prefetching pipeline, the setting of the prefetch buffer size is crucial. If the buffer size is too small, the optimization benefits will not be fully realized, while a larger buffer size may lead to resource contention, resulting in performance degradation. To accommodate different scenarios, we have exposed two environment variables `VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE` and `VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE` to allow for flexible buffer size configuration based on the specific workload.

This optimization requires setting the environment variable `VLLM_ASCEND_ENABLE_PREFETCH_MLP = 1` to be enabled.
Previously, the environment variables VLLM_ASCEND_ENABLE_PREFETCH_MLP used to enable MLP weight prefetch and VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE and VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE used to set the weight prefetch size for MLP gate_up_proj and down_proj were deprecated. Please use the following configuration instead: "weight_prefetch_config": { "enabled": true, "prefetch_ratio": { "mlp": { "gate_up": 1.0, "down": 1.0}}}. See User Guide->Feature Guide->Weight Prefetch Guide for details.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See User Guide->Feature Guide->Weight Prefetch Guide for details. this can be set to link instead.

@wangxiyuan wangxiyuan merged commit 78fad4e into vllm-project:main Feb 4, 2026
40 of 41 checks passed
ZYang6263 pushed a commit to rjg-lyh/vllm-ascend that referenced this pull request Feb 4, 2026
…ching in terms of code and usage (vllm-project#6442)

Refactor MLP weight prefetch to consistency with MoE Model's prefetching
in terms of code and usage.
Environments VLLM_ASCEND_ENABLE_PREFETCH_MLP,
VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE and
VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE is removed, usage as following:

--additional-config '{"weight_prefetch_config": { "enabled": true,
"prefetch_ratio": {"mlp": { "gate_up": 1.0, "down": 1.0} }}}'

- vLLM version: v0.14.1
- vLLM main:
vllm-project/vllm@dc917cc

---------

Signed-off-by: leo-pony <nengjunma@outlook.com>
Signed-off-by: ZYang6263 <zy626375@gmail.com>
845473182 pushed a commit to 845473182/vllm-ascend that referenced this pull request Feb 6, 2026
…to qwen3next_rebase

* 'main' of https://github.com/vllm-project/vllm-ascend: (59 commits)
  [Feat.]: 310p support MOE models (vllm-project#6530)
  [Doc] backport 0.13.0 release note (vllm-project#6584)
  [CI] Update UT CANN version to 8.5.0 for main branch (vllm-project#6564)
  [CI] Change A2 runner (vllm-project#6557)
  [Bugfix] Fix the incorrect use of the output parameter in _forward_fia_slidingwindow (vllm-project#6469)
  [main2main] upgrade vllm main 0202 (vllm-project#6560)
  [CI][npugraph_ex]Fix npugraph ex e2e test (vllm-project#6553)
  [Feature]KV pool supports sparse attention (vllm-project#6339)
  [bugfix]Fix accuracy issue in PCP/DCP with speculative decoding (vllm-project#6491)
  perf: adaptive block size selection in linear_persistent kernel (vllm-project#6537)
  [ModelRunner][Fix] Pads query_start_loc to satisfy FIA/TND constraint (vllm-project#6475)
  [Bugfix]Fix of Pooling Code and Update of Pooling Usage Guide (vllm-project#6126)
  [Fusion] Add rmsnorm dynamic quant fusion pass (vllm-project#6274)
  [Bugfix] Synchronize only the current stream to avoid device sync (vllm-project#6432)
  [CI] Add long and short prompt tests for DeepSeek-V3.2 (vllm-project#6499)
  [Refactor] MLP weight prefetch to consistency with MoE Model's prefetching in terms of code and usage (vllm-project#6442)
  [bugfix][npugraph_ex]duplicate pattern issue (vllm-project#6513)
  [bugfix][npugraph_ex]add the extra check for allreduce rmsnorm fusion pass (vllm-project#6430)
  [Quant] GLM4.7-Flash Support W8A8 (vllm-project#6492)
  [Nightly][BugFix] Remove kv_cache nz test case for test_mla_preprocess_nq.py (vllm-project#6505)
  ...
chenchuw886 pushed a commit to chenchuw886/vllm-ascend that referenced this pull request Feb 12, 2026
…ching in terms of code and usage (vllm-project#6442)

### What this PR does / why we need it?
Refactor MLP weight prefetch to consistency with MoE Model's prefetching
in terms of code and usage.
Environments VLLM_ASCEND_ENABLE_PREFETCH_MLP,
VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE and
VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE is removed, usage as following:

--additional-config '{"weight_prefetch_config": { "enabled": true,
"prefetch_ratio": {"mlp": { "gate_up": 1.0, "down": 1.0} }}}'

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.1
- vLLM main:
vllm-project/vllm@dc917cc

---------

Signed-off-by: leo-pony <nengjunma@outlook.com>
Signed-off-by: momochenchuw <chenchuw@huawei.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
…ching in terms of code and usage (vllm-project#6442)

### What this PR does / why we need it?
Refactor MLP weight prefetch to consistency with MoE Model's prefetching
in terms of code and usage.
Environments VLLM_ASCEND_ENABLE_PREFETCH_MLP,
VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE and
VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE is removed, usage as following:

--additional-config '{"weight_prefetch_config": { "enabled": true,
"prefetch_ratio": {"mlp": { "gate_up": 1.0, "down": 1.0} }}}'

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.1
- vLLM main:
vllm-project/vllm@dc917cc

---------

Signed-off-by: leo-pony <nengjunma@outlook.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
maoxx241 pushed a commit to maoxx241/vllm-ascend that referenced this pull request Mar 2, 2026
…ching in terms of code and usage (vllm-project#6442)

### What this PR does / why we need it?
Refactor MLP weight prefetch to consistency with MoE Model's prefetching
in terms of code and usage.
Environments VLLM_ASCEND_ENABLE_PREFETCH_MLP,
VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE and
VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE is removed, usage as following:

--additional-config '{"weight_prefetch_config": { "enabled": true,
"prefetch_ratio": {"mlp": { "gate_up": 1.0, "down": 1.0} }}}'

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.1
- vLLM main:
vllm-project/vllm@dc917cc

---------

Signed-off-by: leo-pony <nengjunma@outlook.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
…ching in terms of code and usage (vllm-project#6442)

### What this PR does / why we need it?
Refactor MLP weight prefetch to consistency with MoE Model's prefetching
in terms of code and usage.
Environments VLLM_ASCEND_ENABLE_PREFETCH_MLP,
VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE and
VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE is removed, usage as following:

--additional-config '{"weight_prefetch_config": { "enabled": true,
"prefetch_ratio": {"mlp": { "gate_up": 1.0, "down": 1.0} }}}'

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.1
- vLLM main:
vllm-project/vllm@dc917cc

---------

Signed-off-by: leo-pony <nengjunma@outlook.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
LCAIZJ pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Mar 7, 2026
…ching in terms of code and usage (vllm-project#6442)

### What this PR does / why we need it?
Refactor MLP weight prefetch to consistency with MoE Model's prefetching
in terms of code and usage.
Environments VLLM_ASCEND_ENABLE_PREFETCH_MLP,
VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE and
VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE is removed, usage as following:

--additional-config '{"weight_prefetch_config": { "enabled": true,
"prefetch_ratio": {"mlp": { "gate_up": 1.0, "down": 1.0} }}}'

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.1
- vLLM main:
vllm-project/vllm@dc917cc

---------

Signed-off-by: leo-pony <nengjunma@outlook.com>
jiangyunfan1 pushed a commit to jiangyunfan1/vllm-ascend that referenced this pull request Apr 9, 2026
…ching in terms of code and usage (vllm-project#6442)

### What this PR does / why we need it?
Refactor MLP weight prefetch to consistency with MoE Model's prefetching
in terms of code and usage.
Environments VLLM_ASCEND_ENABLE_PREFETCH_MLP,
VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE and
VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE is removed, usage as following:

--additional-config '{"weight_prefetch_config": { "enabled": true,
"prefetch_ratio": {"mlp": { "gate_up": 1.0, "down": 1.0} }}}'

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.1
- vLLM main:
vllm-project/vllm@dc917cc

---------

Signed-off-by: leo-pony <nengjunma@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation module:core module:ops module:tests ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants