diff --git a/src/binary_heap.rs b/src/binary_heap.rs index cf5b12c125..aa56c57b72 100644 --- a/src/binary_heap.rs +++ b/src/binary_heap.rs @@ -12,6 +12,7 @@ use core::{ fmt, marker::PhantomData, mem::{self, ManuallyDrop}, + ops::{Deref, DerefMut}, ptr, slice, }; @@ -226,6 +227,44 @@ where self.0.data.as_slice().get(0) } + /// Returns a mutable reference to the greatest item in the binary heap, or + /// `None` if it is empty. + /// + /// Note: If the `PeekMut` value is leaked, the heap may be in an + /// inconsistent state. + /// + /// # Examples + /// + /// Basic usage: + /// + /// ``` + /// use heapless::binary_heap::{BinaryHeap, Max}; + /// use heapless::consts::*; + /// + /// let mut heap: BinaryHeap<_, U8, Max> = BinaryHeap::new(); + /// assert!(heap.peek_mut().is_none()); + /// + /// heap.push(1); + /// heap.push(5); + /// heap.push(2); + /// { + /// let mut val = heap.peek_mut().unwrap(); + /// *val = 0; + /// } + /// + /// assert_eq!(heap.peek(), Some(&2)); + /// ``` + pub fn peek_mut(&mut self) -> Option> { + if self.is_empty() { + None + } else { + Some(PeekMut { + heap: self, + sift: true, + }) + } + } + /// Removes the *top* (greatest if max-heap, smallest if min-heap) item from the binary heap and /// returns it, or None if it is empty. /// @@ -390,6 +429,78 @@ impl<'a, T> Hole<'a, T> { } } +/// Structure wrapping a mutable reference to the greatest item on a +/// `BinaryHeap`. +/// +/// This `struct` is created by the [`peek_mut`] method on [`BinaryHeap`]. See +/// its documentation for more. +/// +/// [`peek_mut`]: struct.BinaryHeap.html#method.peek_mut +/// [`BinaryHeap`]: struct.BinaryHeap.html +pub struct PeekMut<'a, T, N, K> +where + T: Ord, + N: ArrayLength, + K: Kind, +{ + heap: &'a mut BinaryHeap, + sift: bool, +} + +impl Drop for PeekMut<'_, T, N, K> +where + T: Ord, + N: ArrayLength, + K: Kind, +{ + fn drop(&mut self) { + if self.sift { + self.heap.sift_down_to_bottom(0); + } + } +} + +impl Deref for PeekMut<'_, T, N, K> +where + T: Ord, + N: ArrayLength, + K: Kind, +{ + type Target = T; + fn deref(&self) -> &T { + debug_assert!(!self.heap.is_empty()); + // SAFE: PeekMut is only instantiated for non-empty heaps + unsafe { self.heap.0.data.as_slice().get_unchecked(0) } + } +} + +impl DerefMut for PeekMut<'_, T, N, K> +where + T: Ord, + N: ArrayLength, + K: Kind, +{ + fn deref_mut(&mut self) -> &mut T { + debug_assert!(!self.heap.is_empty()); + // SAFE: PeekMut is only instantiated for non-empty heaps + unsafe { self.heap.0.data.as_mut_slice().get_unchecked_mut(0) } + } +} + +impl<'a, T, N, K> PeekMut<'a, T, N, K> +where + T: Ord, + N: ArrayLength, + K: Kind, +{ + /// Removes the peeked value from the heap and returns it. + pub fn pop(mut this: PeekMut<'a, T, N, K>) -> T { + let value = this.heap.pop().unwrap(); + this.sift = false; + value + } +} + impl<'a, T> Drop for Hole<'a, T> { #[inline] fn drop(&mut self) { @@ -510,6 +621,22 @@ mod tests { assert_eq!(heap.pop(), Some(36)); assert_eq!(heap.pop(), Some(100)); assert_eq!(heap.pop(), None); + + assert!(heap.peek_mut().is_none()); + + heap.push(1).unwrap(); + heap.push(2).unwrap(); + heap.push(10).unwrap(); + + { + let mut val = heap.peek_mut().unwrap(); + *val = 7; + } + + assert_eq!(heap.pop(), Some(2)); + assert_eq!(heap.pop(), Some(7)); + assert_eq!(heap.pop(), Some(10)); + assert_eq!(heap.pop(), None); } #[test] @@ -546,5 +673,21 @@ mod tests { assert_eq!(heap.pop(), Some(2)); assert_eq!(heap.pop(), Some(1)); assert_eq!(heap.pop(), None); + + assert!(heap.peek_mut().is_none()); + + heap.push(1).unwrap(); + heap.push(9).unwrap(); + heap.push(10).unwrap(); + + { + let mut val = heap.peek_mut().unwrap(); + *val = 7; + } + + assert_eq!(heap.pop(), Some(9)); + assert_eq!(heap.pop(), Some(7)); + assert_eq!(heap.pop(), Some(1)); + assert_eq!(heap.pop(), None); } }