Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions tests/ut/attention/test_mla_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,10 @@ def test_init(self):
self.assertEqual(self.impl.dcp_size, 2)

@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
@patch("vllm_ascend.attention.mla_v1.get_weight_prefetch_method",
return_value=MagicMock())
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
def test_mla_preprocess_dcp(self, magic_npu_fetch,
def test_mla_preprocess_dcp(self, mock_get_weight_prefetch_method,
mock_maybe_all_gather_and_maybe_unpad):

self.impl.num_kv_heads = 1
Expand Down Expand Up @@ -309,7 +310,6 @@ def test_mla_preprocess_dcp(self, magic_npu_fetch,
self.impl.qk_rope_head_dim)
]

magic_npu_fetch.return_value = MagicMock()
mock_maybe_all_gather_and_maybe_unpad.side_effect = lambda x, label: x

decode_res, prefill_res = self.impl._mla_preprocess(
Expand All @@ -324,9 +324,10 @@ def test_mla_preprocess_dcp(self, magic_npu_fetch,

@patch('torch_npu._npu_reshape_and_cache')
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
@patch("vllm_ascend.attention.mla_v1.get_weight_prefetch_method",
return_value=MagicMock())
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
def test_mla_preprocess_pcp(self, magic_npu_fetch,
def test_mla_preprocess_pcp(self, mock_get_weight_prefetch_method,
mock_maybe_all_gather_and_maybe_unpad,
mock_npu_reshape_and_cache):
self.impl.num_kv_heads = 1
Expand Down Expand Up @@ -389,7 +390,6 @@ def test_mla_preprocess_pcp(self, magic_npu_fetch,
self.impl.qk_rope_head_dim)
]

magic_npu_fetch.return_value = MagicMock()
mock_maybe_all_gather_and_maybe_unpad.side_effect = lambda x, label: x

self.impl.kv_a_layernorm = MagicMock()
Expand Down
6 changes: 3 additions & 3 deletions tests/ut/attention/test_mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,10 +967,10 @@ def test_forward_decode_without_graph(self,
mock_npu_fused_infer_attention_score.assert_called_once()

@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
def test_mla_preprocess(self, magic_npu_fetch,
@patch("vllm_ascend.attention.mla_v1.get_weight_prefetch_method",
return_value=MagicMock())
def test_mla_preprocess(self, mock_get_weight_prefetch_method,
mock_maybe_all_gather_and_maybe_unpad):
magic_npu_fetch.return_value = MagicMock()
mock_maybe_all_gather_and_maybe_unpad.side_effect = lambda x, label: x
batch_size = 4
seq_len = 8
Expand Down
3 changes: 3 additions & 0 deletions tests/ut/ops/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,12 @@ def test_QuickGELU_forward(mock_gelu, dummy_tensor, default_vllm_config):


@pytest.mark.skipif(is_310p_hw(), reason="non_310P device unittest case.")
@patch("vllm_ascend.ops.activation.get_weight_prefetch_method",
return_value=MagicMock())
@patch("torch_npu.npu_swiglu", side_effect=lambda x: x + 1)
def test_SiluAndMul_forward(
mock_swiglu,
mock_get_weight_prefetch_method,
dummy_tensor,
default_vllm_config,
):
Expand Down
15 changes: 12 additions & 3 deletions tests/ut/ops/test_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ def test_cumsum_group_list_with_type_2(self):

class TestUnifiedApplyMLP(TestBase):

@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_weight_prefetch_method',
return_value=MagicMock())
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context')
@patch('vllm_ascend.utils.get_ascend_device_type',
return_value=AscendDeviceType.A3)
Expand All @@ -306,7 +308,8 @@ def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
mock_npu_dynamic_quant,
mock_npu_grouped_matmul,
mock_soc_version,
mock_get_forward_context):
mock_get_forward_context,
mock_get_weight_prefetch_method):

mock_forward_context = MagicMock()
mock_forward_context.moe_comm_type = MoECommType.MC2
Expand Down Expand Up @@ -402,13 +405,16 @@ def test_unified_apply_mlp_without_quantization(self,
self.assertEqual(result.dtype, torch.float16)

@patch('vllm_ascend.ops.fused_moe.moe_mlp.HAS_TRITON', False)
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_weight_prefetch_method',
return_value=MagicMock())
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant')
def test_unified_apply_mlp_with_quantization_and_dynamic_scale(
self, mock_npu_dynamic_quant, mock_npu_swiglu,
mock_npu_grouped_matmul, mock_get_forward_context):
mock_npu_grouped_matmul, mock_get_forward_context,
mock_get_weight_prefetch_method):

mock_forward_context = MagicMock()
mock_forward_context.with_quant = True
Expand Down Expand Up @@ -505,6 +511,8 @@ def test_unified_apply_mlp_without_quantization_310p(
self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.float16)

@patch("vllm_ascend.ops.fused_moe.moe_mlp.get_weight_prefetch_method",
return_value=MagicMock())
@patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context")
@patch("torch_npu.npu_grouped_matmul")
@patch("torch_npu.npu_swiglu")
Expand All @@ -513,7 +521,8 @@ def test_unified_apply_mlp_without_quantization_310p(
def test_unified_apply_mlp_with_quantization_and_fusion_mlp(
self, mock_npu_dynamic_quant, mock_npu_grouped_matmul_swiglu_quant,
mock_npu_swiglu, mock_npu_grouped_matmul,
mock_get_forward_context):
mock_get_forward_context,
mock_get_weight_prefetch_method):

mock_forward_context = MagicMock()
mock_forward_context.with_quant = True
Expand Down
8 changes: 6 additions & 2 deletions tests/ut/ops/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def test_process_weights_after_loading_with_nz2(self, mock_format_cast):

class TestAscendRowParallelLinear(BaseLinearTest):

def test_mlp_optimize(self):
@patch("vllm_ascend.ops.linear_op.get_weight_prefetch_method",
return_value=MagicMock())
def test_mlp_optimize(self, mock_get_weight_prefetch_method):

ascend_config._ASCEND_CONFIG = MagicMock()
ascend_config._ASCEND_CONFIG.recompute_scheduler_enable = False
Expand All @@ -100,7 +102,9 @@ def test_mlp_optimize(self):
input_tensor = torch.randn(16, 8)
linear(input_tensor)

def test_oproj_tp(self):
@patch("vllm_ascend.ops.linear_op.get_weight_prefetch_method",
return_value=MagicMock())
def test_oproj_tp(self, mock_get_weight_prefetch_method):

config._current_vllm_config = MagicMock()

Expand Down
3 changes: 1 addition & 2 deletions vllm_ascend/_310p/fused_moe/experts_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def select_experts(
"""
# prefetch w1_w3_proj.weight preprocess
weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(hidden_states, "gate_up")
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(hidden_states, "gate_up")
topk_weights, topk_ids = _native_select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
Expand Down
21 changes: 12 additions & 9 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,13 @@
register_all_layers_to_shard_weight_series,
)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.methods import AscendW8A8LinearMethod
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_ND, maybe_trans_nz, weak_ref_tensors
from vllm_ascend.utils import (
ACL_FORMAT_FRACTAL_ND,
get_weight_prefetch_method,
maybe_trans_nz,
weak_ref_tensors,
)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch

if TYPE_CHECKING:
Expand Down Expand Up @@ -703,7 +707,6 @@ def __init__(

ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
self.enable_kv_nz = ascend_config.enable_kv_nz

self.ring_mla_mask_size = 512
Expand Down Expand Up @@ -1412,8 +1415,9 @@ def _mla_preprocess(self, layer_name, hidden_states, kv_cache, attn_metadata, ne
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
if self.fused_qkv_a_proj is not None:
maybe_npu_prefetch(
inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states, enabled=self.enable_prefetch
weight_prefetch_method = get_weight_prefetch_method()
weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream(
inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states
)
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_no_split = qkv_lora.split(
Expand Down Expand Up @@ -1545,14 +1549,13 @@ def forward(

o_proj_input[num_decode_tokens:num_actual_tokens] = output_prefill
# O proj
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
maybe_npu_prefetch(
weight_prefetch_method = get_weight_prefetch_method()
weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream(
inputs=self.o_proj.weight,
dependency=o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch,
linear_layer=self.o_proj,
)

output[...] = self.o_proj(o_proj_input, is_prefill=prefill_preprocess_res is not None)[0]

del o_proj_input
Expand Down
13 changes: 7 additions & 6 deletions vllm_ascend/attention/sfa_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@
)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.ops.triton.rope import rope_forward_triton
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.methods import AscendW8A8LinearMethod
from vllm_ascend.utils import (
ACL_FORMAT_FRACTAL_ND,
_round_up,
dispose_layer,
enable_dsa_cp,
enable_dsa_cp_with_layer_shard,
get_weight_prefetch_method,
maybe_trans_nz,
)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
Expand Down Expand Up @@ -410,7 +410,6 @@ def __init__(

ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled

# In sfa, prefill and decode have the same calculation formula,
# so do not distinguish between prefill and decode here.
Expand Down Expand Up @@ -800,8 +799,9 @@ def forward(
)
else:
assert self.fused_qkv_a_proj is not None, "q lora is required for DSA."
maybe_npu_prefetch(
inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states, enabled=self.enable_prefetch
weight_prefetch_method = get_weight_prefetch_method()
weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream(
inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states
)
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_no_split = qkv_lora.split(
Expand Down Expand Up @@ -917,11 +917,12 @@ def forward(
)

attn_output = self._v_up_proj(attn_output)
maybe_npu_prefetch(
weight_prefetch_method = get_weight_prefetch_method()
Comment thread
leo-pony marked this conversation as resolved.
weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream(
inputs=self.o_proj.weight,
dependency=attn_output,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch,
linear_layer=self.o_proj,
)

if self.enable_dsa_cp and not self.enable_dsa_cp_prefill_only:
Expand Down
6 changes: 2 additions & 4 deletions vllm_ascend/ops/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ def forward_oot(self, x: torch.Tensor) -> torch.Tensor:
import torch_npu

weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_mlp_weight_preprocess(weight_prefetch_method.MLP_DOWN, x)
weight_prefetch_method.maybe_prefetch_mlp_weight_preprocess(weight_prefetch_method.MLP_DOWN, x)
out = torch_npu.npu_swiglu(x)
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_mlp_weight_postprocess(out)
weight_prefetch_method.maybe_prefetch_mlp_weight_postprocess(out)
return out
3 changes: 1 addition & 2 deletions vllm_ascend/ops/fused_moe/experts_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def select_experts(
"""
# prefetch w1_w3_proj.weight preprocess
weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(hidden_states, "gate_up")
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(hidden_states, "gate_up")
is_support_npu_moe_gating_top_k = check_npu_moe_gating_top_k(
hidden_states=hidden_states,
top_k=top_k,
Expand Down
3 changes: 1 addition & 2 deletions vllm_ascend/ops/fused_moe/moe_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ def quant_apply_mlp(
_output_dtype = w2_scale[0].dtype

weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(hidden_states)
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(hidden_states)
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
if w1_scale_bias is None and w1_offset is None and is_mc2:
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
Expand Down
3 changes: 1 addition & 2 deletions vllm_ascend/ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ def forward_oot(
x.add_(self.bias)

weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_mlp_weight_postprocess(x)
weight_prefetch_method.maybe_prefetch_mlp_weight_postprocess(x)
return x


Expand Down
7 changes: 3 additions & 4 deletions vllm_ascend/ops/linear_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,9 @@ def update_attrs(self):
def apply(self, input_):
output, output_bias = self.apply_impl(input_)
weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_mlp_weight_preprocess(
weight_prefetch_method.MLP_GATE_UP, output, self.prefix
)
weight_prefetch_method.maybe_prefetch_mlp_weight_preprocess(
weight_prefetch_method.MLP_GATE_UP, output, self.prefix
)

if not self.return_bias:
return output
Expand Down
31 changes: 31 additions & 0 deletions vllm_ascend/ops/weight_prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class WeightPrefetchMethod:

def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None:
self.is_moe = is_moe_model(get_current_vllm_config())
self.mla_sfa_prefetch_enable = weight_prefetch_config.enabled

self.attn = ModuleWeightPrefetchConfig(
module_name="attn",
Expand Down Expand Up @@ -94,6 +95,9 @@ def maybe_prefetch_moe_weight_preprocess(self, hidden_states, prefix):
if not self.moe.is_active_this_forward:
return
forward_context = get_forward_context()
if not forward_context or forward_context.model_instance is None:
return

# layer_idx is subtracted by 1 because layer_idx was incremented by 1 at layernorm.
weight = forward_context.model_instance.model.layers[forward_context.layer_idx - 1].mlp.experts.w13_weight
weight_size = weight.data.element_size() * weight.data.numel() * self.moe.prefetch_ratio.get(prefix, 0)
Expand Down Expand Up @@ -184,6 +188,33 @@ def maybe_prefetch_mlp_weight_postprocess(self, stop_flag: torch.Tensor):
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False

def maybe_prefetch_mla_or_sla_weight_in_current_stream(
self,
inputs: torch.Tensor,
dependency: torch.Tensor,
max_size: int = 0,
linear_layer: torch.nn.Module | None = None,
) -> None:
if not self.mla_sfa_prefetch_enable:
return

# The prefetching of the weights of the o_proj matrix in the W8A8
# scene is already performed once in AscendW8A8LinearMethod, so it
# is not needed here.
if linear_layer is not None:
from vllm_ascend.quantization.methods import AscendW8A8LinearMethod

if isinstance(
getattr(linear_layer.quant_method, "quant_method", None),
AscendW8A8LinearMethod,
):
return

input_size = inputs.element_size() * inputs.numel()
if max_size <= 0 or max_size > input_size:
max_size = input_size
torch.ops.vllm.prefetch_preprocess(weight=inputs, start_flag=dependency, max_weight_size=int(max_size))


def maybe_npu_prefetch(
inputs: torch.Tensor, dependency: torch.Tensor, max_size: int = 0, offset: int = 0, *, enabled: bool = True
Expand Down
Loading