Skip to content

Commit

Permalink
feat: Serde Serialize and Deserialize traits for the RbTree in the ic…
Browse files Browse the repository at this point in the history
…-certified-map crate. (#399)

* This change adds the serde Serialize and Deserialize traits to the RbTree.

* Manual implementation of Serialize and Deserialize for the RbTree.

* replace 'static lifetime bounds with a custom lifetime 't.

* Added bincode serialization in the serde test.

* CandidType for the RbTree.

* update CHANGELOG with CandidType.
  • Loading branch information
levifeldman authored Jan 12, 2024
1 parent c01fc17 commit 1da310b
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 41 deletions.
3 changes: 3 additions & 0 deletions library/ic-certified-map/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added
- Implement CandidType, Serialize, and Deserialize for the RbTree.

## [0.4.0] - 2023-07-13

### Changed
Expand Down
3 changes: 2 additions & 1 deletion library/ic-certified-map/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ include = ["src", "Cargo.toml", "CHANGELOG.md", "LICENSE", "README.md"]
serde.workspace = true
serde_bytes.workspace = true
sha2.workspace = true
candid.workspace = true

[dev-dependencies]
hex.workspace = true
serde_cbor = "0.11"
ic-cdk.workspace = true
candid.workspace = true
bincode = "1.3.3"
166 changes: 126 additions & 40 deletions library/ic-certified-map/src/rbtree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ impl AsHashTree for Hash {
}
}

impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> AsHashTree for RbTree<K, V> {
impl<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't> AsHashTree for RbTree<K, V> {
fn root_hash(&self) -> Hash {
match self.root.as_ref() {
None => Empty.reconstruct(),
Expand Down Expand Up @@ -102,7 +102,7 @@ struct Node<K, V> {
subtree_hash: Hash,
}

impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> Node<K, V> {
impl<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't> Node<K, V> {
fn new(key: K, value: V) -> Box<Node<K, V>> {
let value_hash = value.root_hash();
let data_hash = labeled_hash(key.as_ref(), &value_hash);
Expand Down Expand Up @@ -274,47 +274,47 @@ pub struct RbTree<K, V> {
root: NodeRef<K, V>,
}

impl<K, V> PartialEq for RbTree<K, V>
impl<'t, K, V> PartialEq for RbTree<K, V>
where
K: 'static + AsRef<[u8]> + PartialEq,
V: 'static + AsHashTree + PartialEq,
K: 't + AsRef<[u8]> + PartialEq,
V: 't + AsHashTree + PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.iter().eq(other.iter())
}
}

impl<K, V> Eq for RbTree<K, V>
impl<'t, K, V> Eq for RbTree<K, V>
where
K: 'static + AsRef<[u8]> + Eq,
V: 'static + AsHashTree + Eq,
K: 't + AsRef<[u8]> + Eq,
V: 't + AsHashTree + Eq,
{
}

impl<K, V> PartialOrd for RbTree<K, V>
impl<'t, K, V> PartialOrd for RbTree<K, V>
where
K: 'static + AsRef<[u8]> + PartialOrd,
V: 'static + AsHashTree + PartialOrd,
K: 't + AsRef<[u8]> + PartialOrd,
V: 't + AsHashTree + PartialOrd,
{
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.iter().partial_cmp(other.iter())
}
}

impl<K, V> Ord for RbTree<K, V>
impl<'t, K, V> Ord for RbTree<K, V>
where
K: 'static + AsRef<[u8]> + Ord,
V: 'static + AsHashTree + Ord,
K: 't + AsRef<[u8]> + Ord,
V: 't + AsHashTree + Ord,
{
fn cmp(&self, other: &Self) -> Ordering {
self.iter().cmp(other.iter())
}
}

impl<K, V> std::iter::FromIterator<(K, V)> for RbTree<K, V>
impl<'t, K, V> std::iter::FromIterator<(K, V)> for RbTree<K, V>
where
K: 'static + AsRef<[u8]>,
V: 'static + AsHashTree,
K: 't + AsRef<[u8]>,
V: 't + AsHashTree,
{
fn from_iter<T>(iter: T) -> Self
where
Expand All @@ -328,10 +328,10 @@ where
}
}

impl<K, V> std::fmt::Debug for RbTree<K, V>
impl<'t, K, V> std::fmt::Debug for RbTree<K, V>
where
K: 'static + AsRef<[u8]> + std::fmt::Debug,
V: 'static + AsHashTree + std::fmt::Debug,
K: 't + AsRef<[u8]> + std::fmt::Debug,
V: 't + AsHashTree + std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "[")?;
Expand Down Expand Up @@ -359,7 +359,7 @@ impl<K, V> RbTree<K, V> {
}
}

impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {
impl<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't> RbTree<K, V> {
/// Looks up the key in the map and returns the associated value, if there is one.
pub fn get(&self, key: &[u8]) -> Option<&V> {
let mut root = self.root.as_ref();
Expand All @@ -375,7 +375,7 @@ impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {

/// Updates the value corresponding to the specified key.
pub fn modify(&mut self, key: &[u8], f: impl FnOnce(&mut V)) {
fn go<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
fn go<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
h: &mut NodeRef<K, V>,
k: &[u8],
f: impl FnOnce(&mut V),
Expand Down Expand Up @@ -506,7 +506,7 @@ impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {
lo: KeyBound<'a>,
f: fn(&'a Node<K, V>) -> HashTree<'a>,
) -> HashTree<'a> {
fn go<'a, K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
fn go<'a, 't, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
n: &'a NodeRef<K, V>,
lo: KeyBound<'a>,
f: fn(&'a Node<K, V>) -> HashTree<'a>,
Expand Down Expand Up @@ -543,7 +543,7 @@ impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {
hi: KeyBound<'a>,
f: fn(&'a Node<K, V>) -> HashTree<'a>,
) -> HashTree<'a> {
fn go<'a, K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
fn go<'a, 't, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
n: &'a NodeRef<K, V>,
hi: KeyBound<'a>,
f: fn(&'a Node<K, V>) -> HashTree<'a>,
Expand Down Expand Up @@ -587,7 +587,7 @@ impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {
lo.as_ref(),
hi.as_ref()
);
fn go<'a, K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
fn go<'a, 't, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
n: &'a NodeRef<K, V>,
lo: KeyBound<'a>,
hi: KeyBound<'a>,
Expand Down Expand Up @@ -645,7 +645,7 @@ impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {
}

fn lower_bound(&self, key: &[u8]) -> Option<KeyBound<'_>> {
fn go<'a, K: 'static + AsRef<[u8]>, V>(
fn go<'a, 't, K: 't + AsRef<[u8]>, V>(
n: &'a NodeRef<K, V>,
key: &[u8],
) -> Option<KeyBound<'a>> {
Expand All @@ -662,7 +662,7 @@ impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {
}

fn upper_bound(&self, key: &[u8]) -> Option<KeyBound<'_>> {
fn go<'a, K: 'static + AsRef<[u8]>, V>(
fn go<'a, 't, K: 't + AsRef<[u8]>, V>(
n: &'a NodeRef<K, V>,
key: &[u8],
) -> Option<KeyBound<'a>> {
Expand All @@ -685,7 +685,7 @@ impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {
}
&x[0..p.len()] == p
}
fn go<'a, K: 'static + AsRef<[u8]>, V>(
fn go<'a, 't, K: 't + AsRef<[u8]>, V>(
n: &'a NodeRef<K, V>,
prefix: &[u8],
) -> Option<KeyBound<'a>> {
Expand All @@ -706,7 +706,7 @@ impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {
key: &[u8],
f: impl FnOnce(&'a V) -> HashTree<'a>,
) -> Option<HashTree<'a>> {
fn go<'a, K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
fn go<'a, 't, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
n: &'a NodeRef<K, V>,
key: &[u8],
f: impl FnOnce(&'a V) -> HashTree<'a>,
Expand Down Expand Up @@ -740,7 +740,7 @@ impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {

/// Inserts a key-value entry into the map.
pub fn insert(&mut self, key: K, value: V) {
fn go<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
fn go<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
h: NodeRef<K, V>,
k: K,
v: V,
Expand Down Expand Up @@ -778,7 +778,7 @@ impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {

/// Removes the specified key from the map.
pub fn delete(&mut self, key: &[u8]) {
fn move_red_left<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
fn move_red_left<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
mut h: Box<Node<K, V>>,
) -> Box<Node<K, V>> {
flip_colors(&mut h);
Expand All @@ -790,7 +790,7 @@ impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {
h
}

fn move_red_right<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
fn move_red_right<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
mut h: Box<Node<K, V>>,
) -> Box<Node<K, V>> {
flip_colors(&mut h);
Expand All @@ -802,7 +802,7 @@ impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {
}

#[inline]
fn min<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
fn min<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
mut h: &mut Box<Node<K, V>>,
) -> &mut Box<Node<K, V>> {
while h.left.is_some() {
Expand All @@ -811,7 +811,7 @@ impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {
h
}

fn delete_min<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
fn delete_min<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
mut h: Box<Node<K, V>>,
) -> NodeRef<K, V> {
if h.left.is_none() {
Expand All @@ -827,7 +827,7 @@ impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {
Some(balance(h))
}

fn go<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
fn go<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
mut h: Box<Node<K, V>>,
key: &[u8],
) -> NodeRef<K, V> {
Expand Down Expand Up @@ -888,6 +888,94 @@ impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {
}
}

use candid::CandidType;

impl<'t, K, V> CandidType for RbTree<K, V>
where
K: CandidType + AsRef<[u8]> + 't,
V: CandidType + AsHashTree + 't,
{
fn _ty() -> candid::types::internal::Type {
<Vec<(&K, &V)> as CandidType>::_ty()
}
fn idl_serialize<S: candid::types::Serializer>(&self, serializer: S) -> Result<(), S::Error> {
let collect_as_vec = self.iter().collect::<Vec<(&K, &V)>>();
<Vec<(&K, &V)> as CandidType>::idl_serialize(&collect_as_vec, serializer)
}
}

use serde::{
de::{Deserialize, Deserializer, MapAccess, Visitor},
ser::{Serialize, SerializeMap, Serializer},
};
use std::marker::PhantomData;

impl<'t, K, V> Serialize for RbTree<K, V>
where
K: Serialize + AsRef<[u8]> + 't,
V: Serialize + AsHashTree + 't,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut map = serializer.serialize_map(Some(self.iter().count()))?;
for (k, v) in self.iter() {
map.serialize_entry(k, v)?;
}
map.end()
}
}

// The PhantomData keeps the compiler from complaining about unused generic type parameters.
struct RbTreeSerdeVisitor<K, V> {
marker: PhantomData<fn() -> RbTree<K, V>>,
}

impl<K, V> RbTreeSerdeVisitor<K, V> {
fn new() -> Self {
RbTreeSerdeVisitor {
marker: PhantomData,
}
}
}

impl<'de, 't, K, V> Visitor<'de> for RbTreeSerdeVisitor<K, V>
where
K: Deserialize<'de> + AsRef<[u8]> + 't,
V: Deserialize<'de> + AsHashTree + 't,
{
type Value = RbTree<K, V>;

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("a map")
}

fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
where
M: MapAccess<'de>,
{
let mut t = RbTree::<K, V>::new();
while let Some((key, value)) = access.next_entry()? {
t.insert(key, value);
}
Ok(t)
}
}

impl<'de, 't, K, V> Deserialize<'de> for RbTree<K, V>
where
K: Deserialize<'de> + AsRef<[u8]> + 't,
V: Deserialize<'de> + AsHashTree + 't,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_map(RbTreeSerdeVisitor::new())
}
}

fn three_way_fork<'a>(l: HashTree<'a>, m: HashTree<'a>, r: HashTree<'a>) -> HashTree<'a> {
match (l, m, r) {
(Empty, m, Empty) => m,
Expand All @@ -906,9 +994,7 @@ fn is_red<K, V>(x: &NodeRef<K, V>) -> bool {
x.as_ref().map(|h| h.color == Color::Red).unwrap_or(false)
}

fn balance<K: AsRef<[u8]> + 'static, V: AsHashTree + 'static>(
mut h: Box<Node<K, V>>,
) -> Box<Node<K, V>> {
fn balance<'t, K: AsRef<[u8]> + 't, V: AsHashTree + 't>(mut h: Box<Node<K, V>>) -> Box<Node<K, V>> {
if is_red(&h.right) && !is_red(&h.left) {
h = rotate_left(h);
}
Expand All @@ -922,7 +1008,7 @@ fn balance<K: AsRef<[u8]> + 'static, V: AsHashTree + 'static>(
}

/// Make a left-leaning link lean to the right.
fn rotate_right<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
fn rotate_right<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
mut h: Box<Node<K, V>>,
) -> Box<Node<K, V>> {
debug_assert!(is_red(&h.left));
Expand All @@ -939,7 +1025,7 @@ fn rotate_right<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
x
}

fn rotate_left<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
fn rotate_left<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
mut h: Box<Node<K, V>>,
) -> Box<Node<K, V>> {
debug_assert!(is_red(&h.right));
Expand Down
Loading

0 comments on commit 1da310b

Please sign in to comment.