From 8388e5df8161cb04788bb78e73bdeb721eaba49d Mon Sep 17 00:00:00 2001 From: Tim Murison Date: Thu, 25 Jun 2020 00:24:08 -0400 Subject: [PATCH 1/3] Sqlite Collation Support Adds a method create_collation to SqliteConnection. Adds a unit test confirming the collation works as expected. --- sqlx-core/src/sqlite/connection/collation.rs | 80 ++++++++++++++++++++ sqlx-core/src/sqlite/connection/mod.rs | 11 +++ tests/sqlite/sqlite.rs | 36 ++++++++- 3 files changed, 126 insertions(+), 1 deletion(-) create mode 100644 sqlx-core/src/sqlite/connection/collation.rs diff --git a/sqlx-core/src/sqlite/connection/collation.rs b/sqlx-core/src/sqlite/connection/collation.rs new file mode 100644 index 0000000000..a498228dc5 --- /dev/null +++ b/sqlx-core/src/sqlite/connection/collation.rs @@ -0,0 +1,80 @@ +use std::cmp::Ordering; +use std::ffi::CString; +use std::os::raw::{c_char, c_int, c_void}; +use std::panic::{catch_unwind, UnwindSafe}; +use std::slice; + +use libsqlite3_sys::{sqlite3_create_collation_v2, SQLITE_OK, SQLITE_UTF8}; + +use crate::error::Error; +use crate::sqlite::connection::handle::ConnectionHandle; +use crate::sqlite::SqliteError; + +unsafe extern "C" fn free_boxed_value(p: *mut c_void) { + drop(Box::from_raw(p as *mut T)); +} + +pub(crate) fn create_collation( + handle: &ConnectionHandle, + name: &str, + collation: F +) -> Result<(), Error> + where F: Fn(&str, &str) -> Ordering + Send + Sync + UnwindSafe + 'static +{ + unsafe extern "C" fn call_boxed_closure( + arg1: *mut c_void, + arg2: c_int, + arg3: *const c_void, + arg4: c_int, + arg5: *const c_void, + ) -> c_int + where + C: Fn(&str, &str) -> Ordering, + { + let r = catch_unwind(|| { + let boxed_f: *mut C = arg1 as *mut C; + assert!(!boxed_f.is_null(), "Internal error - null function pointer"); + let s1 = { + let c_slice = slice::from_raw_parts(arg3 as *const u8, arg2 as usize); + String::from_utf8_lossy(c_slice) + }; + let s2 = { + let c_slice = slice::from_raw_parts(arg5 as *const u8, arg4 as usize); + String::from_utf8_lossy(c_slice) + }; + (*boxed_f)(s1.as_ref(), s2.as_ref()) + }); + let t = match r { + Err(_) => { + return -1; // FIXME How ? + } + Ok(r) => r, + }; + + match t { + Ordering::Less => -1, + Ordering::Equal => 0, + Ordering::Greater => 1, + } + } + + let boxed_f: *mut F = Box::into_raw(Box::new(collation)); + let c_name = CString::new(name).map_err(|_| err_protocol!("Invalid collation name: {}", name))?; + let flags = SQLITE_UTF8; + let r = unsafe { + sqlite3_create_collation_v2( + handle.as_ptr(), + c_name.as_ptr(), + flags, + boxed_f as *mut c_void, + Some(call_boxed_closure::), + Some(free_boxed_value::), + ) + }; + + if r == SQLITE_OK { + Ok(()) + } else { + Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr())))) + } +} \ No newline at end of file diff --git a/sqlx-core/src/sqlite/connection/mod.rs b/sqlx-core/src/sqlite/connection/mod.rs index d70c1c8909..8df030af5c 100644 --- a/sqlx-core/src/sqlite/connection/mod.rs +++ b/sqlx-core/src/sqlite/connection/mod.rs @@ -1,4 +1,6 @@ +use std::cmp::Ordering; use std::fmt::{self, Debug, Formatter}; +use std::panic::UnwindSafe; use std::sync::Arc; use futures_core::future::BoxFuture; @@ -13,6 +15,7 @@ use crate::sqlite::connection::establish::establish; use crate::sqlite::statement::{SqliteStatement, StatementWorker}; use crate::sqlite::{Sqlite, SqliteConnectOptions}; +mod collation; mod establish; mod executor; mod handle; @@ -39,6 +42,14 @@ impl SqliteConnection { pub fn as_raw_handle(&mut self) -> *mut sqlite3 { self.handle.as_ptr() } + + pub fn create_collation(&mut self, + name: &str, + collation: impl Fn(&str, &str) -> Ordering + Send + Sync + UnwindSafe + 'static + ) -> &mut SqliteConnection { + collation::create_collation(&self.handle, name, collation); + self + } } impl Debug for SqliteConnection { diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index 1630773b0f..c0642b29ea 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -1,6 +1,6 @@ use futures::TryStreamExt; use sqlx::{ - query, sqlite::Sqlite, Connect, Connection, Executor, Row, SqliteConnection, SqlitePool, + query, sqlite::Sqlite, sqlite::SqliteRow, Connect, Connection, Executor, Row, SqliteConnection, SqlitePool, }; use sqlx_test::new; @@ -269,3 +269,37 @@ SELECT id, text FROM _sqlx_test; Ok(()) } + +#[sqlx_macros::test] +async fn it_supports_collations() -> anyhow::Result<()> { + let mut conn = new::().await?; + + conn.create_collation("test_collation", |l, r| { + l.cmp(r).reverse() + }); + + + let _ = conn + .execute( + r#" +CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL COLLATE test_collation) + "#, + ) + .await?; + + sqlx::query("INSERT INTO users (name) VALUES (?)") + .bind("a") + .execute(&mut conn) + .await?; + sqlx::query("INSERT INTO users (name) VALUES (?)") + .bind("b") + .execute(&mut conn) + .await?; + + let row: SqliteRow = conn.fetch_one("SELECT name FROM users ORDER BY name ASC").await?; + let name: &str = row.try_get(0)?; + + assert_eq!(name, "b"); + + Ok(()) +} From 4e174549513c21dee7cb7a7bbec7a1f2abccc7ba Mon Sep 17 00:00:00 2001 From: Tim Murison Date: Thu, 25 Jun 2020 00:47:13 -0400 Subject: [PATCH 2/3] Fix formatting --- sqlx-core/src/sqlite/connection/collation.rs | 12 +++++++----- sqlx-core/src/sqlite/connection/mod.rs | 5 +++-- tests/sqlite/sqlite.rs | 12 ++++++------ 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/sqlx-core/src/sqlite/connection/collation.rs b/sqlx-core/src/sqlite/connection/collation.rs index a498228dc5..e08da84b87 100644 --- a/sqlx-core/src/sqlite/connection/collation.rs +++ b/sqlx-core/src/sqlite/connection/collation.rs @@ -17,9 +17,10 @@ unsafe extern "C" fn free_boxed_value(p: *mut c_void) { pub(crate) fn create_collation( handle: &ConnectionHandle, name: &str, - collation: F + collation: F, ) -> Result<(), Error> - where F: Fn(&str, &str) -> Ordering + Send + Sync + UnwindSafe + 'static +where + F: Fn(&str, &str) -> Ordering + Send + Sync + UnwindSafe + 'static, { unsafe extern "C" fn call_boxed_closure( arg1: *mut c_void, @@ -59,7 +60,8 @@ pub(crate) fn create_collation( } let boxed_f: *mut F = Box::into_raw(Box::new(collation)); - let c_name = CString::new(name).map_err(|_| err_protocol!("Invalid collation name: {}", name))?; + let c_name = + CString::new(name).map_err(|_| err_protocol!("Invalid collation name: {}", name))?; let flags = SQLITE_UTF8; let r = unsafe { sqlite3_create_collation_v2( @@ -76,5 +78,5 @@ pub(crate) fn create_collation( Ok(()) } else { Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr())))) - } -} \ No newline at end of file + } +} diff --git a/sqlx-core/src/sqlite/connection/mod.rs b/sqlx-core/src/sqlite/connection/mod.rs index 8df030af5c..080fc0b8dd 100644 --- a/sqlx-core/src/sqlite/connection/mod.rs +++ b/sqlx-core/src/sqlite/connection/mod.rs @@ -43,9 +43,10 @@ impl SqliteConnection { self.handle.as_ptr() } - pub fn create_collation(&mut self, + pub fn create_collation( + &mut self, name: &str, - collation: impl Fn(&str, &str) -> Ordering + Send + Sync + UnwindSafe + 'static + collation: impl Fn(&str, &str) -> Ordering + Send + Sync + UnwindSafe + 'static, ) -> &mut SqliteConnection { collation::create_collation(&self.handle, name, collation); self diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index c0642b29ea..7c9b910ff9 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -1,6 +1,7 @@ use futures::TryStreamExt; use sqlx::{ - query, sqlite::Sqlite, sqlite::SqliteRow, Connect, Connection, Executor, Row, SqliteConnection, SqlitePool, + query, sqlite::Sqlite, sqlite::SqliteRow, Connect, Connection, Executor, Row, SqliteConnection, + SqlitePool, }; use sqlx_test::new; @@ -274,10 +275,7 @@ SELECT id, text FROM _sqlx_test; async fn it_supports_collations() -> anyhow::Result<()> { let mut conn = new::().await?; - conn.create_collation("test_collation", |l, r| { - l.cmp(r).reverse() - }); - + conn.create_collation("test_collation", |l, r| l.cmp(r).reverse()); let _ = conn .execute( @@ -296,7 +294,9 @@ CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL COLLATE .execute(&mut conn) .await?; - let row: SqliteRow = conn.fetch_one("SELECT name FROM users ORDER BY name ASC").await?; + let row: SqliteRow = conn + .fetch_one("SELECT name FROM users ORDER BY name ASC") + .await?; let name: &str = row.try_get(0)?; assert_eq!(name, "b"); From 747ab33a239a387e9e6e84e45b17dad7cbb05594 Mon Sep 17 00:00:00 2001 From: Tim Murison Date: Sat, 27 Jun 2020 20:33:38 -0400 Subject: [PATCH 3/3] Address feedback --- sqlx-core/src/sqlite/connection/collation.rs | 38 ++++++++------------ sqlx-core/src/sqlite/connection/mod.rs | 8 ++--- tests/sqlite/sqlite.rs | 4 ++- 3 files changed, 21 insertions(+), 29 deletions(-) diff --git a/sqlx-core/src/sqlite/connection/collation.rs b/sqlx-core/src/sqlite/connection/collation.rs index e08da84b87..353fa7252c 100644 --- a/sqlx-core/src/sqlite/connection/collation.rs +++ b/sqlx-core/src/sqlite/connection/collation.rs @@ -1,8 +1,8 @@ use std::cmp::Ordering; use std::ffi::CString; use std::os::raw::{c_char, c_int, c_void}; -use std::panic::{catch_unwind, UnwindSafe}; use std::slice; +use std::str::from_utf8_unchecked; use libsqlite3_sys::{sqlite3_create_collation_v2, SQLITE_OK, SQLITE_UTF8}; @@ -17,10 +17,10 @@ unsafe extern "C" fn free_boxed_value(p: *mut c_void) { pub(crate) fn create_collation( handle: &ConnectionHandle, name: &str, - collation: F, + compare: F, ) -> Result<(), Error> where - F: Fn(&str, &str) -> Ordering + Send + Sync + UnwindSafe + 'static, + F: Fn(&str, &str) -> Ordering + Send + Sync + 'static, { unsafe extern "C" fn call_boxed_closure( arg1: *mut c_void, @@ -32,25 +32,17 @@ where where C: Fn(&str, &str) -> Ordering, { - let r = catch_unwind(|| { - let boxed_f: *mut C = arg1 as *mut C; - assert!(!boxed_f.is_null(), "Internal error - null function pointer"); - let s1 = { - let c_slice = slice::from_raw_parts(arg3 as *const u8, arg2 as usize); - String::from_utf8_lossy(c_slice) - }; - let s2 = { - let c_slice = slice::from_raw_parts(arg5 as *const u8, arg4 as usize); - String::from_utf8_lossy(c_slice) - }; - (*boxed_f)(s1.as_ref(), s2.as_ref()) - }); - let t = match r { - Err(_) => { - return -1; // FIXME How ? - } - Ok(r) => r, + let boxed_f: *mut C = arg1 as *mut C; + debug_assert!(!boxed_f.is_null()); + let s1 = { + let c_slice = slice::from_raw_parts(arg3 as *const u8, arg2 as usize); + from_utf8_unchecked(c_slice) }; + let s2 = { + let c_slice = slice::from_raw_parts(arg5 as *const u8, arg4 as usize); + from_utf8_unchecked(c_slice) + }; + let t = (*boxed_f)(s1, s2); match t { Ordering::Less => -1, @@ -59,9 +51,9 @@ where } } - let boxed_f: *mut F = Box::into_raw(Box::new(collation)); + let boxed_f: *mut F = Box::into_raw(Box::new(compare)); let c_name = - CString::new(name).map_err(|_| err_protocol!("Invalid collation name: {}", name))?; + CString::new(name).map_err(|_| err_protocol!("invalid collation name: {}", name))?; let flags = SQLITE_UTF8; let r = unsafe { sqlite3_create_collation_v2( diff --git a/sqlx-core/src/sqlite/connection/mod.rs b/sqlx-core/src/sqlite/connection/mod.rs index af01d63523..1db8510b67 100644 --- a/sqlx-core/src/sqlite/connection/mod.rs +++ b/sqlx-core/src/sqlite/connection/mod.rs @@ -1,6 +1,5 @@ use std::cmp::Ordering; use std::fmt::{self, Debug, Formatter}; -use std::panic::UnwindSafe; use std::sync::Arc; use futures_core::future::BoxFuture; @@ -47,10 +46,9 @@ impl SqliteConnection { pub fn create_collation( &mut self, name: &str, - collation: impl Fn(&str, &str) -> Ordering + Send + Sync + UnwindSafe + 'static, - ) -> &mut SqliteConnection { - collation::create_collation(&self.handle, name, collation); - self + compare: impl Fn(&str, &str) -> Ordering + Send + Sync + 'static, + ) -> Result<(), Error> { + collation::create_collation(&self.handle, name, compare) } } diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index 13406315c8..44f1f4a7be 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -275,7 +275,7 @@ SELECT id, text FROM _sqlx_test; async fn it_supports_collations() -> anyhow::Result<()> { let mut conn = new::().await?; - conn.create_collation("test_collation", |l, r| l.cmp(r).reverse()); + conn.create_collation("test_collation", |l, r| l.cmp(r).reverse())?; let _ = conn .execute( @@ -300,6 +300,8 @@ CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL COLLATE let name: &str = row.try_get(0)?; assert_eq!(name, "b"); + + Ok(()) } #[sqlx_macros::test]