Skip to content

[Op] DeepSeekV3.2 support bmm_transpose operator#4631

Merged
wangxiyuan merged 3 commits intovllm-project:mainfrom
rjg-lyh:pr_bmmtrans
Dec 8, 2025
Merged

[Op] DeepSeekV3.2 support bmm_transpose operator#4631
wangxiyuan merged 3 commits intovllm-project:mainfrom
rjg-lyh:pr_bmmtrans

Conversation

@ZYang6263
Copy link
Copy Markdown
Collaborator

@ZYang6263 ZYang6263 commented Dec 2, 2025

What this PR does / why we need it?

DeepSeekV3.2 support bmm_transpose operator.

Does this PR introduce any user-facing change?

How was this patch tested?

Signed-off-by: ZYang6263 <zy626375@gmail.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new bmm_transpose operator to replace torch_npu.npu_transpose_batchmatmul in _v_up_proj. While the Python-side change is straightforward, my review of the C++ implementation for the new operator (torch.ops._C_ascend.batch_matmul_transpose) revealed critical issues related to out-of-bounds memory access and multi-device safety. These issues, detailed in the line comment, could lead to runtime errors and must be addressed before this change is merged.

res = torch.empty((b, self.num_heads, self.v_head_dim),
dtype=x.dtype,
device=x.device)
torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The C++ implementation of the batch_matmul_transpose operator has a couple of critical issues that could lead to runtime errors.

  1. Potential for out-of-bounds access: In csrc/batch_matmul_transpose/op_host/batch_matmul_transpose.h, the tiling data for different batch sizes is cached in a static array global_tiling_data of size MAX_CAPTURE_NUM (1024). The index into this array, batchIdx, is derived from the number of tokens (opShape.m). If the number of tokens is greater than 1024, which is common during prefill, this will result in an out-of-bounds access and a runtime error.

    // csrc/batch_matmul_transpose/op_host/batch_matmul_transpose.h:104
    int32_t batchIdx = opShape.m - 1;
    ...
    if (batchIdx >= 0 && batchIdx < MAX_CAPTURE_NUM) { // MAX_CAPTURE_NUM is 1024
        ...
    } else {
        TORCH_CHECK(false, "batchIdx is out of range: ", batchIdx);
    }
  2. Not safe for multi-device execution: The global_tiling_data is a static variable, meaning it's initialized only once. Its device is set to the device of the input tensor from the first call. In a multi-GPU environment where workers on different devices might call this operator, this will cause device mismatch errors for any call not on the initial device.

    // csrc/batch_matmul_transpose/op_host/batch_matmul_transpose.h:106
    static auto global_tiling_data = at::empty(
        {tilingSize * MAX_CAPTURE_NUM}, at::TensorOptions().dtype(at::kByte).device(tensor_a.options().device()));

    A potential solution for the multi-device issue is to use a thread-safe map from device index to the tiling data tensor, for example using std::map with a std::mutex.

Given these issues, the underlying implementation of this operator needs to be revised before it can be safely used.

@ZYang6263 ZYang6263 changed the title [Op] Support bmm_transpose operator [Op] DeepSeekV3.2 support bmm_transpose operator Dec 2, 2025
@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Dec 2, 2025

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@hust17yixuan
Copy link
Copy Markdown
Contributor

this ops may have some problem in high level cann version, considering merge lately

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Dec 6, 2025

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Signed-off-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com>
@wangxiyuan wangxiyuan merged commit a433f32 into vllm-project:main Dec 8, 2025
16 checks passed
Clorist33 pushed a commit to Clorist33/vllm-ascend that referenced this pull request Dec 9, 2025
### What this PR does / why we need it?
DeepSeekV3.2 support bmm_transpose operator.

- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

Signed-off-by: ZYang6263 <zy626375@gmail.com>
Signed-off-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: tanqingshan (A) <50050625@china.huawei.com>
weijinqian0 pushed a commit to weijinqian0/vllm-ascend that referenced this pull request Dec 9, 2025
### What this PR does / why we need it?
DeepSeekV3.2 support bmm_transpose operator.

- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

Signed-off-by: ZYang6263 <zy626375@gmail.com>
Signed-off-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Clorist33 pushed a commit to Clorist33/vllm-ascend that referenced this pull request Dec 10, 2025
### What this PR does / why we need it?
DeepSeekV3.2 support bmm_transpose operator.

- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

Signed-off-by: ZYang6263 <zy626375@gmail.com>
Signed-off-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Mercykid-bash pushed a commit to Mercykid-bash/vllm-ascend that referenced this pull request Dec 10, 2025
### What this PR does / why we need it?
DeepSeekV3.2 support bmm_transpose operator.

- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

Signed-off-by: ZYang6263 <zy626375@gmail.com>
Signed-off-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants