Skip to content

Commit

Permalink
WIP - add rdbc-tokio-postgres as first impl of new traits
Browse files Browse the repository at this point in the history
  • Loading branch information
sd2k committed Jan 28, 2020
1 parent 2adb34a commit a483a82
Show file tree
Hide file tree
Showing 4 changed files with 238 additions and 22 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ members = [
# "rdbc-mysql",
# "rdbc-postgres",
# "rdbc-sqlite",
"rdbc-tokio-postgres",
# "rdbc-odbc",
"rdbc-cli",
]
14 changes: 14 additions & 0 deletions rdbc-tokio-postgres/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "rdbc-tokio-postgres"
version = "0.1.0"
authors = ["Ben Sully <[email protected]>"]
edition = "2018"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
async-trait = "0.1.22"
rdbc = { path = "../rdbc", version = "0.1.6" }
sqlparser = "0.5.0"
tokio = "0.2.10"
tokio-postgres = "0.5.1"
191 changes: 191 additions & 0 deletions rdbc-tokio-postgres/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
use std::{pin::Pin, sync::Arc};

use async_trait::async_trait;
use sqlparser::{
dialect::PostgreSqlDialect,
tokenizer::{Token, Tokenizer, Word},
};
use tokio::stream::Stream;
use tokio_postgres::{types::Type, Client, NoTls, Row, Statement};

#[derive(Debug)]
pub enum Error {
TokioPostgres(tokio_postgres::Error),
}

impl From<tokio_postgres::Error> for Error {
fn from(other: tokio_postgres::Error) -> Self {
Self::TokioPostgres(other)
}
}

pub struct TokioPostgresDriver;

#[async_trait]
impl rdbc::Driver for TokioPostgresDriver {
type Connection = TokioPostgresConnection;
type Error = Error;

async fn connect(url: &str) -> Result<Self::Connection, Self::Error> {
let (client, conn) = tokio_postgres::connect(url, NoTls).await?;
tokio::spawn(conn);
Ok(TokioPostgresConnection {
inner: Arc::new(client),
})
}
}

pub struct TokioPostgresConnection {
inner: Arc<Client>,
}

#[async_trait]
impl rdbc::Connection for TokioPostgresConnection {
type Statement = TokioPostgresStatement;
type Error = Error;

async fn create(&mut self, sql: &str) -> Result<Self::Statement, Self::Error> {
let sql = {
let dialect = PostgreSqlDialect {};
let mut tokenizer = Tokenizer::new(&dialect, sql);
let tokens = tokenizer.tokenize().unwrap();
let mut i = 0_usize;
let tokens: Vec<Token> = tokens
.iter()
.map(|t| match t {
Token::Char(c) if *c == '?' => {
i += 1;
Token::Word(Word {
value: format!("${}", i),
quote_style: None,
keyword: "".to_owned(),
})
}
_ => t.clone(),
})
.collect();
tokens
.iter()
.map(|t| format!("{}", t))
.collect::<Vec<String>>()
.join("")
};
let statement = self.inner.prepare(&sql).await?;
Ok(TokioPostgresStatement {
client: Arc::clone(&self.inner),
statement,
})
}

async fn prepare(&mut self, sql: &str) -> Result<Self::Statement, Self::Error> {
self.create(sql).await
}
}

pub struct TokioPostgresStatement {
client: Arc<Client>,
statement: Statement,
}

fn to_rdbc_type(ty: &Type) -> rdbc::DataType {
match ty {
&Type::BOOL => rdbc::DataType::Bool,
&Type::CHAR => rdbc::DataType::Char,
//TODO all types
_ => rdbc::DataType::Utf8,
}
}

fn to_postgres_params(values: &[rdbc::Value]) -> Vec<Box<dyn tokio_postgres::types::ToSql + Sync>> {
values
.iter()
.map(|v| match v {
rdbc::Value::String(s) => {
Box::new(s.clone()) as Box<dyn tokio_postgres::types::ToSql + Sync>
}
rdbc::Value::Int32(n) => Box::new(*n) as Box<dyn tokio_postgres::types::ToSql + Sync>,
rdbc::Value::UInt32(n) => Box::new(*n) as Box<dyn tokio_postgres::types::ToSql + Sync>, //TODO all types
})
.collect()
}

#[async_trait]
impl rdbc::Statement for TokioPostgresStatement {
type ResultSet = TokioPostgresResultSet;
type Error = Error;
async fn execute_query(
&mut self,
params: &[rdbc::Value],
) -> Result<Self::ResultSet, Self::Error> {
let params = to_postgres_params(params);
let params: Vec<_> = params.into_iter().map(|p| p.as_ref()).collect();
let rows = self
.client
.query(&self.statement, params.as_slice())
.await?
.into_iter()
.map(|row| TokioPostgresRow { inner: row })
.collect();
let meta = self
.statement
.columns()
.iter()
.map(|c| rdbc::Column::new(c.name(), to_rdbc_type(c.type_())))
.collect();
Ok(TokioPostgresResultSet { rows, meta })
}
async fn execute_update(&mut self, params: &[rdbc::Value]) -> Result<u64, Self::Error> {
todo!()
}
}

pub struct TokioPostgresResultSet {
meta: Vec<rdbc::Column>,
rows: Vec<TokioPostgresRow>,
}

#[async_trait]
impl rdbc::ResultSet for TokioPostgresResultSet {
type MetaData = Vec<rdbc::Column>;
type Row = TokioPostgresRow;
type Error = Error;

fn meta_data(&self) -> Result<&Self::MetaData, Self::Error> {
Ok(&self.meta)
}

async fn batches(
&mut self,
) -> Result<Pin<Box<dyn Stream<Item = Vec<Self::Row>>>>, Self::Error> {
let rows = std::mem::take(&mut self.rows);
Ok(Box::pin(tokio::stream::once(rows)))
}
}

pub struct TokioPostgresRow {
inner: Row,
}

macro_rules! impl_resultset_fns {
($($fn: ident -> $ty: ty),*) => {
$(
fn $fn(&self, i: u64) -> Result<Option<$ty>, Self::Error> {
Some(self.inner.try_get((i - 1) as usize)).transpose().map_err(Into::into)
}
)*
}
}

impl rdbc::Row for TokioPostgresRow {
type Error = Error;
impl_resultset_fns! {
get_i8 -> i8,
get_i16 -> i16,
get_i32 -> i32,
get_i64 -> i64,
get_f32 -> f32,
get_f64 -> f64,
get_string -> String,
get_bytes -> Vec<u8>
}
}
54 changes: 32 additions & 22 deletions rdbc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,42 +49,48 @@ impl ToString for Value {
}
}

/// RDBC Result type
pub type Result<T> = std::result::Result<T, Error>;

/// Represents database driver that can be shared between threads, and can therefore implement
/// a connection pool
#[async_trait]
pub trait Driver: Sync + Send {
/// The type of connection created by this driver.
type Connection: Connection;

type Error;

/// Create a connection to the database. Note that connections are intended to be used
/// in a single thread since most database connections are not thread-safe
fn connect(url: &str) -> Result<Self::Connection>;
async fn connect(url: &str) -> Result<Self::Connection, Self::Error>;
}

/// Represents a connection to a database
#[async_trait]
pub trait Connection {
/// The type of statement produced by this connection.
type Statement: Statement;

type Error;

/// Create a statement for execution
fn create(&mut self, sql: &str) -> Result<Self::Statement>;
async fn create(&mut self, sql: &str) -> Result<Self::Statement, Self::Error>;

/// Create a prepared statement for execution
fn prepare(&mut self, sql: &str) -> Result<Self::Statement>;
async fn prepare(&mut self, sql: &str) -> Result<Self::Statement, Self::Error>;
}

/// Represents an executable statement
#[async_trait]
pub trait Statement {
/// The type of ResultSet returned by this statement.
type ResultSet: ResultSet;

type Error;

/// Execute a query that is expected to return a result set, such as a `SELECT` statement
fn execute_query(&mut self, params: &[Value]) -> Result<Self::ResultSet>;
async fn execute_query(&mut self, params: &[Value]) -> Result<Self::ResultSet, Self::Error>;

/// Execute a query that is expected to update some rows.
fn execute_update(&mut self, params: &[Value]) -> Result<u64>;
async fn execute_update(&mut self, params: &[Value]) -> Result<u64, Self::Error>;
}

/// Result set from executing a query against a statement
Expand All @@ -95,36 +101,40 @@ pub trait ResultSet {
/// The type of row included in this result set.
type Row: Row;

type Error;

/// get meta data about this result set
fn meta_data(&self) -> Result<Self::MetaData>;
fn meta_data(&self) -> Result<&Self::MetaData, Self::Error>;

/// Get a stream where each item is a batch of rows.
async fn batches(&mut self) -> Result<Pin<Box<dyn Stream<Item = Vec<Self::Row>>>>>;
async fn batches(&mut self)
-> Result<Pin<Box<dyn Stream<Item = Vec<Self::Row>>>>, Self::Error>;

/// Get a stream of rows.
///
/// Note that the rows are actually returned from the database in batches;
/// this just flattens the batches to provide a (possibly) simpler API.
async fn rows<'a>(&'a mut self) -> Result<Box<dyn Stream<Item = Self::Row> + 'a>> {
async fn rows<'a>(&'a mut self) -> Result<Box<dyn Stream<Item = Self::Row> + 'a>, Self::Error> {
Ok(Box::new(self.batches().await?.map(iter).flatten()))
}
}

pub trait Row {
fn get_i8(&self, i: u64) -> Result<Option<i8>>;
fn get_i16(&self, i: u64) -> Result<Option<i16>>;
fn get_i32(&self, i: u64) -> Result<Option<i32>>;
fn get_i64(&self, i: u64) -> Result<Option<i64>>;
fn get_f32(&self, i: u64) -> Result<Option<f32>>;
fn get_f64(&self, i: u64) -> Result<Option<f64>>;
fn get_string(&self, i: u64) -> Result<Option<String>>;
fn get_bytes(&self, i: u64) -> Result<Option<Vec<u8>>>;
type Error;
fn get_i8(&self, i: u64) -> Result<Option<i8>, Self::Error>;
fn get_i16(&self, i: u64) -> Result<Option<i16>, Self::Error>;
fn get_i32(&self, i: u64) -> Result<Option<i32>, Self::Error>;
fn get_i64(&self, i: u64) -> Result<Option<i64>, Self::Error>;
fn get_f32(&self, i: u64) -> Result<Option<f32>, Self::Error>;
fn get_f64(&self, i: u64) -> Result<Option<f64>, Self::Error>;
fn get_string(&self, i: u64) -> Result<Option<String>, Self::Error>;
fn get_bytes(&self, i: u64) -> Result<Option<Vec<u8>>, Self::Error>;
}

/// Meta data for result set
pub trait MetaData {
fn num_columns(&self) -> u64;
fn column_name(&self, i: u64) -> String;
fn column_name(&self, i: u64) -> &str;
fn column_type(&self, i: u64) -> DataType;
}

Expand Down Expand Up @@ -166,8 +176,8 @@ impl MetaData for Vec<Column> {
self.len() as u64
}

fn column_name(&self, i: u64) -> String {
self[i as usize].name.clone()
fn column_name(&self, i: u64) -> &str {
&self[i as usize].name
}

fn column_type(&self, i: u64) -> DataType {
Expand Down

0 comments on commit a483a82

Please sign in to comment.