diff --git a/rust/arrow/src/datatypes/field.rs b/rust/arrow/src/datatypes/field.rs index cd43510f55c..11fc31d6343 100644 --- a/rust/arrow/src/datatypes/field.rs +++ b/rust/arrow/src/datatypes/field.rs @@ -488,6 +488,49 @@ impl Field { Ok(()) } + + /// Check to see if `self` is a superset of `other` field. Superset is defined as: + /// + /// * if nullability doesn't match, self needs to be nullable + /// * self.metadata is a superset of other.metadata + /// * all other fields are equal + pub fn contains(&self, other: &Field) -> bool { + if self.name != other.name + || self.data_type != other.data_type + || self.dict_id != other.dict_id + || self.dict_is_ordered != other.dict_is_ordered + { + return false; + } + + if self.nullable != other.nullable && !self.nullable { + return false; + } + + // make sure self.metadata is a superset of other.metadata + match (&self.metadata, &other.metadata) { + (None, Some(_)) => { + return false; + } + (Some(self_meta), Some(other_meta)) => { + for (k, v) in other_meta.iter() { + match self_meta.get(k) { + Some(s) => { + if s != v { + return false; + } + } + None => { + return false; + } + } + } + } + _ => {} + } + + true + } } // TODO: improve display with crate https://crates.io/crates/derive_more ? diff --git a/rust/arrow/src/datatypes/schema.rs b/rust/arrow/src/datatypes/schema.rs index 1e9acf799fc..ad89b29cacd 100644 --- a/rust/arrow/src/datatypes/schema.rs +++ b/rust/arrow/src/datatypes/schema.rs @@ -279,6 +279,42 @@ impl Schema { )), } } + + /// Check to see if `self` is a superset of `other` schema. Here are the comparision rules: + /// + /// * `self` and `other` should contain the same number of fields + /// * for every field `f` in `other`, the field in `self` with corresponding index should be a + /// superset of `f`. + /// * self.metadata is a superset of other.metadata + /// + /// In other words, any record conforms to `other` should also conform to `self`. + pub fn contains(&self, other: &Schema) -> bool { + if self.fields.len() != other.fields.len() { + return false; + } + + for (i, field) in other.fields.iter().enumerate() { + if !self.fields[i].contains(field) { + return false; + } + } + + // make sure self.metadata is a superset of other.metadata + for (k, v) in &other.metadata { + match self.metadata.get(k) { + Some(s) => { + if s != v { + return false; + } + } + None => { + return false; + } + } + } + + true + } } impl fmt::Display for Schema { diff --git a/rust/datafusion/src/datasource/memory.rs b/rust/datafusion/src/datasource/memory.rs index 1fc0eaabc6c..0fafa0f6925 100644 --- a/rust/datafusion/src/datasource/memory.rs +++ b/rust/datafusion/src/datasource/memory.rs @@ -88,7 +88,7 @@ impl MemTable { if partitions .iter() .flatten() - .all(|batches| batches.schema() == schema) + .all(|batches| schema.contains(&batches.schema())) { let statistics = calculate_statistics(&schema, &partitions); debug!("MemTable statistics: {:?}", statistics); @@ -220,6 +220,7 @@ mod tests { use arrow::array::Int32Array; use arrow::datatypes::{DataType, Field, Schema}; use futures::StreamExt; + use std::collections::HashMap; #[tokio::test] async fn test_with_projection() -> Result<()> { @@ -333,7 +334,7 @@ mod tests { } #[test] - fn test_schema_validation() -> Result<()> { + fn test_schema_validation_incompatible_column() -> Result<()> { let schema1 = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int32, false), @@ -365,4 +366,91 @@ mod tests { Ok(()) } + + #[test] + fn test_schema_validation_different_column_count() -> Result<()> { + let schema1 = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ])); + + let schema2 = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ])); + + let batch = RecordBatch::try_new( + schema1, + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(Int32Array::from(vec![7, 5, 9])), + ], + )?; + + match MemTable::try_new(schema2, vec![vec![batch]]) { + Err(DataFusionError::Plan(e)) => assert_eq!( + "\"Mismatch between schema and batches\"", + format!("{:?}", e) + ), + _ => panic!("MemTable::new should have failed due to schema mismatch"), + } + + Ok(()) + } + + #[tokio::test] + async fn test_merged_schema() -> Result<()> { + let mut metadata = HashMap::new(); + metadata.insert("foo".to_string(), "bar".to_string()); + + let schema1 = Schema::new_with_metadata( + vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ], + // test for comparing metadata + metadata, + ); + + let schema2 = Schema::new(vec![ + // test for comparing nullability + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ]); + + let merged_schema = Schema::try_merge(vec![schema1.clone(), schema2.clone()])?; + + let batch1 = RecordBatch::try_new( + Arc::new(schema1), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(Int32Array::from(vec![4, 5, 6])), + Arc::new(Int32Array::from(vec![7, 8, 9])), + ], + )?; + + let batch2 = RecordBatch::try_new( + Arc::new(schema2), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(Int32Array::from(vec![4, 5, 6])), + Arc::new(Int32Array::from(vec![7, 8, 9])), + ], + )?; + + let provider = + MemTable::try_new(Arc::new(merged_schema), vec![vec![batch1, batch2]])?; + + let exec = provider.scan(&None, 1024, &[])?; + let mut it = exec.execute(0).await?; + let batch1 = it.next().await.unwrap()?; + assert_eq!(3, batch1.schema().fields().len()); + assert_eq!(3, batch1.num_columns()); + assert_eq!(provider.statistics().num_rows, Some(6)); + + Ok(()) + } }