From a483a82f69a7e97e19961adb5c27f2e44a1595fb Mon Sep 17 00:00:00 2001 From: Ben Sully Date: Tue, 28 Jan 2020 00:12:52 +0000 Subject: [PATCH] WIP - add rdbc-tokio-postgres as first impl of new traits --- Cargo.toml | 1 + rdbc-tokio-postgres/Cargo.toml | 14 +++ rdbc-tokio-postgres/src/lib.rs | 191 +++++++++++++++++++++++++++++++++ rdbc/src/lib.rs | 54 ++++++---- 4 files changed, 238 insertions(+), 22 deletions(-) create mode 100644 rdbc-tokio-postgres/Cargo.toml create mode 100644 rdbc-tokio-postgres/src/lib.rs diff --git a/Cargo.toml b/Cargo.toml index 769cd1e..15b438a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ members = [ # "rdbc-mysql", # "rdbc-postgres", # "rdbc-sqlite", + "rdbc-tokio-postgres", # "rdbc-odbc", "rdbc-cli", ] diff --git a/rdbc-tokio-postgres/Cargo.toml b/rdbc-tokio-postgres/Cargo.toml new file mode 100644 index 0000000..771fae4 --- /dev/null +++ b/rdbc-tokio-postgres/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "rdbc-tokio-postgres" +version = "0.1.0" +authors = ["Ben Sully "] +edition = "2018" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +async-trait = "0.1.22" +rdbc = { path = "../rdbc", version = "0.1.6" } +sqlparser = "0.5.0" +tokio = "0.2.10" +tokio-postgres = "0.5.1" diff --git a/rdbc-tokio-postgres/src/lib.rs b/rdbc-tokio-postgres/src/lib.rs new file mode 100644 index 0000000..48b56e7 --- /dev/null +++ b/rdbc-tokio-postgres/src/lib.rs @@ -0,0 +1,191 @@ +use std::{pin::Pin, sync::Arc}; + +use async_trait::async_trait; +use sqlparser::{ + dialect::PostgreSqlDialect, + tokenizer::{Token, Tokenizer, Word}, +}; +use tokio::stream::Stream; +use tokio_postgres::{types::Type, Client, NoTls, Row, Statement}; + +#[derive(Debug)] +pub enum Error { + TokioPostgres(tokio_postgres::Error), +} + +impl From for Error { + fn from(other: tokio_postgres::Error) -> Self { + Self::TokioPostgres(other) + } +} + +pub struct TokioPostgresDriver; + +#[async_trait] +impl rdbc::Driver for TokioPostgresDriver { + type Connection = TokioPostgresConnection; + type Error = Error; + + async fn connect(url: &str) -> Result { + let (client, conn) = tokio_postgres::connect(url, NoTls).await?; + tokio::spawn(conn); + Ok(TokioPostgresConnection { + inner: Arc::new(client), + }) + } +} + +pub struct TokioPostgresConnection { + inner: Arc, +} + +#[async_trait] +impl rdbc::Connection for TokioPostgresConnection { + type Statement = TokioPostgresStatement; + type Error = Error; + + async fn create(&mut self, sql: &str) -> Result { + let sql = { + let dialect = PostgreSqlDialect {}; + let mut tokenizer = Tokenizer::new(&dialect, sql); + let tokens = tokenizer.tokenize().unwrap(); + let mut i = 0_usize; + let tokens: Vec = tokens + .iter() + .map(|t| match t { + Token::Char(c) if *c == '?' => { + i += 1; + Token::Word(Word { + value: format!("${}", i), + quote_style: None, + keyword: "".to_owned(), + }) + } + _ => t.clone(), + }) + .collect(); + tokens + .iter() + .map(|t| format!("{}", t)) + .collect::>() + .join("") + }; + let statement = self.inner.prepare(&sql).await?; + Ok(TokioPostgresStatement { + client: Arc::clone(&self.inner), + statement, + }) + } + + async fn prepare(&mut self, sql: &str) -> Result { + self.create(sql).await + } +} + +pub struct TokioPostgresStatement { + client: Arc, + statement: Statement, +} + +fn to_rdbc_type(ty: &Type) -> rdbc::DataType { + match ty { + &Type::BOOL => rdbc::DataType::Bool, + &Type::CHAR => rdbc::DataType::Char, + //TODO all types + _ => rdbc::DataType::Utf8, + } +} + +fn to_postgres_params(values: &[rdbc::Value]) -> Vec> { + values + .iter() + .map(|v| match v { + rdbc::Value::String(s) => { + Box::new(s.clone()) as Box + } + rdbc::Value::Int32(n) => Box::new(*n) as Box, + rdbc::Value::UInt32(n) => Box::new(*n) as Box, //TODO all types + }) + .collect() +} + +#[async_trait] +impl rdbc::Statement for TokioPostgresStatement { + type ResultSet = TokioPostgresResultSet; + type Error = Error; + async fn execute_query( + &mut self, + params: &[rdbc::Value], + ) -> Result { + let params = to_postgres_params(params); + let params: Vec<_> = params.into_iter().map(|p| p.as_ref()).collect(); + let rows = self + .client + .query(&self.statement, params.as_slice()) + .await? + .into_iter() + .map(|row| TokioPostgresRow { inner: row }) + .collect(); + let meta = self + .statement + .columns() + .iter() + .map(|c| rdbc::Column::new(c.name(), to_rdbc_type(c.type_()))) + .collect(); + Ok(TokioPostgresResultSet { rows, meta }) + } + async fn execute_update(&mut self, params: &[rdbc::Value]) -> Result { + todo!() + } +} + +pub struct TokioPostgresResultSet { + meta: Vec, + rows: Vec, +} + +#[async_trait] +impl rdbc::ResultSet for TokioPostgresResultSet { + type MetaData = Vec; + type Row = TokioPostgresRow; + type Error = Error; + + fn meta_data(&self) -> Result<&Self::MetaData, Self::Error> { + Ok(&self.meta) + } + + async fn batches( + &mut self, + ) -> Result>>>, Self::Error> { + let rows = std::mem::take(&mut self.rows); + Ok(Box::pin(tokio::stream::once(rows))) + } +} + +pub struct TokioPostgresRow { + inner: Row, +} + +macro_rules! impl_resultset_fns { + ($($fn: ident -> $ty: ty),*) => { + $( + fn $fn(&self, i: u64) -> Result, Self::Error> { + Some(self.inner.try_get((i - 1) as usize)).transpose().map_err(Into::into) + } + )* + } +} + +impl rdbc::Row for TokioPostgresRow { + type Error = Error; + impl_resultset_fns! { + get_i8 -> i8, + get_i16 -> i16, + get_i32 -> i32, + get_i64 -> i64, + get_f32 -> f32, + get_f64 -> f64, + get_string -> String, + get_bytes -> Vec + } +} diff --git a/rdbc/src/lib.rs b/rdbc/src/lib.rs index d265f60..3fddc6a 100644 --- a/rdbc/src/lib.rs +++ b/rdbc/src/lib.rs @@ -49,42 +49,48 @@ impl ToString for Value { } } -/// RDBC Result type -pub type Result = std::result::Result; - /// Represents database driver that can be shared between threads, and can therefore implement /// a connection pool +#[async_trait] pub trait Driver: Sync + Send { /// The type of connection created by this driver. type Connection: Connection; + type Error; + /// Create a connection to the database. Note that connections are intended to be used /// in a single thread since most database connections are not thread-safe - fn connect(url: &str) -> Result; + async fn connect(url: &str) -> Result; } /// Represents a connection to a database +#[async_trait] pub trait Connection { /// The type of statement produced by this connection. type Statement: Statement; + type Error; + /// Create a statement for execution - fn create(&mut self, sql: &str) -> Result; + async fn create(&mut self, sql: &str) -> Result; /// Create a prepared statement for execution - fn prepare(&mut self, sql: &str) -> Result; + async fn prepare(&mut self, sql: &str) -> Result; } /// Represents an executable statement +#[async_trait] pub trait Statement { /// The type of ResultSet returned by this statement. type ResultSet: ResultSet; + type Error; + /// Execute a query that is expected to return a result set, such as a `SELECT` statement - fn execute_query(&mut self, params: &[Value]) -> Result; + async fn execute_query(&mut self, params: &[Value]) -> Result; /// Execute a query that is expected to update some rows. - fn execute_update(&mut self, params: &[Value]) -> Result; + async fn execute_update(&mut self, params: &[Value]) -> Result; } /// Result set from executing a query against a statement @@ -95,36 +101,40 @@ pub trait ResultSet { /// The type of row included in this result set. type Row: Row; + type Error; + /// get meta data about this result set - fn meta_data(&self) -> Result; + fn meta_data(&self) -> Result<&Self::MetaData, Self::Error>; /// Get a stream where each item is a batch of rows. - async fn batches(&mut self) -> Result>>>>; + async fn batches(&mut self) + -> Result>>>, Self::Error>; /// Get a stream of rows. /// /// Note that the rows are actually returned from the database in batches; /// this just flattens the batches to provide a (possibly) simpler API. - async fn rows<'a>(&'a mut self) -> Result + 'a>> { + async fn rows<'a>(&'a mut self) -> Result + 'a>, Self::Error> { Ok(Box::new(self.batches().await?.map(iter).flatten())) } } pub trait Row { - fn get_i8(&self, i: u64) -> Result>; - fn get_i16(&self, i: u64) -> Result>; - fn get_i32(&self, i: u64) -> Result>; - fn get_i64(&self, i: u64) -> Result>; - fn get_f32(&self, i: u64) -> Result>; - fn get_f64(&self, i: u64) -> Result>; - fn get_string(&self, i: u64) -> Result>; - fn get_bytes(&self, i: u64) -> Result>>; + type Error; + fn get_i8(&self, i: u64) -> Result, Self::Error>; + fn get_i16(&self, i: u64) -> Result, Self::Error>; + fn get_i32(&self, i: u64) -> Result, Self::Error>; + fn get_i64(&self, i: u64) -> Result, Self::Error>; + fn get_f32(&self, i: u64) -> Result, Self::Error>; + fn get_f64(&self, i: u64) -> Result, Self::Error>; + fn get_string(&self, i: u64) -> Result, Self::Error>; + fn get_bytes(&self, i: u64) -> Result>, Self::Error>; } /// Meta data for result set pub trait MetaData { fn num_columns(&self) -> u64; - fn column_name(&self, i: u64) -> String; + fn column_name(&self, i: u64) -> &str; fn column_type(&self, i: u64) -> DataType; } @@ -166,8 +176,8 @@ impl MetaData for Vec { self.len() as u64 } - fn column_name(&self, i: u64) -> String { - self[i as usize].name.clone() + fn column_name(&self, i: u64) -> &str { + &self[i as usize].name } fn column_type(&self, i: u64) -> DataType {