Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import (
AscendDeviceType,
enable_flash_comm_v1,
enable_sp,
flashcomm2_enable,
get_ascend_device_type,
has_layer_idx,
Expand Down Expand Up @@ -91,14 +91,14 @@ def set_ascend_forward_context(
# main model and drafter model may have different architecture
is_context_moe_model = is_drafter_moe_model(vllm_config) if is_draft_model else is_moe_model(vllm_config)
if is_context_moe_model:
flash_comm_v1_enabled = enable_flash_comm_v1() and num_tokens is not None
flash_comm_v1_enabled = enable_sp(vllm_config) and num_tokens is not None
mmrs_fusion = False
elif is_draft_model:
# TODO: for dense drafter, `sp` is redundant and is not compatible with `dp` and `graph`.
# Disable it to avoid more problems.
flash_comm_v1_enabled = False
else:
flash_comm_v1_enabled = enable_flash_comm_v1() and num_tokens is not None and num_tokens > 1000
flash_comm_v1_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000
forward_context.mmrs_fusion = mmrs_fusion
forward_context.num_tokens = num_tokens
forward_context.flash_comm_v1_enabled = flash_comm_v1_enabled
Expand Down
4 changes: 2 additions & 2 deletions vllm_ascend/ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from vllm.model_executor.utils import set_weight_attrs

from vllm_ascend.ops.linear_op import get_parallel_op, get_replicated_op
from vllm_ascend.utils import enable_flash_comm_v1, maybe_trans_nz
from vllm_ascend.utils import enable_sp, maybe_trans_nz


class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):
Expand Down Expand Up @@ -240,7 +240,7 @@ def __init__(
disable_tp: bool = False,
):
# TODO(kunpengW-code): Specifying the prefix in linear layers of some models in the vLLM.
if enable_flash_comm_v1():
if enable_sp():
compilation_config = get_current_vllm_config().compilation_config
unique_prefix = prefix
if prefix in compilation_config.static_forward_context:
Expand Down
8 changes: 4 additions & 4 deletions vllm_ascend/ops/linear_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
from vllm_ascend.utils import (
enable_dsa_cp,
enable_dsa_cp_with_layer_shard,
enable_flash_comm_v1,
enable_sp,
flashcomm2_enable,
get_flashcomm2_reorgnized_batch_ids,
get_weight_prefetch_method,
Expand Down Expand Up @@ -466,7 +466,7 @@ def apply_impl(self, input_: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor,
# Matrix multiply.
assert self.quant_method is not None

if enable_flash_comm_v1():
if enable_sp():
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)

# Trigger async broadcast before matmul to overlap communication.
Expand Down Expand Up @@ -649,7 +649,7 @@ def _get_column_parallel_op(
if flashcomm2_oshard_manager.flashcomm2_oshard_enable():
if any(p in prefix for p in ("qkv_proj", "conv1d", "query_key_value")):
return Flashcomm2OshardQKVParallelOp(layer)
if enable_flash_comm_v1():
if enable_sp():
if "shared_expert" in prefix:
return None
sp_column_prefix = [
Expand Down Expand Up @@ -688,7 +688,7 @@ def _get_row_parallel_op(
if flashcomm2_enable():
if "o_proj" in prefix or "out_proj" in prefix:
return Flashcomm2OProjRowParallelOp(layer)
if enable_flash_comm_v1():
if enable_sp():
if "shared_expert" in prefix:
return None
sp_row_prefixes = [
Expand Down
5 changes: 2 additions & 3 deletions vllm_ascend/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
COMPRESSED_TENSORS_METHOD,
AscendDeviceType,
check_kv_extra_config,
enable_sp,
flashcomm2_enable,
get_ascend_device_type,
is_moe_model,
Expand All @@ -48,7 +47,7 @@
update_aclgraph_sizes,
update_cudagraph_capture_sizes,
is_310p,
enable_flash_comm_v1,
enable_sp,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -402,7 +401,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
)
vllm_config.parallel_config.cp_kv_cache_interleave_size = cache_config.block_size

if enable_flash_comm_v1():
if enable_sp(vllm_config):
assert not is_vl_model(vllm_config), """Flash Comm V1 is not supported for VL models. \
Please disable it by setting VLLM_ASCEND_ENABLE_FLASHCOMM1=0. \
For optimal performance with VL models, we recommend enabling Sequence Parallelism \
Expand Down
18 changes: 7 additions & 11 deletions vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,15 +719,6 @@ def matmul_allreduce_enable() -> bool:
return envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE


def enable_flash_comm_v1():
return (
envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1
# Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1
# We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility.
or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0")))
)


def enable_sp_by_pass(vllm_config: VllmConfig):
return not vllm_config.model_config.enforce_eager and vllm_config.compilation_config.pass_config.enable_sp

Expand All @@ -739,7 +730,12 @@ def enable_sp(vllm_config=None, enable_shared_expert_dp: bool = False) -> bool:
from vllm.config import get_current_vllm_config

vllm_config = get_current_vllm_config()
_ENABLE_SP = enable_sp_by_pass(vllm_config) or enable_flash_comm_v1()
_ENABLE_SP = (
envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1
# Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1
# We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility.
or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0")))
)
Comment on lines +733 to +738
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

After this refactoring, the vllm_config parameter in the enable_sp function signature is no longer used within the function body. This makes the code less clear. Consider removing the vllm_config parameter and the associated logic that handles it when it's None. This would require updating all call sites to no longer pass this argument.


if not _ENABLE_SP and enable_shared_expert_dp:
_ENABLE_SP = True
Expand Down Expand Up @@ -1104,7 +1100,7 @@ def enable_dsa_cp() -> bool:
is_ds_v32 = hasattr(vllm_config.model_config, "hf_text_config") and hasattr(
vllm_config.model_config.hf_text_config, "index_topk"
)
return bool(is_ds_v32 and enable_flash_comm_v1())
return bool(is_ds_v32 and enable_sp())


@lru_cache(maxsize=1)
Expand Down
6 changes: 3 additions & 3 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.utils import (
check_gdn_layer,
enable_flash_comm_v1,
enable_sp,
enable_sp_by_pass,
is_drafter_moe_model,
is_moe_model,
lmhead_tp_enable,
Expand Down Expand Up @@ -1722,7 +1722,7 @@ def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int:
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if enable_sp(self.vllm_config):
if enable_sp(self.vllm_config) or enable_sp_by_pass(self.vllm_config):
return round_up(num_scheduled_tokens, tp_size)
return num_scheduled_tokens

Expand Down Expand Up @@ -2272,7 +2272,7 @@ def _dummy_run(
# tp_size; otherwise, on non-first PP ranks it would effectively perform an extra all-gather, leading
# to incorrect memory estimation and potentially causing OOM.
intermediate_tokens = num_tokens_padded
if enable_flash_comm_v1():
if enable_sp():
tp_size = get_tensor_model_parallel_world_size()
intermediate_tokens = (num_tokens_padded + tp_size - 1) // tp_size
if self.intermediate_tensors is None:
Expand Down
6 changes: 3 additions & 3 deletions vllm_ascend/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from vllm_ascend.utils import (
AscendDeviceType,
check_ascend_device_type,
enable_flash_comm_v1,
enable_sp,
get_ascend_device_type,
register_ascend_customop,
)
Expand Down Expand Up @@ -376,7 +376,7 @@ def execute_model(
if forward_pass and not get_pp_group().is_first_rank:
# If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise
# it will conflict with the all-gather operation in flashcomm1.
if enable_flash_comm_v1():
if enable_sp():
all_gather_group = None
else:
all_gather_group = get_tp_group()
Expand All @@ -393,7 +393,7 @@ def execute_model(
assert parallel_config.distributed_executor_backend != ("external_launcher") and not get_pp_group().is_last_rank
# If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise
# it will conflict with the all-gather operation in flashcomm1.
if enable_flash_comm_v1():
if enable_sp():
all_gather_group = None
else:
all_gather_group = get_tp_group()
Expand Down
Loading