Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions tests/ut/attention/test_sfa_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,22 @@ def test_ascend_sfa_metadata_builder_default(self):
assert builder.device == device
assert builder.vllm_config == vllm_config

@patch("vllm_ascend.attention.sfa_v1.get_current_vllm_config")
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
def test_ascend_sfa_metadata_builder_build(self, mock_get_cos_and_sin_mla):
@patch("vllm_ascend.attention.sfa_v1.enable_dsa_cp")
def test_ascend_sfa_metadata_builder_build(
self,
mock_enable_dsa_cp,
mock_get_cos_and_sin_mla,
mock_get_current_vllm_config,
):
mock_enable_dsa_cp.return_value = False

cfg = MagicMock()
cfg.model_config = MagicMock()
cfg.model_config.hf_text_config = MagicMock()

mock_get_current_vllm_config.return_value = cfg
kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"]
vllm_config = MagicMock()
Expand Down Expand Up @@ -144,9 +158,16 @@ def test_ascend_sfa_metadata_builder_build(self, mock_get_cos_and_sin_mla):
assert metadata.num_actual_tokens == common_attn_metadata.num_actual_tokens
assert metadata.slot_mapping.shape == (100, 4, 1024)

@patch("vllm_ascend.attention.sfa_v1.get_current_vllm_config")
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
def test_ascend_sfa_metadata_builder_build_for_graph_capture(
self, mock_get_cos_and_sin_mla):
self, mock_get_cos_and_sin_mla, mock_get_current_vllm_config):
cfg = MagicMock()
cfg.model_config = MagicMock()
cfg.model_config.hf_text_config = MagicMock()

mock_get_current_vllm_config.return_value = cfg

kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"]
vllm_config = MagicMock()
Expand Down
51 changes: 4 additions & 47 deletions vllm_ascend/attention/sfa_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.forward_context import get_forward_context
from vllm.logger import logger
from vllm.model_executor.layers.linear import (ReplicatedLinear,
UnquantizedLinearMethod)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.triton_utils import HAS_TRITON
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
from vllm.v1.attention.backends.utils import AttentionCGSupport
Expand All @@ -34,7 +33,7 @@
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, _round_up, dispose_layer,
enable_sp, maybe_trans_nz, replace_layer)
enable_dsa_cp, maybe_trans_nz)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch

if TYPE_CHECKING:
Expand Down Expand Up @@ -149,8 +148,7 @@ def __init__(
got {self.decode_threshold}"

self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
self.enable_sfa_cp = enable_sp() and \
hasattr(self.model_config.hf_text_config, "index_topk")
self.enable_sfa_cp = enable_dsa_cp()

assert not (
self.enable_sfa_cp
Expand Down Expand Up @@ -368,13 +366,11 @@ def __init__(

assert self.indexer is not None, "Indexer is required for DSA."

self.enable_sfa_cp = enable_sp()
self.enable_sfa_cp = enable_dsa_cp()
self.local_num_heads = self.num_heads
self.vllm_config = get_current_vllm_config()
if self.enable_sfa_cp:
self.local_num_heads = self.num_heads * self.tp_size

self._replace_linear_class_for_sfa_cp()
self.layer_sharding_kwargs = []
for layer_name in (get_ascend_config().layer_sharding or []):
if layer_name in kwargs:
Expand Down Expand Up @@ -925,42 +921,3 @@ def indexer_select(
sparse_count=2048,
sparse_mode=3)
return topk_indices

def _replace_linear_class_for_sfa_cp(self):

vllm_config = get_current_vllm_config()
# Dispose tensor from the original q_proj
dispose_layer(self.q_proj)
# Construct the new q_proj using ReplicatedLinear
new_q_proj = ReplicatedLinear(self.q_lora_rank,
self.local_num_heads * self.qk_head_dim,
bias=False,
quant_config=vllm_config.quant_config,
prefix=self.q_proj.prefix)
# Replace the q_proj with the new one
replace_layer(self.q_proj, new_q_proj)

# Dispose tensor from the original kv_b_proj
dispose_layer(self.kv_b_proj)
# Construct the new kv_b_proj using ReplicatedLinear
new_kv_b_proj = ReplicatedLinear(
self.kv_lora_rank,
self.local_num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=vllm_config.quant_config,
prefix=self.kv_b_proj.prefix)
# Replace the kv_b_proj with the new one
replace_layer(self.kv_b_proj, new_kv_b_proj)

# Dispose tensor from the original o_proj
dispose_layer(self.o_proj)
# Construct the new o_proj using ReplicatedLinear
config = vllm_config.model_config.hf_text_config
new_o_proj = ReplicatedLinear(config.num_attention_heads *
config.v_head_dim,
config.hidden_size,
bias=False,
quant_config=vllm_config.quant_config,
prefix=self.o_proj.prefix)
# Replace the o_proj with the new one
replace_layer(self.o_proj, new_o_proj)
65 changes: 60 additions & 5 deletions vllm_ascend/ops/linear_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

import re
from functools import lru_cache
from types import SimpleNamespace
from typing import Optional, Union

import torch
Expand All @@ -59,7 +60,7 @@
get_flashcomm2_otp_group,
get_mlp_tp_group,
get_otp_group)
from vllm_ascend.utils import (enable_sp, flashcomm2_enable,
from vllm_ascend.utils import (enable_dsa_cp, enable_sp, flashcomm2_enable,
get_flashcomm2_reorgnized_batch_ids,
matmul_allreduce_enable, mlp_tp_enable,
oproj_tp_enable, shared_expert_dp_enabled)
Expand Down Expand Up @@ -609,9 +610,60 @@ def update_attrs(self):
self.unique_prefix = self.layer.unique_prefix


class ShardedCPRowParallelOp(CustomRowParallelOp):

@property
def comm_group(self):
# fake comm group to bypass tp logic
return SimpleNamespace(world_size=1,
rank_in_group=0,
device_group=None)

def apply_impl(
self,
input_,
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
assert self.quant_method is not None
output = self.quant_method.apply(self.layer, input_, bias_)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias

def update_attrs(self):
super().update_attrs()
self.layer.reduce_results = False


class ShardedCPColumnParallelOp(CustomColumnParallelOp):

@property
def comm_group(self):
# fake comm group to bypass tp logic
return SimpleNamespace(world_size=1,
rank_in_group=0,
device_group=None)

def apply_impl(
self,
input_,
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self.layer, input_, bias)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias


def _get_column_parallel_op(
prefix, layer
) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp]]:
prefix, layer
) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp,
ShardedCPColumnParallelOp]]:
if enable_dsa_cp() and ("q_b_proj" in prefix or "kv_b_proj" in prefix):
return ShardedCPColumnParallelOp(layer)
if "gate_up_proj" in prefix and mlp_tp_enable(
) and not is_moe_layer(prefix):
return MLPColumnParallelOp(layer)
Expand All @@ -636,7 +688,9 @@ def _get_row_parallel_op(
prefix, layer
) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp,
SequenceRowParallelOp]]:
SequenceRowParallelOp, ShardedCPRowParallelOp]]:
if enable_dsa_cp() and "o_proj" in prefix:
return ShardedCPRowParallelOp(layer)
if "down_proj" in prefix and mlp_tp_enable() and not is_moe_layer(prefix):
return MLPRowParallelOp(layer)
if "o_proj" in prefix and oproj_tp_enable():
Expand Down Expand Up @@ -670,7 +724,8 @@ def get_parallel_op(disable_tp, prefix, layer, direct):
MLPRowParallelOp, OProjRowParallelOp,
Flashcomm2OProjRowParallelOp,
MatmulAllreduceRowParallelOp,
SequenceRowParallelOp]] = None
SequenceRowParallelOp, ShardedCPRowParallelOp,
ShardedCPColumnParallelOp]] = None
if direct == "row":
custom_op = _get_row_parallel_op(prefix, layer)

Expand Down
31 changes: 20 additions & 11 deletions vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,11 +1119,6 @@ def dispose_layer(layer: Any):
dispose_tensor(attr_value)


def replace_layer(original_layer: Any, new_layer: Any):
original_layer.__class__ = new_layer.__class__
original_layer.__dict__ = new_layer.__dict__


def check_kv_extra_config(vllm_config):

def _check(name: str, config: dict):
Expand Down Expand Up @@ -1166,17 +1161,31 @@ def get_instance(*args, **kwargs):
return get_instance


#TODO: Temporarily use enable_sp to enable the dsa_cp feature of ds32. and subsequent updates will introduce new interfaces. --zzhx1
@lru_cache(maxsize=1)
def get_current_model_config():
def enable_dsa_cp() -> bool:
from vllm.config import get_current_vllm_config

vllm_config = get_current_vllm_config()
return vllm_config.model_config
if vllm_config is None:
return False

model_config = getattr(vllm_config, "model_config", None)
if model_config is None:
return False

hf_text_config = getattr(model_config, "hf_text_config", None)
if hf_text_config is None:
return False

return hasattr(hf_text_config, "index_topk")


#TODO: Temporarily use enable_sp to enable the dsa_cp feature of ds32. and subsequent updates will introduce new interfaces. --zzhx1
@lru_cache(maxsize=1)
def enable_dsa_cp() -> bool:
def enable_dsa_cp_with_layer_shard() -> bool:
if not enable_dsa_cp():
return False
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk")
return is_ds_v32 and enable_sp()
is_prefill_instance = vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_producer
return is_prefill_instance
Loading