4141 get_otp_group )
4242from vllm_ascend .utils import (dense_optim_enable , matmul_allreduce_enable ,
4343 mlp_tp_enable , oproj_tp_enable )
44- from vllm_ascend . ascend_config import get_ascend_config
44+
4545
4646_HCOMM_INFO = None
4747
@@ -282,64 +282,43 @@ def _forward_oproj_tp(
282282 input_ , num_partitions = self .tp_size )
283283 input_parallel = splitted_input [self .tp_rank ].contiguous ()
284284
285- # prefill or decode
286285 forward_context = get_forward_context ()
287- with_prefill = forward_context .with_prefill
288286
289287 # Prepare tensors for all-to-all communication
290288 local_batch_size = input_parallel .size (0 )
291289 chunk_size = self .input_size_per_partition
292290
293- if with_prefill or not self .ascend_config .torchair_graph_config .enabled :
294- cu_tokens_across_dp_cpu = forward_context .dp_metadata .cu_tokens_across_dp_cpu
295- prefix_array = cu_tokens_across_dp_cpu .cpu ().numpy ()
296- global_batch_size = np .concatenate (
297- ([prefix_array [0 ]], np .diff (prefix_array )))
298- tp_group_id = self .dp_rank // self .tp_size
299- tp_group_batchsize = global_batch_size [tp_group_id * self .tp_size : tp_group_id * self .tp_size + self .tp_size ]
300- total_batch_size = sum (tp_group_batchsize )
301-
302- # Reshape for all-to-all communication
303- send_buf = (
304- input_parallel .reshape (- 1 , self .tp_size , chunk_size )
305- .transpose (0 , 1 )
306- .contiguous ()
307- .view (- 1 ))
308- # Create receive buffer
309- recv_buf = torch .zeros (
310- total_batch_size * chunk_size ,
311- dtype = input_parallel .dtype ,
312- device = input_parallel .device )
313-
314- # Create split array
315- recv_splits = [size * chunk_size for size in tp_group_batchsize ]
316- send_splits = [local_batch_size * chunk_size ] * self .tp_size
317-
318- # Perform all-to-all communication
319- dist .all_to_all_single (
320- recv_buf ,
321- send_buf ,
322- recv_splits ,
323- send_splits ,
324- group = self .comm_group .device_group )
325- else :
326- total_batch_size = local_batch_size * self .tp_size
327-
328- # Reshape tensor for efficient cross-device transfer:
329- # [batch, dim] -> [tp_size, batch, chunk] -> flattened
330- send_buf = (input_parallel .reshape (- 1 ,
331- self .tp_size , chunk_size ).transpose (
332- 0 , 1 ).contiguous ().view (- 1 ))
333-
334- # Create receive buffer
335- recv_buf = torch .empty (total_batch_size * chunk_size ,
336- dtype = input_parallel .dtype ,
337- device = input_parallel .device )
338-
339- # Perform all-to-all communication
340- dist .all_to_all_single (recv_buf ,
341- send_buf ,
342- group = self .comm_group .device_group )
291+ cu_tokens_across_dp_cpu = forward_context .dp_metadata .cu_tokens_across_dp_cpu
292+ prefix_array = cu_tokens_across_dp_cpu .cpu ().numpy ()
293+ global_batch_size = np .concatenate (
294+ ([prefix_array [0 ]], np .diff (prefix_array )))
295+ tp_group_id = self .dp_rank // self .tp_size
296+ tp_group_batchsize = global_batch_size [tp_group_id * self .tp_size : tp_group_id * self .tp_size + self .tp_size ]
297+ total_batch_size = sum (tp_group_batchsize )
298+
299+ # Reshape for all-to-all communication
300+ send_buf = (
301+ input_parallel .reshape (- 1 , self .tp_size , chunk_size )
302+ .transpose (0 , 1 )
303+ .contiguous ()
304+ .view (- 1 ))
305+ # Create receive buffer
306+ recv_buf = torch .zeros (
307+ total_batch_size * chunk_size ,
308+ dtype = input_parallel .dtype ,
309+ device = input_parallel .device )
310+
311+ # Create split array
312+ recv_splits = [size * chunk_size for size in tp_group_batchsize ]
313+ send_splits = [local_batch_size * chunk_size ] * self .tp_size
314+
315+ # Perform all-to-all communication
316+ dist .all_to_all_single (
317+ recv_buf ,
318+ send_buf ,
319+ recv_splits ,
320+ send_splits ,
321+ group = self .comm_group .device_group )
343322
344323 input_parallel = recv_buf .view (total_batch_size , chunk_size )
345324
@@ -350,31 +329,26 @@ def _forward_oproj_tp(
350329 input_parallel ,
351330 bias = bias_ )
352331
353- if with_prefill or not self .ascend_config .torchair_graph_config .enabled :
354- # prepare all-reduce data
355- output = torch .empty (
356- local_batch_size ,
357- output_parallel .size (1 ),
358- dtype = output_parallel .dtype ,
359- device = output_parallel .device )
360-
361- recv_chunks = []
362- start_idx = 0
363- for size in tp_group_batchsize :
364- chunk = output_parallel [start_idx :start_idx + size , :]
365- recv_chunks .append (chunk .contiguous ())
366- start_idx += size
367-
368- # Reduce-scatter the results across devices
369- dist .reduce_scatter (
370- output ,
371- recv_chunks ,
372- op = dist .ReduceOp .SUM ,
373- group = self .comm_group .device_group )
374-
375- else :
376- # otp-specific: Combine partial results across devices
377- output = self .comm_group .reduce_scatter (output_parallel , dim = 0 )
332+ # prepare all-reduce data
333+ output = torch .empty (
334+ local_batch_size ,
335+ output_parallel .size (1 ),
336+ dtype = output_parallel .dtype ,
337+ device = output_parallel .device )
338+
339+ recv_chunks = []
340+ start_idx = 0
341+ for size in tp_group_batchsize :
342+ chunk = output_parallel [start_idx :start_idx + size , :]
343+ recv_chunks .append (chunk .contiguous ())
344+ start_idx += size
345+
346+ # Reduce-scatter the results across devices
347+ dist .reduce_scatter (
348+ output ,
349+ recv_chunks ,
350+ op = dist .ReduceOp .SUM ,
351+ group = self .comm_group .device_group )
378352
379353 # Handle bias return based on configuration
380354 output_bias = self .bias if self .skip_bias_add else None
@@ -689,5 +663,4 @@ def __init__(
689663 self .quant_method = quant_config .get_quant_method (self ,
690664 prefix = prefix )
691665 self .return_bias = return_bias
692- self .disable_tp = disable_tp
693- self .ascend_config = get_ascend_config ()
666+ self .disable_tp = disable_tp
0 commit comments