Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

INSERT returns number of rows written, add InsertExec to handle common case. #6354

Merged
merged 10 commits into from
May 19, 2023
21 changes: 20 additions & 1 deletion datafusion/core/src/datasource/datasource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,26 @@ pub trait TableProvider: Sync + Send {
None
}

/// Insert into this table
/// Return an [`ExecutionPlan`] to insert data into this table, if
/// supported.
///
/// The returned plan should return a single row in a UInt64
/// column called "count" such as the following
///
/// ```text
/// +-------+,
/// | count |,
/// +-------+,
/// | 6 |,
/// +-------+,
/// ```
///
/// # See Also
///
/// See [`InsertExec`] for the common pattern of inserting a
/// single stream of `RecordBatch`es.
///
/// [`InsertExec`]: crate::physical_plan::insert::InsertExec
async fn insert_into(
&self,
_state: &SessionState,
Expand Down
119 changes: 94 additions & 25 deletions datafusion/core/src/datasource/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

use futures::StreamExt;
use std::any::Any;
use std::fmt::{self, Debug, Display};
use std::sync::Arc;

use arrow::datatypes::SchemaRef;
Expand All @@ -30,11 +31,11 @@ use crate::datasource::{TableProvider, TableType};
use crate::error::{DataFusionError, Result};
use crate::execution::context::SessionState;
use crate::logical_expr::Expr;
use crate::physical_plan::common;
use crate::physical_plan::common::AbortOnDropSingle;
use crate::physical_plan::insert::{DataSink, InsertExec};
use crate::physical_plan::memory::MemoryExec;
use crate::physical_plan::memory::MemoryWriteExec;
use crate::physical_plan::ExecutionPlan;
use crate::physical_plan::{common, SendableRecordBatchStream};
use crate::physical_plan::{repartition::RepartitionExec, Partitioning};

/// Type alias for partition data
Expand Down Expand Up @@ -164,7 +165,8 @@ impl TableProvider for MemTable {
)?))
}

/// Inserts the execution results of a given [`ExecutionPlan`] into this [`MemTable`].
/// Returns an ExecutionPlan that inserts the execution results of a given [`ExecutionPlan`] into this [`MemTable`].
///
/// The [`ExecutionPlan`] must have the same schema as this [`MemTable`].
///
/// # Arguments
Expand All @@ -174,7 +176,7 @@ impl TableProvider for MemTable {
///
/// # Returns
///
/// * A `Result` indicating success or failure.
/// * A plan that returns the number of rows written.
async fn insert_into(
&self,
_state: &SessionState,
Expand All @@ -187,27 +189,61 @@ impl TableProvider for MemTable {
"Inserting query must have the same schema with the table.".to_string(),
));
}
let sink = Arc::new(MemSink::new(self.batches.clone()));
Ok(Arc::new(InsertExec::new(input, sink)))
}
}

if self.batches.is_empty() {
return Err(DataFusionError::Plan(
"The table must have partitions.".to_string(),
));
/// Implements for writing to a [`MemTable`]
struct MemSink {
/// Target locations for writing data
batches: Vec<PartitionData>,
}

impl Debug for MemSink {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MemSink")
.field("num_partitions", &self.batches.len())
.finish()
}
}

impl Display for MemSink {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let partition_count = self.batches.len();
write!(f, "MemoryTable (partitions={partition_count})")
}
}

impl MemSink {
fn new(batches: Vec<PartitionData>) -> Self {
Self { batches }
}
}

#[async_trait]
impl DataSink for MemSink {
async fn write_all(&self, mut data: SendableRecordBatchStream) -> Result<u64> {
let num_partitions = self.batches.len();

// buffer up the data round robin style into num_partitions

let mut new_batches = vec![vec![]; num_partitions];
let mut i = 0;
let mut row_count = 0;
while let Some(batch) = data.next().await.transpose()? {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this formulation of MemTable insert is easier to understand as the code is now in the same place alongside the code that defines the rest of the data structures

row_count += batch.num_rows();
new_batches[i].push(batch);
i = (i + 1) % num_partitions;
}

let input = if self.batches.len() > 1 {
Arc::new(RepartitionExec::try_new(
input,
Partitioning::RoundRobinBatch(self.batches.len()),
)?)
} else {
input
};
// write the outputs into the batches
for (target, mut batches) in self.batches.iter().zip(new_batches.into_iter()) {
// Append all the new batches in one go to minimize locking overhead
target.write().await.append(&mut batches);
Copy link
Contributor

@tustvold tustvold May 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we can now remove the per-partition locks 🤔

Possibly even switching to using a non-async lock even...

}

Ok(Arc::new(MemoryWriteExec::try_new(
input,
self.batches.clone(),
self.schema.clone(),
)?))
Ok(row_count as u64)
}
}

Expand All @@ -218,8 +254,8 @@ mod tests {
use crate::from_slice::FromSlice;
use crate::physical_plan::collect;
use crate::prelude::SessionContext;
use arrow::array::Int32Array;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::array::{AsArray, Int32Array};
use arrow::datatypes::{DataType, Field, Schema, UInt64Type};
use arrow::error::ArrowError;
use datafusion_expr::LogicalPlanBuilder;
use futures::StreamExt;
Expand Down Expand Up @@ -457,6 +493,11 @@ mod tests {
initial_data: Vec<Vec<RecordBatch>>,
inserted_data: Vec<Vec<RecordBatch>>,
) -> Result<Vec<Vec<RecordBatch>>> {
let expected_count: u64 = inserted_data
.iter()
.flat_map(|batches| batches.iter().map(|batch| batch.num_rows() as u64))
.sum();

// Create a new session context
let session_ctx = SessionContext::new();
// Create and register the initial table with the provided schema and data
Expand All @@ -480,8 +521,8 @@ mod tests {

// Execute the physical plan and collect the results
let res = collect(plan, session_ctx.task_ctx()).await?;
// Ensure the result is empty after the insert operation
assert!(res.is_empty());
assert_eq!(extract_count(res), expected_count);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here is the test for the new feature of returning row counts


// Read the data from the initial table and store it in a vector of partitions
let mut partitions = vec![];
for partition in initial_table.batches.iter() {
Expand All @@ -491,6 +532,34 @@ mod tests {
Ok(partitions)
}

/// Returns the value of results. For example, returns 6 given the follwing
///
/// ```text
/// +-------+,
/// | count |,
/// +-------+,
/// | 6 |,
/// +-------+,
/// ```
fn extract_count(res: Vec<RecordBatch>) -> u64 {
assert_eq!(res.len(), 1, "expected one batch, got {}", res.len());
let batch = &res[0];
assert_eq!(
batch.num_columns(),
1,
"expected 1 column, got {}",
batch.num_columns()
);
let col = batch.column(0).as_primitive::<UInt64Type>();
assert_eq!(col.len(), 1, "expected 1 row, got {}", col.len());
let val = col
.iter()
.next()
.expect("had value")
.expect("expected non null");
val
}

// Test inserting a single batch of data into a single partition
#[tokio::test]
async fn test_insert_into_single_partition() -> Result<()> {
Expand Down
Loading