diff --git a/crates/matrix-sdk-sqlite/src/event_cache_store.rs b/crates/matrix-sdk-sqlite/src/event_cache_store.rs index e44b1b62fa5..c4ae0d1dda3 100644 --- a/crates/matrix-sdk-sqlite/src/event_cache_store.rs +++ b/crates/matrix-sdk-sqlite/src/event_cache_store.rs @@ -165,7 +165,7 @@ impl SqliteEventCacheStore { }) } - // Acquire a connection for executing read operations. + /// Acquire a connection for executing read operations. #[instrument(skip_all)] async fn read(&self) -> Result { trace!("Taking a `read` connection"); @@ -182,7 +182,7 @@ impl SqliteEventCacheStore { Ok(connection) } - // Acquire a connection for executing write operations. + /// Acquire a connection for executing write operations. #[instrument(skip_all)] async fn write(&self) -> Result> { trace!("Taking a `write` connection"); diff --git a/crates/matrix-sdk-sqlite/src/state_store.rs b/crates/matrix-sdk-sqlite/src/state_store.rs index 115893089f7..cca55782817 100644 --- a/crates/matrix-sdk-sqlite/src/state_store.rs +++ b/crates/matrix-sdk-sqlite/src/state_store.rs @@ -17,7 +17,7 @@ use matrix_sdk_base::{ QueuedRequestKind, RoomLoadSettings, SentRequestKey, StoredThreadSubscription, ThreadSubscriptionStatus, }, - MinimalRoomMemberEvent, RoomInfo, RoomMemberships, RoomState, StateChanges, StateStore, + timer, MinimalRoomMemberEvent, RoomInfo, RoomMemberships, RoomState, StateChanges, StateStore, StateStoreDataKey, StateStoreDataValue, ROOM_VERSION_FALLBACK, ROOM_VERSION_RULES_FALLBACK, }; use matrix_sdk_store_encryption::StoreCipher; @@ -39,8 +39,11 @@ use ruma::{ }; use rusqlite::{OptionalExtension, Transaction}; use serde::{Deserialize, Serialize}; -use tokio::fs; -use tracing::{debug, trace, warn}; +use tokio::{ + fs, + sync::{Mutex, OwnedMutexGuard}, +}; +use tracing::{debug, instrument, trace, warn}; use crate::{ error::{Error, Result}, @@ -81,7 +84,15 @@ const DATABASE_VERSION: u8 = 14; #[derive(Clone)] pub struct SqliteStateStore { store_cipher: Option>, + + /// The pool of connections. pool: SqlitePool, + + /// We make the difference between connections for read operations, and for + /// write operations. We keep a single connection apart from write + /// operations. All other connections are used for read operations. The + /// lock is used to ensure there is one owner at a time. + write_connection: Arc>, } #[cfg(not(tarpaulin_include))] @@ -146,8 +157,13 @@ impl SqliteStateStore { Some(s) => Some(Arc::new(conn.get_or_create_store_cipher(s).await?)), None => None, }; - let this = Self { store_cipher, pool }; - this.run_migrations(&conn, version, None).await?; + let this = Self { + store_cipher, + pool, + // Use `conn` as our selected write connections. + write_connection: Arc::new(Mutex::new(conn)), + }; + this.run_migrations(version, None).await?; Ok(this) } @@ -156,7 +172,7 @@ impl SqliteStateStore { /// version /// /// If `to` is `None`, the current database version will be used. - async fn run_migrations(&self, conn: &SqliteAsyncConn, from: u8, to: Option) -> Result<()> { + async fn run_migrations(&self, from: u8, to: Option) -> Result<()> { let to = to.unwrap_or(DATABASE_VERSION); if from < to { @@ -165,6 +181,8 @@ impl SqliteStateStore { return Ok(()); } + let conn = self.write().await; + if from < 2 && to >= 2 { let this = self.clone(); conn.with_transaction(move |txn| { @@ -443,10 +461,24 @@ impl SqliteStateStore { self.encode_key(keys::KV_BLOB, full_key) } - async fn acquire(&self) -> Result { + /// Acquire a connection for executing read operations. + #[instrument(skip_all)] + async fn read(&self) -> Result { + trace!("Taking a `read` connection"); + let _timer = timer!("connection"); + Ok(self.pool.get().await?) } + /// Acquire a connection for executing write operations. + #[instrument(skip_all)] + async fn write(&self) -> OwnedMutexGuard { + trace!("Taking a `write` connection"); + let _timer = timer!("connection"); + + self.write_connection.clone().lock_owned().await + } + fn remove_maybe_stripped_room_data( &self, txn: &Transaction<'_>, @@ -1016,7 +1048,7 @@ impl StateStore for SqliteStateStore { type Error = Error; async fn get_kv_data(&self, key: StateStoreDataKey<'_>) -> Result> { - self.acquire() + self.read() .await? .get_kv_blob(self.encode_state_store_data_key(key)) .await? @@ -1101,21 +1133,21 @@ impl StateStore for SqliteStateStore { )?, }; - self.acquire() - .await? + self.write() + .await .set_kv_blob(self.encode_state_store_data_key(key), serialized_value) .await } async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<()> { - self.acquire().await?.delete_kv_blob(self.encode_state_store_data_key(key)).await + self.write().await.delete_kv_blob(self.encode_state_store_data_key(key)).await } async fn save_changes(&self, changes: &StateChanges) -> Result<()> { let changes = changes.to_owned(); let this = self.clone(); - self.acquire() - .await? + self.write() + .await .with_transaction(move |txn| { let StateChanges { sync_token, @@ -1422,7 +1454,7 @@ impl StateStore for SqliteStateStore { } async fn get_presence_event(&self, user_id: &UserId) -> Result>> { - self.acquire() + self.read() .await? .get_kv_blob(self.encode_presence_key(user_id)) .await? @@ -1439,7 +1471,7 @@ impl StateStore for SqliteStateStore { } let user_ids = user_ids.iter().map(|u| self.encode_presence_key(u)).collect(); - self.acquire() + self.read() .await? .get_kv_blobs(user_ids) .await? @@ -1468,7 +1500,7 @@ impl StateStore for SqliteStateStore { ) -> Result> { let room_id = self.encode_key(keys::STATE_EVENT, room_id); let event_type = self.encode_key(keys::STATE_EVENT, event_type.to_string()); - self.acquire() + self.read() .await? .get_maybe_stripped_state_events(room_id, event_type) .await? @@ -1498,7 +1530,7 @@ impl StateStore for SqliteStateStore { let room_id = self.encode_key(keys::STATE_EVENT, room_id); let event_type = self.encode_key(keys::STATE_EVENT, event_type.to_string()); let state_keys = state_keys.iter().map(|k| self.encode_key(keys::STATE_EVENT, k)).collect(); - self.acquire() + self.read() .await? .get_maybe_stripped_state_events_for_keys(room_id, event_type, state_keys) .await? @@ -1523,7 +1555,7 @@ impl StateStore for SqliteStateStore { let room_id = self.encode_key(keys::PROFILE, room_id); let user_ids = vec![self.encode_key(keys::PROFILE, user_id)]; - self.acquire() + self.read() .await? .get_profiles(room_id, user_ids) .await? @@ -1549,7 +1581,7 @@ impl StateStore for SqliteStateStore { .collect::>(); let user_ids = user_ids_map.keys().cloned().collect(); - self.acquire() + self.read() .await? .get_profiles(room_id, user_ids) .await? @@ -1576,7 +1608,7 @@ impl StateStore for SqliteStateStore { .into_iter() .map(|m| self.encode_key(keys::MEMBER, m.as_str())) .collect(); - self.acquire() + self.read() .await? .get_user_ids(room_id, memberships) .await? @@ -1586,7 +1618,7 @@ impl StateStore for SqliteStateStore { } async fn get_room_infos(&self, room_load_settings: &RoomLoadSettings) -> Result> { - self.acquire() + self.read() .await? .get_room_infos(match room_load_settings { RoomLoadSettings::All => None, @@ -1610,7 +1642,7 @@ impl StateStore for SqliteStateStore { )]; Ok(self - .acquire() + .read() .await? .get_display_names(room_id, names) .await? @@ -1656,8 +1688,7 @@ impl StateStore for SqliteStateStore { .collect::>(); let names = names_map.keys().cloned().collect(); - for (name, data) in - self.acquire().await?.get_display_names(room_id, names).await?.into_iter() + for (name, data) in self.read().await?.get_display_names(room_id, names).await?.into_iter() { let display_name = names_map.remove(name.as_slice()).expect("returned display names were requested"); @@ -1674,7 +1705,7 @@ impl StateStore for SqliteStateStore { event_type: GlobalAccountDataEventType, ) -> Result>> { let event_type = self.encode_key(keys::GLOBAL_ACCOUNT_DATA, event_type.to_string()); - self.acquire() + self.read() .await? .get_global_account_data(event_type) .await? @@ -1689,7 +1720,7 @@ impl StateStore for SqliteStateStore { ) -> Result>> { let room_id = self.encode_key(keys::ROOM_ACCOUNT_DATA, room_id); let event_type = self.encode_key(keys::ROOM_ACCOUNT_DATA, event_type.to_string()); - self.acquire() + self.read() .await? .get_room_account_data(room_id, event_type) .await? @@ -1711,7 +1742,7 @@ impl StateStore for SqliteStateStore { let thread = self.encode_key(keys::RECEIPT, rmp_serde::to_vec_named(&thread)?); let user_id = self.encode_key(keys::RECEIPT, user_id); - self.acquire() + self.read() .await? .get_user_receipt(room_id, receipt_type, thread, user_id) .await? @@ -1735,7 +1766,7 @@ impl StateStore for SqliteStateStore { let thread = self.encode_key(keys::RECEIPT, rmp_serde::to_vec_named(&thread)?); let event_id = self.encode_key(keys::RECEIPT, event_id); - self.acquire() + self.read() .await? .get_event_receipts(room_id, receipt_type, thread, event_id) .await? @@ -1747,18 +1778,18 @@ impl StateStore for SqliteStateStore { } async fn get_custom_value(&self, key: &[u8]) -> Result>> { - self.acquire().await?.get_kv_blob(self.encode_custom_key(key)).await + self.read().await?.get_kv_blob(self.encode_custom_key(key)).await } async fn set_custom_value_no_read(&self, key: &[u8], value: Vec) -> Result<()> { - let conn = self.acquire().await?; + let conn = self.write().await; let key = self.encode_custom_key(key); conn.set_kv_blob(key, value).await?; Ok(()) } async fn set_custom_value(&self, key: &[u8], value: Vec) -> Result>> { - let conn = self.acquire().await?; + let conn = self.write().await; let key = self.encode_custom_key(key); let previous = conn.get_kv_blob(key.clone()).await?; conn.set_kv_blob(key, value).await?; @@ -1766,7 +1797,7 @@ impl StateStore for SqliteStateStore { } async fn remove_custom_value(&self, key: &[u8]) -> Result>> { - let conn = self.acquire().await?; + let conn = self.write().await; let key = self.encode_custom_key(key); let previous = conn.get_kv_blob(key.clone()).await?; if previous.is_some() { @@ -1779,7 +1810,7 @@ impl StateStore for SqliteStateStore { let this = self.clone(); let room_id = room_id.to_owned(); - let conn = self.acquire().await?; + let conn = self.write().await; conn.with_transaction(move |txn| -> Result<()> { let room_info_room_id = this.encode_key(keys::ROOM_INFO, &room_id); @@ -1842,8 +1873,8 @@ impl StateStore for SqliteStateStore { // all, it carries no personal information, so this is considered fine. let created_at_ts: u64 = created_at.0.into(); - self.acquire() - .await? + self.write() + .await .with_transaction(move |txn| { txn.prepare_cached("INSERT INTO send_queue_events (room_id, room_id_val, transaction_id, content, priority, created_at) VALUES (?, ?, ?, ?, ?, ?)")?.execute((room_id_key, room_id_value, transaction_id.to_string(), content, priority, created_at_ts))?; Ok(()) @@ -1864,8 +1895,8 @@ impl StateStore for SqliteStateStore { // transaction id is neither encrypted or hashed. let transaction_id = transaction_id.to_string(); - let num_updated = self.acquire() - .await? + let num_updated = self.write() + .await .with_transaction(move |txn| { txn.prepare_cached("UPDATE send_queue_events SET wedge_reason = NULL, content = ? WHERE room_id = ? AND transaction_id = ?")?.execute((content, room_id, transaction_id)) }) @@ -1885,8 +1916,8 @@ impl StateStore for SqliteStateStore { let transaction_id = transaction_id.to_string(); let num_deleted = self - .acquire() - .await? + .write() + .await .with_transaction(move |txn| { txn.prepare_cached( "DELETE FROM send_queue_events WHERE room_id = ? AND transaction_id = ?", @@ -1908,7 +1939,7 @@ impl StateStore for SqliteStateStore { // want to maintain the insertion order, so we can sort using it. // Note 2: transaction_id is not encoded, see why in `save_send_queue_event`. let res: Vec<(String, Vec, Option>, usize, Option)> = self - .acquire() + .read() .await? .prepare( "SELECT transaction_id, content, wedge_reason, priority, created_at FROM send_queue_events WHERE room_id = ? ORDER BY priority DESC, ROWID", @@ -1921,11 +1952,13 @@ impl StateStore for SqliteStateStore { .await?; let mut requests = Vec::with_capacity(res.len()); + for entry in res { let created_at = entry .4 .and_then(UInt::new) .map_or_else(MilliSecondsSinceUnixEpoch::now, MilliSecondsSinceUnixEpoch); + requests.push(QueuedRequest { transaction_id: entry.0.into(), kind: self.deserialize_json(&entry.1)?, @@ -1952,8 +1985,8 @@ impl StateStore for SqliteStateStore { // Serialize the error to json bytes (encrypted if option is enabled) if set. let error_value = error.map(|e| self.serialize_value(&e)).transpose()?; - self.acquire() - .await? + self.write() + .await .with_transaction(move |txn| { txn.prepare_cached("UPDATE send_queue_events SET wedge_reason = ? WHERE room_id = ? AND transaction_id = ?")?.execute((error_value, room_id, transaction_id))?; Ok(()) @@ -1967,7 +2000,7 @@ impl StateStore for SqliteStateStore { // != encrypted(X), since we use a nonce in the encryption process. let res: Vec> = self - .acquire() + .read() .await? .prepare("SELECT room_id_val FROM send_queue_events", |mut stmt| { stmt.query(())?.mapped(|row| row.get(0)).collect() @@ -2000,8 +2033,8 @@ impl StateStore for SqliteStateStore { let own_txn_id = own_txn_id.to_string(); let created_at_ts: u64 = created_at.0.into(); - self.acquire() - .await? + self.write() + .await .with_transaction(move |txn| { txn.prepare_cached( r#"INSERT INTO dependent_send_queue_events @@ -2033,8 +2066,8 @@ impl StateStore for SqliteStateStore { let own_txn_id = own_transaction_id.to_string(); let num_updated = self - .acquire() - .await? + .write() + .await .with_transaction(move |txn| { txn.prepare_cached( r#"UPDATE dependent_send_queue_events @@ -2065,8 +2098,8 @@ impl StateStore for SqliteStateStore { // See comment in `save_send_queue_event`. let parent_txn_id = parent_txn_id.to_string(); - self.acquire() - .await? + self.write() + .await .with_transaction(move |txn| { Ok(txn.prepare_cached( "UPDATE dependent_send_queue_events SET parent_key = ? WHERE parent_transaction_id = ? and room_id = ?", @@ -2087,8 +2120,8 @@ impl StateStore for SqliteStateStore { let txn_id = txn_id.to_string(); let num_deleted = self - .acquire() - .await? + .write() + .await .with_transaction(move |txn| { txn.prepare_cached( "DELETE FROM dependent_send_queue_events WHERE own_transaction_id = ? AND room_id = ?", @@ -2108,7 +2141,7 @@ impl StateStore for SqliteStateStore { // Note: transaction_id is not encoded, see why in `save_send_queue_event`. let res: Vec<(String, String, Option>, Vec, Option)> = self - .acquire() + .read() .await? .prepare( "SELECT own_transaction_id, parent_transaction_id, parent_key, content, created_at FROM dependent_send_queue_events WHERE room_id = ? ORDER BY ROWID", @@ -2121,11 +2154,13 @@ impl StateStore for SqliteStateStore { .await?; let mut dependent_events = Vec::with_capacity(res.len()); + for entry in res { let created_at = entry .4 .and_then(UInt::new) .map_or_else(MilliSecondsSinceUnixEpoch::now, MilliSecondsSinceUnixEpoch); + dependent_events.push(DependentQueuedRequest { own_transaction_id: entry.0.into(), parent_transaction_id: entry.1.into(), @@ -2150,6 +2185,7 @@ impl StateStore for SqliteStateStore { trace!("not saving thread subscription because the subscription is the same"); return Ok(()); } + if !compare_thread_subscription_bump_stamps(previous.bump_stamp, &mut new.bump_stamp) { trace!("not saving thread subscription because we have a newer bump stamp"); return Ok(()); @@ -2160,8 +2196,8 @@ impl StateStore for SqliteStateStore { let thread_id = self.encode_key(keys::THREAD_SUBSCRIPTIONS, thread_id); let status = new.status.as_str(); - self.acquire() - .await? + self.write() + .await .with_transaction(move |txn| { // Try to find a previous value. txn.prepare_cached( @@ -2183,7 +2219,7 @@ impl StateStore for SqliteStateStore { let thread_id = self.encode_key(keys::THREAD_SUBSCRIPTIONS, thread_id); Ok(self - .acquire() + .read() .await? .query_row( "SELECT status, bump_stamp FROM thread_subscriptions WHERE room_id = ? AND event_id = ?", @@ -2209,8 +2245,8 @@ impl StateStore for SqliteStateStore { let room_id = self.encode_key(keys::THREAD_SUBSCRIPTIONS, room_id); let thread_id = self.encode_key(keys::THREAD_SUBSCRIPTIONS, thread_id); - self.acquire() - .await? + self.write() + .await .execute( "DELETE FROM thread_subscriptions WHERE room_id = ? AND event_id = ?", (room_id, thread_id), @@ -2369,7 +2405,7 @@ mod migration_tests { use serde::{Deserialize, Serialize}; use serde_json::json; use tempfile::{tempdir, TempDir}; - use tokio::fs; + use tokio::{fs, sync::Mutex}; use zeroize::Zeroizing; use super::{init, keys, SqliteStateStore, DATABASE_NAME}; @@ -2404,8 +2440,13 @@ mod migration_tests { .await .unwrap(), )); - let this = SqliteStateStore { store_cipher, pool }; - this.run_migrations(&conn, 1, Some(version)).await?; + let this = SqliteStateStore { + store_cipher, + pool, + // Use `conn` as our selected write connections. + write_connection: Arc::new(Mutex::new(conn)), + }; + this.run_migrations(1, Some(version)).await?; Ok(this) }