diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index 476948ba3be..0666e9fc27b 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -195,6 +195,7 @@ jobs: pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC_old_version pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_sp_for_qwen3_moe + pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_fc2_for_qwen3_moe pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen_Dense_with_flashcomm_v1 pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen_Dense_with_prefetch_mlp_weight diff --git a/tests/e2e/multicard/test_offline_inference_distributed.py b/tests/e2e/multicard/test_offline_inference_distributed.py index a8102ec7428..320c3bdf0b9 100644 --- a/tests/e2e/multicard/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/test_offline_inference_distributed.py @@ -189,6 +189,26 @@ def test_sp_for_qwen3_moe() -> None: vllm_model.generate(example_prompts, sampling_params) +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"}) +@patch.dict(os.environ, {"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": "1"}) +def test_fc2_for_qwen3_moe() -> None: + example_prompts = [ + "Hello, my name is", + ] + sampling_params = SamplingParams(max_tokens=5, + temperature=0.0, + top_k=50, + top_p=0.9) + + with VllmRunner(snapshot_download("Qwen/Qwen3-30B-A3B"), + dtype="auto", + tensor_parallel_size=2, + distributed_executor_backend="mp", + enable_expert_parallel=True, + enforce_eager=True) as vllm_model: + vllm_model.generate(example_prompts, sampling_params) + + @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"}) def test_models_distributed_deepseek_v2_lite_with_flashcomm_v1() -> None: example_prompts = [ diff --git a/tests/e2e/singlecard/test_aclgraph_mem.py b/tests/e2e/singlecard/test_aclgraph_mem.py index c7e50788283..df7d355eff5 100644 --- a/tests/e2e/singlecard/test_aclgraph_mem.py +++ b/tests/e2e/singlecard/test_aclgraph_mem.py @@ -34,6 +34,7 @@ reason="aclgraph only support on v1") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("max_tokens", [4]) +@patch.dict(os.environ, {"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": "0"}) @patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1"}) def test_aclgraph_mem_use(model: str, max_tokens: int) -> None: del os.environ["VLLM_WORKER_MULTIPROC_METHOD"] diff --git a/tests/ut/distributed/test_parallel_state.py b/tests/ut/distributed/test_parallel_state.py index c6724ce05b4..15a5c50986b 100644 --- a/tests/ut/distributed/test_parallel_state.py +++ b/tests/ut/distributed/test_parallel_state.py @@ -4,9 +4,10 @@ from vllm.config import ParallelConfig from vllm_ascend.distributed.parallel_state import ( - _LMTP, _MC2, _OTP, _P_TP, destroy_ascend_model_parallel, - get_lmhead_tp_group, get_mc2_group, get_otp_group, get_p_tp_group, - init_ascend_model_parallel) + _FLASHCOMM2_ODP, _FLASHCOMM2_OTP, _LMTP, _MC2, _OTP, _P_TP, + destroy_ascend_model_parallel, get_flashcomm2_odp_group, + get_flashcomm2_otp_group, get_lmhead_tp_group, get_mc2_group, + get_otp_group, get_p_tp_group, init_ascend_model_parallel) @pytest.fixture @@ -21,9 +22,13 @@ def mock_distributed(): with patch('torch.distributed.is_initialized', return_value=True), \ patch('torch.distributed.get_world_size', return_value=8), \ patch('torch.distributed.get_backend', return_value='nccl'), \ - patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group: + patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group, \ + patch('vllm_ascend.distributed.parallel_state.get_tp_group') as mock_tp_group, \ + patch('vllm_ascend.distributed.parallel_state.get_dp_group') as mock_dp_group: mock_group.return_value.local_rank = 0 mock_group.return_value.device_group = MagicMock() + mock_tp_group.return_value.world_size = 4 + mock_dp_group.return_value.world_size = 2 yield @@ -31,23 +36,33 @@ def test_init_ascend_model_parallel(mock_distributed, parallel_config): mock_ascend_config = MagicMock() mock_ascend_config.lmhead_tensor_parallel_size = 2 mock_ascend_config.oproj_tensor_parallel_size = 2 + mock_ascend_config.flashcomm2_oproj_tensor_parallel_size = 2 mock_ascend_config.pd_tp_ratio = 2 mock_ascend_config.num_head_replica = 0 mock_ascend_config.pd_head_ratio = 2 mock_vllm_config = MagicMock() mock_vllm_config.kv_transfer_config.is_kv_producer = True + mock_envs_ascend = MagicMock() + mock_envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE = 2 + mock_envs_ascend.VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL = 0 with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \ patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \ patch('vllm_ascend.distributed.parallel_state.get_current_vllm_config', return_value=mock_vllm_config), \ - patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config): + patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config), \ + patch('vllm_ascend.utils.envs_ascend', new=mock_envs_ascend), \ + patch('vllm_ascend.utils.get_ascend_config', return_value=mock_ascend_config): init_ascend_model_parallel(parallel_config) mc2_group = get_mc2_group() lmheadtp_group = get_lmhead_tp_group() otp_group = get_otp_group() + flashcomm2_otp_group = get_flashcomm2_otp_group() + flashcomm2_odp_group = get_flashcomm2_odp_group() p_tp_group = get_p_tp_group() assert mc2_group is not None assert otp_group is not None + assert flashcomm2_otp_group is not None + assert flashcomm2_odp_group is not None assert lmheadtp_group is not None assert p_tp_group is not None @@ -55,4 +70,6 @@ def test_init_ascend_model_parallel(mock_distributed, parallel_config): assert _MC2 is None assert _LMTP is None assert _OTP is None + assert _FLASHCOMM2_OTP is None + assert _FLASHCOMM2_ODP is None assert _P_TP is None diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 82eb78ead71..f947fc62ec4 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -130,6 +130,10 @@ def __init__(self, vllm_config): "Only support P node tp size lagger then D node tp size") self.SLO_limits_for_dynamic_batch = additional_config.get( "SLO_limits_for_dynamic_batch", -1) + from vllm_ascend.utils import \ + get_flashcomm2_oproj_tp_size_and_validate_config + self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_oproj_tp_size_and_validate_config( + self, vllm_config) class TorchairGraphConfig: diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 80e6541e5f4..df4b22dbf3d 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -11,7 +11,8 @@ set_forward_context) import vllm_ascend.envs as envs_ascend -from vllm_ascend.utils import enable_sp, has_layer_idx, is_moe_model +from vllm_ascend.utils import (enable_sp, flashcomm2_enable, has_layer_idx, + is_moe_model) if TYPE_CHECKING: from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod @@ -123,13 +124,17 @@ def set_ascend_forward_context( tp_world_size > 1 and \ num_tokens is not None and num_tokens > 1000 forward_context.mmrs_fusion = mmrs_fusion + forward_context.num_tokens = num_tokens + forward_context.sp_enabled = sp_enabled + #TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2 + forward_context.flashcomm_v2_enabled = flashcomm2_enable( + ) and tp_world_size > 1 and num_tokens is not None - if sp_enabled: + if (forward_context.sp_enabled + or forward_context.flashcomm_v2_enabled): pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size forward_context.pad_size = pad_size - forward_context.sp_enabled = sp_enabled - forward_context.num_tokens = num_tokens # set this for rope forward_oot using forward_context.is_first_layer = True @@ -181,7 +186,8 @@ def set_ascend_forward_context( if dp_world_size > 1 and forward_context.dp_metadata is not None: max_tokens_across_dp = \ forward_context.dp_metadata.max_tokens_across_dp_cpu.item() - if sp_enabled: + if (forward_context.sp_enabled + or forward_context.flashcomm_v2_enabled): padded_length = (max_tokens_across_dp + tp_world_size - 1) // tp_world_size * tp_world_size pad_size = padded_length - num_tokens diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index 4885d4d12e3..9b5dde0fee1 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -2,12 +2,14 @@ import torch from vllm.config import ParallelConfig, get_current_vllm_config -from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group, +from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group, + get_tp_group, get_world_group, init_model_parallel_group) import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.utils import prefill_context_parallel_enable +from vllm_ascend.utils import (flashcomm2_enable, + prefill_context_parallel_enable) # Currently, mc2 op need their own group coordinator. _MC2: Optional[GroupCoordinator] = None @@ -15,6 +17,8 @@ _OTP: Optional[GroupCoordinator] = None _LMTP: Optional[GroupCoordinator] = None _P_TP: Optional[GroupCoordinator] = None +_FLASHCOMM2_OTP: Optional[GroupCoordinator] = None +_FLASHCOMM2_ODP: Optional[GroupCoordinator] = None def get_mc2_group() -> GroupCoordinator: @@ -34,6 +38,16 @@ def get_lmhead_tp_group() -> GroupCoordinator: return _LMTP +def get_flashcomm2_otp_group() -> GroupCoordinator: + return _FLASHCOMM2_OTP + + +def get_flashcomm2_odp_group() -> GroupCoordinator: + assert _FLASHCOMM2_ODP is not None, ( + "output data parallel group for flashcomm2 is not initialized") + return _FLASHCOMM2_ODP + + def get_mlp_tp_group() -> GroupCoordinator: assert _MLP_TP is not None, ("mlp group is not initialized") return _MLP_TP @@ -165,6 +179,48 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): backend, group_name="lmheadtp") + # TODO: Extract and unify the logic across different communication group. + if flashcomm2_enable(): + flashcomm2_otp_size = get_ascend_config( + ).flashcomm2_oproj_tensor_parallel_size + global_tp_size = get_tp_group().world_size + global_dp_size = get_dp_group().world_size + num_fc2_oproj_tensor_parallel_groups: int = (global_tp_size // + flashcomm2_otp_size) + + global _FLASHCOMM2_OTP + global _FLASHCOMM2_ODP + + _FLASHCOMM2_OTP = None + _FLASHCOMM2_ODP = get_tp_group() + + if flashcomm2_otp_size > 1: + otp_group_ranks = [] + odp_group_ranks: list[list[int]] = [ + [] for _ in range(flashcomm2_otp_size * global_dp_size) + ] + + for dp_group_index in range(global_dp_size): + for i in range(num_fc2_oproj_tensor_parallel_groups): + ranks = [] + for j in range(flashcomm2_otp_size): + rank_idx = dp_group_index * global_tp_size + i + j * num_fc2_oproj_tensor_parallel_groups + ranks.append(rank_idx) + odp_group_index = dp_group_index * flashcomm2_otp_size + j + odp_group_ranks[odp_group_index].append(rank_idx) + otp_group_ranks.append(ranks) + + _FLASHCOMM2_OTP = init_model_parallel_group( + otp_group_ranks, + get_world_group().local_rank, + backend, + group_name="flashcomm2_otp") + _FLASHCOMM2_ODP = init_model_parallel_group( + odp_group_ranks, + get_world_group().local_rank, + backend, + group_name="flashcomm2_odp") + def get_mlp_tensor_model_parallel_world_size(): """Return world size for the tensor model parallel group.""" @@ -201,3 +257,15 @@ def destroy_ascend_model_parallel(): if _P_TP: _P_TP.destroy() _P_TP = None + + global _FLASHCOMM2_OTP + if _FLASHCOMM2_OTP and get_ascend_config( + ).flashcomm2_oproj_tensor_parallel_size != 1: + _FLASHCOMM2_OTP.destroy() + _FLASHCOMM2_OTP = None + + global _FLASHCOMM2_ODP + if _FLASHCOMM2_ODP and get_ascend_config( + ).flashcomm2_oproj_tensor_parallel_size != 1: + _FLASHCOMM2_ODP.destroy() + _FLASHCOMM2_ODP = None diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 8f9e1d98996..aa2e539507d 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -132,6 +132,12 @@ # This feature will get better performance when concurrency is large. "VLLM_ASCEND_ENABLE_FLASHCOMM1": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", '0'))), + # Whether to enable FLASHCOMM2. Setting it to 0 disables the feature, while setting it to 1 or above enables it. + # The specific value set will be used as the O-matrix TP group size for flashcomm2. + # For a detailed introduction to the parameters and the differences and applicable scenarios + # between this feature and FLASHCOMM1, please refer to the feature guide in the documentation. + "VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": + lambda: int(os.getenv("VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE", 0)), # Whether to enable MLP weight prefetch, only used in small concurrency. "VLLM_ASCEND_ENABLE_PREFETCH_MLP": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))), @@ -185,4 +191,4 @@ def __getattr__(name: str): def __dir__(): - return list(env_variables.keys()) \ No newline at end of file + return list(env_variables.keys()) diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 1271f8e986f..2bffa44cc6f 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -24,6 +24,7 @@ └── CustomRowParallelOp │ ├── MLPRowParallelOp │ ├── OProjRowParallelOp +| ├── Flashcomm2OProjRowParallelOp │ ├── MatmulAllreduceRowParallelOp │ └── SequenceRowParallelOp └── CustomReplicatedOp @@ -41,6 +42,7 @@ import torch.distributed as dist import torch.nn.functional as F import torch_npu +from torch import nn from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter from vllm.distributed import (split_tensor_along_last_dim, @@ -49,9 +51,14 @@ from vllm.distributed.parallel_state import get_tp_group from vllm.forward_context import get_forward_context -from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group, +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.distributed.parallel_state import (get_flashcomm2_odp_group, + get_flashcomm2_otp_group, + get_mlp_tp_group, get_otp_group) from vllm_ascend.utils import (dense_optim_enable, enable_sp, + flashcomm2_enable, + get_flashcomm2_reorgnized_batch_ids, matmul_allreduce_enable, mlp_tp_enable, oproj_tp_enable, shared_expert_dp_enabled) @@ -263,6 +270,135 @@ def update_attrs(self): self.input_size_per_partition = self.layer.input_size_per_partition +class Flashcomm2OProjRowParallelOp(CustomRowParallelOp): + + def __init__(self, layer): + super().__init__(layer) + self.odp_group = get_flashcomm2_odp_group() + self.odp_size = self.odp_group.world_size + self.reorgnized_batch_ids = get_flashcomm2_reorgnized_batch_ids( + get_tp_group().world_size) + self.group_indices = torch.tensor(self.reorgnized_batch_ids).npu() + self.layer._quant_comm_config = {} + + @property + def comm_group(self): + return get_flashcomm2_otp_group() + + @property + def tp_rank(self): + if get_ascend_config().flashcomm2_oproj_tensor_parallel_size == 1: + return 0 + return self.comm_group.rank_in_group + + @property + def tp_size(self): + if get_ascend_config().flashcomm2_oproj_tensor_parallel_size == 1: + return 1 + return self.comm_group.world_size + + def apply_impl( + self, + input_: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + """Linear layer for Flashcomm2. + Input.ahspe = [batchsize*seqlength, headnum*headdim/TP] + Output.shape = [(batchsize*seqlength+padsize)/TP, hiddensize] + """ + # Handle input parallelism - split or use as-is + if self.input_is_parallel: + input_parallel = input_ + else: + tp_rank = self.tp_rank + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # padding for all-to-all + forward_context = get_forward_context() + num_padding_tokens = forward_context.pad_size + if num_padding_tokens > 0: + input_parallel = nn.functional.pad(input_parallel, + (0, 0, 0, num_padding_tokens)) + + def otp_maybe_quant_comm(x): + + # Reorganize the tensor so that the batch id and rank id correspond to each other. + chunk_num = len(self.reorgnized_batch_ids) * len( + self.reorgnized_batch_ids[0]) + batch_size = x.size(0) + + assert batch_size % chunk_num == 0, f"Batch_size({batch_size}) must be divisible by chunk_num({chunk_num})" + + batch_size_per_chunk = batch_size // chunk_num + # Indices of reorganized tensor + chunked = x.view(chunk_num, batch_size_per_chunk, x.shape[1]) + reorganized_chunks = chunked[self.group_indices] + send_buf = reorganized_chunks.flatten(1, 2) + + # all-to-all operation parameters + all2all_tp_size = self.odp_size + local_intermediate_size = x.size(1) + chunk_size = x.size(0) // all2all_tp_size + total_intermediate_size = local_intermediate_size * all2all_tp_size + + # Create receive buffer + recv_buf = torch.empty(total_intermediate_size * chunk_size, + dtype=x.dtype, + device=x.device) + + # Perform all-to-all communication + dist.all_to_all_single(recv_buf, + send_buf, + group=self.odp_group.device_group) + + return recv_buf.view(all2all_tp_size, chunk_size, + -1).transpose(0, 1).reshape(chunk_size, -1) + + if not hasattr(self, "_quant_comm_config"): + self.layer._quant_comm_config = {} + self.layer._quant_comm_config[ + "communication_fn"] = otp_maybe_quant_comm + actual_quant_method = getattr(self.quant_method, 'quant_method', + self.quant_method) + from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod + if not isinstance(actual_quant_method, AscendW8A8LinearMethod): + # Check if w8a8 quantization is enabled. If not, communicate immediately. + input_parallel = otp_maybe_quant_comm(input_parallel) + + # Matrix multiply. + assert self.quant_method is not None + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in TP>1 case) + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + + output_parallel = self.quant_method.apply(self.layer, + input_parallel, + bias=bias_) + # output_parallel shape: [bs/(TP/flashcomm2_otp_size), hiddenstate] + if self.tp_size > 1: + # flashcomm2 with reduce-scatter + output = self.comm_group.reduce_scatter(output_parallel, dim=0) + else: + output = output_parallel + + if not forward_context.sp_enabled: + # flashcomm1 not enabled + output = get_tp_group().all_gather(output, 0) + if num_padding_tokens > 0: + output = output[:-num_padding_tokens] + + # Handle bias return based on configuration + output_bias = self.bias if self.skip_bias_add else None + + return output, output_bias + + def update_attrs(self): + super().update_attrs() + self.input_is_parallel = self.layer.input_is_parallel + self.input_size_per_partition = self.layer.input_size_per_partition + + class MatmulAllreduceRowParallelOp(CustomRowParallelOp): _HCOMM_INFO = None @@ -487,13 +623,17 @@ def _get_column_parallel_op( def _get_row_parallel_op( prefix, layer ) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp, - MatmulAllreduceRowParallelOp, SequenceRowParallelOp]]: + Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp, + SequenceRowParallelOp]]: if "down_proj" in prefix and mlp_tp_enable(): return MLPRowParallelOp(layer) if "o_proj" in prefix and oproj_tp_enable(): return OProjRowParallelOp(layer) if matmul_allreduce_enable(): return MatmulAllreduceRowParallelOp(layer) + if flashcomm2_enable(): + if "o_proj" in prefix or "out_proj" in prefix: + return Flashcomm2OProjRowParallelOp(layer) if enable_sp(): if "shared_expert" in prefix: return None @@ -509,6 +649,7 @@ def get_parallel_op(disable_tp, prefix, layer, direct): return None, 0, 1 custom_op: Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp, MLPRowParallelOp, OProjRowParallelOp, + Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp, SequenceRowParallelOp]] = None if direct == "row": diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 5960d2f8574..c0760c800ed 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -35,12 +35,14 @@ from vllm.model_executor.parameter import PerTensorScaleParameter from vllm.model_executor.utils import set_weight_attrs -from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group, +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.distributed.parallel_state import (get_flashcomm2_otp_group, + get_mlp_tp_group, get_otp_group) from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod -from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable, - oproj_tp_enable) +from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, flashcomm2_enable, + mlp_tp_enable, oproj_tp_enable) from .utils import get_quant_method @@ -348,6 +350,13 @@ def apply( tp_rank = get_otp_group().rank_in_group elif layer.prefix.find("down_proj") != -1 and mlp_tp_enable(): tp_rank = get_mlp_tp_group().rank_in_group + elif (layer.prefix.find("o_proj") != -1 or + layer.prefix.find("out_proj") != -1) and flashcomm2_enable(): + if get_ascend_config( + ).flashcomm2_oproj_tensor_parallel_size == 1: + tp_rank = 0 + else: + tp_rank = get_flashcomm2_otp_group().rank_in_group else: tp_rank = get_tensor_model_parallel_rank() else: diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index 07b7cac2557..dcd692acfb6 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -115,12 +115,30 @@ def apply( weight=layer.weight, start_flag=x, ) - # quant - x = quant_per_tensor( - x, - layer.aclnn_input_scale_reciprocal, - layer.aclnn_input_offset, - ) + + quant_comm_config = getattr(layer, "_quant_comm_config", {}) + comm_fn = quant_comm_config.get("communication_fn") + enable_flashcomm2_quant_comm = comm_fn is not None and ( + "o_proj" in layer.prefix or "out_proj" in layer.prefix) + if enable_flashcomm2_quant_comm: + quant_input_x = x.contiguous().view( + -1, layer.aclnn_input_scale_reciprocal.size(0)) + quant_x = quant_per_tensor( + quant_input_x, + layer.aclnn_input_scale_reciprocal, + layer.aclnn_input_offset, + ) + comm_input = quant_x.view(x.size(0), -1) + assert comm_fn is not None + x = comm_fn(comm_input) + else: + # quant + x = quant_per_tensor( + x, + layer.aclnn_input_scale_reciprocal, + layer.aclnn_input_offset, + ) + # prefetch qkvo_proj.weight postprocess if weight_prefetch_method: weight_prefetch_method.maybe_prefetch_attn_weight_postprocess( diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index e1afd24a082..bdff9695723 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -803,3 +803,68 @@ def has_layer_idx(model_instance: torch.nn.Module) -> bool: _HAS_LAYER_IDX = hasattr(model_instance, "model") and \ hasattr(model_instance.model, "start_layer") return _HAS_LAYER_IDX + + +def flashcomm2_enable() -> bool: + return envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE > 0 + + +def get_flashcomm2_oproj_tp_size_and_validate_config(ascend_config, + vllm_config): + flashcomm2_oproj_tp_size = envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE + global_tp_size = vllm_config.parallel_config.tensor_parallel_size + + if not flashcomm2_enable(): + logger.info("FLASHCOMM2 not enable.") + return flashcomm2_oproj_tp_size + + logger.info( + f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size={flashcomm2_oproj_tp_size} and global_tp_size={global_tp_size}" + ) + if not envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1: + logger.warning_once( + "It is recommended to enable FLASHCOMM1 simultaneously when starting FLASHCOMM2 for optimal performance." + ) + if ascend_config.oproj_tensor_parallel_size is not None: + raise AssertionError( + "flashcomm2_oproj_tensor_parallel_size cannot be enabled simultaneously with oproj_tensor_parallel_size" + ) + if global_tp_size <= flashcomm2_oproj_tp_size: + raise AssertionError( + f"flashcomm2_oproj_tensor_parallel_size ({flashcomm2_oproj_tp_size}) cannot exceed global tensor parallel size ({global_tp_size})" + ) + if global_tp_size % flashcomm2_oproj_tp_size != 0: + raise AssertionError( + f"Global tensor parallel size ({global_tp_size}) must be divisible by flashcomm2_oproj_tensor_parallel_size ({flashcomm2_oproj_tp_size})" + ) + if vllm_config.kv_transfer_config is None: + logger.warning_once( + "It is recommended to enable FLASHCOMM2 in P-scenario deployments, enable it in hybrid deployment may lead to decode performance degradation." + ) + if vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_consumer: + raise AssertionError( + "FLASHCOMM2 primarily targets P-scenario deployments, " + "with additional support for hybrid deployment scenarios. " + "It is not applicable in D-scenario environments.") + + return flashcomm2_oproj_tp_size + + +def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]: + # Reorganize batch_ids so that, after the all2all and reduce-scatter operation, each batch_id corresponds to the rank_id within the DP domain. + # For example, when DP = [0, 1, 2, ..., 15] and flashcomm2_oproj_tensor_parallel_size = 2, + # the reorganized batch_ids will be [[batch0, batch8], [batch1, batch9], ..., [batch7, batch15]]. + flashcomm2_otp_size = get_ascend_config( + ).flashcomm2_oproj_tensor_parallel_size + num_oproj_tensor_parallel_groups: int = (global_tp_size // + flashcomm2_otp_size) + + reorgnized_batch_ids = [] + for i in range(num_oproj_tensor_parallel_groups): + ranks = [] + for j in range(flashcomm2_otp_size): + rank_idx = i + j * num_oproj_tensor_parallel_groups + ranks.append(rank_idx) + reorgnized_batch_ids.append(ranks) + + return reorgnized_batch_ids