diff --git a/rust/arrow/src/datatypes/schema.rs b/rust/arrow/src/datatypes/schema.rs index ad89b29cacd..35ee3353d62 100644 --- a/rust/arrow/src/datatypes/schema.rs +++ b/rust/arrow/src/datatypes/schema.rs @@ -152,6 +152,12 @@ impl Schema { }) } + /// Appends a new field to this `Schema` as a field named + /// `field_name`. + pub fn push(&mut self, field: Field) { + self.fields.push(field) + } + /// Returns an immutable reference of the vector of `Field` instances. #[inline] pub const fn fields(&self) -> &Vec { diff --git a/rust/arrow/src/record_batch.rs b/rust/arrow/src/record_batch.rs index 93abb909d02..1d1a583ca71 100644 --- a/rust/arrow/src/record_batch.rs +++ b/rust/arrow/src/record_batch.rs @@ -93,7 +93,7 @@ impl RecordBatch { Ok(RecordBatch { schema, columns }) } - /// Creates a new empty [`RecordBatch`]. + /// Creates a new empty [`RecordBatch`] based on `schema`. pub fn new_empty(schema: SchemaRef) -> Self { let columns = schema .fields() @@ -103,6 +103,56 @@ impl RecordBatch { RecordBatch { schema, columns } } + /// Creates a new [`RecordBatch`] with no columns + /// + /// TODO add an code example using `append` + pub fn new() -> Self { + Self { + schema: Arc::new(Schema::empty()), + columns: Vec::new(), + } + } + + /// Appends the `field_array` array to this `RecordBatch` as a + /// field named `field_name`. + /// + /// TODO: code example + /// + /// TODO: on error, can we return `Self` in some meaningful way? + pub fn append(self, field_name: &str, field_values: ArrayRef) -> Result { + if let Some(col) = self.columns.get(0) { + if col.len() != field_values.len() { + return Err(ArrowError::InvalidArgumentError( + format!("all columns in a record batch must have the same length. expected {}, field {} had {} ", + col.len(), field_name, field_values.len()) + )); + } + } + + let Self { + schema, + mut columns, + } = self; + + // modify the schema we have if possible, otherwise copy + let mut schema = match Arc::try_unwrap(schema) { + Ok(schema) => schema, + Err(shared_schema) => shared_schema.as_ref().clone(), + }; + + let nullable = field_values.null_count() > 0; + schema.push(Field::new( + field_name, + field_values.data_type().clone(), + nullable, + )); + let schema = Arc::new(schema); + + columns.push(field_values); + + Ok(Self { schema, columns }) + } + /// Validate the schema and columns using [`RecordBatchOptions`]. Returns an error /// if any validation check fails. fn validate_new_batch( @@ -245,6 +295,12 @@ impl RecordBatch { } } +impl Default for RecordBatch { + fn default() -> Self { + Self::new() + } +} + /// Options that control the behaviour used when creating a [`RecordBatch`]. #[derive(Debug)] pub struct RecordBatchOptions { @@ -337,6 +393,38 @@ mod tests { assert_eq!(5, record_batch.column(1).data().len()); } + #[test] + fn create_record_batch_builder() { + let a = Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + None, + Some(4), + Some(5), + ])); + let b = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])); + + let record_batch = RecordBatch::new() + .append("a", a) + .unwrap() + .append("b", b) + .unwrap(); + + let expected_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, false), + ]); + + assert_eq!(record_batch.schema().as_ref(), &expected_schema); + + assert_eq!(5, record_batch.num_rows()); + assert_eq!(2, record_batch.num_columns()); + assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type()); + assert_eq!(&DataType::Utf8, record_batch.schema().field(1).data_type()); + assert_eq!(5, record_batch.column(0).data().len()); + assert_eq!(5, record_batch.column(1).data().len()); + } + #[test] fn create_record_batch_schema_mismatch() { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);