diff --git a/src/srtctl/backends/base.py b/src/srtctl/backends/base.py index 62904ff13..99f87c565 100644 --- a/src/srtctl/backends/base.py +++ b/src/srtctl/backends/base.py @@ -14,7 +14,7 @@ from pathlib import Path from srtctl.core.runtime import RuntimeContext - from srtctl.core.topology import Endpoint, Process + from srtctl.core.topology import Endpoint, NodePortAllocator, Process class BackendType(str, Enum): @@ -90,6 +90,7 @@ def endpoints_to_processes( self, endpoints: list["Endpoint"], base_sys_port: int = 8081, + port_allocator: "NodePortAllocator | None" = None, ) -> list["Process"]: """Convert logical endpoints to physical processes.""" ... diff --git a/src/srtctl/backends/sglang.py b/src/srtctl/backends/sglang.py index 1f4b818da..819545a7b 100644 --- a/src/srtctl/backends/sglang.py +++ b/src/srtctl/backends/sglang.py @@ -25,7 +25,7 @@ if TYPE_CHECKING: from srtctl.backends.base import SrunConfig from srtctl.core.runtime import RuntimeContext - from srtctl.core.topology import Endpoint, Process + from srtctl.core.topology import Endpoint, NodePortAllocator, Process # Type alias for worker modes WorkerMode = Literal["prefill", "decode", "agg"] @@ -198,11 +198,12 @@ def endpoints_to_processes( self, endpoints: list["Endpoint"], base_sys_port: int = 8081, + port_allocator: "NodePortAllocator | None" = None, ) -> list["Process"]: """Convert endpoints to processes.""" from srtctl.core.topology import endpoints_to_processes - return endpoints_to_processes(endpoints, base_sys_port=base_sys_port) + return endpoints_to_processes(endpoints, base_sys_port=base_sys_port, port_allocator=port_allocator) def build_worker_command( self, @@ -270,7 +271,24 @@ def build_worker_command( cmd.extend(["--disaggregation-mode", mode]) # Bootstrap port only needed for sglang frontend (dynamo handles internally) if frontend_type == "sglang" and mode == "prefill" and process.bootstrap_port is not None: - cmd.extend(["--disaggregation-bootstrap-port", str(process.bootstrap_port)]) + user_bootstrap_port = config.get("disaggregation-bootstrap-port") + if user_bootstrap_port is None: + cmd.extend(["--disaggregation-bootstrap-port", str(process.bootstrap_port)]) + else: + try: + user_port_int = int(user_bootstrap_port) + except (TypeError, ValueError) as e: + raise ValueError( + f"Invalid disaggregation-bootstrap-port={user_bootstrap_port!r} in sglang_config.prefill" + ) from e + if user_port_int != process.bootstrap_port: + raise ValueError( + "disaggregation-bootstrap-port mismatch for sglang prefill worker: " + f"config={user_port_int}, topology={process.bootstrap_port}. " + "For sglang router frontend, router and prefill workers must use the same bootstrap port. " + "If you run multiple prefill workers on the same node, do not set a fixed " + "disaggregation-bootstrap-port in the recipe." + ) # Add multi-node coordination flags if is_multi_node: @@ -297,7 +315,7 @@ def build_worker_command( kv_cfg["endpoint"] = f"tcp://*:{process.kv_events_port}" cmd.extend(["--kv-events-config", json.dumps(kv_cfg)]) - # Add all config flags + # Add all config flags. cmd.extend(_config_to_cli_args(config)) return cmd diff --git a/src/srtctl/cli/do_sweep.py b/src/srtctl/cli/do_sweep.py index ff6eaa917..547d6dec2 100644 --- a/src/srtctl/cli/do_sweep.py +++ b/src/srtctl/cli/do_sweep.py @@ -32,9 +32,9 @@ ) from srtctl.core.runtime import RuntimeContext from srtctl.core.schema import SrtConfig -from srtctl.core.slurm import get_slurm_job_id, start_srun_process +from srtctl.core.slurm import get_port_offset, get_slurm_job_id, start_srun_process from srtctl.core.status import JobStage, JobStatus, StatusReporter -from srtctl.core.topology import Endpoint, Process +from srtctl.core.topology import Endpoint, NodePortAllocator, Process from srtctl.logging_utils import setup_logging logger = logging.getLogger(__name__) @@ -80,14 +80,67 @@ def endpoints(self) -> list[Endpoint]: @functools.cached_property def backend_processes(self) -> list[Process]: """Compute physical process topology from endpoints (cached).""" - return self.backend.endpoints_to_processes(self.endpoints) - - def start_head_infrastructure(self, registry: ProcessRegistry) -> ManagedProcess: - """Start NATS and etcd on the infra node. - - When etcd_nats_dedicated_node is enabled, services run on a dedicated node. - Otherwise, they run on the head node (default behavior). + # DYN_SYSTEM_PORT is parsed as i16 by dynamo runtime, so keep ports < 32768. + # Also avoid collisions across concurrent jobs by offsetting from the job id. + # + # Note: srtctl allocates one sys_port per Process and increments sequentially from base_sys_port. + # Therefore, base_sys_port must reserve a sufficiently large consecutive port window per job + # to avoid collisions with other jobs running concurrently. + # + # Use get_port_offset() for consistency with other services (NATS, etcd, frontend). + # get_port_offset returns 0-990 in steps of 10, giving 100 slots. + # Each slot needs ~200 ports for sys_port allocation, so we multiply offset by 20. + port_offset = get_port_offset(self.runtime.job_id) + sys_port_stride = 200 # Reserved consecutive sys ports per job. + base_sys_port = 9000 + (port_offset * 20) # Range: 9000-28800, step 200 + + port_allocator: NodePortAllocator | None = None + if self.config.frontend.type == "sglang" and getattr(self.backend, "type", None) == "sglang": + prefill_cfg: dict[str, object] = {} + try: + prefill_cfg = self.backend.get_config_for_mode("prefill") # type: ignore[assignment] + except Exception: + prefill_cfg = {} + + user_bootstrap_port = prefill_cfg.get("disaggregation-bootstrap-port") + if user_bootstrap_port is not None: + try: + base_bootstrap_port = int(user_bootstrap_port) + except (TypeError, ValueError): + logger.warning( + "Invalid disaggregation-bootstrap-port=%r; falling back to default bootstrap port allocation", + user_bootstrap_port, + ) + else: + port_allocator = NodePortAllocator(base_bootstrap_port=base_bootstrap_port) + + processes = self.backend.endpoints_to_processes( + self.endpoints, + base_sys_port=base_sys_port, + port_allocator=port_allocator, + ) + if len(processes) > sys_port_stride: + logger.warning( + "This job allocates %d processes, which may exceed the reserved sys_port window (%d). " + "Consider increasing sys_port_stride to reduce cross-job collision risk.", + len(processes), + sys_port_stride, + ) + return processes + + def start_head_infrastructure(self, registry: ProcessRegistry) -> ManagedProcess | None: + """Start head node infrastructure when required by the chosen frontend. + + Dynamo frontend requires NATS+etcd for discovery/control planes. + SGLang frontend uses direct worker connections and does not require these services. """ + if self.config.frontend.type != "dynamo": + logger.info( + "Skipping head node infrastructure (frontend.type=%s does not require NATS/etcd)", + self.config.frontend.type, + ) + return None + infra_node = self.runtime.nodes.infra logger.info("Starting infrastructure services (NATS, etcd)") logger.info("Infra node: %s", infra_node) @@ -130,14 +183,19 @@ def start_head_infrastructure(self, registry: ProcessRegistry) -> ManagedProcess critical=True, ) + port_offset = get_port_offset(self.runtime.job_id) + nats_port = 4222 + port_offset + etcd_port = 2379 + port_offset + logger.info("Port offset for this job: %d (job_id: %s)", port_offset, self.runtime.job_id) + # 300s timeout to handle slow container imports on first run - logger.info("Waiting for NATS (port 4222) on %s...", infra_node) - if not wait_for_port(infra_node, 4222, timeout=300): + logger.info("Waiting for NATS (port %d) on %s...", nats_port, infra_node) + if not wait_for_port(infra_node, nats_port, timeout=300): raise RuntimeError("NATS failed to start") logger.info("NATS is ready") - logger.info("Waiting for etcd (port 2379) on %s...", infra_node) - if not wait_for_port(infra_node, 2379, timeout=300): + logger.info("Waiting for etcd (port %d) on %s...", etcd_port, infra_node) + if not wait_for_port(infra_node, etcd_port, timeout=300): raise RuntimeError("etcd failed to start") logger.info("etcd is ready") @@ -150,11 +208,13 @@ def _print_connection_info(self) -> None: if mounts_str: container_args += f" --container-mounts={mounts_str}" + public_port = self.runtime.frontend_port + logger.info("") logger.info("=" * 60) logger.info("Connection Commands") logger.info("=" * 60) - logger.info("Frontend URL: http://%s:8000", self.runtime.nodes.head) + logger.info("Frontend URL: http://%s:%d", self.runtime.nodes.head, public_port) logger.info("") logger.info("To connect to head node (%s):", self.runtime.nodes.head) logger.info( @@ -206,7 +266,8 @@ def run(self) -> int: # Stage 1: Head infrastructure (NATS, etcd) reporter.report(JobStatus.STARTING, JobStage.HEAD_INFRASTRUCTURE, "Starting head infrastructure") head_proc = self.start_head_infrastructure(registry) - registry.add_process(head_proc) + if head_proc is not None: + registry.add_process(head_proc) # Stage 2: Workers reporter.report(JobStatus.WORKERS, JobStage.WORKERS, "Starting workers") diff --git a/src/srtctl/cli/mixins/benchmark_stage.py b/src/srtctl/cli/mixins/benchmark_stage.py index f3769a9f0..c44051973 100644 --- a/src/srtctl/cli/mixins/benchmark_stage.py +++ b/src/srtctl/cli/mixins/benchmark_stage.py @@ -78,7 +78,7 @@ def run_benchmark( hc = self.config.health_check if not wait_for_model( host=self.runtime.nodes.head, - port=8000, + port=self.runtime.frontend_port, n_prefill=n_prefill, n_decode=n_decode, poll_interval=float(hc.interval_seconds), @@ -106,7 +106,7 @@ def run_benchmark( if benchmark_type == "manual": logger.info("Benchmark type is 'manual' - server is ready for testing") - logger.info("Frontend URL: http://%s:8000", self.runtime.nodes.head) + logger.info("Frontend URL: http://%s:%d", self.runtime.nodes.head, self.runtime.frontend_port) logger.info("Press Ctrl+C to stop the job") while not stop_event.is_set(): diff --git a/src/srtctl/cli/mixins/frontend_stage.py b/src/srtctl/cli/mixins/frontend_stage.py index 1e1b0b70f..7570b1bb4 100644 --- a/src/srtctl/cli/mixins/frontend_stage.py +++ b/src/srtctl/cli/mixins/frontend_stage.py @@ -67,7 +67,7 @@ def backend(self) -> Any: @property def backend_processes(self) -> list["Process"]: """Compute physical process topology from endpoints (cached).""" - ... + raise NotImplementedError def _compute_frontend_topology(self) -> FrontendTopology: """Determine where nginx and frontends should run. @@ -76,20 +76,33 @@ def _compute_frontend_topology(self) -> FrontendTopology: - Single node OR multiple_frontends disabled: 1 frontend on head, no nginx - 2+ nodes AND multiple_frontends enabled: nginx on head, frontends on other nodes + Port offset based on job_id avoids conflicts between different SLURM jobs. + Returns: FrontendTopology describing where to run nginx and frontends. """ + from srtctl.core.slurm import get_port_offset + nodes = self.runtime.nodes.worker head = self.runtime.nodes.head fe_config = self.config.frontend + # Calculate port offset to avoid conflicts between jobs + port_offset = get_port_offset(self.runtime.job_id) + + # Base ports with offset + # Note: base_internal_port must not conflict with DYN_SYSTEM_PORT (8081+offset+worker_idx) + # With many workers, DYN_SYSTEM_PORT can reach 8081+offset+N, so use 9090 to stay clear + base_public_port = 8000 + port_offset + base_internal_port = 9090 + port_offset + # Single node or multiple frontends disabled: single frontend, no nginx if len(nodes) == 1 or not fe_config.enable_multiple_frontends: return FrontendTopology( nginx_node=None, frontend_nodes=[head], - frontend_port=8000, - public_port=8000, + frontend_port=base_public_port, + public_port=base_public_port, ) # Multiple nodes with multiple frontends enabled: @@ -104,17 +117,19 @@ def _compute_frontend_topology(self) -> FrontendTopology: frontend_nodes = other_nodes[:max_frontends] logger.info( - "Frontend topology: nginx on %s, %d frontends on %s", + "Frontend topology: nginx on %s (port %d), %d frontends on %s (port %d)", head, + base_public_port, len(frontend_nodes), frontend_nodes, + base_internal_port, ) return FrontendTopology( nginx_node=head, frontend_nodes=frontend_nodes, - frontend_port=8180, # Internal port behind nginx - public_port=8000, # Public port exposed by nginx + frontend_port=base_internal_port, # Internal port behind nginx + public_port=base_public_port, # Public port exposed by nginx ) def _start_nginx(self, topology: FrontendTopology) -> ManagedProcess: @@ -132,6 +147,7 @@ def _start_nginx(self, topology: FrontendTopology) -> ManagedProcess: # Install nginx and run it (daemon off keeps nginx in foreground so srun can manage it) # Use container path (/logs) since log_dir is mounted there + # Add retry logic for apt-get in case of mirror sync issues container_config_path = "/logs/nginx.conf" cmd = [ "bash", diff --git a/src/srtctl/cli/mixins/worker_stage.py b/src/srtctl/cli/mixins/worker_stage.py index 63c41489e..91e053394 100644 --- a/src/srtctl/cli/mixins/worker_stage.py +++ b/src/srtctl/cli/mixins/worker_stage.py @@ -110,13 +110,29 @@ def start_worker(self, process: "Process", endpoint_processes: list["Process"]) ) # Environment variables - env_to_set = { - "HEAD_NODE_IP": self.runtime.head_node_ip, - "ETCD_ENDPOINTS": f"http://{self.runtime.nodes.infra}:2379", - "NATS_SERVER": f"nats://{self.runtime.nodes.infra}:4222", - "DYN_SYSTEM_PORT": str(process.sys_port), - "DYN_REQUEST_PLANE": "nats", - } + env_to_set: dict[str, str] = {"HEAD_NODE_IP": self.runtime.head_node_ip} + + # Only Dynamo workers require etcd/NATS + system status server port. + if self.config.frontend.type == "dynamo": + from srtctl.core.slurm import get_port_offset + + port_offset = get_port_offset(self.runtime.job_id) + nats_port = 4222 + port_offset + etcd_port = 2379 + port_offset + + env_to_set.update( + { + "ETCD_ENDPOINTS": f"http://{self.runtime.infra_node_ip}:{etcd_port}", + "NATS_SERVER": f"nats://{self.runtime.infra_node_ip}:{nats_port}", + "DYN_SYSTEM_PORT": str(process.sys_port), + } + ) + + # Keep request-plane consistent across frontend/workers + frontend_plane = None + if self.config.frontend.env: + frontend_plane = self.config.frontend.env.get("DYN_REQUEST_PLANE") + env_to_set["DYN_REQUEST_PLANE"] = frontend_plane if frontend_plane else "nats" # Add mode-specific environment variables from backend # Support simple {node} and {node_id} templating @@ -139,8 +155,9 @@ def __missing__(self, key: str) -> str: # Add profiling environment variables if profiling.enabled: - profile_dir = str(self.runtime.log_dir / "profiles") - env_to_set.update(profiling.get_env_vars(mode, profile_dir)) + # /logs is the mounted host log directory inside the container. + profile_dir_in_container = "/logs/profiles" + env_to_set.update(profiling.get_env_vars(mode, profile_dir_in_container)) # Set CUDA_VISIBLE_DEVICES if not using all GPUs if len(process.gpu_indices) < self.runtime.gpus_per_node: @@ -228,6 +245,9 @@ def start_endpoint_worker(self, endpoint_processes: list["Process"]) -> ManagedP ) # Environment variables + # TODO: port-offset is only applied in start_worker() (SGLang path). + # This MPI-style path (TRTLLM) still uses hardcoded NATS/etcd ports. + # If TRTLLM needs port-offset support, mirror the dynamo env logic from start_worker(). env_to_set = { "HEAD_NODE_IP": self.runtime.head_node_ip, "ETCD_ENDPOINTS": f"http://{self.runtime.nodes.infra}:2379", diff --git a/src/srtctl/cli/setup_head.py b/src/srtctl/cli/setup_head.py index 91582e3fc..d8fde106c 100644 --- a/src/srtctl/cli/setup_head.py +++ b/src/srtctl/cli/setup_head.py @@ -17,15 +17,37 @@ import time from pathlib import Path -# Network configurations -ETCD_CLIENT_PORT = 2379 -ETCD_PEER_PORT = 2380 -NATS_PORT = 4222 +# Network configurations (base ports - offset applied at runtime based on job_id) +BASE_ETCD_CLIENT_PORT = 2379 +BASE_ETCD_PEER_PORT = 2380 +BASE_NATS_PORT = 4222 ETCD_LISTEN_ADDR = "http://0.0.0.0" logger = logging.getLogger(__name__) +def get_port_offset_from_job_id(job_id: str | None = None) -> int: + """Calculate port offset based on SLURM job ID. + + Args: + job_id: SLURM job ID (if None, reads from environment) + + Returns: + Port offset: (job_id % 100) * 10 + """ + if job_id is None: + job_id = os.environ.get("SLURM_JOB_ID") or os.environ.get("SLURM_JOBID") + + if not job_id: + return 0 + + try: + offset = (int(job_id) % 100) * 10 + return offset + except (ValueError, TypeError): + return 0 + + def get_local_ip() -> str: """Get local IP address using multiple fallback methods. @@ -119,11 +141,12 @@ def setup_logging(): ) -def start_nats(binary_path: str = "/configs/nats-server") -> subprocess.Popen: +def start_nats(binary_path: str = "/configs/nats-server", port: int = 4222) -> subprocess.Popen: """Start NATS server. Args: binary_path: Path to nats-server binary + port: Port for NATS to listen on Returns: Popen object for the NATS process @@ -131,21 +154,21 @@ def start_nats(binary_path: str = "/configs/nats-server") -> subprocess.Popen: if not os.path.exists(binary_path): raise FileNotFoundError(f"NATS binary not found: {binary_path}") - # Use /tmp for JetStream storage - avoids "Temporary storage directory" warning - # and ensures we're using fast local storage' + # Use /tmp for JetStream storage - container-local fast storage. + # Each SLURM job runs in its own container, so /tmp is isolated per job. if os.path.exists("/tmp/nats"): shutil.rmtree("/tmp/nats") nats_store_dir = "/tmp/nats" os.makedirs(nats_store_dir, exist_ok=True) - logger.info("Starting NATS server...") - cmd = [binary_path, "-js", "-sd", nats_store_dir] + logger.info("Starting NATS server on port %d...", port) + cmd = [binary_path, "-js", "-sd", nats_store_dir, "-p", str(port)] proc = subprocess.Popen( cmd, ) - logger.info("NATS server started (PID: %d)", proc.pid) + logger.info("NATS server started (PID: %d, port: %d)", proc.pid, port) return proc @@ -153,6 +176,9 @@ def start_etcd( host_ip: str, binary_path: str = "/configs/etcd", log_dir: Path | None = None, + client_port: int = 2379, + peer_port: int = 2380, + job_id: str | None = None, ) -> subprocess.Popen: """Start etcd server. @@ -160,6 +186,9 @@ def start_etcd( host_ip: IP address of this node (for peer URLs) binary_path: Path to etcd binary log_dir: Optional log directory + client_port: Client API port + peer_port: Peer communication port + job_id: SLURM job ID for unique data directory Returns: Popen object for the etcd process @@ -167,24 +196,37 @@ def start_etcd( if not os.path.exists(binary_path): raise FileNotFoundError(f"etcd binary not found: {binary_path}") - logger.info("Starting etcd server...") + logger.info("Starting etcd server on client port %d, peer port %d...", client_port, peer_port) - # Use /tmp for etcd data directory - this is typically on fast local storage - # (often tmpfs on HPC systems). Without this, etcd uses "default.etcd" in CWD - # which may be on slow network storage, causing Raft consensus timeouts. - if os.path.exists("/tmp/etcd"): - shutil.rmtree("/tmp/etcd") - etcd_data_dir = "/tmp/etcd" - os.makedirs(etcd_data_dir, exist_ok=True) + # Determine etcd data directory. Prefer log_dir so data lives alongside logs + # for easier debugging; fall back to /tmp (container-local fast storage). + if log_dir: + data_dir = log_dir / "etcd-data" + elif job_id: + data_dir = Path(f"/tmp/etcd-{job_id}") + else: + data_dir = Path(f"/tmp/etcd-{client_port}") + + if data_dir.exists(): + shutil.rmtree(data_dir) + data_dir.mkdir(parents=True, exist_ok=True) + etcd_data_dir = str(data_dir) + logger.info("etcd data directory: %s", etcd_data_dir) cmd = [ binary_path, "--data-dir", etcd_data_dir, "--listen-client-urls", - f"{ETCD_LISTEN_ADDR}:{ETCD_CLIENT_PORT}", + f"{ETCD_LISTEN_ADDR}:{client_port}", "--advertise-client-urls", - f"http://{host_ip}:{ETCD_CLIENT_PORT}", # Must be reachable IP, not 0.0.0.0 + f"http://{host_ip}:{client_port}", # Must be reachable IP, not 0.0.0.0 + "--listen-peer-urls", + f"{ETCD_LISTEN_ADDR}:{peer_port}", + "--initial-advertise-peer-urls", + f"http://{host_ip}:{peer_port}", + "--initial-cluster", + f"default=http://{host_ip}:{peer_port}", ] # Set up output handling @@ -255,6 +297,16 @@ def main(): log_dir = Path(args.log_dir) log_dir.mkdir(parents=True, exist_ok=True) + # Calculate port offset based on SLURM_JOB_ID to avoid conflicts + job_id = os.environ.get("SLURM_JOB_ID") + port_offset = get_port_offset_from_job_id(job_id) + logger.info("Port offset for this job: %d (SLURM_JOB_ID: %s)", port_offset, job_id or "N/A") + + # Calculate actual ports + nats_port = BASE_NATS_PORT + port_offset + etcd_client_port = BASE_ETCD_CLIENT_PORT + port_offset + etcd_peer_port = BASE_ETCD_PEER_PORT + port_offset + # Get our IP address using multiple fallback methods host_ip = get_local_ip() logger.info("Host IP: %s", host_ip) @@ -264,21 +316,23 @@ def main(): etcd_proc = None try: - nats_proc = start_nats(args.nats_binary) - etcd_proc = start_etcd(host_ip, args.etcd_binary, log_dir) + nats_proc = start_nats(args.nats_binary, port=nats_port) + etcd_proc = start_etcd( + host_ip, args.etcd_binary, log_dir, client_port=etcd_client_port, peer_port=etcd_peer_port, job_id=job_id + ) # Wait for services - if not wait_for_service("localhost", NATS_PORT, "NATS"): + if not wait_for_service("localhost", nats_port, "NATS"): logger.error("NATS failed to start") sys.exit(1) - if not wait_for_service("localhost", ETCD_CLIENT_PORT, "etcd"): + if not wait_for_service("localhost", etcd_client_port, "etcd"): logger.error("etcd failed to start") sys.exit(1) logger.info("Head node infrastructure is ready") - logger.info(" NATS: nats://localhost:%d", NATS_PORT) - logger.info(" etcd: http://localhost:%d", ETCD_CLIENT_PORT) + logger.info(" NATS: nats://localhost:%d", nats_port) + logger.info(" etcd: http://localhost:%d", etcd_client_port) # Keep running - wait for either process to exit while True: diff --git a/src/srtctl/core/runtime.py b/src/srtctl/core/runtime.py index 3e68bdd5c..6c0037579 100644 --- a/src/srtctl/core/runtime.py +++ b/src/srtctl/core/runtime.py @@ -257,6 +257,12 @@ def from_config( container_path = container_template.get_path(temp_context, make_absolute=False, ensure_exists=False) container_mounts[host_path] = container_path + # Calculate frontend port with job offset to avoid conflicts + from .slurm import get_port_offset + + port_offset = get_port_offset(job_id) + frontend_port = 8000 + port_offset + return cls( job_id=job_id, run_name=run_name, @@ -271,6 +277,7 @@ def from_config( container_mounts=container_mounts, srun_options=dict(config.srun_options), environment=dict(config.environment), + frontend_port=frontend_port, is_hf_model=is_hf_model, ) diff --git a/src/srtctl/core/slurm.py b/src/srtctl/core/slurm.py index 6b5f4e584..9799b15a7 100644 --- a/src/srtctl/core/slurm.py +++ b/src/srtctl/core/slurm.py @@ -16,7 +16,7 @@ import shlex import socket import subprocess -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from pathlib import Path from .ip_utils import get_node_ip @@ -34,6 +34,36 @@ def get_slurm_job_id() -> str | None: return os.environ.get("SLURM_JOB_ID") or os.environ.get("SLURM_JOBID") +def get_port_offset(job_id: str | None = None) -> int: + """Calculate port offset based on SLURM job ID to avoid conflicts. + + When multiple jobs run on the same nodes, they need different port ranges. + This function computes an offset based on job_id: (job_id % 100) * 10 + + Args: + job_id: SLURM job ID (if None, reads from environment) + + Returns: + Port offset to add to base ports (0-990) + + Example: + Job 2437: (2437 % 100) * 10 = 370 + Job 2518: (2518 % 100) * 10 = 180 + """ + if job_id is None: + job_id = get_slurm_job_id() + + if not job_id: + return 0 + + try: + offset = (int(job_id) % 100) * 10 + logger.debug("Port offset for job %s: %d", job_id, offset) + return offset + except (ValueError, TypeError): + return 0 + + def get_slurm_nodelist() -> list[str]: """Get list of nodes from SLURM_NODELIST environment variable. @@ -144,7 +174,7 @@ def start_srun_process( nodelist: Sequence[str] | None = None, output: str | None = None, container_image: str | None = None, - container_mounts: dict[Path, Path] | None = None, + container_mounts: Mapping[Path | str, Path | str] | None = None, env_to_pass_through: list[str] | None = None, env_to_set: dict[str, str] | None = None, bash_preamble: str | None = None, diff --git a/src/srtctl/core/topology.py b/src/srtctl/core/topology.py index f2a24e5d0..9144704de 100644 --- a/src/srtctl/core/topology.py +++ b/src/srtctl/core/topology.py @@ -40,39 +40,58 @@ class NodePortAllocator: - http_port: 30000+ (per node) - HTTP serving port - bootstrap_port: 31000+ (per node) - P/D coordination port (prefill only) + Port offset based on SLURM_JOB_ID avoids conflicts when multiple jobs + run on the same nodes: offset = (job_id % 100) * 10 + Example: allocator = NodePortAllocator() # Two workers on same node get different ports - port1 = allocator.next_http_port("node0") # 30000 - port2 = allocator.next_http_port("node0") # 30001 + port1 = allocator.next_http_port("node0") # 30000 + offset + port2 = allocator.next_http_port("node0") # 30001 + offset # Different node starts fresh - port3 = allocator.next_http_port("node1") # 30000 + port3 = allocator.next_http_port("node1") # 30000 + offset """ base_http_port: int = 30000 base_bootstrap_port: int = 31000 base_kv_events_port: int = 5550 base_nixl_port: int = 6550 # NIXL side channel ports (must not overlap with kv_events) + port_offset: int = 0 # Offset based on job_id to avoid cross-job conflicts _http_ports: dict[str, int] = field(default_factory=dict, repr=False) _bootstrap_ports: dict[str, int] = field(default_factory=dict, repr=False) _next_kv_events_port: int = field(default=0, repr=False) # Global counter _next_nixl_port: int = field(default=0, repr=False) # Global counter for NIXL + @classmethod + def with_job_offset(cls, job_id: str | None = None) -> "NodePortAllocator": + """Create allocator with port offset based on SLURM job ID. + + Args: + job_id: SLURM job ID (if None, reads from environment) + + Returns: + NodePortAllocator with appropriate port offset + """ + from srtctl.core.slurm import get_port_offset + + offset = get_port_offset(job_id) + return cls(port_offset=offset) + def next_http_port(self, node: str) -> int: """Get next available HTTP port for a node.""" if node not in self._http_ports: - self._http_ports[node] = self.base_http_port + self._http_ports[node] = self.base_http_port + self.port_offset port = self._http_ports[node] - self._http_ports[node] += 1000 + self._http_ports[node] += 1 return port def next_bootstrap_port(self, node: str) -> int: """Get next available bootstrap port for a node (prefill only).""" if node not in self._bootstrap_ports: - self._bootstrap_ports[node] = self.base_bootstrap_port + self._bootstrap_ports[node] = self.base_bootstrap_port + self.port_offset port = self._bootstrap_ports[node] self._bootstrap_ports[node] += 1 return port @@ -80,7 +99,7 @@ def next_bootstrap_port(self, node: str) -> int: def next_kv_events_port(self) -> int: """Get next available kv-events ZMQ port (globally unique across all nodes).""" if self._next_kv_events_port == 0: - self._next_kv_events_port = self.base_kv_events_port + self._next_kv_events_port = self.base_kv_events_port + self.port_offset port = self._next_kv_events_port self._next_kv_events_port += 1 return port @@ -88,7 +107,7 @@ def next_kv_events_port(self) -> int: def next_nixl_port(self) -> int: """Get next available NIXL side channel port (globally unique across all nodes).""" if self._next_nixl_port == 0: - self._next_nixl_port = self.base_nixl_port + self._next_nixl_port = self.base_nixl_port + self.port_offset port = self._next_nixl_port self._next_nixl_port += 1 return port @@ -367,6 +386,7 @@ def endpoints_to_processes( endpoints: list[Endpoint], base_sys_port: int = 8081, port_allocator: NodePortAllocator | None = None, + job_id: str | None = None, ) -> list[Process]: """Convert endpoints to physical processes. @@ -376,19 +396,24 @@ def endpoints_to_processes( Ports are assigned per-node to avoid conflicts when multiple workers share a node (e.g., 2 decode workers with 4 GPUs each on an 8-GPU node). + Port offset based on job_id avoids conflicts between different SLURM jobs. + Args: endpoints: List of Endpoint objects base_sys_port: Starting port for DYN_SYSTEM_PORT assignment port_allocator: NodePortAllocator for HTTP/bootstrap ports (created if None) + job_id: SLURM job ID for port offset calculation (if None, reads from env) Returns: List of Process objects """ processes: list[Process] = [] + # Note: base_sys_port already includes job-based offset from do_sweep.py, + # so we don't add port_offset here to avoid double-counting current_sys_port = base_sys_port if port_allocator is None: - port_allocator = NodePortAllocator() + port_allocator = NodePortAllocator.with_job_offset(job_id) for endpoint in endpoints: # Allocate bootstrap port once per prefill endpoint (shared by all processes) diff --git a/src/srtctl/frontends/dynamo.py b/src/srtctl/frontends/dynamo.py index 5e5109a18..884a46fe1 100644 --- a/src/srtctl/frontends/dynamo.py +++ b/src/srtctl/frontends/dynamo.py @@ -77,9 +77,15 @@ def start_frontends( cmd = ["python3", "-m", "dynamo.frontend", f"--http-port={topology.frontend_port}"] cmd.extend(self.get_frontend_args_list(config.frontend.args)) + from srtctl.core.slurm import get_port_offset + + port_offset = get_port_offset(runtime.job_id) + nats_port = 4222 + port_offset + etcd_port = 2379 + port_offset + env_to_set = { - "ETCD_ENDPOINTS": f"http://{runtime.nodes.infra}:2379", - "NATS_SERVER": f"nats://{runtime.nodes.infra}:4222", + "ETCD_ENDPOINTS": f"http://{runtime.infra_node_ip}:{etcd_port}", + "NATS_SERVER": f"nats://{runtime.infra_node_ip}:{nats_port}", "DYN_REQUEST_PLANE": "nats", } diff --git a/tests/test_frontend_topology.py b/tests/test_frontend_topology.py index 6fd741573..653544039 100644 --- a/tests/test_frontend_topology.py +++ b/tests/test_frontend_topology.py @@ -11,6 +11,11 @@ from srtctl.core.runtime import Nodes, RuntimeContext from srtctl.core.schema import FrontendConfig, ResourceConfig, SrtConfig +TEST_JOB_ID = "12345" +_PORT_OFFSET = (int(TEST_JOB_ID) % 100) * 10 # 450 +EXPECTED_PUBLIC_PORT = 8000 + _PORT_OFFSET +EXPECTED_INTERNAL_PORT = 9090 + _PORT_OFFSET + def make_config( *, @@ -39,7 +44,7 @@ def make_config( def make_runtime(nodes: list[str]) -> RuntimeContext: """Create a minimal RuntimeContext for testing.""" return RuntimeContext( - job_id="12345", + job_id=TEST_JOB_ID, run_name="test-run", nodes=Nodes(head=nodes[0], bench=nodes[0], infra=nodes[0], worker=tuple(nodes)), head_node_ip="10.0.0.1", @@ -80,7 +85,7 @@ class TestComputeFrontendTopology: """Tests for _compute_frontend_topology method.""" def test_single_node_no_nginx(self): - """Single node: no nginx, 1 frontend on head at port 8000.""" + """Single node: no nginx, 1 frontend on head with offset port.""" config = make_config(enable_multiple_frontends=True) runtime = make_runtime(["node0"]) @@ -89,8 +94,8 @@ def test_single_node_no_nginx(self): assert topology.nginx_node is None assert topology.frontend_nodes == ["node0"] - assert topology.frontend_port == 8000 - assert topology.public_port == 8000 + assert topology.frontend_port == EXPECTED_PUBLIC_PORT + assert topology.public_port == EXPECTED_PUBLIC_PORT assert topology.uses_nginx is False def test_multi_node_frontends_disabled(self): @@ -103,8 +108,8 @@ def test_multi_node_frontends_disabled(self): assert topology.nginx_node is None assert topology.frontend_nodes == ["node0"] - assert topology.frontend_port == 8000 - assert topology.public_port == 8000 + assert topology.frontend_port == EXPECTED_PUBLIC_PORT + assert topology.public_port == EXPECTED_PUBLIC_PORT assert topology.uses_nginx is False def test_two_nodes_with_nginx(self): @@ -117,8 +122,8 @@ def test_two_nodes_with_nginx(self): assert topology.nginx_node == "node0" assert topology.frontend_nodes == ["node1"] - assert topology.frontend_port == 8180 # Behind nginx - assert topology.public_port == 8000 + assert topology.frontend_port == EXPECTED_INTERNAL_PORT + assert topology.public_port == EXPECTED_PUBLIC_PORT assert topology.uses_nginx is True def test_three_nodes_with_nginx(self): @@ -131,8 +136,8 @@ def test_three_nodes_with_nginx(self): assert topology.nginx_node == "node0" assert topology.frontend_nodes == ["node1", "node2"] - assert topology.frontend_port == 8180 - assert topology.public_port == 8000 + assert topology.frontend_port == EXPECTED_INTERNAL_PORT + assert topology.public_port == EXPECTED_PUBLIC_PORT assert topology.uses_nginx is True def test_many_nodes_with_nginx(self):