diff --git a/src/heavykeeper.rs b/src/heavykeeper.rs index 56a322d..60c7e6f 100644 --- a/src/heavykeeper.rs +++ b/src/heavykeeper.rs @@ -18,18 +18,19 @@ struct Bucket { } #[derive(Clone, PartialEq, Eq, Debug)] -pub struct Node { +pub struct Node { pub item: T, pub count: u64, + pub data: A, } -impl Ord for Node { +impl Ord for Node { fn cmp(&self, other: &Self) -> std::cmp::Ordering { other.count.cmp(&self.count) // Reverse ordering for min-heap } } -impl PartialOrd for Node { +impl PartialOrd for Node { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } @@ -66,14 +67,14 @@ pub enum BuilderError { MissingField { field: String }, } -pub struct TopK { +pub struct TopK { top_items: usize, width: usize, depth: usize, decay: f64, decay_thresholds: Vec, buckets: Vec>, - priority_queue: TopKQueue, + priority_queue: TopKQueue, hasher: RandomState, random: Box, } @@ -99,7 +100,7 @@ fn precompute_decay_thresholds(decay: f64, num_entries: usize) -> Vec { thresholds } -impl TopK { +impl TopK { pub fn builder() -> Builder { Builder::new() } @@ -243,7 +244,7 @@ impl TopK { } } - pub fn add(&mut self, item: &Q, increment: u64) -> Option<(u64, Option)> + pub fn add(&mut self, item: &Q, increment: u64) -> Option<(u64, A)> where T: Borrow, Q: Hash + Eq + ToOwned + ?Sized, @@ -306,17 +307,18 @@ impl TopK { } // Clone the item here since we need to store it in the priority queue - let evicted = self.priority_queue.upsert(item.to_owned(), max_count); - Some((max_count, evicted)) + let data = self.priority_queue.upsert(item.to_owned(), max_count)?; + Some((max_count, data)) } - pub fn list(&self) -> Vec> { + pub fn list(&self) -> Vec> { let mut nodes = self .priority_queue .iter() - .map(|(item, count)| Node { + .map(|(item, count, associated_data)| Node { item: item.clone(), count, + data: associated_data.clone(), }) .collect::>(); nodes.sort(); @@ -347,9 +349,10 @@ impl TopK { let mut nodes = self .priority_queue .iter() - .map(|(item, count)| Node { + .map(|(item, count, associated_data)| Node { item: item.clone(), count, + data: associated_data.clone(), }) .collect::>(); @@ -405,7 +408,7 @@ impl TopK { } // Merge priority queues - for (item, count) in other.priority_queue.iter() { + for (item, count, _associated_data) in other.priority_queue.iter() { let self_count = self.priority_queue.get(item).unwrap_or(0); self.priority_queue.upsert(item.clone(), self_count + count); } @@ -720,9 +723,10 @@ mod tests { let nodes = topk .priority_queue .iter() - .map(|(item, count)| Node { + .map(|(item, count, associated_data)| Node { item: item.clone(), count, + data: associated_data.clone(), }) .collect::>(); @@ -840,9 +844,10 @@ mod tests { let top_items = topk .priority_queue .iter() - .map(|(item, count)| Node { + .map(|(item, count, associated_data)| Node { item: std::str::from_utf8(item).unwrap().to_string().into_bytes(), count, + data: associated_data.clone(), }) .collect::>(); @@ -1014,9 +1019,10 @@ mod tests { let top_items = topk .priority_queue .iter() - .map(|(item, count)| Node { + .map(|(item, count, associated_data)| Node { item: std::str::from_utf8(item).unwrap().to_string().into_bytes(), count, + data: associated_data.clone(), }) .collect::>(); diff --git a/src/priority_queue.rs b/src/priority_queue.rs index 9827898..1c74ead 100644 --- a/src/priority_queue.rs +++ b/src/priority_queue.rs @@ -4,16 +4,16 @@ use std::collections::HashMap; use std::hash::Hash; /// A specialized priority queue for HeavyKeeper that maintains top-k items by count -pub(crate) struct TopKQueue { - items: HashMap, // item -> (count, heap_index) - heap: Vec<(u64, usize, usize)>, // (count, sequence, item_index) - item_store: Vec, // Store actual items here - free_slots: Vec, // Track free slots in item_store +pub(crate) struct TopKQueue { + items: HashMap, // item -> (count, heap_index, associated_data) + heap: Vec<(u64, usize, usize)>, // (count, sequence, item_index) + item_store: Vec, // Store actual items here + free_slots: Vec, // Track free slots in item_store capacity: usize, sequence: usize, } -impl TopKQueue { +impl TopKQueue { pub(crate) fn with_capacity_and_hasher(capacity: usize, hasher: RandomState) -> Self { Self { items: HashMap::with_capacity_and_hasher(capacity, hasher), @@ -39,7 +39,7 @@ impl TopKQueue { T: Borrow, Q: Hash + Eq + ToOwned + ?Sized, { - self.items.get(item).map(|(count, _)| *count) + self.items.get(item).map(|(count, _, _)| *count) } pub(crate) fn min_count(&self) -> u64 { @@ -52,22 +52,23 @@ impl TopKQueue { self.items.len() >= self.capacity } - /// Returns `Some(k)` if the existing item `k` is evicted, otherwise None - pub(crate) fn upsert(&mut self, item: T, count: u64) -> Option { + /// Returns cloned associated data if item is in top-k, None otherwise + pub(crate) fn upsert(&mut self, item: T, count: u64) -> Option { // Fast path: update existing item - if let Some((old_count, pos)) = self.items.get_mut(&item) { + if let Some((old_count, pos, data)) = self.items.get_mut(&item) { if count == *old_count { - return None; + return Some(data.clone()); } *old_count = count; - // Update heap - no need to clone item + // Clone data first, then do heap operations + let data_clone = data.clone(); let pos = *pos; let item_idx = self.heap[pos].2; self.heap[pos] = (count, self.heap[pos].1, item_idx); self.sift_down(pos); self.sift_up(pos); - return None; + return Some(data_clone); } // For new items, if we have space just add it @@ -85,9 +86,10 @@ impl TopKQueue { }; self.heap.push((count, self.sequence, item_idx)); - self.items.insert(item, (count, pos)); + let inserted_entry = self.items.entry(item).or_insert((count, pos, A::default())); + let data = inserted_entry.2.clone(); self.sift_up(pos); - return None; + return Some(data); } // Queue is full - check if new count beats minimum @@ -96,20 +98,21 @@ impl TopKQueue { // Reuse the item slot let old_item = std::mem::replace(&mut self.item_store[item_idx], item.clone()); self.items.remove(&old_item); - self.items.insert(item, (count, 0)); + let inserted_entry = self.items.entry(item).or_insert((count, 0, A::default())); + let data = inserted_entry.2.clone(); self.sequence += 1; self.heap[0] = (count, self.sequence, item_idx); self.sift_down(0); - return Some(old_item); + return Some(data); } } None } - pub(crate) fn iter(&self) -> impl Iterator { - let mut items: Vec<_> = self.items.iter().map(|(k, v)| (k, v.0)).collect(); + pub(crate) fn iter(&self) -> impl Iterator { + let mut items: Vec<_> = self.items.iter().map(|(k, v)| (k, v.0, &v.2)).collect(); // Sort by count descending, then by sequence ascending - items.sort_unstable_by(|(k1, v1), (k2, v2)| { + items.sort_unstable_by(|(k1, v1, _), (k2, v2, _)| { match v2.cmp(v1) { std::cmp::Ordering::Equal => { // For equal counts, compare sequence numbers @@ -189,10 +192,10 @@ impl TopKQueue { let item_j = &self.item_store[item_idx_j]; // Update the positions in the items map - if let Some((_, pos_i)) = self.items.get_mut(item_i) { + if let Some((_, pos_i, _)) = self.items.get_mut(item_i) { *pos_i = i; } - if let Some((_, pos_j)) = self.items.get_mut(item_j) { + if let Some((_, pos_j, _)) = self.items.get_mut(item_j) { *pos_j = j; } } @@ -202,7 +205,7 @@ impl TopKQueue { entry.0 = (entry.0 as f64 * scale_factor) as u64; } - for (_item, (count, _)) in self.items.iter_mut() { + for (_item, (count, _, _)) in self.items.iter_mut() { *count = (*count as f64 * scale_factor) as u64; } } @@ -219,7 +222,7 @@ mod tests { queue.upsert("b", 2); let items: Vec<_> = queue.iter().collect(); - assert_eq!(items, vec![(&"b", 2), (&"a", 1)]); + assert_eq!(items, vec![(&"b", 2, &()), (&"a", 1, &())]); } #[test] @@ -230,7 +233,7 @@ mod tests { queue.upsert("a", 3); // Update a's count let items: Vec<_> = queue.iter().collect(); - assert_eq!(items, vec![(&"a", 3), (&"b", 2)]); + assert_eq!(items, vec![(&"a", 3, &()), (&"b", 2, &())]); } #[test] @@ -253,7 +256,7 @@ mod tests { assert_eq!(queue.heap.len(), 2, "Expected 2 items"); let items: Vec<_> = queue.iter().collect(); - assert_eq!(items, vec![(&"c", 6), (&"a", 5)]); + assert_eq!(items, vec![(&"c", 6, &()), (&"a", 5, &())]); } #[test] @@ -266,7 +269,7 @@ mod tests { queue.upsert("c", 1); let items: Vec<_> = queue.iter().collect(); - assert_eq!(items, vec![(&"a", 1), (&"b", 1), (&"c", 1)]); + assert_eq!(items, vec![(&"a", 1, &()), (&"b", 1, &()), (&"c", 1, &())]); } #[test] @@ -303,7 +306,7 @@ mod tests { assert_eq!(queue.len(), 2, "Queue should maintain capacity"); let items: Vec<_> = queue.iter().collect(); - assert_eq!(items, vec![(&"e", 5), (&"d", 4)]); + assert_eq!(items, vec![(&"e", 5, &()), (&"d", 4, &())]); } #[test] @@ -320,7 +323,7 @@ mod tests { assert_eq!(queue.len(), 2); let items: Vec<_> = queue.iter().collect(); - assert_eq!(items, vec![(&"a", 99), (&"b", 50)]); + assert_eq!(items, vec![(&"a", 99, &()), (&"b", 50, &())]); } #[test]