diff --git a/mm2src/coins/lightning/ln_sql.rs b/mm2src/coins/lightning/ln_sql.rs index 64bbb52ea7..4b5dd2f7c8 100644 --- a/mm2src/coins/lightning/ln_sql.rs +++ b/mm2src/coins/lightning/ln_sql.rs @@ -609,11 +609,14 @@ pub struct SqliteLightningDB { } impl SqliteLightningDB { - pub fn new(ticker: String, sqlite_connection: SqliteConnShared) -> Self { - Self { - db_ticker: ticker.replace('-', "_"), + pub fn new(ticker: String, sqlite_connection: SqliteConnShared) -> Result { + let db_ticker = ticker.replace('-', "_"); + validate_table_name(&db_ticker)?; + + Ok(Self { + db_ticker, sqlite_connection, - } + }) } } @@ -1047,7 +1050,7 @@ mod tests { use super::*; use crate::lightning::ln_db::DBChannelDetails; use common::{block_on, new_uuid}; - use db_common::sqlite::rusqlite::Connection; + use db_common::sqlite::rusqlite::{self, Connection}; use rand::distributions::Alphanumeric; use rand::{Rng, RngCore}; use secp256k1v24::{Secp256k1, SecretKey}; @@ -1157,7 +1160,8 @@ mod tests { let db = SqliteLightningDB::new( "init_sql_collection".into(), Arc::new(Mutex::new(Connection::open_in_memory().unwrap())), - ); + ) + .unwrap(); let initialized = block_on(db.is_db_initialized()).unwrap(); assert!(!initialized); @@ -1174,7 +1178,8 @@ mod tests { let db = SqliteLightningDB::new( "add_get_channel".into(), Arc::new(Mutex::new(Connection::open_in_memory().unwrap())), - ); + ) + .unwrap(); block_on(db.init_db()).unwrap(); @@ -1282,7 +1287,8 @@ mod tests { let db = SqliteLightningDB::new( "add_get_payment".into(), Arc::new(Mutex::new(Connection::open_in_memory().unwrap())), - ); + ) + .unwrap(); block_on(db.init_db()).unwrap(); @@ -1371,7 +1377,8 @@ mod tests { let db = SqliteLightningDB::new( "test_get_payments_by_filter".into(), Arc::new(Mutex::new(Connection::open_in_memory().unwrap())), - ); + ) + .unwrap(); block_on(db.init_db()).unwrap(); @@ -1485,12 +1492,48 @@ mod tests { assert_eq!(expected_payments, actual_payments); } + #[test] + fn test_invalid_lightning_db_name() { + let db = SqliteLightningDB::new("123".into(), Mutex::new(Connection::open_in_memory().unwrap()).into()); + + let expected = || { + SqlError::SqliteFailure( + rusqlite::ffi::Error { + code: rusqlite::ErrorCode::ApiMisuse, + extended_code: rusqlite::ffi::SQLITE_MISUSE, + }, + None, + ) + }; + + assert_eq!(db.err(), Some(expected())); + + let db = SqliteLightningDB::new( + "t".repeat(u8::MAX as usize + 1), + Mutex::new(Connection::open_in_memory().unwrap()).into(), + ); + + assert_eq!(db.err(), Some(expected())); + + let db = SqliteLightningDB::new( + "PROCEDURE".to_owned(), + Mutex::new(Connection::open_in_memory().unwrap()).into(), + ); + + assert_eq!(db.err(), Some(expected())); + + let db = SqliteLightningDB::new(String::new(), Mutex::new(Connection::open_in_memory().unwrap()).into()); + + assert_eq!(db.err(), Some(expected())); + } + #[test] fn test_get_channels_by_filter() { let db = SqliteLightningDB::new( "test_get_channels_by_filter".into(), Arc::new(Mutex::new(Connection::open_in_memory().unwrap())), - ); + ) + .unwrap(); block_on(db.init_db()).unwrap(); diff --git a/mm2src/coins/lightning/ln_utils.rs b/mm2src/coins/lightning/ln_utils.rs index 9e12e52bf7..693b7c3a4f 100644 --- a/mm2src/coins/lightning/ln_utils.rs +++ b/mm2src/coins/lightning/ln_utils.rs @@ -76,7 +76,7 @@ pub async fn init_db(ctx: &MmArc, ticker: String) -> EnableLightningResult SqlResult<()> { /// It disallows any characters in the table name that may lead to SQL injection, only /// allowing alphanumeric characters and underscores. pub fn validate_table_name(table_name: &str) -> SqlResult<()> { + let table_name = table_name.trim(); + + const RESERVED_KEYWORDS: &[&str] = &[ + "SELECT", + "INSERT", + "UPDATE", + "DELETE", + "FROM", + "WHERE", + "JOIN", + "INNER", + "OUTER", + "LEFT", + "RIGHT", + "ON", + "CREATE", + "ALTER", + "DROP", + "TABLE", + "INDEX", + "VIEW", + "TRIGGER", + "PROCEDURE", + "FUNCTION", + "DATABASE", + "AND", + "OR", + "NOT", + "NULL", + "IS", + "IN", + "EXISTS", + "BETWEEN", + "LIKE", + "UNION", + "ALL", + "ANY", + "AS", + "DISTINCT", + "GROUP", + "BY", + "ORDER", + "HAVING", + "LIMIT", + "OFFSET", + "VALUES", + "INTO", + "PRIMARY", + "FOREIGN", + "KEY", + "REFERENCES", + ]; + + let validation_error = || { + SqlError::SqliteFailure( + rusqlite::ffi::Error { + code: rusqlite::ErrorCode::ApiMisuse, + extended_code: rusqlite::ffi::SQLITE_MISUSE, + }, + None, + ) + }; + + if table_name.is_empty() { + log::error!("Table name can not be empty."); + return Err(validation_error()); + } + + if RESERVED_KEYWORDS.contains(&table_name.to_uppercase().as_str()) { + log::error!("{table_name} is a reserved SQLite keyword and can not be used as a table name."); + return Err(validation_error()); + } + + if table_name.len() > u8::MAX as usize { + log::error!("{table_name} length can not be greater than {}.", u8::MAX); + return Err(validation_error()); + } + // As per https://stackoverflow.com/a/3247553, tables can't be the target of parameter substitution. // So we have to use a plain concatenation disallowing any characters in the table name that may lead to SQL injection. validate_ident_impl(table_name, |c| c.is_alphanumeric() || c == '_') @@ -346,9 +424,32 @@ fn validate_ident_impl(ident: &str, is_valid: F) -> SqlResult<()> where F: Fn(char) -> bool, { + let ident = ident.trim(); + + let validation_error = || { + SqlError::SqliteFailure( + rusqlite::ffi::Error { + code: rusqlite::ErrorCode::ApiMisuse, + extended_code: rusqlite::ffi::SQLITE_MISUSE, + }, + None, + ) + }; + + if ident.is_empty() { + log::error!("Ident can not be empty."); + return Err(validation_error()); + } + + if ident.as_bytes()[0].is_ascii_digit() { + log::error!("{ident} starts with number."); + return Err(validation_error()); + } + if ident.chars().all(is_valid) { Ok(()) } else { - Err(SqlError::InvalidParameterName(ident.to_string())) + log::error!("{ident} is not valid."); + Err(validation_error()) } }