diff --git a/ipld/hamt/tests/hamt_tests.rs b/ipld/hamt/tests/hamt_tests.rs index 926c6f70d..9f58d561a 100644 --- a/ipld/hamt/tests/hamt_tests.rs +++ b/ipld/hamt/tests/hamt_tests.rs @@ -173,6 +173,32 @@ fn test_load(factory: HamtFactory) { assert_eq!(c3, c2); } +// Make sure we correctly set the root _and_ the cached root cid. +fn test_set_root(factory: HamtFactory) { + let store = MemoryBlockstore::default(); + + let mut hamt: Hamt<_, _, usize> = factory.new(&store); + hamt.set(1, "world".to_string()).unwrap(); + + assert_eq!(hamt.get(&1).unwrap(), Some(&"world".to_string())); + let c1 = hamt.flush().unwrap(); + + hamt.set(2, "world2".to_string()).unwrap(); + assert_eq!(hamt.get(&2).unwrap(), Some(&"world2".to_string())); + + let c2 = hamt.flush().unwrap(); + + let mut new_hamt: Hamt<_, String, usize> = factory.load(&c1, &store).unwrap(); + assert_eq!(new_hamt.get(&1).unwrap(), Some(&"world".to_string())); + assert_eq!(new_hamt.get(&2).unwrap(), None); + + new_hamt.set_root(&c2).unwrap(); + assert_eq!(new_hamt.get(&2).unwrap(), Some(&"world2".to_string())); + + let c3 = new_hamt.flush().unwrap(); + assert_eq!(c2, c3); +} + fn test_set_if_absent(factory: HamtFactory, stats: Option, mut cids: CidChecker) { let mem = MemoryBlockstore::default(); let store = TrackingBlockstore::new(&mem); @@ -954,6 +980,11 @@ mod test_default { super::test_load(HamtFactory::default()) } + #[test] + fn test_set_root() { + super::test_set_root(HamtFactory::default()) + } + #[test] fn test_set_if_absent() { #[rustfmt::skip] diff --git a/ipld/kamt/src/kamt.rs b/ipld/kamt/src/kamt.rs index 2583e7c9f..e964446b1 100644 --- a/ipld/kamt/src/kamt.rs +++ b/ipld/kamt/src/kamt.rs @@ -101,6 +101,7 @@ where /// Sets the root based on the Cid of the root node using the Kamt store pub fn set_root(&mut self, cid: &Cid) -> Result<(), Error> { self.root = Node::load(&self.conf, &self.store, cid, 0)?; + self.flushed_cid = Some(*cid); Ok(()) } diff --git a/ipld/kamt/tests/kamt_tests.rs b/ipld/kamt/tests/kamt_tests.rs index 71dadea4b..817c1bcae 100644 --- a/ipld/kamt/tests/kamt_tests.rs +++ b/ipld/kamt/tests/kamt_tests.rs @@ -111,6 +111,36 @@ fn test_load(factory: KamtFactory) { assert_eq!(c3, c2); } +// Make sure we correctly set the root _and_ the cached root cid. +fn test_set_root(factory: KamtFactory) { + let store = MemoryBlockstore::default(); + + let mut kamt: HKamt<_, _> = factory.new(&store); + kamt.set(1, "world".to_string()).unwrap(); + + // Record a kamt root with one entry. + assert_eq!(kamt.get(&1).unwrap(), Some(&"world".to_string())); + let c1 = kamt.flush().unwrap(); + + // Record a second kamt root with 2 entries. + kamt.set(2, "world2".to_string()).unwrap(); + assert_eq!(kamt.get(&2).unwrap(), Some(&"world2".to_string())); + let c2 = kamt.flush().unwrap(); + + // Re-load the original kamt with one entry. + let mut new_kamt: HKamt<_, String> = factory.load(&c1, &store).unwrap(); + assert_eq!(new_kamt.get(&1).unwrap(), Some(&"world".to_string())); + assert_eq!(new_kamt.get(&2).unwrap(), None); + + // Try to update it to the new kamt by setting its root manually. + new_kamt.set_root(&c2).unwrap(); + assert_eq!(new_kamt.get(&2).unwrap(), Some(&"world2".to_string())); + + // Flush the new kamt and make sure it matches the root we just set. + let c3 = new_kamt.flush().unwrap(); + assert_eq!(c2, c3); +} + fn test_set_if_absent(factory: KamtFactory) { let store = MemoryBlockstore::default(); @@ -370,6 +400,11 @@ macro_rules! test_kamt_mod { super::test_load($factory) } + #[test] + fn test_set_root() { + super::test_set_root($factory) + } + #[test] fn test_set_if_absent() { super::test_set_if_absent($factory)