diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 8a9ddba63e51..de7cc03b69cb 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -49,6 +49,9 @@ import zmq from sglang.srt.elastic_ep.expert_backup_manager import run_expert_backup_manager +from sglang.srt.entrypoints.engine_info_bootstrap_server import ( + EngineInfoBootstrapServer, +) from sglang.srt.entrypoints.EngineBase import EngineBase from sglang.srt.managers.data_parallel_controller import ( run_data_parallel_controller_process, @@ -80,9 +83,6 @@ from sglang.srt.managers.template_manager import TemplateManager from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager_multiitem_mixin import ScoreResult -from sglang.srt.model_loader.remote_instance_weight_loader_utils import ( - parse_remote_instance_transfer_engine_info_from_scheduler_infos, -) from sglang.srt.observability.trace import process_tracing_init, trace_set_thread_info from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( @@ -98,7 +98,7 @@ set_prometheus_multiproc_dir, set_ulimit, ) -from sglang.srt.utils.network import get_zmq_socket +from sglang.srt.utils.network import get_zmq_socket, is_port_available from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils.watchdog import SubprocessWatchdog from sglang.version import __version__ @@ -116,6 +116,7 @@ class SchedulerInitResult: scheduler_infos: List[Dict[str, Any]] wait_for_ready: Callable[[], None] = lambda: None wait_for_completion: Callable[[], None] = lambda: None + engine_info_bootstrap_server: Optional[Any] = None def init_tokenizer_manager( @@ -201,11 +202,11 @@ def __init__(self, **kwargs): if tokenizer_manager is not None: tokenizer_manager._subprocess_watchdog = subprocess_watchdog self.port_args = port_args - self.remote_instance_transfer_engine_info = ( - parse_remote_instance_transfer_engine_info_from_scheduler_infos( - scheduler_init_result.scheduler_infos + # Access transfer engine info if bootstrap server is started. + if scheduler_init_result.engine_info_bootstrap_server is not None: + self.remote_instance_transfer_engine_info = ( + scheduler_init_result.engine_info_bootstrap_server.transfer_engine_info ) - ) # Initialize ZMQ sockets context = zmq.Context(2) @@ -642,10 +643,30 @@ def _launch_subprocesses( port_args = PortArgs.init_new(server_args) logger.info(f"{server_args=}") + # Start the engine info bootstrap server if per-rank info is needed. + engine_info_bootstrap_server = None + if ( + server_args.remote_instance_weight_loader_start_seed_via_transfer_engine + and server_args.node_rank == 0 + ): + bootstrap_port = server_args.engine_info_bootstrap_port + if not is_port_available(bootstrap_port): + raise RuntimeError( + f"engine_info_bootstrap_port {bootstrap_port} is already in use. " + f"When running multiple instances on the same node, each instance must use a " + f"different --engine-info-bootstrap-port." + ) + engine_info_bootstrap_server = EngineInfoBootstrapServer( + host=server_args.host, port=bootstrap_port + ) + # Launch scheduler processes scheduler_init_result, scheduler_procs = cls._launch_scheduler_processes( server_args, port_args, run_scheduler_process_func ) + scheduler_init_result.engine_info_bootstrap_server = ( + engine_info_bootstrap_server + ) if ( server_args.enable_elastic_expert_backup diff --git a/python/sglang/srt/entrypoints/engine_info_bootstrap_server.py b/python/sglang/srt/entrypoints/engine_info_bootstrap_server.py new file mode 100644 index 000000000000..77de7fc7d030 --- /dev/null +++ b/python/sglang/srt/entrypoints/engine_info_bootstrap_server.py @@ -0,0 +1,105 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import logging +import threading +from typing import Dict, Optional, Tuple + +import uvicorn +from fastapi import FastAPI, HTTPException +from fastapi.responses import PlainTextResponse + +logger = logging.getLogger(__name__) + + +class EngineInfoBootstrapServer: + """Lightweight HTTP server for per-rank model info registration. + + Runs in a daemon thread on node_rank==0. Each ModelRunner registers its + info via HTTP PUT after model initialization. The Engine + accesses the collected info directly in-process; external consumers can + query via HTTP GET. + + Currently supports transfer engine memory registration info. + """ + + def __init__(self, host: str, port: int): + self.host = host + self.port = port + + # Storage: {tp_rank: (session_id, weights_info_dict)} + self.transfer_engine_info: Dict[int, Tuple] = {} + self.lock = threading.Lock() + + app = FastAPI() + + @app.get("/health") + def health(): + return PlainTextResponse("OK") + + @app.put("/register_transfer_engine_info") + def register_transfer_engine_info(data: dict): + try: + tp_rank = data["tp_rank"] + info = data["transfer_engine_info"] + session_id = info["session_id"] + weights_info_dict = info["weights_info_dict"] + + with self.lock: + self.transfer_engine_info[tp_rank] = ( + session_id, + weights_info_dict, + ) + + logger.info( + f"Registered transfer engine info for tp_rank={tp_rank}, " + f"session_id={session_id}" + ) + return PlainTextResponse("OK") + except Exception as e: + logger.error(f"Failed to register engine info: {e}") + raise HTTPException(status_code=400, detail=str(e)) + + @app.get("/get_transfer_engine_info") + def get_transfer_engine_info(rank: int): + if rank < 0: + raise HTTPException(status_code=400, detail="Invalid rank parameter") + + with self.lock: + info = self.transfer_engine_info.get(rank) + + if info is None: + raise HTTPException( + status_code=404, + detail=f"No transfer engine info for rank {rank}", + ) + + return {"rank": rank, "remote_instance_transfer_engine_info": list(info)} + + config = uvicorn.Config(app, host=host, port=port, log_level="warning") + self._server = uvicorn.Server(config) + self._thread = threading.Thread( + target=self._server.run, + daemon=True, + ) + self._thread.start() + logger.info(f"EngineInfoBootstrapServer started on {host}:{port}") + + def close(self): + self._server.should_exit = True + self._thread.join(timeout=5) + + def get_transfer_engine_info(self, rank: int) -> Optional[Tuple]: + """Direct in-process access for co-located HTTP server (no HTTP round-trip).""" + return self.transfer_engine_info.get(rank) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 9cba65bf7f08..dd85504a5a30 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -153,9 +153,6 @@ ) from sglang.srt.managers.template_manager import TemplateManager from sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager -from sglang.srt.model_loader.remote_instance_weight_loader_utils import ( - parse_remote_instance_transfer_engine_info_from_scheduler_infos, -) from sglang.srt.observability.func_timer import enable_func_timer from sglang.srt.observability.trace import ( process_tracing_init, @@ -196,15 +193,6 @@ class _GlobalState: tokenizer_manager: Union[TokenizerManager, MultiTokenizerRouter, TokenizerWorker] template_manager: TemplateManager scheduler_info: Dict - # Dict{ - # rank: Tuple( - # session_id, - # Dict{ - # name: Tuple (d_ptr, numel, element_size) - # } - # ) - # } - remote_instance_transfer_engine_info: Optional[Dict] = None _global_state: Optional[_GlobalState] = None @@ -1030,26 +1018,39 @@ async def send_weights_to_remote_instance( @app.get("/get_remote_instance_transfer_engine_info") @auth_level(AuthLevel.ADMIN_OPTIONAL) async def get_remote_instance_transfer_engine_info(rank: int = None): - if rank is None or rank < 0: - return Response(status_code=HTTPStatus.BAD_REQUEST) + """Get the server information (deprecated - use /remote_instance_transfer_engine_info instead).""" + logger.warning( + "Endpoint '/get_remote_instance_transfer_engine_info' is deprecated and will be removed in a future version. " + "Please use '/remote_instance_transfer_engine_info' instead." + ) + return await remote_instance_transfer_engine_info(rank=rank) - if ( - _global_state.remote_instance_transfer_engine_info is None - or len(_global_state.remote_instance_transfer_engine_info) == 0 - ): - return Response(status_code=HTTPStatus.BAD_REQUEST) +@app.get("/remote_instance_transfer_engine_info") +@auth_level(AuthLevel.ADMIN_OPTIONAL) +async def remote_instance_transfer_engine_info(rank: int = None): + if rank is None or rank < 0: + return ORJSONResponse( + {"error": {"message": "Missing or invalid rank parameter"}}, + status_code=HTTPStatus.BAD_REQUEST, + ) + + server_args = _global_state.tokenizer_manager.server_args try: - result = { - "rank": rank, - "remote_instance_transfer_engine_info": _global_state.remote_instance_transfer_engine_info[ - rank - ], - } - return result - except Exception as e: - logger.error(f"Exception: {e}") - return Response(status_code=HTTPStatus.BAD_REQUEST) + resp = requests.get( + f"{server_args.engine_info_bootstrap_url}/get_transfer_engine_info", + params={"rank": rank}, + timeout=5, + ) + if resp.status_code == 200: + return resp.json() + except (requests.exceptions.RequestException, ValueError) as e: + logger.warning(f"Failed to get transfer engine info for rank {rank}: {e}") + + return ORJSONResponse( + {"error": {"message": f"Failed to get transfer engine info for rank {rank}"}}, + status_code=HTTPStatus.BAD_REQUEST, + ) @app.post("/init_weights_update_group") @@ -1993,18 +1994,12 @@ def _setup_and_run_http_server( Called by launch_server after subprocesses have been launched. """ - # Parse info got from the schedulers - remote_instance_transfer_engine_info = ( - parse_remote_instance_transfer_engine_info_from_scheduler_infos(scheduler_infos) - ) - # Set global states set_global_state( _GlobalState( tokenizer_manager=tokenizer_manager, template_manager=template_manager, scheduler_info=scheduler_infos[0], - remote_instance_transfer_engine_info=remote_instance_transfer_engine_info, ) ) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 4351634515aa..a2c9e4c79ec0 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1255,19 +1255,6 @@ def get_init_info(self) -> Dict[str, Any]: "max_req_input_len": self.max_req_input_len, } - if self.server_args.remote_instance_weight_loader_use_transfer_engine(): - ( - remote_instance_transfer_engine_session_id, - remote_instance_transfer_engine_weights_info_dict, - ) = self.get_remote_instance_transfer_engine_info() - result_dict.update( - { - "tp_rank": self.tp_rank, - "remote_instance_transfer_engine_session_id": remote_instance_transfer_engine_session_id, - "remote_instance_transfer_engine_weights_info_dict": remote_instance_transfer_engine_weights_info_dict, - } - ) - return result_dict def run_event_loop(self) -> None: @@ -3377,9 +3364,6 @@ def update_cache_from_scheduler( ): pass - def get_remote_instance_transfer_engine_info(self): - return self.tp_worker.get_remote_instance_transfer_engine_info() - class IdleSleeper: """ diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 46a87aeb8762..7f63610da8ee 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -441,12 +441,6 @@ def _forward_batch_generation_dllm( can_run_cuda_graph=can_run_cuda_graph, ) - def get_remote_instance_transfer_engine_info(self): - return ( - self.model_runner.remote_instance_transfer_engine_session_id, - self.model_runner.remote_instance_transfer_engine_weight_info, - ) - def forward_batch_generation( self, model_worker_batch: ModelWorkerBatch, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 1a9579f96791..006fcdd30a70 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -520,9 +520,11 @@ def initialize(self, pre_model_load_memory: float): and self.remote_instance_transfer_engine is not None and self.remote_instance_transfer_engine_weight_info is None ): + # Register memory and upstream the transfer engine info to the bootstrap server self.remote_instance_transfer_engine_weight_info = register_memory_region( self.model, self.remote_instance_transfer_engine ) + self._register_to_engine_info_bootstrap() # For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft # models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to @@ -700,6 +702,52 @@ def remote_instance_init_transfer_engine(self): local_ip, self.remote_instance_transfer_engine.get_rpc_port() ).to_host_port_str() + def _register_to_engine_info_bootstrap(self): + """Register transfer engine info with the EngineInfoBootstrapServer via HTTP PUT. + + The bootstrap server runs on node_rank==0. For multi-node setups, the + host is derived from dist_init_addr. For single-node, use 127.0.0.1. + """ + import requests as http_requests + + if self.server_args.dist_init_addr: + # Multi-node: bootstrap server is on the head node (node_rank==0). + # Derive host from dist_init_addr (shared across all nodes). + bootstrap_host = ( + NetworkAddress.parse(self.server_args.dist_init_addr).resolved().host + ) + else: + bootstrap_host = "127.0.0.1" + + bootstrap_port = self.server_args.engine_info_bootstrap_port + bootstrap_na = NetworkAddress(bootstrap_host, bootstrap_port) + url = f"{bootstrap_na.to_url()}/register_transfer_engine_info" + + payload = { + "tp_rank": self.tp_rank, + "transfer_engine_info": { + "session_id": self.remote_instance_transfer_engine_session_id, + "weights_info_dict": self.remote_instance_transfer_engine_weight_info, + }, + } + + try: + resp = http_requests.put(url, json=payload, timeout=5) + if resp.status_code == 200: + logger.info( + f"Registered transfer engine info for tp_rank={self.tp_rank} " + f"with bootstrap server at {bootstrap_na}" + ) + else: + logger.error( + f"Failed to register transfer engine info for tp_rank={self.tp_rank}: " + f"{resp.status_code}, {resp.text}" + ) + except Exception as e: + logger.error( + f"Failed to register transfer engine info for tp_rank={self.tp_rank}: {e}" + ) + def _publish_modelexpress_metadata(self): """Publish TransferEngine metadata to ModelExpress server (seed mode).""" try: diff --git a/python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py b/python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py index 8a945bb4c2e3..2a0aeb047ed6 100644 --- a/python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py +++ b/python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py @@ -106,21 +106,6 @@ def get_remote_instance_transfer_engine_info_per_rank(seed_url: str, rank: int): return None, None -def parse_remote_instance_transfer_engine_info_from_scheduler_infos(scheduler_infos): - remote_instance_transfer_engine_info = {} - for data in scheduler_infos: - if ( - "tp_rank" in data - and "remote_instance_transfer_engine_session_id" in data - and "remote_instance_transfer_engine_weights_info_dict" in data - ): - remote_instance_transfer_engine_info[data["tp_rank"]] = ( - data["remote_instance_transfer_engine_session_id"], - data["remote_instance_transfer_engine_weights_info_dict"], - ) - return remote_instance_transfer_engine_info - - def register_memory_region(model, transfer_engine): if importlib.util.find_spec("torch") is None: return register_memory_region_v1(model, transfer_engine) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 04a365927214..fb29dea9bfc6 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -712,6 +712,7 @@ class ServerArgs: "transfer_engine", "nccl", "modelexpress" ] = "nccl" remote_instance_weight_loader_start_seed_via_transfer_engine: bool = False + engine_info_bootstrap_port: int = 6789 modelexpress_config: Optional[str] = None # For PD-Multiplexing @@ -5772,6 +5773,13 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Start seed server via transfer engine backend for remote instance weight loader.", ) + parser.add_argument( + "--engine-info-bootstrap-port", + type=int, + default=ServerArgs.engine_info_bootstrap_port, + help="Port for the engine info bootstrap server. Default is 6789. " + "Must be set explicitly when running multiple instances on the same node.", + ) parser.add_argument( "--modelexpress-config", type=str, @@ -5891,7 +5899,7 @@ def from_cli_args(cls, args: argparse.Namespace): attrs = [attr.name for attr in dataclasses.fields(cls)] return cls(**{attr: getattr(args, attr) for attr in attrs}) - def url(self): + def url(self, port: Optional[int] = None): scheme = "https" if self.ssl_certfile else "http" # When binding to all interfaces, use loopback for internal requests. host = self.host @@ -5899,7 +5907,13 @@ def url(self): host = "127.0.0.1" elif host == "::": host = "::1" - return NetworkAddress(host, self.port).to_url(scheme) + return NetworkAddress(host, port if port is not None else self.port).to_url( + scheme + ) + + @property + def engine_info_bootstrap_url(self): + return self.url(port=self.engine_info_bootstrap_port) def ssl_verify(self): """Return the value for the requests library's ``verify=`` parameter. diff --git a/test/manual/test_cross_node_scheduler_info_sync.py b/test/manual/test_cross_node_scheduler_info_sync.py new file mode 100755 index 000000000000..f6e5a835fb73 --- /dev/null +++ b/test/manual/test_cross_node_scheduler_info_sync.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 +""" +Test cross-node scheduler_infos synchronization for remote weight loading. + +Simulates multi-node setups on a single machine using different GPU subsets. +Validates that scheduler_infos are correctly synced across nodes via Gloo. + +IMPORTANT: For multi-node tests, start both nodes within a few seconds of each +other to avoid port binding conflicts (they share the same network namespace). + +Test cases: + - tp4_nodes2: TP=4 across 2 nodes, validates basic cross-node sync + - dp2_single_node: DP=2 with dp_attention on single node + - dp2_tp2_nodes2: DP=2, TP=4 across 2 nodes with dp_attention + +Usage (multi-node): + Terminal 1: python test_cross_node_scheduler_info_sync.py --test-case tp4_nodes2 --node-rank 0 + Terminal 2: python test_cross_node_scheduler_info_sync.py --test-case tp4_nodes2 --node-rank 1 + Terminal 3: python test_cross_node_scheduler_info_sync.py --test-case tp4_nodes2 --test-only + +Usage (single-node): + Terminal 1: python test_cross_node_scheduler_info_sync.py --test-case dp2_single_node --node-rank 0 + Terminal 2: python test_cross_node_scheduler_info_sync.py --test-case dp2_single_node --test-only + +Requirements: 4 GPUs on single machine +""" + +import argparse +import socket +import subprocess +import sys +import time +from dataclasses import dataclass +from typing import List + +import requests + +from sglang.test.test_utils import ( + DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, +) + + +@dataclass +class TestCase: + name: str + tp_size: int + dp_size: int + nnodes: int + gpus_per_node: int + expected_ranks: int + extra_args: List[str] + + +TEST_CASES = { + "tp4_nodes2": TestCase( + name="tp4_nodes2", + tp_size=4, + dp_size=1, + nnodes=2, + gpus_per_node=2, + expected_ranks=4, + extra_args=[], + ), + "dp2_single_node": TestCase( + name="dp2_single_node", + tp_size=2, + dp_size=2, + nnodes=1, + gpus_per_node=2, + expected_ranks=2, + extra_args=["--enable-dp-attention", "--dp", "2", "--attention-backend", "fa3"], + ), + "dp2_tp2_nodes2": TestCase( + name="dp2_tp2_nodes2", + tp_size=4, + dp_size=2, + nnodes=2, + gpus_per_node=2, + expected_ranks=4, + extra_args=["--enable-dp-attention", "--dp", "2", "--attention-backend", "fa3"], + ), +} + +TEST_CASE_MODELS = { + "tp4_nodes2": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, + "dp2_single_node": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, + "dp2_tp2_nodes2": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, +} + + +def get_local_ip() -> str: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + s.connect(("8.8.8.8", 80)) + return s.getsockname()[0] + except Exception: + return "127.0.0.1" + finally: + s.close() + + +def launch_node( + test_case: TestCase, node_rank: int, model_path: str, dist_init_addr: str +): + cmd = [ + sys.executable, + "-m", + "sglang.launch_server", + "--model-path", + model_path, + "--tp", + str(test_case.tp_size), + "--port", + str(30000 + node_rank * 100), + "--host", + "0.0.0.0", + "--remote-instance-weight-loader-start-seed-via-transfer-engine", + ] + if test_case.nnodes > 1: + cmd.extend( + [ + "--nnodes", + str(test_case.nnodes), + "--node-rank", + str(node_rank), + "--dist-init-addr", + dist_init_addr, + "--base-gpu-id", + str(node_rank * test_case.gpus_per_node), + ] + ) + cmd.extend(test_case.extra_args) + print(f"[Node {node_rank}] {' '.join(cmd)}") + subprocess.run(cmd) + + +def test_api(test_case: TestCase) -> bool: + base_url = "http://127.0.0.1:30000" + print(f"Testing {test_case.name}: expecting {test_case.expected_ranks} ranks") + + for _ in range(60): + try: + if requests.get(f"{base_url}/health", timeout=2).status_code == 200: + break + except Exception: + pass + time.sleep(2) + else: + print("ERROR: Server not ready") + return False + + all_passed = True + for rank in range(test_case.expected_ranks): + try: + resp = requests.get( + f"{base_url}/get_remote_instance_transfer_engine_info", + params={"rank": rank}, + timeout=5, + ) + status = "✓" if resp.status_code == 200 else "✗" + print(f"{status} Rank {rank}: {resp.status_code}") + if resp.status_code != 200: + all_passed = False + except Exception as e: + print(f"✗ Rank {rank}: {e}") + all_passed = False + + print("PASSED" if all_passed else "FAILED") + return all_passed + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--test-case", type=str, choices=list(TEST_CASES.keys()), required=True + ) + parser.add_argument("--node-rank", type=int, choices=[0, 1]) + parser.add_argument("--model-path", type=str, default=None) + parser.add_argument("--dist-init-addr", type=str, default=None) + parser.add_argument("--test-only", action="store_true") + args = parser.parse_args() + + test_case = TEST_CASES[args.test_case] + model_path = args.model_path or TEST_CASE_MODELS.get( + args.test_case, DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT + ) + + if args.test_only: + sys.exit(0 if test_api(test_case) else 1) + + if test_case.nnodes == 1: + launch_node(test_case, 0, model_path, "") + return + + if args.node_rank is None: + print(f"Usage: --node-rank 0 or 1, then --test-only in another terminal") + sys.exit(0) + + dist_init_addr = args.dist_init_addr or f"{get_local_ip()}:20000" + launch_node(test_case, args.node_rank, model_path, dist_init_addr) + + +if __name__ == "__main__": + main() diff --git a/test/registered/distributed/test_load_weights_from_remote_instance.py b/test/registered/distributed/test_load_weights_from_remote_instance.py index f1080caeb258..4402d399d30f 100644 --- a/test/registered/distributed/test_load_weights_from_remote_instance.py +++ b/test/registered/distributed/test_load_weights_from_remote_instance.py @@ -228,6 +228,8 @@ def init_process_dst( "--remote-instance-weight-loader-backend", remote_instance_loader_backend, "--remote-instance-weight-loader-start-seed-via-transfer-engine", + "--engine-info-bootstrap-port", + str(6789 + rank), ), ) torch.cuda.synchronize()