Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo

# SPDX-License-Identifier: Apache-2.0
# Adapted from: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/parallel_state.py
# Copyright 2023 The vLLM team.
Expand Down Expand Up @@ -31,6 +30,7 @@
you can skip the model parallel initialization and destruction steps.
"""
import contextlib
import datetime
import os
import weakref
from collections import namedtuple
Expand Down Expand Up @@ -219,6 +219,7 @@ def init_distributed_environment(
local_rank: int = 0,
backend: str = "nccl",
device_id: torch.device | None = None,
timeout: int | None = None,
):
# Determine the appropriate backend based on the platform
from sglang.multimodal_gen.runtime.platforms import current_platform
Expand All @@ -229,12 +230,14 @@ def init_distributed_environment(
logger.info("Using gloo backend for %s platform", current_platform.device_name)

logger.debug(
"world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s",
"world_size=%d rank=%d local_rank=%d "
"distributed_init_method=%s backend=%s timeout=%s",
world_size,
rank,
local_rank,
distributed_init_method,
backend,
timeout,
)
if not torch.distributed.is_initialized():
assert distributed_init_method is not None, (
Expand All @@ -248,13 +251,20 @@ def init_distributed_environment(
if (current_platform.is_mps() or current_platform.is_musa())
else dict(device_id=device_id)
)

if timeout is not None:

extra_args["timeout"] = datetime.timedelta(seconds=timeout)
logger.info(f"Setting distributed timeout to {timeout} seconds")

torch.distributed.init_process_group(
backend=backend,
init_method=distributed_init_method,
world_size=world_size,
rank=rank,
**extra_args,
)

# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
Expand Down Expand Up @@ -577,6 +587,7 @@ def maybe_init_distributed_environment_and_model_parallel(
ring_degree: int = 1,
dp_size: int = 1,
distributed_init_method: str = "env://",
dist_timeout: int | None = None,
):
from sglang.multimodal_gen.runtime.platforms import current_platform

Expand All @@ -594,9 +605,10 @@ def maybe_init_distributed_environment_and_model_parallel(
rank = int(os.environ.get("RANK", 0))
device = get_local_torch_device()
logger.info(
"Initializing distributed environment with world_size=%d, device=%s",
"Initializing distributed environment with world_size=%d, device=%s, timeout=%s",
world_size,
device,
dist_timeout,
main_process_only=False,
)

Expand All @@ -606,6 +618,7 @@ def maybe_init_distributed_environment_and_model_parallel(
local_rank=local_rank,
distributed_init_method=distributed_init_method,
device_id=device,
timeout=dist_timeout,
)
initialize_model_parallel(
data_parallel_size=dp_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def init_device_and_model(self) -> None:
ring_degree=self.server_args.ring_degree,
sp_size=self.server_args.sp_degree,
dp_size=self.server_args.dp_size,
dist_timeout=self.server_args.dist_timeout,
)

self.pipeline = build_pipeline(self.server_args)
Expand Down
5 changes: 3 additions & 2 deletions python/sglang/multimodal_gen/runtime/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ class ServerArgs:

hsdp_replicate_dim: int = 1
hsdp_shard_dim: int = -1
dist_timeout: int | None = None # timeout for torch.distributed
dist_timeout: int | None = 3600 # 1 hour

pipeline_config: PipelineConfig = field(default_factory=PipelineConfig, repr=False)

Expand Down Expand Up @@ -592,7 +592,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"--dist-timeout",
type=int,
default=ServerArgs.dist_timeout,
help="Set timeout for torch.distributed initialization.",
help="Timeout for torch.distributed operations in seconds. "
"Increase this value if you encounter 'Connection closed by peer' errors after the service is idle. ",
)

# Prompt text file for batch processing
Expand Down
Loading