Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions rust/arrow/src/datatypes/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,49 @@ impl Field {

Ok(())
}

/// Check to see if `self` is a superset of `other` field. Superset is defined as:
///
/// * if nullability doesn't match, self needs to be nullable
/// * self.metadata is a superset of other.metadata
/// * all other fields are equal
pub fn contains(&self, other: &Field) -> bool {
if self.name != other.name
|| self.data_type != other.data_type
|| self.dict_id != other.dict_id
|| self.dict_is_ordered != other.dict_is_ordered
{
return false;
}

if self.nullable != other.nullable && !self.nullable {
return false;
}

// make sure self.metadata is a superset of other.metadata
match (&self.metadata, &other.metadata) {
(None, Some(_)) => {
return false;
}
(Some(self_meta), Some(other_meta)) => {
for (k, v) in other_meta.iter() {
match self_meta.get(k) {
Some(s) => {
if s != v {
return false;
}
}
None => {
return false;
}
}
}
}
_ => {}
}

true
}
}

// TODO: improve display with crate https://crates.io/crates/derive_more ?
Expand Down
36 changes: 36 additions & 0 deletions rust/arrow/src/datatypes/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,42 @@ impl Schema {
)),
}
}

/// Check to see if `self` is a superset of `other` schema. Here are the comparision rules:
///
/// * `self` and `other` should contain the same number of fields
/// * for every field `f` in `other`, the field in `self` with corresponding index should be a
/// superset of `f`.
/// * self.metadata is a superset of other.metadata
///
/// In other words, any record conforms to `other` should also conform to `self`.
pub fn contains(&self, other: &Schema) -> bool {
if self.fields.len() != other.fields.len() {
return false;
}

for (i, field) in other.fields.iter().enumerate() {
if !self.fields[i].contains(field) {
return false;
}
}

// make sure self.metadata is a superset of other.metadata
for (k, v) in &other.metadata {
match self.metadata.get(k) {
Some(s) => {
if s != v {
return false;
}
}
None => {
return false;
}
}
}

true
}
}

impl fmt::Display for Schema {
Expand Down
92 changes: 90 additions & 2 deletions rust/datafusion/src/datasource/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ impl MemTable {
if partitions
.iter()
.flatten()
.all(|batches| batches.schema() == schema)
.all(|batches| schema.contains(&batches.schema()))
{
let statistics = calculate_statistics(&schema, &partitions);
debug!("MemTable statistics: {:?}", statistics);
Expand Down Expand Up @@ -220,6 +220,7 @@ mod tests {
use arrow::array::Int32Array;
use arrow::datatypes::{DataType, Field, Schema};
use futures::StreamExt;
use std::collections::HashMap;

#[tokio::test]
async fn test_with_projection() -> Result<()> {
Expand Down Expand Up @@ -333,7 +334,7 @@ mod tests {
}

#[test]
fn test_schema_validation() -> Result<()> {
fn test_schema_validation_incompatible_column() -> Result<()> {
let schema1 = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Expand Down Expand Up @@ -365,4 +366,91 @@ mod tests {

Ok(())
}

#[test]
fn test_schema_validation_different_column_count() -> Result<()> {
let schema1 = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]));

let schema2 = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]));

let batch = RecordBatch::try_new(
schema1,
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![7, 5, 9])),
],
)?;

match MemTable::try_new(schema2, vec![vec![batch]]) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Err(DataFusionError::Plan(e)) => assert_eq!(
"\"Mismatch between schema and batches\"",
format!("{:?}", e)
),
_ => panic!("MemTable::new should have failed due to schema mismatch"),
}

Ok(())
}

#[tokio::test]
async fn test_merged_schema() -> Result<()> {
let mut metadata = HashMap::new();
metadata.insert("foo".to_string(), "bar".to_string());

let schema1 = Schema::new_with_metadata(
vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
],
// test for comparing metadata
metadata,
);

let schema2 = Schema::new(vec![
// test for comparing nullability
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]);

let merged_schema = Schema::try_merge(vec![schema1.clone(), schema2.clone()])?;

let batch1 = RecordBatch::try_new(
Arc::new(schema1),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![4, 5, 6])),
Arc::new(Int32Array::from(vec![7, 8, 9])),
],
)?;

let batch2 = RecordBatch::try_new(
Arc::new(schema2),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![4, 5, 6])),
Arc::new(Int32Array::from(vec![7, 8, 9])),
],
)?;

let provider =
MemTable::try_new(Arc::new(merged_schema), vec![vec![batch1, batch2]])?;

let exec = provider.scan(&None, 1024, &[])?;
let mut it = exec.execute(0).await?;
let batch1 = it.next().await.unwrap()?;
assert_eq!(3, batch1.schema().fields().len());
assert_eq!(3, batch1.num_columns());
assert_eq!(provider.statistics().num_rows, Some(6));

Ok(())
}
}