Skip to content

[feat]: oproj tensor parallelism in pure DP and graph-mode scenarios.#2167

Merged
wangxiyuan merged 8 commits intovllm-project:mainfrom
lidenghui1110:oproj
Sep 7, 2025
Merged

[feat]: oproj tensor parallelism in pure DP and graph-mode scenarios.#2167
wangxiyuan merged 8 commits intovllm-project:mainfrom
lidenghui1110:oproj

Conversation

@lidenghui1110
Copy link
Copy Markdown
Contributor

@lidenghui1110 lidenghui1110 commented Aug 1, 2025

What this PR does / why we need it?

This PR introduces Oproj matrix tensor model parallel to achieve decreasing of memory consumption. It only support graph mode in pure DP scenario.

In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with oproj_tensor_parallel_size = 8, we have 1 ms TPOT increasing, saved 5.8 GB NPU memory per RANK. We got best performance when oproj_tensor_parallel_size=4 without TPOT increasing.

performance data:
image

Does this PR introduce any user-facing change?

This PR introduces one new config in additional_config.

Name Effect Required Type Constraints
oproj_tensor_parallel_size Split the o_proj matrix along the row dimension (head num * head dim) into oproj_tensor_parallel_size pieces. No int default value is None, once this value is set, the feature will be enabled, head num * head dim must be divisible by this value.

example

--additional_config={"oproj_tensor_parallel_size": 8}

How was this patch tested?

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Aug 1, 2025

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Comment thread vllm_ascend/ops/linear.py Outdated
else:
self.register_parameter("bias", None)

def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function seems to be identical with that of RowParallelLinear, why do we need to rewrite it here?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in origin weight_load,

tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()

we need replace it into custom comm group

tp_rank = self.tp_rank
tp_size = self.tp_size

It seems that the latest vllm does not have this problem.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood, thanks

@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

else:
tp_rank = get_tensor_model_parallel_rank()
else:
tp_rank = 0
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What means tp_rank = 0?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This origin code here

if isinstance(layer, RowParallelLinear):
            tp_rank = get_tensor_model_parallel_rank()
            return self.quant_method.apply(layer, x, bias, tp_rank)
        return self.quant_method.apply(layer, x, bias)

The default situation is not passing tp, which is tp=0

Comment thread mypy.ini Outdated
Comment on lines +6 to +8
[mypy-numpy.*]
ignore_missing_imports = True

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to update these configurations? If this is a bug in the repo, I suggest creating a separate PR to fix it.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I will remove this here. Here is just some local CI checks.

Comment on lines +517 to 531

if oproj_tp_enable():
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
elif (config.n_routed_experts is not None
and self.debug_layer_idx >= config.first_k_dense_replace
and self.debug_layer_idx % config.moe_layer_freq == 0
and (ascend_config.torchair_graph_config.enable_multistream_moe
or self.enable_shared_expert_dp)):
self.o_proj = TorchairDeepseekV2RowParallelLinearReplaceAllreduce(
self.num_heads * self.v_head_dim,
self.hidden_size,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it still not possible to eliminate these if-else branches even with CustomOp? @wangxiyuan @Yikun

Comment thread vllm_ascend/ops/linear.py
Comment on lines +139 to +153
if prefix.find("down_proj") != -1 and mlp_tp_enable():
comm_group = get_mlp_tp_group()
self.forward_type = "mlp_tp"
elif prefix.find("o_proj") != -1 and oproj_tp_enable():
comm_group = get_otp_group()
self.forward_type = "oproj_tp"
else:
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.enable_mlp_optimze = False
comm_group = get_tp_group()
self.forward_type = "normal"
self.comm_group = comm_group
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is adding more if-else conditions the way to extend support for new models?

@zzhx1 zzhx1 force-pushed the oproj branch 2 times, most recently from 0e2fce6 to b1582b4 Compare August 28, 2025 14:52
Comment thread vllm_ascend/ops/linear.py Outdated
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous()
assert self.quant_method is not None
# Choose different forward function according to the type of TP group
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part use dict format may be more extensible. The same logic applies as mentioned above.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to modify it and found it not very intuitive, but I made some changed it to super.forward().

Comment thread vllm_ascend/utils.py
name="SiluAndMul")
CustomOp.register_oot(_decorated_op_cls=AscendRotaryEmbedding,
name="RotaryEmbedding")
CustomOp.register_oot(_decorated_op_cls=AscendColumnParallelLinear,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this component is enabled by default, modifications to the original vLLM repository will require ongoing maintenance and updates. What is the long-term maintenance strategy for this?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there won't be many changes here, we just need to focus on the maintenance of the __init__ method in the follow-up.

@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@wangxiyuan
Copy link
Copy Markdown
Collaborator

according to the comment #2678 (comment) please remove the patch_linear as well

@wangxiyuan
Copy link
Copy Markdown
Collaborator

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
@zzhx1
Copy link
Copy Markdown
Contributor

zzhx1 commented Sep 6, 2025

@wangxiyuan This PR is ready,and also fixed the bug related to linearBase.

@@ -0,0 +1,15 @@
import vllm
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks that these 3 file can merged into one

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants