diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 678cd45800de..10f29393fe6f 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -48,6 +48,7 @@ def naive_multicast( rank = self.rank if is_sequence_parallel else self.dp_rank world_size = self.world_size if is_sequence_parallel else self.dp_world_size + dist_group = get_ep_group() if is_sequence_parallel else get_dp_group() start = 0 if rank == 0 else cu_tokens_across_sp_cpu[rank - 1] end = cu_tokens_across_sp_cpu[rank] @@ -55,7 +56,7 @@ def naive_multicast( for idx in range(world_size): start = 0 if idx == 0 else cu_tokens_across_sp_cpu[idx - 1] end = cu_tokens_across_sp_cpu[idx] - get_ep_group().broadcast(buffer[start:end, :], idx) + dist_group.broadcast(buffer[start:end, :], idx) return buffer @@ -125,7 +126,8 @@ def combine( start = 0 if ep_rank == 0 else cu_tokens_across_sp_cpu[ep_rank - 1] end = cu_tokens_across_sp_cpu[ep_rank] - all_hidden_states = get_ep_group().all_reduce(hidden_states) + dist_group = get_ep_group() if is_sequence_parallel else get_dp_group() + all_hidden_states = dist_group.all_reduce(hidden_states) hidden_states = all_hidden_states[start:end, :] return hidden_states