Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 155 additions & 50 deletions examples/vllm/components/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import json
import logging
import os
import socket
import sys
import time
from typing import Optional

from vllm.config import KVTransferConfig
Expand All @@ -30,23 +34,16 @@
DEFAULT_MODEL = "Qwen/Qwen3-0.6B"


def find_free_port() -> int:
"""Find a free port by binding to port 0."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
port = s.getsockname()[1]
return port


class Config:
"""Command line parameters or defaults"""

# dynamo specific
namespace: str
component: str
endpoint: str
kv_events_port: int
is_prefill_worker: bool
kv_port: Optional[int] = None
side_channel_port: Optional[int] = None

# mirror vLLM
model: str
Expand All @@ -56,38 +53,6 @@ class Config:
engine_args: AsyncEngineArgs


def overwrite_args(config):
defaults = {
"task": "generate",
"skip_tokenizer_init": True,
"disable_log_requests": True,
"enable_prefix_caching": True,
# KV routing relies on logging KV metrics
"disable_log_stats": False,
# Always set up KV Events for routing
"kv_events_config": KVEventsConfig(
enable_kv_cache_events=True,
publisher="zmq",
endpoint=f"tcp://*:{config.kv_events_port}",
),
# Always setting up kv transfer for disagg
"kv_transfer_config": KVTransferConfig(
kv_connector="NixlConnector", kv_role="kv_both"
),
}

# Made decision to always overwrite.
# Respecting users original cmd line args at all costs requires a bunch of arg parse work

logger.debug("Setting Dynamo defaults for vLLM")
for key, value in defaults.items():
if hasattr(config.engine_args, key):
setattr(config.engine_args, key, value)
logger.debug(f" engine_args.{key} = {value}")
else:
raise ValueError(f"{key} not found in AsyncEngineArgs from vLLM.")


def parse_args() -> Config:
parser = FlexibleArgumentParser(
description="vLLM server integrated with Dynamo LLM."
Expand All @@ -103,12 +68,6 @@ def parse_args() -> Config:
action="store_true",
help="Enable prefill functionality for this worker. Currently overwrites the --endpoint to be a specially chosen dyn://dynamo.prefill.generate",
)
parser.add_argument(
"--kv-events-port",
type=int,
default=find_free_port(),
help="Endpoint where vLLM publishes metrics for dynamo. For DP, we handle the port iteration.",
)

parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
Expand Down Expand Up @@ -143,14 +102,160 @@ def parse_args() -> Config:
config.endpoint = parsed_endpoint_name
config.engine_args = engine_args
config.is_prefill_worker = args.is_prefill_worker
config.kv_events_port = args.kv_events_port

if config.engine_args.block_size is None:
config.engine_args.block_size = 16
logger.debug(
f"Setting reasonable default of {config.engine_args.block_size} for block_size"
)

overwrite_args(config)

return config


async def allocate_and_reserve_port(
namespace,
etcd_client,
worker_id: str,
reason: str,
max_attempts: int = 100,
) -> int:
"""
Get an OS-assigned port and atomically reserve it in ETCD.
Retries until successful or max_attempts reached.

Args:
max_attempts: Maximum number of ports to try (default: 100)

Raises:
RuntimeError: If unable to reserve a port within max_attempts
OSError: If unable to create sockets (system resource issues)
"""

node_name = socket.gethostname()

for attempt in range(1, max_attempts + 1):
# Hold socket open just long enough to reserve in ETCD
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("", 0))
port = sock.getsockname()[1]

# Reserve in ETCD while holding the socket
key = f"dyn://{namespace}/ports/{node_name}/{port}"
value = {
"worker_id": worker_id,
"reason": reason,
"reserved_at": time.time(),
"pid": os.getpid(),
}

try:
await etcd_client.kv_create(
key=key,
value=json.dumps(value).encode(),
lease_id=etcd_client.primary_lease_id(),
)
logger.debug(f"Reserved OS-assigned port {port} for {worker_id}")
return port

except Exception as e:
logger.debug(
f"Port {port} on {node_name} was already reserved (attempt {attempt}): {e}"
)

if attempt < max_attempts:
await asyncio.sleep(0.01)

raise RuntimeError(
f"Failed to allocate and reserve a port after {max_attempts} attempts"
)


async def configure_ports_with_etcd(config: Config, etcd_client):
"""Configure all settings that require ETCD, including port allocation and vLLM overrides."""

# First, allocate ports
dp_rank = config.engine_args.data_parallel_rank or 0
worker_id = f"vllm-{config.component}-dp{dp_rank}"

# Allocate KV events port
kv_port = await allocate_and_reserve_port(
namespace=config.namespace,
etcd_client=etcd_client,
worker_id=f"{worker_id}",
reason="zmq_kv_event_port",
)

# Allocate side channel port
side_channel_port = await allocate_and_reserve_port(
namespace=config.namespace,
etcd_client=etcd_client,
worker_id=f"{worker_id}",
reason="nixl_side_channel_port",
)

# Update config with allocated ports
config.kv_port = kv_port
config.side_channel_port = side_channel_port


def overwrite_args(config):
"""Set vLLM defaults for Dynamo."""
assert (
config.kv_port is not None
), "Must set the kv_port, use configure_ports_with_etcd"
assert (
config.side_channel_port is not None
), "Must set the kv_port, use configure_ports_with_etcd"

dp_rank = config.engine_args.data_parallel_rank or 0

defaults = {
"task": "generate",
"skip_tokenizer_init": True,
"disable_log_requests": True,
"enable_prefix_caching": True,
# KV routing relies on logging KV metrics
"disable_log_stats": False,
# Always setting up kv transfer for disagg
"kv_transfer_config": KVTransferConfig(
kv_connector="NixlConnector", kv_role="kv_both"
),
"kv_events_config": KVEventsConfig(
enable_kv_cache_events=True,
publisher="zmq",
endpoint=f"tcp://*:{config.kv_port - dp_rank}", # vLLM will iterate dp_rank for us, so we need to subtract it out TODO: fix in vLLM
),
}

set_side_channel_host_and_port(config)

logger.debug("Setting Dynamo defaults for vLLM")
for key, value in defaults.items():
if hasattr(config.engine_args, key):
setattr(config.engine_args, key, value)
logger.debug(f" engine_args.{key} = {value}")
else:
raise ValueError(f"{key} not found in AsyncEngineArgs from vLLM.")


def set_side_channel_host_and_port(config: Config, hostname: Optional[str] = None):
"""vLLM V1 NixlConnector creates a side channel to exchange metadata with other NIXL connectors.
This sets the port number for the side channel.
"""
if hostname is None:
hostname = socket.gethostname()
# Test if hostname is usable by attempting to bind to it
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as test_socket:
test_socket.bind((hostname, 0))
except (socket.error, socket.gaierror):
# If hostname is not usable, fall back to localhost
logger.warning(
f"Hostname '{hostname}' is not usable, falling back to '127.0.0.1'"
)
hostname = "127.0.0.1"

os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = hostname
os.environ["VLLM_NIXL_SIDE_CHANNEL_PORT"] = str(config.side_channel_port)
logger.debug(f"Set NIXL side channel to {hostname}:{config.side_channel_port}")
36 changes: 5 additions & 31 deletions examples/vllm/components/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@
import logging
import os
import signal
import socket
from typing import Optional

import uvloop
from args import Config, find_free_port, parse_args
from args import Config, configure_ports_with_etcd, overwrite_args, parse_args
from handlers import DecodeWorkerHandler, PrefillWorkerHandler
from publisher import StatLoggerFactory
from vllm.distributed.kv_events import ZmqEventPublisher
Expand Down Expand Up @@ -57,6 +55,10 @@ async def graceful_shutdown(runtime):
async def worker(runtime: DistributedRuntime):
config = parse_args()

etcd_client = runtime.etcd_client()
await configure_ports_with_etcd(config, etcd_client)
overwrite_args(config)

# Set up signal handler for graceful shutdown
loop = asyncio.get_running_loop()

Expand All @@ -78,8 +80,6 @@ def setup_vllm_engine(config, stat_logger=None):
os.environ["VLLM_NO_USAGE_STATS"] = "1" # Avoid internal HTTP requests
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

set_side_channel_host_and_port()

engine_args = config.engine_args
# Load default sampling params from `generation_config.json`
default_sampling_params = (
Expand All @@ -105,32 +105,6 @@ def setup_vllm_engine(config, stat_logger=None):
return engine_client, vllm_config, default_sampling_params


def set_side_channel_host_and_port(
hostname: Optional[str] = None, port: Optional[int] = None
):
"""vLLM V1 NixlConnector creates a side channel to exchange metadata with other NIXL connectors.
This sets the port number for the side channel.
"""
if hostname is None:
hostname = socket.gethostname()
# Test if hostname is usable by attempting to bind to it
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as test_socket:
test_socket.bind((hostname, 0))
except (socket.error, socket.gaierror):
# If hostname is not usable, fall back to localhost
logger.warning(
f"Hostname '{hostname}' is not usable, falling back to '127.0.0.1'"
)
hostname = "127.0.0.1"
if port is None:
port = find_free_port()
logger.debug("Setting VLLM_NIXL_SIDE_CHANNEL_HOST to %s", hostname)
os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = hostname
logger.debug("Setting VLLM_NIXL_SIDE_CHANNEL_PORT to %s", port)
os.environ["VLLM_NIXL_SIDE_CHANNEL_PORT"] = str(port)


async def init_prefill(runtime: DistributedRuntime, config: Config):
"""
Instantiate and serve
Expand Down
3 changes: 1 addition & 2 deletions examples/vllm/launch/dep.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ for i in {0..3}; do
--data-parallel-rank $i \
--data-parallel-size 4 \
--enable-expert-parallel \
--enforce-eager \
--kv-events-port 49500 &
--enforce-eager &
done

echo "All workers starting. (press Ctrl+C to stop)..."
Expand Down
3 changes: 1 addition & 2 deletions examples/vllm/launch/dsr1_dep.sh
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,7 @@ for ((i=0; i<GPUS_PER_NODE; i++)); do
--data-parallel-address $MASTER_ADDR \
--data-parallel-rpc-port 13345 \
--gpu-memory-utilization 0.95 \
--enforce-eager \
--kv-events-port 49700 2>&1 | tee $LOG_DIR/dsr1_dep_${dp_rank}.log &
--enforce-eager 2>&1 | tee $LOG_DIR/dsr1_dep_${dp_rank}.log &
done

echo "All workers starting. (press Ctrl+C to stop)..."
Expand Down
Loading