diff --git a/rust/datafusion/src/physical_plan/common.rs b/rust/datafusion/src/physical_plan/common.rs index 2b96c454fc9..40c9763c024 100644 --- a/rust/datafusion/src/physical_plan/common.rs +++ b/rust/datafusion/src/physical_plan/common.rs @@ -26,17 +26,26 @@ use super::{RecordBatchStream, SendableRecordBatchStream}; use crate::error::{DataFusionError, Result}; use array::{ - BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, - Int8Array, LargeStringArray, StringArray, UInt16Array, UInt32Array, UInt64Array, + ArrayData, BooleanArray, Date32Array, DecimalArray, Float32Array, Float64Array, + Int16Array, Int32Array, Int64Array, Int8Array, LargeStringArray, StringArray, + Time32MillisecondArray, Time32SecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }; -use arrow::datatypes::{DataType, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; use arrow::{ array::{self, ArrayRef}, datatypes::Schema, }; +use arrow::{ + array::{ + Date64Array, Time64MicrosecondArray, Time64NanosecondArray, + TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, + TimestampSecondArray, + }, + buffer::Buffer, + datatypes::{DataType, SchemaRef, TimeUnit}, +}; use futures::{Stream, TryStreamExt}; /// Stream of record batches @@ -157,6 +166,67 @@ pub fn create_batch_empty(schema: &Schema) -> ArrowResult { DataType::Boolean => { Ok(Arc::new(BooleanArray::from(vec![] as Vec)) as ArrayRef) } + DataType::Decimal(scale, precision) => { + let array_data = + ArrayData::builder(DataType::Decimal(*scale, *precision)) + .len(0) + .add_buffer(Buffer::from(&[])) + .build(); + + Ok(Arc::new(DecimalArray::from(array_data)) as ArrayRef) + } + DataType::Timestamp(TimeUnit::Nanosecond, tz) => Ok(Arc::new( + TimestampNanosecondArray::from_vec(vec![] as Vec, tz.clone()), + ) + as ArrayRef), + DataType::Timestamp(TimeUnit::Microsecond, tz) => Ok(Arc::new( + TimestampMicrosecondArray::from_vec(vec![] as Vec, tz.clone()), + ) + as ArrayRef), + DataType::Timestamp(TimeUnit::Millisecond, tz) => Ok(Arc::new( + TimestampMillisecondArray::from_vec(vec![] as Vec, tz.clone()), + ) + as ArrayRef), + DataType::Timestamp(TimeUnit::Second, tz) => Ok(Arc::new( + TimestampSecondArray::from_vec(vec![] as Vec, tz.clone()), + ) as ArrayRef), + DataType::Date32(_) => { + Ok(Arc::new(Date32Array::from(vec![] as Vec)) as ArrayRef) + } + DataType::Date64(_) => { + Ok(Arc::new(Date64Array::from(vec![] as Vec)) as ArrayRef) + } + DataType::Time32(unit) => match unit { + TimeUnit::Second => { + Ok(Arc::new(Time32SecondArray::from(vec![] as Vec)) as ArrayRef) + } + TimeUnit::Millisecond => { + Ok(Arc::new(Time32MillisecondArray::from(vec![] as Vec)) + as ArrayRef) + } + TimeUnit::Microsecond | TimeUnit::Nanosecond => { + Err(DataFusionError::NotImplemented(format!( + "Cannot convert datatype {:?} to array", + f.data_type() + ))) + } + }, + DataType::Time64(unit) => match unit { + TimeUnit::Second | TimeUnit::Millisecond => { + Err(DataFusionError::NotImplemented(format!( + "Cannot convert datatype {:?} to array", + f.data_type() + ))) + } + TimeUnit::Microsecond => { + Ok(Arc::new(Time64MicrosecondArray::from(vec![] as Vec)) + as ArrayRef) + } + TimeUnit::Nanosecond => { + Ok(Arc::new(Time64NanosecondArray::from(vec![] as Vec)) + as ArrayRef) + } + }, _ => Err(DataFusionError::NotImplemented(format!( "Cannot convert datatype {:?} to array", f.data_type() @@ -164,5 +234,52 @@ pub fn create_batch_empty(schema: &Schema) -> ArrowResult { }) .collect::>() .map_err(DataFusionError::into_arrow_external_error)?; + RecordBatch::try_new(Arc::new(schema.to_owned()), columns) } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::Field; + + #[test] + fn test_create_batch_empty() { + let schema = Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::UInt32, false), + Field::new("c3", DataType::Int8, false), + Field::new("c4", DataType::Int16, false), + Field::new("c5", DataType::Int32, false), + Field::new("c6", DataType::Int64, false), + Field::new("c7", DataType::UInt8, false), + Field::new("c8", DataType::UInt16, false), + Field::new("c9", DataType::UInt32, false), + Field::new("c10", DataType::UInt64, false), + Field::new("c11", DataType::Float32, false), + Field::new("c12", DataType::Float64, false), + Field::new("c13", DataType::Utf8, false), + Field::new("c14", DataType::Decimal(10, 10), false), + Field::new("c15", DataType::Timestamp(TimeUnit::Second, None), false), + Field::new( + "c16", + DataType::Timestamp(TimeUnit::Microsecond, None), + false, + ), + Field::new( + "c17", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ), + Field::new( + "c18", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + ), + Field::new("c19", DataType::Boolean, false), + ]); + + let batch = create_batch_empty(&schema).unwrap(); + assert_eq!(batch.columns().len(), 19); + } +}