Skip to content

Commit

Permalink
Merge branch 'spiceai' into qianqian/disable-pg-federation
Browse files Browse the repository at this point in the history
  • Loading branch information
phillipleblanc authored Sep 6, 2024
2 parents 6680241 + 154eabd commit ce683ca
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 38 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ on:
pull_request:
branches:
- main
- spiceai

jobs:
lint:
Expand Down
6 changes: 4 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ tonic = { version = "0.11", optional = true } # pinned for arrow-flight compat
datafusion-federation = "0.1"
datafusion-federation-sql = { git = "https://github.com/spiceai/datafusion-federation.git", rev = "b6682948d07cc3155edb3dfbf03f8b55570fc1d2" }
itertools = "0.13.0"
dyn-clone = { version = "1.0.17", optional = true }
geo-types = "0.7.13"

[dev-dependencies]
Expand All @@ -78,7 +79,7 @@ tokio-stream = { version = "0.1.15", features = ["net"] }
mysql = ["dep:mysql_async", "dep:async-stream"]
postgres = ["dep:tokio-postgres", "dep:uuid", "dep:postgres-native-tls", "dep:bb8", "dep:bb8-postgres", "dep:native-tls", "dep:pem", "dep:async-stream"]
sqlite = ["dep:rusqlite", "dep:tokio-rusqlite"]
duckdb = ["dep:duckdb", "dep:r2d2", "dep:uuid"]
duckdb = ["dep:duckdb", "dep:r2d2", "dep:uuid", "dep:dyn-clone", "dep:async-stream"]
flight = [
"dep:arrow-array",
"dep:arrow-flight",
Expand All @@ -98,4 +99,5 @@ sqlite-federation = ["sqlite"]
postgres-federation = ["postgres"]

[patch.crates-io]
datafusion-federation = { git = "https://github.com/spiceai/datafusion-federation.git", rev = "b6682948d07cc3155edb3dfbf03f8b55570fc1d2" }
datafusion-federation = { git = "https://github.com/spiceai/datafusion-federation.git", rev = "b6682948d07cc3155edb3dfbf03f8b55570fc1d2" }
duckdb = { git = "https://github.com/spiceai/duckdb-rs.git", rev = "f2ca47d094a5636df8b9f3792b2f474a7b210dc1" }
32 changes: 22 additions & 10 deletions src/duckdb.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use crate::sql::db_connection_pool::{
self,
dbconnection::{
duckdbconn::{flatten_table_function_name, is_table_function, DuckDbConnection},
duckdbconn::{
flatten_table_function_name, is_table_function, DuckDBParameter, DuckDbConnection,
},
get_schema, DbConnection,
},
duckdbpool::DuckDbConnectionPool,
Expand All @@ -25,7 +27,7 @@ use datafusion::{
logical_expr::CreateExternalTable,
sql::TableReference,
};
use duckdb::{AccessMode, DuckdbConnectionManager, ToSql, Transaction};
use duckdb::{AccessMode, DuckdbConnectionManager, Transaction};
use itertools::Itertools;
use snafu::prelude::*;
use std::{cmp, collections::HashMap, sync::Arc};
Expand Down Expand Up @@ -177,7 +179,7 @@ impl Default for DuckDBTableProviderFactory {
}
}

type DynDuckDbConnectionPool = dyn DbConnectionPool<r2d2::PooledConnection<DuckdbConnectionManager>, &'static dyn ToSql>
type DynDuckDbConnectionPool = dyn DbConnectionPool<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>
+ Send
+ Sync;

Expand Down Expand Up @@ -265,7 +267,12 @@ impl TableProviderFactory for DuckDBTableProviderFactory {
));

#[cfg(feature = "duckdb-federation")]
let read_provider = Arc::new(read_provider.create_federated_table_provider()?);
let read_provider: Arc<dyn TableProvider> = if mode == Mode::File {
// federation is disabled for in-memory mode until memory connections are updated to use the same database instance instead of separate instances
Arc::new(read_provider.create_federated_table_provider()?)
} else {
read_provider
};

Ok(DuckDBTableWriter::create(
read_provider,
Expand Down Expand Up @@ -317,18 +324,18 @@ impl DuckDB {
pub fn connect_sync(
&self,
) -> Result<
Box<dyn DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, &'static dyn ToSql>>,
Box<dyn DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>>,
> {
Arc::clone(&self.pool)
.connect_sync()
.context(DbConnectionSnafu)
}

pub fn duckdb_conn<'a>(
db_connection: &'a mut Box<
dyn DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, &'static dyn ToSql>,
pub fn duckdb_conn(
db_connection: &mut Box<
dyn DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>,
>,
) -> Result<&'a mut DuckDbConnection> {
) -> Result<&mut DuckDbConnection> {
db_connection
.as_any_mut()
.downcast_mut::<DuckDbConnection>()
Expand Down Expand Up @@ -441,7 +448,12 @@ impl DuckDBTableFactory {
));

#[cfg(feature = "duckdb-federation")]
let table_provider = Arc::new(table_provider.create_federated_table_provider()?);
let table_provider: Arc<dyn TableProvider> = if self.pool.mode() == Mode::File {
// federation is disabled for in-memory mode until memory connections are updated to use the same database instance instead of separate instances
Arc::new(table_provider.create_federated_table_provider()?)
} else {
table_provider
};

Ok(table_provider)
}
Expand Down
107 changes: 96 additions & 11 deletions src/sql/db_connection_pool/dbconnection/duckdbconn.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
use std::any::Any;

use arrow::array::RecordBatch;
use async_stream::stream;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::error::DataFusionError;
use datafusion::execution::SendableRecordBatchStream;
use datafusion::physical_plan::memory::MemoryStream;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::sql::sqlparser::ast::TableFactor;
use datafusion::sql::sqlparser::parser::Parser;
use datafusion::sql::sqlparser::{dialect::DuckDbDialect, tokenizer::Tokenizer};
use datafusion::sql::TableReference;
use duckdb::DuckdbConnectionManager;
use duckdb::ToSql;
use dyn_clone::DynClone;
use snafu::{prelude::*, ResultExt};
use tokio::sync::mpsc::Sender;

use super::DbConnection;
use super::Result;
Expand All @@ -20,7 +24,22 @@ use super::SyncDbConnection;
pub enum Error {
#[snafu(display("DuckDBError: {source}"))]
DuckDBError { source: duckdb::Error },

#[snafu(display("ChannelError: {message}"))]
ChannelError { message: String },
}

pub trait DuckDBSyncParameter: ToSql + Sync + Send + DynClone {
fn as_input_parameter(&self) -> &dyn ToSql;
}

impl<T: ToSql + Sync + Send + DynClone> DuckDBSyncParameter for T {
fn as_input_parameter(&self) -> &dyn ToSql {
self
}
}
dyn_clone::clone_trait_object!(DuckDBSyncParameter);
pub type DuckDBParameter = Box<dyn DuckDBSyncParameter>;

pub struct DuckDbConnection {
pub conn: r2d2::PooledConnection<DuckdbConnectionManager>,
Expand All @@ -34,7 +53,7 @@ impl DuckDbConnection {
}
}

impl<'a> DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, &'a dyn ToSql>
impl DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>
for DuckDbConnection
{
fn as_any(&self) -> &dyn Any {
Expand All @@ -47,13 +66,14 @@ impl<'a> DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, &'a dyn T

fn as_sync(
&self,
) -> Option<&dyn SyncDbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, &'a dyn ToSql>>
{
) -> Option<
&dyn SyncDbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>,
> {
Some(self)
}
}

impl SyncDbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, &dyn ToSql>
impl SyncDbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>
for DuckDbConnection
{
fn new(conn: r2d2::PooledConnection<DuckdbConnectionManager>) -> Self {
Expand Down Expand Up @@ -83,23 +103,88 @@ impl SyncDbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, &dyn ToSq
fn query_arrow(
&self,
sql: &str,
params: &[&dyn ToSql],
params: &[DuckDBParameter],
_projected_schema: Option<SchemaRef>,
) -> Result<SendableRecordBatchStream> {
let mut stmt = self.conn.prepare(sql).context(DuckDBSnafu)?;
let (batch_tx, mut batch_rx) = tokio::sync::mpsc::channel::<RecordBatch>(4);

let fetch_schema_sql =
format!("WITH fetch_schema AS ({sql}) SELECT * FROM fetch_schema LIMIT 0");
let mut stmt = self
.conn
.prepare(&fetch_schema_sql)
.boxed()
.context(super::UnableToGetSchemaSnafu)?;

let result: duckdb::Arrow<'_> = stmt
.query_arrow([])
.boxed()
.context(super::UnableToGetSchemaSnafu)?;

let result: duckdb::Arrow<'_> = stmt.query_arrow(params).context(DuckDBSnafu)?;
let schema = result.get_schema();
let recs: Vec<RecordBatch> = result.collect();
Ok(Box::pin(MemoryStream::try_new(recs, schema, None)?))

let params = params.iter().map(dyn_clone::clone).collect::<Vec<_>>();

let conn = self.conn.try_clone()?;
let sql = sql.to_string();

let cloned_schema = schema.clone();

let join_handle = tokio::task::spawn_blocking(move || {
let mut stmt = conn.prepare(&sql).context(DuckDBSnafu)?;
let params: &[&dyn ToSql] = &params
.iter()
.map(|f| f.as_input_parameter())
.collect::<Vec<_>>();
let result: duckdb::ArrowStream<'_> = stmt
.stream_arrow(params, cloned_schema)
.context(DuckDBSnafu)?;
for i in result {
blocking_channel_send(&batch_tx, i)?;
}

Ok::<_, Box<dyn std::error::Error + Send + Sync>>(())
});

let output_stream = stream! {
while let Some(batch) = batch_rx.recv().await {
yield Ok(batch);
}

if let Err(e) = join_handle.await {
yield Err(DataFusionError::Execution(format!(
"Failed to execute DuckDB query: {e}"
)))
}
};

Ok(Box::pin(RecordBatchStreamAdapter::new(
schema,
output_stream,
)))
}

fn execute(&self, sql: &str, params: &[&dyn ToSql]) -> Result<u64> {
fn execute(&self, sql: &str, params: &[DuckDBParameter]) -> Result<u64> {
let params: &[&dyn ToSql] = &params
.iter()
.map(|f| f.as_input_parameter())
.collect::<Vec<_>>();

let rows_modified = self.conn.execute(sql, params).context(DuckDBSnafu)?;
Ok(rows_modified as u64)
}
}

fn blocking_channel_send<T>(channel: &Sender<T>, item: T) -> Result<()> {
match channel.blocking_send(item) {
Ok(()) => Ok(()),
Err(e) => Err(Error::ChannelError {
message: format!("{e}"),
}
.into()),
}
}

#[must_use]
pub fn flatten_table_function_name(table_reference: &TableReference) -> String {
let table_name = table_reference.table();
Expand Down
20 changes: 14 additions & 6 deletions src/sql/db_connection_pool/duckdbpool.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use async_trait::async_trait;
use duckdb::{vtab::arrow::ArrowVTab, AccessMode, DuckdbConnectionManager, ToSql};
use duckdb::{vtab::arrow::ArrowVTab, AccessMode, DuckdbConnectionManager};
use snafu::{prelude::*, ResultExt};
use std::sync::Arc;

use super::{DbConnectionPool, Result};
use super::{dbconnection::duckdbconn::DuckDBParameter, DbConnectionPool, Mode, Result};
use crate::sql::db_connection_pool::{
dbconnection::{duckdbconn::DuckDbConnection, DbConnection, SyncDbConnection},
JoinPushDown,
Expand Down Expand Up @@ -36,6 +36,7 @@ pub struct DuckDbConnectionPool {
pool: Arc<r2d2::Pool<DuckdbConnectionManager>>,
join_push_down: JoinPushDown,
attached_databases: Vec<Arc<str>>,
mode: Mode,
}

impl DuckDbConnectionPool {
Expand Down Expand Up @@ -71,6 +72,7 @@ impl DuckDbConnectionPool {
// There can't be any other tables that share the same context for an in-memory DuckDB.
join_push_down: JoinPushDown::Disallow,
attached_databases: Vec::new(),
mode: Mode::Memory,
})
}

Expand Down Expand Up @@ -108,6 +110,7 @@ impl DuckDbConnectionPool {
// Allow join-push down for any other instances that connect to the same underlying file.
join_push_down: JoinPushDown::AllowedFor(path.to_string()),
attached_databases: Vec::new(),
mode: Mode::File,
})
}

Expand All @@ -134,23 +137,28 @@ impl DuckDbConnectionPool {
pub fn connect_sync(
self: Arc<Self>,
) -> Result<
Box<dyn DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, &'static dyn ToSql>>,
Box<dyn DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>>,
> {
let pool = Arc::clone(&self.pool);
let conn: r2d2::PooledConnection<DuckdbConnectionManager> =
pool.get().context(ConnectionPoolSnafu)?;
Ok(Box::new(DuckDbConnection::new(conn)))
}

#[must_use]
pub fn mode(&self) -> Mode {
self.mode
}
}

#[async_trait]
impl DbConnectionPool<r2d2::PooledConnection<DuckdbConnectionManager>, &'static dyn ToSql>
impl DbConnectionPool<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>
for DuckDbConnectionPool
{
async fn connect(
&self,
) -> Result<
Box<dyn DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, &'static dyn ToSql>>,
Box<dyn DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>>,
> {
let pool = Arc::clone(&self.pool);
let conn: r2d2::PooledConnection<DuckdbConnectionManager> =
Expand Down Expand Up @@ -225,7 +233,6 @@ fn extract_db_name(file_path: Arc<str>) -> Result<String> {

#[cfg(test)]
mod test {

use rand::Rng;

use super::*;
Expand Down Expand Up @@ -265,6 +272,7 @@ mod test {
}

#[tokio::test]
#[cfg(feature = "duckdb-federation")]
async fn test_duckdb_connection_pool_with_attached_databases() {
let db_base_name = random_db_name();
let db_attached_name = random_db_name();
Expand Down
Loading

0 comments on commit ce683ca

Please sign in to comment.