-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
INSERT returns number of rows written, add InsertExec
to handle common case.
#6354
Changes from 6 commits
095486f
c8a4a55
3000463
3acd24d
0e20c08
b1464da
f4fba65
7c2f189
a665265
b1b0146
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
|
||
use futures::StreamExt; | ||
use std::any::Any; | ||
use std::fmt::{self, Debug, Display}; | ||
use std::sync::Arc; | ||
|
||
use arrow::datatypes::SchemaRef; | ||
|
@@ -30,11 +31,11 @@ use crate::datasource::{TableProvider, TableType}; | |
use crate::error::{DataFusionError, Result}; | ||
use crate::execution::context::SessionState; | ||
use crate::logical_expr::Expr; | ||
use crate::physical_plan::common; | ||
use crate::physical_plan::common::AbortOnDropSingle; | ||
use crate::physical_plan::insert::{DataSink, InsertExec}; | ||
use crate::physical_plan::memory::MemoryExec; | ||
use crate::physical_plan::memory::MemoryWriteExec; | ||
use crate::physical_plan::ExecutionPlan; | ||
use crate::physical_plan::{common, SendableRecordBatchStream}; | ||
use crate::physical_plan::{repartition::RepartitionExec, Partitioning}; | ||
|
||
/// Type alias for partition data | ||
|
@@ -164,7 +165,8 @@ impl TableProvider for MemTable { | |
)?)) | ||
} | ||
|
||
/// Inserts the execution results of a given [`ExecutionPlan`] into this [`MemTable`]. | ||
/// Returns an ExecutionPlan that inserts the execution results of a given [`ExecutionPlan`] into this [`MemTable`]. | ||
/// | ||
/// The [`ExecutionPlan`] must have the same schema as this [`MemTable`]. | ||
/// | ||
/// # Arguments | ||
|
@@ -174,7 +176,7 @@ impl TableProvider for MemTable { | |
/// | ||
/// # Returns | ||
/// | ||
/// * A `Result` indicating success or failure. | ||
/// * A plan that returns the number of rows written. | ||
async fn insert_into( | ||
&self, | ||
_state: &SessionState, | ||
|
@@ -187,27 +189,61 @@ impl TableProvider for MemTable { | |
"Inserting query must have the same schema with the table.".to_string(), | ||
)); | ||
} | ||
let sink = Arc::new(MemSink::new(self.batches.clone())); | ||
Ok(Arc::new(InsertExec::new(input, sink))) | ||
} | ||
} | ||
|
||
if self.batches.is_empty() { | ||
return Err(DataFusionError::Plan( | ||
"The table must have partitions.".to_string(), | ||
)); | ||
/// Implements for writing to a [`MemTable`] | ||
struct MemSink { | ||
/// Target locations for writing data | ||
batches: Vec<PartitionData>, | ||
} | ||
|
||
impl Debug for MemSink { | ||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||
f.debug_struct("MemSink") | ||
.field("num_partitions", &self.batches.len()) | ||
.finish() | ||
} | ||
} | ||
|
||
impl Display for MemSink { | ||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||
let partition_count = self.batches.len(); | ||
write!(f, "MemoryTable (partitions={partition_count})") | ||
} | ||
} | ||
|
||
impl MemSink { | ||
fn new(batches: Vec<PartitionData>) -> Self { | ||
Self { batches } | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl DataSink for MemSink { | ||
async fn write_all(&self, mut data: SendableRecordBatchStream) -> Result<u64> { | ||
let num_partitions = self.batches.len(); | ||
|
||
// buffer up the data round robin style into num_partitions | ||
|
||
let mut new_batches = vec![vec![]; num_partitions]; | ||
let mut i = 0; | ||
let mut row_count = 0; | ||
while let Some(batch) = data.next().await.transpose()? { | ||
row_count += batch.num_rows(); | ||
new_batches[i].push(batch); | ||
i = (i + 1) % num_partitions; | ||
} | ||
|
||
let input = if self.batches.len() > 1 { | ||
Arc::new(RepartitionExec::try_new( | ||
input, | ||
Partitioning::RoundRobinBatch(self.batches.len()), | ||
)?) | ||
} else { | ||
input | ||
}; | ||
// write the outputs into the batches | ||
for (target, mut batches) in self.batches.iter().zip(new_batches.into_iter()) { | ||
// Append all the new batches in one go to minimize locking overhead | ||
target.write().await.append(&mut batches); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if we can now remove the per-partition locks 🤔 Possibly even switching to using a non-async lock even... |
||
} | ||
|
||
Ok(Arc::new(MemoryWriteExec::try_new( | ||
input, | ||
self.batches.clone(), | ||
self.schema.clone(), | ||
)?)) | ||
Ok(row_count as u64) | ||
} | ||
} | ||
|
||
|
@@ -218,8 +254,8 @@ mod tests { | |
use crate::from_slice::FromSlice; | ||
use crate::physical_plan::collect; | ||
use crate::prelude::SessionContext; | ||
use arrow::array::Int32Array; | ||
use arrow::datatypes::{DataType, Field, Schema}; | ||
use arrow::array::{AsArray, Int32Array}; | ||
use arrow::datatypes::{DataType, Field, Schema, UInt64Type}; | ||
use arrow::error::ArrowError; | ||
use datafusion_expr::LogicalPlanBuilder; | ||
use futures::StreamExt; | ||
|
@@ -457,6 +493,11 @@ mod tests { | |
initial_data: Vec<Vec<RecordBatch>>, | ||
inserted_data: Vec<Vec<RecordBatch>>, | ||
) -> Result<Vec<Vec<RecordBatch>>> { | ||
let expected_count: u64 = inserted_data | ||
.iter() | ||
.flat_map(|batches| batches.iter().map(|batch| batch.num_rows() as u64)) | ||
.sum(); | ||
|
||
// Create a new session context | ||
let session_ctx = SessionContext::new(); | ||
// Create and register the initial table with the provided schema and data | ||
|
@@ -480,8 +521,8 @@ mod tests { | |
|
||
// Execute the physical plan and collect the results | ||
let res = collect(plan, session_ctx.task_ctx()).await?; | ||
// Ensure the result is empty after the insert operation | ||
assert!(res.is_empty()); | ||
assert_eq!(extract_count(res), expected_count); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here is the test for the new feature of returning row counts |
||
|
||
// Read the data from the initial table and store it in a vector of partitions | ||
let mut partitions = vec![]; | ||
for partition in initial_table.batches.iter() { | ||
|
@@ -491,6 +532,34 @@ mod tests { | |
Ok(partitions) | ||
} | ||
|
||
/// Returns the value of results. For example, returns 6 given the follwing | ||
/// | ||
/// ```text | ||
/// +-------+, | ||
/// | count |, | ||
/// +-------+, | ||
/// | 6 |, | ||
/// +-------+, | ||
/// ``` | ||
fn extract_count(res: Vec<RecordBatch>) -> u64 { | ||
assert_eq!(res.len(), 1, "expected one batch, got {}", res.len()); | ||
let batch = &res[0]; | ||
assert_eq!( | ||
batch.num_columns(), | ||
1, | ||
"expected 1 column, got {}", | ||
batch.num_columns() | ||
); | ||
let col = batch.column(0).as_primitive::<UInt64Type>(); | ||
assert_eq!(col.len(), 1, "expected 1 row, got {}", col.len()); | ||
let val = col | ||
.iter() | ||
.next() | ||
.expect("had value") | ||
.expect("expected non null"); | ||
val | ||
} | ||
|
||
// Test inserting a single batch of data into a single partition | ||
#[tokio::test] | ||
async fn test_insert_into_single_partition() -> Result<()> { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this formulation of
MemTable
insert is easier to understand as the code is now in the same place alongside the code that defines the rest of the data structures