Skip to content

Commit

Permalink
Accept ranges when fetching messages from store (#119)
Browse files Browse the repository at this point in the history
  • Loading branch information
gferon authored Mar 8, 2023
1 parent e0dcf6d commit 5476eef
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 28 deletions.
17 changes: 17 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Changelog

All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Changed

- Changed (and fixed) the behaviour of the iterator returned by `SledStore::messages` (#119)
* The iterator yields elements in chronological order (used to be reversed)
* The iterator now implements `DoubleEndedIterator` which means you it can be reversed or consumed from the end
* The method now accepts the full range syntax, like `0..=1678295210` or `..` for all messages

[unreleased]: https://github.com/whisperfish/presage/compare/0.4.0...HEAD
2 changes: 1 addition & 1 deletion examples/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ async fn run<C: Store + MessageStore>(subcommand: Cmd, config_store: C) -> anyho
(_, Some(uuid)) => Thread::Contact(uuid),
_ => unreachable!(),
};
let iter = config_store.messages(&thread, from)?;
let iter = config_store.messages(&thread, from.unwrap_or(0)..)?;
for msg in iter.filter_map(Result::ok) {
println!("{:?}: {:?}", msg.metadata.sender, msg);
}
Expand Down
5 changes: 3 additions & 2 deletions src/manager.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{
fmt,
ops::RangeBounds,
sync::Arc,
time::{Duration, UNIX_EPOCH},
};
Expand Down Expand Up @@ -677,9 +678,9 @@ impl<C: Store> Manager<C, Registered> {
pub fn messages(
&self,
thread: &Thread,
from: Option<u64>,
range: impl RangeBounds<u64>,
) -> Result<impl Iterator<Item = Result<Content, Error>>, Error> {
self.config_store.messages(thread, from)
self.config_store.messages(thread, range)
}

async fn receive_messages_encrypted(
Expand Down
10 changes: 7 additions & 3 deletions src/store/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::fmt;
use std::{fmt, ops::RangeBounds};

use crate::{manager::Registered, Error, GroupMasterKeyBytes};
use libsignal_service::{
Expand Down Expand Up @@ -168,6 +168,10 @@ pub trait MessageStore {
/// Retrieve a message from a [Thread] by its timestamp.
fn message(&self, thread: &Thread, timestamp: u64) -> Result<Option<Content>, Error>;

/// Retrieve a message from a [Thread].
fn messages(&self, thread: &Thread, from: Option<u64>) -> Result<Self::MessagesIter, Error>;
/// Retrieve all messages from a [Thread] within a range in time
fn messages(
&self,
thread: &Thread,
range: impl RangeBounds<u64>,
) -> Result<Self::MessagesIter, Error>;
}
165 changes: 143 additions & 22 deletions src/store/sled.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::{
ops::Range,
ops::{Bound, Range, RangeBounds, RangeFull},
path::Path,
sync::Arc,
time::{SystemTime, UNIX_EPOCH},
Expand Down Expand Up @@ -27,7 +27,7 @@ use matrix_sdk_store_encryption::StoreCipher;
use prost::Message;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use sha2::{Digest, Sha256};
use sled::Batch;
use sled::{Batch, IVec};

use super::{ContactsStore, GroupsStore, MessageStore, StateStore};
use crate::{
Expand Down Expand Up @@ -302,9 +302,8 @@ fn migrate(

impl StateStore<Registered> for SledStore {
fn load_state(&self) -> Result<Registered, Error> {
Ok(self
.get(SLED_TREE_STATE, SLED_KEY_REGISTRATION)?
.ok_or(Error::NotYetRegisteredError)?)
self.get(SLED_TREE_STATE, SLED_KEY_REGISTRATION)?
.ok_or(Error::NotYetRegisteredError)
}

fn save_state(&mut self, state: &Registered) -> Result<(), Error> {
Expand Down Expand Up @@ -828,33 +827,43 @@ impl MessageStore for SledStore {
}
}

fn messages(&self, thread: &Thread, from: Option<u64>) -> Result<Self::MessagesIter, Error> {
fn messages(
&self,
thread: &Thread,
range: impl RangeBounds<u64>,
) -> Result<Self::MessagesIter, Error> {
let tree_thread = self.tree(self.messages_thread_tree_name(thread))?;
debug!("{} messages in this tree", tree_thread.len());
let iter = if let Some(from) = from {
tree_thread.range(from.to_be_bytes()..)
} else {
tree_thread.range::<&[u8], std::ops::RangeFull>(..)

let iter = match (range.start_bound(), range.end_bound()) {
(Bound::Included(start), Bound::Unbounded) => tree_thread.range(start.to_be_bytes()..),
(Bound::Included(start), Bound::Excluded(end)) => {
tree_thread.range(start.to_be_bytes()..end.to_be_bytes())
}
(Bound::Included(start), Bound::Included(end)) => {
tree_thread.range(start.to_be_bytes()..=end.to_be_bytes())
}
(Bound::Unbounded, Bound::Included(end)) => tree_thread.range(..=end.to_be_bytes()),
(Bound::Unbounded, Bound::Excluded(end)) => tree_thread.range(..end.to_be_bytes()),
(Bound::Unbounded, Bound::Unbounded) => tree_thread.range::<[u8; 8], RangeFull>(..),
(Bound::Excluded(_), _) => unreachable!("range that excludes the initial value"),
};

Ok(SledMessagesIter {
cipher: self.cipher.clone(),
iter: iter.rev(),
iter,
})
}
}

pub struct SledMessagesIter {
cipher: Option<Arc<StoreCipher>>,
iter: std::iter::Rev<sled::Iter>,
iter: sled::Iter,
}

impl Iterator for SledMessagesIter {
type Item = Result<Content, Error>;

fn next(&mut self) -> Option<Self::Item> {
self.iter
.next()?
.map_err(Error::from)
impl SledMessagesIter {
fn decode(&self, elem: Result<(IVec, IVec), sled::Error>) -> Option<Result<Content, Error>> {
elem.map_err(Error::from)
.and_then(|(_, value)| {
self.cipher.as_ref().map_or_else(
|| serde_json::from_slice(&value).map_err(Error::from),
Expand All @@ -866,16 +875,42 @@ impl Iterator for SledMessagesIter {
}
}

impl Iterator for SledMessagesIter {
type Item = Result<Content, Error>;

fn next(&mut self) -> Option<Self::Item> {
let elem = self.iter.next()?;
self.decode(elem)
}
}

impl DoubleEndedIterator for SledMessagesIter {
fn next_back(&mut self) -> Option<Self::Item> {
let elem = self.iter.next_back()?;
self.decode(elem)
}
}

#[cfg(test)]
mod tests {
use core::fmt;

use libsignal_service::prelude::protocol::{
self, Direction, IdentityKeyStore, PreKeyRecord, PreKeyStore, SessionRecord, SessionStore,
SignedPreKeyRecord, SignedPreKeyStore,
use libsignal_service::{
content::{ContentBody, Metadata},
prelude::{
protocol::{
self, Direction, IdentityKeyStore, PreKeyRecord, PreKeyStore, SessionRecord,
SessionStore, SignedPreKeyRecord, SignedPreKeyStore,
},
Uuid,
},
proto::DataMessage,
ServiceAddress,
};
use quickcheck::{Arbitrary, Gen};

use crate::{MessageStore, Thread};

use super::SledStore;

#[derive(Debug, Clone)]
Expand All @@ -884,6 +919,9 @@ mod tests {
#[derive(Clone)]
struct KeyPair(protocol::KeyPair);

#[derive(Debug, Clone)]
struct Content(libsignal_service::prelude::Content);

impl fmt::Debug for KeyPair {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "{}", base64::encode(self.0.public_key.serialize()))
Expand All @@ -905,6 +943,41 @@ mod tests {
}
}

impl Arbitrary for Content {
fn arbitrary(g: &mut Gen) -> Self {
let timestamp: u64 = Arbitrary::arbitrary(g);
let contacts = [
Uuid::from_u128(Arbitrary::arbitrary(g)),
Uuid::from_u128(Arbitrary::arbitrary(g)),
Uuid::from_u128(Arbitrary::arbitrary(g)),
];
let metadata = Metadata {
sender: ServiceAddress {
uuid: *g.choose(&contacts).unwrap(),
},
sender_device: Arbitrary::arbitrary(g),
timestamp,
needs_receipt: Arbitrary::arbitrary(g),
unidentified_sender: Arbitrary::arbitrary(g),
};
let content_body = ContentBody::DataMessage(DataMessage {
body: Arbitrary::arbitrary(g),
timestamp: Some(timestamp),
..Default::default()
});
Self(libsignal_service::prelude::Content::from_body(
content_body,
metadata,
))
}
}

impl Arbitrary for Thread {
fn arbitrary(g: &mut Gen) -> Self {
Self::Contact(Uuid::from_u128(Arbitrary::arbitrary(g)))
}
}

#[quickcheck_async::tokio]
async fn test_save_get_trust_identity(addr: ProtocolAddress, key_pair: KeyPair) -> bool {
let mut db = SledStore::temporary().unwrap();
Expand Down Expand Up @@ -971,4 +1044,52 @@ mod tests {
.unwrap()
== signed_pre_key_record.serialize().unwrap()
}

fn content_with_timestamp(content: &Content, ts: u64) -> libsignal_service::prelude::Content {
libsignal_service::prelude::Content {
metadata: Metadata {
timestamp: ts,
..content.0.metadata.clone()
},
body: content.0.body.clone(),
}
}

#[quickcheck_async::tokio]
async fn test_store_messages(thread: Thread, content: Content) -> anyhow::Result<()> {
let mut db = SledStore::temporary()?;
db.save_message(&thread, content_with_timestamp(&content, 1678295210))?;
db.save_message(&thread, content_with_timestamp(&content, 1678295220))?;
db.save_message(&thread, content_with_timestamp(&content, 1678295230))?;
db.save_message(&thread, content_with_timestamp(&content, 1678295240))?;
db.save_message(&thread, content_with_timestamp(&content, 1678280000))?;

assert_eq!(db.messages(&thread, ..).unwrap().count(), 5);
assert_eq!(db.messages(&thread, 0..).unwrap().count(), 5);
assert_eq!(db.messages(&thread, 1678280000..).unwrap().count(), 5);

assert_eq!(db.messages(&thread, 0..1678280000)?.count(), 0);
assert_eq!(db.messages(&thread, 0..1678295210)?.count(), 1);
assert_eq!(db.messages(&thread, 1678295210..1678295240)?.count(), 3);
assert_eq!(db.messages(&thread, 1678295210..=1678295240)?.count(), 4);

assert_eq!(
db.messages(&thread, 0..=1678295240)?
.next()
.unwrap()?
.metadata
.timestamp,
1678280000
);
assert_eq!(
db.messages(&thread, 0..=1678295240)?
.next_back()
.unwrap()?
.metadata
.timestamp,
1678295240
);

Ok(())
}
}

0 comments on commit 5476eef

Please sign in to comment.