-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[Feat] flashcomm_v2 optim solution #3232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6b765b7
63536d4
3281f6c
8ccc51e
d7a6c79
01f5e7e
1eca224
1414e69
110d673
67bd24c
929a125
ff22e1a
b8c0e64
3ad78c3
c120c59
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,19 +2,23 @@ | |
|
|
||
| import torch | ||
| from vllm.config import ParallelConfig, get_current_vllm_config | ||
| from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group, | ||
| from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group, | ||
| get_tp_group, get_world_group, | ||
| init_model_parallel_group) | ||
|
|
||
| import vllm_ascend.envs as envs_ascend | ||
| from vllm_ascend.ascend_config import get_ascend_config | ||
| from vllm_ascend.utils import prefill_context_parallel_enable | ||
| from vllm_ascend.utils import (flashcomm2_enable, | ||
| prefill_context_parallel_enable) | ||
|
|
||
| # Currently, mc2 op need their own group coordinator. | ||
| _MC2: Optional[GroupCoordinator] = None | ||
| _MLP_TP: Optional[GroupCoordinator] = None | ||
| _OTP: Optional[GroupCoordinator] = None | ||
| _LMTP: Optional[GroupCoordinator] = None | ||
| _P_TP: Optional[GroupCoordinator] = None | ||
| _FLASHCOMM2_OTP: Optional[GroupCoordinator] = None | ||
| _FLASHCOMM2_ODP: Optional[GroupCoordinator] = None | ||
|
|
||
|
|
||
| def get_mc2_group() -> GroupCoordinator: | ||
|
|
@@ -34,6 +38,16 @@ def get_lmhead_tp_group() -> GroupCoordinator: | |
| return _LMTP | ||
|
|
||
|
|
||
| def get_flashcomm2_otp_group() -> GroupCoordinator: | ||
| return _FLASHCOMM2_OTP | ||
|
|
||
|
|
||
| def get_flashcomm2_odp_group() -> GroupCoordinator: | ||
| assert _FLASHCOMM2_ODP is not None, ( | ||
| "output data parallel group for flashcomm2 is not initialized") | ||
| return _FLASHCOMM2_ODP | ||
|
|
||
|
|
||
| def get_mlp_tp_group() -> GroupCoordinator: | ||
| assert _MLP_TP is not None, ("mlp group is not initialized") | ||
| return _MLP_TP | ||
|
|
@@ -165,6 +179,48 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): | |
| backend, | ||
| group_name="lmheadtp") | ||
|
|
||
| # TODO: Extract and unify the logic across different communication group. | ||
| if flashcomm2_enable(): | ||
| flashcomm2_otp_size = get_ascend_config( | ||
| ).flashcomm2_oproj_tensor_parallel_size | ||
| global_tp_size = get_tp_group().world_size | ||
| global_dp_size = get_dp_group().world_size | ||
| num_fc2_oproj_tensor_parallel_groups: int = (global_tp_size // | ||
| flashcomm2_otp_size) | ||
|
|
||
| global _FLASHCOMM2_OTP | ||
| global _FLASHCOMM2_ODP | ||
|
|
||
| _FLASHCOMM2_OTP = None | ||
| _FLASHCOMM2_ODP = get_tp_group() | ||
|
|
||
| if flashcomm2_otp_size > 1: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The process group creation for FlashComm2 is guarded by
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the Flashcomm2OProjRowParallelOp that uses _FLASHCOMM2_OTP , a check has been added to determine whether flashcomm2_oproj_tensor_parallel_size is 1 to avoid errors. By the way,This approach of setting it to None avoids redundant communication groups when flashcomm2_oproj_tensor_parallel_sizeis 1, reducing buffer consumption. |
||
| otp_group_ranks = [] | ||
| odp_group_ranks: list[list[int]] = [ | ||
| [] for _ in range(flashcomm2_otp_size * global_dp_size) | ||
| ] | ||
|
|
||
| for dp_group_index in range(global_dp_size): | ||
| for i in range(num_fc2_oproj_tensor_parallel_groups): | ||
| ranks = [] | ||
| for j in range(flashcomm2_otp_size): | ||
| rank_idx = dp_group_index * global_tp_size + i + j * num_fc2_oproj_tensor_parallel_groups | ||
| ranks.append(rank_idx) | ||
| odp_group_index = dp_group_index * flashcomm2_otp_size + j | ||
| odp_group_ranks[odp_group_index].append(rank_idx) | ||
| otp_group_ranks.append(ranks) | ||
|
|
||
| _FLASHCOMM2_OTP = init_model_parallel_group( | ||
| otp_group_ranks, | ||
| get_world_group().local_rank, | ||
| backend, | ||
| group_name="flashcomm2_otp") | ||
| _FLASHCOMM2_ODP = init_model_parallel_group( | ||
| odp_group_ranks, | ||
| get_world_group().local_rank, | ||
| backend, | ||
| group_name="flashcomm2_odp") | ||
|
|
||
|
|
||
| def get_mlp_tensor_model_parallel_world_size(): | ||
| """Return world size for the tensor model parallel group.""" | ||
|
|
@@ -201,3 +257,15 @@ def destroy_ascend_model_parallel(): | |
| if _P_TP: | ||
| _P_TP.destroy() | ||
| _P_TP = None | ||
|
|
||
| global _FLASHCOMM2_OTP | ||
| if _FLASHCOMM2_OTP and get_ascend_config( | ||
| ).flashcomm2_oproj_tensor_parallel_size != 1: | ||
| _FLASHCOMM2_OTP.destroy() | ||
| _FLASHCOMM2_OTP = None | ||
|
|
||
| global _FLASHCOMM2_ODP | ||
| if _FLASHCOMM2_ODP and get_ascend_config( | ||
| ).flashcomm2_oproj_tensor_parallel_size != 1: | ||
| _FLASHCOMM2_ODP.destroy() | ||
| _FLASHCOMM2_ODP = None | ||
Uh oh!
There was an error while loading. Please reload this page.