diff --git a/tests/ut/attention/test_mla_cp.py b/tests/ut/attention/test_mla_cp.py index e8b94c570d8..39696a832f5 100755 --- a/tests/ut/attention/test_mla_cp.py +++ b/tests/ut/attention/test_mla_cp.py @@ -248,9 +248,10 @@ def test_init(self): self.assertEqual(self.impl.dcp_size, 2) @patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad") - @patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch") + @patch("vllm_ascend.attention.mla_v1.get_weight_prefetch_method", + return_value=MagicMock()) @patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False) - def test_mla_preprocess_dcp(self, magic_npu_fetch, + def test_mla_preprocess_dcp(self, mock_get_weight_prefetch_method, mock_maybe_all_gather_and_maybe_unpad): self.impl.num_kv_heads = 1 @@ -309,7 +310,6 @@ def test_mla_preprocess_dcp(self, magic_npu_fetch, self.impl.qk_rope_head_dim) ] - magic_npu_fetch.return_value = MagicMock() mock_maybe_all_gather_and_maybe_unpad.side_effect = lambda x, label: x decode_res, prefill_res = self.impl._mla_preprocess( @@ -324,9 +324,10 @@ def test_mla_preprocess_dcp(self, magic_npu_fetch, @patch('torch_npu._npu_reshape_and_cache') @patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad") - @patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch") + @patch("vllm_ascend.attention.mla_v1.get_weight_prefetch_method", + return_value=MagicMock()) @patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False) - def test_mla_preprocess_pcp(self, magic_npu_fetch, + def test_mla_preprocess_pcp(self, mock_get_weight_prefetch_method, mock_maybe_all_gather_and_maybe_unpad, mock_npu_reshape_and_cache): self.impl.num_kv_heads = 1 @@ -389,7 +390,6 @@ def test_mla_preprocess_pcp(self, magic_npu_fetch, self.impl.qk_rope_head_dim) ] - magic_npu_fetch.return_value = MagicMock() mock_maybe_all_gather_and_maybe_unpad.side_effect = lambda x, label: x self.impl.kv_a_layernorm = MagicMock() diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 5c412a8ba55..4a7ee9e1a12 100755 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -967,10 +967,10 @@ def test_forward_decode_without_graph(self, mock_npu_fused_infer_attention_score.assert_called_once() @patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad") - @patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch") - def test_mla_preprocess(self, magic_npu_fetch, + @patch("vllm_ascend.attention.mla_v1.get_weight_prefetch_method", + return_value=MagicMock()) + def test_mla_preprocess(self, mock_get_weight_prefetch_method, mock_maybe_all_gather_and_maybe_unpad): - magic_npu_fetch.return_value = MagicMock() mock_maybe_all_gather_and_maybe_unpad.side_effect = lambda x, label: x batch_size = 4 seq_len = 8 diff --git a/tests/ut/ops/test_activation.py b/tests/ut/ops/test_activation.py index db668868651..a8467440daa 100644 --- a/tests/ut/ops/test_activation.py +++ b/tests/ut/ops/test_activation.py @@ -53,9 +53,12 @@ def test_QuickGELU_forward(mock_gelu, dummy_tensor, default_vllm_config): @pytest.mark.skipif(is_310p_hw(), reason="non_310P device unittest case.") +@patch("vllm_ascend.ops.activation.get_weight_prefetch_method", + return_value=MagicMock()) @patch("torch_npu.npu_swiglu", side_effect=lambda x: x + 1) def test_SiluAndMul_forward( mock_swiglu, + mock_get_weight_prefetch_method, dummy_tensor, default_vllm_config, ): diff --git a/tests/ut/ops/test_fused_moe.py b/tests/ut/ops/test_fused_moe.py index ffb85a197be..b82ca200406 100644 --- a/tests/ut/ops/test_fused_moe.py +++ b/tests/ut/ops/test_fused_moe.py @@ -296,6 +296,8 @@ def test_cumsum_group_list_with_type_2(self): class TestUnifiedApplyMLP(TestBase): + @patch('vllm_ascend.ops.fused_moe.moe_mlp.get_weight_prefetch_method', + return_value=MagicMock()) @patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context') @patch('vllm_ascend.utils.get_ascend_device_type', return_value=AscendDeviceType.A3) @@ -306,7 +308,8 @@ def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant, mock_npu_dynamic_quant, mock_npu_grouped_matmul, mock_soc_version, - mock_get_forward_context): + mock_get_forward_context, + mock_get_weight_prefetch_method): mock_forward_context = MagicMock() mock_forward_context.moe_comm_type = MoECommType.MC2 @@ -402,13 +405,16 @@ def test_unified_apply_mlp_without_quantization(self, self.assertEqual(result.dtype, torch.float16) @patch('vllm_ascend.ops.fused_moe.moe_mlp.HAS_TRITON', False) + @patch('vllm_ascend.ops.fused_moe.moe_mlp.get_weight_prefetch_method', + return_value=MagicMock()) @patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context') @patch('torch_npu.npu_grouped_matmul') @patch('torch_npu.npu_swiglu') @patch('torch_npu.npu_dynamic_quant') def test_unified_apply_mlp_with_quantization_and_dynamic_scale( self, mock_npu_dynamic_quant, mock_npu_swiglu, - mock_npu_grouped_matmul, mock_get_forward_context): + mock_npu_grouped_matmul, mock_get_forward_context, + mock_get_weight_prefetch_method): mock_forward_context = MagicMock() mock_forward_context.with_quant = True @@ -505,6 +511,8 @@ def test_unified_apply_mlp_without_quantization_310p( self.assertEqual(result.shape, hidden_states.shape) self.assertEqual(result.dtype, torch.float16) + @patch("vllm_ascend.ops.fused_moe.moe_mlp.get_weight_prefetch_method", + return_value=MagicMock()) @patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context") @patch("torch_npu.npu_grouped_matmul") @patch("torch_npu.npu_swiglu") @@ -513,7 +521,8 @@ def test_unified_apply_mlp_without_quantization_310p( def test_unified_apply_mlp_with_quantization_and_fusion_mlp( self, mock_npu_dynamic_quant, mock_npu_grouped_matmul_swiglu_quant, mock_npu_swiglu, mock_npu_grouped_matmul, - mock_get_forward_context): + mock_get_forward_context, + mock_get_weight_prefetch_method): mock_forward_context = MagicMock() mock_forward_context.with_quant = True diff --git a/tests/ut/ops/test_linear.py b/tests/ut/ops/test_linear.py index 4718002c421..74a480660bb 100644 --- a/tests/ut/ops/test_linear.py +++ b/tests/ut/ops/test_linear.py @@ -83,7 +83,9 @@ def test_process_weights_after_loading_with_nz2(self, mock_format_cast): class TestAscendRowParallelLinear(BaseLinearTest): - def test_mlp_optimize(self): + @patch("vllm_ascend.ops.linear_op.get_weight_prefetch_method", + return_value=MagicMock()) + def test_mlp_optimize(self, mock_get_weight_prefetch_method): ascend_config._ASCEND_CONFIG = MagicMock() ascend_config._ASCEND_CONFIG.recompute_scheduler_enable = False @@ -100,7 +102,9 @@ def test_mlp_optimize(self): input_tensor = torch.randn(16, 8) linear(input_tensor) - def test_oproj_tp(self): + @patch("vllm_ascend.ops.linear_op.get_weight_prefetch_method", + return_value=MagicMock()) + def test_oproj_tp(self, mock_get_weight_prefetch_method): config._current_vllm_config = MagicMock() diff --git a/vllm_ascend/_310p/fused_moe/experts_selector.py b/vllm_ascend/_310p/fused_moe/experts_selector.py index 71200c992de..e3f953403c2 100644 --- a/vllm_ascend/_310p/fused_moe/experts_selector.py +++ b/vllm_ascend/_310p/fused_moe/experts_selector.py @@ -57,8 +57,7 @@ def select_experts( """ # prefetch w1_w3_proj.weight preprocess weight_prefetch_method = get_weight_prefetch_method() - if weight_prefetch_method: - weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(hidden_states, "gate_up") + weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(hidden_states, "gate_up") topk_weights, topk_ids = _native_select_experts( hidden_states=hidden_states, router_logits=router_logits, diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 7d936ab23d7..34d8e250a1b 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -43,9 +43,13 @@ register_all_layers_to_shard_weight_series, ) from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla -from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.quantization.methods import AscendW8A8LinearMethod -from vllm_ascend.utils import ACL_FORMAT_FRACTAL_ND, maybe_trans_nz, weak_ref_tensors +from vllm_ascend.utils import ( + ACL_FORMAT_FRACTAL_ND, + get_weight_prefetch_method, + maybe_trans_nz, + weak_ref_tensors, +) from vllm_ascend.worker.npu_input_batch import NPUInputBatch if TYPE_CHECKING: @@ -703,7 +707,6 @@ def __init__( ascend_config = get_ascend_config() self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp - self.enable_prefetch = ascend_config.weight_prefetch_config.enabled self.enable_kv_nz = ascend_config.enable_kv_nz self.ring_mla_mask_size = 512 @@ -1412,8 +1415,9 @@ def _mla_preprocess(self, layer_name, hidden_states, kv_cache, attn_metadata, ne has_decode = attn_metadata.num_decodes > 0 has_prefill = attn_metadata.num_prefills > 0 if self.fused_qkv_a_proj is not None: - maybe_npu_prefetch( - inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states, enabled=self.enable_prefetch + weight_prefetch_method = get_weight_prefetch_method() + weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream( + inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states ) qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] q_c, kv_no_split = qkv_lora.split( @@ -1545,14 +1549,13 @@ def forward( o_proj_input[num_decode_tokens:num_actual_tokens] = output_prefill # O proj - MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 - maybe_npu_prefetch( + weight_prefetch_method = get_weight_prefetch_method() + weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream( inputs=self.o_proj.weight, dependency=o_proj_input, max_size=MAX_O_PROJ_PREFETCH_SIZE, - enabled=self.enable_prefetch, + linear_layer=self.o_proj, ) - output[...] = self.o_proj(o_proj_input, is_prefill=prefill_preprocess_res is not None)[0] del o_proj_input diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index c926fdf5423..26984947d9c 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -37,7 +37,6 @@ ) from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla from vllm_ascend.ops.triton.rope import rope_forward_triton -from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.quantization.methods import AscendW8A8LinearMethod from vllm_ascend.utils import ( ACL_FORMAT_FRACTAL_ND, @@ -45,6 +44,7 @@ dispose_layer, enable_dsa_cp, enable_dsa_cp_with_layer_shard, + get_weight_prefetch_method, maybe_trans_nz, ) from vllm_ascend.worker.npu_input_batch import NPUInputBatch @@ -410,7 +410,6 @@ def __init__( ascend_config = get_ascend_config() self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp - self.enable_prefetch = ascend_config.weight_prefetch_config.enabled # In sfa, prefill and decode have the same calculation formula, # so do not distinguish between prefill and decode here. @@ -800,8 +799,9 @@ def forward( ) else: assert self.fused_qkv_a_proj is not None, "q lora is required for DSA." - maybe_npu_prefetch( - inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states, enabled=self.enable_prefetch + weight_prefetch_method = get_weight_prefetch_method() + weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream( + inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states ) qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] q_c, kv_no_split = qkv_lora.split( @@ -917,11 +917,12 @@ def forward( ) attn_output = self._v_up_proj(attn_output) - maybe_npu_prefetch( + weight_prefetch_method = get_weight_prefetch_method() + weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream( inputs=self.o_proj.weight, dependency=attn_output, max_size=MAX_O_PROJ_PREFETCH_SIZE, - enabled=self.enable_prefetch, + linear_layer=self.o_proj, ) if self.enable_dsa_cp and not self.enable_dsa_cp_prefill_only: diff --git a/vllm_ascend/ops/activation.py b/vllm_ascend/ops/activation.py index ac8730af60a..e8b685c3718 100644 --- a/vllm_ascend/ops/activation.py +++ b/vllm_ascend/ops/activation.py @@ -34,9 +34,7 @@ def forward_oot(self, x: torch.Tensor) -> torch.Tensor: import torch_npu weight_prefetch_method = get_weight_prefetch_method() - if weight_prefetch_method: - weight_prefetch_method.maybe_prefetch_mlp_weight_preprocess(weight_prefetch_method.MLP_DOWN, x) + weight_prefetch_method.maybe_prefetch_mlp_weight_preprocess(weight_prefetch_method.MLP_DOWN, x) out = torch_npu.npu_swiglu(x) - if weight_prefetch_method: - weight_prefetch_method.maybe_prefetch_mlp_weight_postprocess(out) + weight_prefetch_method.maybe_prefetch_mlp_weight_postprocess(out) return out diff --git a/vllm_ascend/ops/fused_moe/experts_selector.py b/vllm_ascend/ops/fused_moe/experts_selector.py index d9775fbe8b3..3f7a3fd032b 100644 --- a/vllm_ascend/ops/fused_moe/experts_selector.py +++ b/vllm_ascend/ops/fused_moe/experts_selector.py @@ -59,8 +59,7 @@ def select_experts( """ # prefetch w1_w3_proj.weight preprocess weight_prefetch_method = get_weight_prefetch_method() - if weight_prefetch_method: - weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(hidden_states, "gate_up") + weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(hidden_states, "gate_up") is_support_npu_moe_gating_top_k = check_npu_moe_gating_top_k( hidden_states=hidden_states, top_k=top_k, diff --git a/vllm_ascend/ops/fused_moe/moe_mlp.py b/vllm_ascend/ops/fused_moe/moe_mlp.py index 65673b71148..a3d3af6bd9b 100644 --- a/vllm_ascend/ops/fused_moe/moe_mlp.py +++ b/vllm_ascend/ops/fused_moe/moe_mlp.py @@ -100,8 +100,7 @@ def quant_apply_mlp( _output_dtype = w2_scale[0].dtype weight_prefetch_method = get_weight_prefetch_method() - if weight_prefetch_method: - weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(hidden_states) + weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(hidden_states) is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2 if w1_scale_bias is None and w1_offset is None and is_mc2: if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb): diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index b3a503e4e9e..17214afbddf 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -66,8 +66,7 @@ def forward_oot( x.add_(self.bias) weight_prefetch_method = get_weight_prefetch_method() - if weight_prefetch_method: - weight_prefetch_method.maybe_prefetch_mlp_weight_postprocess(x) + weight_prefetch_method.maybe_prefetch_mlp_weight_postprocess(x) return x diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 9a2d4a0d608..4aa558f45b4 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -149,10 +149,9 @@ def update_attrs(self): def apply(self, input_): output, output_bias = self.apply_impl(input_) weight_prefetch_method = get_weight_prefetch_method() - if weight_prefetch_method: - weight_prefetch_method.maybe_prefetch_mlp_weight_preprocess( - weight_prefetch_method.MLP_GATE_UP, output, self.prefix - ) + weight_prefetch_method.maybe_prefetch_mlp_weight_preprocess( + weight_prefetch_method.MLP_GATE_UP, output, self.prefix + ) if not self.return_bias: return output diff --git a/vllm_ascend/ops/weight_prefetch.py b/vllm_ascend/ops/weight_prefetch.py index e53e899bea4..2464b51c7d3 100644 --- a/vllm_ascend/ops/weight_prefetch.py +++ b/vllm_ascend/ops/weight_prefetch.py @@ -47,6 +47,7 @@ class WeightPrefetchMethod: def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None: self.is_moe = is_moe_model(get_current_vllm_config()) + self.mla_sfa_prefetch_enable = weight_prefetch_config.enabled self.attn = ModuleWeightPrefetchConfig( module_name="attn", @@ -94,6 +95,9 @@ def maybe_prefetch_moe_weight_preprocess(self, hidden_states, prefix): if not self.moe.is_active_this_forward: return forward_context = get_forward_context() + if not forward_context or forward_context.model_instance is None: + return + # layer_idx is subtracted by 1 because layer_idx was incremented by 1 at layernorm. weight = forward_context.model_instance.model.layers[forward_context.layer_idx - 1].mlp.experts.w13_weight weight_size = weight.data.element_size() * weight.data.numel() * self.moe.prefetch_ratio.get(prefix, 0) @@ -184,6 +188,33 @@ def maybe_prefetch_mlp_weight_postprocess(self, stop_flag: torch.Tensor): forward_context.prefetch_mlp_gate_up_proj = False forward_context.prefetch_mlp_down_proj = False + def maybe_prefetch_mla_or_sla_weight_in_current_stream( + self, + inputs: torch.Tensor, + dependency: torch.Tensor, + max_size: int = 0, + linear_layer: torch.nn.Module | None = None, + ) -> None: + if not self.mla_sfa_prefetch_enable: + return + + # The prefetching of the weights of the o_proj matrix in the W8A8 + # scene is already performed once in AscendW8A8LinearMethod, so it + # is not needed here. + if linear_layer is not None: + from vllm_ascend.quantization.methods import AscendW8A8LinearMethod + + if isinstance( + getattr(linear_layer.quant_method, "quant_method", None), + AscendW8A8LinearMethod, + ): + return + + input_size = inputs.element_size() * inputs.numel() + if max_size <= 0 or max_size > input_size: + max_size = input_size + torch.ops.vllm.prefetch_preprocess(weight=inputs, start_flag=dependency, max_weight_size=int(max_size)) + def maybe_npu_prefetch( inputs: torch.Tensor, dependency: torch.Tensor, max_size: int = 0, offset: int = 0, *, enabled: bool = True diff --git a/vllm_ascend/quantization/methods/w8a8_static.py b/vllm_ascend/quantization/methods/w8a8_static.py index a2101010a0f..47ffc201004 100644 --- a/vllm_ascend/quantization/methods/w8a8_static.py +++ b/vllm_ascend/quantization/methods/w8a8_static.py @@ -82,12 +82,11 @@ def apply( layer_cls_name = layer.__class__.__name__ weight_prefetch_method = get_weight_prefetch_method() # prefetch qkvo_proj.weight preprocess - if weight_prefetch_method: - weight_prefetch_method.maybe_prefetch_attn_weight_preprocess( - layer_cls_name=layer_cls_name, - weight=layer.weight, - start_flag=x, - ) + weight_prefetch_method.maybe_prefetch_attn_weight_preprocess( + layer_cls_name=layer_cls_name, + weight=layer.weight, + start_flag=x, + ) try: quant_comm_config = layer._quant_comm_config except AttributeError: @@ -117,11 +116,10 @@ def apply( ) # prefetch qkvo_proj.weight postprocess - if weight_prefetch_method: - weight_prefetch_method.maybe_prefetch_attn_weight_postprocess( - layer_cls_name=layer_cls_name, - stop_flag=x, - ) + weight_prefetch_method.maybe_prefetch_attn_weight_postprocess( + layer_cls_name=layer_cls_name, + stop_flag=x, + ) quant_bias = layer.quant_bias if tp_rank == 0 else None