diff --git a/.gitignore b/.gitignore index 84c7ea4325..5b086664ee 100644 --- a/.gitignore +++ b/.gitignore @@ -89,4 +89,7 @@ generated-values.yaml TensorRT-LLM # Local build artifacts for devcontainer -.build/ \ No newline at end of file +.build/ + +# Pytest +.coverage diff --git a/components/backends/vllm/src/dynamo/vllm/args.py b/components/backends/vllm/src/dynamo/vllm/args.py index b86649f06b..889405f6af 100644 --- a/components/backends/vllm/src/dynamo/vllm/args.py +++ b/components/backends/vllm/src/dynamo/vllm/args.py @@ -2,13 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 -import asyncio -import json import logging import os -import socket import sys -import time from typing import Optional from vllm.config import KVTransferConfig @@ -16,9 +12,20 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.utils import FlexibleArgumentParser +from .ports import ( + DEFAULT_DYNAMO_PORT_MAX, + DEFAULT_DYNAMO_PORT_MIN, + DynamoPortRange, + EtcdContext, + PortAllocationRequest, + PortMetadata, + allocate_and_reserve_port, + allocate_and_reserve_port_block, + get_host_ip, +) + logger = logging.getLogger(__name__) -# Only used if you run it manually from the command line DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate" DEFAULT_MODEL = "Qwen/Qwen3-0.6B" @@ -34,6 +41,7 @@ class Config: migration_limit: int = 0 kv_port: Optional[int] = None side_channel_port: Optional[int] = None + port_range: DynamoPortRange # mirror vLLM model: str @@ -64,6 +72,18 @@ def parse_args() -> Config: default=0, help="Maximum number of times a request may be migrated to a different engine worker. The number may be overridden by the engine.", ) + parser.add_argument( + "--dynamo-port-min", + type=int, + default=DEFAULT_DYNAMO_PORT_MIN, + help=f"Minimum port number for Dynamo services (default: {DEFAULT_DYNAMO_PORT_MIN}). Must be in registered ports range (1024-49151).", + ) + parser.add_argument( + "--dynamo-port-max", + type=int, + default=DEFAULT_DYNAMO_PORT_MAX, + help=f"Maximum port number for Dynamo services (default: {DEFAULT_DYNAMO_PORT_MAX}). Must be in registered ports range (1024-49151).", + ) parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() @@ -110,6 +130,9 @@ def parse_args() -> Config: config.engine_args = engine_args config.is_prefill_worker = args.is_prefill_worker config.migration_limit = args.migration_limit + config.port_range = DynamoPortRange( + min=args.dynamo_port_min, max=args.dynamo_port_max + ) if config.engine_args.block_size is None: config.engine_args.block_size = 16 @@ -120,106 +143,66 @@ def parse_args() -> Config: return config -async def allocate_and_reserve_port( - namespace, - etcd_client, - worker_id: str, - reason: str, - max_attempts: int = 100, -) -> int: - """ - Get an OS-assigned port and atomically reserve it in ETCD. - Retries until successful or max_attempts reached. - - Args: - max_attempts: Maximum number of ports to try (default: 100) - - Raises: - RuntimeError: If unable to reserve a port within max_attempts - OSError: If unable to create sockets (system resource issues) - """ - - node_name = socket.gethostname() - try: - node_ip = socket.gethostbyname(node_name) - except socket.gaierror: - # If hostname cannot be resolved, fall back to localhost - logger.warning( - f"Hostname '{node_name}' cannot be resolved, falling back to '127.0.0.1'" - ) - node_ip = "127.0.0.1" - - for attempt in range(1, max_attempts + 1): - # Hold socket open just long enough to reserve in ETCD - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(("", 0)) - port = sock.getsockname()[1] - - # Reserve in ETCD while holding the socket - key = f"dyn://{namespace}/ports/{node_ip}/{port}" - value = { - "worker_id": worker_id, - "reason": reason, - "reserved_at": time.time(), - "pid": os.getpid(), - } - - try: - await etcd_client.kv_create( - key=key, - value=json.dumps(value).encode(), - lease_id=etcd_client.primary_lease_id(), - ) - logger.debug(f"Reserved OS-assigned port {port} for {worker_id}") - return port - - except Exception as e: - logger.debug( - f"Port {port} on {node_name} was already reserved (attempt {attempt}): {e}" - ) - - if attempt < max_attempts: - await asyncio.sleep(0.01) - - raise RuntimeError( - f"Failed to allocate and reserve a port after {max_attempts} attempts" - ) - - async def configure_ports_with_etcd(config: Config, etcd_client): """Configure all settings that require ETCD, including port allocation and vLLM overrides.""" - # First, allocate ports + etcd_context = EtcdContext(client=etcd_client, namespace=config.namespace) + dp_rank = config.engine_args.data_parallel_rank or 0 worker_id = f"vllm-{config.component}-dp{dp_rank}" # Allocate KV events port - kv_port = await allocate_and_reserve_port( - namespace=config.namespace, - etcd_client=etcd_client, - worker_id=f"{worker_id}", - reason="zmq_kv_event_port", + if config.engine_args.enable_prefix_caching: + kv_metadata = PortMetadata(worker_id=worker_id, reason="zmq_kv_event_port") + kv_port = await allocate_and_reserve_port( + etcd_context=etcd_context, + metadata=kv_metadata, + port_range=config.port_range, + ) + config.kv_port = kv_port + logger.info(f"Allocated ZMQ KV events port: {kv_port} (worker_id={worker_id})") + + # Allocate side channel ports + # https://github.com/vllm-project/vllm/blob/releases/v0.10.0/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py#L372 + # NIXL calculates ports as: base_port + (dp_rank * tp_size) + tp_rank + # For dp_rank, we need to reserve tp_size consecutive ports + tp_size = config.engine_args.tensor_parallel_size or 1 + + # The first port for this dp_rank will be at: base_port + (dp_rank * tp_size) + # We need to allocate tp_size consecutive ports starting from there + nixl_metadata = PortMetadata(worker_id=worker_id, reason="nixl_side_channel_port") + nixl_request = PortAllocationRequest( + etcd_context=etcd_context, + metadata=nixl_metadata, + port_range=config.port_range, + block_size=tp_size, ) + allocated_ports = await allocate_and_reserve_port_block(nixl_request) + first_port_for_dp_rank = allocated_ports[0] + + # Calculate the base port that NIXL expects + # base_port = first_port_for_dp_rank - (dp_rank * tp_size) + nixl_offset = dp_rank * tp_size + base_side_channel_port = first_port_for_dp_rank - nixl_offset + + if base_side_channel_port < 0: + raise ValueError( + f"NIXL base port calculation resulted in negative port: " + f"first_allocated_port={first_port_for_dp_rank}, offset={nixl_offset}, " + f"base_port={base_side_channel_port}. Current range: {config.port_range.min}-{config.port_range.max}. " + f"Consider using a higher port range." + ) - # Allocate side channel port - side_channel_port = await allocate_and_reserve_port( - namespace=config.namespace, - etcd_client=etcd_client, - worker_id=f"{worker_id}", - reason="nixl_side_channel_port", - ) + config.side_channel_port = base_side_channel_port - # Update config with allocated ports - config.kv_port = kv_port - config.side_channel_port = side_channel_port + logger.info( + f"Allocated NIXL side channel ports: base={base_side_channel_port}, " + f"allocated_ports={allocated_ports} (worker_id={worker_id}, dp_rank={dp_rank}, tp_size={tp_size})" + ) def overwrite_args(config): """Set vLLM defaults for Dynamo.""" - assert ( - config.kv_port is not None - ), "Must set the kv_port, use configure_ports_with_etcd" assert ( config.side_channel_port is not None ), "Must set the kv_port, use configure_ports_with_etcd" @@ -263,36 +246,6 @@ def overwrite_args(config): raise ValueError(f"{key} not found in AsyncEngineArgs from vLLM.") -def get_host_ip() -> str: - """Get the IP address of the host. - This is needed for the side channel to work in multi-node deployments. - """ - try: - host_name = socket.gethostname() - except socket.error as e: - logger.warning(f"Failed to get hostname: {e}, falling back to '127.0.0.1'") - return "127.0.0.1" - else: - try: - # Get the IP address of the hostname - this is needed for the side channel to work in multi-node deployments - host_ip = socket.gethostbyname(host_name) - # Test if the IP is actually usable by binding to it - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as test_socket: - test_socket.bind((host_ip, 0)) - return host_ip - except socket.gaierror as e: - logger.warning( - f"Hostname '{host_name}' cannot be resolved: {e}, falling back to '127.0.0.1'" - ) - return "127.0.0.1" - except socket.error as e: - # If hostname is not usable for binding, fall back to localhost - logger.warning( - f"Hostname '{host_name}' is not usable for binding: {e}, falling back to '127.0.0.1'" - ) - return "127.0.0.1" - - def set_side_channel_host_and_port(config: Config): """vLLM V1 NixlConnector creates a side channel to exchange metadata with other NIXL connectors. This sets the port number for the side channel. diff --git a/components/backends/vllm/src/dynamo/vllm/ports.py b/components/backends/vllm/src/dynamo/vllm/ports.py new file mode 100644 index 0000000000..19fdde7279 --- /dev/null +++ b/components/backends/vllm/src/dynamo/vllm/ports.py @@ -0,0 +1,290 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Port allocation and management utilities for Dynamo services.""" + +import asyncio +import json +import logging +import os +import random +import socket +import time +from contextlib import contextmanager +from dataclasses import dataclass, field + +from dynamo.runtime import EtcdKvCache + +logger = logging.getLogger(__name__) + +# Default port range in the registered ports section +DEFAULT_DYNAMO_PORT_MIN = 20000 +DEFAULT_DYNAMO_PORT_MAX = 30000 + + +@dataclass +class DynamoPortRange: + """Port range configuration for Dynamo services""" + + min: int + max: int + + def __post_init__(self): + if self.min < 1024 or self.max > 49151: + raise ValueError( + f"Port range {self.min}-{self.max} is outside registered ports range (1024-49151)" + ) + if self.min >= self.max: + raise ValueError( + f"Invalid port range: min ({self.min}) must be less than max ({self.max})" + ) + + +@dataclass +class EtcdContext: + """Context for ETCD operations""" + + client: EtcdKvCache # etcd client instance + namespace: str # Namespace for keys (used in key prefix) + + def make_port_key(self, port: int) -> str: + """Generate ETCD key for a port reservation""" + node_ip = get_host_ip() + return f"dyn://{self.namespace}/ports/{node_ip}/{port}" + + +@dataclass +class PortMetadata: + """Metadata to store with port reservations in ETCD""" + + worker_id: str # Worker identifier (e.g., "vllm-backend-dp0") + reason: str # Purpose of the port (e.g., "nixl_side_channel_port") + block_info: dict = field(default_factory=dict) # Optional block allocation info + + def to_etcd_value(self) -> dict: + """Convert to dictionary for ETCD storage""" + value = { + "worker_id": self.worker_id, + "reason": self.reason, + "reserved_at": time.time(), + "pid": os.getpid(), + } + if self.block_info: + value.update(self.block_info) + return value + + +@dataclass +class PortAllocationRequest: + """Parameters for port allocation""" + + etcd_context: EtcdContext + metadata: PortMetadata + port_range: DynamoPortRange + block_size: int = 1 + max_attempts: int = 100 + + +@contextmanager +def hold_ports(ports: int | list[int]): + """Context manager to hold port binding(s). + + Holds socket bindings to ensure exclusive access to ports during reservation. + Can handle a single port or multiple ports. + + Args: + ports: Single port number or list of port numbers to hold + """ + if isinstance(ports, int): + ports = [ports] + + sockets = [] + try: + for port in ports: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(("", port)) + sockets.append(sock) + + yield + + finally: + for sock in sockets: + sock.close() + + +def check_port_available(port: int) -> bool: + """Check if a specific port is available for binding.""" + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("", port)) + return True + except OSError: + return False + + +async def reserve_port_in_etcd( + etcd_context: EtcdContext, + port: int, + metadata: PortMetadata, +) -> None: + """Reserve a single port in ETCD.""" + key = etcd_context.make_port_key(port) + value = metadata.to_etcd_value() + + await etcd_context.client.kv_create( + key=key, + value=json.dumps(value).encode(), + lease_id=etcd_context.client.primary_lease_id(), + ) + + +async def allocate_and_reserve_port_block(request: PortAllocationRequest) -> list[int]: + """ + Allocate a contiguous block of ports from the specified range and atomically reserve them in ETCD. + Returns a list of all allocated ports in order. + + This function uses a context manager to hold port bindings while reserving in ETCD, + preventing race conditions between multiple processes. + + Args: + request: PortAllocationRequest containing all allocation parameters + + Returns: + list[int]: List of all allocated ports in ascending order + + Raises: + RuntimeError: If unable to reserve a port block within max_attempts + OSError: If unable to create sockets (system resource issues) + """ + # Create a list of valid starting ports (must have room for the entire block) + max_start_port = request.port_range.max - request.block_size + 1 + if max_start_port < request.port_range.min: + raise ValueError( + f"Port range {request.port_range.min}-{request.port_range.max} is too small for block size {request.block_size}" + ) + + available_start_ports = list(range(request.port_range.min, max_start_port + 1)) + random.shuffle(available_start_ports) + + actual_max_attempts = min(len(available_start_ports), request.max_attempts) + + for attempt in range(1, actual_max_attempts + 1): + start_port = available_start_ports[attempt - 1] + ports_to_reserve = list(range(start_port, start_port + request.block_size)) + + try: + # Try to bind to all ports in the block atomically + with hold_ports(ports_to_reserve): + logger.debug( + f"Successfully bound to ports {ports_to_reserve}, now reserving in ETCD" + ) + + # We have exclusive access to these ports, now reserve them in ETCD + for i, port in enumerate(ports_to_reserve): + port_metadata = PortMetadata( + worker_id=f"{request.metadata.worker_id}-{i}" + if request.block_size > 1 + else request.metadata.worker_id, + reason=request.metadata.reason, + block_info={ + "block_index": i, + "block_size": request.block_size, + "block_start": start_port, + } + if request.block_size > 1 + else {}, + ) + + await reserve_port_in_etcd( + etcd_context=request.etcd_context, + port=port, + metadata=port_metadata, + ) + + logger.debug( + f"Reserved port block {ports_to_reserve} from range {request.port_range.min}-{request.port_range.max} " + f"for {request.metadata.worker_id} (block_size={request.block_size})" + ) + return ports_to_reserve + + except OSError as e: + logger.debug( + f"Failed to bind to port block starting at {start_port} (attempt {attempt}): {e}" + ) + except Exception as e: + logger.debug( + f"Failed to reserve port block starting at {start_port} in ETCD (attempt {attempt}): {e}" + ) + + if attempt < actual_max_attempts: + await asyncio.sleep(0.01) + + raise RuntimeError( + f"Failed to allocate and reserve a port block of size {request.block_size} from range " + f"{request.port_range.min}-{request.port_range.max} after {actual_max_attempts} attempts" + ) + + +async def allocate_and_reserve_port( + etcd_context: EtcdContext, + metadata: PortMetadata, + port_range: DynamoPortRange, + max_attempts: int = 100, +) -> int: + """ + Allocate a port from the specified range and atomically reserve it in ETCD. + This is a convenience wrapper around allocate_and_reserve_port_block with block_size=1. + + Args: + etcd_context: ETCD context for operations + metadata: Port metadata for ETCD storage + port_range: DynamoPortRange object specifying min and max ports to try + max_attempts: Maximum number of ports to try (default: 100) + + Returns: + int: The allocated port number + + Raises: + RuntimeError: If unable to reserve a port within max_attempts + OSError: If unable to create sockets (system resource issues) + """ + request = PortAllocationRequest( + etcd_context=etcd_context, + metadata=metadata, + port_range=port_range, + block_size=1, + max_attempts=max_attempts, + ) + allocated_ports = await allocate_and_reserve_port_block(request) + return allocated_ports[0] # Return the single allocated port + + +def get_host_ip() -> str: + """Get the IP address of the host. + This is needed for the side channel to work in multi-node deployments. + """ + try: + host_name = socket.gethostname() + except socket.error as e: + logger.warning(f"Failed to get hostname: {e}, falling back to '127.0.0.1'") + return "127.0.0.1" + else: + try: + # Get the IP address of the hostname - this is needed for the side channel to work in multi-node deployments + host_ip = socket.gethostbyname(host_name) + # Test if the IP is actually usable by binding to it + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as test_socket: + test_socket.bind((host_ip, 0)) + return host_ip + except socket.gaierror as e: + logger.warning( + f"Hostname '{host_name}' cannot be resolved: {e}, falling back to '127.0.0.1'" + ) + return "127.0.0.1" + except socket.error as e: + # If hostname is not usable for binding, fall back to localhost + logger.warning( + f"Hostname '{host_name}' is not usable for binding: {e}, falling back to '127.0.0.1'" + ) + return "127.0.0.1" diff --git a/components/backends/vllm/src/dynamo/vllm/tests/README.md b/components/backends/vllm/src/dynamo/vllm/tests/README.md new file mode 100644 index 0000000000..1065162d04 --- /dev/null +++ b/components/backends/vllm/src/dynamo/vllm/tests/README.md @@ -0,0 +1,39 @@ +# Dynamo vLLM Backend Tests + +This directory contains unit tests for the Dynamo vLLM backend components. + +## Running Tests + +### Run all tests +```bash +cd components/backends/vllm/src/dynamo/tests +python -m pytest +``` + +### Run specific test file +```bash +python -m pytest test_ports.py +``` + +### Run with coverage +```bash +python -m pytest --cov=dynamo.vllm.ports --cov-report=term +``` + +### Run specific test class or method +```bash +python -m pytest test_ports.py::TestPortBinding +python -m pytest test_ports.py::TestPortBinding::test_single_port_binding +``` + +## Dependencies + +The tests require: +- `pytest` - Test framework +- `pytest-asyncio` - For async test support +- `pytest-cov` - For coverage reports (optional) + +Install with: +```bash +pip install pytest pytest-asyncio pytest-cov +``` diff --git a/components/backends/vllm/src/dynamo/vllm/tests/__init__.py b/components/backends/vllm/src/dynamo/vllm/tests/__init__.py new file mode 100644 index 0000000000..c5f715b9cb --- /dev/null +++ b/components/backends/vllm/src/dynamo/vllm/tests/__init__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for dynamo vllm backend components.""" diff --git a/components/backends/vllm/src/dynamo/vllm/tests/pytest.ini b/components/backends/vllm/src/dynamo/vllm/tests/pytest.ini new file mode 100644 index 0000000000..be536e08ce --- /dev/null +++ b/components/backends/vllm/src/dynamo/vllm/tests/pytest.ini @@ -0,0 +1,24 @@ +; [pytest] +; # Pytest configuration for dynamo vllm tests + +; testpaths = . +; python_files = test_*.py +; python_classes = Test* +; python_functions = test_* + +; # Add parent directory to Python path for imports +; pythonpath = ../.. + +; # Markers for async tests +; markers = +; asyncio: marks tests as async (deselect with '-m "not asyncio"') + +; # Test output options +; addopts = +; -v +; --tb=short +; --strict-markers +; --disable-warnings + +; # Async test configuration +; asyncio_mode = auto diff --git a/components/backends/vllm/src/dynamo/vllm/tests/test_ports.py b/components/backends/vllm/src/dynamo/vllm/tests/test_ports.py new file mode 100644 index 0000000000..1d07edc33c --- /dev/null +++ b/components/backends/vllm/src/dynamo/vllm/tests/test_ports.py @@ -0,0 +1,229 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for port allocation and management utilities.""" + +import json +import socket +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import pytest + +from dynamo.vllm.ports import ( + DynamoPortRange, + EtcdContext, + PortAllocationRequest, + PortMetadata, + allocate_and_reserve_port, + allocate_and_reserve_port_block, + check_port_available, + get_host_ip, + hold_ports, + reserve_port_in_etcd, +) + + +class TestDynamoPortRange: + """Test DynamoPortRange validation.""" + + def test_valid_port_range(self): + """Test creating a valid port range.""" + port_range = DynamoPortRange(min=2000, max=3000) + assert port_range.min == 2000 + assert port_range.max == 3000 + + def test_port_range_outside_registered_range(self): + """Test that port ranges outside 1024-49151 are rejected.""" + with pytest.raises(ValueError, match="outside registered ports range"): + DynamoPortRange(min=500, max=2000) + + with pytest.raises(ValueError, match="outside registered ports range"): + DynamoPortRange(min=2000, max=50000) + + def test_invalid_port_range_min_greater_than_max(self): + """Test that min >= max is rejected.""" + with pytest.raises(ValueError, match="min .* must be less than max"): + DynamoPortRange(min=3000, max=2000) + + with pytest.raises(ValueError, match="min .* must be less than max"): + DynamoPortRange(min=3000, max=3000) + + +class TestPortMetadata: + """Test port metadata functionality.""" + + def test_to_etcd_value_with_block_info(self): + """Test converting metadata to ETCD value with block info.""" + metadata = PortMetadata( + worker_id="test-worker", + reason="test-reason", + block_info={"block_index": 0, "block_size": 4, "block_start": 8080}, + ) + + value = metadata.to_etcd_value() + assert value["block_index"] == 0 + assert value["block_size"] == 4 + assert value["block_start"] == 8080 + + +class TestHoldPorts: + """Test hold_ports context manager.""" + + def test_hold_single_port(self): + """Test holding a single port.""" + with socket.socket() as s: + s.bind(("", 0)) + port = s.getsockname()[1] + + with hold_ports(port): + assert not check_port_available(port) + + # Port should be released after context exit + assert check_port_available(port) + + def test_hold_multiple_ports(self): + """Test holding multiple ports.""" + ports = [] + for _ in range(2): + with socket.socket() as s: + s.bind(("", 0)) + ports.append(s.getsockname()[1]) + + with hold_ports(ports): + for port in ports: + assert not check_port_available(port) + + # All ports should be released after context exit + for port in ports: + assert check_port_available(port) + + +class TestReservePortInEtcd: + """Test ETCD port reservation.""" + + @pytest.mark.asyncio + async def test_reserve_port_success(self): + """Test successful port reservation in ETCD.""" + mock_client = AsyncMock() + mock_client.primary_lease_id = Mock(return_value="test-lease-123") + + context = EtcdContext(client=mock_client, namespace="test-ns") + metadata = PortMetadata(worker_id="test-worker", reason="test") + + host_ip = get_host_ip() + await reserve_port_in_etcd(context, 8080, metadata) + + mock_client.kv_create.assert_called_once() + call_args = mock_client.kv_create.call_args + + assert call_args.kwargs["key"] == f"dyn://test-ns/ports/{host_ip}/8080" + assert call_args.kwargs["lease_id"] == "test-lease-123" + + # Check the value is valid JSON + value_bytes = call_args.kwargs["value"] + value_dict = json.loads(value_bytes.decode()) + assert value_dict["worker_id"] == "test-worker" + assert value_dict["reason"] == "test" + + +class TestAllocateAndReservePort: + """Test single port allocation.""" + + @pytest.mark.asyncio + async def test_allocate_single_port_success(self): + """Test successful single port allocation.""" + mock_client = AsyncMock() + mock_client.primary_lease_id = Mock(return_value="test-lease") + + context = EtcdContext(client=mock_client, namespace="test-ns") + metadata = PortMetadata(worker_id="test-worker", reason="test") + port_range = DynamoPortRange(min=20000, max=20010) + + # Mock that all ports are available + with patch("dynamo.vllm.ports.check_port_available", return_value=True): + with patch("dynamo.vllm.ports.hold_ports") as mock_hold: + # Set up the context manager mock + mock_hold.return_value.__enter__ = Mock() + mock_hold.return_value.__exit__ = Mock(return_value=None) + + port = await allocate_and_reserve_port( + context, metadata, port_range, max_attempts=5 + ) + + assert 20000 <= port <= 20010 + mock_client.kv_create.assert_called_once() + + +class TestAllocateAndReservePortBlock: + """Test port block allocation.""" + + @pytest.mark.asyncio + async def test_allocate_block_success(self): + """Test successful port block allocation.""" + mock_client = AsyncMock() + mock_client.primary_lease_id = Mock(return_value="test-lease") + + context = EtcdContext(client=mock_client, namespace="test-ns") + metadata = PortMetadata(worker_id="test-worker", reason="test") + port_range = DynamoPortRange(min=20000, max=20010) + + request = PortAllocationRequest( + etcd_context=context, + metadata=metadata, + port_range=port_range, + block_size=3, + max_attempts=5, + ) + + with patch("dynamo.vllm.ports.hold_ports") as mock_hold: + # Set up the context manager mock + mock_hold.return_value.__enter__ = Mock() + mock_hold.return_value.__exit__ = Mock(return_value=None) + + ports = await allocate_and_reserve_port_block(request) + + assert len(ports) == 3 + assert all(20000 <= p <= 20010 for p in ports) + assert ports == list(range(ports[0], ports[0] + 3)) + + # Should have reserved 3 ports + assert mock_client.kv_create.call_count == 3 + + @pytest.mark.asyncio + async def test_allocate_block_port_range_too_small(self): + """Test error when port range is too small for block.""" + context = EtcdContext(client=Mock(), namespace="test-ns") + metadata = PortMetadata(worker_id="test-worker", reason="test") + port_range = DynamoPortRange(min=20000, max=20002) + + request = PortAllocationRequest( + etcd_context=context, + metadata=metadata, + port_range=port_range, + block_size=5, # Needs 5 ports but range only has 3 + ) + + with pytest.raises( + ValueError, match="Port range .* is too small for block size" + ): + await allocate_and_reserve_port_block(request) + + +class TestGetHostIP: + """Test get_host_ip function.""" + + def test_get_host_ip_success(self): + """Test successful host IP retrieval.""" + with patch("socket.gethostname", return_value="test-host"): + with patch("socket.gethostbyname", return_value="192.168.1.100"): + with patch("socket.socket") as mock_socket_class: + # Mock successful bind + mock_socket = MagicMock() + mock_socket_class.return_value.__enter__.return_value = mock_socket + + ip = get_host_ip() + assert ip == "192.168.1.100" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/pyproject.toml b/pyproject.toml index 245711df19..8beae1c9b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -132,6 +132,12 @@ indent = " " skip = ["build"] known_first_party = ["dynamo"] +[pytest] +pythonpath = [ + ".", + "components/backends/vlm/src" +] + [tool.pytest.ini_options] minversion = "8.0" tmp_path_retention_policy = "failed"