diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 000000000..3e9a20a5c --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,17 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +### Changed + +- Changed (and fixed) the behaviour of the iterator returned by `SledStore::messages` (#119) + * The iterator yields elements in chronological order (used to be reversed) + * The iterator now implements `DoubleEndedIterator` which means you it can be reversed or consumed from the end + * The method now accepts the full range syntax, like `0..=1678295210` or `..` for all messages + +[unreleased]: https://github.com/whisperfish/presage/compare/0.4.0...HEAD diff --git a/examples/cli.rs b/examples/cli.rs index 4240e5cbe..1c0e06c6c 100644 --- a/examples/cli.rs +++ b/examples/cli.rs @@ -603,7 +603,7 @@ async fn run(subcommand: Cmd, config_store: C) -> anyho (_, Some(uuid)) => Thread::Contact(uuid), _ => unreachable!(), }; - let iter = config_store.messages(&thread, from)?; + let iter = config_store.messages(&thread, from.unwrap_or(0)..)?; for msg in iter.filter_map(Result::ok) { println!("{:?}: {:?}", msg.metadata.sender, msg); } diff --git a/src/manager.rs b/src/manager.rs index 15578f2f1..7b1ffa9f2 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -1,5 +1,6 @@ use std::{ fmt, + ops::RangeBounds, sync::Arc, time::{Duration, UNIX_EPOCH}, }; @@ -677,9 +678,9 @@ impl Manager { pub fn messages( &self, thread: &Thread, - from: Option, + range: impl RangeBounds, ) -> Result>, Error> { - self.config_store.messages(thread, from) + self.config_store.messages(thread, range) } async fn receive_messages_encrypted( diff --git a/src/store/mod.rs b/src/store/mod.rs index a30bef7c6..6b5988897 100644 --- a/src/store/mod.rs +++ b/src/store/mod.rs @@ -1,4 +1,4 @@ -use std::fmt; +use std::{fmt, ops::RangeBounds}; use crate::{manager::Registered, Error, GroupMasterKeyBytes}; use libsignal_service::{ @@ -168,6 +168,10 @@ pub trait MessageStore { /// Retrieve a message from a [Thread] by its timestamp. fn message(&self, thread: &Thread, timestamp: u64) -> Result, Error>; - /// Retrieve a message from a [Thread]. - fn messages(&self, thread: &Thread, from: Option) -> Result; + /// Retrieve all messages from a [Thread] within a range in time + fn messages( + &self, + thread: &Thread, + range: impl RangeBounds, + ) -> Result; } diff --git a/src/store/sled.rs b/src/store/sled.rs index d90964b65..a029f02cd 100644 --- a/src/store/sled.rs +++ b/src/store/sled.rs @@ -1,5 +1,5 @@ use std::{ - ops::Range, + ops::{Bound, Range, RangeBounds, RangeFull}, path::Path, sync::Arc, time::{SystemTime, UNIX_EPOCH}, @@ -27,7 +27,7 @@ use matrix_sdk_store_encryption::StoreCipher; use prost::Message; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use sha2::{Digest, Sha256}; -use sled::Batch; +use sled::{Batch, IVec}; use super::{ContactsStore, GroupsStore, MessageStore, StateStore}; use crate::{ @@ -302,9 +302,8 @@ fn migrate( impl StateStore for SledStore { fn load_state(&self) -> Result { - Ok(self - .get(SLED_TREE_STATE, SLED_KEY_REGISTRATION)? - .ok_or(Error::NotYetRegisteredError)?) + self.get(SLED_TREE_STATE, SLED_KEY_REGISTRATION)? + .ok_or(Error::NotYetRegisteredError) } fn save_state(&mut self, state: &Registered) -> Result<(), Error> { @@ -828,33 +827,43 @@ impl MessageStore for SledStore { } } - fn messages(&self, thread: &Thread, from: Option) -> Result { + fn messages( + &self, + thread: &Thread, + range: impl RangeBounds, + ) -> Result { let tree_thread = self.tree(self.messages_thread_tree_name(thread))?; debug!("{} messages in this tree", tree_thread.len()); - let iter = if let Some(from) = from { - tree_thread.range(from.to_be_bytes()..) - } else { - tree_thread.range::<&[u8], std::ops::RangeFull>(..) + + let iter = match (range.start_bound(), range.end_bound()) { + (Bound::Included(start), Bound::Unbounded) => tree_thread.range(start.to_be_bytes()..), + (Bound::Included(start), Bound::Excluded(end)) => { + tree_thread.range(start.to_be_bytes()..end.to_be_bytes()) + } + (Bound::Included(start), Bound::Included(end)) => { + tree_thread.range(start.to_be_bytes()..=end.to_be_bytes()) + } + (Bound::Unbounded, Bound::Included(end)) => tree_thread.range(..=end.to_be_bytes()), + (Bound::Unbounded, Bound::Excluded(end)) => tree_thread.range(..end.to_be_bytes()), + (Bound::Unbounded, Bound::Unbounded) => tree_thread.range::<[u8; 8], RangeFull>(..), + (Bound::Excluded(_), _) => unreachable!("range that excludes the initial value"), }; + Ok(SledMessagesIter { cipher: self.cipher.clone(), - iter: iter.rev(), + iter, }) } } pub struct SledMessagesIter { cipher: Option>, - iter: std::iter::Rev, + iter: sled::Iter, } -impl Iterator for SledMessagesIter { - type Item = Result; - - fn next(&mut self) -> Option { - self.iter - .next()? - .map_err(Error::from) +impl SledMessagesIter { + fn decode(&self, elem: Result<(IVec, IVec), sled::Error>) -> Option> { + elem.map_err(Error::from) .and_then(|(_, value)| { self.cipher.as_ref().map_or_else( || serde_json::from_slice(&value).map_err(Error::from), @@ -866,16 +875,42 @@ impl Iterator for SledMessagesIter { } } +impl Iterator for SledMessagesIter { + type Item = Result; + + fn next(&mut self) -> Option { + let elem = self.iter.next()?; + self.decode(elem) + } +} + +impl DoubleEndedIterator for SledMessagesIter { + fn next_back(&mut self) -> Option { + let elem = self.iter.next_back()?; + self.decode(elem) + } +} + #[cfg(test)] mod tests { use core::fmt; - use libsignal_service::prelude::protocol::{ - self, Direction, IdentityKeyStore, PreKeyRecord, PreKeyStore, SessionRecord, SessionStore, - SignedPreKeyRecord, SignedPreKeyStore, + use libsignal_service::{ + content::{ContentBody, Metadata}, + prelude::{ + protocol::{ + self, Direction, IdentityKeyStore, PreKeyRecord, PreKeyStore, SessionRecord, + SessionStore, SignedPreKeyRecord, SignedPreKeyStore, + }, + Uuid, + }, + proto::DataMessage, + ServiceAddress, }; use quickcheck::{Arbitrary, Gen}; + use crate::{MessageStore, Thread}; + use super::SledStore; #[derive(Debug, Clone)] @@ -884,6 +919,9 @@ mod tests { #[derive(Clone)] struct KeyPair(protocol::KeyPair); + #[derive(Debug, Clone)] + struct Content(libsignal_service::prelude::Content); + impl fmt::Debug for KeyPair { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, "{}", base64::encode(self.0.public_key.serialize())) @@ -905,6 +943,41 @@ mod tests { } } + impl Arbitrary for Content { + fn arbitrary(g: &mut Gen) -> Self { + let timestamp: u64 = Arbitrary::arbitrary(g); + let contacts = [ + Uuid::from_u128(Arbitrary::arbitrary(g)), + Uuid::from_u128(Arbitrary::arbitrary(g)), + Uuid::from_u128(Arbitrary::arbitrary(g)), + ]; + let metadata = Metadata { + sender: ServiceAddress { + uuid: *g.choose(&contacts).unwrap(), + }, + sender_device: Arbitrary::arbitrary(g), + timestamp, + needs_receipt: Arbitrary::arbitrary(g), + unidentified_sender: Arbitrary::arbitrary(g), + }; + let content_body = ContentBody::DataMessage(DataMessage { + body: Arbitrary::arbitrary(g), + timestamp: Some(timestamp), + ..Default::default() + }); + Self(libsignal_service::prelude::Content::from_body( + content_body, + metadata, + )) + } + } + + impl Arbitrary for Thread { + fn arbitrary(g: &mut Gen) -> Self { + Self::Contact(Uuid::from_u128(Arbitrary::arbitrary(g))) + } + } + #[quickcheck_async::tokio] async fn test_save_get_trust_identity(addr: ProtocolAddress, key_pair: KeyPair) -> bool { let mut db = SledStore::temporary().unwrap(); @@ -971,4 +1044,52 @@ mod tests { .unwrap() == signed_pre_key_record.serialize().unwrap() } + + fn content_with_timestamp(content: &Content, ts: u64) -> libsignal_service::prelude::Content { + libsignal_service::prelude::Content { + metadata: Metadata { + timestamp: ts, + ..content.0.metadata.clone() + }, + body: content.0.body.clone(), + } + } + + #[quickcheck_async::tokio] + async fn test_store_messages(thread: Thread, content: Content) -> anyhow::Result<()> { + let mut db = SledStore::temporary()?; + db.save_message(&thread, content_with_timestamp(&content, 1678295210))?; + db.save_message(&thread, content_with_timestamp(&content, 1678295220))?; + db.save_message(&thread, content_with_timestamp(&content, 1678295230))?; + db.save_message(&thread, content_with_timestamp(&content, 1678295240))?; + db.save_message(&thread, content_with_timestamp(&content, 1678280000))?; + + assert_eq!(db.messages(&thread, ..).unwrap().count(), 5); + assert_eq!(db.messages(&thread, 0..).unwrap().count(), 5); + assert_eq!(db.messages(&thread, 1678280000..).unwrap().count(), 5); + + assert_eq!(db.messages(&thread, 0..1678280000)?.count(), 0); + assert_eq!(db.messages(&thread, 0..1678295210)?.count(), 1); + assert_eq!(db.messages(&thread, 1678295210..1678295240)?.count(), 3); + assert_eq!(db.messages(&thread, 1678295210..=1678295240)?.count(), 4); + + assert_eq!( + db.messages(&thread, 0..=1678295240)? + .next() + .unwrap()? + .metadata + .timestamp, + 1678280000 + ); + assert_eq!( + db.messages(&thread, 0..=1678295240)? + .next_back() + .unwrap()? + .metadata + .timestamp, + 1678295240 + ); + + Ok(()) + } }