Skip to content
Merged
1 change: 1 addition & 0 deletions .github/workflows/_e2e_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ jobs:
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC_old_version
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_sp_for_qwen3_moe
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_fc2_for_qwen3_moe
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen_Dense_with_flashcomm_v1
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen_Dense_with_prefetch_mlp_weight

Expand Down
20 changes: 20 additions & 0 deletions tests/e2e/multicard/test_offline_inference_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,26 @@ def test_sp_for_qwen3_moe() -> None:
vllm_model.generate(example_prompts, sampling_params)


@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
@patch.dict(os.environ, {"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": "1"})
def test_fc2_for_qwen3_moe() -> None:
example_prompts = [
"Hello, my name is",
]
sampling_params = SamplingParams(max_tokens=5,
temperature=0.0,
top_k=50,
top_p=0.9)

with VllmRunner(snapshot_download("Qwen/Qwen3-30B-A3B"),
dtype="auto",
tensor_parallel_size=2,
distributed_executor_backend="mp",
enable_expert_parallel=True,
enforce_eager=True) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)


@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
def test_models_distributed_deepseek_v2_lite_with_flashcomm_v1() -> None:
example_prompts = [
Expand Down
1 change: 1 addition & 0 deletions tests/e2e/singlecard/test_aclgraph_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
reason="aclgraph only support on v1")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [4])
@patch.dict(os.environ, {"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": "0"})
@patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1"})
def test_aclgraph_mem_use(model: str, max_tokens: int) -> None:
del os.environ["VLLM_WORKER_MULTIPROC_METHOD"]
Expand Down
27 changes: 22 additions & 5 deletions tests/ut/distributed/test_parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from vllm.config import ParallelConfig

from vllm_ascend.distributed.parallel_state import (
_LMTP, _MC2, _OTP, _P_TP, destroy_ascend_model_parallel,
get_lmhead_tp_group, get_mc2_group, get_otp_group, get_p_tp_group,
init_ascend_model_parallel)
_FLASHCOMM2_ODP, _FLASHCOMM2_OTP, _LMTP, _MC2, _OTP, _P_TP,
destroy_ascend_model_parallel, get_flashcomm2_odp_group,
get_flashcomm2_otp_group, get_lmhead_tp_group, get_mc2_group,
get_otp_group, get_p_tp_group, init_ascend_model_parallel)


@pytest.fixture
Expand All @@ -21,38 +22,54 @@ def mock_distributed():
with patch('torch.distributed.is_initialized', return_value=True), \
patch('torch.distributed.get_world_size', return_value=8), \
patch('torch.distributed.get_backend', return_value='nccl'), \
patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group:
patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group, \
patch('vllm_ascend.distributed.parallel_state.get_tp_group') as mock_tp_group, \
patch('vllm_ascend.distributed.parallel_state.get_dp_group') as mock_dp_group:
mock_group.return_value.local_rank = 0
mock_group.return_value.device_group = MagicMock()
mock_tp_group.return_value.world_size = 4
mock_dp_group.return_value.world_size = 2
yield


def test_init_ascend_model_parallel(mock_distributed, parallel_config):
mock_ascend_config = MagicMock()
mock_ascend_config.lmhead_tensor_parallel_size = 2
mock_ascend_config.oproj_tensor_parallel_size = 2
mock_ascend_config.flashcomm2_oproj_tensor_parallel_size = 2
mock_ascend_config.pd_tp_ratio = 2
mock_ascend_config.num_head_replica = 0
mock_ascend_config.pd_head_ratio = 2
mock_vllm_config = MagicMock()
mock_vllm_config.kv_transfer_config.is_kv_producer = True
mock_envs_ascend = MagicMock()
mock_envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE = 2
mock_envs_ascend.VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL = 0
with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \
patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \
patch('vllm_ascend.distributed.parallel_state.get_current_vllm_config', return_value=mock_vllm_config), \
patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config):
patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config), \
patch('vllm_ascend.utils.envs_ascend', new=mock_envs_ascend), \
patch('vllm_ascend.utils.get_ascend_config', return_value=mock_ascend_config):
init_ascend_model_parallel(parallel_config)

mc2_group = get_mc2_group()
lmheadtp_group = get_lmhead_tp_group()
otp_group = get_otp_group()
flashcomm2_otp_group = get_flashcomm2_otp_group()
flashcomm2_odp_group = get_flashcomm2_odp_group()
p_tp_group = get_p_tp_group()
assert mc2_group is not None
assert otp_group is not None
assert flashcomm2_otp_group is not None
assert flashcomm2_odp_group is not None
assert lmheadtp_group is not None
assert p_tp_group is not None

destroy_ascend_model_parallel()
assert _MC2 is None
assert _LMTP is None
assert _OTP is None
assert _FLASHCOMM2_OTP is None
assert _FLASHCOMM2_ODP is None
assert _P_TP is None
4 changes: 4 additions & 0 deletions vllm_ascend/ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ def __init__(self, vllm_config):
"Only support P node tp size lagger then D node tp size")
self.SLO_limits_for_dynamic_batch = additional_config.get(
"SLO_limits_for_dynamic_batch", -1)
from vllm_ascend.utils import \
get_flashcomm2_oproj_tp_size_and_validate_config
self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_oproj_tp_size_and_validate_config(
self, vllm_config)


class TorchairGraphConfig:
Expand Down
16 changes: 11 additions & 5 deletions vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
set_forward_context)

import vllm_ascend.envs as envs_ascend
from vllm_ascend.utils import enable_sp, has_layer_idx, is_moe_model
from vllm_ascend.utils import (enable_sp, flashcomm2_enable, has_layer_idx,
is_moe_model)

if TYPE_CHECKING:
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
Expand Down Expand Up @@ -123,13 +124,17 @@ def set_ascend_forward_context(
tp_world_size > 1 and \
num_tokens is not None and num_tokens > 1000
forward_context.mmrs_fusion = mmrs_fusion
forward_context.num_tokens = num_tokens
forward_context.sp_enabled = sp_enabled
#TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
forward_context.flashcomm_v2_enabled = flashcomm2_enable(
) and tp_world_size > 1 and num_tokens is not None

if sp_enabled:
if (forward_context.sp_enabled
or forward_context.flashcomm_v2_enabled):
pad_size = (tp_world_size -
(num_tokens % tp_world_size)) % tp_world_size
forward_context.pad_size = pad_size
forward_context.sp_enabled = sp_enabled
forward_context.num_tokens = num_tokens

# set this for rope forward_oot using
forward_context.is_first_layer = True
Expand Down Expand Up @@ -181,7 +186,8 @@ def set_ascend_forward_context(
if dp_world_size > 1 and forward_context.dp_metadata is not None:
max_tokens_across_dp = \
forward_context.dp_metadata.max_tokens_across_dp_cpu.item()
if sp_enabled:
if (forward_context.sp_enabled
or forward_context.flashcomm_v2_enabled):
padded_length = (max_tokens_across_dp + tp_world_size -
1) // tp_world_size * tp_world_size
pad_size = padded_length - num_tokens
Expand Down
72 changes: 70 additions & 2 deletions vllm_ascend/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,23 @@

import torch
from vllm.config import ParallelConfig, get_current_vllm_config
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group,
get_tp_group, get_world_group,
init_model_parallel_group)

import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import prefill_context_parallel_enable
from vllm_ascend.utils import (flashcomm2_enable,
prefill_context_parallel_enable)

# Currently, mc2 op need their own group coordinator.
_MC2: Optional[GroupCoordinator] = None
_MLP_TP: Optional[GroupCoordinator] = None
_OTP: Optional[GroupCoordinator] = None
_LMTP: Optional[GroupCoordinator] = None
_P_TP: Optional[GroupCoordinator] = None
_FLASHCOMM2_OTP: Optional[GroupCoordinator] = None
_FLASHCOMM2_ODP: Optional[GroupCoordinator] = None


def get_mc2_group() -> GroupCoordinator:
Expand All @@ -34,6 +38,16 @@ def get_lmhead_tp_group() -> GroupCoordinator:
return _LMTP


def get_flashcomm2_otp_group() -> GroupCoordinator:
return _FLASHCOMM2_OTP


def get_flashcomm2_odp_group() -> GroupCoordinator:
assert _FLASHCOMM2_ODP is not None, (
"output data parallel group for flashcomm2 is not initialized")
return _FLASHCOMM2_ODP


def get_mlp_tp_group() -> GroupCoordinator:
assert _MLP_TP is not None, ("mlp group is not initialized")
return _MLP_TP
Expand Down Expand Up @@ -165,6 +179,48 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
backend,
group_name="lmheadtp")

# TODO: Extract and unify the logic across different communication group.
if flashcomm2_enable():
flashcomm2_otp_size = get_ascend_config(
).flashcomm2_oproj_tensor_parallel_size
global_tp_size = get_tp_group().world_size
global_dp_size = get_dp_group().world_size
num_fc2_oproj_tensor_parallel_groups: int = (global_tp_size //
flashcomm2_otp_size)

global _FLASHCOMM2_OTP
Comment thread
Levi-JQ marked this conversation as resolved.
global _FLASHCOMM2_ODP

_FLASHCOMM2_OTP = None
_FLASHCOMM2_ODP = get_tp_group()

if flashcomm2_otp_size > 1:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The process group creation for FlashComm2 is guarded by if flashcomm2_otp_size > 1:. This causes _FLASHCOMM2_OTP to be None when flashcomm2_oproj_tensor_parallel_size is 1. However, Flashcomm2OProjRowParallelOp is still used in this case, and it attempts to access methods on the _FLASHCOMM2_OTP group, which will lead to a crash. The logic within this if block appears to correctly handle the size == 1 case by creating groups of size 1. The conditional guard should be removed, and its content unindented, to fix this critical bug.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the Flashcomm2OProjRowParallelOp that uses _FLASHCOMM2_OTP , a check has been added to determine whether flashcomm2_oproj_tensor_parallel_size is 1 to avoid errors. By the way,This approach of setting it to None avoids redundant communication groups when flashcomm2_oproj_tensor_parallel_sizeis 1, reducing buffer consumption.

otp_group_ranks = []
odp_group_ranks: list[list[int]] = [
[] for _ in range(flashcomm2_otp_size * global_dp_size)
]

for dp_group_index in range(global_dp_size):
for i in range(num_fc2_oproj_tensor_parallel_groups):
ranks = []
for j in range(flashcomm2_otp_size):
rank_idx = dp_group_index * global_tp_size + i + j * num_fc2_oproj_tensor_parallel_groups
ranks.append(rank_idx)
odp_group_index = dp_group_index * flashcomm2_otp_size + j
odp_group_ranks[odp_group_index].append(rank_idx)
otp_group_ranks.append(ranks)

_FLASHCOMM2_OTP = init_model_parallel_group(
otp_group_ranks,
get_world_group().local_rank,
backend,
group_name="flashcomm2_otp")
_FLASHCOMM2_ODP = init_model_parallel_group(
odp_group_ranks,
get_world_group().local_rank,
backend,
group_name="flashcomm2_odp")


def get_mlp_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
Expand Down Expand Up @@ -201,3 +257,15 @@ def destroy_ascend_model_parallel():
if _P_TP:
_P_TP.destroy()
_P_TP = None

global _FLASHCOMM2_OTP
if _FLASHCOMM2_OTP and get_ascend_config(
).flashcomm2_oproj_tensor_parallel_size != 1:
_FLASHCOMM2_OTP.destroy()
_FLASHCOMM2_OTP = None

global _FLASHCOMM2_ODP
if _FLASHCOMM2_ODP and get_ascend_config(
).flashcomm2_oproj_tensor_parallel_size != 1:
_FLASHCOMM2_ODP.destroy()
_FLASHCOMM2_ODP = None
8 changes: 7 additions & 1 deletion vllm_ascend/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,12 @@
# This feature will get better performance when concurrency is large.
"VLLM_ASCEND_ENABLE_FLASHCOMM1":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", '0'))),
# Whether to enable FLASHCOMM2. Setting it to 0 disables the feature, while setting it to 1 or above enables it.
# The specific value set will be used as the O-matrix TP group size for flashcomm2.
# For a detailed introduction to the parameters and the differences and applicable scenarios
# between this feature and FLASHCOMM1, please refer to the feature guide in the documentation.
"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE":
lambda: int(os.getenv("VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE", 0)),
# Whether to enable MLP weight prefetch, only used in small concurrency.
"VLLM_ASCEND_ENABLE_PREFETCH_MLP":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),
Expand Down Expand Up @@ -185,4 +191,4 @@ def __getattr__(name: str):


def __dir__():
return list(env_variables.keys())
return list(env_variables.keys())
Loading
Loading