From c444bf61f807fd2f43b1768738feecb3a5590fb8 Mon Sep 17 00:00:00 2001 From: leo-pony Date: Mon, 2 Feb 2026 16:33:09 +0800 Subject: [PATCH 1/7] Refact MLP weight prefetch to consist with moe weight prefetch Signed-off-by: leo-pony --- docs/source/tutorials/Qwen3-Dense.md | 7 +- .../test_offline_inference_distributed.py | 3 +- .../2-cards/test_qwen3_performance.py | 3 +- .../single_node/models/test_qwen3_32b_int8.py | 4 +- .../test_qwen3_32b_int8_a3_feature_stack3.py | 4 +- .../single_node/models/test_qwq_32b.py | 5 +- tests/ut/ops/test_activation.py | 20 ---- vllm_ascend/_310p/ops/activation.py | 10 +- vllm_ascend/ascend_config.py | 1 + vllm_ascend/ascend_forward_context.py | 14 +-- vllm_ascend/envs.py | 10 -- vllm_ascend/ops/activation.py | 9 +- vllm_ascend/ops/layernorm.py | 6 +- vllm_ascend/ops/linear_op.py | 10 +- vllm_ascend/ops/register_custom_ops.py | 102 ------------------ vllm_ascend/ops/weight_prefetch.py | 90 +++++++++++++++- 16 files changed, 124 insertions(+), 174 deletions(-) diff --git a/docs/source/tutorials/Qwen3-Dense.md b/docs/source/tutorials/Qwen3-Dense.md index 86f13ee3811..e543dc7d5d6 100644 --- a/docs/source/tutorials/Qwen3-Dense.md +++ b/docs/source/tutorials/Qwen3-Dense.md @@ -171,9 +171,6 @@ export TASK_QUEUE_ENABLE=1 # Enable the AIVector core to directly schedule ROCE communication export HCCL_OP_EXPANSION_MODE="AIV" -# Enable MLP prefetch for better performance. -export VLLM_ASCEND_ENABLE_PREFETCH_MLP=1 - # Enable FlashComm_v1 optimization when tensor parallel is enabled. export VLLM_ASCEND_ENABLE_FLASHCOMM1=1 @@ -187,7 +184,7 @@ vllm serve /model/Qwen3-32B-W8A8 \ --max-model-len 5500 \ --max-num-batched-tokens 40960 \ --compilation-config '{"cudagraph_mode": "FULL_DECODE_ONLY"}' \ - --additional-config '{"pa_shape_list":[48,64,72,80]}' \ + --additional-config '{"pa_shape_list":[48,64,72,80], "weight_prefetch_config":{"enabled":true}}' \ --port 8113 \ --block-size 128 \ --gpu-memory-utilization 0.9 @@ -350,8 +347,6 @@ In dense model scenarios, the MLP's gate_up_proj and down_proj linear layers oft It is important to emphasize that, since we use vector computations to hide the weight prefetching pipeline, the setting of the prefetch buffer size is crucial. If the buffer size is too small, the optimization benefits will not be fully realized, while a larger buffer size may lead to resource contention, resulting in performance degradation. To accommodate different scenarios, we have exposed two environment variables `VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE` and `VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE` to allow for flexible buffer size configuration based on the specific workload. -This optimization requires setting the environment variable `VLLM_ASCEND_ENABLE_PREFETCH_MLP = 1` to be enabled. - ### 6. Zerolike Elimination This elimination removes unnecessary operations related to zero-like tensors in Attention forward, improving the efficiency of matrix operations and reducing memory usage. diff --git a/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py b/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py index 77e28ece331..d840be3aa37 100644 --- a/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py @@ -222,7 +222,7 @@ def test_qwen3_dense_fc1_tp2(model): @pytest.mark.parametrize("model", QWEN_DENSE_MODELS) -@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_PREFETCH_MLP": "1"}) +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"}) def test_qwen3_dense_prefetch_mlp_weight_tp2(model): example_prompts = [ "Hello, my name is", @@ -236,6 +236,7 @@ def test_qwen3_dense_prefetch_mlp_weight_tp2(model): tensor_parallel_size=2, cudagraph_capture_sizes=[1, 2, 4, 8], quantization="ascend", + additional_config={"weight_prefetch_config": {"enabled": True}}, ) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/tests/e2e/multicard/2-cards/test_qwen3_performance.py b/tests/e2e/multicard/2-cards/test_qwen3_performance.py index ef30db68e5d..c6e6378e4d2 100644 --- a/tests/e2e/multicard/2-cards/test_qwen3_performance.py +++ b/tests/e2e/multicard/2-cards/test_qwen3_performance.py @@ -57,7 +57,6 @@ async def test_models(model: str) -> None: env_dict = { "TASK_QUEUE_ENABLE": "1", "HCCL_OP_EXPANSION_MODE": "AIV", - "VLLM_ASCEND_ENABLE_PREFETCH_MLP": "1", } server_args = [ "--async-scheduling", @@ -74,7 +73,7 @@ async def test_models(model: str) -> None: "--compilation-config", '{"cudagraph_mode": "FULL_DECODE_ONLY"}', "--additional-config", - '{"pa_shape_list":[48,64,72,80]}', + '{"pa_shape_list":[48,64,72,80],"weight_prefetch_config":{"enabled":true}}', "--block-size", "128", "--trust-remote-code", diff --git a/tests/e2e/nightly/single_node/models/test_qwen3_32b_int8.py b/tests/e2e/nightly/single_node/models/test_qwen3_32b_int8.py index f486f90a633..a9c7d7a891c 100644 --- a/tests/e2e/nightly/single_node/models/test_qwen3_32b_int8.py +++ b/tests/e2e/nightly/single_node/models/test_qwen3_32b_int8.py @@ -83,7 +83,6 @@ async def test_models(model: str, mode: str, tp_size: int) -> None: "TASK_QUEUE_ENABLE": "1", "HCCL_OP_EXPANSION_MODE": "AIV", "VLLM_ASCEND_ENABLE_FLASHCOMM": "1", - "VLLM_ASCEND_ENABLE_PREFETCH_MLP": "1" } compilation_config = { "cudagraph_mode": @@ -98,7 +97,8 @@ async def test_models(model: str, mode: str, tp_size: int) -> None: str(port), "--max-model-len", "40960", "--max-num-batched-tokens", "40960", "--block-size", "128", "--trust-remote-code", "--reasoning-parser", "qwen3", "--gpu-memory-utilization", "0.9", - "--async-scheduling" + "--async-scheduling", "--additional-config", + '{"weight_prefetch_config":{"enabled":true}}', ] if mode == "single": server_args.append("--enforce-eager") diff --git a/tests/e2e/nightly/single_node/models/test_qwen3_32b_int8_a3_feature_stack3.py b/tests/e2e/nightly/single_node/models/test_qwen3_32b_int8_a3_feature_stack3.py index 2edae4005d9..3410f31cb3e 100644 --- a/tests/e2e/nightly/single_node/models/test_qwen3_32b_int8_a3_feature_stack3.py +++ b/tests/e2e/nightly/single_node/models/test_qwen3_32b_int8_a3_feature_stack3.py @@ -72,7 +72,6 @@ async def test_models(model: str, tp_size: int) -> None: "OMP_PROC_BIND": "false", "VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE": "1", "VLLM_ASCEND_ENABLE_FLASHCOMM": "1", - "VLLM_ASCEND_ENABLE_PREFETCH_MLP": "1" } server_args = [ "--quantization", "ascend", "--tensor-parallel-size", @@ -82,7 +81,8 @@ async def test_models(model: str, tp_size: int) -> None: "0.9", "--block-size", "128", "--max-num-seqs", "256", "--enforce-eager", "--max-model-len", "35840", "--max-num-batched-tokens", "35840", "--additional-config", - '{"enable_weight_nz_layout":true}', "--compilation-config", + '{"enable_weight_nz_layout":true, "weight_prefetch_config":{"enabled": true}}', + "--compilation-config", '{"cudagraph_mode":"FULL_DECODE_ONLY", "cudagraph_capture_sizes":[1,8,24,48,60]}' ] with RemoteOpenAIServer(model, diff --git a/tests/e2e/nightly/single_node/models/test_qwq_32b.py b/tests/e2e/nightly/single_node/models/test_qwq_32b.py index c0998343320..b5bcb89b5a6 100644 --- a/tests/e2e/nightly/single_node/models/test_qwq_32b.py +++ b/tests/e2e/nightly/single_node/models/test_qwq_32b.py @@ -75,8 +75,7 @@ async def test_models(model: str, mode: str, tp_size: int) -> None: "OMP_PROC_BIND": "false", "HCCL_OP_EXPANSION_MODE": "AIV", "VLLM_ASCEND_ENABLE_FLASHCOMM": "1", - "VLLM_ASCEND_ENABLE_DEBSE_OPTIMIZE": "1", - "VLLM_ASCEND_ENABLE_PREFETCH_MLP": "1" + "VLLM_ASCEND_ENABLE_DEBSE_OPTIMIZE": "1" } server_args = [ "--tensor-parallel-size", @@ -86,7 +85,7 @@ async def test_models(model: str, mode: str, tp_size: int) -> None: "--gpu-memory-utilization", "0.9", "--compilation_config", '{"cudagraph_mode":"FULL_DECODE_ONLY", "cudagraph_capture_sizes": [1, 8, 24, 48, 60]}', "--reasoning-parser", "deepseek_r1", "--distributed_executor_backend", - "mp" + "mp", "--additional-config", '{"weight_prefetch_config":{"enabled":true}}' ] if mode == "single": server_args.remove("--compilation_config") diff --git a/tests/ut/ops/test_activation.py b/tests/ut/ops/test_activation.py index d05c7df128d..db668868651 100644 --- a/tests/ut/ops/test_activation.py +++ b/tests/ut/ops/test_activation.py @@ -54,11 +54,7 @@ def test_QuickGELU_forward(mock_gelu, dummy_tensor, default_vllm_config): @pytest.mark.skipif(is_310p_hw(), reason="non_310P device unittest case.") @patch("torch_npu.npu_swiglu", side_effect=lambda x: x + 1) -@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None) -@patch("torch.ops.vllm.maybe_prefetch_mlp_down_proj", side_effect=lambda x: None) def test_SiluAndMul_forward( - mock_maybe_prefetch_mlp_down_proj, - mock_maybe_wait_prefetch_done, mock_swiglu, dummy_tensor, default_vllm_config, @@ -67,15 +63,9 @@ def test_SiluAndMul_forward( out = layer.forward(dummy_tensor) expected_arg = dummy_tensor - # assert mock_maybe_prefetch_mlp_down_proj.call_count == 1 - mock_maybe_prefetch_mlp_down_proj.assert_called_once() - # assert mock_swiglu.call_count == 1 mock_swiglu.assert_called_once() - # assert mock_maybe_wait_prefetch_done.call_count == 1 - mock_maybe_wait_prefetch_done.assert_called_once() - actual_arg = mock_swiglu.call_args[0][0] assert torch.allclose(actual_arg, expected_arg), "npu_swiglu called with unexpected input" @@ -85,11 +75,7 @@ def test_SiluAndMul_forward( @pytest.mark.skipif(not is_310p_hw(), reason="310P device unittest case.") @patch("torch.nn.functional.silu", side_effect=lambda x: x + 1) -@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None) -@patch("torch.ops.vllm.maybe_prefetch_mlp_down_proj", side_effect=lambda x: None) def test_SiluAndMul_forward_310p( - mock_maybe_prefetch_mlp_down_proj, - mock_maybe_wait_prefetch_done, mock_silu, dummy_tensor, default_vllm_config, @@ -99,15 +85,9 @@ def test_SiluAndMul_forward_310p( h = dummy_tensor.shape[-1] // 2 expected_arg = dummy_tensor[..., :h] - # assert mock_maybe_prefetch_mlp_down_proj.call_count == 1 - mock_maybe_prefetch_mlp_down_proj.assert_called_once() - # assert mock_silu.call_count == 1 mock_silu.assert_called_once() - # assert mock_maybe_wait_prefetch_done.call_count == 1 - mock_maybe_wait_prefetch_done.assert_called_once() - actual_arg = mock_silu.call_args[0][0] assert torch.allclose(actual_arg, expected_arg), "swiglu called with unexpected input" diff --git a/vllm_ascend/_310p/ops/activation.py b/vllm_ascend/_310p/ops/activation.py index 241a955f7b2..ad0ff94ad75 100644 --- a/vllm_ascend/_310p/ops/activation.py +++ b/vllm_ascend/_310p/ops/activation.py @@ -19,12 +19,16 @@ import torch.nn.functional as F from vllm_ascend.ops.activation import AscendSiluAndMul +from vllm_ascend.utils import get_weight_prefetch_method class AscendSiluAndMul310(AscendSiluAndMul): def forward(self, x: torch.Tensor) -> torch.Tensor: - torch.ops.vllm.maybe_prefetch_mlp_down_proj(x) + weight_prefetch_method = get_weight_prefetch_method() + if weight_prefetch_method: + weight_prefetch_method.maybe_prefetch_mlp_weight_preprocess(weight_prefetch_method.MLP_DOWN, x) h = x.shape[-1] // 2 - out = F.silu(x[..., :h]) * x[..., h:] - torch.ops.vllm.maybe_wait_prefetch_done(out) + out = (F.silu(x[..., :h].to(torch.float32)) * x[..., h:].to(torch.float32)).to(torch.float16) + if weight_prefetch_method: + weight_prefetch_method.maybe_prefetch_mlp_weight_postprocess(out) return out diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index c2420a46e3e..9c45f1d0178 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -311,6 +311,7 @@ class WeightPrefetchConfig: "o": 1.0, }, "moe": {"gate_up": 0.8}, + "mlp": {"gate_up": 1, "down": 1.0}, } def __init__(self, weight_prefetch_config: dict): diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 04ffa7b30e5..61d3f97543c 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -117,18 +117,8 @@ def set_ascend_forward_context( if has_layer_idx(model_instance): forward_context.layer_idx = model_instance.model.start_layer - # TODO(rjg-lyh): refactor mlp weight prefetch method - # set for mlp weight prefetch - prefetch_mlp_enabled = ( - envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP - and forward_context.layer_idx is not None - and num_tokens is not None - and num_tokens < 500 - ) - if prefetch_mlp_enabled: - forward_context.prefetch_mlp_gate_up_proj = False - forward_context.prefetch_mlp_down_proj = False - forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled + forward_context.prefetch_mlp_gate_up_proj = False + forward_context.prefetch_mlp_down_proj = False forward_context.model_instance = model_instance forward_context.is_draft_model = is_draft_model diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 3e5c4cc7a24..257523559dd 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -75,16 +75,6 @@ # For a detailed introduction to the parameters and the differences and applicable scenarios # between this feature and FLASHCOMM1, please refer to the feature guide in the documentation. "VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": lambda: int(os.getenv("VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE", 0)), - # Whether to enable MLP weight prefetch, only used in small concurrency. - "VLLM_ASCEND_ENABLE_PREFETCH_MLP": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", "0"))), - # buffer size for gate up prefetch - "VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE": lambda: int( - os.getenv("VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024) - ), - # buffer size for down proj prefetch - "VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE": lambda: int( - os.getenv("VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024) - ), # Whether to enable msMonitor tool to monitor the performance of vllm-ascend. "MSMONITOR_USE_DAEMON": lambda: bool(int(os.getenv("MSMONITOR_USE_DAEMON", "0"))), # Whether to enable MLAPO optimization for DeepSeek W8A8 series models. diff --git a/vllm_ascend/ops/activation.py b/vllm_ascend/ops/activation.py index a605b87c89c..a22513fee16 100644 --- a/vllm_ascend/ops/activation.py +++ b/vllm_ascend/ops/activation.py @@ -17,7 +17,7 @@ import torch from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul - +from vllm_ascend.utils import get_weight_prefetch_method class AscendQuickGELU(QuickGELU): @@ -33,7 +33,10 @@ class AscendSiluAndMul(SiluAndMul): def forward_oot(self, x: torch.Tensor) -> torch.Tensor: import torch_npu - torch.ops.vllm.maybe_prefetch_mlp_down_proj(x) + weight_prefetch_method = get_weight_prefetch_method() + if weight_prefetch_method: + weight_prefetch_method.maybe_prefetch_mlp_weight_preprocess(weight_prefetch_method.MLP_DOWN, x) out = torch_npu.npu_swiglu(x) - torch.ops.vllm.maybe_wait_prefetch_done(out) + if weight_prefetch_method: + weight_prefetch_method.maybe_prefetch_mlp_weight_postprocess(out) return out diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index dd676e8f498..fa7ef0ae92e 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -24,7 +24,7 @@ from vllm_ascend.ops.triton.layernorm_gated import layer_norm_fwd_npu from vllm_ascend.utils import enable_custom_op - +from vllm_ascend.utils import get_weight_prefetch_method class AscendRMSNorm(RMSNorm): @@ -67,6 +67,10 @@ def forward_oot( self.variance_epsilon) if self.bias is not None: x.add_(self.bias) + + weight_prefetch_method = get_weight_prefetch_method() + if weight_prefetch_method: + weight_prefetch_method.maybe_prefetch_mlp_weight_postprocess(x) return x diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index e4afa07c8a4..1b28f6d2658 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -65,8 +65,8 @@ from vllm_ascend.utils import (enable_dsa_cp, enable_dsa_cp_with_layer_shard, enable_sp, flashcomm2_enable, get_flashcomm2_reorgnized_batch_ids, matmul_allreduce_enable, mlp_tp_enable, - oproj_tp_enable, shared_expert_dp_enabled) - + oproj_tp_enable, shared_expert_dp_enabled, + get_weight_prefetch_method) class CustomLinearOp: @@ -138,8 +138,10 @@ def update_attrs(self): def apply(self, input_): output, output_bias = self.apply_impl(input_) - if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP: - torch.ops.vllm.maybe_prefetch_mlp_gate_up_proj(output, self.prefix) + weight_prefetch_method = get_weight_prefetch_method() + if weight_prefetch_method: + weight_prefetch_method.maybe_prefetch_mlp_weight_preprocess(weight_prefetch_method.MLP_GATE_UP, output, self.prefix) + if not self.return_bias: return output return output, output_bias diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index 0e43a5989a1..11d6e8ed699 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -110,33 +110,6 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, 0) -def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor, - prefix: str) -> None: - try: - forward_context = get_forward_context() - except AssertionError: - return - - if not getattr(forward_context, 'prefetch_mlp_enabled', False): - return - model_instance = forward_context.model_instance - weight_prefetch_stream = prefetch_stream() - layer_idx = int(prefix.split('.')[2]) - - # start point of gate_up_proj weight prefetch - if prefix.split('.')[-2] == "self_attn": - forward_context.prefetch_mlp_gate_up_proj = True - if forward_context.prefetch_mlp_gate_up_proj: - weight_prefetch_stream.wait_stream(torch.npu.current_stream()) - - with torch.npu.stream(weight_prefetch_stream): - mlp_gate_up_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE - torch_npu.npu_prefetch( - model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight, - x_dependency, mlp_gate_up_prefetch_size) - return - - def _maybe_all_gather_and_maybe_unpad_fake( x: torch.Tensor, label: bool, @@ -164,63 +137,6 @@ def _maybe_pad_and_reduce_fake(x: torch.Tensor, return x -def _maybe_prefetch_mlp_gate_up_proj_impl_fake(x_dependency: torch.Tensor, - prefix: str) -> None: - return - - -def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None: - try: - forward_context = get_forward_context() - except AssertionError: - return - - if not getattr(forward_context, 'prefetch_mlp_enabled', False): - return - forward_context.prefetch_mlp_down_proj = True - model_instance = forward_context.model_instance - weight_prefetch_stream = prefetch_stream() - layer_idx = forward_context.layer_idx - - # start point of down_proj weight prefetch - weight_prefetch_stream.wait_stream(torch.npu.current_stream()) - - with torch.npu.stream(weight_prefetch_stream): - mlp_down_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE - torch_npu.npu_prefetch( - model_instance.model.layers[layer_idx].mlp.down_proj.weight, - x_dependency, mlp_down_prefetch_size) - forward_context.layer_idx += 1 - return - - -def _maybe_prefetch_mlp_down_proj_impl_fake( - x_dependency: torch.Tensor) -> None: - return - - -def _maybe_wait_prefetch_done_impl(x: torch.Tensor) -> None: - try: - forward_context = get_forward_context() - except AssertionError: - return - - if not getattr(forward_context, 'prefetch_mlp_enabled', False): - return - if forward_context.prefetch_mlp_gate_up_proj or \ - forward_context.prefetch_mlp_down_proj: - weight_prefetch_stream = prefetch_stream() - # wait until prefetch done - torch.npu.current_stream().wait_stream(weight_prefetch_stream) - forward_context.prefetch_mlp_gate_up_proj = False - forward_context.prefetch_mlp_down_proj = False - return - - -def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None: - return - - def _prefetch_preprocess_impl(weight: torch.Tensor, start_flag: torch.Tensor, max_weight_size: int) -> None: calculation_stream = torch_npu.npu.current_stream() @@ -331,24 +247,6 @@ def _rope_forward_triton_fake( mutates_args=[], dispatch_key="PrivateUse1") -direct_register_custom_op(op_name="maybe_prefetch_mlp_gate_up_proj", - op_func=_maybe_prefetch_mlp_gate_up_proj_impl, - fake_impl=_maybe_prefetch_mlp_gate_up_proj_impl_fake, - mutates_args=[], - dispatch_key="PrivateUse1") - -direct_register_custom_op(op_name="maybe_prefetch_mlp_down_proj", - op_func=_maybe_prefetch_mlp_down_proj_impl, - fake_impl=_maybe_prefetch_mlp_down_proj_impl_fake, - mutates_args=[], - dispatch_key="PrivateUse1") - -direct_register_custom_op(op_name="maybe_wait_prefetch_done", - op_func=_maybe_wait_prefetch_done_impl, - fake_impl=_maybe_wait_prefetch_done_impl_fake, - mutates_args=[], - dispatch_key="PrivateUse1") - direct_register_custom_op(op_name="prefetch_preprocess", op_func=_prefetch_preprocess_impl, fake_impl=_prefetch_preprocess_impl_fake, diff --git a/vllm_ascend/ops/weight_prefetch.py b/vllm_ascend/ops/weight_prefetch.py index 42ff7e01703..3ae0af80f08 100644 --- a/vllm_ascend/ops/weight_prefetch.py +++ b/vllm_ascend/ops/weight_prefetch.py @@ -2,15 +2,18 @@ import torch import torch_npu -from vllm.forward_context import get_forward_context +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.config import get_current_vllm_config +from vllm.logger import logger from vllm_ascend.ascend_config import WeightPrefetchConfig from vllm_ascend.ops.linear import (AscendQKVParallelLinear, AscendRowParallelLinear) +from vllm_ascend.utils import is_moe_model SUPPORTED_MODULES = ["attn", "mlp", "moe"] MOE_PREFETCH_TOKEN_THRESHOLD = 96 - +MAX_PREFETCH_WEIGHT_SIZE = 18 * 1024 * 1024 @dataclass class ModuleWeightPrefetchConfig: @@ -38,8 +41,13 @@ class WeightPrefetchMethod: """ Unified weight prefetch method. """ + is_moe: bool = True + MLP_GATE_UP: str = "gate_up" + MLP_DOWN: str = "down" def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None: + self.is_moe = is_moe_model(get_current_vllm_config()) + self.attn = ModuleWeightPrefetchConfig( module_name="attn", enable=weight_prefetch_config.enabled, @@ -51,10 +59,16 @@ def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None: }) self.moe = ModuleWeightPrefetchConfig( module_name="moe", - enable=weight_prefetch_config.enabled, + enable=weight_prefetch_config.enabled and self.is_moe, prefetch_ratio=weight_prefetch_config.prefetch_ratio.get( "moe", {})) + self.mlp = ModuleWeightPrefetchConfig( + module_name="mlp", + enable=weight_prefetch_config.enabled and not self.is_moe, + prefetch_ratio=weight_prefetch_config.prefetch_ratio.get( + "mlp", {}) or {'gate_up': 1.0, 'down': 1.0}) + def maybe_prefetch_attn_weight_preprocess( self, layer_cls_name: str, weight: torch.Tensor, start_flag: torch.Tensor) -> None: @@ -97,6 +111,76 @@ def maybe_prefetch_moe_weight_postprocess(self, stop_flag: torch.Tensor): torch.ops.vllm.prefetch_postprocess(stop_flag) + # x_dependency only eager mode can pass None + def maybe_prefetch_mlp_weight_preprocess(self, prefetch_layer_name: str, x_dependency: torch.Tensor | None, curr_layer_prefix: str | None = None): + if not self.mlp.enable: + self.mlp.is_active_this_forward = False + return + + try: + forward_context = get_forward_context() + except AssertionError: + return + self.mlp.is_active_this_forward = ( + forward_context.layer_idx is not None + and forward_context.num_tokens is not None + and forward_context.num_tokens < 500 + ) + if not self.mlp.is_active_this_forward: + return + + if prefetch_layer_name == self.MLP_GATE_UP: + self._maybe_prefetch_mlp_gate_up_weight_preprocess(x_dependency, forward_context, curr_layer_prefix) + elif prefetch_layer_name == self.MLP_DOWN: + self._maybe_prefetch_mlp_down_weight_preprocess(x_dependency, forward_context) + else: + raise ValueError(f"Unsupported prefetch weight name: {prefetch_weight_name}") + + def _maybe_prefetch_mlp_gate_up_weight_preprocess(self, x_dependency: torch.Tensor, forward_context: ForwardContext, curr_layer_prefix: str): + if not curr_layer_prefix: + raise ValueError("curr_layer_prefix must been specified when prefetching mlp gate_up_proj weight") + + # start point of gate_up_proj weight prefetch + if curr_layer_prefix.split('.')[-2] == "self_attn": + model_instance = forward_context.model_instance + layer_idx = int(curr_layer_prefix.split('.')[2]) + weight = model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight + weight_size = weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get("gate_up", 0) + if weight_size > MAX_PREFETCH_WEIGHT_SIZE: + weight_size = MAX_PREFETCH_WEIGHT_SIZE + torch.ops.vllm.prefetch_preprocess(weight=weight, + start_flag=x_dependency, + max_weight_size=int(weight_size)) + forward_context.prefetch_mlp_gate_up_proj = True + + def _maybe_prefetch_mlp_down_weight_preprocess(self, x_dependency: torch.Tensor, forward_context: ForwardContext): + layer_idx = forward_context.layer_idx + model_instance = forward_context.model_instance + weight = model_instance.model.layers[layer_idx].mlp.down_proj.weight + weight_size = weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get("down", 0) + if weight_size > MAX_PREFETCH_WEIGHT_SIZE: + weight_size = MAX_PREFETCH_WEIGHT_SIZE + torch.ops.vllm.prefetch_preprocess(weight=weight, + start_flag=x_dependency, + max_weight_size=int(weight_size)) + forward_context.prefetch_mlp_down_proj = True + forward_context.layer_idx += 1 + + def maybe_prefetch_mlp_weight_postprocess(self, stop_flag: torch.Tensor): + if not self.mlp.is_active_this_forward: + return + + try: + forward_context = get_forward_context() + except AssertionError: + return + + if forward_context.prefetch_mlp_gate_up_proj or \ + forward_context.prefetch_mlp_down_proj: + torch.ops.vllm.prefetch_postprocess(stop_flag) + forward_context.prefetch_mlp_gate_up_proj = False + forward_context.prefetch_mlp_down_proj = False + def maybe_npu_prefetch(inputs: torch.Tensor, dependency: torch.Tensor, From de318780b5c510ae4c64ee8b34914350f77f7d65 Mon Sep 17 00:00:00 2001 From: leo-pony Date: Mon, 2 Feb 2026 16:37:59 +0800 Subject: [PATCH 2/7] Remove unused description in qwen3-dense.md Signed-off-by: leo-pony --- docs/source/tutorials/Qwen3-Dense.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/tutorials/Qwen3-Dense.md b/docs/source/tutorials/Qwen3-Dense.md index e543dc7d5d6..5b9d7d6a3c5 100644 --- a/docs/source/tutorials/Qwen3-Dense.md +++ b/docs/source/tutorials/Qwen3-Dense.md @@ -345,7 +345,7 @@ Weight prefetching optimizes memory usage by preloading weights into the cache b In dense model scenarios, the MLP's gate_up_proj and down_proj linear layers often exhibit relatively high MTE utilization. To address this, we create a separate pipeline specifically for weight prefetching, which runs in parallel with the original vector computation pipeline, such as RMSNorm and SiLU, before the MLP. This approach allows the weights to be preloaded to L2 cache ahead of time, reducing MTE utilization during the MLP computations and indirectly improving Cube computation efficiency by minimizing resource contention and optimizing data flow. -It is important to emphasize that, since we use vector computations to hide the weight prefetching pipeline, the setting of the prefetch buffer size is crucial. If the buffer size is too small, the optimization benefits will not be fully realized, while a larger buffer size may lead to resource contention, resulting in performance degradation. To accommodate different scenarios, we have exposed two environment variables `VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE` and `VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE` to allow for flexible buffer size configuration based on the specific workload. +For details, please refer to the weight_prefetch_config section in additional_config. ### 6. Zerolike Elimination From 2fbab466cdbe691d4c199e960c7e0498b146ab74 Mon Sep 17 00:00:00 2001 From: leo-pony Date: Tue, 3 Feb 2026 11:57:02 +0800 Subject: [PATCH 3/7] Add backward cenv config cmpitable to MLP prefetch Signed-off-by: leo-pony --- docs/source/tutorials/Qwen3-Dense.md | 2 +- vllm_ascend/ascend_config.py | 37 +++++++++++++++++++++++++--- vllm_ascend/envs.py | 10 ++++++++ vllm_ascend/ops/weight_prefetch.py | 16 +++++++++--- 4 files changed, 58 insertions(+), 7 deletions(-) diff --git a/docs/source/tutorials/Qwen3-Dense.md b/docs/source/tutorials/Qwen3-Dense.md index 5b9d7d6a3c5..6ae2a57c528 100644 --- a/docs/source/tutorials/Qwen3-Dense.md +++ b/docs/source/tutorials/Qwen3-Dense.md @@ -345,7 +345,7 @@ Weight prefetching optimizes memory usage by preloading weights into the cache b In dense model scenarios, the MLP's gate_up_proj and down_proj linear layers often exhibit relatively high MTE utilization. To address this, we create a separate pipeline specifically for weight prefetching, which runs in parallel with the original vector computation pipeline, such as RMSNorm and SiLU, before the MLP. This approach allows the weights to be preloaded to L2 cache ahead of time, reducing MTE utilization during the MLP computations and indirectly improving Cube computation efficiency by minimizing resource contention and optimizing data flow. -For details, please refer to the weight_prefetch_config section in additional_config. +Previously, the environment variables VLLM_ASCEND_ENABLE_PREFETCH_MLP used to enable MLP weight prefetch and VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE and VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE used to set the weight prefetch size for MLP gate_up_proj and down_proj were deprecated. Please use the following configuration instead: "weight_prefetch_config": { "enabled": true, "prefetch_ratio": { "mlp": { "gate_up": 1.0, "down": 1.0}}}. See User Guide->Configuration Guide->Additional Configuration for details. ### 6. Zerolike Elimination diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 9c45f1d0178..5521b2527e8 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import warnings from typing import TYPE_CHECKING from vllm.logger import logger @@ -48,9 +49,7 @@ def __init__(self, vllm_config: "VllmConfig"): # Dump / PrecisionDebugger configuration self.dump_config_path = additional_config.get("dump_config_path", None) - - weight_prefetch_config = additional_config.get("weight_prefetch_config", {}) - self.weight_prefetch_config = WeightPrefetchConfig(weight_prefetch_config) + self._construct_weight_prefetch_config(additional_config) self.layer_sharding = additional_config.get("layer_sharding", None) logger.info_once( f"Linear layer sharding enabled with config: {self.layer_sharding}. " @@ -138,6 +137,29 @@ def __init__(self, vllm_config: "VllmConfig"): "enable_kv_nz is only supported in pd scenario and can only be used in D node." ) + def _construct_weight_prefetch_config(self, additional_config): + weight_prefetch_config = additional_config.get("weight_prefetch_config", {}) + self.weight_prefetch_config = WeightPrefetchConfig(weight_prefetch_config) + # Deprecated env var handling for backward compatibility + if os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", "0") == "1": + MAX_PREFETCH_WEIGHT_SIZE: int = 18 * 1024 * 1024 + gate_up_prefetch_size = int(os.getenv("VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE", MAX_PREFETCH_WEIGHT_SIZE)) + down_prefetch_szie = int(os.getenv("VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE", MAX_PREFETCH_WEIGHT_SIZE)) + self.weight_prefetch_config.set_mlp_pre_version_compatibale_config( + gate_up_prefetch_size, down_prefetch_szie + ) + logger.info_once( + f"MLP weight prefetch enabled from env variable VLLM_ASCEND_ENABLE_PREFETCH_MLP." + f"gate_up_prefetch_size={gate_up_prefetch_size}, " + f"down_prefetch_szie={down_prefetch_szie}." + ) + warnings.warn( + "VLLM_ASCEND_ENABLE_PREFETCH_MLP is deprecated and will be removed in a v0.16.0 version. " + "Please use weight_prefetch_config in additional-config for now instead.", + DeprecationWarning, + stacklevel=2, + ) + class FinegrainedTPConfig: """ @@ -305,6 +327,8 @@ class WeightPrefetchConfig: Configuration Object for weight_prefetch_config from additional_config """ + mlp_pre_version_compatibale_config: dict = {} + prefetch_ratio: dict = { "attn": { "qkv": 1.0, @@ -318,6 +342,13 @@ def __init__(self, weight_prefetch_config: dict): self.enabled = weight_prefetch_config.get("enabled", False) self.prefetch_ratio = weight_prefetch_config.get("prefetch_ratio", self.prefetch_ratio) + def set_mlp_pre_version_compatibale_config(self, gate_up_prefetch_size: int, down_prefetch_size: int): + config = { + "gate_up": gate_up_prefetch_size, + "down": down_prefetch_size, + } + self.mlp_pre_version_compatibale_config = config + class EplbConfig: """ diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 257523559dd..3e5c4cc7a24 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -75,6 +75,16 @@ # For a detailed introduction to the parameters and the differences and applicable scenarios # between this feature and FLASHCOMM1, please refer to the feature guide in the documentation. "VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": lambda: int(os.getenv("VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE", 0)), + # Whether to enable MLP weight prefetch, only used in small concurrency. + "VLLM_ASCEND_ENABLE_PREFETCH_MLP": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", "0"))), + # buffer size for gate up prefetch + "VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE": lambda: int( + os.getenv("VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024) + ), + # buffer size for down proj prefetch + "VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE": lambda: int( + os.getenv("VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024) + ), # Whether to enable msMonitor tool to monitor the performance of vllm-ascend. "MSMONITOR_USE_DAEMON": lambda: bool(int(os.getenv("MSMONITOR_USE_DAEMON", "0"))), # Whether to enable MLAPO optimization for DeepSeek W8A8 series models. diff --git a/vllm_ascend/ops/weight_prefetch.py b/vllm_ascend/ops/weight_prefetch.py index 3ae0af80f08..45090b9efc8 100644 --- a/vllm_ascend/ops/weight_prefetch.py +++ b/vllm_ascend/ops/weight_prefetch.py @@ -45,6 +45,9 @@ class WeightPrefetchMethod: MLP_GATE_UP: str = "gate_up" MLP_DOWN: str = "down" + # backward compatibility: delete in future versions + mlp_pre_version_compatibale_config: dict = {} + def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None: self.is_moe = is_moe_model(get_current_vllm_config()) @@ -68,6 +71,7 @@ def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None: enable=weight_prefetch_config.enabled and not self.is_moe, prefetch_ratio=weight_prefetch_config.prefetch_ratio.get( "mlp", {}) or {'gate_up': 1.0, 'down': 1.0}) + self.mlp_pre_version_compatibale_config = weight_prefetch_config.mlp_pre_version_compatibale_config def maybe_prefetch_attn_weight_preprocess( self, layer_cls_name: str, weight: torch.Tensor, @@ -113,7 +117,7 @@ def maybe_prefetch_moe_weight_postprocess(self, stop_flag: torch.Tensor): # x_dependency only eager mode can pass None def maybe_prefetch_mlp_weight_preprocess(self, prefetch_layer_name: str, x_dependency: torch.Tensor | None, curr_layer_prefix: str | None = None): - if not self.mlp.enable: + if not self.mlp.enable and not self.mlp_pre_version_compatibale_config: self.mlp.is_active_this_forward = False return @@ -145,7 +149,10 @@ def _maybe_prefetch_mlp_gate_up_weight_preprocess(self, x_dependency: torch.Tens model_instance = forward_context.model_instance layer_idx = int(curr_layer_prefix.split('.')[2]) weight = model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight - weight_size = weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get("gate_up", 0) + if self.mlp_pre_version_compatibale_config: + weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_GATE_UP, 0) + else: + weight_size = weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get(self.MLP_GATE_UP, 0) if weight_size > MAX_PREFETCH_WEIGHT_SIZE: weight_size = MAX_PREFETCH_WEIGHT_SIZE torch.ops.vllm.prefetch_preprocess(weight=weight, @@ -157,7 +164,10 @@ def _maybe_prefetch_mlp_down_weight_preprocess(self, x_dependency: torch.Tensor, layer_idx = forward_context.layer_idx model_instance = forward_context.model_instance weight = model_instance.model.layers[layer_idx].mlp.down_proj.weight - weight_size = weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get("down", 0) + if self.mlp_pre_version_compatibale_config: + weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_DOWN, 0) + else: + weight_size = weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get(self.MLP_DOWN, 0) if weight_size > MAX_PREFETCH_WEIGHT_SIZE: weight_size = MAX_PREFETCH_WEIGHT_SIZE torch.ops.vllm.prefetch_preprocess(weight=weight, From 41e16f4ed2122fb3b95e16dfc61e1367e2c1d654 Mon Sep 17 00:00:00 2001 From: leo-pony Date: Tue, 3 Feb 2026 14:48:55 +0800 Subject: [PATCH 4/7] Add description to weight prefetch Signed-off-by: leo-pony --- docs/source/user_guide/configuration/additional_config.md | 7 ++++++- vllm_ascend/ops/weight_prefetch.py | 4 ++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index b5e54ed3706..d574f04d8d9 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -60,7 +60,12 @@ The details of each configuration option are as follows: | Name | Type | Default | Description | |------------------|------|-------------------------------------------------------------|------------------------------------| | `enabled` | bool | `False` | Whether to enable weight prefetch. | -| `prefetch_ratio` | dict | `{"attn": {"qkv": 1.0, "o": 1.0}, "moe": {"gate_up": 0.8}}` | Prefetch ratio of each weight. | +| `prefetch_ratio` | dict | `{"attn": {"qkv": 1.0, "o": 1.0}, "moe": {"gate_up": 0.8}, "mlp": { "gate_up": 1.0, "down": 1.0}}` | Prefetch ratio of each weight. | + +Weight prefetching optimizes memory usage by preloading weights into the cache before they are needed, minimizing delays caused by memory access during model execution. Linear layers sometimes exhibit relatively high MTE utilization. To address this, we create a separate pipeline specifically for weight prefetching, which runs in parallel with the original vector computation pipeline, such as quantize, MoE gating top_k, RMSNorm and SiLU. This approach allows the weights to be preloaded to L2 cache ahead of time, reducing MTE utilization during the linear layer computations and indirectly improving Cube computation efficiency by minimizing resource contention and optimizing data flow. +The “attn” and “moe” configuration options are used to optimize the performance of the MoE model, while the “mlp” configuration option is used to optimize the performance of the Dense model. +Additionally, if you prioritize low latency over high throughput, then do not enable prefetching. +It is important to emphasize that, since we use vector computations to hide the weight prefetching pipeline, the setting of the prefetch buffer size is crucial. If the buffer size is too small, the optimization benefits will not be fully realized, while a larger buffer size may lead to resource contention, resulting in performance degradation. **finegrained_tp_config** diff --git a/vllm_ascend/ops/weight_prefetch.py b/vllm_ascend/ops/weight_prefetch.py index 45090b9efc8..8958a8beea3 100644 --- a/vllm_ascend/ops/weight_prefetch.py +++ b/vllm_ascend/ops/weight_prefetch.py @@ -138,9 +138,9 @@ def maybe_prefetch_mlp_weight_preprocess(self, prefetch_layer_name: str, x_depen elif prefetch_layer_name == self.MLP_DOWN: self._maybe_prefetch_mlp_down_weight_preprocess(x_dependency, forward_context) else: - raise ValueError(f"Unsupported prefetch weight name: {prefetch_weight_name}") + raise ValueError(f"Unsupported prefetch weight name: {prefetch_layer_name}") - def _maybe_prefetch_mlp_gate_up_weight_preprocess(self, x_dependency: torch.Tensor, forward_context: ForwardContext, curr_layer_prefix: str): + def _maybe_prefetch_mlp_gate_up_weight_preprocess(self, x_dependency: torch.Tensor, forward_context: ForwardContext, curr_layer_prefix: str | None): if not curr_layer_prefix: raise ValueError("curr_layer_prefix must been specified when prefetching mlp gate_up_proj weight") From 4b453bed3aa40441532411488c3a3c99fdce5e21 Mon Sep 17 00:00:00 2001 From: leo-pony Date: Tue, 3 Feb 2026 15:02:01 +0800 Subject: [PATCH 5/7] Add new line for md Signed-off-by: leo-pony --- docs/source/user_guide/configuration/additional_config.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index d574f04d8d9..12a60aa3389 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -63,8 +63,11 @@ The details of each configuration option are as follows: | `prefetch_ratio` | dict | `{"attn": {"qkv": 1.0, "o": 1.0}, "moe": {"gate_up": 0.8}, "mlp": { "gate_up": 1.0, "down": 1.0}}` | Prefetch ratio of each weight. | Weight prefetching optimizes memory usage by preloading weights into the cache before they are needed, minimizing delays caused by memory access during model execution. Linear layers sometimes exhibit relatively high MTE utilization. To address this, we create a separate pipeline specifically for weight prefetching, which runs in parallel with the original vector computation pipeline, such as quantize, MoE gating top_k, RMSNorm and SiLU. This approach allows the weights to be preloaded to L2 cache ahead of time, reducing MTE utilization during the linear layer computations and indirectly improving Cube computation efficiency by minimizing resource contention and optimizing data flow. + The “attn” and “moe” configuration options are used to optimize the performance of the MoE model, while the “mlp” configuration option is used to optimize the performance of the Dense model. + Additionally, if you prioritize low latency over high throughput, then do not enable prefetching. + It is important to emphasize that, since we use vector computations to hide the weight prefetching pipeline, the setting of the prefetch buffer size is crucial. If the buffer size is too small, the optimization benefits will not be fully realized, while a larger buffer size may lead to resource contention, resulting in performance degradation. **finegrained_tp_config** From 3ace646c4acaaadd3047f388f85f2c6a9d0d2487 Mon Sep 17 00:00:00 2001 From: leo-pony Date: Tue, 3 Feb 2026 17:23:27 +0800 Subject: [PATCH 6/7] Add MLP Prefetch user guide Signed-off-by: leo-pony --- .../configuration/additional_config.md | 12 ++-- docs/source/user_guide/feature_guide/index.md | 1 + .../feature_guide/weight_prefetch.md | 58 +++++++++++++++++++ vllm_ascend/ops/weight_prefetch.py | 4 +- 4 files changed, 65 insertions(+), 10 deletions(-) create mode 100644 docs/source/user_guide/feature_guide/weight_prefetch.md diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index 12a60aa3389..169ffeabc6e 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -62,14 +62,6 @@ The details of each configuration option are as follows: | `enabled` | bool | `False` | Whether to enable weight prefetch. | | `prefetch_ratio` | dict | `{"attn": {"qkv": 1.0, "o": 1.0}, "moe": {"gate_up": 0.8}, "mlp": { "gate_up": 1.0, "down": 1.0}}` | Prefetch ratio of each weight. | -Weight prefetching optimizes memory usage by preloading weights into the cache before they are needed, minimizing delays caused by memory access during model execution. Linear layers sometimes exhibit relatively high MTE utilization. To address this, we create a separate pipeline specifically for weight prefetching, which runs in parallel with the original vector computation pipeline, such as quantize, MoE gating top_k, RMSNorm and SiLU. This approach allows the weights to be preloaded to L2 cache ahead of time, reducing MTE utilization during the linear layer computations and indirectly improving Cube computation efficiency by minimizing resource contention and optimizing data flow. - -The “attn” and “moe” configuration options are used to optimize the performance of the MoE model, while the “mlp” configuration option is used to optimize the performance of the Dense model. - -Additionally, if you prioritize low latency over high throughput, then do not enable prefetching. - -It is important to emphasize that, since we use vector computations to hide the weight prefetching pipeline, the setting of the prefetch buffer size is crucial. If the buffer size is too small, the optimization benefits will not be fully realized, while a larger buffer size may lead to resource contention, resulting in performance degradation. - **finegrained_tp_config** | Name | Type | Default | Description | @@ -123,6 +115,10 @@ An example of additional configuration is as follows: }, "moe": { "gate_up": 0.8 + }, + "mlp": { + "gate_up": 1.0, + "down": 1.0 } }, }, diff --git a/docs/source/user_guide/feature_guide/index.md b/docs/source/user_guide/feature_guide/index.md index a8318be48cb..0e15a081c37 100644 --- a/docs/source/user_guide/feature_guide/index.md +++ b/docs/source/user_guide/feature_guide/index.md @@ -23,4 +23,5 @@ layer_sharding speculative_decoding context_parallel npugraph_ex +weight_prefetch ::: diff --git a/docs/source/user_guide/feature_guide/weight_prefetch.md b/docs/source/user_guide/feature_guide/weight_prefetch.md new file mode 100644 index 00000000000..edfb15ac09b --- /dev/null +++ b/docs/source/user_guide/feature_guide/weight_prefetch.md @@ -0,0 +1,58 @@ +# Weight Prefetch Guide + +Weight prefetching optimizes memory usage by preloading weights into the cache before they are needed, minimizing delays caused by memory access during model execution. Linear layers sometimes exhibit relatively high MTE utilization. To address this, we create a separate pipeline specifically for weight prefetching, which runs in parallel with the original vector computation pipeline, such as quantize, MoE gating top_k, RMSNorm and SiLU. This approach allows the weights to be preloaded to L2 cache ahead of time, reducing MTE utilization during the linear layer computations and indirectly improving Cube computation efficiency by minimizing resource contention and optimizing data flow. + +Since we use vector computations to hide the weight prefetching pipeline, it has effect on computation, if you prioritize low latency over high throughput, then it it best not to enable prefetching. + +## How to Use + +With `--additional-config '{"weight_prefetch_config": {"enabled": true}}'` to open weight prefetch. +With `prefetch_ratio` in `"weight_prefetch_config"` to custom the weight prefetch ratio for specify linear layers. +The “attn” and “moe” configuration options are used for MoE model, detail as following: +`"attn": { "qkv": 1.0, "o": 1.0}, "moe": {"gate_up": 0.8}` +The “mlp” configuration option is used to optimize the performance of the Dense model, detail as following: + `"mlp": {"gate_up": 1.0, "down": 1.0}` + +Notices: + +1) Due to the current size of the L2 cache, the maximum prefetch cannot exceed 18MB. If `prefetch_ration * lineaer_layer_weight_size >= 18 * 1024 * 1024` bytes, the backend will only prefetch 18MB. +2) Weight prefetch of MLP `down` project prefetch dependence sequence parallel, if you want open for mlp `down` please also enable sequence parallel. + +## Example + +1) For MoE model: + +```shell + --additional-config \ + '{ + "weight_prefetch_config": { + "enabled": true, + "prefetch_ratio": { + "attn": { + "qkv": 1.0, + "o": 1.0 + }, + "moe": { + "gate_up": 0.8 + } + } + } + }' +``` + +2) For dense model: + +```shell + --additional-config \ + '{ + "weight_prefetch_config": { + "enabled": true, + "prefetch_ratio": { + "mlp": { + "gate_up": 1.0, + "down": 1.0 + } + } + } + }' +``` diff --git a/vllm_ascend/ops/weight_prefetch.py b/vllm_ascend/ops/weight_prefetch.py index 8958a8beea3..d01c3cfca39 100644 --- a/vllm_ascend/ops/weight_prefetch.py +++ b/vllm_ascend/ops/weight_prefetch.py @@ -55,7 +55,7 @@ def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None: module_name="attn", enable=weight_prefetch_config.enabled, prefetch_ratio=weight_prefetch_config.prefetch_ratio.get( - "attn", {}), + "attn", {}) or {'qkv': 1.0, 'o': 1.0}, linear_prefix_map={ AscendQKVParallelLinear.__name__: "qkv", AscendRowParallelLinear.__name__: "o", @@ -64,7 +64,7 @@ def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None: module_name="moe", enable=weight_prefetch_config.enabled and self.is_moe, prefetch_ratio=weight_prefetch_config.prefetch_ratio.get( - "moe", {})) + "moe", {})) or {'gate_up': 0.8} self.mlp = ModuleWeightPrefetchConfig( module_name="mlp", From 3e0bdad8a8d583198b44817862655088f25b39d9 Mon Sep 17 00:00:00 2001 From: leo-pony Date: Tue, 3 Feb 2026 21:16:48 +0800 Subject: [PATCH 7/7] Add prefeth ratio change method Signed-off-by: leo-pony --- docs/source/tutorials/Qwen3-Dense.md | 2 +- .../feature_guide/weight_prefetch.md | 23 +++++++++++++++---- vllm_ascend/ops/weight_prefetch.py | 2 +- 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/docs/source/tutorials/Qwen3-Dense.md b/docs/source/tutorials/Qwen3-Dense.md index 6ae2a57c528..cf99beb86a5 100644 --- a/docs/source/tutorials/Qwen3-Dense.md +++ b/docs/source/tutorials/Qwen3-Dense.md @@ -345,7 +345,7 @@ Weight prefetching optimizes memory usage by preloading weights into the cache b In dense model scenarios, the MLP's gate_up_proj and down_proj linear layers often exhibit relatively high MTE utilization. To address this, we create a separate pipeline specifically for weight prefetching, which runs in parallel with the original vector computation pipeline, such as RMSNorm and SiLU, before the MLP. This approach allows the weights to be preloaded to L2 cache ahead of time, reducing MTE utilization during the MLP computations and indirectly improving Cube computation efficiency by minimizing resource contention and optimizing data flow. -Previously, the environment variables VLLM_ASCEND_ENABLE_PREFETCH_MLP used to enable MLP weight prefetch and VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE and VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE used to set the weight prefetch size for MLP gate_up_proj and down_proj were deprecated. Please use the following configuration instead: "weight_prefetch_config": { "enabled": true, "prefetch_ratio": { "mlp": { "gate_up": 1.0, "down": 1.0}}}. See User Guide->Configuration Guide->Additional Configuration for details. +Previously, the environment variables VLLM_ASCEND_ENABLE_PREFETCH_MLP used to enable MLP weight prefetch and VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE and VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE used to set the weight prefetch size for MLP gate_up_proj and down_proj were deprecated. Please use the following configuration instead: "weight_prefetch_config": { "enabled": true, "prefetch_ratio": { "mlp": { "gate_up": 1.0, "down": 1.0}}}. See User Guide->Feature Guide->Weight Prefetch Guide for details. ### 6. Zerolike Elimination diff --git a/docs/source/user_guide/feature_guide/weight_prefetch.md b/docs/source/user_guide/feature_guide/weight_prefetch.md index edfb15ac09b..ad01d2f820b 100644 --- a/docs/source/user_guide/feature_guide/weight_prefetch.md +++ b/docs/source/user_guide/feature_guide/weight_prefetch.md @@ -1,22 +1,35 @@ # Weight Prefetch Guide -Weight prefetching optimizes memory usage by preloading weights into the cache before they are needed, minimizing delays caused by memory access during model execution. Linear layers sometimes exhibit relatively high MTE utilization. To address this, we create a separate pipeline specifically for weight prefetching, which runs in parallel with the original vector computation pipeline, such as quantize, MoE gating top_k, RMSNorm and SiLU. This approach allows the weights to be preloaded to L2 cache ahead of time, reducing MTE utilization during the linear layer computations and indirectly improving Cube computation efficiency by minimizing resource contention and optimizing data flow. +Weight prefetching optimizes memory usage by preloading weights into the cache before they are needed, minimizing delays caused by memory access during model execution. Linear layers sometimes exhibit relatively high MTE utilization. To address this, we create a separate pipeline specifically for weight prefetching, which runs in parallel with the original vector computation pipeline, such as quantize, MoE gating top_k, RMSNorm and SwiGlu. This approach allows the weights to be preloaded to L2 cache ahead of time, reducing MTE utilization during the linear layer computations and indirectly improving Cube computation efficiency by minimizing resource contention and optimizing data flow. Since we use vector computations to hide the weight prefetching pipeline, it has effect on computation, if you prioritize low latency over high throughput, then it it best not to enable prefetching. -## How to Use +## Quick Start With `--additional-config '{"weight_prefetch_config": {"enabled": true}}'` to open weight prefetch. + +## Fine-tune Prefetch Ratio + +Since weight prefetch use vector computations to hide the weight prefetching pipeline, the setting of the prefetch size is crucial. If the size is too small, the optimization benefits will not be fully realized, while a larger size may lead to resource contention, resulting in performance degradation. To accommodate different scenarios, we have add `prefetch_ratio` to allow for flexible size configuration based on the specific workload, detail as following: + With `prefetch_ratio` in `"weight_prefetch_config"` to custom the weight prefetch ratio for specify linear layers. + The “attn” and “moe” configuration options are used for MoE model, detail as following: + `"attn": { "qkv": 1.0, "o": 1.0}, "moe": {"gate_up": 0.8}` + The “mlp” configuration option is used to optimize the performance of the Dense model, detail as following: + `"mlp": {"gate_up": 1.0, "down": 1.0}` +Above value are the default config, the default value has a good performance for Qwen3-235B-A22B-W8A8 when `--max-num-seqs`is 144, for Qwen3-32B-W8A8 when `--max-num-seqs`is 72. + +However, this may not be the optimal configuration for your scenario. For higher concurrency, you can try increasing the prefetch size. For lower concurrency, prefetching may not offer any advantages, so you can decrease the size or disable prefetching. Determine if the prefetch size is appropriate by collecting profiling data. Specifically, check if the time required for the prefetch operation (e.g., MLP Down Proj weight prefetching) overlaps with the time required for parallel vector computation operators (e.g., SwiGlu computation), and whether the prefetch operation is no later than the completion time of the vector computation operator. In the profiling timeline, a prefetch operation appears as a CMO operation on a single stream; this CMO operation is the prefetch operation. + Notices: -1) Due to the current size of the L2 cache, the maximum prefetch cannot exceed 18MB. If `prefetch_ration * lineaer_layer_weight_size >= 18 * 1024 * 1024` bytes, the backend will only prefetch 18MB. -2) Weight prefetch of MLP `down` project prefetch dependence sequence parallel, if you want open for mlp `down` please also enable sequence parallel. +1) Weight prefetch of MLP `down` project prefetch dependence sequence parallel, if you want open for mlp `down` please also enable sequence parallel. +2) Due to the current size of the L2 cache, the maximum prefetch cannot exceed 18MB. If `prefetch_ration * lineaer_layer_weight_size >= 18 * 1024 * 1024` bytes, the backend will only prefetch 18MB. ## Example @@ -42,6 +55,8 @@ Notices: 2) For dense model: +Following is the default configuration that can get a good performance for `--max-num-seqs`is 72 for Qwen3-32B-W8A8 + ```shell --additional-config \ '{ diff --git a/vllm_ascend/ops/weight_prefetch.py b/vllm_ascend/ops/weight_prefetch.py index d01c3cfca39..e41390ee43e 100644 --- a/vllm_ascend/ops/weight_prefetch.py +++ b/vllm_ascend/ops/weight_prefetch.py @@ -64,7 +64,7 @@ def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None: module_name="moe", enable=weight_prefetch_config.enabled and self.is_moe, prefetch_ratio=weight_prefetch_config.prefetch_ratio.get( - "moe", {})) or {'gate_up': 0.8} + "moe", {}) or {'gate_up': 0.8}) self.mlp = ModuleWeightPrefetchConfig( module_name="mlp",