Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions tests/ut/ops/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from unittest.mock import MagicMock, patch

import torch
from vllm import config
from vllm import config, forward_context
from vllm.distributed import parallel_state as vllm_parallel_state

from tests.ut.base import TestBase
from vllm_ascend import ascend_config
Expand All @@ -22,8 +23,15 @@ def setUp(self):
self.mock_group.world_size = 2
self.mock_group.rank_in_group = 0

self.mock_dp_group = mock.MagicMock()
self.mock_dp_group.world_size = 4
self.mock_dp_group.rank_in_group = 0

self._forward_context = mock.MagicMock()
forward_context._forward_context = self._forward_context
parallel_state._MLP_TP = self.mock_group
parallel_state._OTP = self.mock_group
vllm_parallel_state._DP = self.mock_dp_group

self.mock_ascend_config = MagicMock()
self.mock_ascend_config.oproj_tensor_parallel_size = 2
Expand All @@ -37,10 +45,10 @@ def setUp(self):
return_value=self.mock_group),
patch("vllm_ascend.ops.linear_op.get_tp_group",
return_value=self.mock_group),
patch(
"vllm.distributed.parallel_state.get_tp_group",
return_value=self.mock_group,
),
patch("vllm.distributed.parallel_state.get_dp_group",
return_value=self.mock_dp_group),
patch("vllm.forward_context.get_forward_context",
return_value=self._forward_context),
patch("vllm_ascend.utils.mlp_tp_enable", return_value=True),
patch("vllm_ascend.utils.oproj_tp_enable", return_value=True)
]
Expand Down Expand Up @@ -101,15 +109,30 @@ def test_oproj_tp(self):
ascend_config._ASCEND_CONFIG.oproj_tensor_parallel_size = 2
ascend_config._ASCEND_CONFIG.ascend_scheduler_config.enabled = False

linear = AscendRowParallelLinear(
input_size=16,
output_size=8,
prefix="o_proj",
)
self.assertEqual(linear.custom_op.comm_group, parallel_state._OTP)

input_tensor = torch.randn(16, 8)
linear(input_tensor)
dp_batch_sizes = [3, 5, 7, 9]
otp_groups = [[0, 1], [2, 3]]
outputs = []

forward_context._forward_context.dp_metadata.cu_tokens_across_dp_cpu = torch.tensor(
[3, 8, 15, 24], device='cpu')

for group in otp_groups:
for dp_rank in group:
with patch.object(self.mock_dp_group, "rank_in_group",
dp_rank):
input_tensor = torch.randn(dp_batch_sizes[dp_rank], 16)
linear = AscendRowParallelLinear(
input_size=16,
output_size=8,
prefix="o_proj",
)
self.assertEqual(linear.custom_op.comm_group,
parallel_state._OTP)
output = linear(input_tensor)
outputs.append(output[0])
self.assertEqual(output[0].shape[0],
dp_batch_sizes[dp_rank])
self.assertEqual(output[0].shape[1], 8)


class TestAscendMergedColumnParallelLinear(BaseLinearTest):
Expand Down
14 changes: 1 addition & 13 deletions tests/ut/test_ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,16 +347,4 @@ def test_ascend_config_load_error(self):
}
test_vllm_config.parallel_config = ParallelConfig(
data_parallel_size=4, tensor_parallel_size=2)
init_ascend_config(test_vllm_config)

with self.assertRaises(AssertionError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
},
"oproj_tensor_parallel_size": 2,
"refresh": True
}
test_vllm_config.parallel_config = ParallelConfig(
data_parallel_size=4, tensor_parallel_size=1)
init_ascend_config(test_vllm_config)
init_ascend_config(test_vllm_config)
5 changes: 2 additions & 3 deletions tests/ut/torchair/models/test_torchair_deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,9 @@ def test_torchair_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
x = torch.randn(2, 4, 128)
positions = torch.arange(4).repeat(2, 1)
with patch.object(attn.mla_attn,
"__call__",
"forward",
return_value=torch.randn(2, 4, 128)):
with pytest.raises(AssertionError):
attn(positions, x)
attn(positions, x)

attn = TorchairDeepseekV2MLAAttention(config=base_config,
hidden_size=128,
Expand Down
8 changes: 0 additions & 8 deletions vllm_ascend/ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,6 @@ def __init__(self, vllm_config):
raise AssertionError(
"oproj_tensor_parallel_size is only supported in the pure DP scenario"
)
if not self.torchair_graph_config.enabled:
raise AssertionError(
"oproj_tensor_parallel_size is only supported in graph mode"
)
if vllm_config.kv_transfer_config is None or not vllm_config.kv_transfer_config.is_kv_consumer:
raise AssertionError(
"oproj_tensor_parallel_size is only supported in pd scenario and can only be used in D node."
)
self.enable_cpu_binding = additional_config.get(
"enable_cpu_binding", False)
self.pd_tp_ratio = 1
Expand Down
54 changes: 45 additions & 9 deletions vllm_ascend/ops/linear_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

from typing import Optional, Union

import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
Expand All @@ -46,7 +47,7 @@
from vllm.distributed import (split_tensor_along_last_dim,
tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter)
from vllm.distributed.parallel_state import get_tp_group
from vllm.distributed.parallel_state import get_dp_group, get_tp_group
from vllm.forward_context import get_forward_context

from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
Expand Down Expand Up @@ -208,6 +209,10 @@ def __init__(self, layer):
def comm_group(self):
return get_otp_group()

@property
def dp_rank(self):
return get_dp_group().rank_in_group

def apply_impl(
self,
input_: torch.Tensor,
Expand All @@ -220,26 +225,42 @@ def apply_impl(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous()

forward_context = get_forward_context()

# Prepare tensors for all-to-all communication
local_batch_size = input_parallel.size(0)
chunk_size = self.input_size_per_partition
total_batch_size = local_batch_size * self.tp_size

# Reshape tensor for efficient cross-device transfer:
# [batch, dim] -> [tp_size, batch, chunk] -> flattened
cu_tokens_across_dp_cpu = forward_context.dp_metadata.cu_tokens_across_dp_cpu
prefix_array = cu_tokens_across_dp_cpu.cpu().numpy()
global_batch_size = np.concatenate(
([prefix_array[0]], np.diff(prefix_array)))
tp_group_id = self.dp_rank // self.tp_size
tp_group_batchsize = global_batch_size[tp_group_id *
self.tp_size:tp_group_id *
self.tp_size + self.tp_size]
total_batch_size = sum(tp_group_batchsize)

# Reshape for all-to-all communication
send_buf = (input_parallel.reshape(-1,
self.tp_size, chunk_size).transpose(
0, 1).contiguous().view(-1))

# Create receive buffer
recv_buf = torch.empty(total_batch_size * chunk_size,
recv_buf = torch.zeros(total_batch_size * chunk_size,
dtype=input_parallel.dtype,
device=input_parallel.device)

# Create split array
recv_splits = [size * chunk_size for size in tp_group_batchsize]
send_splits = [local_batch_size * chunk_size] * self.tp_size

# Perform all-to-all communication
dist.all_to_all_single(recv_buf,
send_buf,
recv_splits,
send_splits,
group=self.comm_group.device_group)

input_parallel = recv_buf.view(total_batch_size, chunk_size)

# Only fuse bias add for rank 0 to avoid duplicate bias addition in TP>1
Expand All @@ -249,9 +270,24 @@ def apply_impl(
input_parallel,
bias=bias_)

# otp-specific: Combine partial results across devices
output = self.comm_group.reduce_scatter(output_parallel, dim=0)
output = output.view(input_.shape[0], self.layer.output_size)
# prepare all-reduce data
output = torch.empty(local_batch_size,
output_parallel.size(1),
dtype=output_parallel.dtype,
device=output_parallel.device)

recv_chunks = []
start_idx = 0
for size in tp_group_batchsize:
chunk = output_parallel[start_idx:start_idx + size, :]
recv_chunks.append(chunk.contiguous())
start_idx += size

# Reduce-scatter the results across devices
dist.reduce_scatter(output,
recv_chunks,
op=dist.ReduceOp.SUM,
group=self.comm_group.device_group)

# Handle bias return based on configuration
output_bias = self.bias if self.skip_bias_add else None
Expand Down
131 changes: 131 additions & 0 deletions vllm_ascend/torchair/ops/torchair_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#

from typing import Optional, Union

import numpy as np
import torch
import torch.distributed as dist
from torch.nn.parameter import Parameter
from vllm.distributed import split_tensor_along_last_dim
from vllm.forward_context import get_forward_context


def torchair_oproj_tp_forward(
self,
input_: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
else:
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous()

# prefill or decode
forward_context = get_forward_context()
with_prefill = forward_context.with_prefill

# Prepare tensors for all-to-all communication
local_batch_size = input_parallel.size(0)
chunk_size = self.input_size_per_partition

if with_prefill:
cu_tokens_across_dp_cpu = forward_context.dp_metadata.cu_tokens_across_dp_cpu
prefix_array = cu_tokens_across_dp_cpu.cpu().numpy()
global_batch_size = np.concatenate(
([prefix_array[0]], np.diff(prefix_array)))
tp_group_id = self.dp_rank // self.tp_size
tp_group_batchsize = global_batch_size[tp_group_id *
self.tp_size:tp_group_id *
self.tp_size + self.tp_size]
total_batch_size = sum(tp_group_batchsize)

# Reshape for all-to-all communication
send_buf = (input_parallel.reshape(-1,
self.tp_size, chunk_size).transpose(
0, 1).contiguous().view(-1))
# Create receive buffer
recv_buf = torch.zeros(total_batch_size * chunk_size,
dtype=input_parallel.dtype,
device=input_parallel.device)

# Create split array
recv_splits = [size * chunk_size for size in tp_group_batchsize]
send_splits = [local_batch_size * chunk_size] * self.tp_size

# Perform all-to-all communication
dist.all_to_all_single(recv_buf,
send_buf,
recv_splits,
send_splits,
group=self.comm_group.device_group)
else:
total_batch_size = local_batch_size * self.tp_size

# Reshape tensor for efficient cross-device transfer:
# [batch, dim] -> [tp_size, batch, chunk] -> flattened
send_buf = (input_parallel.reshape(-1,
self.tp_size, chunk_size).transpose(
0, 1).contiguous().view(-1))

# Create receive buffer
recv_buf = torch.empty(total_batch_size * chunk_size,
dtype=input_parallel.dtype,
device=input_parallel.device)

# Perform all-to-all communication
dist.all_to_all_single(recv_buf,
send_buf,
group=self.comm_group.device_group)

input_parallel = recv_buf.view(total_batch_size, chunk_size)

# Only fuse bias add for rank 0 to avoid duplicate bias addition in TP>1
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)

if with_prefill:
# prepare all-reduce data
output = torch.empty(local_batch_size,
output_parallel.size(1),
dtype=output_parallel.dtype,
device=output_parallel.device)

recv_chunks = []
start_idx = 0
for size in tp_group_batchsize:
chunk = output_parallel[start_idx:start_idx + size, :]
recv_chunks.append(chunk.contiguous())
start_idx += size

# Reduce-scatter the results across devices
dist.reduce_scatter(output,
recv_chunks,
op=dist.ReduceOp.SUM,
group=self.comm_group.device_group)

else:
# otp-specific: Combine partial results across devices
output = self.comm_group.reduce_scatter(output_parallel, dim=0)

# Handle bias return based on configuration
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
5 changes: 5 additions & 0 deletions vllm_ascend/torchair/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,18 +213,23 @@ def torchair_quant_method_register():
def torchair_ops_patch():
from vllm_ascend.ops.activation import AscendSiluAndMul
from vllm_ascend.ops.layernorm import AscendQuantRMSNorm, AscendRMSNorm
from vllm_ascend.ops.linear import AscendRowParallelLinear
from vllm_ascend.ops.rotary_embedding import (
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding)
from vllm_ascend.ops.vocab_parallel_embedding import \
AscendVocabParallelEmbedding
from vllm_ascend.torchair.ops import (torchair_activation,
torchair_layernorm)
from vllm_ascend.torchair.ops.torchair_linear import \
torchair_oproj_tp_forward
from vllm_ascend.torchair.ops.torchair_rotary_embedding import (
deepseek_rope_init_func, native_rope_deepseek_forward,
qwen_rope_init_func, rope_forward)
from vllm_ascend.torchair.ops.torchair_vocab_parallel_embedding import \
vocab_embedding_forward

AscendRowParallelLinear._forward_oproj_tp = torchair_oproj_tp_forward # type: ignore[method-assign]

AscendRotaryEmbedding.__init__ = qwen_rope_init_func # type: ignore[method-assign]
AscendRotaryEmbedding.forward_oot = rope_forward # type: ignore[method-assign]

Expand Down
Loading