Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

validate and adjust Substrait NamedTable schemas (#12223) #12245

Merged
merged 21 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 59 additions & 8 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -657,16 +657,34 @@ pub async fn from_substrait_rel(
table: nt.names[2].clone().into(),
},
};
let t = ctx.table(table_reference).await?;
let t = ctx.table(table_reference.clone()).await?;
let t = t.into_optimized_plan()?;
let datafusion_schema = t.schema();

let named_struct = read.base_schema.as_ref().ok_or_else(|| {
substrait_datafusion_err!("No base schema provided for Named Scan")
})?;
let substrait_schema = from_substrait_named_struct(
named_struct,
extensions,
Some(table_reference),
)?;

if !validate_substrait_schema(datafusion_schema, &substrait_schema) {
return Err(substrait_datafusion_err!(
"Schema mismatch in ReadRel: substrait: {:?}, DataFusion: {:?}",
substrait_schema,
datafusion_schema
));
};
extract_projection(t, &read.projection)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I started wondering if the ensure_schema_compatability can now conflict with extract_projection - and I think it can, either by failing if DF doesn't optimize the select into a projection, or if DF does, then by overriding the select's projection with the Substrait projection...

I guess a fix would be something like in extract_projection, if there is an existing scan.projection, then apply columnIndices on it first

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That did indeed cause problems. It triggered an error of unexpected plan for table in extract_projection.

I added some code for this case in d571eb2 (#12245). Is something like this what you had in mind?

I am noticing that the plans generated look a little weird/bad with a lot of redundant projects

                "Projection: DATA.a, DATA.b\
                \n  Projection: DATA.a, DATA.b\
                \n    Projection: DATA.a, DATA.b, DATA.c\
                \n      TableScan: DATA projection=[b, a, c]"

but they are at least correct for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is something like this what you had in mind?

What I had in mind was manipulating the scan.projection directly - kinda like it is alreadydone in extract_projection, we could do it that way also for ensure_schema_compatibility. That way there wouldn't be additional Projections, and maybe it'd be a bit more efficient if the current setup doesn't push the column-pruning into the scan level (though I'm a bit surprised they don't get optimized anyways).

But I don't think it's necessary - the way you've done it here seems correct, and we (I?) can do the project-mangling as a followup, unless you want to take a stab at it :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think project unmangling would be better as a follow-up. Possible as part of #12347 because supporting remaps is going to add yet another layer of Projects 😅

}
Some(ReadType::VirtualTable(vt)) => {
let base_schema = read.base_schema.as_ref().ok_or_else(|| {
substrait_datafusion_err!("No base schema provided for Virtual Table")
})?;

let schema = from_substrait_named_struct(base_schema, extensions)?;
let schema = from_substrait_named_struct(base_schema, extensions, None)?;

if vt.values.is_empty() {
return Ok(LogicalPlan::EmptyRelation(EmptyRelation {
Expand Down Expand Up @@ -850,6 +868,31 @@ pub async fn from_substrait_rel(
}
}

/// Validates that the given Substrait schema matches the given DataFusion schema
vbarua marked this conversation as resolved.
Show resolved Hide resolved
/// Returns true if the two schemas have the same qualified named fields with the same data types.
/// Returns false otherwise.
/// Ignores case when comparing field names
///
/// This code is equivalent to [DFSchema::equivalent_names_and_types] except that the field name
/// checking is case-insensitive
fn validate_substrait_schema(
datafusion_schema: &DFSchema,
substrait_schema: &DFSchema,
) -> bool {
if datafusion_schema.fields().len() != substrait_schema.fields().len() {
return false;
}
let datafusion_fields = datafusion_schema.iter();
let substrait_fields = substrait_schema.iter();
datafusion_fields
.zip(substrait_fields)
.all(|((q1, f1), (q2, f2))| {
q1 == q2
&& f1.name().to_lowercase() == f2.name().to_lowercase()
vbarua marked this conversation as resolved.
Show resolved Hide resolved
&& f1.data_type().equals_datatype(f2.data_type())
vbarua marked this conversation as resolved.
Show resolved Hide resolved
})
}

/// (Re)qualify the sides of a join if needed, i.e. if the columns from one side would otherwise
/// conflict with the columns from the other.
/// Substrait doesn't currently allow specifying aliases, neither for columns nor for tables. For
Expand Down Expand Up @@ -1586,9 +1629,12 @@ fn next_struct_field_name(
}
}

fn from_substrait_named_struct(
/// Convert Substrait NamedStruct to DataFusion DFSchemaRef
pub fn from_substrait_named_struct(
vbarua marked this conversation as resolved.
Show resolved Hide resolved
base_schema: &NamedStruct,
extensions: &Extensions,
// optional qualifier to apply to every field in the schema
field_qualifier: Option<TableReference>,
) -> Result<DFSchemaRef> {
let mut name_idx = 0;
let fields = from_substrait_struct_type(
Expand All @@ -1601,12 +1647,17 @@ fn from_substrait_named_struct(
);
if name_idx != base_schema.names.len() {
return substrait_err!(
"Names list must match exactly to nested schema, but found {} uses for {} names",
name_idx,
base_schema.names.len()
);
"Names list must match exactly to nested schema, but found {} uses for {} names",
name_idx,
base_schema.names.len()
);
}
let mut df_schema = DFSchema::try_from(Schema::new(fields?))?;
match field_qualifier {
None => (),
Some(fq) => df_schema = df_schema.replace_qualifier(fq),
vbarua marked this conversation as resolved.
Show resolved Hide resolved
}
Ok(DFSchemaRef::new(DFSchema::try_from(Schema::new(fields?))?))
Ok(DFSchemaRef::new(df_schema))
vbarua marked this conversation as resolved.
Show resolved Hide resolved
}

fn from_substrait_bound(
Expand Down
16 changes: 5 additions & 11 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ use crate::variation_const::{
use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait};
use datafusion::common::{
exec_err, internal_err, not_impl_err, plan_err, substrait_datafusion_err,
substrait_err, DFSchemaRef, ToDFSchema,
};
use datafusion::common::{substrait_err, DFSchemaRef};
#[allow(unused_imports)]
use datafusion::logical_expr::expr::{
Alias, BinaryExpr, Case, Cast, GroupingSet, InList, InSubquery, Sort, WindowFunction,
Expand Down Expand Up @@ -140,19 +140,13 @@ pub fn to_substrait_rel(
maintain_singular_struct: false,
});

let table_schema = scan.source.schema().to_dfschema_ref()?;
let base_schema = to_substrait_named_struct(&table_schema, extensions)?;

Ok(Box::new(Rel {
rel_type: Some(RelType::Read(Box::new(ReadRel {
common: None,
base_schema: Some(NamedStruct {
names: scan
.source
.schema()
.fields()
.iter()
.map(|f| f.name().to_owned())
.collect(),
r#struct: None,
}),
base_schema: Some(base_schema),
vbarua marked this conversation as resolved.
Show resolved Hide resolved
filter: None,
best_effort_filter: None,
projection,
Expand Down
23 changes: 8 additions & 15 deletions datafusion/substrait/tests/cases/function_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,40 +19,33 @@

#[cfg(test)]
mod tests {
use crate::utils::test::TestSchemaCollector;
use datafusion::common::Result;
use datafusion::prelude::{CsvReadOptions, SessionContext};
use datafusion_substrait::logical_plan::consumer::from_substrait_plan;
use std::fs::File;
use std::io::BufReader;
use substrait::proto::Plan;

#[tokio::test]
async fn contains_function_test() -> Result<()> {
let ctx = create_context().await?;

let path = "tests/testdata/contains_plan.substrait.json";
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(

let proto_plan = serde_json::from_reader::<_, Plan>(BufReader::new(
File::open(path).expect("file not found"),
))
.expect("failed to parse json");

let plan = from_substrait_plan(&ctx, &proto).await?;
let ctx = TestSchemaCollector::generate_context_from_plan(&proto_plan);
let plan = from_substrait_plan(&ctx, &proto_plan).await?;

let plan_str = format!("{}", plan);

assert_eq!(
plan_str,
"Projection: nation.b AS n_name\
\n Filter: contains(nation.b, Utf8(\"IA\"))\
\n TableScan: nation projection=[a, b, c, d, e, f]"
"Projection: nation.n_name\
\n Filter: contains(nation.n_name, Utf8(\"IA\"))\
\n TableScan: nation projection=[n_nationkey, n_name, n_regionkey, n_comment]"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can see in this change how the DataFusion and Substrait plans had different schemas.

);
Ok(())
}

async fn create_context() -> datafusion::common::Result<SessionContext> {
let ctx = SessionContext::new();
ctx.register_csv("nation", "tests/testdata/data.csv", CsvReadOptions::new())
.await?;
Ok(ctx)
}
}
58 changes: 19 additions & 39 deletions datafusion/substrait/tests/cases/logical_plans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,10 @@

#[cfg(test)]
mod tests {
use crate::utils::test::{read_json, TestSchemaCollector};
use datafusion::common::Result;
use datafusion::dataframe::DataFrame;
use datafusion::prelude::{CsvReadOptions, SessionContext};
use datafusion_substrait::logical_plan::consumer::from_substrait_plan;
use std::fs::File;
use std::io::BufReader;
use substrait::proto::Plan;

#[tokio::test]
async fn scalar_function_compound_signature() -> Result<()> {
Expand All @@ -35,18 +32,17 @@ mod tests {
// we don't yet produce such plans.
// Once we start producing plans with compound signatures, this test can be replaced by the roundtrip tests.

let ctx = create_context().await?;

// File generated with substrait-java's Isthmus:
// ./isthmus-cli/build/graal/isthmus "select not d from data" -c "create table data (d boolean)"
let proto = read_json("tests/testdata/test_plans/select_not_bool.substrait.json");

let plan = from_substrait_plan(&ctx, &proto).await?;
// ./isthmus-cli/build/graal/isthmus --create "create table data (d boolean)" "select not d from data"
let proto_plan =
read_json("tests/testdata/test_plans/select_not_bool.substrait.json");
let ctx = TestSchemaCollector::generate_context_from_plan(&proto_plan);
let plan = from_substrait_plan(&ctx, &proto_plan).await?;

assert_eq!(
format!("{}", plan),
"Projection: NOT DATA.a AS EXPR$0\
\n TableScan: DATA projection=[a, b, c, d, e, f]"
"Projection: NOT DATA.D AS EXPR$0\
\n TableScan: DATA projection=[D]"
);
Ok(())
}
Expand All @@ -61,19 +57,18 @@ mod tests {
// we don't yet produce such plans.
// Once we start producing plans with compound signatures, this test can be replaced by the roundtrip tests.

let ctx = create_context().await?;

// File generated with substrait-java's Isthmus:
// ./isthmus-cli/build/graal/isthmus "select sum(d) OVER (PARTITION BY part ORDER BY ord ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) AS lead_expr from data" -c "create table data (d int, part int, ord int)"
let proto = read_json("tests/testdata/test_plans/select_window.substrait.json");

let plan = from_substrait_plan(&ctx, &proto).await?;
// ./isthmus-cli/build/graal/isthmus --create "create table data (d int, part int, ord int)" "select sum(d) OVER (PARTITION BY part ORDER BY ord ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) AS lead_expr from data"
let proto_plan =
read_json("tests/testdata/test_plans/select_window.substrait.json");
let ctx = TestSchemaCollector::generate_context_from_plan(&proto_plan);
let plan = from_substrait_plan(&ctx, &proto_plan).await?;

assert_eq!(
format!("{}", plan),
"Projection: sum(DATA.a) PARTITION BY [DATA.b] ORDER BY [DATA.c ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS LEAD_EXPR\
\n WindowAggr: windowExpr=[[sum(DATA.a) PARTITION BY [DATA.b] ORDER BY [DATA.c ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]]\
\n TableScan: DATA projection=[a, b, c, d, e, f]"
"Projection: sum(DATA.D) PARTITION BY [DATA.PART] ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS LEAD_EXPR\
\n WindowAggr: windowExpr=[[sum(DATA.D) PARTITION BY [DATA.PART] ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]]\
\n TableScan: DATA projection=[D, PART, ORD]"
);
Ok(())
}
Expand All @@ -83,11 +78,10 @@ mod tests {
// DataFusion's Substrait consumer treats all lists as nullable, even if the Substrait plan specifies them as non-nullable.
// That's because implementing the non-nullability consistently is non-trivial.
// This test confirms that reading a plan with non-nullable lists works as expected.
let ctx = create_context().await?;
let proto =
let proto_plan =
read_json("tests/testdata/test_plans/non_nullable_lists.substrait.json");

let plan = from_substrait_plan(&ctx, &proto).await?;
let ctx = TestSchemaCollector::generate_context_from_plan(&proto_plan);
let plan = from_substrait_plan(&ctx, &proto_plan).await?;

assert_eq!(format!("{}", &plan), "Values: (List([1, 2]))");

Expand All @@ -96,18 +90,4 @@ mod tests {

Ok(())
}

fn read_json(path: &str) -> Plan {
serde_json::from_reader::<_, Plan>(BufReader::new(
File::open(path).expect("file not found"),
))
.expect("failed to parse json")
}

async fn create_context() -> datafusion::common::Result<SessionContext> {
let ctx = SessionContext::new();
ctx.register_csv("DATA", "tests/testdata/data.csv", CsvReadOptions::new())
.await?;
Ok(ctx)
}
}
1 change: 1 addition & 0 deletions datafusion/substrait/tests/cases/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ mod logical_plans;
mod roundtrip_logical_plan;
mod roundtrip_physical_plan;
mod serialize;
mod substrait_validations;
101 changes: 101 additions & 0 deletions datafusion/substrait/tests/cases/substrait_validations.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#[cfg(test)]
mod tests {
use crate::utils::test::read_json;
use datafusion::arrow::datatypes::{DataType, Field};
use datafusion::catalog_common::TableReference;
use datafusion::common::{DFSchema, Result};
use datafusion::datasource::empty::EmptyTable;
use datafusion::prelude::SessionContext;
use datafusion_substrait::logical_plan::consumer::from_substrait_plan;
use std::collections::HashMap;
use std::sync::Arc;

fn generate_context_with_table(
table_name: &str,
field_data_type_pairs: Vec<(&str, DataType)>,
) -> Result<SessionContext> {
let table_ref = TableReference::bare(table_name);
let fields: Vec<(Option<TableReference>, Arc<Field>)> = field_data_type_pairs
.into_iter()
.map(|pair| {
let (field_name, data_type) = pair;
(
Some(table_ref.clone()),
Arc::new(Field::new(field_name, data_type, false)),
)
})
.collect();

let df_schema = DFSchema::new_with_metadata(fields, HashMap::default())?;

let ctx = SessionContext::new();
ctx.register_table(
table_ref,
Arc::new(EmptyTable::new(df_schema.inner().clone())),
)?;
Ok(ctx)
}

#[tokio::test]
async fn substrait_schema_validation_ignores_field_name_case() -> Result<()> {
let proto_plan =
read_json("tests/testdata/test_plans/simple_select.substrait.json");

let ctx = generate_context_with_table("DATA", vec![("a", DataType::Int32)])?;
from_substrait_plan(&ctx, &proto_plan).await?;
Ok(())
}

#[tokio::test]
async fn reject_plans_with_mismatched_number_of_fields() -> Result<()> {
let proto_plan =
read_json("tests/testdata/test_plans/simple_select.substrait.json");

let ctx = generate_context_with_table(
"DATA",
vec![("a", DataType::Int32), ("b", DataType::Int32)],
)?;
let res = from_substrait_plan(&ctx, &proto_plan).await;
assert!(res.is_err());
Ok(())
}

#[tokio::test]
async fn reject_plans_with_mismatched_field_names() -> Result<()> {
let proto_plan =
read_json("tests/testdata/test_plans/simple_select.substrait.json");

let ctx = generate_context_with_table("DATA", vec![("b", DataType::Date32)])?;
let res = from_substrait_plan(&ctx, &proto_plan).await;
assert!(res.is_err());
Ok(())
}

#[tokio::test]
async fn reject_plans_with_incompatible_field_types() -> Result<()> {
let proto_plan =
read_json("tests/testdata/test_plans/simple_select.substrait.json");

let ctx = generate_context_with_table("DATA", vec![("a", DataType::Date32)])?;
let res = from_substrait_plan(&ctx, &proto_plan).await;
assert!(res.is_err());
Ok(())
}
}
Loading
Loading