diff --git a/benchmarks/multi-round-qa/multi-round-qa.py b/benchmarks/multi-round-qa/multi-round-qa.py index e7d024c88..3c0b65dcd 100644 --- a/benchmarks/multi-round-qa/multi-round-qa.py +++ b/benchmarks/multi-round-qa/multi-round-qa.py @@ -40,6 +40,9 @@ class WorkloadConfig: # Whether to include user id in request header enable_user_id: bool + # Max number of unfinished queries allowed (None means no limit) + max_unfinished_queries: Optional[int] + @dataclass class UserConfig: @@ -419,6 +422,13 @@ def step(self, timestamp: float, executor: RequestExecutor): if self.start_time is None: self.start_time = timestamp + pending_queries = len([s for s in self.sessions if s.has_unfinished_request]) + # Only check limit if max_unfinished_queries is set + if (self.workload_config.max_unfinished_queries is not None and + pending_queries > self.workload_config.max_unfinished_queries): + logger.info(f"unfinished queries >{self.workload_config.max_unfinished_queries}, waiting") + return + if timestamp - self.last_user_join > self.gap_between_users: self._create_user_session() self.last_user_join = timestamp @@ -625,6 +635,12 @@ def parse_arguments() -> WorkloadConfig: parser.add_argument( "--sharegpt", action="store_true", help="Whether to use ShareGPT dataset" ) + parser.add_argument( + "--max-unfinished-queries", + type=int, + default=None, + help="Maximum number of unfinished queries allowed (default: no limit)", + ) args = parser.parse_args() return args @@ -675,6 +691,7 @@ def main(): qps=args.qps, model=args.model, enable_user_id=args.request_with_user_id, + max_unfinished_queries=args.max_unfinished_queries, ) manager = UserSessionManager( diff --git a/src/vllm_router/app.py b/src/vllm_router/app.py index 0713e9c0f..580e344b7 100644 --- a/src/vllm_router/app.py +++ b/src/vllm_router/app.py @@ -109,6 +109,18 @@ async def lifespan(app: FastAPI): dyn_cfg_watcher.close() +def create_instance_id_to_url(lmcache_instances, static_backends): + if lmcache_instances is None or static_backends is None: + return None + instance_ids = [s.strip() for s in lmcache_instances.split(',') if s.strip()] + urls = parse_static_urls(static_backends) + if not instance_ids or not urls: + return None + if len(instance_ids) != len(urls): + raise ValueError("length of lmcache-instances & static-backends mismatched") + return dict(zip(instance_ids, urls)) + + def initialize_all(app: FastAPI, args): """ Initialize all the components of the router with the given arguments. @@ -206,6 +218,10 @@ def initialize_all(app: FastAPI, args): prefill_model_labels=args.prefill_model_labels, decode_model_labels=args.decode_model_labels, kv_aware_threshold=args.kv_aware_threshold, + tokenizer=args.tokenizer, + enable_shared_cache=args.enable_shared_cache, + instance_id_to_url=create_instance_id_to_url(args.lmcache_instances, + args.static_backends), ) # Initialize feature gates diff --git a/src/vllm_router/parsers/parser.py b/src/vllm_router/parsers/parser.py index 8b12cf983..a030b7489 100644 --- a/src/vllm_router/parsers/parser.py +++ b/src/vllm_router/parsers/parser.py @@ -20,6 +20,7 @@ from vllm_router.parsers.yaml_utils import ( read_and_process_yaml_config_file, ) +from vllm_router.routers.routing_logic import RoutingLogic from vllm_router.version import __version__ try: @@ -203,13 +204,7 @@ def parse_args(): parser.add_argument( "--routing-logic", type=str, - choices=[ - "roundrobin", - "session", - "kvaware", - "prefixaware", - "disaggregated_prefill", - ], + choices=[routing for routing in RoutingLogic], help="The routing logic to use", ) parser.add_argument( @@ -218,12 +213,30 @@ def parse_args(): default=9000, help="The port of the LMCache controller.", ) + parser.add_argument( + "--lmcache-instances", + type=str, + default=None, + help="The instance id in the lmcache config files, must be with the length of static-backends," + " separated by commas. E.g., instance_0,instance_1", + ) parser.add_argument( "--session-key", type=str, default=None, help="The key (in the header) to identify a session.", ) + parser.add_argument( + "--tokenizer", + type=str, + default=None, + help="The tokenizer model.", + ) + parser.add_argument( + "--enable-shared-cache", + action="store_true", + help="Enable shared KV Cache.", + ) parser.add_argument( "--callbacks", type=str, diff --git a/src/vllm_router/routers/routing_logic.py b/src/vllm_router/routers/routing_logic.py index fc8bf4039..b750722d9 100644 --- a/src/vllm_router/routers/routing_logic.py +++ b/src/vllm_router/routers/routing_logic.py @@ -19,10 +19,12 @@ import random import threading import uuid -from typing import Dict, List +import traceback +from typing import Dict, List, Optional, Tuple, Union import requests from fastapi import Request +from urllib.parse import urlparse try: from transformers import AutoTokenizer @@ -33,6 +35,7 @@ from lmcache.v1.cache_controller import controller_manager from lmcache.v1.cache_controller.message import ( LookupMsg, + FullLookupMsg, QueryInstMsg, ) except ImportError: @@ -42,18 +45,45 @@ from vllm_router.log import init_logger from vllm_router.service_discovery import EndpointInfo from vllm_router.stats.engine_stats import EngineStats -from vllm_router.stats.request_stats import RequestStats +from vllm_router.stats.request_stats import RequestStats, RequestStatsCacheInfo, prefill_workload from vllm_router.utils import SingletonABCMeta logger = init_logger(__name__) +def extract_prompt(request_json: Dict): + """Extract prompt message from the request json object.""" + if "messages" in request_json: + # Get the last message from the messages array + messages = request_json["messages"] + if messages: + # Concatenate all message content + prompt_parts = [] + for message in messages: + content = message.get("content", "") + if isinstance(content, list): + # Handle multimodal messages + text_content = " ".join( + part.get("text", "") + for part in content + if part.get("type") == "text" + ) + prompt_parts.append(text_content) + elif content is not None: + prompt_parts.append(content) + return "\n".join(prompt_parts) + return "" + # Handle regular completions + return request_json["prompt"] + + class RoutingLogic(str, enum.Enum): ROUND_ROBIN = "roundrobin" SESSION_BASED = "session" KVAWARE = "kvaware" PREFIXAWARE = "prefixaware" DISAGGREGATED_PREFILL = "disaggregated_prefill" + TTFT = "ttft" class RoutingInterface(metaclass=SingletonABCMeta): @@ -110,7 +140,7 @@ def route_request( engine_stats: Dict[str, EngineStats], request_stats: Dict[str, RequestStats], request: Request, - ) -> str: + ) -> Union[str, Tuple[str, RequestStatsCacheInfo]]: """ Route the request to the appropriate engine URL @@ -231,18 +261,21 @@ def __init__( lmcache_controller_port: int, session_key: str, kv_aware_threshold: int = 2000, + tokenizer_name: Optional[str] = None, + instance_id_to_url: Optional[Dict[str, str]] = None, ): self.lmcache_controller_port = lmcache_controller_port logger.info( f"Initializing KvawareRouter with port: {self.lmcache_controller_port}" ) self.kv_manager = controller_manager.LMCacheControllerManager( - f"0.0.0.0:{self.lmcache_controller_port}" + {"pull" : f"0.0.0.0:{lmcache_controller_port}", "reply" : None} ) self.req_id = 0 - self.instance_id_to_ip = {} + self.instance_id_to_url = instance_id_to_url or {} self.session_key = session_key self.hash_ring = HashRing() + self.tokenizer_name = tokenizer_name self.tokenizer = None self.threshold = kv_aware_threshold @@ -254,8 +287,10 @@ def start_kv_manager(self): self.thread = threading.Thread(target=self.loop.run_forever, daemon=True) self.thread.start() asyncio.run_coroutine_threadsafe(self.kv_manager.start_all(), self.loop) + if self.tokenizer_name is not None: + self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) - def query_manager(self, msg) -> str: + def query_manager(self, msg): """ Get the instance id for the given message """ @@ -337,7 +372,7 @@ async def route_request( return url else: queried_instance_ids = [info for info in instance_id.layout_info] - if queried_instance_ids[0] not in self.instance_id_to_ip: + if queried_instance_ids[0] not in self.instance_id_to_url: for endpoint in endpoints: event_id = "QueryInst" + str(uuid.uuid4()) logger.debug(f"QueryInst event id: {event_id}") @@ -349,14 +384,14 @@ async def route_request( ) endpoint_instance_id = await self.query_manager(query_message) - self.instance_id_to_ip[endpoint_instance_id.instance_id] = ( + self.instance_id_to_url[endpoint_instance_id.instance_id] = ( endpoint.url ) - logger.info(f"Instance id to ip: {self.instance_id_to_ip}") + logger.info(f"Instance id to ip: {self.instance_id_to_url}") logger.info( f"Routing request to {queried_instance_ids[0]} found by kvaware router" ) - return self.instance_id_to_ip[queried_instance_ids[0]] + return self.instance_id_to_url[queried_instance_ids[0]] class PrefixAwareRouter(RoutingInterface): @@ -399,33 +434,7 @@ async def route_request( request_json (Dict): The request body (needed for finding the longest prefix match) """ - - # Handle chat completions - if "messages" in request_json: - # Get the last message from the messages array - messages = request_json["messages"] - if messages: - # Concatenate all message content - prompt_parts = [] - for message in messages: - content = message.get("content", "") - if isinstance(content, list): - # Handle multimodal messages - text_content = " ".join( - part.get("text", "") - for part in content - if part.get("type") == "text" - ) - prompt_parts.append(text_content) - elif content is not None: - prompt_parts.append(content) - prompt = "\n".join(prompt_parts) - else: - prompt = "" - else: - # Handle regular completions - prompt = request_json["prompt"] - + prompt = extract_prompt(request_json) available_endpoints = set(endpoint.url for endpoint in endpoints) _, matched_endpoint = await self.hashtrie.longest_prefix_match( prompt, available_endpoints @@ -481,6 +490,196 @@ def route_request( return decoder_endpoints[0].url +class TtftRouter(RoutingInterface): + """ + Route the request to the qppropriate engine URL by the least estimated TTFT. + """ + + def __init__( + self, + lmcache_controller_port: int, + session_key: str, + tokenizer_name: Optional[str] = None, + enable_shared_cache : bool = False, + instance_id_to_url: Optional[Dict[str, str]] = None, + ): + logger.info( + f"Initializing TtftRouter with lmcache addr: 0.0.0.0:{lmcache_controller_port}" + ) + self.kv_manager = controller_manager.LMCacheControllerManager( + {"pull":f"0.0.0.0:{lmcache_controller_port}", "reply" : None} + ) + self.instance_id_to_url = instance_id_to_url or {} + self.session_key = session_key + self.hash_ring = HashRing() + self.tokenizer_name = tokenizer_name + self.tokenizer = None + self.enable_shared_cache = enable_shared_cache + + def start_kv_manager(self): + """ + Start the kv manager + """ + self.loop = asyncio.new_event_loop() + self.thread = threading.Thread(target=self.loop.run_forever, daemon=True) + self.thread.start() + asyncio.run_coroutine_threadsafe(self.kv_manager.start_all(), self.loop) + if self.tokenizer_name is not None: + self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) + + async def route_request( + self, + endpoints: List[EndpointInfo], + engine_stats: Dict[str, EngineStats], + request_stats: Dict[str, RequestStats], + request: Request, + request_json: Dict, + ) -> Tuple[str, RequestStatsCacheInfo]: + """ + Route the request to the appropriate engine URL by where the KV cache + of the longest prefix match is found. + If there is no session id in the reqest header, it will pick a server + with round robin. + + Args: + endpoints (List[EndpointInfo]): The list of engine URLs + engine_stats (Dict[str, EngineStats]): The engine stats indicating + the 'physical' load of each engine + request_stats (Dict[str, RequestStats]): The request stats + indicating the request-level performance of each engine + request (Request): The incoming request + request_json (Dist): The request body (needed for finding the + longest prefix match) + """ + if self.tokenizer is None: + # fallback to use the model of the first endpoint as tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(endpoints[0].model_names[0]) + + token_ids = self.tokenizer.encode(extract_prompt(request_json)) + cache_info = RequestStatsCacheInfo() + cache_info.num_prefix_tokens = len(token_ids) + try: + if request_stats is None: + raise ValueError("no request stats was provided") + msg = FullLookupMsg(event_id="", tokens=token_ids) + ret_msg = await self.kv_manager.handle_orchestration_message(msg) + matched_infos = ret_msg.matched_info + if matched_infos is None: + matched_infos = [] + if self.enable_shared_cache: + best_matched_info = self._find_best_matched(matched_infos) + else: + best_matched_info = None + best_inst_url, num_cached_tokens = \ + await self._find_best_inst(endpoints, matched_infos, + best_matched_info, request_stats, + len(token_ids)) + cache_info.num_cached_tokens = num_cached_tokens + return best_inst_url, cache_info + except ValueError: + logger.info("Fallback to QPS routing due to:") + logger.info(traceback.format_exc()) + cache_info.num_cached_tokens = 0 + return self._fallback_routing(endpoints, request_stats, request), cache_info + + def _find_best_matched(self, matched_infos): + if not matched_infos: + return None + best_matched_info = None + for instance_info in matched_infos: + if best_matched_info is None or instance_info[1][-1][1] > best_matched_info[1][-1][1]: + best_matched_info = instance_info + if best_matched_info is None: + raise ValueError("no best matched instance was found") + return best_matched_info + + async def _find_best_inst(self, endpoints, matched_infos, best_matched_info, + request_stats, num_prefix_tokens): + matched_stats = [] + matched_urls = [] + for matched_info in matched_infos: + url = await self._get_instance_url(endpoints, matched_info[0]) + stats = request_stats.get(url, None) + matched_urls.append(url) + matched_stats.append(stats) + + # Assume the computation speed of all endpoints are equal + # comparing workload is equivalent to comparing TTFT + + # cache matched pass + min_workload = math.inf + min_workload_url = None + min_workload_cached_tokens = 0 + for i, matched_info in enumerate(matched_infos): + workload, cached_tokens = self._estimate_workload(matched_info, best_matched_info, + matched_stats[i], num_prefix_tokens) + if min_workload_url is None or workload <= min_workload: + min_workload = workload + min_workload_url = matched_urls[i] + min_workload_cached_tokens = cached_tokens + + # cache not matched pass + matched_url_set = set(matched_urls) + not_matched_endpoints = [endpoint for endpoint in endpoints + if endpoint.url not in matched_url_set] + for endpoint in not_matched_endpoints: + url = endpoint.url + stats = request_stats.get(url, None) + workload, cached_tokens = self._estimate_workload(None, best_matched_info, + stats, num_prefix_tokens) + if min_workload_url is None or workload <= min_workload: + min_workload = workload + min_workload_url = url + min_workload_cached_tokens = cached_tokens + + if min_workload_url is None: + raise ValueError(f"no best instance was found") + return min_workload_url, min_workload_cached_tokens + + def _estimate_workload(self, matched_info, best_matched_info, stats, num_prefix_tokens): + """Estimate prefill workload.""" + num_cache_tokens = 0 + if best_matched_info is not None: + num_cache_tokens = best_matched_info[1][-1][1] + elif matched_info is not None: + num_cache_tokens = matched_info[1][-1][1] + workload = ((stats.prefill_todo_workload if stats else 0) + + prefill_workload(num_prefix_tokens, num_cache_tokens)) + return workload, num_cache_tokens + + async def _get_instance_url(self, endpoints, instance_id): + url = self.instance_id_to_url.get(instance_id, None) + if url is not None: + return url + for endpoint in endpoints: + msg = QueryInstMsg( + event_id="", + ip=urlparse(endpoint.url).hostname + ) + ret_msg = await self.kv_manager.handle_orchestration_message(msg) + self.instance_id_to_url[ret_msg.instance_id] = endpoint.url + if ret_msg.instance_id == instance_id: + url = endpoint.url + if url is None: + raise ValueError(f"cannot resolve URL for {instance_id}") + return url + + def _fallback_routing(self, endpoints, request_stats, request): + session_id = request.headers.get(self.session_key, None) + logger.debug(f"Got session id: {session_id}") + + # Update the hash ring with the current list of endpoints + self._update_hash_ring(endpoints) + + if session_id is None: + # Route base on QPS if no session ID is present + url = self._qps_routing(endpoints, request_stats) + else: + # Use the hash ring to get the endpoint for the session ID + url = self.hash_ring.get_node(session_id) + return url + + # Instead of managing a global _global_router, we can define the initialization functions as: def initialize_routing_logic( routing_logic: RoutingLogic, *args, **kwargs @@ -497,6 +696,8 @@ def initialize_routing_logic( kwargs.get("lmcache_controller_port"), kwargs.get("session_key"), kwargs.get("kv_aware_threshold"), + kwargs.get("tokenizer"), + kwargs.get("instance_id_to_url"), ) router.start_kv_manager() return router @@ -508,6 +709,17 @@ def initialize_routing_logic( return DisaggregatedPrefillRouter( kwargs.get("prefill_model_labels"), kwargs.get("decode_model_labels") ) + elif routing_logic == RoutingLogic.TTFT: + logger.info("Initializing ttft routing logic") + router = TtftRouter( + kwargs.get("lmcache_controller_port"), + kwargs.get("session_key"), + kwargs.get("tokenizer"), + kwargs.get("enable_shared_cache"), + kwargs.get("instance_id_to_url"), + ) + router.start_kv_manager() + return router else: raise ValueError(f"Invalid routing logic {routing_logic}") @@ -535,6 +747,7 @@ def get_routing_logic() -> RoutingInterface: KvawareRouter, PrefixAwareRouter, DisaggregatedPrefillRouter, + TtftRouter, ): if cls in SingletonABCMeta._instances: return cls() diff --git a/src/vllm_router/services/request_service/request.py b/src/vllm_router/services/request_service/request.py index 6ae025b2e..92939a92e 100644 --- a/src/vllm_router/services/request_service/request.py +++ b/src/vllm_router/services/request_service/request.py @@ -29,6 +29,7 @@ DisaggregatedPrefillRouter, KvawareRouter, PrefixAwareRouter, + TtftRouter, ) from vllm_router.service_discovery import get_service_discovery from vllm_router.services.request_service.rewriter import ( @@ -61,6 +62,7 @@ async def process_request( endpoint, background_tasks: BackgroundTasks, debug_request=None, + cache_info=None, ): """ Process a request by sending it to the chosen backend. @@ -73,7 +75,7 @@ async def process_request( endpoint: The endpoint to send the request to on the backend. debug_request: The original request object from the client, used for optional debug logging. - + cache_info: Cache information. Yields: The response headers and status code, followed by the response content. @@ -84,7 +86,7 @@ async def process_request( total_len = 0 start_time = time.time() request.app.state.request_stats_monitor.on_new_request( - backend_url, request_id, start_time + backend_url, request_id, start_time, cache_info ) # Check if this is a streaming request try: @@ -228,7 +230,8 @@ async def route_general_request( ) engine_stats = request.app.state.engine_stats_scraper.get_engine_stats() request_stats = request.app.state.request_stats_monitor.get_request_stats( - time.time() + time.time(), + [endpoint.url for endpoint in endpoints], ) else: endpoints = list( @@ -262,24 +265,31 @@ async def route_general_request( headers={"X-Request-Id": request_id}, ) + route_result = None + cache_info = None + logger.debug(f"Routing request {request_id} for model: {requested_model}") if request_endpoint: server_url = endpoints[0].url logger.debug( f"Routing request {request_id} to engine with Id: {endpoints[0].Id}" ) - - elif isinstance(request.app.state.router, KvawareRouter) or isinstance( - request.app.state.router, PrefixAwareRouter - ): - server_url = await request.app.state.router.route_request( + elif isinstance(request.app.state.router, (KvawareRouter, PrefixAwareRouter, TtftRouter)): + route_result = await request.app.state.router.route_request( endpoints, engine_stats, request_stats, request, request_json ) else: - server_url = request.app.state.router.route_request( + route_result = request.app.state.router.route_request( endpoints, engine_stats, request_stats, request ) + if isinstance(route_result, (tuple, list)): + server_url = route_result[0] + if len(route_result) > 1: + cache_info = route_result[1] + elif isinstance(route_result, str): + server_url = route_result + curr_time = time.time() # Extract actual session ID from request headers for logging session_key = ( @@ -310,6 +320,7 @@ async def route_general_request( request_id, endpoint, background_tasks, + cache_info=cache_info, ) headers, status = await anext(stream_generator) headers_dict = {key: value for key, value in headers.items()} diff --git a/src/vllm_router/stats/request_stats.py b/src/vllm_router/stats/request_stats.py index f0409b912..405799936 100644 --- a/src/vllm_router/stats/request_stats.py +++ b/src/vllm_router/stats/request_stats.py @@ -14,13 +14,22 @@ import time from collections import deque from dataclasses import dataclass -from typing import Deque, Dict, Tuple +from numbers import Number +from typing import Deque, Dict, Tuple, Set, List from vllm_router.log import init_logger logger = init_logger(__name__) +def prefill_workload(num_prefix_tokens, num_cached_tokens): + """Calculate prefill computation workload with trapezoid area formula""" + top = num_cached_tokens + 1 + bottom = num_prefix_tokens + height = num_prefix_tokens - num_cached_tokens + return (top + bottom) * height // 2 + + class SingletonMeta(type): _instances = {} @@ -53,6 +62,8 @@ class RequestStats: avg_itl: float # Number of swapped requests (moved from GPU to CPU) num_swapped_requests: int + # Unfinished prefill computation workload + prefill_todo_workload: int class MovingAverageMonitor: @@ -103,6 +114,15 @@ def get_sum(self) -> float: return sum(self.values) +@dataclass +class RequestStatsCacheInfo: + """ + Cache information. + """ + num_prefix_tokens : int = 0 + num_cached_tokens : int = 0 + + class RequestStatsMonitor(metaclass=SingletonMeta): """ Monitors the request statistics of all serving engines. @@ -127,6 +147,8 @@ def __init__(self, sliding_window_size: float = None): self.request_start_time: Dict[Tuple[str, str], float] = {} # Record time when first token is received: (engine_url, request_id) -> timestamp self.first_token_time: Dict[Tuple[str, str], float] = {} + # The number of cached prefix tokens + self.cache_infos: Dict[Tuple[str, str], RequestStatsCacheInfo] = {} # Number of requests in different stages (from the start of the router) self.in_prefill_requests: Dict[str, int] = {} @@ -142,7 +164,10 @@ def __init__(self, sliding_window_size: float = None): self.first_query_time: float = None self._initialized = True - def on_new_request(self, engine_url: str, request_id: str, timestamp: float): + def on_new_request(self, engine_url: str, + request_id: str, + timestamp: float, + cache_info: RequestStatsCacheInfo = None): """ Tell the monitor that a new request has been created. @@ -150,9 +175,13 @@ def on_new_request(self, engine_url: str, request_id: str, timestamp: float): engine_url: The URL of the serving engine request_id: The global request ID timestamp: the timestamp when the request was created + cache_info: The cache information """ self.request_start_time[(engine_url, request_id)] = timestamp + if cache_info is not None: + self.cache_infos[(engine_url, request_id)] = cache_info + if engine_url not in self.in_prefill_requests: self.in_prefill_requests[engine_url] = 0 self.in_prefill_requests[engine_url] += 1 @@ -197,7 +226,9 @@ def on_request_response(self, engine_url: str, request_id: str, timestamp: float self.sliding_window_size ) # Update TTFT as time from request start to first token - ttft = timestamp - self.request_start_time[(engine_url, request_id)] + # ttft = timestamp - self.request_start_time[(engine_url, request_id)] + start_time = self.request_start_time[(engine_url, request_id)] + ttft = timestamp - start_time self.ttft_monitors[engine_url].update(timestamp, ttft) def on_request_complete(self, engine_url: str, request_id: str, timestamp: float): @@ -235,12 +266,13 @@ def on_request_swapped(self, engine_url: str, request_id: str, timestamp: float) self.swapped_requests[engine_url] = 0 self.swapped_requests[engine_url] += 1 - def get_request_stats(self, current_time: float) -> Dict[str, RequestStats]: + def get_request_stats(self, current_time: float, urls: List[str] = None) -> Dict[str, RequestStats]: """ Get the request statistics for each serving engine Args: current_time: The current timestamp in seconds + urls: The URLs of engines Returns: A dictionary where the key is the serving engine URL and the value @@ -248,10 +280,11 @@ def get_request_stats(self, current_time: float) -> Dict[str, RequestStats]: The TTFT and inter token latency will be -1 if there is no requests finished in the sliding window. """ + if urls is None: + urls = set(self.in_prefill_requests.keys()).union( + set(self.in_decoding_requests.keys()) + ) ret = {} - urls = set(self.in_prefill_requests.keys()).union( - set(self.in_decoding_requests.keys()) - ) for engine_url in urls: if engine_url not in self.qps_monitors: qps = -1 @@ -289,6 +322,8 @@ def get_request_stats(self, current_time: float) -> Dict[str, RequestStats]: else: swapped = 0 + prefill_todo_workload = self._get_prefill_todo_workload(engine_url) + ret[engine_url] = RequestStats( qps=qps, ttft=ttft, @@ -302,9 +337,19 @@ def get_request_stats(self, current_time: float) -> Dict[str, RequestStats]: avg_latency=avg_lat, avg_itl=avg_itl_val, num_swapped_requests=swapped, + prefill_todo_workload=prefill_todo_workload, ) return ret + def _get_prefill_todo_workload(self, engine_url: str) -> int: + amount = 0 + for (url, request_id), cache_info in self.cache_infos.items(): + if url != engine_url or (url, request_id) in self.first_token_time: + continue + amount += prefill_workload(cache_info.num_prefix_tokens, + cache_info.num_cached_tokens) + return amount + def initialize_request_stats_monitor(sliding_window_size: float): return RequestStatsMonitor(sliding_window_size)