Skip to content

Commit 107ca09

Browse files
committed
should be functional
1 parent e6b4c3b commit 107ca09

File tree

9 files changed

+469
-261
lines changed

9 files changed

+469
-261
lines changed

components/metrics/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ use std::net::SocketAddr;
8484
use std::time::Duration as StdDuration;
8585

8686
use dynamo_llm::kv_router::protocols::{ForwardPassMetrics, LoadMetrics};
87-
use dynamo_llm::kv_router::scheduler::Endpoint;
87+
use dynamo_llm::kv_router::scoring::Endpoint;
8888
use dynamo_llm::kv_router::scoring::ProcessedEndpoints;
8989

9090
use dynamo_runtime::{

lib/bindings/python/Cargo.lock

Lines changed: 16 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

lib/llm/src/kv_router.rs

Lines changed: 5 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ use dynamo_runtime::{
1515
protocols::annotated::Annotated,
1616
};
1717
use futures::stream::{self, StreamExt};
18-
use tokio::sync::Mutex;
1918

2019
pub mod approx;
2120
pub mod indexer;
@@ -138,10 +137,6 @@ pub struct KvRouter {
138137
scheduler: KvScheduler,
139138

140139
block_size: u32,
141-
142-
// To ensure blocking reads / writes
143-
// TODO: benchmark tradeoffs
144-
find_best_match_mutex: Mutex<()>,
145140
}
146141

147142
impl KvRouter {
@@ -178,13 +173,8 @@ impl KvRouter {
178173
))
179174
};
180175

181-
let scheduler = KvScheduler::start(
182-
component.namespace().clone(),
183-
block_size,
184-
instances_rx,
185-
selector,
186-
)
187-
.await?;
176+
let scheduler =
177+
KvScheduler::start(component.clone(), block_size, instances_rx, selector).await?;
188178

189179
// [gluo TODO] try subscribe_with_type::<RouterEvent>,
190180
// error checking below will be different.
@@ -218,7 +208,6 @@ impl KvRouter {
218208
indexer,
219209
scheduler,
220210
block_size,
221-
find_best_match_mutex: Mutex::new(()), // Add this
222211
})
223212
}
224213

@@ -230,28 +219,19 @@ impl KvRouter {
230219
context_id: &str,
231220
tokens: &[u32],
232221
) -> anyhow::Result<(i64, u32)> {
233-
// Acquire mutex to serialize access
234-
// TODO: may as well make all the subroutines synchronous if benchmarking favors this
235-
let _guard = self.find_best_match_mutex.lock().await;
236-
237222
let isl_tokens = tokens.len();
238223

239224
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
240-
let seq_hashes = compute_seq_hash_for_block(&block_hashes);
241225

242226
let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
243227

244228
let best_worker_id = self
245229
.scheduler
246-
.schedule(
247-
context_id.to_string(),
248-
isl_tokens,
249-
seq_hashes.clone(),
250-
overlap_scores.clone(),
251-
)
230+
.schedule(context_id.to_string(), isl_tokens, overlap_scores.clone())
252231
.await?;
253232

254233
if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer {
234+
let seq_hashes = compute_seq_hash_for_block(&block_hashes);
255235
indexer
256236
.process_routing_decision(best_worker_id, block_hashes, seq_hashes)
257237
.await
@@ -267,15 +247,10 @@ impl KvRouter {
267247
}
268248

269249
/// Free all blocks associated with a request
270-
pub async fn mark_prefill_completed(&self, request_id: &String) {
250+
pub async fn mark_prefill_completed(&self, request_id: &str) {
271251
self.scheduler.mark_prefill_completed(request_id).await
272252
}
273253

274-
/// Free all blocks associated with a request
275-
pub async fn free(&self, request_id: &String) {
276-
self.scheduler.free(request_id).await
277-
}
278-
279254
/// Get the block size this router was configured with
280255
pub fn block_size(&self) -> u32 {
281256
self.block_size
@@ -362,8 +337,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
362337
while let Some(item) = response_stream.next().await {
363338
yield item;
364339
}
365-
366-
chooser.free(&context_id).await;
367340
});
368341
Ok(ResponseStream::new(wrapped_stream, stream_context))
369342
}

lib/llm/src/kv_router/metrics_aggregator.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use std::sync::Once;
1818
pub use crate::kv_router::protocols::{ForwardPassMetrics, LoadMetrics, PredictiveLoadMetrics};
1919
use crate::kv_router::KV_METRICS_ENDPOINT;
2020

21-
use crate::kv_router::scheduler::Endpoint;
21+
use crate::kv_router::scoring::Endpoint;
2222
use crate::kv_router::ProcessedEndpoints;
2323
use dynamo_runtime::component::Component;
2424
use dynamo_runtime::{service::EndpointInfo, utils::Duration, Result};

lib/llm/src/kv_router/prefill_counter.rs

Lines changed: 120 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,41 @@
22
// SPDX-License-Identifier: Apache-2.0
33

44
use anyhow::Result;
5-
use dynamo_runtime::component::{Component, Namespace};
5+
use dynamo_runtime::component::Component;
66
use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber};
7+
use futures::stream::FuturesUnordered;
78
use futures::StreamExt;
89
use std::sync::atomic::{AtomicUsize, Ordering};
910
use std::sync::Arc;
1011
use tokio::sync::RwLock;
12+
// Remove the Mutex import since we're using DashMap
1113

1214
use super::protocols::{PrefillEvent, PrefillEventData};
1315
use crate::kv_router::PREFILL_SUBJECT;
1416
use dashmap::DashMap;
17+
use std::collections::HashMap;
18+
use std::hash::Hash;
19+
20+
pub fn get_snapshot<K, V>(state: &DashMap<K, V>) -> HashMap<K, V>
21+
where
22+
K: Clone + Hash + Eq,
23+
V: Copy,
24+
{
25+
state
26+
.iter()
27+
.map(|entry| (entry.key().clone(), *entry.value()))
28+
.collect()
29+
}
1530

1631
/// A counter that tracks pending prefill tokens for each request.
1732
///
1833
/// This struct maintains a local hashmap of request_id to token count,
1934
/// a running sum of all tokens, and subscribes to prefill events over NATS
2035
/// to keep the counts synchronized across components.
36+
#[derive(Clone)]
2137
pub struct PrefillCounter {
2238
state: Arc<RwLock<PrefillCounterState>>,
23-
namespace: Namespace,
39+
component: Component,
2440
}
2541

2642
struct PrefillCounterState {
@@ -78,14 +94,14 @@ impl PrefillCounter {
7894

7995
let counter = Self {
8096
state: state.clone(),
81-
namespace: component.namespace().clone(),
97+
component: component.clone(),
8298
};
8399

84100
let state_clone = state.clone();
85-
let namespace_clone = counter.namespace.clone();
101+
let component_clone = component.clone();
86102

87103
tokio::spawn(async move {
88-
if let Err(e) = Self::subscribe_to_events(state_clone, namespace_clone).await {
104+
if let Err(e) = Self::subscribe_to_events(state_clone, component_clone).await {
89105
tracing::error!("Error in prefill events subscription: {}", e);
90106
}
91107
});
@@ -97,9 +113,9 @@ impl PrefillCounter {
97113
/// TODO: somehow try to block events that are sent by itself
98114
async fn subscribe_to_events(
99115
state: Arc<RwLock<PrefillCounterState>>,
100-
namespace: Namespace,
116+
component: Component,
101117
) -> Result<()> {
102-
let mut subscriber = namespace
118+
let mut subscriber = component
103119
.subscribe_with_type::<PrefillEvent>(PREFILL_SUBJECT)
104120
.await?;
105121

@@ -135,7 +151,7 @@ impl PrefillCounter {
135151
.tokens_map
136152
.insert(event.request_id.clone(), new_tokens);
137153
}
138-
PrefillEventData::CompletePrefill(_) => {
154+
PrefillEventData::CompletePrefill => {
139155
let state_read = state.read().await;
140156
if !state_read.contains_key(&event.request_id) {
141157
continue;
@@ -164,7 +180,7 @@ impl PrefillCounter {
164180
PrefillEventData::NewPrefill(tokens)
165181
},
166182
};
167-
self.namespace.publish(PREFILL_SUBJECT, &event).await?;
183+
self.component.publish(PREFILL_SUBJECT, &event).await?;
168184

169185
Ok(old_value)
170186
}
@@ -173,12 +189,12 @@ impl PrefillCounter {
173189
let state = self.state.write().await;
174190
let removed_tokens = state.remove(request_id);
175191

176-
if let Some(tokens) = removed_tokens {
192+
if removed_tokens.is_some() {
177193
let event = PrefillEvent {
178194
request_id: request_id.to_string(),
179-
data: PrefillEventData::CompletePrefill(tokens),
195+
data: PrefillEventData::CompletePrefill,
180196
};
181-
self.namespace.publish(PREFILL_SUBJECT, &event).await?;
197+
self.component.publish(PREFILL_SUBJECT, &event).await?;
182198
}
183199

184200
Ok(removed_tokens)
@@ -203,6 +219,92 @@ impl PrefillCounter {
203219
let state = self.state.read().await;
204220
state.tokens_map.is_empty()
205221
}
222+
223+
/// Returns a snapshot of the current state as a HashMap
224+
pub async fn snapshot(&self) -> HashMap<String, usize> {
225+
let state = self.state.read().await;
226+
get_snapshot(&state.tokens_map)
227+
}
228+
}
229+
230+
/// A collection of PrefillCounters for multiple workers
231+
pub struct PrefillCountersMultiWorker {
232+
pub counters: DashMap<i64, PrefillCounter>,
233+
pub request_to_workers: DashMap<String, i64>,
234+
component: Component,
235+
}
236+
237+
impl PrefillCountersMultiWorker {
238+
pub fn new(component: Component) -> Self {
239+
Self {
240+
counters: DashMap::new(),
241+
request_to_workers: DashMap::new(),
242+
component,
243+
}
244+
}
245+
246+
pub async fn add_prefill(
247+
&self,
248+
worker_id: i64,
249+
request_id: String,
250+
new_tokens: usize,
251+
) -> Result<()> {
252+
if let Some(existing_worker_id) = self.request_to_workers.get(&request_id) {
253+
tracing::warn!(
254+
"Request {} already exists for worker {}, but trying to add to worker {}",
255+
request_id,
256+
*existing_worker_id,
257+
worker_id
258+
);
259+
}
260+
self.request_to_workers
261+
.insert(request_id.clone(), worker_id);
262+
263+
if let Some(counter) = self.counters.get(&worker_id) {
264+
counter.insert(request_id, new_tokens).await?;
265+
} else {
266+
tracing::warn!(
267+
"Worker {} does not exist, creating new PrefillCounter",
268+
worker_id
269+
);
270+
let new_counter = PrefillCounter::new(self.component.clone());
271+
new_counter.insert(request_id, new_tokens).await?;
272+
self.counters.insert(worker_id, new_counter);
273+
}
274+
275+
Ok(())
276+
}
277+
278+
pub async fn remove_prefill(&self, request_id: &str) -> Result<Option<usize>> {
279+
let Some((_request_id, worker_id)) = self.request_to_workers.remove(request_id) else {
280+
tracing::warn!("Request {} not found", request_id);
281+
return Ok(None);
282+
};
283+
284+
if let Some(counter) = self.counters.get(&worker_id) {
285+
counter.remove(request_id).await
286+
} else {
287+
tracing::warn!(
288+
"Worker {} not found in counters for request {}",
289+
worker_id,
290+
request_id
291+
);
292+
Ok(None)
293+
}
294+
}
295+
296+
/// Get the running sums for all workers as a HashMap<i64, usize>
297+
pub async fn running_sums(&self) -> HashMap<i64, usize> {
298+
let futures = FuturesUnordered::new();
299+
300+
for entry in self.counters.iter() {
301+
let worker_id = *entry.key();
302+
let counter = entry.value().clone();
303+
futures.push(async move { (worker_id, counter.running_sum().await) });
304+
}
305+
306+
futures.collect::<HashMap<_, _>>().await
307+
}
206308
}
207309

208310
#[cfg(test)]
@@ -222,22 +324,17 @@ mod integration_tests {
222324
let runtime = Runtime::from_current()?;
223325
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
224326

225-
// Create namespace and components for two counters
327+
// Create namespace and a single component
226328
let namespace = distributed.namespace("test_prefill_counter")?;
227-
let component1 = namespace
228-
.component("counter1")?
229-
.service_builder()
230-
.create()
231-
.await?;
232-
let component2 = namespace
233-
.component("counter2")?
329+
let component = namespace
330+
.component("shared_counter")?
234331
.service_builder()
235332
.create()
236333
.await?;
237334

238-
// Create two PrefillCounter instances
239-
let counter1 = PrefillCounter::new(component1);
240-
let counter2 = PrefillCounter::new(component2);
335+
// Create two PrefillCounter instances using the same component (cloned)
336+
let counter1 = PrefillCounter::new(component.clone());
337+
let counter2 = PrefillCounter::new(component.clone());
241338

242339
// Give some time for subscribers to initialize
243340
tokio::time::sleep(Duration::from_millis(2000)).await;

0 commit comments

Comments
 (0)