22// SPDX-License-Identifier: Apache-2.0
33
44use anyhow:: Result ;
5- use dynamo_runtime:: component:: { Component , Namespace } ;
5+ use dynamo_runtime:: component:: Component ;
66use dynamo_runtime:: traits:: events:: { EventPublisher , EventSubscriber } ;
7+ use futures:: stream:: FuturesUnordered ;
78use futures:: StreamExt ;
89use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
910use std:: sync:: Arc ;
1011use tokio:: sync:: RwLock ;
12+ // Remove the Mutex import since we're using DashMap
1113
1214use super :: protocols:: { PrefillEvent , PrefillEventData } ;
1315use crate :: kv_router:: PREFILL_SUBJECT ;
1416use dashmap:: DashMap ;
17+ use std:: collections:: HashMap ;
18+ use std:: hash:: Hash ;
19+
20+ pub fn get_snapshot < K , V > ( state : & DashMap < K , V > ) -> HashMap < K , V >
21+ where
22+ K : Clone + Hash + Eq ,
23+ V : Copy ,
24+ {
25+ state
26+ . iter ( )
27+ . map ( |entry| ( entry. key ( ) . clone ( ) , * entry. value ( ) ) )
28+ . collect ( )
29+ }
1530
1631/// A counter that tracks pending prefill tokens for each request.
1732///
1833/// This struct maintains a local hashmap of request_id to token count,
1934/// a running sum of all tokens, and subscribes to prefill events over NATS
2035/// to keep the counts synchronized across components.
36+ #[ derive( Clone ) ]
2137pub struct PrefillCounter {
2238 state : Arc < RwLock < PrefillCounterState > > ,
23- namespace : Namespace ,
39+ component : Component ,
2440}
2541
2642struct PrefillCounterState {
@@ -78,14 +94,14 @@ impl PrefillCounter {
7894
7995 let counter = Self {
8096 state : state. clone ( ) ,
81- namespace : component. namespace ( ) . clone ( ) ,
97+ component : component. clone ( ) ,
8298 } ;
8399
84100 let state_clone = state. clone ( ) ;
85- let namespace_clone = counter . namespace . clone ( ) ;
101+ let component_clone = component . clone ( ) ;
86102
87103 tokio:: spawn ( async move {
88- if let Err ( e) = Self :: subscribe_to_events ( state_clone, namespace_clone ) . await {
104+ if let Err ( e) = Self :: subscribe_to_events ( state_clone, component_clone ) . await {
89105 tracing:: error!( "Error in prefill events subscription: {}" , e) ;
90106 }
91107 } ) ;
@@ -97,9 +113,9 @@ impl PrefillCounter {
97113 /// TODO: somehow try to block events that are sent by itself
98114 async fn subscribe_to_events (
99115 state : Arc < RwLock < PrefillCounterState > > ,
100- namespace : Namespace ,
116+ component : Component ,
101117 ) -> Result < ( ) > {
102- let mut subscriber = namespace
118+ let mut subscriber = component
103119 . subscribe_with_type :: < PrefillEvent > ( PREFILL_SUBJECT )
104120 . await ?;
105121
@@ -135,7 +151,7 @@ impl PrefillCounter {
135151 . tokens_map
136152 . insert ( event. request_id . clone ( ) , new_tokens) ;
137153 }
138- PrefillEventData :: CompletePrefill ( _ ) => {
154+ PrefillEventData :: CompletePrefill => {
139155 let state_read = state. read ( ) . await ;
140156 if !state_read. contains_key ( & event. request_id ) {
141157 continue ;
@@ -164,7 +180,7 @@ impl PrefillCounter {
164180 PrefillEventData :: NewPrefill ( tokens)
165181 } ,
166182 } ;
167- self . namespace . publish ( PREFILL_SUBJECT , & event) . await ?;
183+ self . component . publish ( PREFILL_SUBJECT , & event) . await ?;
168184
169185 Ok ( old_value)
170186 }
@@ -173,12 +189,12 @@ impl PrefillCounter {
173189 let state = self . state . write ( ) . await ;
174190 let removed_tokens = state. remove ( request_id) ;
175191
176- if let Some ( tokens ) = removed_tokens {
192+ if removed_tokens . is_some ( ) {
177193 let event = PrefillEvent {
178194 request_id : request_id. to_string ( ) ,
179- data : PrefillEventData :: CompletePrefill ( tokens ) ,
195+ data : PrefillEventData :: CompletePrefill ,
180196 } ;
181- self . namespace . publish ( PREFILL_SUBJECT , & event) . await ?;
197+ self . component . publish ( PREFILL_SUBJECT , & event) . await ?;
182198 }
183199
184200 Ok ( removed_tokens)
@@ -203,6 +219,92 @@ impl PrefillCounter {
203219 let state = self . state . read ( ) . await ;
204220 state. tokens_map . is_empty ( )
205221 }
222+
223+ /// Returns a snapshot of the current state as a HashMap
224+ pub async fn snapshot ( & self ) -> HashMap < String , usize > {
225+ let state = self . state . read ( ) . await ;
226+ get_snapshot ( & state. tokens_map )
227+ }
228+ }
229+
230+ /// A collection of PrefillCounters for multiple workers
231+ pub struct PrefillCountersMultiWorker {
232+ pub counters : DashMap < i64 , PrefillCounter > ,
233+ pub request_to_workers : DashMap < String , i64 > ,
234+ component : Component ,
235+ }
236+
237+ impl PrefillCountersMultiWorker {
238+ pub fn new ( component : Component ) -> Self {
239+ Self {
240+ counters : DashMap :: new ( ) ,
241+ request_to_workers : DashMap :: new ( ) ,
242+ component,
243+ }
244+ }
245+
246+ pub async fn add_prefill (
247+ & self ,
248+ worker_id : i64 ,
249+ request_id : String ,
250+ new_tokens : usize ,
251+ ) -> Result < ( ) > {
252+ if let Some ( existing_worker_id) = self . request_to_workers . get ( & request_id) {
253+ tracing:: warn!(
254+ "Request {} already exists for worker {}, but trying to add to worker {}" ,
255+ request_id,
256+ * existing_worker_id,
257+ worker_id
258+ ) ;
259+ }
260+ self . request_to_workers
261+ . insert ( request_id. clone ( ) , worker_id) ;
262+
263+ if let Some ( counter) = self . counters . get ( & worker_id) {
264+ counter. insert ( request_id, new_tokens) . await ?;
265+ } else {
266+ tracing:: warn!(
267+ "Worker {} does not exist, creating new PrefillCounter" ,
268+ worker_id
269+ ) ;
270+ let new_counter = PrefillCounter :: new ( self . component . clone ( ) ) ;
271+ new_counter. insert ( request_id, new_tokens) . await ?;
272+ self . counters . insert ( worker_id, new_counter) ;
273+ }
274+
275+ Ok ( ( ) )
276+ }
277+
278+ pub async fn remove_prefill ( & self , request_id : & str ) -> Result < Option < usize > > {
279+ let Some ( ( _request_id, worker_id) ) = self . request_to_workers . remove ( request_id) else {
280+ tracing:: warn!( "Request {} not found" , request_id) ;
281+ return Ok ( None ) ;
282+ } ;
283+
284+ if let Some ( counter) = self . counters . get ( & worker_id) {
285+ counter. remove ( request_id) . await
286+ } else {
287+ tracing:: warn!(
288+ "Worker {} not found in counters for request {}" ,
289+ worker_id,
290+ request_id
291+ ) ;
292+ Ok ( None )
293+ }
294+ }
295+
296+ /// Get the running sums for all workers as a HashMap<i64, usize>
297+ pub async fn running_sums ( & self ) -> HashMap < i64 , usize > {
298+ let futures = FuturesUnordered :: new ( ) ;
299+
300+ for entry in self . counters . iter ( ) {
301+ let worker_id = * entry. key ( ) ;
302+ let counter = entry. value ( ) . clone ( ) ;
303+ futures. push ( async move { ( worker_id, counter. running_sum ( ) . await ) } ) ;
304+ }
305+
306+ futures. collect :: < HashMap < _ , _ > > ( ) . await
307+ }
206308}
207309
208310#[ cfg( test) ]
@@ -222,22 +324,17 @@ mod integration_tests {
222324 let runtime = Runtime :: from_current ( ) ?;
223325 let distributed = DistributedRuntime :: from_settings ( runtime. clone ( ) ) . await ?;
224326
225- // Create namespace and components for two counters
327+ // Create namespace and a single component
226328 let namespace = distributed. namespace ( "test_prefill_counter" ) ?;
227- let component1 = namespace
228- . component ( "counter1" ) ?
229- . service_builder ( )
230- . create ( )
231- . await ?;
232- let component2 = namespace
233- . component ( "counter2" ) ?
329+ let component = namespace
330+ . component ( "shared_counter" ) ?
234331 . service_builder ( )
235332 . create ( )
236333 . await ?;
237334
238- // Create two PrefillCounter instances
239- let counter1 = PrefillCounter :: new ( component1 ) ;
240- let counter2 = PrefillCounter :: new ( component2 ) ;
335+ // Create two PrefillCounter instances using the same component (cloned)
336+ let counter1 = PrefillCounter :: new ( component . clone ( ) ) ;
337+ let counter2 = PrefillCounter :: new ( component . clone ( ) ) ;
241338
242339 // Give some time for subscribers to initialize
243340 tokio:: time:: sleep ( Duration :: from_millis ( 2000 ) ) . await ;
0 commit comments