Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/vllm_router/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def initialize_all(app: FastAPI, args):
prefill_model_labels=args.prefill_model_labels,
decode_model_labels=args.decode_model_labels,
kv_aware_threshold=args.kv_aware_threshold,
request_reroutes=args.request_reroutes,
)

# Initialize feature gates
Expand Down
7 changes: 7 additions & 0 deletions src/vllm_router/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,13 @@ def parse_args():
help="The threshold for kv-aware routing.",
)

parser.add_argument(
"--request-reroutes",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may consider renaming this argument to a more explicit name, such as --max-instance-failover-reroute-attempts

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed

type=int,
default=0,
help="Number of reroute attempts per failed request",
)

args = parser.parse_args()
args = load_initial_config_from_config_file_if_required(parser, args)

Expand Down
14 changes: 9 additions & 5 deletions src/vllm_router/routers/routing_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def _qps_routing(
ret = url
return ret

def set_request_migration(self, request_reroutes):
self.request_reroutes = request_reroutes

def _update_hash_ring(self, endpoints: List["EndpointInfo"]):
"""
Update the hash ring with the current list of endpoints.
Expand Down Expand Up @@ -466,10 +469,10 @@ def initialize_routing_logic(
) -> RoutingInterface:
if routing_logic == RoutingLogic.ROUND_ROBIN:
logger.info("Initializing round-robin routing logic")
return RoundRobinRouter()
router = RoundRobinRouter()
elif routing_logic == RoutingLogic.SESSION_BASED:
logger.info(f"Initializing session-based routing logic with kwargs: {kwargs}")
return SessionRouter(kwargs.get("session_key"))
router = SessionRouter(kwargs.get("session_key"))
elif routing_logic == RoutingLogic.KVAWARE:
logger.info("Initializing kvaware routing logic")
router = KvawareRouter(
Expand All @@ -478,17 +481,18 @@ def initialize_routing_logic(
kwargs.get("kv_aware_threshold"),
)
router.start_kv_manager()
return router
elif routing_logic == RoutingLogic.PREFIXAWARE:
logger.info("Initializing prefix-aware routing logic")
return PrefixAwareRouter()
router = PrefixAwareRouter()
elif routing_logic == RoutingLogic.DISAGGREGATED_PREFILL:
logger.info("Initializing disaggregated prefill routing logic")
return DisaggregatedPrefillRouter(
router = DisaggregatedPrefillRouter(
kwargs.get("prefill_model_labels"), kwargs.get("decode_model_labels")
)
else:
raise ValueError(f"Invalid routing logic {routing_logic}")
router.set_request_migration(request_reroutes=kwargs.get("request_reroutes"))
return router


def reconfigure_routing_logic(
Expand Down
195 changes: 116 additions & 79 deletions src/vllm_router/services/request_service/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

# --- Request Processing & Routing ---
import asyncio
import json
import os
import time
Expand Down Expand Up @@ -136,8 +137,58 @@ async def process_request(
)


def perform_service_discovery(
request, request_json, request_endpoint, requested_model, error_urls
):
# TODO (ApostaC): merge two awaits into one
service_discovery = get_service_discovery()
endpoints = service_discovery.get_endpoint_info()

aliases = getattr(service_discovery, "aliases", None)
if aliases and requested_model in aliases.keys():
requested_model = aliases[requested_model]
request_body = replace_model_in_request_body(request_json, requested_model)
update_content_length(request, request_body)

if not request_endpoint:
endpoints = list(
filter(
lambda x: requested_model in x.model_names
and not x.sleep
and x.url not in error_urls,
endpoints,
)
)
engine_stats = request.app.state.engine_stats_scraper.get_engine_stats()
request_stats = request.app.state.request_stats_monitor.get_request_stats(
time.time()
)
else:
endpoints = list(
filter(
lambda x: requested_model in x.model_names
and x.Id == request_endpoint
and not x.sleep
and x.url not in error_urls,
endpoints,
)
)
engine_stats, request_stats = None, None

if not endpoints:
return JSONResponse(
status_code=400,
content={
"error": f"Model {requested_model} not found or vLLM engine is sleeping."
},
)
return endpoints, engine_stats, request_stats


async def route_general_request(
request: Request, endpoint: str, background_tasks: BackgroundTasks
request: Request,
endpoint: str,
background_tasks: BackgroundTasks,
):
"""
Route the incoming request to the backend server and stream the response back to the client.
Expand Down Expand Up @@ -203,96 +254,82 @@ async def route_general_request(
status_code=400, detail="Request body is not JSON parsable."
)

service_discovery = get_service_discovery()
endpoints = service_discovery.get_endpoint_info()
# Perform service discovery to request path a number of times equal to reroutes + 1
error_urls = set()
for _ in range(request.app.state.router.reroutes + 1):
endpoints, engine_stats, request_stats = await asyncio.to_thread(
perform_service_discovery,
request,
request_json,
request_endpoint,
requested_model,
error_urls,
)

aliases = getattr(service_discovery, "aliases", None)
if aliases and requested_model in aliases.keys():
requested_model = aliases[requested_model]
request_body = replace_model_in_request_body(request_json, requested_model)
update_content_length(request, request_body)
logger.debug(f"Routing request {request_id} for model: {requested_model}")
if request_endpoint:
server_url = endpoints[0].url
logger.debug(
f"Routing request {request_id} to engine with Id: {endpoints[0].Id}"
)

if not request_endpoint:
endpoints = list(
filter(
lambda x: requested_model in x.model_names and not x.sleep,
endpoints,
elif isinstance(request.app.state.router, KvawareRouter) or isinstance(
request.app.state.router, PrefixAwareRouter
):
server_url = await request.app.state.router.route_request(
endpoints, engine_stats, request_stats, request, request_json
)
)
engine_stats = request.app.state.engine_stats_scraper.get_engine_stats()
request_stats = request.app.state.request_stats_monitor.get_request_stats(
time.time()
)
else:
endpoints = list(
filter(
lambda x: requested_model in x.model_names
and x.Id == request_endpoint
and not x.sleep,
endpoints,
else:
server_url = request.app.state.router.route_request(
endpoints, engine_stats, request_stats, request
)
)

if not endpoints:
return JSONResponse(
status_code=400,
content={
"error": f"Model {requested_model} not found or vLLM engine is sleeping."
},
curr_time = time.time()
# Extract actual session ID from request headers for logging
session_key = (
getattr(request.app.state.router, "session_key", None)
if hasattr(request.app.state.router, "session_key")
else None
)

logger.debug(f"Routing request {request_id} for model: {requested_model}")
if request_endpoint:
server_url = endpoints[0].url
logger.debug(
f"Routing request {request_id} to engine with Id: {endpoints[0].Id}"
session_id = (
request.headers.get(session_key, None) if session_key is not None else None
)
session_id_display = session_id if session_id is not None else "None"

elif isinstance(request.app.state.router, KvawareRouter) or isinstance(
request.app.state.router, PrefixAwareRouter
):
server_url = await request.app.state.router.route_request(
endpoints, engine_stats, request_stats, request, request_json
# Debug logging to help troubleshoot session ID extraction
logger.debug(
f"Debug session extraction - Router type: {type(request.app.state.router).__name__}"
)
else:
server_url = request.app.state.router.route_request(
endpoints, engine_stats, request_stats, request
logger.debug(f"Debug session extraction - Session key config: {session_key}")
logger.debug(
f"Debug session extraction - Request headers: {dict(request.headers)}"
)
logger.debug(f"Debug session extraction - Extracted session ID: {session_id}")

curr_time = time.time()
# Extract actual session ID from request headers for logging
session_key = (
getattr(request.app.state.router, "session_key", None)
if hasattr(request.app.state.router, "session_key")
else None
)
session_id = (
request.headers.get(session_key, None) if session_key is not None else None
)
session_id_display = session_id if session_id is not None else "None"

# Debug logging to help troubleshoot session ID extraction
logger.debug(
f"Debug session extraction - Router type: {type(request.app.state.router).__name__}"
)
logger.debug(f"Debug session extraction - Session key config: {session_key}")
logger.debug(f"Debug session extraction - Request headers: {dict(request.headers)}")
logger.debug(f"Debug session extraction - Extracted session ID: {session_id}")
logger.info(
f"Routing request {request_id} with session id {session_id_display} to {server_url} at {curr_time}, process time = {curr_time - in_router_time:.4f}"
)
error = None
try:
stream_generator = process_request(
request,
request_body,
server_url,
request_id,
endpoint,
background_tasks,
)
headers, status = await anext(stream_generator)
headers_dict = {key: value for key, value in headers.items()}
headers_dict["X-Request-Id"] = request_id
# Break out of the loop when the request's stream is fully generated
break
except Exception as e:
error_urls.add(server_url)
error = e

logger.info(
f"Routing request {request_id} with session id {session_id_display} to {server_url} at {curr_time}, process time = {curr_time - in_router_time:.4f}"
)
stream_generator = process_request(
request,
request_body,
server_url,
request_id,
endpoint,
background_tasks,
)
headers, status = await anext(stream_generator)
headers_dict = {key: value for key, value in headers.items()}
headers_dict["X-Request-Id"] = request_id
if error:
raise error
return StreamingResponse(
stream_generator,
status_code=status,
Expand Down
Loading