diff --git a/python/sglang/multimodal_gen/runtime/distributed/parallel_state.py b/python/sglang/multimodal_gen/runtime/distributed/parallel_state.py index b0264adb39be..24d1dca770d3 100644 --- a/python/sglang/multimodal_gen/runtime/distributed/parallel_state.py +++ b/python/sglang/multimodal_gen/runtime/distributed/parallel_state.py @@ -1,5 +1,4 @@ # Copied and adapted from: https://github.com/hao-ai-lab/FastVideo - # SPDX-License-Identifier: Apache-2.0 # Adapted from: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/parallel_state.py # Copyright 2023 The vLLM team. @@ -31,6 +30,7 @@ you can skip the model parallel initialization and destruction steps. """ import contextlib +import datetime import os import weakref from collections import namedtuple @@ -219,6 +219,7 @@ def init_distributed_environment( local_rank: int = 0, backend: str = "nccl", device_id: torch.device | None = None, + timeout: int | None = None, ): # Determine the appropriate backend based on the platform from sglang.multimodal_gen.runtime.platforms import current_platform @@ -229,12 +230,14 @@ def init_distributed_environment( logger.info("Using gloo backend for %s platform", current_platform.device_name) logger.debug( - "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", + "world_size=%d rank=%d local_rank=%d " + "distributed_init_method=%s backend=%s timeout=%s", world_size, rank, local_rank, distributed_init_method, backend, + timeout, ) if not torch.distributed.is_initialized(): assert distributed_init_method is not None, ( @@ -248,6 +251,12 @@ def init_distributed_environment( if (current_platform.is_mps() or current_platform.is_musa()) else dict(device_id=device_id) ) + + if timeout is not None: + + extra_args["timeout"] = datetime.timedelta(seconds=timeout) + logger.info(f"Setting distributed timeout to {timeout} seconds") + torch.distributed.init_process_group( backend=backend, init_method=distributed_init_method, @@ -255,6 +264,7 @@ def init_distributed_environment( rank=rank, **extra_args, ) + # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 @@ -577,6 +587,7 @@ def maybe_init_distributed_environment_and_model_parallel( ring_degree: int = 1, dp_size: int = 1, distributed_init_method: str = "env://", + dist_timeout: int | None = None, ): from sglang.multimodal_gen.runtime.platforms import current_platform @@ -594,9 +605,10 @@ def maybe_init_distributed_environment_and_model_parallel( rank = int(os.environ.get("RANK", 0)) device = get_local_torch_device() logger.info( - "Initializing distributed environment with world_size=%d, device=%s", + "Initializing distributed environment with world_size=%d, device=%s, timeout=%s", world_size, device, + dist_timeout, main_process_only=False, ) @@ -606,6 +618,7 @@ def maybe_init_distributed_environment_and_model_parallel( local_rank=local_rank, distributed_init_method=distributed_init_method, device_id=device, + timeout=dist_timeout, ) initialize_model_parallel( data_parallel_size=dp_size, diff --git a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py index 5008c1d61968..9af2e1efce4a 100644 --- a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py +++ b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py @@ -85,6 +85,7 @@ def init_device_and_model(self) -> None: ring_degree=self.server_args.ring_degree, sp_size=self.server_args.sp_degree, dp_size=self.server_args.dp_size, + dist_timeout=self.server_args.dist_timeout, ) self.pipeline = build_pipeline(self.server_args) diff --git a/python/sglang/multimodal_gen/runtime/server_args.py b/python/sglang/multimodal_gen/runtime/server_args.py index 4aed80cea16d..61bebc8d8add 100644 --- a/python/sglang/multimodal_gen/runtime/server_args.py +++ b/python/sglang/multimodal_gen/runtime/server_args.py @@ -266,7 +266,7 @@ class ServerArgs: hsdp_replicate_dim: int = 1 hsdp_shard_dim: int = -1 - dist_timeout: int | None = None # timeout for torch.distributed + dist_timeout: int | None = 3600 # 1 hour pipeline_config: PipelineConfig = field(default_factory=PipelineConfig, repr=False) @@ -592,7 +592,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--dist-timeout", type=int, default=ServerArgs.dist_timeout, - help="Set timeout for torch.distributed initialization.", + help="Timeout for torch.distributed operations in seconds. " + "Increase this value if you encounter 'Connection closed by peer' errors after the service is idle. ", ) # Prompt text file for batch processing