Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 71 additions & 4 deletions examples/python_rs/llm/vllm/kv_router/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from common.protocol import Tokens
from vllm.logger import logger as vllm_logger

from dynemo.llm import KvRouter
from dynemo.llm import KvIndexer, KvMetricsAggregator, KvRouter
from dynemo.runtime import DistributedRuntime, dynemo_endpoint, dynemo_worker

WorkerId = str
Expand Down Expand Up @@ -78,6 +78,60 @@ async def generate(self, request) -> AsyncIterator[WorkerId]:
)


class CustomRouter:
"""
Request handler for the generate endpoint
"""

def __init__(
self,
indexer: KvIndexer,
metrics_aggregator: KvMetricsAggregator,
):
self.indexer = indexer
self.metrics_aggregator = metrics_aggregator

def _cost_function(self, scores, metrics):
# naive cost function for demonstration purposes
current_best = ("", 0)
for worker_id, score in scores.scores.items():
if score > current_best[1]:
current_best = (worker_id, score)
for endpoint in metrics.endpoints:
if endpoint.worker_id == current_best[0]:
print(f"Metrics of endpoint: {endpoint.worker_id}")
print(
f"request slot usage: {endpoint.request_active_slots} / {endpoint.request_total_slots}"
)
print(
f"KV block usage: {endpoint.kv_active_blocks} / {endpoint.kv_total_blocks}"
)
return current_best[0]

@dynemo_endpoint(Tokens, WorkerId)
async def generate(self, request) -> AsyncIterator[WorkerId]:
lora_id = 0
worker_id = ""
try:
scores = await self.indexer.find_matches_for_request(
request.tokens, lora_id
)
metrics = await self.metrics_aggregator.get_metrics()
worker_id = self._cost_function(scores, metrics)

# [NOTE][TODO] Now that the scheduler may return more error messages,
# now we are catching all exceptions and logging them. Should have
# catch specific router exceptions once we have dedicated types.
except Exception as e:
vllm_logger.info(f"{e}")
worker_id = ""
vllm_logger.exception(f"Error during worker selection: {e}")

vllm_logger.info(f"Scheduling to worker_id: {worker_id}")

yield str(worker_id)


@dynemo_worker()
async def worker(runtime: DistributedRuntime, args: Namespace):
"""
Expand Down Expand Up @@ -116,10 +170,17 @@ async def worker(runtime: DistributedRuntime, args: Namespace):
router_component = runtime.namespace("dynemo").component("router")
await router_component.create_service()

router = KvRouter(runtime, kv_listener)

endpoint = router_component.endpoint("generate")
await endpoint.serve_endpoint(Router(router, args.routing_strategy).generate)

if args.custom_router:
indexer = KvIndexer(kv_listener)
metrics_aggregator = KvMetricsAggregator(kv_listener)
await endpoint.serve_endpoint(
CustomRouter(indexer, metrics_aggregator).generate
)
else:
router = KvRouter(runtime, kv_listener)
await endpoint.serve_endpoint(Router(router, args.routing_strategy).generate)


if __name__ == "__main__":
Expand Down Expand Up @@ -147,6 +208,12 @@ async def worker(runtime: DistributedRuntime, args: Namespace):
default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
help="Model that is being served",
)
parser.add_argument(
"--custom-router",
type=bool,
default=False,
help="Whether to use custom router or not",
)
args = parser.parse_args()

asyncio.run(worker(args))
5 changes: 5 additions & 0 deletions lib/bindings/python/rust/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::model_card::ModelDeploymentCard>()?;
m.add_class::<llm::preprocessor::OAIChatPreprocessor>()?;
m.add_class::<llm::backend::Backend>()?;
m.add_class::<llm::kv::OverlapScores>()?;
m.add_class::<llm::kv::KvIndexer>()?;
m.add_class::<llm::kv::EndpointKvMetrics>()?;
m.add_class::<llm::kv::AggregatedMetrics>()?;
m.add_class::<llm::kv::KvMetricsAggregator>()?;

engine::add_to_module(m)?;

Expand Down
161 changes: 161 additions & 0 deletions lib/bindings/python/rust/llm/kv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashMap;

use super::*;
use llm_rs::kv_router::indexer::KvIndexerInterface;
use tracing;

#[pyclass]
pub(crate) struct KvRouter {
Expand Down Expand Up @@ -106,3 +110,160 @@ impl KvMetricsPublisher {
.map_err(to_pyerr)
}
}

#[pyclass]
#[derive(Clone)]
pub(crate) struct OverlapScores {
inner: llm_rs::kv_router::indexer::OverlapScores,
}

#[pymethods]
impl OverlapScores {
#[getter]
fn scores(&self) -> HashMap<llm_rs::kv_router::indexer::WorkerId, u32> {
self.inner.scores.clone()
}

#[getter]
fn frequencies(&self) -> Vec<usize> {
self.inner.frequencies.clone()
}
}

#[pyclass]
pub(crate) struct KvIndexer {
inner: Arc<llm_rs::kv_router::indexer::KvIndexer>,
}

#[pymethods]
impl KvIndexer {
#[new]
fn new(component: Component) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async {
let kv_subject = component
.inner
.event_subject(llm_rs::kv_router::KV_EVENT_SUBJECT);
let inner: Arc<llm_rs::kv_router::indexer::KvIndexer> =
llm_rs::kv_router::indexer::KvIndexer::new(
component.inner.drt().runtime().child_token(),
)
.into();
let mut kv_events_rx = component
.inner
.drt()
.nats_client()
.client()
.subscribe(kv_subject)
.await
.map_err(to_pyerr)?;
let kv_events_tx = inner.event_sender();

// [FIXME] this is the added functionality to the indexer to subscribe to kv events,
// should have been made to a trait and implemented here? i.e. AsyncEngine style
tokio::spawn(async move {
while let Some(event) = kv_events_rx.next().await {
let event: llm_rs::kv_router::indexer::RouterEvent =
serde_json::from_slice(&event.payload).unwrap();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should avoid unwrap() - this will panic and crash if deserialization of the event fails.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@GuanLuo , @alec-flowers - can you fix?

tracing::debug!("received kv event: {:?}", event);
if let Err(e) = kv_events_tx.send(event).await {
tracing::trace!(
"failed to send kv event to indexer; shutting down: {:?}",
e
);
}
}
});
Ok(Self { inner })
})
}

fn find_matches_for_request<'p>(
&self,
py: Python<'p>,
token_ids: Vec<u32>,
_lora_id: u64,
) -> PyResult<Bound<'p, PyAny>> {
let indexer = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let rs_overlap_scores = indexer
.find_matches_for_request(token_ids.as_slice())
.await
.map_err(to_pyerr)?;
Ok(OverlapScores {
inner: rs_overlap_scores,
})
})
}
}

#[pyclass]
#[derive(Clone)]
pub(crate) struct EndpointKvMetrics {
#[pyo3(get, set)]
pub worker_id: i64,
#[pyo3(get, set)]
pub request_active_slots: u64,
#[pyo3(get, set)]
pub request_total_slots: u64,
#[pyo3(get, set)]
pub kv_active_blocks: u64,
#[pyo3(get, set)]
pub kv_total_blocks: u64,
}

#[pyclass]
#[derive(Clone)]
pub(crate) struct AggregatedMetrics {
#[pyo3(get, set)]
pub endpoints: Vec<EndpointKvMetrics>,
#[pyo3(get, set)]
pub load_avg: f64,
#[pyo3(get, set)]
pub load_std: f64,
}

#[pyclass]
pub(crate) struct KvMetricsAggregator {
inner: Arc<llm_rs::kv_router::metrics_aggregator::KvMetricsAggregator>,
}

#[pymethods]
impl KvMetricsAggregator {
#[new]
fn new(component: Component) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async {
let inner = llm_rs::kv_router::metrics_aggregator::KvMetricsAggregator::new(
component.inner.clone(),
component.inner.drt().runtime().child_token(),
)
.await;
Ok(Self {
inner: inner.into(),
})
})
}

fn get_metrics<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let endpoints = self.inner.get_endpoints();
let endpoint_kv_metrics = endpoints
.endpoints
.iter()
.map(|x| EndpointKvMetrics {
worker_id: x.worker_id(),
request_active_slots: x.data.request_active_slots,
request_total_slots: x.data.request_total_slots,
kv_active_blocks: x.data.kv_active_blocks,
kv_total_blocks: x.data.kv_total_blocks,
})
.collect();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
Ok(AggregatedMetrics {
endpoints: endpoint_kv_metrics,
load_avg: endpoints.load_avg,
load_std: endpoints.load_std,
})
})
}
}
52 changes: 52 additions & 0 deletions lib/bindings/python/src/dynemo/_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,55 @@ class Backend:
Start the backend engine and requests to the downstream LLM engine
"""
...


class OverlapScores:
"""
A collection of prefix matching scores of workers for a given token ids.
'scores' is a map of worker id to the score which is the number of matching blocks.
"""

...

class KvIndexer:
"""
A KV Indexer that tracks KV Events emitted by workers. Events include add_block and remove_block.
"""

...

def __init__(self, component: Component) -> None:
"""
Create a `KvIndexer` object
"""

def find_matches_for_request(self, token_ids: List[int], lora_id: int) -> OverlapScores:
"""
Return the overlapping scores of workers for the given token ids.
"""
...

class AggregatedMetrics:
"""
A collection of metrics of the endpoints
"""

...

class KvMetricsAggregator:
"""
A metrics aggregator will collect KV metrics of the endpoints.
"""

...

def __init__(self, component: Component) -> None:
"""
Create a `KvMetricsAggregator` object
"""

def get_metrics(self) -> AggregatedMetrics:
"""
Return the aggregated metrics of the endpoints.
"""
...
2 changes: 2 additions & 0 deletions lib/bindings/python/src/dynemo/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dynemo._core import KvIndexer as KvIndexer
from dynemo._core import KvMetricsAggregator as KvMetricsAggregator
from dynemo._core import KvMetricsPublisher as KvMetricsPublisher
from dynemo._core import KvRouter as KvRouter
1 change: 1 addition & 0 deletions lib/llm/src/kv_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use tokio_util::sync::CancellationToken;
use tracing;

pub mod indexer;
pub mod metrics_aggregator;
pub mod protocols;
pub mod publisher;
pub mod scheduler;
Expand Down
Loading
Loading