Skip to content

Commit 06241d9

Browse files
committed
Code refactoring: separate torchair from eager.
Signed-off-by: funanyang <[email protected]> Co-authored-by: funanyang [email protected]
1 parent 4308c0f commit 06241d9

File tree

3 files changed

+178
-80
lines changed

3 files changed

+178
-80
lines changed

vllm_ascend/ops/linear.py

Lines changed: 53 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
get_otp_group)
4242
from vllm_ascend.utils import (dense_optim_enable, matmul_allreduce_enable,
4343
mlp_tp_enable, oproj_tp_enable)
44-
from vllm_ascend.ascend_config import get_ascend_config
44+
4545

4646
_HCOMM_INFO = None
4747

@@ -282,64 +282,43 @@ def _forward_oproj_tp(
282282
input_, num_partitions=self.tp_size)
283283
input_parallel = splitted_input[self.tp_rank].contiguous()
284284

285-
# prefill or decode
286285
forward_context = get_forward_context()
287-
with_prefill = forward_context.with_prefill
288286

289287
# Prepare tensors for all-to-all communication
290288
local_batch_size = input_parallel.size(0)
291289
chunk_size = self.input_size_per_partition
292290

293-
if with_prefill or not self.ascend_config.torchair_graph_config.enabled:
294-
cu_tokens_across_dp_cpu = forward_context.dp_metadata.cu_tokens_across_dp_cpu
295-
prefix_array = cu_tokens_across_dp_cpu.cpu().numpy()
296-
global_batch_size = np.concatenate(
297-
([prefix_array[0]], np.diff(prefix_array)))
298-
tp_group_id = self.dp_rank // self.tp_size
299-
tp_group_batchsize = global_batch_size[tp_group_id * self.tp_size: tp_group_id * self.tp_size + self.tp_size]
300-
total_batch_size = sum(tp_group_batchsize)
301-
302-
# Reshape for all-to-all communication
303-
send_buf = (
304-
input_parallel.reshape(-1, self.tp_size, chunk_size)
305-
.transpose(0, 1)
306-
.contiguous()
307-
.view(-1))
308-
# Create receive buffer
309-
recv_buf = torch.zeros(
310-
total_batch_size * chunk_size,
311-
dtype=input_parallel.dtype,
312-
device=input_parallel.device)
313-
314-
# Create split array
315-
recv_splits = [size * chunk_size for size in tp_group_batchsize]
316-
send_splits = [local_batch_size * chunk_size] * self.tp_size
317-
318-
# Perform all-to-all communication
319-
dist.all_to_all_single(
320-
recv_buf,
321-
send_buf,
322-
recv_splits,
323-
send_splits,
324-
group=self.comm_group.device_group)
325-
else:
326-
total_batch_size = local_batch_size * self.tp_size
327-
328-
# Reshape tensor for efficient cross-device transfer:
329-
# [batch, dim] -> [tp_size, batch, chunk] -> flattened
330-
send_buf = (input_parallel.reshape(-1,
331-
self.tp_size, chunk_size).transpose(
332-
0, 1).contiguous().view(-1))
333-
334-
# Create receive buffer
335-
recv_buf = torch.empty(total_batch_size * chunk_size,
336-
dtype=input_parallel.dtype,
337-
device=input_parallel.device)
338-
339-
# Perform all-to-all communication
340-
dist.all_to_all_single(recv_buf,
341-
send_buf,
342-
group=self.comm_group.device_group)
291+
cu_tokens_across_dp_cpu = forward_context.dp_metadata.cu_tokens_across_dp_cpu
292+
prefix_array = cu_tokens_across_dp_cpu.cpu().numpy()
293+
global_batch_size = np.concatenate(
294+
([prefix_array[0]], np.diff(prefix_array)))
295+
tp_group_id = self.dp_rank // self.tp_size
296+
tp_group_batchsize = global_batch_size[tp_group_id * self.tp_size: tp_group_id * self.tp_size + self.tp_size]
297+
total_batch_size = sum(tp_group_batchsize)
298+
299+
# Reshape for all-to-all communication
300+
send_buf = (
301+
input_parallel.reshape(-1, self.tp_size, chunk_size)
302+
.transpose(0, 1)
303+
.contiguous()
304+
.view(-1))
305+
# Create receive buffer
306+
recv_buf = torch.zeros(
307+
total_batch_size * chunk_size,
308+
dtype=input_parallel.dtype,
309+
device=input_parallel.device)
310+
311+
# Create split array
312+
recv_splits = [size * chunk_size for size in tp_group_batchsize]
313+
send_splits = [local_batch_size * chunk_size] * self.tp_size
314+
315+
# Perform all-to-all communication
316+
dist.all_to_all_single(
317+
recv_buf,
318+
send_buf,
319+
recv_splits,
320+
send_splits,
321+
group=self.comm_group.device_group)
343322

344323
input_parallel = recv_buf.view(total_batch_size, chunk_size)
345324

@@ -350,31 +329,26 @@ def _forward_oproj_tp(
350329
input_parallel,
351330
bias=bias_)
352331

353-
if with_prefill or not self.ascend_config.torchair_graph_config.enabled:
354-
# prepare all-reduce data
355-
output = torch.empty(
356-
local_batch_size,
357-
output_parallel.size(1),
358-
dtype=output_parallel.dtype,
359-
device=output_parallel.device)
360-
361-
recv_chunks = []
362-
start_idx = 0
363-
for size in tp_group_batchsize:
364-
chunk = output_parallel[start_idx:start_idx + size, :]
365-
recv_chunks.append(chunk.contiguous())
366-
start_idx += size
367-
368-
# Reduce-scatter the results across devices
369-
dist.reduce_scatter(
370-
output,
371-
recv_chunks,
372-
op=dist.ReduceOp.SUM,
373-
group=self.comm_group.device_group)
374-
375-
else:
376-
# otp-specific: Combine partial results across devices
377-
output = self.comm_group.reduce_scatter(output_parallel, dim=0)
332+
# prepare all-reduce data
333+
output = torch.empty(
334+
local_batch_size,
335+
output_parallel.size(1),
336+
dtype=output_parallel.dtype,
337+
device=output_parallel.device)
338+
339+
recv_chunks = []
340+
start_idx = 0
341+
for size in tp_group_batchsize:
342+
chunk = output_parallel[start_idx:start_idx + size, :]
343+
recv_chunks.append(chunk.contiguous())
344+
start_idx += size
345+
346+
# Reduce-scatter the results across devices
347+
dist.reduce_scatter(
348+
output,
349+
recv_chunks,
350+
op=dist.ReduceOp.SUM,
351+
group=self.comm_group.device_group)
378352

379353
# Handle bias return based on configuration
380354
output_bias = self.bias if self.skip_bias_add else None
@@ -689,5 +663,4 @@ def __init__(
689663
self.quant_method = quant_config.get_quant_method(self,
690664
prefix=prefix)
691665
self.return_bias = return_bias
692-
self.disable_tp = disable_tp
693-
self.ascend_config = get_ascend_config()
666+
self.disable_tp = disable_tp
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
from typing import Optional, Union
2+
3+
import torch
4+
import torch.distributed as dist
5+
import numpy as np
6+
from torch.nn.parameter import Parameter
7+
from vllm.distributed import split_tensor_along_last_dim
8+
from vllm.forward_context import get_forward_context
9+
10+
11+
def torchair_oproj_tp_forward(
12+
self,
13+
input_: torch.Tensor,
14+
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
15+
if self.input_is_parallel:
16+
input_parallel = input_
17+
else:
18+
splitted_input = split_tensor_along_last_dim(
19+
input_, num_partitions=self.tp_size)
20+
input_parallel = splitted_input[self.tp_rank].contiguous()
21+
22+
# prefill or decode
23+
forward_context = get_forward_context()
24+
with_prefill = forward_context.with_prefill
25+
26+
# Prepare tensors for all-to-all communication
27+
local_batch_size = input_parallel.size(0)
28+
chunk_size = self.input_size_per_partition
29+
30+
if with_prefill:
31+
cu_tokens_across_dp_cpu = forward_context.dp_metadata.cu_tokens_across_dp_cpu
32+
prefix_array = cu_tokens_across_dp_cpu.cpu().numpy()
33+
global_batch_size = np.concatenate(
34+
([prefix_array[0]], np.diff(prefix_array)))
35+
tp_group_id = self.dp_rank // self.tp_size
36+
tp_group_batchsize = global_batch_size[tp_group_id * self.tp_size: tp_group_id * self.tp_size + self.tp_size]
37+
total_batch_size = sum(tp_group_batchsize)
38+
39+
# Reshape for all-to-all communication
40+
send_buf = (
41+
input_parallel.reshape(-1, self.tp_size, chunk_size)
42+
.transpose(0, 1)
43+
.contiguous()
44+
.view(-1))
45+
# Create receive buffer
46+
recv_buf = torch.zeros(
47+
total_batch_size * chunk_size,
48+
dtype=input_parallel.dtype,
49+
device=input_parallel.device)
50+
51+
# Create split array
52+
recv_splits = [size * chunk_size for size in tp_group_batchsize]
53+
send_splits = [local_batch_size * chunk_size] * self.tp_size
54+
55+
# Perform all-to-all communication
56+
dist.all_to_all_single(
57+
recv_buf,
58+
send_buf,
59+
recv_splits,
60+
send_splits,
61+
group=self.comm_group.device_group)
62+
else:
63+
total_batch_size = local_batch_size * self.tp_size
64+
65+
# Reshape tensor for efficient cross-device transfer:
66+
# [batch, dim] -> [tp_size, batch, chunk] -> flattened
67+
send_buf = (input_parallel.reshape(-1,
68+
self.tp_size, chunk_size).transpose(
69+
0, 1).contiguous().view(-1))
70+
71+
# Create receive buffer
72+
recv_buf = torch.empty(total_batch_size * chunk_size,
73+
dtype=input_parallel.dtype,
74+
device=input_parallel.device)
75+
76+
# Perform all-to-all communication
77+
dist.all_to_all_single(recv_buf,
78+
send_buf,
79+
group=self.comm_group.device_group)
80+
81+
input_parallel = recv_buf.view(total_batch_size, chunk_size)
82+
83+
# Only fuse bias add for rank 0 to avoid duplicate bias addition in TP>1
84+
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
85+
assert self.quant_method is not None
86+
output_parallel = self.quant_method.apply(self,
87+
input_parallel,
88+
bias=bias_)
89+
90+
if with_prefill:
91+
# prepare all-reduce data
92+
output = torch.empty(
93+
local_batch_size,
94+
output_parallel.size(1),
95+
dtype=output_parallel.dtype,
96+
device=output_parallel.device)
97+
98+
recv_chunks = []
99+
start_idx = 0
100+
for size in tp_group_batchsize:
101+
chunk = output_parallel[start_idx:start_idx + size, :]
102+
recv_chunks.append(chunk.contiguous())
103+
start_idx += size
104+
105+
# Reduce-scatter the results across devices
106+
dist.reduce_scatter(
107+
output,
108+
recv_chunks,
109+
op=dist.ReduceOp.SUM,
110+
group=self.comm_group.device_group)
111+
112+
else:
113+
# otp-specific: Combine partial results across devices
114+
output = self.comm_group.reduce_scatter(output_parallel, dim=0)
115+
116+
# Handle bias return based on configuration
117+
output_bias = self.bias if self.skip_bias_add else None
118+
if not self.return_bias:
119+
return output
120+
return output, output_bias

vllm_ascend/torchair/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,11 +199,16 @@ def torchair_quant_method_register():
199199

200200

201201
def torchair_ops_patch():
202+
from vllm_ascend.ops.linear import AscendRowParallelLinear
202203
from vllm_ascend.ops.rotary_embedding import (
203204
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding)
204205
from vllm_ascend.torchair.ops.torchair_rotary_embedding import (
205206
deepseek_rope_init_func, native_rope_deepseek_forward,
206207
qwen_rope_init_func, rope_forward)
208+
from vllm_ascend.torchair.ops.torchair_linear import (
209+
torchair_oproj_tp_forward)
210+
211+
AscendRowParallelLinear.forward = torchair_oproj_tp_forward # type: ignore[method-assign]
207212

208213
AscendRotaryEmbedding.__init__ = qwen_rope_init_func # type: ignore[method-assign]
209214
AscendRotaryEmbedding.forward_oot = rope_forward # type: ignore[method-assign]

0 commit comments

Comments
 (0)