Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

validate and adjust Substrait NamedTable schemas (#12223) #12245

Merged
merged 21 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 101 additions & 10 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ use crate::variation_const::{
};
use datafusion::arrow::array::{new_empty_array, AsArray};
use datafusion::common::scalar::ScalarStructBuilder;
use datafusion::dataframe::DataFrame;
use datafusion::logical_expr::expr::InList;
use datafusion::logical_expr::{
col, expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning,
Expand Down Expand Up @@ -640,6 +641,10 @@ pub async fn from_substrait_rel(
}
Some(RelType::Read(read)) => match &read.as_ref().read_type {
Some(ReadType::NamedTable(nt)) => {
let named_struct = read.base_schema.as_ref().ok_or_else(|| {
substrait_datafusion_err!("No base schema provided for Named Table")
})?;

let table_reference = match nt.names.len() {
0 => {
return plan_err!("No table name found in NamedTable");
Expand All @@ -657,7 +662,13 @@ pub async fn from_substrait_rel(
table: nt.names[2].clone().into(),
},
};
let t = ctx.table(table_reference).await?;

let substrait_schema =
from_substrait_named_struct(named_struct, extensions)?
.replace_qualifier(table_reference.clone());

let t = ctx.table(table_reference.clone()).await?;
let t = ensure_schema_compatability(t, substrait_schema)?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should maybe add this for the local file reads below as well (I didn't have it in my branch yet as I didn't need it immediately). Somewhat annoyingly it'll cause the rest of the TPCH tests to fail...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this worth doing as one big swoop in this PR, or would it make sense to do it as a followup?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine by me to do it separately!

let t = t.into_optimized_plan()?;
extract_projection(t, &read.projection)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I started wondering if the ensure_schema_compatability can now conflict with extract_projection - and I think it can, either by failing if DF doesn't optimize the select into a projection, or if DF does, then by overriding the select's projection with the Substrait projection...

I guess a fix would be something like in extract_projection, if there is an existing scan.projection, then apply columnIndices on it first

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That did indeed cause problems. It triggered an error of unexpected plan for table in extract_projection.

I added some code for this case in d571eb2 (#12245). Is something like this what you had in mind?

I am noticing that the plans generated look a little weird/bad with a lot of redundant projects

                "Projection: DATA.a, DATA.b\
                \n  Projection: DATA.a, DATA.b\
                \n    Projection: DATA.a, DATA.b, DATA.c\
                \n      TableScan: DATA projection=[b, a, c]"

but they are at least correct for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is something like this what you had in mind?

What I had in mind was manipulating the scan.projection directly - kinda like it is alreadydone in extract_projection, we could do it that way also for ensure_schema_compatibility. That way there wouldn't be additional Projections, and maybe it'd be a bit more efficient if the current setup doesn't push the column-pruning into the scan level (though I'm a bit surprised they don't get optimized anyways).

But I don't think it's necessary - the way you've done it here seems correct, and we (I?) can do the project-mangling as a followup, unless you want to take a stab at it :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think project unmangling would be better as a follow-up. Possible as part of #12347 because supporting remaps is going to add yet another layer of Projects 😅

}
Expand All @@ -671,7 +682,7 @@ pub async fn from_substrait_rel(
if vt.values.is_empty() {
return Ok(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema,
schema: DFSchemaRef::new(schema),
}));
}

Expand Down Expand Up @@ -704,7 +715,10 @@ pub async fn from_substrait_rel(
})
.collect::<Result<_>>()?;

Ok(LogicalPlan::Values(Values { schema, values }))
Ok(LogicalPlan::Values(Values {
schema: DFSchemaRef::new(schema),
values,
}))
}
Some(ReadType::LocalFiles(lf)) => {
fn extract_filename(name: &str) -> Option<String> {
Expand Down Expand Up @@ -850,6 +864,82 @@ pub async fn from_substrait_rel(
}
}

/// Ensures that the given Substrait schema is compatible with the schema as given by DataFusion
///
/// This means:
/// 1. All fields present in the Substrait schema are present in the DataFusion schema. The
/// DataFusion schema may have MORE fields, but not the other way around.
/// 2. All fields are compatible. See [`ensure_field_compatability`] for details
///
/// This function returns a DataFrame with fields adjusted if necessary in the event that the
/// SubstraitSchema is a subset of the DataFusion schema.
// ```
fn ensure_schema_compatability(
table: DataFrame,
substrait_schema: DFSchema,
) -> Result<DataFrame> {
// For now strip the qualifiers from the DataFusion and Substrait schemas to make comparisons
// easier as there are issues around qualifier case.
// TODO: ensure that qualifiers are the same
vbarua marked this conversation as resolved.
Show resolved Hide resolved
let df_schema = table.schema().to_owned().strip_qualifiers();
if df_schema.logically_equivalent_names_and_types(&substrait_schema) {
return Ok(table);
}
let selected_columns = substrait_schema
.strip_qualifiers()
.fields()
.iter()
.map(|substrait_field| {
let df_field =
df_schema.field_with_unqualified_name(substrait_field.name())?;
ensure_field_compatability(df_field, substrait_field)?;
Ok(col(format!("\"{}\"", df_field.name())))
})
.collect::<Result<_>>()?;

table.select(selected_columns)
}

///
/// Ensures that the given Substrait field is compatible with the given DataFusion field
///
/// A field is compatible between Substrait and DataFusion if:
/// 1. They have logically equivalent types.
/// 2. They have the same nullability OR the Substrait field is nullable and the DataFusion fields
/// is not nullable.
///
/// If a Substrait field is not nullable, the Substrait plan may be built around assuming it is not
/// nullable. As such if DataFusion has that field as not nullable the plan should be rejected.
vbarua marked this conversation as resolved.
Show resolved Hide resolved
//
fn ensure_field_compatability(
datafusion_field: &Field,
substrait_field: &Field,
) -> Result<()> {
if DFSchema::datatype_is_logically_equal(
vbarua marked this conversation as resolved.
Show resolved Hide resolved
datafusion_field.data_type(),
substrait_field.data_type(),
) {
if (datafusion_field.is_nullable() == substrait_field.is_nullable())
|| (substrait_field.is_nullable() && !datafusion_field.is_nullable())
{
Ok(())
} else {
// TODO: from_substrait_struct_type needs to be updated to set the nullability correctly. It defaults to true for now.
substrait_err!(
"Field '{}' is nullable in the Substrait schema but not nullable in the DataFusion schema.",
vbarua marked this conversation as resolved.
Show resolved Hide resolved
substrait_field.name()
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of what's mentioned in the TODO, this code never fires. I'm opting not to fix that as part of this PR.

}
} else {
substrait_err!(
"Field '{}' in Substrait schema has a different type ({}) than the corresponding field in the table schema ({}).",
substrait_field.name(),
substrait_field.data_type(),
datafusion_field.data_type()
)
}
}

/// (Re)qualify the sides of a join if needed, i.e. if the columns from one side would otherwise
/// conflict with the columns from the other.
/// Substrait doesn't currently allow specifying aliases, neither for columns nor for tables. For
Expand Down Expand Up @@ -1586,10 +1676,11 @@ fn next_struct_field_name(
}
}

fn from_substrait_named_struct(
/// Convert Substrait NamedStruct to DataFusion DFSchemaRef
pub fn from_substrait_named_struct(
vbarua marked this conversation as resolved.
Show resolved Hide resolved
base_schema: &NamedStruct,
extensions: &Extensions,
) -> Result<DFSchemaRef> {
) -> Result<DFSchema> {
let mut name_idx = 0;
let fields = from_substrait_struct_type(
base_schema.r#struct.as_ref().ok_or_else(|| {
Expand All @@ -1601,12 +1692,12 @@ fn from_substrait_named_struct(
);
if name_idx != base_schema.names.len() {
return substrait_err!(
"Names list must match exactly to nested schema, but found {} uses for {} names",
name_idx,
base_schema.names.len()
);
"Names list must match exactly to nested schema, but found {} uses for {} names",
name_idx,
base_schema.names.len()
);
}
Ok(DFSchemaRef::new(DFSchema::try_from(Schema::new(fields?))?))
DFSchema::try_from(Schema::new(fields?))
}

fn from_substrait_bound(
Expand Down
16 changes: 5 additions & 11 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ use crate::variation_const::{
use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait};
use datafusion::common::{
exec_err, internal_err, not_impl_err, plan_err, substrait_datafusion_err,
substrait_err, DFSchemaRef, ToDFSchema,
};
use datafusion::common::{substrait_err, DFSchemaRef};
#[allow(unused_imports)]
use datafusion::logical_expr::expr::{
Alias, BinaryExpr, Case, Cast, GroupingSet, InList, InSubquery, Sort, WindowFunction,
Expand Down Expand Up @@ -140,19 +140,13 @@ pub fn to_substrait_rel(
maintain_singular_struct: false,
});

let table_schema = scan.source.schema().to_dfschema_ref()?;
let base_schema = to_substrait_named_struct(&table_schema, extensions)?;

Ok(Box::new(Rel {
rel_type: Some(RelType::Read(Box::new(ReadRel {
common: None,
base_schema: Some(NamedStruct {
names: scan
.source
.schema()
.fields()
.iter()
.map(|f| f.name().to_owned())
.collect(),
r#struct: None,
}),
base_schema: Some(base_schema),
vbarua marked this conversation as resolved.
Show resolved Hide resolved
filter: None,
best_effort_filter: None,
projection,
Expand Down
23 changes: 8 additions & 15 deletions datafusion/substrait/tests/cases/function_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,40 +19,33 @@

#[cfg(test)]
mod tests {
use crate::utils::test::TestSchemaCollector;
use datafusion::common::Result;
use datafusion::prelude::{CsvReadOptions, SessionContext};
use datafusion_substrait::logical_plan::consumer::from_substrait_plan;
use std::fs::File;
use std::io::BufReader;
use substrait::proto::Plan;

#[tokio::test]
async fn contains_function_test() -> Result<()> {
let ctx = create_context().await?;

let path = "tests/testdata/contains_plan.substrait.json";
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(

let proto_plan = serde_json::from_reader::<_, Plan>(BufReader::new(
File::open(path).expect("file not found"),
))
.expect("failed to parse json");

let plan = from_substrait_plan(&ctx, &proto).await?;
let ctx = TestSchemaCollector::generate_context_from_plan(&proto_plan);
let plan = from_substrait_plan(&ctx, &proto_plan).await?;

let plan_str = format!("{}", plan);

assert_eq!(
plan_str,
"Projection: nation.b AS n_name\
\n Filter: contains(nation.b, Utf8(\"IA\"))\
\n TableScan: nation projection=[a, b, c, d, e, f]"
"Projection: nation.n_name\
\n Filter: contains(nation.n_name, Utf8(\"IA\"))\
\n TableScan: nation projection=[n_nationkey, n_name, n_regionkey, n_comment]"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can see in this change how the DataFusion and Substrait plans had different schemas.

);
Ok(())
}

async fn create_context() -> datafusion::common::Result<SessionContext> {
let ctx = SessionContext::new();
ctx.register_csv("nation", "tests/testdata/data.csv", CsvReadOptions::new())
.await?;
Ok(ctx)
}
}
58 changes: 19 additions & 39 deletions datafusion/substrait/tests/cases/logical_plans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,10 @@

#[cfg(test)]
mod tests {
use crate::utils::test::{read_json, TestSchemaCollector};
use datafusion::common::Result;
use datafusion::dataframe::DataFrame;
use datafusion::prelude::{CsvReadOptions, SessionContext};
use datafusion_substrait::logical_plan::consumer::from_substrait_plan;
use std::fs::File;
use std::io::BufReader;
use substrait::proto::Plan;

#[tokio::test]
async fn scalar_function_compound_signature() -> Result<()> {
Expand All @@ -35,18 +32,17 @@ mod tests {
// we don't yet produce such plans.
// Once we start producing plans with compound signatures, this test can be replaced by the roundtrip tests.

let ctx = create_context().await?;

// File generated with substrait-java's Isthmus:
// ./isthmus-cli/build/graal/isthmus "select not d from data" -c "create table data (d boolean)"
let proto = read_json("tests/testdata/test_plans/select_not_bool.substrait.json");

let plan = from_substrait_plan(&ctx, &proto).await?;
// ./isthmus-cli/build/graal/isthmus --create "create table data (d boolean)" "select not d from data"
let proto_plan =
read_json("tests/testdata/test_plans/select_not_bool.substrait.json");
let ctx = TestSchemaCollector::generate_context_from_plan(&proto_plan);
let plan = from_substrait_plan(&ctx, &proto_plan).await?;

assert_eq!(
format!("{}", plan),
"Projection: NOT DATA.a AS EXPR$0\
\n TableScan: DATA projection=[a, b, c, d, e, f]"
"Projection: NOT DATA.D AS EXPR$0\
\n TableScan: DATA projection=[D]"
);
Ok(())
}
Expand All @@ -61,19 +57,18 @@ mod tests {
// we don't yet produce such plans.
// Once we start producing plans with compound signatures, this test can be replaced by the roundtrip tests.

let ctx = create_context().await?;

// File generated with substrait-java's Isthmus:
// ./isthmus-cli/build/graal/isthmus "select sum(d) OVER (PARTITION BY part ORDER BY ord ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) AS lead_expr from data" -c "create table data (d int, part int, ord int)"
let proto = read_json("tests/testdata/test_plans/select_window.substrait.json");

let plan = from_substrait_plan(&ctx, &proto).await?;
// ./isthmus-cli/build/graal/isthmus --create "create table data (d int, part int, ord int)" "select sum(d) OVER (PARTITION BY part ORDER BY ord ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) AS lead_expr from data"
let proto_plan =
read_json("tests/testdata/test_plans/select_window.substrait.json");
let ctx = TestSchemaCollector::generate_context_from_plan(&proto_plan);
let plan = from_substrait_plan(&ctx, &proto_plan).await?;

assert_eq!(
format!("{}", plan),
"Projection: sum(DATA.a) PARTITION BY [DATA.b] ORDER BY [DATA.c ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS LEAD_EXPR\
\n WindowAggr: windowExpr=[[sum(DATA.a) PARTITION BY [DATA.b] ORDER BY [DATA.c ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]]\
\n TableScan: DATA projection=[a, b, c, d, e, f]"
"Projection: sum(DATA.D) PARTITION BY [DATA.PART] ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS LEAD_EXPR\
\n WindowAggr: windowExpr=[[sum(DATA.D) PARTITION BY [DATA.PART] ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]]\
\n TableScan: DATA projection=[D, PART, ORD]"
);
Ok(())
}
Expand All @@ -83,11 +78,10 @@ mod tests {
// DataFusion's Substrait consumer treats all lists as nullable, even if the Substrait plan specifies them as non-nullable.
// That's because implementing the non-nullability consistently is non-trivial.
// This test confirms that reading a plan with non-nullable lists works as expected.
let ctx = create_context().await?;
let proto =
let proto_plan =
read_json("tests/testdata/test_plans/non_nullable_lists.substrait.json");

let plan = from_substrait_plan(&ctx, &proto).await?;
let ctx = TestSchemaCollector::generate_context_from_plan(&proto_plan);
let plan = from_substrait_plan(&ctx, &proto_plan).await?;

assert_eq!(format!("{}", &plan), "Values: (List([1, 2]))");

Expand All @@ -96,18 +90,4 @@ mod tests {

Ok(())
}

fn read_json(path: &str) -> Plan {
serde_json::from_reader::<_, Plan>(BufReader::new(
File::open(path).expect("file not found"),
))
.expect("failed to parse json")
}

async fn create_context() -> datafusion::common::Result<SessionContext> {
let ctx = SessionContext::new();
ctx.register_csv("DATA", "tests/testdata/data.csv", CsvReadOptions::new())
.await?;
Ok(ctx)
}
}
1 change: 1 addition & 0 deletions datafusion/substrait/tests/cases/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ mod logical_plans;
mod roundtrip_logical_plan;
mod roundtrip_physical_plan;
mod serialize;
mod substrait_validations;
Loading