diff --git a/datafusion/ffi/src/execution_plan.rs b/datafusion/ffi/src/execution_plan.rs index a8c2f42fe251..6c5db1218563 100644 --- a/datafusion/ffi/src/execution_plan.rs +++ b/datafusion/ffi/src/execution_plan.rs @@ -27,7 +27,7 @@ use datafusion::{ execution::{SendableRecordBatchStream, TaskContext}, physical_plan::{DisplayAs, ExecutionPlan, PlanProperties}, }; -use tokio::runtime::Runtime; +use tokio::runtime::Handle; use crate::{ plan_properties::FFI_PlanProperties, record_batch_stream::FFI_RecordBatchStream, @@ -72,7 +72,7 @@ unsafe impl Sync for FFI_ExecutionPlan {} pub struct ExecutionPlanPrivateData { pub plan: Arc, pub context: Arc, - pub runtime: Option>, + pub runtime: Option, } unsafe extern "C" fn properties_fn_wrapper( @@ -110,7 +110,7 @@ unsafe extern "C" fn execute_fn_wrapper( let private_data = plan.private_data as *const ExecutionPlanPrivateData; let plan = &(*private_data).plan; let ctx = &(*private_data).context; - let runtime = (*private_data).runtime.as_ref().map(Arc::clone); + let runtime = (*private_data).runtime.clone(); match plan.execute(partition, Arc::clone(ctx)) { Ok(rbs) => RResult::ROk(FFI_RecordBatchStream::new(rbs, runtime)), @@ -153,7 +153,7 @@ impl FFI_ExecutionPlan { pub fn new( plan: Arc, context: Arc, - runtime: Option>, + runtime: Option, ) -> Self { let private_data = Box::new(ExecutionPlanPrivateData { plan, diff --git a/datafusion/ffi/src/insert_op.rs b/datafusion/ffi/src/insert_op.rs new file mode 100644 index 000000000000..e44262377405 --- /dev/null +++ b/datafusion/ffi/src/insert_op.rs @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use abi_stable::StableAbi; +use datafusion::logical_expr::logical_plan::dml::InsertOp; + +/// FFI safe version of [`InsertOp`]. +#[repr(C)] +#[derive(StableAbi)] +#[allow(non_camel_case_types)] +pub enum FFI_InsertOp { + Append, + Overwrite, + Replace, +} + +impl From for InsertOp { + fn from(value: FFI_InsertOp) -> Self { + match value { + FFI_InsertOp::Append => InsertOp::Append, + FFI_InsertOp::Overwrite => InsertOp::Overwrite, + FFI_InsertOp::Replace => InsertOp::Replace, + } + } +} + +impl From for FFI_InsertOp { + fn from(value: InsertOp) -> Self { + match value { + InsertOp::Append => FFI_InsertOp::Append, + InsertOp::Overwrite => FFI_InsertOp::Overwrite, + InsertOp::Replace => FFI_InsertOp::Replace, + } + } +} diff --git a/datafusion/ffi/src/lib.rs b/datafusion/ffi/src/lib.rs index bef36b1ddd48..b25528234773 100644 --- a/datafusion/ffi/src/lib.rs +++ b/datafusion/ffi/src/lib.rs @@ -20,6 +20,7 @@ pub mod arrow_wrappers; pub mod execution_plan; +pub mod insert_op; pub mod plan_properties; pub mod record_batch_stream; pub mod session_config; diff --git a/datafusion/ffi/src/record_batch_stream.rs b/datafusion/ffi/src/record_batch_stream.rs index 878ac24f6765..466ce247678a 100644 --- a/datafusion/ffi/src/record_batch_stream.rs +++ b/datafusion/ffi/src/record_batch_stream.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::{ffi::c_void, sync::Arc, task::Poll}; +use std::{ffi::c_void, task::Poll}; use abi_stable::{ std_types::{ROption, RResult, RString}, @@ -33,7 +33,7 @@ use datafusion::{ execution::{RecordBatchStream, SendableRecordBatchStream}, }; use futures::{Stream, TryStreamExt}; -use tokio::runtime::Runtime; +use tokio::runtime::Handle; use crate::arrow_wrappers::{WrappedArray, WrappedSchema}; @@ -61,7 +61,7 @@ pub struct FFI_RecordBatchStream { pub struct RecordBatchStreamPrivateData { pub rbs: SendableRecordBatchStream, - pub runtime: Option>, + pub runtime: Option, } impl From for FFI_RecordBatchStream { @@ -71,7 +71,7 @@ impl From for FFI_RecordBatchStream { } impl FFI_RecordBatchStream { - pub fn new(stream: SendableRecordBatchStream, runtime: Option>) -> Self { + pub fn new(stream: SendableRecordBatchStream, runtime: Option) -> Self { let private_data = Box::into_raw(Box::new(RecordBatchStreamPrivateData { rbs: stream, runtime, diff --git a/datafusion/ffi/src/table_provider.rs b/datafusion/ffi/src/table_provider.rs index 183dfc8755d1..978ac10206bd 100644 --- a/datafusion/ffi/src/table_provider.rs +++ b/datafusion/ffi/src/table_provider.rs @@ -28,8 +28,8 @@ use datafusion::{ catalog::{Session, TableProvider}, datasource::TableType, error::DataFusionError, - execution::session_state::SessionStateBuilder, - logical_expr::TableProviderFilterPushDown, + execution::{session_state::SessionStateBuilder, TaskContext}, + logical_expr::{logical_plan::dml::InsertOp, TableProviderFilterPushDown}, physical_plan::ExecutionPlan, prelude::{Expr, SessionContext}, }; @@ -40,7 +40,7 @@ use datafusion_proto::{ protobuf::LogicalExprList, }; use prost::Message; -use tokio::runtime::Runtime; +use tokio::runtime::Handle; use crate::{ arrow_wrappers::WrappedSchema, @@ -50,6 +50,7 @@ use crate::{ use super::{ execution_plan::{FFI_ExecutionPlan, ForeignExecutionPlan}, + insert_op::FFI_InsertOp, session_config::FFI_SessionConfig, }; use datafusion::error::Result; @@ -133,6 +134,14 @@ pub struct FFI_TableProvider { -> RResult, RString>, >, + pub insert_into: + unsafe extern "C" fn( + provider: &Self, + session_config: &FFI_SessionConfig, + input: &FFI_ExecutionPlan, + insert_op: FFI_InsertOp, + ) -> FfiFuture>, + /// Used to create a clone on the provider of the execution plan. This should /// only need to be called by the receiver of the plan. pub clone: unsafe extern "C" fn(plan: &Self) -> Self, @@ -153,7 +162,7 @@ unsafe impl Sync for FFI_TableProvider {} struct ProviderPrivateData { provider: Arc, - runtime: Option>, + runtime: Option, } unsafe extern "C" fn schema_fn_wrapper(provider: &FFI_TableProvider) -> WrappedSchema { @@ -276,6 +285,53 @@ unsafe extern "C" fn scan_fn_wrapper( .into_ffi() } +unsafe extern "C" fn insert_into_fn_wrapper( + provider: &FFI_TableProvider, + session_config: &FFI_SessionConfig, + input: &FFI_ExecutionPlan, + insert_op: FFI_InsertOp, +) -> FfiFuture> { + let private_data = provider.private_data as *mut ProviderPrivateData; + let internal_provider = &(*private_data).provider; + let session_config = session_config.clone(); + let input = input.clone(); + let runtime = &(*private_data).runtime; + + async move { + let config = match ForeignSessionConfig::try_from(&session_config) { + Ok(c) => c, + Err(e) => return RResult::RErr(e.to_string().into()), + }; + let session = SessionStateBuilder::new() + .with_default_features() + .with_config(config.0) + .build(); + let ctx = SessionContext::new_with_state(session); + + let input = match ForeignExecutionPlan::try_from(&input) { + Ok(input) => Arc::new(input), + Err(e) => return RResult::RErr(e.to_string().into()), + }; + + let insert_op = InsertOp::from(insert_op); + + let plan = match internal_provider + .insert_into(&ctx.state(), input, insert_op) + .await + { + Ok(p) => p, + Err(e) => return RResult::RErr(e.to_string().into()), + }; + + RResult::ROk(FFI_ExecutionPlan::new( + plan, + ctx.task_ctx(), + runtime.clone(), + )) + } + .into_ffi() +} + unsafe extern "C" fn release_fn_wrapper(provider: &mut FFI_TableProvider) { let private_data = Box::from_raw(provider.private_data as *mut ProviderPrivateData); drop(private_data); @@ -295,6 +351,7 @@ unsafe extern "C" fn clone_fn_wrapper(provider: &FFI_TableProvider) -> FFI_Table scan: scan_fn_wrapper, table_type: table_type_fn_wrapper, supports_filters_pushdown: provider.supports_filters_pushdown, + insert_into: provider.insert_into, clone: clone_fn_wrapper, release: release_fn_wrapper, version: super::version, @@ -313,7 +370,7 @@ impl FFI_TableProvider { pub fn new( provider: Arc, can_support_pushdown_filters: bool, - runtime: Option>, + runtime: Option, ) -> Self { let private_data = Box::new(ProviderPrivateData { provider, runtime }); @@ -325,6 +382,7 @@ impl FFI_TableProvider { true => Some(supports_filters_pushdown_fn_wrapper), false => None, }, + insert_into: insert_into_fn_wrapper, clone: clone_fn_wrapper, release: release_fn_wrapper, version: super::version, @@ -443,6 +501,37 @@ impl TableProvider for ForeignTableProvider { } } } + + async fn insert_into( + &self, + session: &dyn Session, + input: Arc, + insert_op: InsertOp, + ) -> Result> { + let session_config: FFI_SessionConfig = session.config().into(); + + let rc = Handle::try_current().ok(); + let input = + FFI_ExecutionPlan::new(input, Arc::new(TaskContext::from(session)), rc); + let insert_op: FFI_InsertOp = insert_op.into(); + + let plan = unsafe { + let maybe_plan = + (self.0.insert_into)(&self.0, &session_config, &input, insert_op).await; + + match maybe_plan { + RResult::ROk(p) => ForeignExecutionPlan::try_from(&p)?, + RResult::RErr(e) => { + return Err(DataFusionError::Internal(format!( + "Unable to perform insert_into via FFI: {}", + e + ))) + } + } + }; + + Ok(Arc::new(plan)) + } } #[cfg(test)] @@ -453,7 +542,7 @@ mod tests { use super::*; #[tokio::test] - async fn test_round_trip_ffi_table_provider() -> Result<()> { + async fn test_round_trip_ffi_table_provider_scan() -> Result<()> { use arrow::datatypes::Field; use datafusion::arrow::{ array::Float32Array, datatypes::DataType, record_batch::RecordBatch, @@ -493,4 +582,54 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_round_trip_ffi_table_provider_insert_into() -> Result<()> { + use arrow::datatypes::Field; + use datafusion::arrow::{ + array::Float32Array, datatypes::DataType, record_batch::RecordBatch, + }; + use datafusion::datasource::MemTable; + + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); + + // define data in two partitions + let batch1 = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0]))], + )?; + let batch2 = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Float32Array::from(vec![64.0]))], + )?; + + let ctx = SessionContext::new(); + + let provider = + Arc::new(MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?); + + let ffi_provider = FFI_TableProvider::new(provider, true, None); + + let foreign_table_provider: ForeignTableProvider = (&ffi_provider).into(); + + ctx.register_table("t", Arc::new(foreign_table_provider))?; + + let result = ctx + .sql("INSERT INTO t VALUES (128.0);") + .await? + .collect() + .await?; + + assert!(result.len() == 1 && result[0].num_rows() == 1); + + ctx.table("t") + .await? + .select(vec![col("a")])? + .filter(col("a").gt(lit(3.0)))? + .show() + .await?; + + Ok(()) + } } diff --git a/datafusion/ffi/src/tests/async_provider.rs b/datafusion/ffi/src/tests/async_provider.rs index 38ddd13952b0..0e56a318cd87 100644 --- a/datafusion/ffi/src/tests/async_provider.rs +++ b/datafusion/ffi/src/tests/async_provider.rs @@ -41,7 +41,7 @@ use datafusion::{ }; use futures::Stream; use tokio::{ - runtime::Runtime, + runtime::Handle, sync::{broadcast, mpsc}, }; @@ -59,7 +59,7 @@ fn async_table_provider_thread( mut shutdown: mpsc::Receiver, mut batch_request: mpsc::Receiver, batch_sender: broadcast::Sender>, - tokio_rt: mpsc::Sender>, + tokio_rt: mpsc::Sender, ) { let runtime = Arc::new( tokio::runtime::Builder::new_current_thread() @@ -68,7 +68,7 @@ fn async_table_provider_thread( ); let _runtime_guard = runtime.enter(); tokio_rt - .blocking_send(Arc::clone(&runtime)) + .blocking_send(runtime.handle().clone()) .expect("Unable to send tokio runtime back to main thread"); runtime.block_on(async move { @@ -91,7 +91,7 @@ fn async_table_provider_thread( let _ = shutdown.blocking_recv(); } -pub fn start_async_provider() -> (AsyncTableProvider, Arc) { +pub fn start_async_provider() -> (AsyncTableProvider, Handle) { let (batch_request_tx, batch_request_rx) = mpsc::channel(10); let (record_batch_tx, record_batch_rx) = broadcast::channel(10); let (tokio_rt_tx, mut tokio_rt_rx) = mpsc::channel(10);