Skip to content

Commit

Permalink
support tpch_1 consumer_producer_test
Browse files Browse the repository at this point in the history
  • Loading branch information
Lordworms committed Jun 9, 2024
1 parent c012e9c commit 3c511a2
Show file tree
Hide file tree
Showing 6 changed files with 969 additions and 3 deletions.
1 change: 1 addition & 0 deletions datafusion/substrait/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ object_store = { workspace = true }
pbjson-types = "0.6"
prost = "0.12"
substrait = { version = "0.34.0", features = ["serde"] }
url = { workspace = true }

[dev-dependencies]
serde_json = "1.0"
Expand Down
97 changes: 94 additions & 3 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ use datafusion::arrow::datatypes::{
use datafusion::common::{
not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef,
};
use substrait::proto::expression::literal::IntervalDayToSecond;
use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile;
use url::Url;

use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
use datafusion::execution::FunctionRegistry;
Expand Down Expand Up @@ -408,7 +411,6 @@ pub async fn from_substrait_rel(
};
aggr_expr.push(agg_func?.as_ref().clone());
}

input.aggregate(group_expr, aggr_expr)?.build()
} else {
not_impl_err!("Aggregate without an input is not valid")
Expand Down Expand Up @@ -569,7 +571,80 @@ pub async fn from_substrait_rel(

Ok(LogicalPlan::Values(Values { schema, values }))
}
_ => not_impl_err!("Only NamedTable and VirtualTable reads are supported"),
Some(ReadType::LocalFiles(lf)) => {
fn extract_filename(name: &str) -> Option<String> {
let corrected_url =
if name.starts_with("file://") && !name.starts_with("file:///") {
name.replacen("file://", "file:///", 1)
} else {
name.to_string()
};

Url::parse(&corrected_url).ok().and_then(|url| {
let path = url.path();
std::path::Path::new(path)
.file_name()
.map(|filename| filename.to_string_lossy().to_string())
})
}

// we could use the file name to check the original table provider
// TODO: currently does not support multiple local files
let filename: Option<String> =
lf.items.first().and_then(|x| match x.path_type.as_ref() {
Some(UriFile(name)) => extract_filename(name),
_ => None,
});

if lf.items.len() > 1 || filename.is_none() {
return not_impl_err!(
"Only NamedTable and VirtualTable reads are supported"
);
}
let name = filename.unwrap();
// directly use unwrap here since we could determine it is a valid one
let table_reference = TableReference::Bare { table: name.into() };
let t = ctx.table(table_reference).await?;
let t = t.into_optimized_plan()?;
match &read.projection {
Some(MaskExpression { select, .. }) => match &select.as_ref() {
Some(projection) => {
let column_indices: Vec<usize> = projection
.struct_items
.iter()
.map(|item| item.field as usize)
.collect();
match &t {
LogicalPlan::TableScan(scan) => {
let fields = column_indices
.iter()
.map(|i| {
scan.projected_schema.qualified_field(*i)
})
.map(|(qualifier, field)| {
(qualifier.cloned(), Arc::new(field.clone()))
})
.collect();
let mut scan = scan.clone();
scan.projection = Some(column_indices);
scan.projected_schema =
DFSchemaRef::new(DFSchema::new_with_metadata(
fields,
HashMap::new(),
)?);
Ok(LogicalPlan::TableScan(scan))
}
_ => plan_err!("unexpected plan for table"),
}
}
_ => Ok(t),
},
_ => Ok(t),
}
}
_ => {
not_impl_err!("Only NamedTable and VirtualTable reads are supported")
}
},
Some(RelType::Set(set)) => match set_rel::SetOp::try_from(set.op) {
Ok(set_op) => match set_op {
Expand Down Expand Up @@ -810,14 +885,21 @@ pub async fn from_substrait_agg_func(
f.function_reference
);
};

let function_name = function_name.split(':').next().unwrap_or(function_name);
// try udaf first, then built-in aggr fn.
if let Ok(fun) = ctx.udaf(function_name) {
Ok(Arc::new(Expr::AggregateFunction(
expr::AggregateFunction::new_udf(fun, args, distinct, filter, order_by, None),
)))
} else if let Ok(fun) = aggregate_function::AggregateFunction::from_str(function_name)
{
match &fun {
// deal with situation that count(*) got no arguments
aggregate_function::AggregateFunction::Count if args.is_empty() => {
args.push(Expr::Literal(ScalarValue::Int64(Some(1))));
}
_ => {}
}
Ok(Arc::new(Expr::AggregateFunction(
expr::AggregateFunction::new(fun, args, distinct, filter, order_by, None),
)))
Expand Down Expand Up @@ -1253,6 +1335,8 @@ fn from_substrait_type(
r#type::Kind::Struct(s) => Ok(DataType::Struct(from_substrait_struct_type(
s, dfs_names, name_idx,
)?)),
r#type::Kind::Varchar(_) => Ok(DataType::Utf8),
r#type::Kind::FixedChar(_) => Ok(DataType::Utf8),
_ => not_impl_err!("Unsupported Substrait type: {s_kind:?}"),
},
_ => not_impl_err!("`None` Substrait kind is not supported"),
Expand Down Expand Up @@ -1541,6 +1625,13 @@ fn from_substrait_literal(
Some(LiteralType::Null(ntype)) => {
from_substrait_null(ntype, dfs_names, name_idx)?
}
Some(LiteralType::IntervalDayToSecond(IntervalDayToSecond {
days,
seconds,
microseconds,
})) => {
ScalarValue::new_interval_dt(*days, (seconds * 1000) + (microseconds / 1000))
}
Some(LiteralType::UserDefined(user_defined)) => {
match user_defined.type_reference {
INTERVAL_YEAR_MONTH_TYPE_REF => {
Expand Down
1 change: 1 addition & 0 deletions datafusion/substrait/tests/cases/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ mod logical_plans;
mod roundtrip_logical_plan;
mod roundtrip_physical_plan;
mod serialize;
mod tpch;
63 changes: 63 additions & 0 deletions datafusion/substrait/tests/cases/tpch.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! tests contains in <https://github.com/substrait-io/consumer-testing/tree/main/substrait_consumer/tests/integration/queries/tpch_substrait_plans>

#[cfg(test)]
mod tests {
use datafusion::common::Result;
use datafusion::execution::options::ParquetReadOptions;
use datafusion::prelude::SessionContext;
use datafusion_substrait::logical_plan::consumer::from_substrait_plan;
use std::fs::File;
use std::io::BufReader;
use substrait::proto::Plan;

#[tokio::test]
async fn tpch_test_1() -> Result<()> {
let ctx = create_context().await?;
let path = "tests/testdata/query_1.json";
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
File::open(path).expect("file not found"),
))
.expect("failed to parse json");

let plan = from_substrait_plan(&ctx, &proto).await?;

assert!(
format!("{:?}", plan).eq_ignore_ascii_case(
"Sort: FILENAME_PLACEHOLDER_0.l_returnflag ASC NULLS LAST, FILENAME_PLACEHOLDER_0.l_linestatus ASC NULLS LAST\n \
Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.l_returnflag, FILENAME_PLACEHOLDER_0.l_linestatus]], aggr=[[SUM(FILENAME_PLACEHOLDER_0.l_quantity), SUM(FILENAME_PLACEHOLDER_0.l_extendedprice), SUM(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount), SUM(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount * Int32(1) + FILENAME_PLACEHOLDER_0.l_tax), AVG(FILENAME_PLACEHOLDER_0.l_quantity), AVG(FILENAME_PLACEHOLDER_0.l_extendedprice), AVG(FILENAME_PLACEHOLDER_0.l_discount), COUNT(Int64(1))]]\n \
Projection: FILENAME_PLACEHOLDER_0.l_returnflag, FILENAME_PLACEHOLDER_0.l_linestatus, FILENAME_PLACEHOLDER_0.l_quantity, FILENAME_PLACEHOLDER_0.l_extendedprice, FILENAME_PLACEHOLDER_0.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_0.l_discount), FILENAME_PLACEHOLDER_0.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_0.l_discount) * (CAST(Int32(1) AS Decimal128(19, 0)) + FILENAME_PLACEHOLDER_0.l_tax), FILENAME_PLACEHOLDER_0.l_discount\n \
Filter: FILENAME_PLACEHOLDER_0.l_shipdate <= Date32(\"1998-12-01\") - IntervalDayTime(\"IntervalDayTime { days: 120, milliseconds: 0 }\")\n \
TableScan: FILENAME_PLACEHOLDER_0 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]"
)
);
Ok(())
}

async fn create_context() -> datafusion::common::Result<SessionContext> {
let ctx = SessionContext::new();
ctx.register_parquet(
"FILENAME_PLACEHOLDER_0",
"tests/testdata/tpch/lineitem.parquet",
ParquetReadOptions::default(),
)
.await?;
Ok(ctx)
}
}
Loading

0 comments on commit 3c511a2

Please sign in to comment.