Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions src/heavykeeper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,19 @@ struct Bucket {
}

#[derive(Clone, PartialEq, Eq, Debug)]
pub struct Node<T> {
pub struct Node<T, A = ()> {
pub item: T,
pub count: u64,
pub data: A,
}

impl<T: Ord> Ord for Node<T> {
impl<T: Ord, A: PartialEq + Eq> Ord for Node<T, A> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other.count.cmp(&self.count) // Reverse ordering for min-heap
}
}

impl<T: Ord> PartialOrd for Node<T> {
impl<T: Ord, A: PartialEq + Eq> PartialOrd for Node<T, A> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
Expand Down Expand Up @@ -66,14 +67,14 @@ pub enum BuilderError {
MissingField { field: String },
}

pub struct TopK<T: Ord + Clone + Hash + Debug> {
pub struct TopK<T: Ord + Clone + Hash + Debug, A: Clone + Default + PartialEq + Eq = ()> {
top_items: usize,
width: usize,
depth: usize,
decay: f64,
decay_thresholds: Vec<u64>,
buckets: Vec<Vec<Bucket>>,
priority_queue: TopKQueue<T>,
priority_queue: TopKQueue<T, A>,
hasher: RandomState,
random: Box<dyn RngCore + Send>,
}
Expand All @@ -99,7 +100,7 @@ fn precompute_decay_thresholds(decay: f64, num_entries: usize) -> Vec<u64> {
thresholds
}

impl<T: Ord + Clone + Hash + Debug> TopK<T> {
impl<T: Ord + Clone + Hash + Debug, A: Clone + Default + PartialEq + Eq> TopK<T, A> {
pub fn builder() -> Builder<T> {
Builder::new()
}
Expand Down Expand Up @@ -243,7 +244,7 @@ impl<T: Ord + Clone + Hash + Debug> TopK<T> {
}
}

pub fn add<Q>(&mut self, item: &Q, increment: u64) -> Option<(u64, Option<T>)>
pub fn add<Q>(&mut self, item: &Q, increment: u64) -> Option<(u64, A)>
where
T: Borrow<Q>,
Q: Hash + Eq + ToOwned<Owned = T> + ?Sized,
Expand Down Expand Up @@ -306,17 +307,18 @@ impl<T: Ord + Clone + Hash + Debug> TopK<T> {
}

// Clone the item here since we need to store it in the priority queue
let evicted = self.priority_queue.upsert(item.to_owned(), max_count);
Some((max_count, evicted))
let data = self.priority_queue.upsert(item.to_owned(), max_count)?;
Some((max_count, data))
}

pub fn list(&self) -> Vec<Node<T>> {
pub fn list(&self) -> Vec<Node<T, A>> {
let mut nodes = self
.priority_queue
.iter()
.map(|(item, count)| Node {
.map(|(item, count, associated_data)| Node {
item: item.clone(),
count,
data: associated_data.clone(),
})
.collect::<Vec<_>>();
nodes.sort();
Expand Down Expand Up @@ -347,9 +349,10 @@ impl<T: Ord + Clone + Hash + Debug> TopK<T> {
let mut nodes = self
.priority_queue
.iter()
.map(|(item, count)| Node {
.map(|(item, count, associated_data)| Node {
item: item.clone(),
count,
data: associated_data.clone(),
})
.collect::<Vec<_>>();

Expand Down Expand Up @@ -405,7 +408,7 @@ impl<T: Ord + Clone + Hash + Debug> TopK<T> {
}

// Merge priority queues
for (item, count) in other.priority_queue.iter() {
for (item, count, _associated_data) in other.priority_queue.iter() {
let self_count = self.priority_queue.get(item).unwrap_or(0);
self.priority_queue.upsert(item.clone(), self_count + count);
}
Expand Down Expand Up @@ -720,9 +723,10 @@ mod tests {
let nodes = topk
.priority_queue
.iter()
.map(|(item, count)| Node {
.map(|(item, count, associated_data)| Node {
item: item.clone(),
count,
data: associated_data.clone(),
})
.collect::<Vec<_>>();

Expand Down Expand Up @@ -840,9 +844,10 @@ mod tests {
let top_items = topk
.priority_queue
.iter()
.map(|(item, count)| Node {
.map(|(item, count, associated_data)| Node {
item: std::str::from_utf8(item).unwrap().to_string().into_bytes(),
count,
data: associated_data.clone(),
})
.collect::<Vec<_>>();

Expand Down Expand Up @@ -1014,9 +1019,10 @@ mod tests {
let top_items = topk
.priority_queue
.iter()
.map(|(item, count)| Node {
.map(|(item, count, associated_data)| Node {
item: std::str::from_utf8(item).unwrap().to_string().into_bytes(),
count,
data: associated_data.clone(),
})
.collect::<Vec<_>>();

Expand Down
61 changes: 32 additions & 29 deletions src/priority_queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@ use std::collections::HashMap;
use std::hash::Hash;

/// A specialized priority queue for HeavyKeeper that maintains top-k items by count
pub(crate) struct TopKQueue<T> {
items: HashMap<T, (u64, usize), RandomState>, // item -> (count, heap_index)
heap: Vec<(u64, usize, usize)>, // (count, sequence, item_index)
item_store: Vec<T>, // Store actual items here
free_slots: Vec<usize>, // Track free slots in item_store
pub(crate) struct TopKQueue<T, A: PartialEq + Eq = ()> {
items: HashMap<T, (u64, usize, A), RandomState>, // item -> (count, heap_index, associated_data)
heap: Vec<(u64, usize, usize)>, // (count, sequence, item_index)
item_store: Vec<T>, // Store actual items here
free_slots: Vec<usize>, // Track free slots in item_store
capacity: usize,
sequence: usize,
}

impl<T: Ord + Clone + Hash + PartialEq> TopKQueue<T> {
impl<T: Ord + Clone + Hash + PartialEq, A: Clone + Default + PartialEq + Eq> TopKQueue<T, A> {
pub(crate) fn with_capacity_and_hasher(capacity: usize, hasher: RandomState) -> Self {
Self {
items: HashMap::with_capacity_and_hasher(capacity, hasher),
Expand All @@ -39,7 +39,7 @@ impl<T: Ord + Clone + Hash + PartialEq> TopKQueue<T> {
T: Borrow<Q>,
Q: Hash + Eq + ToOwned<Owned = T> + ?Sized,
{
self.items.get(item).map(|(count, _)| *count)
self.items.get(item).map(|(count, _, _)| *count)
}

pub(crate) fn min_count(&self) -> u64 {
Expand All @@ -52,22 +52,23 @@ impl<T: Ord + Clone + Hash + PartialEq> TopKQueue<T> {
self.items.len() >= self.capacity
}

/// Returns `Some(k)` if the existing item `k` is evicted, otherwise None
pub(crate) fn upsert(&mut self, item: T, count: u64) -> Option<T> {
/// Returns cloned associated data if item is in top-k, None otherwise
pub(crate) fn upsert(&mut self, item: T, count: u64) -> Option<A> {
// Fast path: update existing item
if let Some((old_count, pos)) = self.items.get_mut(&item) {
if let Some((old_count, pos, data)) = self.items.get_mut(&item) {
if count == *old_count {
return None;
return Some(data.clone());
}
*old_count = count;

// Update heap - no need to clone item
// Clone data first, then do heap operations
let data_clone = data.clone();
let pos = *pos;
let item_idx = self.heap[pos].2;
self.heap[pos] = (count, self.heap[pos].1, item_idx);
self.sift_down(pos);
self.sift_up(pos);
return None;
return Some(data_clone);
}

// For new items, if we have space just add it
Expand All @@ -85,9 +86,10 @@ impl<T: Ord + Clone + Hash + PartialEq> TopKQueue<T> {
};

self.heap.push((count, self.sequence, item_idx));
self.items.insert(item, (count, pos));
let inserted_entry = self.items.entry(item).or_insert((count, pos, A::default()));
let data = inserted_entry.2.clone();
self.sift_up(pos);
return None;
return Some(data);
}

// Queue is full - check if new count beats minimum
Expand All @@ -96,20 +98,21 @@ impl<T: Ord + Clone + Hash + PartialEq> TopKQueue<T> {
// Reuse the item slot
let old_item = std::mem::replace(&mut self.item_store[item_idx], item.clone());
self.items.remove(&old_item);
self.items.insert(item, (count, 0));
let inserted_entry = self.items.entry(item).or_insert((count, 0, A::default()));
let data = inserted_entry.2.clone();
self.sequence += 1;
self.heap[0] = (count, self.sequence, item_idx);
self.sift_down(0);
return Some(old_item);
return Some(data);
}
}
None
}

pub(crate) fn iter(&self) -> impl Iterator<Item = (&T, u64)> {
let mut items: Vec<_> = self.items.iter().map(|(k, v)| (k, v.0)).collect();
pub(crate) fn iter(&self) -> impl Iterator<Item = (&T, u64, &A)> {
let mut items: Vec<_> = self.items.iter().map(|(k, v)| (k, v.0, &v.2)).collect();
// Sort by count descending, then by sequence ascending
items.sort_unstable_by(|(k1, v1), (k2, v2)| {
items.sort_unstable_by(|(k1, v1, _), (k2, v2, _)| {
match v2.cmp(v1) {
std::cmp::Ordering::Equal => {
// For equal counts, compare sequence numbers
Expand Down Expand Up @@ -189,10 +192,10 @@ impl<T: Ord + Clone + Hash + PartialEq> TopKQueue<T> {
let item_j = &self.item_store[item_idx_j];

// Update the positions in the items map
if let Some((_, pos_i)) = self.items.get_mut(item_i) {
if let Some((_, pos_i, _)) = self.items.get_mut(item_i) {
*pos_i = i;
}
if let Some((_, pos_j)) = self.items.get_mut(item_j) {
if let Some((_, pos_j, _)) = self.items.get_mut(item_j) {
*pos_j = j;
}
}
Expand All @@ -202,7 +205,7 @@ impl<T: Ord + Clone + Hash + PartialEq> TopKQueue<T> {
entry.0 = (entry.0 as f64 * scale_factor) as u64;
}

for (_item, (count, _)) in self.items.iter_mut() {
for (_item, (count, _, _)) in self.items.iter_mut() {
*count = (*count as f64 * scale_factor) as u64;
}
}
Expand All @@ -219,7 +222,7 @@ mod tests {
queue.upsert("b", 2);

let items: Vec<_> = queue.iter().collect();
assert_eq!(items, vec![(&"b", 2), (&"a", 1)]);
assert_eq!(items, vec![(&"b", 2, &()), (&"a", 1, &())]);
}

#[test]
Expand All @@ -230,7 +233,7 @@ mod tests {
queue.upsert("a", 3); // Update a's count

let items: Vec<_> = queue.iter().collect();
assert_eq!(items, vec![(&"a", 3), (&"b", 2)]);
assert_eq!(items, vec![(&"a", 3, &()), (&"b", 2, &())]);
}

#[test]
Expand All @@ -253,7 +256,7 @@ mod tests {
assert_eq!(queue.heap.len(), 2, "Expected 2 items");

let items: Vec<_> = queue.iter().collect();
assert_eq!(items, vec![(&"c", 6), (&"a", 5)]);
assert_eq!(items, vec![(&"c", 6, &()), (&"a", 5, &())]);
}

#[test]
Expand All @@ -266,7 +269,7 @@ mod tests {
queue.upsert("c", 1);

let items: Vec<_> = queue.iter().collect();
assert_eq!(items, vec![(&"a", 1), (&"b", 1), (&"c", 1)]);
assert_eq!(items, vec![(&"a", 1, &()), (&"b", 1, &()), (&"c", 1, &())]);
}

#[test]
Expand Down Expand Up @@ -303,7 +306,7 @@ mod tests {
assert_eq!(queue.len(), 2, "Queue should maintain capacity");

let items: Vec<_> = queue.iter().collect();
assert_eq!(items, vec![(&"e", 5), (&"d", 4)]);
assert_eq!(items, vec![(&"e", 5, &()), (&"d", 4, &())]);
}

#[test]
Expand All @@ -320,7 +323,7 @@ mod tests {
assert_eq!(queue.len(), 2);

let items: Vec<_> = queue.iter().collect();
assert_eq!(items, vec![(&"a", 99), (&"b", 50)]);
assert_eq!(items, vec![(&"a", 99, &()), (&"b", 50, &())]);
}

#[test]
Expand Down
Loading