Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions vllm_ascend/attention/sfa_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,14 +484,15 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
self._process_weights_for_fused_mlapo(act_dtype)

def _v_up_proj(self, x):
if self.W_UV.shape[0] * self.W_UV.shape[1] < 65536:
x = x.view(-1, self.local_num_heads, self.kv_lora_rank)
x = torch_npu.npu_transpose_batchmatmul(x,
self.W_UV,
perm_x1=[1, 0, 2],
perm_x2=[0, 1, 2],
perm_y=[1, 0, 2])
x = x.reshape(-1, self.local_num_heads * self.v_head_dim)
if x.dtype in [torch.float16, torch.bfloat16] \
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose"):
x = x.view(-1, self.num_heads, self.kv_lora_rank)
b, _, _ = x.shape
res = torch.empty((b, self.num_heads, self.v_head_dim),
dtype=x.dtype,
device=x.device)
torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The C++ implementation of the batch_matmul_transpose operator has a couple of critical issues that could lead to runtime errors.

  1. Potential for out-of-bounds access: In csrc/batch_matmul_transpose/op_host/batch_matmul_transpose.h, the tiling data for different batch sizes is cached in a static array global_tiling_data of size MAX_CAPTURE_NUM (1024). The index into this array, batchIdx, is derived from the number of tokens (opShape.m). If the number of tokens is greater than 1024, which is common during prefill, this will result in an out-of-bounds access and a runtime error.

    // csrc/batch_matmul_transpose/op_host/batch_matmul_transpose.h:104
    int32_t batchIdx = opShape.m - 1;
    ...
    if (batchIdx >= 0 && batchIdx < MAX_CAPTURE_NUM) { // MAX_CAPTURE_NUM is 1024
        ...
    } else {
        TORCH_CHECK(false, "batchIdx is out of range: ", batchIdx);
    }
  2. Not safe for multi-device execution: The global_tiling_data is a static variable, meaning it's initialized only once. Its device is set to the device of the input tensor from the first call. In a multi-GPU environment where workers on different devices might call this operator, this will cause device mismatch errors for any call not on the initial device.

    // csrc/batch_matmul_transpose/op_host/batch_matmul_transpose.h:106
    static auto global_tiling_data = at::empty(
        {tilingSize * MAX_CAPTURE_NUM}, at::TensorOptions().dtype(at::kByte).device(tensor_a.options().device()));

    A potential solution for the multi-device issue is to use a thread-safe map from device index to the tiling data tensor, for example using std::map with a std::mutex.

Given these issues, the underlying implementation of this operator needs to be revised before it can be safely used.

x = res.reshape(-1, self.num_heads * self.v_head_dim)
else:
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.local_num_heads,
Expand Down
Loading