From 9e5a0e4102b904f53e1daf137c506416e8ed08c8 Mon Sep 17 00:00:00 2001 From: Guan Luo Date: Fri, 28 Feb 2025 14:18:30 -0800 Subject: [PATCH 01/13] wip: exposing lower level KV components --- lib/bindings/python/rust/llm/kv.rs | 72 +++++++++++++++ lib/llm/src/kv_router.rs | 1 + lib/llm/src/kv_router/metrics_aggregator.rs | 97 +++++++++++++++++++++ 3 files changed, 170 insertions(+) create mode 100644 lib/llm/src/kv_router/metrics_aggregator.rs diff --git a/lib/bindings/python/rust/llm/kv.rs b/lib/bindings/python/rust/llm/kv.rs index ffb3c93f27..941e56467f 100644 --- a/lib/bindings/python/rust/llm/kv.rs +++ b/lib/bindings/python/rust/llm/kv.rs @@ -13,7 +13,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; + use super::*; +use tracing; +use llm_rs::kv_router::indexer::KvIndexerInterface; #[pyclass] pub(crate) struct KvRouter { @@ -106,3 +110,71 @@ impl KvMetricsPublisher { .map_err(to_pyerr) } } + +#[pyclass] +pub(crate) struct OverlapScores(pub llm_rs::kv_router::indexer::OverlapScores); + +#[pymethods] +impl OverlapScores { + fn scores(&self) -> HashMap { + self.0.scores.clone() + } + + fn frequencies(&self) -> Vec { + self.0.frequencies.clone() + } +} + +#[pyclass] +pub(crate) struct KvIndexer { + inner: Arc, +} + +#[pymethods] +impl KvIndexer { + #[new] + fn new(component: Component, token: CancellationToken) -> PyResult { + let runtime = pyo3_async_runtimes::tokio::get_runtime(); + runtime.block_on(async { + let kv_subject = component.inner.event_subject(llm_rs::kv_router::KV_EVENT_SUBJECT); + let inner: Arc = llm_rs::kv_router::indexer::KvIndexer::new(token.inner).into(); + let mut kv_events_rx = component.inner.drt().nats_client().client().subscribe(kv_subject).await.map_err(to_pyerr)?; + let kv_events_tx = inner.event_sender(); + + // [FIXME] this is the added functionality to the indexer to subscribe to kv events, + // should have been made to a trait and implemented here? i.e. AsyncEngine style + tokio::spawn(async move { + while let Some(event) = kv_events_rx.next().await { + let event: llm_rs::kv_router::indexer::RouterEvent = serde_json::from_slice(&event.payload).unwrap(); + tracing::debug!("received kv event: {:?}", event); + if let Err(e) = kv_events_tx.send(event).await { + tracing::trace!("failed to send kv event to indexer; shutting down: {:?}", e); + } + } + }); + Ok(Self { inner }) + }) + } + + fn find_matches_for_request<'p>( + &self, + py: Python<'p>, + token_ids: Vec, + _lora_id: u64, + ) -> PyResult> { + let indexer = self.inner.clone(); + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let rs_overlap_scores = indexer + .find_matches_for_request(token_ids.as_slice()) + .await + .map_err(to_pyerr)?; + Ok(OverlapScores(rs_overlap_scores)) + }) + } +} + +// [WIP] this should be a rust class for metrics subscription, not really scheduler +#[pyclass] +pub(crate) struct KvMetricsSubscriber { + inner: Arc, +} diff --git a/lib/llm/src/kv_router.rs b/lib/llm/src/kv_router.rs index 433f6589e1..0014381c3f 100644 --- a/lib/llm/src/kv_router.rs +++ b/lib/llm/src/kv_router.rs @@ -25,6 +25,7 @@ pub mod protocols; pub mod publisher; pub mod scheduler; pub mod scoring; +pub mod metrics_aggregator; use crate::kv_router::{ indexer::{KvIndexer, KvIndexerInterface, RouterEvent}, diff --git a/lib/llm/src/kv_router/metrics_aggregator.rs b/lib/llm/src/kv_router/metrics_aggregator.rs new file mode 100644 index 0000000000..8fc89fa12f --- /dev/null +++ b/lib/llm/src/kv_router/metrics_aggregator.rs @@ -0,0 +1,97 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +pub use crate::kv_router::protocols::ForwardPassMetrics; + +use anyhow::Result; +use triton_distributed_runtime::pipeline::network::{ + ingress::push_endpoint::PushEndpoint, + PushWorkHandler, +}; + +use tokio::sync::watch; +use tokio_util::sync::CancellationToken; +use tracing as log; + +pub struct KvMetricsAggregator { + pub service_name: String, + + pub nats: nats::Client, + pub service_handler: Arc, + pub metrics_rx: watch::Receiver>, + pub cancellation_token: CancellationToken, +} + +/// version of crate +pub const VERSION: &str = env!("CARGO_PKG_VERSION"); + +impl KvRoutedIngress { + pub fn builder() -> KvRoutedIngressBuilder { + KvRoutedIngressBuilder::default() + } + + pub async fn start(self) -> Result<()> { + let worker_id = self.worker_id; + + log::trace!( + worker_id, + "Starting nats service: {}:{}", + self.service_name, + VERSION + ); + + let mut metrics_rx = self.metrics_rx; + let worker_id_clone = worker_id.clone(); + + let service = self + .nats + .client() + .service_builder() + .description("A handy min max service") + .stats_handler(move |name, stats| { + log::debug!( + worker_id = worker_id_clone.as_str(), + "[IN worker?] Stats for service {}: {:?}", + name, + stats + ); + let metrics = metrics_rx.borrow_and_update().clone(); + serde_json::to_value(&*metrics).unwrap() + }) + .start(self.service_name.as_str(), VERSION) + .await + .map_err(|e| anyhow::anyhow!("Failed to start service: {e}"))?; + + let group = service.group(self.service_name.as_str()); + + log::trace!(worker_id, "Starting endpoint: {}", worker_id); + + // creates an endpoint for the service + let service_endpoint = group + .endpoint(worker_id.clone()) + .await + .map_err(|e| anyhow::anyhow!("Failed to start endpoint: {e}"))?; + + let push_endpoint = PushEndpoint::builder() + .service_handler(self.service_handler) + .cancellation_token(self.cancellation_token) + .build() + .map_err(|e| anyhow::anyhow!("Failed to build push endpoint: {e}"))?; + + push_endpoint.start(service_endpoint).await + } +} From 6e26ccdc3e200c68f466159314c28781d7c9c8ac Mon Sep 17 00:00:00 2001 From: Guan Luo Date: Sun, 2 Mar 2025 01:22:34 -0800 Subject: [PATCH 02/13] wip: complete implementation --- lib/bindings/python/rust/lib.rs | 5 + lib/bindings/python/rust/llm/kv.rs | 63 ++++++- lib/bindings/python/src/dynemo/_core.pyi | 59 ++++++ .../python/src/dynemo/llm/__init__.py | 2 + lib/llm/src/kv_router/metrics_aggregator.rs | 171 ++++++++++++------ lib/llm/src/kv_router/scoring.rs | 2 +- 6 files changed, 241 insertions(+), 61 deletions(-) diff --git a/lib/bindings/python/rust/lib.rs b/lib/bindings/python/rust/lib.rs index f60afc104b..30ad7f2439 100644 --- a/lib/bindings/python/rust/lib.rs +++ b/lib/bindings/python/rust/lib.rs @@ -70,6 +70,11 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; engine::add_to_module(m)?; diff --git a/lib/bindings/python/rust/llm/kv.rs b/lib/bindings/python/rust/llm/kv.rs index 941e56467f..bbc72754e8 100644 --- a/lib/bindings/python/rust/llm/kv.rs +++ b/lib/bindings/python/rust/llm/kv.rs @@ -112,6 +112,7 @@ impl KvMetricsPublisher { } #[pyclass] +#[derive(Clone)] pub(crate) struct OverlapScores(pub llm_rs::kv_router::indexer::OverlapScores); #[pymethods] @@ -173,8 +174,64 @@ impl KvIndexer { } } -// [WIP] this should be a rust class for metrics subscription, not really scheduler #[pyclass] -pub(crate) struct KvMetricsSubscriber { - inner: Arc, +#[derive(Clone)] +pub(crate) struct EndpiontKvMetrics +{ + pub worker_ids: i64, + pub request_active_slots: u64, + pub request_total_slots: u64, + pub kv_active_blocks: u64, + pub kv_total_blocks: u64, +} + +#[pyclass] +#[derive(Clone)] +pub(crate) struct AggregatedMetrics { + pub endpoints: Vec, + pub load_avg: f64, + pub load_std: f64, +} + + +#[pyclass] +pub(crate) struct KvMetricsAggregator { + inner: Arc, +} + +#[pymethods] +impl KvMetricsAggregator { + #[new] + fn new(component: Component, token: CancellationToken) -> PyResult { + let runtime = pyo3_async_runtimes::tokio::get_runtime(); + runtime.block_on(async { + let inner = llm_rs::kv_router::metrics_aggregator::KvMetricsAggregator::new( + component.inner.clone(), + token.inner, + ) + .await; + Ok(Self { inner: inner.into() }) + }) + } + + fn get_metrics<'p>(&self, py: Python<'p>) -> PyResult> { + let endpoints = self.inner.get_endpoints(); + let endpiont_kv_metrics = endpoints + .endpoints + .iter() + .map(|x| EndpiontKvMetrics { + worker_ids: x.worker_id(), + request_active_slots: x.data.request_active_slots, + request_total_slots: x.data.request_total_slots, + kv_active_blocks: x.data.kv_active_blocks, + kv_total_blocks: x.data.kv_total_blocks, + }) + .collect(); + pyo3_async_runtimes::tokio::future_into_py(py, async move { + Ok(AggregatedMetrics { + endpoints: endpiont_kv_metrics, + load_avg: endpoints.load_avg, + load_std: endpoints.load_std, + })}) + } } diff --git a/lib/bindings/python/src/dynemo/_core.pyi b/lib/bindings/python/src/dynemo/_core.pyi index 766e380233..a777fd1848 100644 --- a/lib/bindings/python/src/dynemo/_core.pyi +++ b/lib/bindings/python/src/dynemo/_core.pyi @@ -233,3 +233,62 @@ class Backend: Start the backend engine and requests to the downstream LLM engine """ ... +class CancellationToken: + """ + A cancellation token is used to cancel an operation + """ + + ... + +class OverlapScores: + """ + A collection of scores for a given token ids + """ + + ... + +# [WIP] fix docs +class KvIndexer: + """ + A metrics publisher will provide KV metrics to the router. + """ + + ... + + def __init__(self, component: Component, token: CancellationToken) -> None: + """ + Create a `KvIndexer` object + """ + + def find_matches_for_request(self, token_ids: List[int], lora_id: int) -> OverlapScores: + """ + Similar to Component.create_service, but only service created through + this method will interact with KV router of the same component. + """ + ... + +class AggregatedMetrics: + """ + A collection of scores for a given token ids + """ + + ... + +class KvMetricsAggregator: + """ + A metrics publisher will provide KV metrics to the router. + """ + + ... + + def __init__(self, component: Component, token: CancellationToken) -> None: + """ + Create a `KvIndexer` object + """ + + def get_metrics(self) -> AggregatedMetrics: + """ + Similar to Component.create_service, but only service created through + this method will interact with KV router of the same component. + """ + ... diff --git a/lib/bindings/python/src/dynemo/llm/__init__.py b/lib/bindings/python/src/dynemo/llm/__init__.py index 2c9abcc65e..4a7f35c1ea 100644 --- a/lib/bindings/python/src/dynemo/llm/__init__.py +++ b/lib/bindings/python/src/dynemo/llm/__init__.py @@ -15,3 +15,5 @@ from dynemo._core import KvMetricsPublisher as KvMetricsPublisher from dynemo._core import KvRouter as KvRouter +from dynemo._core import KvIndexer as KvIndexer +from dynemo._core import KvMetricsAggregator as KvMetricsAggregator diff --git a/lib/llm/src/kv_router/metrics_aggregator.rs b/lib/llm/src/kv_router/metrics_aggregator.rs index 8fc89fa12f..316925dc85 100644 --- a/lib/llm/src/kv_router/metrics_aggregator.rs +++ b/lib/llm/src/kv_router/metrics_aggregator.rs @@ -13,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; +use std::sync::{Arc, Mutex}; pub use crate::kv_router::protocols::ForwardPassMetrics; @@ -25,73 +25,130 @@ use triton_distributed_runtime::pipeline::network::{ use tokio::sync::watch; use tokio_util::sync::CancellationToken; -use tracing as log; +use triton_distributed_runtime::{component::Component, DistributedRuntime}; +use crate::kv_router::ProcessedEndpoints; +use crate::kv_router::scheduler::{Service, Endpoint}; +use std::time::Duration; pub struct KvMetricsAggregator { pub service_name: String, - - pub nats: nats::Client, - pub service_handler: Arc, - pub metrics_rx: watch::Receiver>, - pub cancellation_token: CancellationToken, + pub endpoints: Arc>, } -/// version of crate -pub const VERSION: &str = env!("CARGO_PKG_VERSION"); - -impl KvRoutedIngress { - pub fn builder() -> KvRoutedIngressBuilder { - KvRoutedIngressBuilder::default() +impl KvMetricsAggregator { + pub async fn new( + component: Component, + cancellation_token: CancellationToken + ) -> Self { + let (ep_tx, ep_rx) = tokio::sync::mpsc::channel(128); + + tokio::spawn(collect_endpoints( + component.drt().nats_client().clone(), + component.service_name(), + ep_tx, + cancellation_token.clone(), + )); + let mut ep_rx = ep_rx; + + tracing::trace!("awaiting the start of the background endpoint subscriber"); + let endpoints = Arc::new(Mutex::new(ProcessedEndpoints::default())); + let endpoints_clone = endpoints.clone(); + tokio::spawn(async move { + tracing::debug!("scheduler background task started"); + loop { + tracing::trace!("all workers busy; waiting for more capacity"); + match ep_rx.recv().await { + Some(endpoints) => { + let mut shared_endpoint = endpoints_clone.lock().unwrap(); + *shared_endpoint = endpoints; + }, + None => { + tracing::trace!("endpoint subscriber shutdown"); + break; + } + }; + } + + tracing::trace!("background endpoint subscriber shutting down"); + }); + Self { + service_name: component.service_name(), + endpoints: endpoints, + } } - pub async fn start(self) -> Result<()> { - let worker_id = self.worker_id; - - log::trace!( - worker_id, - "Starting nats service: {}:{}", - self.service_name, - VERSION - ); + pub fn get_endpoints(&self) -> ProcessedEndpoints { + let endpoints = self.endpoints.lock().unwrap(); + endpoints.clone() + } +} - let mut metrics_rx = self.metrics_rx; - let worker_id_clone = worker_id.clone(); - - let service = self - .nats - .client() - .service_builder() - .description("A handy min max service") - .stats_handler(move |name, stats| { - log::debug!( - worker_id = worker_id_clone.as_str(), - "[IN worker?] Stats for service {}: {:?}", - name, - stats - ); - let metrics = metrics_rx.borrow_and_update().clone(); - serde_json::to_value(&*metrics).unwrap() - }) - .start(self.service_name.as_str(), VERSION) +async fn collect_endpoints( + nats_client: triton_distributed_runtime::transports::nats::Client, + service_name: String, + ep_tx: tokio::sync::mpsc::Sender, + cancel: CancellationToken, +) { + loop { + tokio::select! { + _ = cancel.cancelled() => { + tracing::debug!("cancellation token triggered"); + break; + } + _ = tokio::time::sleep(Duration::from_secs(1)) => { + tracing::trace!("collecting endpoints for service: {}", service_name); + } + } + + let values = match nats_client + .get_endpoints(&service_name, Duration::from_secs(1)) .await - .map_err(|e| anyhow::anyhow!("Failed to start service: {e}"))?; - - let group = service.group(self.service_name.as_str()); - - log::trace!(worker_id, "Starting endpoint: {}", worker_id); + { + Ok(v) => v, + Err(e) => { + tracing::warn!("Failed to retrieve endpoints for {}: {:?}", service_name, e); + continue; + } + }; + + tracing::debug!("values: {:?}", values); + let services: Vec = values + .into_iter() + .filter(|v| !v.is_empty()) + .filter_map(|v| match serde_json::from_slice::(&v) { + Ok(service) => Some(service), + Err(e) => { + tracing::warn!("For value: {:?} \nFailed to parse service: {:?}", v, e); + None + } + }) + .collect(); + tracing::debug!("services: {:?}", services); + + let endpoints: Vec = services + .into_iter() + .flat_map(|s| s.endpoints) + .filter(|s| s.data.is_some()) + .map(|s| Endpoint { + name: s.name, + subject: s.subject, + data: s.data.unwrap(), + }) + .collect(); + tracing::debug!("endpoints: {:?}", endpoints); - // creates an endpoint for the service - let service_endpoint = group - .endpoint(worker_id.clone()) - .await - .map_err(|e| anyhow::anyhow!("Failed to start endpoint: {e}"))?; + tracing::trace!( + "found {} endpoints for service: {}", + endpoints.len(), + service_name + ); - let push_endpoint = PushEndpoint::builder() - .service_handler(self.service_handler) - .cancellation_token(self.cancellation_token) - .build() - .map_err(|e| anyhow::anyhow!("Failed to build push endpoint: {e}"))?; + let processed = ProcessedEndpoints::new(endpoints); - push_endpoint.start(service_endpoint).await + // process endpoints into + if ep_tx.send(processed).await.is_err() { + tracing::trace!("failed to send processed endpoints; shutting down"); + break; + } } } diff --git a/lib/llm/src/kv_router/scoring.rs b/lib/llm/src/kv_router/scoring.rs index 970acdb225..5b837dab5c 100644 --- a/lib/llm/src/kv_router/scoring.rs +++ b/lib/llm/src/kv_router/scoring.rs @@ -20,7 +20,7 @@ use std::collections::HashSet; use crate::kv_router::scheduler::Endpoint; -#[derive(Debug, Default, Serialize, Deserialize)] +#[derive(Debug, Default, Serialize, Deserialize, Clone)] pub struct ProcessedEndpoints { pub endpoints: Vec, pub worker_ids: Vec, From 9f7ca082f31dd640372f73d7cef1020726ff7f45 Mon Sep 17 00:00:00 2001 From: Guan Luo Date: Mon, 3 Mar 2025 15:55:57 -0800 Subject: [PATCH 03/13] wip: verifying --- .../llm/vllm/kv_router/test_router.py | 166 ++++++++++++++++++ lib/bindings/python/rust/lib.rs | 2 +- lib/bindings/python/rust/llm/kv.rs | 20 ++- 3 files changed, 181 insertions(+), 7 deletions(-) create mode 100644 examples/python_rs/llm/vllm/kv_router/test_router.py diff --git a/examples/python_rs/llm/vllm/kv_router/test_router.py b/examples/python_rs/llm/vllm/kv_router/test_router.py new file mode 100644 index 0000000000..a69c331d2d --- /dev/null +++ b/examples/python_rs/llm/vllm/kv_router/test_router.py @@ -0,0 +1,166 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +from argparse import Namespace +from enum import Enum +from typing import AsyncIterator + +import uvloop +from common.protocol import Tokens +from vllm.logger import logger as vllm_logger + +from triton_distributed.llm import KvRouter, KvIndexer, KvMetricsAggregator +from triton_distributed.runtime import ( + DistributedRuntime, + triton_endpoint, + triton_worker, +) + +WorkerId = str + + +class RoutingStrategy(Enum): + PREFIX = "prefix" + ROUND_ROBIN = "round_robin" + RANDOM = "random" + + +class Router: + """ + Request handler for the generate endpoint + """ + + def __init__( + self, + indexer: KvIndexer, + metrics_aggregator: KvMetricsAggregator, + routing_strategy: RoutingStrategy = RoutingStrategy.PREFIX, + ): + vllm_logger.info( + f"Initializing KV Router with strategy: {routing_strategy.value}" + ) + self.indexer = indexer + self.metrics_aggregator = metrics_aggregator + self.routing_strategy = routing_strategy + + + @triton_endpoint(Tokens, WorkerId) + async def generate(self, request) -> AsyncIterator[WorkerId]: + lora_id = 0 + worker_id = "" + if self.routing_strategy == RoutingStrategy.PREFIX: + try: + scores = await self.indexer.find_matches_for_request(request.tokens, lora_id) + print(f"Scores: {scores.scores()}") + metrics = await self.metrics_aggregator.get_metrics() + for endpoint in metrics.endpoints(): + print(f"Endpoint: {endpoint.worker_id()}") + print(f"Endpoint: {endpoint.request_total_slots()}") + print(f"Endpoint: {endpoint.kv_total_blocks()}") + # [NOTE][TODO] Now that the scheduler may return more error messages, + # now we are catching all exceptions and logging them. Should have + # catch specific router exceptions once we have dedicated types. + except Exception as e: + vllm_logger.info(f"{e}") + worker_id = "" + vllm_logger.exception(f"Error during worker selection: {e}") + + vllm_logger.info(f"Scheduling to worker_id: {worker_id}") + + yield str(worker_id) + + else: + # TODO: Do we implement round_robin and random here? + # or just skip this router and directly enable in preprocess? + raise NotImplementedError( + f"Routing strategy {self.routing_strategy} not implemented" + ) + + +@triton_worker() +async def worker(runtime: DistributedRuntime, args: Namespace): + """ + Set up the worker clients. + Serve the triton-init.router.generate endpoint. + """ + workers_client = ( + await runtime.namespace("triton-init") + .component("vllm") + .endpoint("generate") + .client() + ) + wait_task = workers_client.wait_for_endpoints() + await asyncio.sleep(1) + + while not wait_task.done(): + vllm_logger.info("Waiting for workers to be ready...") + await asyncio.sleep(5) + + wait_task.result() + + while len(workers_client.endpoint_ids()) < args.min_workers: + vllm_logger.info( + f"Waiting for more workers... Current: {len(workers_client.endpoint_ids())}, Required: {args.min_workers}" + ) + await asyncio.sleep(5) + + vllm_logger.info( + f"Required number of workers ({args.min_workers}) are ready:\n" + + "\n".join(f"id: {id}" for id in workers_client.endpoint_ids()) + ) + + kv_listener = runtime.namespace("triton-init").component("vllm") + await kv_listener.create_service() + + router_component = runtime.namespace("triton-init").component("router") + await router_component.create_service() + + indexer = KvIndexer(kv_listener, runtime.primary_token()) + metrics_aggregator = KvMetricsAggregator(kv_listener, runtime.primary_token()) + + endpoint = router_component.endpoint("generate") + await endpoint.serve_endpoint(Router(indexer, metrics_aggregator, args.routing_strategy).generate) + + +if __name__ == "__main__": + uvloop.install() + + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--routing-strategy", + type=RoutingStrategy, + default=RoutingStrategy.PREFIX, + choices=list(RoutingStrategy), + help="Routing strategy to use", + ) + parser.add_argument( + "--min-workers", + type=int, + default=1, + help="Minimum number of workers required before proceeding", + ) + parser.add_argument( + "--model-name", + type=str, + default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B", + help="Model that is being served", + ) + args = parser.parse_args() + + asyncio.run(worker(args)) diff --git a/lib/bindings/python/rust/lib.rs b/lib/bindings/python/rust/lib.rs index 30ad7f2439..a47f09e16f 100644 --- a/lib/bindings/python/rust/lib.rs +++ b/lib/bindings/python/rust/lib.rs @@ -72,7 +72,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; - m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/lib/bindings/python/rust/llm/kv.rs b/lib/bindings/python/rust/llm/kv.rs index bbc72754e8..73bbd6c759 100644 --- a/lib/bindings/python/rust/llm/kv.rs +++ b/lib/bindings/python/rust/llm/kv.rs @@ -176,24 +176,32 @@ impl KvIndexer { #[pyclass] #[derive(Clone)] -pub(crate) struct EndpiontKvMetrics +pub(crate) struct EndpointKvMetrics { + #[pyo3(get, set)] pub worker_ids: i64, + #[pyo3(get, set)] pub request_active_slots: u64, + #[pyo3(get, set)] pub request_total_slots: u64, + #[pyo3(get, set)] pub kv_active_blocks: u64, + #[pyo3(get, set)] pub kv_total_blocks: u64, } + #[pyclass] #[derive(Clone)] pub(crate) struct AggregatedMetrics { - pub endpoints: Vec, + #[pyo3(get, set)] + pub endpoints: Vec, + #[pyo3(get, set)] pub load_avg: f64, + #[pyo3(get, set)] pub load_std: f64, } - #[pyclass] pub(crate) struct KvMetricsAggregator { inner: Arc, @@ -216,10 +224,10 @@ impl KvMetricsAggregator { fn get_metrics<'p>(&self, py: Python<'p>) -> PyResult> { let endpoints = self.inner.get_endpoints(); - let endpiont_kv_metrics = endpoints + let endpoint_kv_metrics = endpoints .endpoints .iter() - .map(|x| EndpiontKvMetrics { + .map(|x| EndpointKvMetrics { worker_ids: x.worker_id(), request_active_slots: x.data.request_active_slots, request_total_slots: x.data.request_total_slots, @@ -229,7 +237,7 @@ impl KvMetricsAggregator { .collect(); pyo3_async_runtimes::tokio::future_into_py(py, async move { Ok(AggregatedMetrics { - endpoints: endpiont_kv_metrics, + endpoints: endpoint_kv_metrics, load_avg: endpoints.load_avg, load_std: endpoints.load_std, })}) From 18f989c5361cf681f8e418a57b01a5a5f7e39a26 Mon Sep 17 00:00:00 2001 From: Guan Luo Date: Mon, 3 Mar 2025 15:56:55 -0800 Subject: [PATCH 04/13] wip: fix up --- examples/python_rs/llm/vllm/kv_router/test_router.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/python_rs/llm/vllm/kv_router/test_router.py b/examples/python_rs/llm/vllm/kv_router/test_router.py index a69c331d2d..a748e37d57 100644 --- a/examples/python_rs/llm/vllm/kv_router/test_router.py +++ b/examples/python_rs/llm/vllm/kv_router/test_router.py @@ -67,10 +67,10 @@ async def generate(self, request) -> AsyncIterator[WorkerId]: scores = await self.indexer.find_matches_for_request(request.tokens, lora_id) print(f"Scores: {scores.scores()}") metrics = await self.metrics_aggregator.get_metrics() - for endpoint in metrics.endpoints(): - print(f"Endpoint: {endpoint.worker_id()}") - print(f"Endpoint: {endpoint.request_total_slots()}") - print(f"Endpoint: {endpoint.kv_total_blocks()}") + for endpoint in metrics.endpoints: + print(f"Endpoint: {endpoint.worker_id}") + print(f"Endpoint: {endpoint.request_total_slots}") + print(f"Endpoint: {endpoint.kv_total_blocks}") # [NOTE][TODO] Now that the scheduler may return more error messages, # now we are catching all exceptions and logging them. Should have # catch specific router exceptions once we have dedicated types. From 7ddd545c967b8ca056e502581ff8e34c743eca9b Mon Sep 17 00:00:00 2001 From: Guan Luo Date: Tue, 4 Mar 2025 10:57:09 -0800 Subject: [PATCH 05/13] fix: fix up --- lib/bindings/python/rust/llm/kv.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/bindings/python/rust/llm/kv.rs b/lib/bindings/python/rust/llm/kv.rs index 73bbd6c759..360937fb5d 100644 --- a/lib/bindings/python/rust/llm/kv.rs +++ b/lib/bindings/python/rust/llm/kv.rs @@ -179,7 +179,7 @@ impl KvIndexer { pub(crate) struct EndpointKvMetrics { #[pyo3(get, set)] - pub worker_ids: i64, + pub worker_id: i64, #[pyo3(get, set)] pub request_active_slots: u64, #[pyo3(get, set)] @@ -228,7 +228,7 @@ impl KvMetricsAggregator { .endpoints .iter() .map(|x| EndpointKvMetrics { - worker_ids: x.worker_id(), + worker_id: x.worker_id(), request_active_slots: x.data.request_active_slots, request_total_slots: x.data.request_total_slots, kv_active_blocks: x.data.kv_active_blocks, From 5f7b8bfdf8b0bbd422ea1147e0c96f4cbdc51781 Mon Sep 17 00:00:00 2001 From: Guan Luo Date: Tue, 4 Mar 2025 11:40:32 -0800 Subject: [PATCH 06/13] style: doc and format --- .../llm/vllm/kv_router/test_router.py | 11 +++-- lib/bindings/python/rust/llm/kv.rs | 41 +++++++++++++------ lib/bindings/python/src/dynemo/_core.pyi | 17 ++++---- .../python/src/dynemo/llm/__init__.py | 4 +- lib/llm/src/kv_router.rs | 2 +- lib/llm/src/kv_router/metrics_aggregator.rs | 16 +++----- 6 files changed, 51 insertions(+), 40 deletions(-) diff --git a/examples/python_rs/llm/vllm/kv_router/test_router.py b/examples/python_rs/llm/vllm/kv_router/test_router.py index a748e37d57..f023c9e7cd 100644 --- a/examples/python_rs/llm/vllm/kv_router/test_router.py +++ b/examples/python_rs/llm/vllm/kv_router/test_router.py @@ -23,7 +23,7 @@ from common.protocol import Tokens from vllm.logger import logger as vllm_logger -from triton_distributed.llm import KvRouter, KvIndexer, KvMetricsAggregator +from triton_distributed.llm import KvIndexer, KvMetricsAggregator from triton_distributed.runtime import ( DistributedRuntime, triton_endpoint, @@ -57,14 +57,15 @@ def __init__( self.metrics_aggregator = metrics_aggregator self.routing_strategy = routing_strategy - @triton_endpoint(Tokens, WorkerId) async def generate(self, request) -> AsyncIterator[WorkerId]: lora_id = 0 worker_id = "" if self.routing_strategy == RoutingStrategy.PREFIX: try: - scores = await self.indexer.find_matches_for_request(request.tokens, lora_id) + scores = await self.indexer.find_matches_for_request( + request.tokens, lora_id + ) print(f"Scores: {scores.scores()}") metrics = await self.metrics_aggregator.get_metrics() for endpoint in metrics.endpoints: @@ -133,7 +134,9 @@ async def worker(runtime: DistributedRuntime, args: Namespace): metrics_aggregator = KvMetricsAggregator(kv_listener, runtime.primary_token()) endpoint = router_component.endpoint("generate") - await endpoint.serve_endpoint(Router(indexer, metrics_aggregator, args.routing_strategy).generate) + await endpoint.serve_endpoint( + Router(indexer, metrics_aggregator, args.routing_strategy).generate + ) if __name__ == "__main__": diff --git a/lib/bindings/python/rust/llm/kv.rs b/lib/bindings/python/rust/llm/kv.rs index 360937fb5d..3386233189 100644 --- a/lib/bindings/python/rust/llm/kv.rs +++ b/lib/bindings/python/rust/llm/kv.rs @@ -16,8 +16,8 @@ use std::collections::HashMap; use super::*; -use tracing; use llm_rs::kv_router::indexer::KvIndexerInterface; +use tracing; #[pyclass] pub(crate) struct KvRouter { @@ -137,19 +137,33 @@ impl KvIndexer { fn new(component: Component, token: CancellationToken) -> PyResult { let runtime = pyo3_async_runtimes::tokio::get_runtime(); runtime.block_on(async { - let kv_subject = component.inner.event_subject(llm_rs::kv_router::KV_EVENT_SUBJECT); - let inner: Arc = llm_rs::kv_router::indexer::KvIndexer::new(token.inner).into(); - let mut kv_events_rx = component.inner.drt().nats_client().client().subscribe(kv_subject).await.map_err(to_pyerr)?; + let kv_subject = component + .inner + .event_subject(llm_rs::kv_router::KV_EVENT_SUBJECT); + let inner: Arc = + llm_rs::kv_router::indexer::KvIndexer::new(token.inner).into(); + let mut kv_events_rx = component + .inner + .drt() + .nats_client() + .client() + .subscribe(kv_subject) + .await + .map_err(to_pyerr)?; let kv_events_tx = inner.event_sender(); - + // [FIXME] this is the added functionality to the indexer to subscribe to kv events, // should have been made to a trait and implemented here? i.e. AsyncEngine style tokio::spawn(async move { while let Some(event) = kv_events_rx.next().await { - let event: llm_rs::kv_router::indexer::RouterEvent = serde_json::from_slice(&event.payload).unwrap(); + let event: llm_rs::kv_router::indexer::RouterEvent = + serde_json::from_slice(&event.payload).unwrap(); tracing::debug!("received kv event: {:?}", event); if let Err(e) = kv_events_tx.send(event).await { - tracing::trace!("failed to send kv event to indexer; shutting down: {:?}", e); + tracing::trace!( + "failed to send kv event to indexer; shutting down: {:?}", + e + ); } } }); @@ -176,8 +190,7 @@ impl KvIndexer { #[pyclass] #[derive(Clone)] -pub(crate) struct EndpointKvMetrics -{ +pub(crate) struct EndpointKvMetrics { #[pyo3(get, set)] pub worker_id: i64, #[pyo3(get, set)] @@ -190,7 +203,6 @@ pub(crate) struct EndpointKvMetrics pub kv_total_blocks: u64, } - #[pyclass] #[derive(Clone)] pub(crate) struct AggregatedMetrics { @@ -218,10 +230,12 @@ impl KvMetricsAggregator { token.inner, ) .await; - Ok(Self { inner: inner.into() }) + Ok(Self { + inner: inner.into(), + }) }) } - + fn get_metrics<'p>(&self, py: Python<'p>) -> PyResult> { let endpoints = self.inner.get_endpoints(); let endpoint_kv_metrics = endpoints @@ -240,6 +254,7 @@ impl KvMetricsAggregator { endpoints: endpoint_kv_metrics, load_avg: endpoints.load_avg, load_std: endpoints.load_std, - })}) + }) + }) } } diff --git a/lib/bindings/python/src/dynemo/_core.pyi b/lib/bindings/python/src/dynemo/_core.pyi index a777fd1848..07a811933f 100644 --- a/lib/bindings/python/src/dynemo/_core.pyi +++ b/lib/bindings/python/src/dynemo/_core.pyi @@ -242,15 +242,14 @@ class CancellationToken: class OverlapScores: """ - A collection of scores for a given token ids + A collection of prefix matching scores of workers for a given token ids """ ... -# [WIP] fix docs class KvIndexer: """ - A metrics publisher will provide KV metrics to the router. + A KV indexer that tracks the KV block operationss of the workers. """ ... @@ -262,33 +261,31 @@ class KvIndexer: def find_matches_for_request(self, token_ids: List[int], lora_id: int) -> OverlapScores: """ - Similar to Component.create_service, but only service created through - this method will interact with KV router of the same component. + Return the overlapping scores of workers for the given token ids. """ ... class AggregatedMetrics: """ - A collection of scores for a given token ids + A collection of metrics of the endpoints """ ... class KvMetricsAggregator: """ - A metrics publisher will provide KV metrics to the router. + A metrics aggregator will collect KV metrics of the endpoints. """ ... def __init__(self, component: Component, token: CancellationToken) -> None: """ - Create a `KvIndexer` object + Create a `KvMetricsAggregator` object """ def get_metrics(self) -> AggregatedMetrics: """ - Similar to Component.create_service, but only service created through - this method will interact with KV router of the same component. + Return the aggregated metrics of the endpoints. """ ... diff --git a/lib/bindings/python/src/dynemo/llm/__init__.py b/lib/bindings/python/src/dynemo/llm/__init__.py index 4a7f35c1ea..9f6ce1e45b 100644 --- a/lib/bindings/python/src/dynemo/llm/__init__.py +++ b/lib/bindings/python/src/dynemo/llm/__init__.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dynemo._core import KvMetricsPublisher as KvMetricsPublisher -from dynemo._core import KvRouter as KvRouter from dynemo._core import KvIndexer as KvIndexer from dynemo._core import KvMetricsAggregator as KvMetricsAggregator +from dynemo._core import KvMetricsPublisher as KvMetricsPublisher +from dynemo._core import KvRouter as KvRouter diff --git a/lib/llm/src/kv_router.rs b/lib/llm/src/kv_router.rs index 0014381c3f..cc2d43b0c6 100644 --- a/lib/llm/src/kv_router.rs +++ b/lib/llm/src/kv_router.rs @@ -21,11 +21,11 @@ use tokio_util::sync::CancellationToken; use tracing; pub mod indexer; +pub mod metrics_aggregator; pub mod protocols; pub mod publisher; pub mod scheduler; pub mod scoring; -pub mod metrics_aggregator; use crate::kv_router::{ indexer::{KvIndexer, KvIndexerInterface, RouterEvent}, diff --git a/lib/llm/src/kv_router/metrics_aggregator.rs b/lib/llm/src/kv_router/metrics_aggregator.rs index 316925dc85..d984682d46 100644 --- a/lib/llm/src/kv_router/metrics_aggregator.rs +++ b/lib/llm/src/kv_router/metrics_aggregator.rs @@ -19,16 +19,15 @@ pub use crate::kv_router::protocols::ForwardPassMetrics; use anyhow::Result; use triton_distributed_runtime::pipeline::network::{ - ingress::push_endpoint::PushEndpoint, - PushWorkHandler, + ingress::push_endpoint::PushEndpoint, PushWorkHandler, }; +use crate::kv_router::scheduler::{Endpoint, Service}; +use crate::kv_router::ProcessedEndpoints; +use std::time::Duration; use tokio::sync::watch; use tokio_util::sync::CancellationToken; use triton_distributed_runtime::{component::Component, DistributedRuntime}; -use crate::kv_router::ProcessedEndpoints; -use crate::kv_router::scheduler::{Service, Endpoint}; -use std::time::Duration; pub struct KvMetricsAggregator { pub service_name: String, @@ -36,10 +35,7 @@ pub struct KvMetricsAggregator { } impl KvMetricsAggregator { - pub async fn new( - component: Component, - cancellation_token: CancellationToken - ) -> Self { + pub async fn new(component: Component, cancellation_token: CancellationToken) -> Self { let (ep_tx, ep_rx) = tokio::sync::mpsc::channel(128); tokio::spawn(collect_endpoints( @@ -61,7 +57,7 @@ impl KvMetricsAggregator { Some(endpoints) => { let mut shared_endpoint = endpoints_clone.lock().unwrap(); *shared_endpoint = endpoints; - }, + } None => { tracing::trace!("endpoint subscriber shutdown"); break; From c30b857ee78209094b05fbd1f2ef269df53bf26a Mon Sep 17 00:00:00 2001 From: Guan Luo Date: Tue, 4 Mar 2025 11:57:34 -0800 Subject: [PATCH 07/13] style: clean up --- .../python_rs/llm/vllm/kv_router/router.py | 74 +++++++- .../llm/vllm/kv_router/test_router.py | 169 ------------------ 2 files changed, 70 insertions(+), 173 deletions(-) delete mode 100644 examples/python_rs/llm/vllm/kv_router/test_router.py diff --git a/examples/python_rs/llm/vllm/kv_router/router.py b/examples/python_rs/llm/vllm/kv_router/router.py index 29b9b06364..886cef8be7 100644 --- a/examples/python_rs/llm/vllm/kv_router/router.py +++ b/examples/python_rs/llm/vllm/kv_router/router.py @@ -23,7 +23,7 @@ from common.protocol import Tokens from vllm.logger import logger as vllm_logger -from dynemo.llm import KvRouter +from dynemo.llm import KvIndexer, KvMetricsAggregator, KvRouter from dynemo.runtime import DistributedRuntime, dynemo_endpoint, dynemo_worker WorkerId = str @@ -77,6 +77,59 @@ async def generate(self, request) -> AsyncIterator[WorkerId]: f"Routing strategy {self.routing_strategy} not implemented" ) +class CustomRouter: + """ + Request handler for the generate endpoint + """ + + def __init__( + self, + indexer: KvIndexer, + metrics_aggregator: KvMetricsAggregator, + ): + self.indexer = indexer + self.metrics_aggregator = metrics_aggregator + + def _cost_function(self, scores, metrics): + # naive cost function for demonstration purposes + current_best = ("", 0) + for worker_id, score in scores.scores().items(): + if score > current_best[1]: + current_best = (worker_id, score) + for endpoint in metrics.endpoints: + if endpoint.worker_id == current_best[0]: + print(f"Metrics of endpoint: {endpoint.worker_id}") + print( + f"request slot usage: {endpoint.request_active_slots} / {endpoint.request_total_slots}" + ) + print( + f"KV block usage: {endpoint.kv_active_blocks} / {endpoint.kv_total_blocks}" + ) + return current_best[0] + + @dynemo_endpoint(Tokens, WorkerId) + async def generate(self, request) -> AsyncIterator[WorkerId]: + lora_id = 0 + worker_id = "" + try: + scores = await self.indexer.find_matches_for_request( + request.tokens, lora_id + ) + metrics = await self.metrics_aggregator.get_metrics() + worker_id = self._cost_function(scores, metrics) + + # [NOTE][TODO] Now that the scheduler may return more error messages, + # now we are catching all exceptions and logging them. Should have + # catch specific router exceptions once we have dedicated types. + except Exception as e: + vllm_logger.info(f"{e}") + worker_id = "" + vllm_logger.exception(f"Error during worker selection: {e}") + + vllm_logger.info(f"Scheduling to worker_id: {worker_id}") + + yield str(worker_id) + @dynemo_worker() async def worker(runtime: DistributedRuntime, args: Namespace): @@ -116,10 +169,17 @@ async def worker(runtime: DistributedRuntime, args: Namespace): router_component = runtime.namespace("dynemo").component("router") await router_component.create_service() - router = KvRouter(runtime, kv_listener) - endpoint = router_component.endpoint("generate") - await endpoint.serve_endpoint(Router(router, args.routing_strategy).generate) + + if args.custom_router: + indexer = KvIndexer(kv_listener, runtime.primary_token()) + metrics_aggregator = KvMetricsAggregator(kv_listener, runtime.primary_token()) + await endpoint.serve_endpoint( + CustomRouter(indexer, metrics_aggregator).generate + ) + else: + router = KvRouter(runtime, kv_listener) + await endpoint.serve_endpoint(Router(router, args.routing_strategy).generate) if __name__ == "__main__": @@ -147,6 +207,12 @@ async def worker(runtime: DistributedRuntime, args: Namespace): default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B", help="Model that is being served", ) + parser.add_argument( + "--custom-router", + type=bool, + default=False, + help="Whether to use custom router or not", + ) args = parser.parse_args() asyncio.run(worker(args)) diff --git a/examples/python_rs/llm/vllm/kv_router/test_router.py b/examples/python_rs/llm/vllm/kv_router/test_router.py deleted file mode 100644 index f023c9e7cd..0000000000 --- a/examples/python_rs/llm/vllm/kv_router/test_router.py +++ /dev/null @@ -1,169 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import asyncio -from argparse import Namespace -from enum import Enum -from typing import AsyncIterator - -import uvloop -from common.protocol import Tokens -from vllm.logger import logger as vllm_logger - -from triton_distributed.llm import KvIndexer, KvMetricsAggregator -from triton_distributed.runtime import ( - DistributedRuntime, - triton_endpoint, - triton_worker, -) - -WorkerId = str - - -class RoutingStrategy(Enum): - PREFIX = "prefix" - ROUND_ROBIN = "round_robin" - RANDOM = "random" - - -class Router: - """ - Request handler for the generate endpoint - """ - - def __init__( - self, - indexer: KvIndexer, - metrics_aggregator: KvMetricsAggregator, - routing_strategy: RoutingStrategy = RoutingStrategy.PREFIX, - ): - vllm_logger.info( - f"Initializing KV Router with strategy: {routing_strategy.value}" - ) - self.indexer = indexer - self.metrics_aggregator = metrics_aggregator - self.routing_strategy = routing_strategy - - @triton_endpoint(Tokens, WorkerId) - async def generate(self, request) -> AsyncIterator[WorkerId]: - lora_id = 0 - worker_id = "" - if self.routing_strategy == RoutingStrategy.PREFIX: - try: - scores = await self.indexer.find_matches_for_request( - request.tokens, lora_id - ) - print(f"Scores: {scores.scores()}") - metrics = await self.metrics_aggregator.get_metrics() - for endpoint in metrics.endpoints: - print(f"Endpoint: {endpoint.worker_id}") - print(f"Endpoint: {endpoint.request_total_slots}") - print(f"Endpoint: {endpoint.kv_total_blocks}") - # [NOTE][TODO] Now that the scheduler may return more error messages, - # now we are catching all exceptions and logging them. Should have - # catch specific router exceptions once we have dedicated types. - except Exception as e: - vllm_logger.info(f"{e}") - worker_id = "" - vllm_logger.exception(f"Error during worker selection: {e}") - - vllm_logger.info(f"Scheduling to worker_id: {worker_id}") - - yield str(worker_id) - - else: - # TODO: Do we implement round_robin and random here? - # or just skip this router and directly enable in preprocess? - raise NotImplementedError( - f"Routing strategy {self.routing_strategy} not implemented" - ) - - -@triton_worker() -async def worker(runtime: DistributedRuntime, args: Namespace): - """ - Set up the worker clients. - Serve the triton-init.router.generate endpoint. - """ - workers_client = ( - await runtime.namespace("triton-init") - .component("vllm") - .endpoint("generate") - .client() - ) - wait_task = workers_client.wait_for_endpoints() - await asyncio.sleep(1) - - while not wait_task.done(): - vllm_logger.info("Waiting for workers to be ready...") - await asyncio.sleep(5) - - wait_task.result() - - while len(workers_client.endpoint_ids()) < args.min_workers: - vllm_logger.info( - f"Waiting for more workers... Current: {len(workers_client.endpoint_ids())}, Required: {args.min_workers}" - ) - await asyncio.sleep(5) - - vllm_logger.info( - f"Required number of workers ({args.min_workers}) are ready:\n" - + "\n".join(f"id: {id}" for id in workers_client.endpoint_ids()) - ) - - kv_listener = runtime.namespace("triton-init").component("vllm") - await kv_listener.create_service() - - router_component = runtime.namespace("triton-init").component("router") - await router_component.create_service() - - indexer = KvIndexer(kv_listener, runtime.primary_token()) - metrics_aggregator = KvMetricsAggregator(kv_listener, runtime.primary_token()) - - endpoint = router_component.endpoint("generate") - await endpoint.serve_endpoint( - Router(indexer, metrics_aggregator, args.routing_strategy).generate - ) - - -if __name__ == "__main__": - uvloop.install() - - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument( - "--routing-strategy", - type=RoutingStrategy, - default=RoutingStrategy.PREFIX, - choices=list(RoutingStrategy), - help="Routing strategy to use", - ) - parser.add_argument( - "--min-workers", - type=int, - default=1, - help="Minimum number of workers required before proceeding", - ) - parser.add_argument( - "--model-name", - type=str, - default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B", - help="Model that is being served", - ) - args = parser.parse_args() - - asyncio.run(worker(args)) From b2c062c657169bf1bdb60ceeefe11b0e3bfbe8a6 Mon Sep 17 00:00:00 2001 From: Guan Luo Date: Tue, 4 Mar 2025 12:08:14 -0800 Subject: [PATCH 08/13] style: fix commit check --- lib/llm/src/kv_router/metrics_aggregator.rs | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/lib/llm/src/kv_router/metrics_aggregator.rs b/lib/llm/src/kv_router/metrics_aggregator.rs index d984682d46..380f10652d 100644 --- a/lib/llm/src/kv_router/metrics_aggregator.rs +++ b/lib/llm/src/kv_router/metrics_aggregator.rs @@ -17,17 +17,11 @@ use std::sync::{Arc, Mutex}; pub use crate::kv_router::protocols::ForwardPassMetrics; -use anyhow::Result; -use triton_distributed_runtime::pipeline::network::{ - ingress::push_endpoint::PushEndpoint, PushWorkHandler, -}; - use crate::kv_router::scheduler::{Endpoint, Service}; use crate::kv_router::ProcessedEndpoints; use std::time::Duration; -use tokio::sync::watch; use tokio_util::sync::CancellationToken; -use triton_distributed_runtime::{component::Component, DistributedRuntime}; +use triton_distributed_runtime::component::Component; pub struct KvMetricsAggregator { pub service_name: String, @@ -69,7 +63,7 @@ impl KvMetricsAggregator { }); Self { service_name: component.service_name(), - endpoints: endpoints, + endpoints, } } From ff4ec6914c114003106469cbdcc6e79b27dc9331 Mon Sep 17 00:00:00 2001 From: Guan Luo Date: Wed, 5 Mar 2025 13:42:16 -0800 Subject: [PATCH 09/13] chore: address comment --- .../python_rs/llm/vllm/kv_router/router.py | 4 +-- lib/bindings/python/rust/llm/kv.rs | 25 +++++++++++------ lib/bindings/python/src/dynemo/_core.pyi | 14 ++++------ lib/llm/src/kv_router/metrics_aggregator.rs | 27 +++++++++++-------- 4 files changed, 40 insertions(+), 30 deletions(-) diff --git a/examples/python_rs/llm/vllm/kv_router/router.py b/examples/python_rs/llm/vllm/kv_router/router.py index 886cef8be7..05dc7e4915 100644 --- a/examples/python_rs/llm/vllm/kv_router/router.py +++ b/examples/python_rs/llm/vllm/kv_router/router.py @@ -172,8 +172,8 @@ async def worker(runtime: DistributedRuntime, args: Namespace): endpoint = router_component.endpoint("generate") if args.custom_router: - indexer = KvIndexer(kv_listener, runtime.primary_token()) - metrics_aggregator = KvMetricsAggregator(kv_listener, runtime.primary_token()) + indexer = KvIndexer(kv_listener) + metrics_aggregator = KvMetricsAggregator(kv_listener) await endpoint.serve_endpoint( CustomRouter(indexer, metrics_aggregator).generate ) diff --git a/lib/bindings/python/rust/llm/kv.rs b/lib/bindings/python/rust/llm/kv.rs index 3386233189..ccb15296aa 100644 --- a/lib/bindings/python/rust/llm/kv.rs +++ b/lib/bindings/python/rust/llm/kv.rs @@ -113,16 +113,20 @@ impl KvMetricsPublisher { #[pyclass] #[derive(Clone)] -pub(crate) struct OverlapScores(pub llm_rs::kv_router::indexer::OverlapScores); +pub(crate) struct OverlapScores { + inner: llm_rs::kv_router::indexer::OverlapScores, +} #[pymethods] impl OverlapScores { + #[getter] fn scores(&self) -> HashMap { - self.0.scores.clone() + self.inner.scores.clone() } + #[getter] fn frequencies(&self) -> Vec { - self.0.frequencies.clone() + self.inner.frequencies.clone() } } @@ -134,14 +138,17 @@ pub(crate) struct KvIndexer { #[pymethods] impl KvIndexer { #[new] - fn new(component: Component, token: CancellationToken) -> PyResult { + fn new(component: Component) -> PyResult { let runtime = pyo3_async_runtimes::tokio::get_runtime(); runtime.block_on(async { let kv_subject = component .inner .event_subject(llm_rs::kv_router::KV_EVENT_SUBJECT); let inner: Arc = - llm_rs::kv_router::indexer::KvIndexer::new(token.inner).into(); + llm_rs::kv_router::indexer::KvIndexer::new( + component.inner.drt().runtime().child_token(), + ) + .into(); let mut kv_events_rx = component .inner .drt() @@ -183,7 +190,9 @@ impl KvIndexer { .find_matches_for_request(token_ids.as_slice()) .await .map_err(to_pyerr)?; - Ok(OverlapScores(rs_overlap_scores)) + Ok(OverlapScores { + inner: rs_overlap_scores, + }) }) } } @@ -222,12 +231,12 @@ pub(crate) struct KvMetricsAggregator { #[pymethods] impl KvMetricsAggregator { #[new] - fn new(component: Component, token: CancellationToken) -> PyResult { + fn new(component: Component) -> PyResult { let runtime = pyo3_async_runtimes::tokio::get_runtime(); runtime.block_on(async { let inner = llm_rs::kv_router::metrics_aggregator::KvMetricsAggregator::new( component.inner.clone(), - token.inner, + component.inner.drt().runtime().child_token(), ) .await; Ok(Self { diff --git a/lib/bindings/python/src/dynemo/_core.pyi b/lib/bindings/python/src/dynemo/_core.pyi index 07a811933f..cd26f95c89 100644 --- a/lib/bindings/python/src/dynemo/_core.pyi +++ b/lib/bindings/python/src/dynemo/_core.pyi @@ -233,28 +233,24 @@ class Backend: Start the backend engine and requests to the downstream LLM engine """ ... -class CancellationToken: - """ - A cancellation token is used to cancel an operation - """ - ... class OverlapScores: """ - A collection of prefix matching scores of workers for a given token ids + A collection of prefix matching scores of workers for a given token ids. + 'scores' is a map of worker id to the score which is the number of matching blocks. """ ... class KvIndexer: """ - A KV indexer that tracks the KV block operationss of the workers. + A KV Indexer that tracks KV Events emitted by workers. Events include add_block and remove_block. """ ... - def __init__(self, component: Component, token: CancellationToken) -> None: + def __init__(self, component: Component) -> None: """ Create a `KvIndexer` object """ @@ -279,7 +275,7 @@ class KvMetricsAggregator: ... - def __init__(self, component: Component, token: CancellationToken) -> None: + def __init__(self, component: Component) -> None: """ Create a `KvMetricsAggregator` object """ diff --git a/lib/llm/src/kv_router/metrics_aggregator.rs b/lib/llm/src/kv_router/metrics_aggregator.rs index 380f10652d..a75c901452 100644 --- a/lib/llm/src/kv_router/metrics_aggregator.rs +++ b/lib/llm/src/kv_router/metrics_aggregator.rs @@ -30,7 +30,7 @@ pub struct KvMetricsAggregator { impl KvMetricsAggregator { pub async fn new(component: Component, cancellation_token: CancellationToken) -> Self { - let (ep_tx, ep_rx) = tokio::sync::mpsc::channel(128); + let (ep_tx, mut ep_rx) = tokio::sync::mpsc::channel(128); tokio::spawn(collect_endpoints( component.drt().nats_client().clone(), @@ -38,7 +38,6 @@ impl KvMetricsAggregator { ep_tx, cancellation_token.clone(), )); - let mut ep_rx = ep_rx; tracing::trace!("awaiting the start of the background endpoint subscriber"); let endpoints = Arc::new(Mutex::new(ProcessedEndpoints::default())); @@ -46,12 +45,15 @@ impl KvMetricsAggregator { tokio::spawn(async move { tracing::debug!("scheduler background task started"); loop { - tracing::trace!("all workers busy; waiting for more capacity"); match ep_rx.recv().await { - Some(endpoints) => { - let mut shared_endpoint = endpoints_clone.lock().unwrap(); - *shared_endpoint = endpoints; - } + Some(endpoints) => match endpoints_clone.lock() { + Ok(mut shared_endpoint) => { + *shared_endpoint = endpoints; + } + Err(e) => { + tracing::error!("Failed to acquire lock on endpoints: {:?}", e); + } + }, None => { tracing::trace!("endpoint subscriber shutdown"); break; @@ -68,8 +70,13 @@ impl KvMetricsAggregator { } pub fn get_endpoints(&self) -> ProcessedEndpoints { - let endpoints = self.endpoints.lock().unwrap(); - endpoints.clone() + match self.endpoints.lock() { + Ok(endpoints) => endpoints.clone(), + Err(e) => { + tracing::error!("Failed to acquire lock on endpoints: {:?}", e); + ProcessedEndpoints::default() + } + } } } @@ -134,8 +141,6 @@ async fn collect_endpoints( ); let processed = ProcessedEndpoints::new(endpoints); - - // process endpoints into if ep_tx.send(processed).await.is_err() { tracing::trace!("failed to send processed endpoints; shutting down"); break; From a1efbd5a54c94ab0ffe3da9aeffdaca83097c467 Mon Sep 17 00:00:00 2001 From: Guan Luo Date: Wed, 5 Mar 2025 14:46:15 -0800 Subject: [PATCH 10/13] fix: rebase artifact --- lib/llm/src/kv_router/metrics_aggregator.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/llm/src/kv_router/metrics_aggregator.rs b/lib/llm/src/kv_router/metrics_aggregator.rs index a75c901452..5fc21c6b71 100644 --- a/lib/llm/src/kv_router/metrics_aggregator.rs +++ b/lib/llm/src/kv_router/metrics_aggregator.rs @@ -19,9 +19,9 @@ pub use crate::kv_router::protocols::ForwardPassMetrics; use crate::kv_router::scheduler::{Endpoint, Service}; use crate::kv_router::ProcessedEndpoints; +use dynemo_runtime::component::Component; use std::time::Duration; use tokio_util::sync::CancellationToken; -use triton_distributed_runtime::component::Component; pub struct KvMetricsAggregator { pub service_name: String, @@ -81,7 +81,7 @@ impl KvMetricsAggregator { } async fn collect_endpoints( - nats_client: triton_distributed_runtime::transports::nats::Client, + nats_client: dynemo_runtime::transports::nats::Client, service_name: String, ep_tx: tokio::sync::mpsc::Sender, cancel: CancellationToken, From 86b6a555b7d1a96e4a7273db939aca53819124bd Mon Sep 17 00:00:00 2001 From: Guan Luo Date: Wed, 5 Mar 2025 15:01:59 -0800 Subject: [PATCH 11/13] style: format --- examples/python_rs/llm/vllm/kv_router/router.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/python_rs/llm/vllm/kv_router/router.py b/examples/python_rs/llm/vllm/kv_router/router.py index 05dc7e4915..ef5e85e6c7 100644 --- a/examples/python_rs/llm/vllm/kv_router/router.py +++ b/examples/python_rs/llm/vllm/kv_router/router.py @@ -77,6 +77,7 @@ async def generate(self, request) -> AsyncIterator[WorkerId]: f"Routing strategy {self.routing_strategy} not implemented" ) + class CustomRouter: """ Request handler for the generate endpoint From d05c6246fc186a74045896493cc18accc38a2ef9 Mon Sep 17 00:00:00 2001 From: Guan Luo Date: Wed, 5 Mar 2025 15:30:01 -0800 Subject: [PATCH 12/13] fix: fix up --- examples/python_rs/llm/vllm/kv_router/router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/python_rs/llm/vllm/kv_router/router.py b/examples/python_rs/llm/vllm/kv_router/router.py index ef5e85e6c7..a28dc66fcf 100644 --- a/examples/python_rs/llm/vllm/kv_router/router.py +++ b/examples/python_rs/llm/vllm/kv_router/router.py @@ -94,7 +94,7 @@ def __init__( def _cost_function(self, scores, metrics): # naive cost function for demonstration purposes current_best = ("", 0) - for worker_id, score in scores.scores().items(): + for worker_id, score in scores.scores.items(): if score > current_best[1]: current_best = (worker_id, score) for endpoint in metrics.endpoints: From 51e3267f8281e2a543168006ed058b20756092db Mon Sep 17 00:00:00 2001 From: Guan Luo Date: Wed, 5 Mar 2025 15:44:03 -0800 Subject: [PATCH 13/13] chore: adddress comment --- lib/llm/src/kv_router/metrics_aggregator.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/llm/src/kv_router/metrics_aggregator.rs b/lib/llm/src/kv_router/metrics_aggregator.rs index 5fc21c6b71..9233c61017 100644 --- a/lib/llm/src/kv_router/metrics_aggregator.rs +++ b/lib/llm/src/kv_router/metrics_aggregator.rs @@ -55,7 +55,7 @@ impl KvMetricsAggregator { } }, None => { - tracing::trace!("endpoint subscriber shutdown"); + tracing::warn!("endpoint subscriber shutdown"); break; } };