diff --git a/crates/matrix-sdk-base/src/client.rs b/crates/matrix-sdk-base/src/client.rs index 6a8dbf9c089..de814ef2721 100644 --- a/crates/matrix-sdk-base/src/client.rs +++ b/crates/matrix-sdk-base/src/client.rs @@ -21,6 +21,10 @@ use std::{ use std::{ops::Deref, sync::Arc}; use eyeball::{SharedObservable, Subscriber}; +#[cfg(not(target_arch = "wasm32"))] +use eyeball_im::{Vector, VectorDiff}; +#[cfg(not(target_arch = "wasm32"))] +use futures_util::Stream; use matrix_sdk_common::instant::Instant; #[cfg(feature = "e2e-encryption")] use matrix_sdk_crypto::{ @@ -169,6 +173,13 @@ impl BaseClient { self.store.rooms_filtered(filter) } + /// Get a stream of all the rooms changes, in addition to the existing + /// rooms. + #[cfg(not(target_arch = "wasm32"))] + pub fn rooms_stream(&self) -> (Vector, impl Stream>>) { + self.store.rooms_stream() + } + /// Lookup the Room for the given RoomId, or create one, if it didn't exist /// yet in the store pub fn get_or_create_room(&self, room_id: &RoomId, room_state: RoomState) -> Room { @@ -1668,6 +1679,8 @@ mod tests { #[cfg(all(feature = "e2e-encryption", feature = "experimental-sliding-sync"))] #[async_test] async fn test_when_there_are_no_latest_encrypted_events_decrypting_them_does_nothing() { + use crate::StateChanges; + // Given a room let user_id = user_id!("@u:u.to"); let room_id = room_id!("!r:u.to"); @@ -1679,7 +1692,7 @@ mod tests { assert!(room.latest_event().is_none()); // When I tell it to do some decryption - let mut changes = crate::StateChanges::default(); + let mut changes = StateChanges::default(); client.decrypt_latest_events(&room, &mut changes).await; // Then nothing changed diff --git a/crates/matrix-sdk-base/src/latest_event.rs b/crates/matrix-sdk-base/src/latest_event.rs index 00a32fbe4c4..36a468bee7e 100644 --- a/crates/matrix-sdk-base/src/latest_event.rs +++ b/crates/matrix-sdk-base/src/latest_event.rs @@ -1,7 +1,7 @@ //! Utilities for working with events to decide whether they are suitable for //! use as a [crate::Room::latest_event]. -#![cfg(feature = "experimental-sliding-sync")] +#![cfg(any(feature = "e2e-encryption", feature = "experimental-sliding-sync"))] use matrix_sdk_common::deserialized_responses::SyncTimelineEvent; #[cfg(feature = "e2e-encryption")] diff --git a/crates/matrix-sdk-base/src/store/mod.rs b/crates/matrix-sdk-base/src/store/mod.rs index fd1da84ecaa..e1100f9f666 100644 --- a/crates/matrix-sdk-base/src/store/mod.rs +++ b/crates/matrix-sdk-base/src/store/mod.rs @@ -29,16 +29,22 @@ use std::{ sync::{Arc, RwLock as StdRwLock}, }; +#[cfg(not(target_arch = "wasm32"))] +use eyeball_im::{Vector, VectorDiff}; +#[cfg(not(target_arch = "wasm32"))] +use futures_util::Stream; use once_cell::sync::OnceCell; #[cfg(any(test, feature = "testing"))] #[macro_use] pub mod integration_tests; +mod observable_map; mod traits; #[cfg(feature = "e2e-encryption")] use matrix_sdk_crypto::store::{DynCryptoStore, IntoCryptoStore}; pub use matrix_sdk_store_encryption::Error as StoreEncryptionError; +use observable_map::ObservableMap; use ruma::{ events::{ presence::PresenceEvent, @@ -139,7 +145,7 @@ pub(crate) struct Store { /// The current sync token that should be used for the next sync call. pub(super) sync_token: Arc>>, /// All rooms the store knows about. - rooms: Arc>>, + rooms: Arc>>, /// A lock to synchronize access to the store, such that data by the sync is /// never overwritten. sync_lock: Arc>, @@ -152,7 +158,7 @@ impl Store { inner, session_meta: Default::default(), sync_token: Default::default(), - rooms: Default::default(), + rooms: Arc::new(StdRwLock::new(ObservableMap::new())), sync_lock: Default::default(), } } @@ -173,15 +179,22 @@ impl Store { session_meta: SessionMeta, roominfo_update_sender: &broadcast::Sender, ) -> Result<()> { - for info in self.inner.get_room_infos().await? { - let room = Room::restore( - &session_meta.user_id, - self.inner.clone(), - info, - roominfo_update_sender.clone(), - ); - - self.rooms.write().unwrap().insert(room.room_id().to_owned(), room); + { + let room_infos = self.inner.get_room_infos().await?; + + let mut rooms = self.rooms.write().unwrap(); + + for room_info in room_infos { + let new_room = Room::restore( + &session_meta.user_id, + self.inner.clone(), + room_info, + roominfo_update_sender.clone(), + ); + let new_room_id = new_room.room_id().to_owned(); + + rooms.insert(new_room_id, new_room); + } } let token = @@ -200,7 +213,7 @@ impl Store { /// Get all the rooms this store knows about. pub fn rooms(&self) -> Vec { - self.rooms.read().unwrap().values().cloned().collect() + self.rooms.read().unwrap().iter().cloned().collect() } /// Get all the rooms this store knows about, filtered by state. @@ -209,18 +222,25 @@ impl Store { .read() .unwrap() .iter() - .filter(|(_, room)| filter.matches(room.state())) - .map(|(_, room)| room.clone()) + .filter(|room| filter.matches(room.state())) + .cloned() .collect() } + /// Get a stream of all the rooms changes, in addition to the existing + /// rooms. + #[cfg(not(target_arch = "wasm32"))] + pub fn rooms_stream(&self) -> (Vector, impl Stream>>) { + self.rooms.read().unwrap().stream() + } + /// Get the room with the given room id. pub fn room(&self, room_id: &RoomId) -> Option { self.rooms.read().unwrap().get(room_id).cloned() } - /// Lookup the Room for the given RoomId, or create one, if it didn't exist - /// yet in the store. + /// Lookup the `Room` for the given `RoomId`, or create one, if it didn't + /// exist yet in the store pub fn get_or_create_room( &self, room_id: &RoomId, @@ -233,8 +253,7 @@ impl Store { self.rooms .write() .unwrap() - .entry(room_id.to_owned()) - .or_insert_with(|| { + .get_or_create(room_id, || { Room::new(user_id, self.inner.clone(), room_id, room_type, roominfo_update_sender) }) .clone() diff --git a/crates/matrix-sdk-base/src/store/observable_map.rs b/crates/matrix-sdk-base/src/store/observable_map.rs new file mode 100644 index 00000000000..e15124a9520 --- /dev/null +++ b/crates/matrix-sdk-base/src/store/observable_map.rs @@ -0,0 +1,299 @@ +// Copyright 2024 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! An [`ObservableMap`] implementation. + +#[cfg(not(target_arch = "wasm32"))] +mod impl_non_wasm32 { + use std::{borrow::Borrow, collections::HashMap, hash::Hash}; + + use eyeball_im::{ObservableVector, Vector, VectorDiff}; + use futures_util::Stream; + + /// An observable map. + /// + /// This is an “observable map” naive implementation. Just like regular + /// hashmap, we have a redirection from a key to a position, and from a + /// position to a value. The (key, position) tuples are stored in an + /// [`HashMap`]. The (position, value) tuples are stored in an + /// [`ObservableVector`]. The (key, position) tuple is only provided for + /// fast _reading_ implementations, like `Self::get` and + /// `Self::get_or_create`. The (position, value) tuples are observable, + /// this is what interests us the most here. + /// + /// Why not implementing a new `ObservableMap` type in `eyeball-im` instead + /// of this custom implementation? Because we want to continue providing + /// `VectorDiff` when observing the changes, so that the rest of the API in + /// the Matrix Rust SDK aren't broken. Indeed, an `ObservableMap` must + /// produce `MapDiff`, which would be quite different. + /// Plus, we would like to re-use all our existing code, test, stream + /// adapters and so on. + /// + /// This is a trade-off. This implementation is simple enough for the + /// moment, and basically does the job. + #[derive(Debug)] + pub(crate) struct ObservableMap + where + V: Clone + Send + Sync + 'static, + { + /// The (key, position) tuples. + mapping: HashMap, + + /// The values where the indices are the `position` part of + /// `Self::mapping`. + values: ObservableVector, + } + + impl ObservableMap + where + K: Hash + Eq, + V: Clone + Send + Sync + 'static, + { + /// Create a new `Self`. + pub(crate) fn new() -> Self { + Self { mapping: HashMap::new(), values: ObservableVector::new() } + } + + /// Insert a new `V` in the collection. + /// + /// If the `V` value already exists, it will be updated to the new one. + pub(crate) fn insert(&mut self, key: K, value: V) -> usize { + match self.mapping.get(&key) { + Some(position) => { + self.values.set(*position, value); + + *position + } + None => { + let position = self.values.len(); + + self.values.push_back(value); + self.mapping.insert(key, position); + + position + } + } + } + + /// Reading one `V` value based on their ID, if it exists. + pub(crate) fn get(&self, key: &L) -> Option<&V> + where + K: Borrow, + L: Hash + Eq + ?Sized, + { + self.mapping.get(key).and_then(|position| self.values.get(*position)) + } + + /// Reading one `V` value based on their ID, or create a new one (by + /// using `default`). + pub(crate) fn get_or_create(&mut self, key: &L, default: F) -> &V + where + K: Borrow, + L: Hash + Eq + ?Sized + ToOwned, + F: FnOnce() -> V, + { + let position = match self.mapping.get(key) { + Some(position) => *position, + None => { + let value = default(); + let position = self.values.len(); + + self.values.push_back(value); + self.mapping.insert(key.to_owned(), position); + + position + } + }; + + self.values + .get(position) + .expect("Value should be present or has just been inserted, but it's missing") + } + + /// Return an iterator over the existing values. + pub(crate) fn iter(&self) -> impl Iterator { + self.values.iter() + } + + /// Get a [`Stream`] of the values. + pub(crate) fn stream(&self) -> (Vector, impl Stream>>) { + self.values.subscribe().into_values_and_batched_stream() + } + } +} + +#[cfg(target_arch = "wasm32")] +mod impl_wasm32 { + use std::{borrow::Borrow, collections::BTreeMap, hash::Hash}; + + /// An observable map for Wasm. It's a simple wrapper around `BTreeMap`. + #[derive(Debug)] + pub(crate) struct ObservableMap(BTreeMap) + where + V: Clone + 'static; + + impl ObservableMap + where + K: Hash + Eq + Ord, + V: Clone + 'static, + { + /// Create a new `Self`. + pub(crate) fn new() -> Self { + Self(BTreeMap::new()) + } + + /// Insert a new `V` in the collection. + /// + /// If the `V` value already exists, it will be updated to the new one. + pub(crate) fn insert(&mut self, key: K, value: V) { + self.0.insert(key, value); + } + + /// Reading one `V` value based on their ID, if it exists. + pub(crate) fn get(&self, key: &L) -> Option<&V> + where + K: Borrow, + L: Hash + Eq + Ord + ?Sized, + { + self.0.get(key) + } + + /// Reading one `V` value based on their ID, or create a new one (by + /// using `default`). + pub(crate) fn get_or_create(&mut self, key: &L, default: F) -> &V + where + K: Borrow, + L: Hash + Eq + ?Sized + ToOwned, + F: FnOnce() -> V, + { + self.0.entry(key.to_owned()).or_insert_with(default) + } + + /// Return an iterator over the existing values. + pub(crate) fn iter(&self) -> impl Iterator { + self.0.values() + } + } +} + +#[cfg(not(target_arch = "wasm32"))] +pub(crate) use impl_non_wasm32::ObservableMap; +#[cfg(target_arch = "wasm32")] +pub(crate) use impl_wasm32::ObservableMap; + +#[cfg(test)] +mod tests { + #[cfg(not(target_arch = "wasm32"))] + use eyeball_im::VectorDiff; + #[cfg(not(target_arch = "wasm32"))] + use stream_assert::{assert_closed, assert_next_eq, assert_pending}; + + use super::ObservableMap; + + #[test] + fn test_insert_and_get() { + let mut map = ObservableMap::::new(); + + assert!(map.get(&'a').is_none()); + assert!(map.get(&'b').is_none()); + assert!(map.get(&'c').is_none()); + + // new items + map.insert('a', 'e'); + map.insert('b', 'f'); + + assert_eq!(map.get(&'a'), Some(&'e')); + assert_eq!(map.get(&'b'), Some(&'f')); + assert!(map.get(&'c').is_none()); + + // one new item + map.insert('c', 'g'); + + assert_eq!(map.get(&'a'), Some(&'e')); + assert_eq!(map.get(&'b'), Some(&'f')); + assert_eq!(map.get(&'c'), Some(&'g')); + + // update one item + map.insert('b', 'F'); + + assert_eq!(map.get(&'a'), Some(&'e')); + assert_eq!(map.get(&'b'), Some(&'F')); + assert_eq!(map.get(&'c'), Some(&'g')); + } + + #[test] + fn test_get_or_create() { + let mut map = ObservableMap::::new(); + + // insert one item + map.insert('b', 'f'); + + // get or create many items + assert_eq!(map.get_or_create(&'a', || 'E'), &'E'); + assert_eq!(map.get_or_create(&'b', || 'F'), &'f'); // this one already exists + assert_eq!(map.get_or_create(&'c', || 'G'), &'G'); + + assert_eq!(map.get(&'a'), Some(&'E')); + assert_eq!(map.get(&'b'), Some(&'f')); + assert_eq!(map.get(&'c'), Some(&'G')); + } + + #[test] + fn test_iter() { + let mut map = ObservableMap::::new(); + + // new items + map.insert('a', 'e'); + map.insert('b', 'f'); + map.insert('c', 'g'); + + assert_eq!( + map.iter().map(|c| c.to_ascii_uppercase()).collect::>(), + &['E', 'F', 'G'] + ); + } + + #[cfg(not(target_arch = "wasm32"))] + #[test] + fn test_stream() { + let mut map = ObservableMap::::new(); + + // insert one item + map.insert('b', 'f'); + + let (initial_values, mut stream) = map.stream(); + assert_eq!(initial_values.iter().copied().collect::>(), &['f']); + + assert_pending!(stream); + + // insert two items + map.insert('c', 'g'); + map.insert('a', 'e'); + assert_next_eq!( + stream, + vec![VectorDiff::PushBack { value: 'g' }, VectorDiff::PushBack { value: 'e' }] + ); + + assert_pending!(stream); + + // update one item + map.insert('b', 'F'); + assert_next_eq!(stream, vec![VectorDiff::Set { index: 0, value: 'F' }]); + + assert_pending!(stream); + + drop(map); + assert_closed!(stream); + } +} diff --git a/crates/matrix-sdk/src/client/mod.rs b/crates/matrix-sdk/src/client/mod.rs index b2a2e756f91..b8da479e8ec 100644 --- a/crates/matrix-sdk/src/client/mod.rs +++ b/crates/matrix-sdk/src/client/mod.rs @@ -23,7 +23,13 @@ use std::{ }; use eyeball::{SharedObservable, Subscriber}; +#[cfg(not(target_arch = "wasm32"))] +use eyeball_im::VectorDiff; use futures_core::Stream; +#[cfg(not(target_arch = "wasm32"))] +use futures_util::StreamExt; +#[cfg(not(target_arch = "wasm32"))] +use imbl::Vector; #[cfg(feature = "e2e-encryption")] use matrix_sdk_base::crypto::store::LockableCryptoStore; use matrix_sdk_base::{ @@ -913,6 +919,19 @@ impl Client { .collect() } + /// Get a stream of all the rooms, in addition to the existing rooms. + #[cfg(not(target_arch = "wasm32"))] + pub fn rooms_stream(&self) -> (Vector, impl Stream>> + '_) { + let (rooms, stream) = self.base_client().rooms_stream(); + + let map_room = |room| Room::new(self.clone(), room); + + ( + rooms.into_iter().map(map_room).collect(), + stream.map(move |diffs| diffs.into_iter().map(|diff| diff.map(map_room)).collect()), + ) + } + /// Returns the joined rooms this client knows about. pub fn joined_rooms(&self) -> Vec { self.base_client() diff --git a/crates/matrix-sdk/tests/integration/client.rs b/crates/matrix-sdk/tests/integration/client.rs index 0945b7d3be5..20ccd266ec9 100644 --- a/crates/matrix-sdk/tests/integration/client.rs +++ b/crates/matrix-sdk/tests/integration/client.rs @@ -1,6 +1,7 @@ use std::{collections::BTreeMap, time::Duration}; use assert_matches2::assert_let; +use eyeball_im::VectorDiff; use futures_util::FutureExt; use matrix_sdk::{ config::SyncSettings, @@ -1260,3 +1261,59 @@ async fn test_test_ambiguity_changes() { assert_pending!(updates); } + +#[cfg(not(target_arch = "wasm32"))] +#[async_test] +async fn test_rooms_stream() { + use futures_util::StreamExt as _; + + let (client, server) = logged_in_client_with_server().await; + let (rooms, mut rooms_stream) = client.rooms_stream(); + + assert!(rooms.is_empty()); + assert_pending!(rooms_stream); + + let room_id_1 = room_id!("!room0:matrix.org"); + let room_id_2 = room_id!("!room1:matrix.org"); + let room_id_3 = room_id!("!room2:matrix.org"); + + let payload = json!({ + "next_batch": "foo", + "rooms": { + "invite": {}, + "join": { + room_id_1: {}, + room_id_2: {}, + room_id_3: {}, + }, + "leave": {} + }, + }); + + mock_sync(&server, &payload, None).await; + + assert!(client.get_room(room_id_1).is_none()); + assert!(client.get_room(room_id_2).is_none()); + assert!(client.get_room(room_id_3).is_none()); + + client.sync_once(SyncSettings::default()).await.unwrap(); + + // Rooms are created. + assert!(client.get_room(room_id_1).is_some()); + assert!(client.get_room(room_id_2).is_some()); + assert!(client.get_room(room_id_3).is_some()); + + // We receive 3 diffs… + assert_let!(Some(diffs) = rooms_stream.next().await); + assert_eq!(diffs.len(), 3); + + // … which map to the new rooms! + assert_let!(VectorDiff::PushBack { value: room_1 } = &diffs[0]); + assert_eq!(room_1.room_id(), room_id_1); + assert_let!(VectorDiff::PushBack { value: room_2 } = &diffs[1]); + assert_eq!(room_2.room_id(), room_id_2); + assert_let!(VectorDiff::PushBack { value: room_3 } = &diffs[2]); + assert_eq!(room_3.room_id(), room_id_3); + + assert_pending!(rooms_stream); +}