Skip to content

[0.9.1] Add LMhead TP communication groups.#1956

Merged
ganyi1996ppo merged 10 commits intovllm-project:v0.9.1-devfrom
Angazenn:lmhead
Aug 14, 2025
Merged

[0.9.1] Add LMhead TP communication groups.#1956
ganyi1996ppo merged 10 commits intovllm-project:v0.9.1-devfrom
Angazenn:lmhead

Conversation

@Angazenn
Copy link
Copy Markdown
Collaborator

@Angazenn Angazenn commented Jul 23, 2025

What this PR does / why we need it?

In pure dp scenarios (such as DP32), LMHead comptuation takes 1~2ms. In this PR we customize the parallelism of LMHead,enabling the separate TP of LMHead. The computation flow is listed as follows:

get_lmhead_group().all_gather  # [num_tokens, hid_dim] -->  [num_tokens * lmhead_tp, hid_dim]
--> lmhead matmul  # [num_tokens * lmhead_tp, hid_dim] -->  [num_tokens * lmhead_tp, vocab_size //  lmhead_tp]
--> get_lmhead_group().all_to_all  # [num_tokens * lmhead_tp, vocab_size //  lmhead_tp] --> [num_tokens, vocab_size]

this can decrease 0.5~1ms for deepseek with 28BS on a single die、MTP.

In addition, this PR also fixes a bug that introduced by LMHead quantization. The OP npu_quant_matmul only accepts dim < 65536, while vocab_size is > 65536 if using TP 1. We can set lmhead tp size > 1 to avoid this bug.

Main version of this PR: #2309 .

Does this PR introduce any user-facing change?

Yes. We introduced another configurable options lmhead_tp_size in ascend_config. For example:

additional_config={
        "lmhead_tp_size": 16,
}

The default value is -1, and lmhead_tp_size is automatically set to tensor_parallel_size in this case. Besides, it is suggested to use it when running full DP to avoid additional communication introduced by TP. Therefore, the parallel size of lmhead group will also be changed to tensor_parallel_size if TP > 1 so as to fall back to normally TP+DP case.

How was this patch tested?

@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@Angazenn Angazenn force-pushed the lmhead branch 2 times, most recently from ae9413c to 9b86bbf Compare July 27, 2025 14:39
@Angazenn Angazenn force-pushed the lmhead branch 3 times, most recently from 5ef27e1 to 74cd78d Compare July 29, 2025 11:32
@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@Angazenn Angazenn force-pushed the lmhead branch 2 times, most recently from 27cbf3a to 49933c3 Compare July 30, 2025 12:30
@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

angazenn added 2 commits July 31, 2025 19:10
Signed-off-by: angazenn <zengyanjia@huawei.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
angazenn added 2 commits August 4, 2025 10:43
Signed-off-by: angazenn <zengyanjia@huawei.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
@Angazenn Angazenn changed the title [DRAFT] Lmhead TP [0.9.1] Lmhead TP Aug 4, 2025
@Angazenn Angazenn changed the title [0.9.1] Lmhead TP [0.9.1] Add LMhead TP communication groups. Aug 4, 2025
if not with_prefill:
padded_num_indices = num_tokens
else:
padded_num_indices = max_num_reqs
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will padding here cause the time to be longer when DP has a serious uneven load?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there might be performance degradation. However, in some cases (you can see _get_forward_metadata_across_dp ) the all_reduce communication used for gathering metedata is skipped. Thus using true num_tokens_across_dp will incur another all_reduce communication in this case. Maybe we can have a better solution for this.

Comment thread vllm_ascend/distributed/parallel_state.py Outdated
backend,
group_name="mc2")

all_ranks = torch.arange(world_size).reshape(-1, lm_head_tp_size)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: please create this parallel only when runing deepseek

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Signed-off-by: zengyanjia <z00883269@china.huawei.com>
zengyanjia added 3 commits August 12, 2025 20:11
Signed-off-by: zengyanjia <z00883269@china.huawei.com>
Signed-off-by: zengyanjia <z00883269@china.huawei.com>
fix
Signed-off-by: zengyanjia <z00883269@china.huawei.com>
@ganyi1996ppo ganyi1996ppo merged commit f5226e3 into vllm-project:v0.9.1-dev Aug 14, 2025
17 checks passed
@Angazenn Angazenn deleted the lmhead branch September 8, 2025 03:16
False) # Whether to enable DeepSeek models' prefill optimizations
self.enable_cpu_binding = additional_config.get( # Whether to enable the cpu binding
"enable_cpu_binding", False)
self.lmhead_tp_size = additional_config.get("lmhead_tp_size", -1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's better that the default value is 1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants