-
-
Notifications
You must be signed in to change notification settings - Fork 12.4k
[core] add nccl symmetric memory for all reduce #24532
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9406ee3
eff6ff9
5d21211
4011bdd
575e9e5
a378397
d3539c9
78c5fcd
3f28439
4193d5d
d793dba
d4c040c
c518a76
e1738ed
14ea822
81dd0dc
aa9632a
6ee0f96
0bdb252
aec4569
ba91f87
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,94 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import random | ||
| import typing | ||
|
|
||
| import pytest | ||
| import torch | ||
| import torch.distributed as dist | ||
| import torch.multiprocessing as mp | ||
|
|
||
| import vllm.envs as envs | ||
| from vllm.distributed import cleanup_dist_env_and_memory | ||
| from vllm.distributed.device_communicators.cuda_communicator import ( | ||
| CudaCommunicator) | ||
| from vllm.distributed.device_communicators.pynccl import ( | ||
| register_nccl_symmetric_ops) | ||
| from vllm.distributed.device_communicators.pynccl_allocator import ( | ||
| get_nccl_mem_pool, is_symmetric_memory_enabled) | ||
| from vllm.distributed.parallel_state import (get_tp_group, | ||
| init_distributed_environment, | ||
| initialize_model_parallel) | ||
| from vllm.platforms import current_platform | ||
| from vllm.utils import update_environment_variables | ||
|
|
||
| torch.manual_seed(42) | ||
| random.seed(44) | ||
|
|
||
| test_size_elements = 4 * 1024 * 1024 | ||
|
|
||
|
|
||
| def nccl_symm_mem_allreduce_worker(local_rank: int, world_size: int): | ||
| monkeypatch = pytest.MonkeyPatch() | ||
| with monkeypatch.context() as m: | ||
| m.delenv("CUDA_VISIBLE_DEVICES", raising=False) | ||
| dtype = torch.bfloat16 | ||
| device = torch.device(f"cuda:{local_rank}") | ||
| torch.cuda.set_device(device) | ||
| torch.set_default_device(device) | ||
| torch.set_default_dtype(dtype) | ||
| update_environment_variables({ | ||
| "RANK": str(local_rank), | ||
| "LOCAL_RANK": str(local_rank), | ||
| "WORLD_SIZE": str(world_size), | ||
| "MASTER_ADDR": "localhost", | ||
| "MASTER_PORT": "12345", | ||
| }) | ||
|
|
||
| init_distributed_environment() | ||
| initialize_model_parallel(tensor_model_parallel_size=world_size) | ||
|
|
||
| cuda_communicator = typing.cast(CudaCommunicator, | ||
| get_tp_group().device_communicator) | ||
| pynccl_comm = cuda_communicator.pynccl_comm | ||
| if get_nccl_mem_pool() is None: | ||
| pytest.skip("NCCL allocator compilation failed " | ||
| "(probably missing NCCL headers).") | ||
| if not is_symmetric_memory_enabled(): | ||
| pytest.skip("NCCL symmetric memory allreduce is disabled.") | ||
|
|
||
| register_nccl_symmetric_ops(pynccl_comm) | ||
| input = torch.randint(1, | ||
| 23, (test_size_elements, ), | ||
| dtype=dtype, | ||
| device=device) | ||
| input_clone = input.clone() | ||
| output = torch.ops.vllm.all_reduce_symmetric_with_copy(input) | ||
| assert output is not None | ||
|
|
||
| group = get_tp_group().device_group | ||
| dist.all_reduce(input_clone, group=group) | ||
| torch.testing.assert_close(output, input_clone, atol=2.5, rtol=0.1) | ||
|
|
||
|
|
||
| @pytest.mark.skipif( | ||
| not current_platform.is_cuda(), | ||
| reason="NCCLSymmMemAllreduce is only available for CUDA platforms.", | ||
| ) | ||
| @pytest.mark.parametrize("world_size", [2]) | ||
| @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], | ||
| reason="Only test on CUDA") | ||
| def test_nccl_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, world_size): | ||
| if world_size > torch.cuda.device_count(): | ||
| pytest.skip("Not enough GPUs to run the test.") | ||
|
|
||
| # Enable SymmMemCommunicator | ||
| monkeypatch.setenv("VLLM_USE_NCCL_SYMM_MEM", "1") | ||
| monkeypatch.setenv("NCCL_NVLS_ENABLE", "1") | ||
| monkeypatch.setenv("NCCL_CUMEM_ENABLE", "1") | ||
|
|
||
| mp.spawn(nccl_symm_mem_allreduce_worker, | ||
| args=(world_size, ), | ||
| nprocs=world_size) | ||
| cleanup_dist_env_and_memory() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,6 +7,12 @@ | |
| from torch.distributed import ProcessGroup | ||
|
|
||
| import vllm.envs as envs | ||
| from vllm.distributed.device_communicators.all_reduce_utils import ( | ||
| should_nccl_symm_mem_allreduce) | ||
| from vllm.distributed.device_communicators.pynccl import ( | ||
| register_nccl_symmetric_ops) | ||
| from vllm.distributed.device_communicators.pynccl_allocator import ( | ||
| is_symmetric_memory_enabled) | ||
| from vllm.logger import init_logger | ||
| from vllm.platforms import current_platform | ||
|
|
||
|
|
@@ -53,6 +59,8 @@ def __init__(self, | |
| group=self.cpu_group, | ||
| device=self.device, | ||
| ) | ||
| if is_symmetric_memory_enabled(): | ||
| register_nccl_symmetric_ops(self.pynccl_comm) | ||
|
|
||
| self.ca_comm: Optional[CustomAllreduce] = None | ||
| self.qr_comm: Optional[QuickAllReduce] = None | ||
|
|
@@ -107,6 +115,13 @@ def __init__(self, | |
| raise ValueError(f"Unknown all2all backend: {all2all_backend}") | ||
|
|
||
| def all_reduce(self, input_): | ||
| # since currently we perform copy input -> symm_input -> out-of-place AR | ||
| # return symm_output, we don't need to check if input is symmetric | ||
| if self.pynccl_comm is not None and \ | ||
| should_nccl_symm_mem_allreduce(self.pynccl_comm.world_size,input_): | ||
| out = torch.ops.vllm.all_reduce_symmetric_with_copy(input_) | ||
| if out is not None: | ||
| return out | ||
|
Comment on lines
+122
to
+124
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't we assert that out is not None? When/why would we want to fall through to another method?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the logic for the regular nccl all reduce is: and with nccl symm memory we are still calling |
||
| # always try quick reduce first, then custom allreduce, | ||
| # and then pynccl. (quick reduce just for ROCM MI3*) | ||
| qr_comm = self.qr_comm | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test is not actually run in CI as we need to add it to a job in .buildkite/test-pipeline.yaml
It doesn't seem to work on my L40s or H100 node, so do we need to restrict it to some compute capability or library availability? We can enable this in a followup PR if this is complex.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does it fail on your l40s and h100 nodes or skipped?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It fails to compile and then fails