Skip to content

Conversation

@trevor-m
Copy link
Contributor

@trevor-m trevor-m commented Jun 26, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

Adds all-gatherv which can support varying sizes per rank and a list of input tensors.
Adds reduce-scatterv which can support varying sizes per rank.

These new collectives can be used for MoE dispatch/combine for DP without padding. See #20037 for usage.

Test Plan

Added tests to test_pynccl.py

Test Result

# python3 -m pytest tests/distributed/test_pynccl.py
================================================================================================================================= test session starts ==================================================================================================================================platform linux -- Python 3.12.3, pytest-8.1.1, pluggy-1.6.0
rootdir: /trevor/vllm
configfile: pyproject.toml
plugins: anyio-4.9.0, shard-0.1.2, rerunfailures-15.1, xdoctest-1.0.2, flakefinder-1.1.0, xdist-3.6.1, hypothesis-6.130.8, typeguard-4.3.0
collected 12 items
Running 12 items in this shard

tests/distributed/test_pynccl.py ............                                                                                                                                                                                                                                    [100%]

=================================================================================================================================== warnings summary ===================================================================================================================================tests/distributed/test_pynccl.py: 30 warnings
  /usr/lib/python3.12/multiprocessing/popen_fork.py:66: DeprecationWarning: This process (pid=23651) is multi-threaded, use of fork() may lead to deadlocks in the child.
    self.pid = os.fork()

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=========================================================================================================================== 12 passed, 30 warnings in 40.05s ===========================================================================================================================

(Optional) Documentation Update

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @trevor-m, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the distributed communication capabilities by introducing variable-sized all_gather and reduce_scatter operations (all_gatherv and reducescatterv) to the pynccl backend. This change is vital for optimizing distributed training workflows, particularly for Mixture-of-Experts (MoE) models, by allowing for more efficient data handling without the overhead of padding.

Highlights

  • New Collective Operations: Implemented all_gatherv and reducescatterv within the PyNcclCommunicator to support distributed tensor operations where input/output sizes can vary per rank. This is a crucial enhancement for flexible distributed training.
  • MoE Optimization: These new variable-sized collectives are specifically designed to enable more efficient Mixture-of-Experts (MoE) dispatch and combine operations in distributed parallel (DP) settings, by eliminating the need for padding, which can lead to significant performance gains.
  • NCCL Primitives Integration: Extended the pynccl wrapper to expose the underlying ncclReduce, ncclGroupStart, and ncclGroupEnd NCCL primitives. These low-level functions are utilized to implement the complex logic required for the variable-sized all_gatherv (via ncclBroadcast calls within a group) and reducescatterv (via ncclReduce calls within a group) operations.
  • Comprehensive Testing: Added new unit tests (test_pynccl_all_gatherv and test_pynccl_reduce_scatterv) to tests/distributed/test_pynccl.py to thoroughly validate the correctness and functionality of the newly introduced variable-sized collective operations across multiple GPUs.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces new collective communication operations (all-gatherv and reduce-scatterv) using PyNCCL. The changes include modifications to the base device communicator, CUDA communicator, and PyNCCL wrapper to support these new operations. Tests have also been added to verify the correctness of the new implementations. The code appears to be well-structured and addresses the intended functionality. However, there are some areas where improvements can be made to enhance efficiency, error handling, and code clarity.

Comment on lines 171 to 209
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The assertion assert dim == 0, "only dim 0 all-gatherv is supported" will always raise an error if dim is not 0. It might be better to raise a NotImplementedError to indicate that other dimensions are not supported.

Suggested change
assert dim == 0, "only dim 0 all-gatherv is supported"
world_size = self.world_size
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None and not pynccl_comm.disabled
if dim != 0:
raise NotImplementedError("only dim 0 all-gatherv is supported")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, better to raise rather than assert here

Comment on lines 151 to 169
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition sizes is not None is checked twice, once here and once in the all_gather function. Consider refactoring to avoid this redundancy and improve maintainability.

Suggested change
if sizes is not None:
assert output_tensor.shape[0] == sum(sizes)
numel_base = int(np.prod(output_tensor.shape[1:]))
split_offset = 0
self.nccl.ncclGroupStart()
for root, split_size in enumerate(sizes):
dst_slice = output_tensor[split_offset:split_offset +
split_size]
self.nccl.ncclBroadcast(
buffer_type(input_tensor.data_ptr()),
buffer_type(dst_slice.data_ptr()),
split_size * numel_base,
ncclDataTypeEnum.from_torch(input_tensor.dtype),
root,
self.comm,
cudaStream_t(stream.cuda_stream),
)
split_offset += split_size
self.nccl.ncclGroupEnd()
if sizes is not None:
assert output_tensor.shape[0] == sum(sizes)
numel_base = int(np.prod(output_tensor.shape[1:]))
split_offset = 0
self.nccl.ncclGroupStart()
for root, split_size in enumerate(sizes):
dst_slice = output_tensor[split_offset:split_offset +
split_size]
self.nccl.ncclBroadcast(
buffer_type(input_tensor.data_ptr()),
buffer_type(dst_slice.data_ptr()),
split_size * numel_base,
ncclDataTypeEnum.from_torch(input_tensor.dtype),
root,
self.comm,
cudaStream_t(stream.cuda_stream),
)
split_offset += split_size
self.nccl.ncclGroupEnd()
return # early return to avoid duplicate code

Comment on lines 194 to 212
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for reduce_scatter when sizes is not None uses ncclReduce in a loop. This approach might not be the most efficient way to perform reduce scatter with varying sizes. It could be beneficial to investigate alternative approaches or optimized NCCL functions that can handle this operation more efficiently. Also, it is not clear if the root argument in ncclReduce is used correctly, since it is always the same as enumerate(sizes).

Suggested change
if sizes is not None:
numel_base = int(np.prod(input_tensor.shape[1:]))
split_offset = 0
self.nccl.ncclGroupStart()
for root, split_size in enumerate(sizes):
chunk = input_tensor[split_offset:split_offset + split_size, ...]
self.nccl.ncclReduce(
buffer_type(chunk.data_ptr()),
buffer_type(output_tensor.data_ptr()),
split_size * numel_base,
ncclDataTypeEnum.from_torch(input_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op),
root,
self.comm,
cudaStream_t(stream.cuda_stream)
)
split_offset += split_size
self.nccl.ncclGroupEnd()
if sizes is not None:
numel_base = int(np.prod(input_tensor.shape[1:]))
split_offset = 0
self.nccl.ncclGroupStart()
for root, split_size in enumerate(sizes):
chunk = input_tensor[split_offset:split_offset + split_size, ...]
self.nccl.ncclReduce(
buffer_type(chunk.data_ptr()),
buffer_type(output_tensor.data_ptr()),
split_size * numel_base,
ncclDataTypeEnum.from_torch(input_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op),
root,
self.comm,
cudaStream_t(stream.cuda_stream)
)
split_offset += split_size
self.nccl.ncclGroupEnd()

@robertgshaw2-redhat
Copy link
Collaborator

cc @bnellnm @tlrmchlsmth @varun-sundar-rabindranath

@robertgshaw2-redhat
Copy link
Collaborator

QQ - how should I think about these kernels in comparison to PPLX and DeepEP?

@trevor-m
Copy link
Contributor Author

QQ - how should I think about these kernels in comparison to PPLX and DeepEP?

These collectives are used by TRT-LLM for the "min latency mode" path. Do pplx/deepep target latency or max throughput?
I don't know if direct performance comparisons against PPLX and DeepEP have been done before.

@robertgshaw2-redhat Allgatherv will outperform the current naive_multicast used for MoE dispatch and reducescatterv will outperform the current allreduce+slice combine. Being able to all-gather on a list of tensors will also allow us to quantize to fp4 before all gather.

Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please run the pre-commit to fix the linter errors https://docs.vllm.ai/en/stable/contributing/#code-quality

Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 171 to 209
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, better to raise rather than assert here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks interesting -- How do ncclGroupStart() and ncclGroupEnd() work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These functions allow you batch multiple NCCL calls into one single launch.
You can read more about it here: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/groups.html

Signed-off-by: Trevor Morris <[email protected]>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One last comment - I think it would be cleaner and clearer if the all_gather and all_gatherv implementations were completely separate. Right now it's slightly awkward that all_gatherv calls pynccl all_gather with a list of sizes. Ditto for reduce_scatter/reduce_scatterv.

Otherwise looks good to me, thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, done!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tlrmchlsmth Could you take a look?

trevor-m added 5 commits July 4, 2025 03:57
Signed-off-by: Trevor Morris <[email protected]>
Signed-off-by: Trevor Morris <[email protected]>
Signed-off-by: Trevor Morris <[email protected]>
Signed-off-by: Trevor Morris <[email protected]>
Signed-off-by: Trevor Morris <[email protected]>
dim: int) -> torch.Tensor:
return self.device_communicator.all_gather(input_, dim)

def all_gatherv(self,
Copy link
Contributor

@wenscarl wenscarl Jul 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The self.use_custom_call_op in this file is set in a way that the newly added ag and rs features cannot be enabled simply by toggling an environment variable.

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks reasonable to me, just one nit

Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you!

@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 9, 2025
@mgoin mgoin enabled auto-merge (squash) July 9, 2025 22:08
@mgoin
Copy link
Member

mgoin commented Jul 11, 2025

@trevor-m It seems there is an issue with your tests on the CI, PTAL https://buildkite.com/vllm/ci/builds/23741/steps/canvas?sid=0197faaf-9704-4b36-a45d-24076ea18d02#0197faaf-97ff-4e37-b73c-8d7912f280a9/6-10388

[2025-07-11T19:34:38Z] Traceback (most recent call last):
[2025-07-11T19:34:38Z]   File "/usr/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
[2025-07-11T19:34:38Z]     self.run()
[2025-07-11T19:34:38Z]   File "/usr/lib/python3.12/multiprocessing/process.py", line 108, in run
[2025-07-11T19:34:38Z]     self._target(*self._args, **self._kwargs)
[2025-07-11T19:34:38Z]   File "/vllm-workspace/tests/distributed/test_pynccl.py", line 54, in wrapped_fn
[2025-07-11T19:34:38Z]     fn()
[2025-07-11T19:34:38Z]   File "/vllm-workspace/tests/distributed/test_pynccl.py", line 202, in all_gatherv_worker_fn
[2025-07-11T19:34:38Z]     pynccl_comm.all_gather(result, tensor, sizes=sizes)
[2025-07-11T19:34:38Z]   File "/usr/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
[2025-07-11T19:34:38Z]     self.run()
[2025-07-11T19:34:38Z] TypeError: PyNcclCommunicator.all_gather() got an unexpected keyword argument 'sizes'
[2025-07-11T19:34:38Z]   File "/usr/lib/python3.12/multiprocessing/process.py", line 108, in run
[2025-07-11T19:34:38Z]     self._target(*self._args, **self._kwargs)
[2025-07-11T19:34:38Z]   File "/vllm-workspace/tests/distributed/test_pynccl.py", line 54, in wrapped_fn
[2025-07-11T19:34:38Z]     fn()
[2025-07-11T19:34:38Z]   File "/vllm-workspace/tests/distributed/test_pynccl.py", line 202, in all_gatherv_worker_fn
[2025-07-11T19:34:38Z]     pynccl_comm.all_gather(result, tensor, sizes=sizes)
[2025-07-11T19:34:38Z] TypeError: PyNcclCommunicator.all_gather() got an unexpected keyword argument 'sizes'

EDIT: I pushed a fix since it was just the func rename

@trevor-m
Copy link
Contributor Author

Thank you @mgoin !

@simon-mo simon-mo disabled auto-merge July 12, 2025 01:59
@simon-mo simon-mo merged commit a859323 into vllm-project:main Jul 12, 2025
67 of 70 checks passed
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
Signed-off-by: Trevor Morris <[email protected]>
Signed-off-by: mgoin <[email protected]>
Co-authored-by: mgoin <[email protected]>
Signed-off-by: x22x22 <[email protected]>
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
Signed-off-by: Trevor Morris <[email protected]>
Signed-off-by: mgoin <[email protected]>
Co-authored-by: mgoin <[email protected]>
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
Signed-off-by: Trevor Morris <[email protected]>
Signed-off-by: mgoin <[email protected]>
Co-authored-by: mgoin <[email protected]>
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
Signed-off-by: Trevor Morris <[email protected]>
Signed-off-by: mgoin <[email protected]>
Co-authored-by: mgoin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
Signed-off-by: Trevor Morris <[email protected]>
Signed-off-by: mgoin <[email protected]>
Co-authored-by: mgoin <[email protected]>
Signed-off-by: Paul Pak <[email protected]>
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
Signed-off-by: Trevor Morris <[email protected]>
Signed-off-by: mgoin <[email protected]>
Co-authored-by: mgoin <[email protected]>
Signed-off-by: Diego-Castan <[email protected]>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 27, 2025
Signed-off-by: Trevor Morris <[email protected]>
Signed-off-by: mgoin <[email protected]>
Co-authored-by: mgoin <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants