From d4e5e870cb93447e755bc30569fe1ec4a240fcda Mon Sep 17 00:00:00 2001 From: Mick Date: Sat, 7 Feb 2026 09:44:02 +0800 Subject: [PATCH 1/3] [diffusion] fix: respect dist_timeout option --- .../runtime/distributed/parallel_state.py | 18 ++++++++++++++++-- .../runtime/managers/gpu_worker.py | 1 + .../multimodal_gen/runtime/server_args.py | 8 ++++++-- 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/distributed/parallel_state.py b/python/sglang/multimodal_gen/runtime/distributed/parallel_state.py index b0264adb39be..68894eddf664 100644 --- a/python/sglang/multimodal_gen/runtime/distributed/parallel_state.py +++ b/python/sglang/multimodal_gen/runtime/distributed/parallel_state.py @@ -219,6 +219,7 @@ def init_distributed_environment( local_rank: int = 0, backend: str = "nccl", device_id: torch.device | None = None, + timeout: int | None = None, ): # Determine the appropriate backend based on the platform from sglang.multimodal_gen.runtime.platforms import current_platform @@ -229,12 +230,14 @@ def init_distributed_environment( logger.info("Using gloo backend for %s platform", current_platform.device_name) logger.debug( - "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", + "world_size=%d rank=%d local_rank=%d " + "distributed_init_method=%s backend=%s timeout=%s", world_size, rank, local_rank, distributed_init_method, backend, + timeout, ) if not torch.distributed.is_initialized(): assert distributed_init_method is not None, ( @@ -248,6 +251,14 @@ def init_distributed_environment( if (current_platform.is_mps() or current_platform.is_musa()) else dict(device_id=device_id) ) + + # set time out in seconds + if timeout is not None: + import datetime + + extra_args["timeout"] = datetime.timedelta(seconds=timeout) + logger.info(f"Setting distributed timeout to {timeout} seconds") + torch.distributed.init_process_group( backend=backend, init_method=distributed_init_method, @@ -577,6 +588,7 @@ def maybe_init_distributed_environment_and_model_parallel( ring_degree: int = 1, dp_size: int = 1, distributed_init_method: str = "env://", + dist_timeout: int | None = None, ): from sglang.multimodal_gen.runtime.platforms import current_platform @@ -594,9 +606,10 @@ def maybe_init_distributed_environment_and_model_parallel( rank = int(os.environ.get("RANK", 0)) device = get_local_torch_device() logger.info( - "Initializing distributed environment with world_size=%d, device=%s", + "Initializing distributed environment with world_size=%d, device=%s, timeout=%s", world_size, device, + dist_timeout, main_process_only=False, ) @@ -606,6 +619,7 @@ def maybe_init_distributed_environment_and_model_parallel( local_rank=local_rank, distributed_init_method=distributed_init_method, device_id=device, + timeout=dist_timeout, ) initialize_model_parallel( data_parallel_size=dp_size, diff --git a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py index 5008c1d61968..9af2e1efce4a 100644 --- a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py +++ b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py @@ -85,6 +85,7 @@ def init_device_and_model(self) -> None: ring_degree=self.server_args.ring_degree, sp_size=self.server_args.sp_degree, dp_size=self.server_args.dp_size, + dist_timeout=self.server_args.dist_timeout, ) self.pipeline = build_pipeline(self.server_args) diff --git a/python/sglang/multimodal_gen/runtime/server_args.py b/python/sglang/multimodal_gen/runtime/server_args.py index 4aed80cea16d..ab40769a00f8 100644 --- a/python/sglang/multimodal_gen/runtime/server_args.py +++ b/python/sglang/multimodal_gen/runtime/server_args.py @@ -266,7 +266,9 @@ class ServerArgs: hsdp_replicate_dim: int = 1 hsdp_shard_dim: int = -1 - dist_timeout: int | None = None # timeout for torch.distributed + dist_timeout: int | None = ( + 3600 # timeout for torch.distributed (in seconds), default 1 hour + ) pipeline_config: PipelineConfig = field(default_factory=PipelineConfig, repr=False) @@ -592,7 +594,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--dist-timeout", type=int, default=ServerArgs.dist_timeout, - help="Set timeout for torch.distributed initialization.", + help="Timeout for torch.distributed operations in seconds. " + "Increase this value if you encounter 'Connection closed by peer' errors after the service is idle. " + "Default is 3600 seconds (1 hour). Set to a larger value for services that may be idle for longer periods.", ) # Prompt text file for batch processing From e20a4966a5a09f28c926ce69fb5ddddebd9a9675 Mon Sep 17 00:00:00 2001 From: Mick Date: Sat, 7 Feb 2026 09:48:01 +0800 Subject: [PATCH 2/3] upd --- python/sglang/multimodal_gen/runtime/server_args.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/server_args.py b/python/sglang/multimodal_gen/runtime/server_args.py index ab40769a00f8..61bebc8d8add 100644 --- a/python/sglang/multimodal_gen/runtime/server_args.py +++ b/python/sglang/multimodal_gen/runtime/server_args.py @@ -266,9 +266,7 @@ class ServerArgs: hsdp_replicate_dim: int = 1 hsdp_shard_dim: int = -1 - dist_timeout: int | None = ( - 3600 # timeout for torch.distributed (in seconds), default 1 hour - ) + dist_timeout: int | None = 3600 # 1 hour pipeline_config: PipelineConfig = field(default_factory=PipelineConfig, repr=False) @@ -595,8 +593,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=int, default=ServerArgs.dist_timeout, help="Timeout for torch.distributed operations in seconds. " - "Increase this value if you encounter 'Connection closed by peer' errors after the service is idle. " - "Default is 3600 seconds (1 hour). Set to a larger value for services that may be idle for longer periods.", + "Increase this value if you encounter 'Connection closed by peer' errors after the service is idle. ", ) # Prompt text file for batch processing From a5e7e948353a7544d47cf97c699616ba985883ad Mon Sep 17 00:00:00 2001 From: Mick Date: Sat, 7 Feb 2026 19:07:54 +0800 Subject: [PATCH 3/3] upd --- .../multimodal_gen/runtime/distributed/parallel_state.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/distributed/parallel_state.py b/python/sglang/multimodal_gen/runtime/distributed/parallel_state.py index 68894eddf664..24d1dca770d3 100644 --- a/python/sglang/multimodal_gen/runtime/distributed/parallel_state.py +++ b/python/sglang/multimodal_gen/runtime/distributed/parallel_state.py @@ -1,5 +1,4 @@ # Copied and adapted from: https://github.com/hao-ai-lab/FastVideo - # SPDX-License-Identifier: Apache-2.0 # Adapted from: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/parallel_state.py # Copyright 2023 The vLLM team. @@ -31,6 +30,7 @@ you can skip the model parallel initialization and destruction steps. """ import contextlib +import datetime import os import weakref from collections import namedtuple @@ -252,9 +252,7 @@ def init_distributed_environment( else dict(device_id=device_id) ) - # set time out in seconds if timeout is not None: - import datetime extra_args["timeout"] = datetime.timedelta(seconds=timeout) logger.info(f"Setting distributed timeout to {timeout} seconds") @@ -266,6 +264,7 @@ def init_distributed_environment( rank=rank, **extra_args, ) + # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816