From 2c02b99311d896a985f80a199e85a9dc5ee5a19d Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Tue, 2 Sep 2025 14:55:02 +0000 Subject: [PATCH 1/6] update num_tokens_across_dp to use nccl instead of gloo Signed-off-by: Sage Moore --- vllm/forward_context.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index c57c51d289ac..806a05c92d23 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -78,10 +78,10 @@ def num_tokens_across_dp(num_tokens: int, dp_size: int, num_tokens_across_dp = [0] * dp_size num_tokens_across_dp[dp_rank] = num_tokens num_tokens_tensor = torch.tensor(num_tokens_across_dp, - device="cpu", + device="cuda", dtype=torch.int32) from vllm.distributed.parallel_state import get_dp_group - dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) + dist.all_reduce(num_tokens_tensor, group=get_dp_group().device_group) return num_tokens_tensor @staticmethod From 66e605ac519e2e72f5e48246c36f039467939e50 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Mon, 8 Sep 2025 13:41:21 +0000 Subject: [PATCH 2/6] review comments Signed-off-by: Sage Moore --- vllm/forward_context.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 806a05c92d23..153b7820ca42 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -13,6 +13,7 @@ import vllm.envs as envs from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig from vllm.logger import init_logger +from vllm.platforms import current_platform if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -75,14 +76,25 @@ def num_tokens_across_dp(num_tokens: int, dp_size: int, Gather the num_tokens across all DP ranks and return results in a CPU tensor of size dp_size. """ + from vllm.distributed.parallel_state import get_dp_group + device = current_platform.device_type + group = get_dp_group().device_group + + # Transfering this tensor from GPU to CPU will introduce a GPU sync point + # that could adversely affect performance of vllm with asynch scheduling. + # This environment variable exists to quickly disable this optimization + # if we run into this case. + if envs.VLLM_DISABLE_NCCL_DP_PADDING: + logger.info("Using CPU all reduce to syncronize DP padding between ranks.") + device = "cpu" + group = get_dp_group().cpu_group num_tokens_across_dp = [0] * dp_size num_tokens_across_dp[dp_rank] = num_tokens num_tokens_tensor = torch.tensor(num_tokens_across_dp, - device="cuda", + device=device, dtype=torch.int32) - from vllm.distributed.parallel_state import get_dp_group - dist.all_reduce(num_tokens_tensor, group=get_dp_group().device_group) - return num_tokens_tensor + dist.all_reduce(num_tokens_tensor, group=group) + return num_tokens_tensor.cpu() @staticmethod def make( From efa3a3cbaae243fa17a8d35e4ed028d41933d288 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Mon, 8 Sep 2025 13:58:50 +0000 Subject: [PATCH 3/6] env var fixes Signed-off-by: Sage Moore --- vllm/envs.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/envs.py b/vllm/envs.py index 1232bd7bf963..7ae27146e84b 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -91,6 +91,7 @@ VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_SKIP_P2P_CHECK: bool = False VLLM_DISABLED_KERNELS: list[str] = [] + VLLM_DISABLE_NCCL_DP_PADDING: bool = False VLLM_USE_V1: bool = True VLLM_ROCM_USE_AITER: bool = False VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False @@ -728,6 +729,10 @@ def get_vllm_port() -> Optional[int]: "VLLM_DISABLED_KERNELS": lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[ "VLLM_DISABLED_KERNELS"].split(","), + + "VLLM_DISABLE_NCCL_DP_PADDING": + lambda: (os.getenv("VLLM_DISABLE_NCCL_DP_PADDING", "False").lower() in + ("true", "1")), # If set, use the V1 code path. "VLLM_USE_V1": From bd40b650ae0aa09d83e6123b431872edf507802b Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Tue, 9 Sep 2025 13:20:49 +0000 Subject: [PATCH 4/6] env var fixes Signed-off-by: Sage Moore --- vllm/forward_context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 153b7820ca42..32a09694c3c2 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -85,7 +85,7 @@ def num_tokens_across_dp(num_tokens: int, dp_size: int, # This environment variable exists to quickly disable this optimization # if we run into this case. if envs.VLLM_DISABLE_NCCL_DP_PADDING: - logger.info("Using CPU all reduce to syncronize DP padding between ranks.") + logger.info_once("Using CPU all reduce to syncronize DP padding between ranks.") device = "cpu" group = get_dp_group().cpu_group num_tokens_across_dp = [0] * dp_size From 6f9c88f8dc53250c12654faeefb80cb3bd5f797e Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Thu, 11 Sep 2025 18:59:37 +0000 Subject: [PATCH 5/6] lint Signed-off-by: Sage Moore --- vllm/envs.py | 2 +- vllm/forward_context.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index d11dc79e1f67..63fe00319abb 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -739,7 +739,7 @@ def get_vllm_port() -> Optional[int]: "VLLM_DISABLED_KERNELS": lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[ "VLLM_DISABLED_KERNELS"].split(","), - + "VLLM_DISABLE_NCCL_DP_PADDING": lambda: (os.getenv("VLLM_DISABLE_NCCL_DP_PADDING", "False").lower() in ("true", "1")), diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 32a09694c3c2..5d68360e4d92 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -80,12 +80,13 @@ def num_tokens_across_dp(num_tokens: int, dp_size: int, device = current_platform.device_type group = get_dp_group().device_group - # Transfering this tensor from GPU to CPU will introduce a GPU sync point - # that could adversely affect performance of vllm with asynch scheduling. - # This environment variable exists to quickly disable this optimization - # if we run into this case. + # Transfering this tensor from GPU to CPU will introduce a GPU sync + # point that could adversely affect performance of vllm with asynch + # scheduling. This environment variable exists to quickly disable + # this optimization if we run into this case. if envs.VLLM_DISABLE_NCCL_DP_PADDING: - logger.info_once("Using CPU all reduce to syncronize DP padding between ranks.") + logger.info_once( + "Using CPU all reduce to syncronize DP padding between ranks.") device = "cpu" group = get_dp_group().cpu_group num_tokens_across_dp = [0] * dp_size From 1e884dc66d61807f40150d0777c074fb10450eee Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Mon, 15 Sep 2025 15:09:17 +0000 Subject: [PATCH 6/6] change env var name Signed-off-by: Sage Moore --- vllm/envs.py | 9 ++++++--- vllm/forward_context.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 63fe00319abb..215dcde7aebb 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -92,7 +92,7 @@ VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_SKIP_P2P_CHECK: bool = False VLLM_DISABLED_KERNELS: list[str] = [] - VLLM_DISABLE_NCCL_DP_PADDING: bool = False + VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION: bool = False VLLM_USE_V1: bool = True VLLM_ROCM_USE_AITER: bool = False VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False @@ -740,8 +740,11 @@ def get_vllm_port() -> Optional[int]: lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[ "VLLM_DISABLED_KERNELS"].split(","), - "VLLM_DISABLE_NCCL_DP_PADDING": - lambda: (os.getenv("VLLM_DISABLE_NCCL_DP_PADDING", "False").lower() in + # Swaps the all reduce backend that we use to coordinate the DP padding + # information from NCCL to gloo. + "VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION": + lambda: + (os.getenv("VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION", "False").lower() in ("true", "1")), # If set, use the V1 code path. diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 5d68360e4d92..b3ddd7b9a739 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -84,7 +84,7 @@ def num_tokens_across_dp(num_tokens: int, dp_size: int, # point that could adversely affect performance of vllm with asynch # scheduling. This environment variable exists to quickly disable # this optimization if we run into this case. - if envs.VLLM_DISABLE_NCCL_DP_PADDING: + if envs.VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION: logger.info_once( "Using CPU all reduce to syncronize DP padding between ranks.") device = "cpu"