diff --git a/examples/python_rs/llm/vllm/kv_router/router.py b/examples/python_rs/llm/vllm/kv_router/router.py index 29b9b06364..a28dc66fcf 100644 --- a/examples/python_rs/llm/vllm/kv_router/router.py +++ b/examples/python_rs/llm/vllm/kv_router/router.py @@ -23,7 +23,7 @@ from common.protocol import Tokens from vllm.logger import logger as vllm_logger -from dynemo.llm import KvRouter +from dynemo.llm import KvIndexer, KvMetricsAggregator, KvRouter from dynemo.runtime import DistributedRuntime, dynemo_endpoint, dynemo_worker WorkerId = str @@ -78,6 +78,60 @@ async def generate(self, request) -> AsyncIterator[WorkerId]: ) +class CustomRouter: + """ + Request handler for the generate endpoint + """ + + def __init__( + self, + indexer: KvIndexer, + metrics_aggregator: KvMetricsAggregator, + ): + self.indexer = indexer + self.metrics_aggregator = metrics_aggregator + + def _cost_function(self, scores, metrics): + # naive cost function for demonstration purposes + current_best = ("", 0) + for worker_id, score in scores.scores.items(): + if score > current_best[1]: + current_best = (worker_id, score) + for endpoint in metrics.endpoints: + if endpoint.worker_id == current_best[0]: + print(f"Metrics of endpoint: {endpoint.worker_id}") + print( + f"request slot usage: {endpoint.request_active_slots} / {endpoint.request_total_slots}" + ) + print( + f"KV block usage: {endpoint.kv_active_blocks} / {endpoint.kv_total_blocks}" + ) + return current_best[0] + + @dynemo_endpoint(Tokens, WorkerId) + async def generate(self, request) -> AsyncIterator[WorkerId]: + lora_id = 0 + worker_id = "" + try: + scores = await self.indexer.find_matches_for_request( + request.tokens, lora_id + ) + metrics = await self.metrics_aggregator.get_metrics() + worker_id = self._cost_function(scores, metrics) + + # [NOTE][TODO] Now that the scheduler may return more error messages, + # now we are catching all exceptions and logging them. Should have + # catch specific router exceptions once we have dedicated types. + except Exception as e: + vllm_logger.info(f"{e}") + worker_id = "" + vllm_logger.exception(f"Error during worker selection: {e}") + + vllm_logger.info(f"Scheduling to worker_id: {worker_id}") + + yield str(worker_id) + + @dynemo_worker() async def worker(runtime: DistributedRuntime, args: Namespace): """ @@ -116,10 +170,17 @@ async def worker(runtime: DistributedRuntime, args: Namespace): router_component = runtime.namespace("dynemo").component("router") await router_component.create_service() - router = KvRouter(runtime, kv_listener) - endpoint = router_component.endpoint("generate") - await endpoint.serve_endpoint(Router(router, args.routing_strategy).generate) + + if args.custom_router: + indexer = KvIndexer(kv_listener) + metrics_aggregator = KvMetricsAggregator(kv_listener) + await endpoint.serve_endpoint( + CustomRouter(indexer, metrics_aggregator).generate + ) + else: + router = KvRouter(runtime, kv_listener) + await endpoint.serve_endpoint(Router(router, args.routing_strategy).generate) if __name__ == "__main__": @@ -147,6 +208,12 @@ async def worker(runtime: DistributedRuntime, args: Namespace): default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B", help="Model that is being served", ) + parser.add_argument( + "--custom-router", + type=bool, + default=False, + help="Whether to use custom router or not", + ) args = parser.parse_args() asyncio.run(worker(args)) diff --git a/lib/bindings/python/rust/lib.rs b/lib/bindings/python/rust/lib.rs index f60afc104b..a47f09e16f 100644 --- a/lib/bindings/python/rust/lib.rs +++ b/lib/bindings/python/rust/lib.rs @@ -70,6 +70,11 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; engine::add_to_module(m)?; diff --git a/lib/bindings/python/rust/llm/kv.rs b/lib/bindings/python/rust/llm/kv.rs index ffb3c93f27..ccb15296aa 100644 --- a/lib/bindings/python/rust/llm/kv.rs +++ b/lib/bindings/python/rust/llm/kv.rs @@ -13,7 +13,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; + use super::*; +use llm_rs::kv_router::indexer::KvIndexerInterface; +use tracing; #[pyclass] pub(crate) struct KvRouter { @@ -106,3 +110,160 @@ impl KvMetricsPublisher { .map_err(to_pyerr) } } + +#[pyclass] +#[derive(Clone)] +pub(crate) struct OverlapScores { + inner: llm_rs::kv_router::indexer::OverlapScores, +} + +#[pymethods] +impl OverlapScores { + #[getter] + fn scores(&self) -> HashMap { + self.inner.scores.clone() + } + + #[getter] + fn frequencies(&self) -> Vec { + self.inner.frequencies.clone() + } +} + +#[pyclass] +pub(crate) struct KvIndexer { + inner: Arc, +} + +#[pymethods] +impl KvIndexer { + #[new] + fn new(component: Component) -> PyResult { + let runtime = pyo3_async_runtimes::tokio::get_runtime(); + runtime.block_on(async { + let kv_subject = component + .inner + .event_subject(llm_rs::kv_router::KV_EVENT_SUBJECT); + let inner: Arc = + llm_rs::kv_router::indexer::KvIndexer::new( + component.inner.drt().runtime().child_token(), + ) + .into(); + let mut kv_events_rx = component + .inner + .drt() + .nats_client() + .client() + .subscribe(kv_subject) + .await + .map_err(to_pyerr)?; + let kv_events_tx = inner.event_sender(); + + // [FIXME] this is the added functionality to the indexer to subscribe to kv events, + // should have been made to a trait and implemented here? i.e. AsyncEngine style + tokio::spawn(async move { + while let Some(event) = kv_events_rx.next().await { + let event: llm_rs::kv_router::indexer::RouterEvent = + serde_json::from_slice(&event.payload).unwrap(); + tracing::debug!("received kv event: {:?}", event); + if let Err(e) = kv_events_tx.send(event).await { + tracing::trace!( + "failed to send kv event to indexer; shutting down: {:?}", + e + ); + } + } + }); + Ok(Self { inner }) + }) + } + + fn find_matches_for_request<'p>( + &self, + py: Python<'p>, + token_ids: Vec, + _lora_id: u64, + ) -> PyResult> { + let indexer = self.inner.clone(); + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let rs_overlap_scores = indexer + .find_matches_for_request(token_ids.as_slice()) + .await + .map_err(to_pyerr)?; + Ok(OverlapScores { + inner: rs_overlap_scores, + }) + }) + } +} + +#[pyclass] +#[derive(Clone)] +pub(crate) struct EndpointKvMetrics { + #[pyo3(get, set)] + pub worker_id: i64, + #[pyo3(get, set)] + pub request_active_slots: u64, + #[pyo3(get, set)] + pub request_total_slots: u64, + #[pyo3(get, set)] + pub kv_active_blocks: u64, + #[pyo3(get, set)] + pub kv_total_blocks: u64, +} + +#[pyclass] +#[derive(Clone)] +pub(crate) struct AggregatedMetrics { + #[pyo3(get, set)] + pub endpoints: Vec, + #[pyo3(get, set)] + pub load_avg: f64, + #[pyo3(get, set)] + pub load_std: f64, +} + +#[pyclass] +pub(crate) struct KvMetricsAggregator { + inner: Arc, +} + +#[pymethods] +impl KvMetricsAggregator { + #[new] + fn new(component: Component) -> PyResult { + let runtime = pyo3_async_runtimes::tokio::get_runtime(); + runtime.block_on(async { + let inner = llm_rs::kv_router::metrics_aggregator::KvMetricsAggregator::new( + component.inner.clone(), + component.inner.drt().runtime().child_token(), + ) + .await; + Ok(Self { + inner: inner.into(), + }) + }) + } + + fn get_metrics<'p>(&self, py: Python<'p>) -> PyResult> { + let endpoints = self.inner.get_endpoints(); + let endpoint_kv_metrics = endpoints + .endpoints + .iter() + .map(|x| EndpointKvMetrics { + worker_id: x.worker_id(), + request_active_slots: x.data.request_active_slots, + request_total_slots: x.data.request_total_slots, + kv_active_blocks: x.data.kv_active_blocks, + kv_total_blocks: x.data.kv_total_blocks, + }) + .collect(); + pyo3_async_runtimes::tokio::future_into_py(py, async move { + Ok(AggregatedMetrics { + endpoints: endpoint_kv_metrics, + load_avg: endpoints.load_avg, + load_std: endpoints.load_std, + }) + }) + } +} diff --git a/lib/bindings/python/src/dynemo/_core.pyi b/lib/bindings/python/src/dynemo/_core.pyi index 766e380233..cd26f95c89 100644 --- a/lib/bindings/python/src/dynemo/_core.pyi +++ b/lib/bindings/python/src/dynemo/_core.pyi @@ -233,3 +233,55 @@ class Backend: Start the backend engine and requests to the downstream LLM engine """ ... + + +class OverlapScores: + """ + A collection of prefix matching scores of workers for a given token ids. + 'scores' is a map of worker id to the score which is the number of matching blocks. + """ + + ... + +class KvIndexer: + """ + A KV Indexer that tracks KV Events emitted by workers. Events include add_block and remove_block. + """ + + ... + + def __init__(self, component: Component) -> None: + """ + Create a `KvIndexer` object + """ + + def find_matches_for_request(self, token_ids: List[int], lora_id: int) -> OverlapScores: + """ + Return the overlapping scores of workers for the given token ids. + """ + ... + +class AggregatedMetrics: + """ + A collection of metrics of the endpoints + """ + + ... + +class KvMetricsAggregator: + """ + A metrics aggregator will collect KV metrics of the endpoints. + """ + + ... + + def __init__(self, component: Component) -> None: + """ + Create a `KvMetricsAggregator` object + """ + + def get_metrics(self) -> AggregatedMetrics: + """ + Return the aggregated metrics of the endpoints. + """ + ... diff --git a/lib/bindings/python/src/dynemo/llm/__init__.py b/lib/bindings/python/src/dynemo/llm/__init__.py index 2c9abcc65e..9f6ce1e45b 100644 --- a/lib/bindings/python/src/dynemo/llm/__init__.py +++ b/lib/bindings/python/src/dynemo/llm/__init__.py @@ -13,5 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dynemo._core import KvIndexer as KvIndexer +from dynemo._core import KvMetricsAggregator as KvMetricsAggregator from dynemo._core import KvMetricsPublisher as KvMetricsPublisher from dynemo._core import KvRouter as KvRouter diff --git a/lib/llm/src/kv_router.rs b/lib/llm/src/kv_router.rs index 433f6589e1..cc2d43b0c6 100644 --- a/lib/llm/src/kv_router.rs +++ b/lib/llm/src/kv_router.rs @@ -21,6 +21,7 @@ use tokio_util::sync::CancellationToken; use tracing; pub mod indexer; +pub mod metrics_aggregator; pub mod protocols; pub mod publisher; pub mod scheduler; diff --git a/lib/llm/src/kv_router/metrics_aggregator.rs b/lib/llm/src/kv_router/metrics_aggregator.rs new file mode 100644 index 0000000000..9233c61017 --- /dev/null +++ b/lib/llm/src/kv_router/metrics_aggregator.rs @@ -0,0 +1,149 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::{Arc, Mutex}; + +pub use crate::kv_router::protocols::ForwardPassMetrics; + +use crate::kv_router::scheduler::{Endpoint, Service}; +use crate::kv_router::ProcessedEndpoints; +use dynemo_runtime::component::Component; +use std::time::Duration; +use tokio_util::sync::CancellationToken; + +pub struct KvMetricsAggregator { + pub service_name: String, + pub endpoints: Arc>, +} + +impl KvMetricsAggregator { + pub async fn new(component: Component, cancellation_token: CancellationToken) -> Self { + let (ep_tx, mut ep_rx) = tokio::sync::mpsc::channel(128); + + tokio::spawn(collect_endpoints( + component.drt().nats_client().clone(), + component.service_name(), + ep_tx, + cancellation_token.clone(), + )); + + tracing::trace!("awaiting the start of the background endpoint subscriber"); + let endpoints = Arc::new(Mutex::new(ProcessedEndpoints::default())); + let endpoints_clone = endpoints.clone(); + tokio::spawn(async move { + tracing::debug!("scheduler background task started"); + loop { + match ep_rx.recv().await { + Some(endpoints) => match endpoints_clone.lock() { + Ok(mut shared_endpoint) => { + *shared_endpoint = endpoints; + } + Err(e) => { + tracing::error!("Failed to acquire lock on endpoints: {:?}", e); + } + }, + None => { + tracing::warn!("endpoint subscriber shutdown"); + break; + } + }; + } + + tracing::trace!("background endpoint subscriber shutting down"); + }); + Self { + service_name: component.service_name(), + endpoints, + } + } + + pub fn get_endpoints(&self) -> ProcessedEndpoints { + match self.endpoints.lock() { + Ok(endpoints) => endpoints.clone(), + Err(e) => { + tracing::error!("Failed to acquire lock on endpoints: {:?}", e); + ProcessedEndpoints::default() + } + } + } +} + +async fn collect_endpoints( + nats_client: dynemo_runtime::transports::nats::Client, + service_name: String, + ep_tx: tokio::sync::mpsc::Sender, + cancel: CancellationToken, +) { + loop { + tokio::select! { + _ = cancel.cancelled() => { + tracing::debug!("cancellation token triggered"); + break; + } + _ = tokio::time::sleep(Duration::from_secs(1)) => { + tracing::trace!("collecting endpoints for service: {}", service_name); + } + } + + let values = match nats_client + .get_endpoints(&service_name, Duration::from_secs(1)) + .await + { + Ok(v) => v, + Err(e) => { + tracing::warn!("Failed to retrieve endpoints for {}: {:?}", service_name, e); + continue; + } + }; + + tracing::debug!("values: {:?}", values); + let services: Vec = values + .into_iter() + .filter(|v| !v.is_empty()) + .filter_map(|v| match serde_json::from_slice::(&v) { + Ok(service) => Some(service), + Err(e) => { + tracing::warn!("For value: {:?} \nFailed to parse service: {:?}", v, e); + None + } + }) + .collect(); + tracing::debug!("services: {:?}", services); + + let endpoints: Vec = services + .into_iter() + .flat_map(|s| s.endpoints) + .filter(|s| s.data.is_some()) + .map(|s| Endpoint { + name: s.name, + subject: s.subject, + data: s.data.unwrap(), + }) + .collect(); + tracing::debug!("endpoints: {:?}", endpoints); + + tracing::trace!( + "found {} endpoints for service: {}", + endpoints.len(), + service_name + ); + + let processed = ProcessedEndpoints::new(endpoints); + if ep_tx.send(processed).await.is_err() { + tracing::trace!("failed to send processed endpoints; shutting down"); + break; + } + } +} diff --git a/lib/llm/src/kv_router/scoring.rs b/lib/llm/src/kv_router/scoring.rs index 970acdb225..5b837dab5c 100644 --- a/lib/llm/src/kv_router/scoring.rs +++ b/lib/llm/src/kv_router/scoring.rs @@ -20,7 +20,7 @@ use std::collections::HashSet; use crate::kv_router::scheduler::Endpoint; -#[derive(Debug, Default, Serialize, Deserialize)] +#[derive(Debug, Default, Serialize, Deserialize, Clone)] pub struct ProcessedEndpoints { pub endpoints: Vec, pub worker_ids: Vec,