diff --git a/ballista/rust/client/src/context.rs b/ballista/rust/client/src/context.rs index 7dc7ec63b9564..c78b3bcd1743a 100644 --- a/ballista/rust/client/src/context.rs +++ b/ballista/rust/client/src/context.rs @@ -269,8 +269,9 @@ impl BallistaContext { options: CsvReadOptions<'_>, ) -> Result<()> { match self.read_csv(path, options).await?.to_logical_plan() { - LogicalPlan::TableScan(TableScan { source, .. }) => { - self.register_table(name, source) + LogicalPlan::TableScan(TableScan { table_name, .. }) => { + todo!("ballista context") + //self.register_table(name, source) } _ => Err(DataFusionError::Internal("Expected tables scan".to_owned())), } @@ -283,8 +284,9 @@ impl BallistaContext { options: ParquetReadOptions<'_>, ) -> Result<()> { match self.read_parquet(path, options).await?.to_logical_plan() { - LogicalPlan::TableScan(TableScan { source, .. }) => { - self.register_table(name, source) + LogicalPlan::TableScan(TableScan { table_name, .. }) => { + todo!("ballista context") + // self.register_table(name, source) } _ => Err(DataFusionError::Internal("Expected tables scan".to_owned())), } @@ -297,8 +299,9 @@ impl BallistaContext { options: AvroReadOptions<'_>, ) -> Result<()> { match self.read_avro(path, options).await?.to_logical_plan() { - LogicalPlan::TableScan(TableScan { source, .. }) => { - self.register_table(name, source) + LogicalPlan::TableScan(TableScan { table_name, .. }) => { + todo!("ballista context") + // self.register_table(name, source) } _ => Err(DataFusionError::Internal("Expected tables scan".to_owned())), } diff --git a/ballista/rust/core/src/execution_plans/distributed_query.rs b/ballista/rust/core/src/execution_plans/distributed_query.rs index b0d3bef1f062a..94183e71ea77c 100644 --- a/ballista/rust/core/src/execution_plans/distributed_query.rs +++ b/ballista/rust/core/src/execution_plans/distributed_query.rs @@ -43,6 +43,7 @@ use datafusion::physical_plan::{ use crate::serde::protobuf::execute_query_params::OptionalSessionId; use crate::serde::{AsLogicalPlan, DefaultLogicalExtensionCodec, LogicalExtensionCodec}; use async_trait::async_trait; +use datafusion::catalog::catalog::{CatalogList, MemoryCatalogList}; use datafusion::execution::context::TaskContext; use futures::future; use futures::StreamExt; @@ -164,7 +165,7 @@ impl ExecutionPlan for DistributedQueryExec { async fn execute( &self, partition: usize, - _context: Arc, + _task_context: Arc, ) -> Result { assert_eq!(0, partition); @@ -176,17 +177,29 @@ impl ExecutionPlan for DistributedQueryExec { .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; let schema: Schema = self.plan.schema().as_ref().clone().into(); + let catalog_list: Arc = Arc::new(MemoryCatalogList::new()); + println!( + "ballista catalogs BEFORE decoding logical plan: {:?}", + catalog_list.catalog_names() + ); let mut buf: Vec = vec![]; - let plan_message = - T::try_from_logical_plan(&self.plan, self.extension_codec.as_ref()).map_err( - |e| { - DataFusionError::Internal(format!( - "failed to serialize logical plan: {:?}", - e - )) - }, - )?; + let plan_message = T::try_from_logical_plan( + &self.plan, + catalog_list.as_ref(), + self.extension_codec.as_ref(), + ) + .map_err(|e| { + DataFusionError::Internal(format!( + "failed to serialize logical plan: {:?}", + e + )) + })?; + println!( + "ballista catalogs AFTER decoding logical plan: {:?}", + catalog_list.catalog_names() + ); + plan_message.try_encode(&mut buf).map_err(|e| { DataFusionError::Execution(format!("failed to encode logical plan: {:?}", e)) })?; diff --git a/ballista/rust/core/src/serde/logical_plan/mod.rs b/ballista/rust/core/src/serde/logical_plan/mod.rs index a0264271a5eea..af49567b8f6f2 100644 --- a/ballista/rust/core/src/serde/logical_plan/mod.rs +++ b/ballista/rust/core/src/serde/logical_plan/mod.rs @@ -38,6 +38,8 @@ use datafusion::logical_plan::{ }; use datafusion::prelude::SessionContext; +use datafusion::catalog::catalog::{get_table_provider, CatalogList}; +use datafusion::execution::context::{DEFAULT_CATALOG, DEFAULT_SCHEMA}; use datafusion_proto::from_proto::parse_expr; use prost::bytes::BufMut; use prost::Message; @@ -248,16 +250,42 @@ impl AsLogicalPlan for LogicalPlanNode { .with_listing_options(options) .with_schema(Arc::new(schema)); - let provider = ListingTable::try_new(config)?; + let provider = Arc::new(ListingTable::try_new(config)?); - LogicalPlanBuilder::scan_with_filters( - &scan.table_name, - Arc::new(provider), - projection, - filters, - )? - .build() - .map_err(|e| e.into()) + //TODO remove hard-coded catalog and schema here and parse table name + // into TableReference first + let catalog_name = DEFAULT_CATALOG; + let schema_name = DEFAULT_SCHEMA; + + let session_state = ctx.state.write(); + match session_state.catalog_list.catalog(catalog_name) { + Some(catalog) => match catalog.schema(schema_name) { + Some(schema) => { + let scan_info = LogicalPlanBuilder::scan_with_filters( + &scan.table_name, + provider, + projection, + filters, + )?; + + println!("Registering table '{}'", scan_info.table_name); + schema.register_table( + scan_info.table_name.to_string(), + scan_info.provider.clone(), + )?; + + scan_info.build().map_err(|e| e.into()) + } + _ => Err(BallistaError::General(format!( + "schema '{}' not found in catalog '{}'", + schema_name, catalog_name + ))), + }, + _ => Err(BallistaError::General(format!( + "catalog '{}' not found", + catalog_name + ))), + } } LogicalPlanType::Sort(sort) => { let input: LogicalPlan = @@ -477,6 +505,7 @@ impl AsLogicalPlan for LogicalPlanNode { fn try_from_logical_plan( plan: &LogicalPlan, + catalog_list: &dyn CatalogList, extension_codec: &dyn LogicalExtensionCodec, ) -> Result where @@ -505,84 +534,108 @@ impl AsLogicalPlan for LogicalPlanNode { } LogicalPlan::TableScan(TableScan { table_name, - source, filters, projection, .. }) => { - let schema = source.schema(); - let source = source.as_any(); - - let projection = match projection { - None => None, - Some(columns) => { - let column_names = columns - .iter() - .map(|i| schema.field(*i).name().to_owned()) - .collect(); - Some(protobuf::ProjectionColumns { - columns: column_names, - }) + for catalog_name in catalog_list.catalog_names() { + let catalog = catalog_list.catalog(&catalog_name).unwrap(); + for schema_name in catalog.schema_names() { + let schema = catalog.schema(&schema_name).unwrap(); + println!( + "{}.{} tables = {:?}", + catalog_name, + schema_name, + schema.table_names() + ); } - }; - let schema: datafusion_proto::protobuf::Schema = schema.as_ref().into(); - - let filters: Vec = filters - .iter() - .map(|filter| filter.try_into()) - .collect::, _>>()?; + } + match get_table_provider(catalog_list, &table_name) { + Some(source) => { + let schema = source.schema(); + let source = source.as_any(); + + let projection = match projection { + None => None, + Some(columns) => { + let column_names = columns + .iter() + .map(|i| schema.field(*i).name().to_owned()) + .collect(); + Some(protobuf::ProjectionColumns { + columns: column_names, + }) + } + }; + let schema: datafusion_proto::protobuf::Schema = + schema.as_ref().into(); - if let Some(listing_table) = source.downcast_ref::() { - let any = listing_table.options().format.as_any(); - let file_format_type = if let Some(parquet) = - any.downcast_ref::() - { - FileFormatType::Parquet(protobuf::ParquetFormat { - enable_pruning: parquet.enable_pruning(), - }) - } else if let Some(csv) = any.downcast_ref::() { - FileFormatType::Csv(protobuf::CsvFormat { - delimiter: byte_to_string(csv.delimiter())?, - has_header: csv.has_header(), - }) - } else if any.is::() { - FileFormatType::Avro(protobuf::AvroFormat {}) - } else { - return Err(proto_error(format!( - "Error converting file format, {:?} is invalid as a datafusion foramt.", - listing_table.options().format - ))); - }; - Ok(protobuf::LogicalPlanNode { - logical_plan_type: Some(LogicalPlanType::ListingScan( - protobuf::ListingTableScanNode { - file_format_type: Some(file_format_type), - table_name: table_name.to_owned(), - collect_stat: listing_table.options().collect_stat, - file_extension: listing_table - .options() - .file_extension - .clone(), - table_partition_cols: listing_table - .options() - .table_partition_cols - .clone(), - path: listing_table.table_path().to_owned(), - schema: Some(schema), - projection, - filters, - target_partitions: listing_table - .options() - .target_partitions - as u32, - }, - )), - }) - } else { - Err(BallistaError::General(format!( - "logical plan to_proto unsupported table provider {:?}", - source - ))) + let filters: Vec = + filters + .iter() + .map(|filter| filter.try_into()) + .collect::, _>>()?; + + if let Some(listing_table) = source.downcast_ref::() + { + let any = listing_table.options().format.as_any(); + let file_format_type = if let Some(parquet) = + any.downcast_ref::() + { + FileFormatType::Parquet(protobuf::ParquetFormat { + enable_pruning: parquet.enable_pruning(), + }) + } else if let Some(csv) = any.downcast_ref::() { + FileFormatType::Csv(protobuf::CsvFormat { + delimiter: byte_to_string(csv.delimiter())?, + has_header: csv.has_header(), + }) + } else if any.is::() { + FileFormatType::Avro(protobuf::AvroFormat {}) + } else { + return Err(proto_error(format!( + "Error converting file format, {:?} is invalid as a datafusion foramt.", + listing_table.options().format + ))); + }; + Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::ListingScan( + protobuf::ListingTableScanNode { + file_format_type: Some(file_format_type), + table_name: table_name.to_owned(), + collect_stat: listing_table + .options() + .collect_stat, + file_extension: listing_table + .options() + .file_extension + .clone(), + table_partition_cols: listing_table + .options() + .table_partition_cols + .clone(), + path: listing_table.table_path().to_owned(), + schema: Some(schema), + projection, + filters, + target_partitions: listing_table + .options() + .target_partitions + as u32, + }, + )), + }) + } else { + Err(BallistaError::General(format!( + "logical plan to_proto unsupported table provider {:?}", + source + ))) + } + } + _ => Err(BallistaError::General(format!( + "logical plan to_proto table '{}' does not exist in catalog", + table_name + ))), } } LogicalPlan::Projection(Projection { @@ -593,6 +646,7 @@ impl AsLogicalPlan for LogicalPlanNode { input: Some(Box::new( protobuf::LogicalPlanNode::try_from_logical_plan( input.as_ref(), + catalog_list, extension_codec, )?, )), @@ -611,6 +665,7 @@ impl AsLogicalPlan for LogicalPlanNode { let input: protobuf::LogicalPlanNode = protobuf::LogicalPlanNode::try_from_logical_plan( input.as_ref(), + catalog_list, extension_codec, )?; Ok(protobuf::LogicalPlanNode { @@ -628,6 +683,7 @@ impl AsLogicalPlan for LogicalPlanNode { let input: protobuf::LogicalPlanNode = protobuf::LogicalPlanNode::try_from_logical_plan( input.as_ref(), + catalog_list, extension_codec, )?; Ok(protobuf::LogicalPlanNode { @@ -651,6 +707,7 @@ impl AsLogicalPlan for LogicalPlanNode { let input: protobuf::LogicalPlanNode = protobuf::LogicalPlanNode::try_from_logical_plan( input.as_ref(), + catalog_list, extension_codec, )?; Ok(protobuf::LogicalPlanNode { @@ -681,11 +738,13 @@ impl AsLogicalPlan for LogicalPlanNode { let left: protobuf::LogicalPlanNode = protobuf::LogicalPlanNode::try_from_logical_plan( left.as_ref(), + catalog_list, extension_codec, )?; let right: protobuf::LogicalPlanNode = protobuf::LogicalPlanNode::try_from_logical_plan( right.as_ref(), + catalog_list, extension_codec, )?; let (left_join_column, right_join_column) = @@ -711,6 +770,7 @@ impl AsLogicalPlan for LogicalPlanNode { let input: protobuf::LogicalPlanNode = protobuf::LogicalPlanNode::try_from_logical_plan( input.as_ref(), + catalog_list, extension_codec, )?; Ok(protobuf::LogicalPlanNode { @@ -726,6 +786,7 @@ impl AsLogicalPlan for LogicalPlanNode { let input: protobuf::LogicalPlanNode = protobuf::LogicalPlanNode::try_from_logical_plan( input.as_ref(), + catalog_list, extension_codec, )?; Ok(protobuf::LogicalPlanNode { @@ -741,6 +802,7 @@ impl AsLogicalPlan for LogicalPlanNode { let input: protobuf::LogicalPlanNode = protobuf::LogicalPlanNode::try_from_logical_plan( input.as_ref(), + catalog_list, extension_codec, )?; let selection_expr: Vec = @@ -764,6 +826,7 @@ impl AsLogicalPlan for LogicalPlanNode { let input: protobuf::LogicalPlanNode = protobuf::LogicalPlanNode::try_from_logical_plan( input.as_ref(), + catalog_list, extension_codec, )?; @@ -871,6 +934,7 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlan::Analyze(a) => { let input = protobuf::LogicalPlanNode::try_from_logical_plan( a.input.as_ref(), + catalog_list, extension_codec, )?; Ok(protobuf::LogicalPlanNode { @@ -885,6 +949,7 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlan::Explain(a) => { let input = protobuf::LogicalPlanNode::try_from_logical_plan( a.plan.as_ref(), + catalog_list, extension_codec, )?; Ok(protobuf::LogicalPlanNode { @@ -903,6 +968,7 @@ impl AsLogicalPlan for LogicalPlanNode { .map(|i| { protobuf::LogicalPlanNode::try_from_logical_plan( i, + catalog_list, extension_codec, ) }) @@ -916,10 +982,12 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { let left = protobuf::LogicalPlanNode::try_from_logical_plan( left.as_ref(), + catalog_list, extension_codec, )?; let right = protobuf::LogicalPlanNode::try_from_logical_plan( right.as_ref(), + catalog_list, extension_codec, )?; Ok(protobuf::LogicalPlanNode { @@ -942,6 +1010,7 @@ impl AsLogicalPlan for LogicalPlanNode { .map(|i| { protobuf::LogicalPlanNode::try_from_logical_plan( i, + catalog_list, extension_codec, ) }) @@ -982,6 +1051,9 @@ mod roundtrip_tests { use crate::serde::{AsLogicalPlan, BallistaCodec}; use async_trait::async_trait; use core::panic; + use datafusion::catalog::catalog::{ + get_table_provider, CatalogList, MemoryCatalogList, + }; use datafusion::{ arrow::datatypes::{DataType, Field, Schema}, datafusion_data_access::{ @@ -1040,6 +1112,29 @@ mod roundtrip_tests { } } + fn roundtrip_test_new(plan: &LogicalPlan, ctx: &SessionContext) { + let codec: BallistaCodec = + BallistaCodec::default(); + + let proto: protobuf::LogicalPlanNode = + protobuf::LogicalPlanNode::try_from_logical_plan( + plan, + ctx.state.read().catalog_list.as_ref(), + codec.logical_extension_codec.as_ref(), + ) + .expect("from logical plan"); + + let ctx2 = SessionContext::new(); + + let round_trip: LogicalPlan = proto + .try_into_logical_plan(&ctx2, codec.logical_extension_codec()) + .expect("to logical plan"); + + // TODO compare catalogs in the contexts + + assert_eq!(format!("{:?}", plan), format!("{:?}", round_trip)); + } + // Given a identity of a LogicalPlan converts it to protobuf and back, using debug formatting to test equality. macro_rules! roundtrip_test { ($initial_struct:ident, $proto_type:ty, $struct_type:ty) => { @@ -1064,6 +1159,7 @@ mod roundtrip_tests { let proto: protobuf::LogicalPlanNode = protobuf::LogicalPlanNode::try_from_logical_plan( &$initial_struct, + ctx.state.read().catalog_list.as_ref(), codec.logical_extension_codec(), ) .expect("from logical plan"); @@ -1076,23 +1172,6 @@ mod roundtrip_tests { format!("{:?}", round_trip) ); }; - ($initial_struct:ident, $ctx:ident) => { - let codec: BallistaCodec< - protobuf::LogicalPlanNode, - protobuf::PhysicalPlanNode, - > = BallistaCodec::default(); - let proto: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan(&$initial_struct) - .expect("from logical plan"); - let round_trip: LogicalPlan = proto - .try_into_logical_plan(&$ctx, codec.logical_extension_codec()) - .expect("to logical plan"); - - assert_eq!( - format!("{:?}", $initial_struct), - format!("{:?}", round_trip) - ); - }; } #[tokio::test] @@ -1112,18 +1191,24 @@ mod roundtrip_tests { Field::new("salary", DataType::Int32, false), ]); + let scan_info = LogicalPlanBuilder::scan_csv( + Arc::new(LocalFileSystem {}), + "employee.csv", + CsvReadOptions::new().schema(&schema).has_header(true), + Some(vec![3, 4]), + 4, + ) + .await?; + + let ctx = SessionContext::new(); + ctx.register_table(scan_info.table_name.as_ref(), scan_info.provider.clone())?; + let plan = std::sync::Arc::new( - LogicalPlanBuilder::scan_csv( - Arc::new(LocalFileSystem {}), - "employee.csv", - CsvReadOptions::new().schema(&schema).has_header(true), - Some(vec![3, 4]), - 4, - ) - .await - .and_then(|plan| plan.sort(vec![col("salary")])) - .and_then(|plan| plan.build()) - .map_err(BallistaError::DataFusionError)?, + scan_info + .builder + .sort(vec![col("salary")])? + .build() + .map_err(BallistaError::DataFusionError)?, ); for partition_count in test_partition_counts.iter() { @@ -1134,7 +1219,7 @@ mod roundtrip_tests { partitioning_scheme: rr_repartition, }); - roundtrip_test!(roundtrip_plan); + roundtrip_test_new(&roundtrip_plan, &ctx); let h_repartition = Partitioning::Hash(test_expr.clone(), *partition_count); @@ -1143,7 +1228,7 @@ mod roundtrip_tests { partitioning_scheme: h_repartition, }); - roundtrip_test!(roundtrip_plan); + roundtrip_test_new(&roundtrip_plan, &ctx); let no_expr_hrepartition = Partitioning::Hash(Vec::new(), *partition_count); @@ -1152,7 +1237,7 @@ mod roundtrip_tests { partitioning_scheme: no_expr_hrepartition, }); - roundtrip_test!(roundtrip_plan); + roundtrip_test_new(&roundtrip_plan, &ctx); } Ok(()) @@ -1206,35 +1291,47 @@ mod roundtrip_tests { Field::new("salary", DataType::Int32, false), ]); - let verbose_plan = LogicalPlanBuilder::scan_csv( + let scan_info = LogicalPlanBuilder::scan_csv( Arc::new(LocalFileSystem {}), "employee.csv", CsvReadOptions::new().schema(&schema).has_header(true), Some(vec![3, 4]), 4, ) - .await - .and_then(|plan| plan.sort(vec![col("salary")])) - .and_then(|plan| plan.explain(true, true)) - .and_then(|plan| plan.build()) - .map_err(BallistaError::DataFusionError)?; + .await?; + + let ctx = SessionContext::new(); + ctx.register_table(scan_info.table_name.as_str(), scan_info.provider.clone())?; - let plan = LogicalPlanBuilder::scan_csv( + let verbose_plan = scan_info + .builder + .sort(vec![col("salary")])? + .explain(true, true)? + .build() + .map_err(BallistaError::DataFusionError)?; + + roundtrip_test_new(&verbose_plan, &ctx); + + let scan_info = LogicalPlanBuilder::scan_csv( Arc::new(LocalFileSystem {}), "employee.csv", CsvReadOptions::new().schema(&schema).has_header(true), Some(vec![3, 4]), 4, ) - .await - .and_then(|plan| plan.sort(vec![col("salary")])) - .and_then(|plan| plan.explain(false, true)) - .and_then(|plan| plan.build()) - .map_err(BallistaError::DataFusionError)?; + .await?; - roundtrip_test!(plan); + let ctx = SessionContext::new(); + ctx.register_table(scan_info.table_name.as_str(), scan_info.provider.clone())?; + + let plan = scan_info + .builder + .sort(vec![col("salary")])? + .explain(false, true)? + .build() + .map_err(BallistaError::DataFusionError)?; - roundtrip_test!(verbose_plan); + roundtrip_test_new(&plan, &ctx); Ok(()) } @@ -1249,35 +1346,47 @@ mod roundtrip_tests { Field::new("salary", DataType::Int32, false), ]); - let verbose_plan = LogicalPlanBuilder::scan_csv( + let scan_info = LogicalPlanBuilder::scan_csv( Arc::new(LocalFileSystem {}), "employee.csv", CsvReadOptions::new().schema(&schema).has_header(true), Some(vec![3, 4]), 4, ) - .await - .and_then(|plan| plan.sort(vec![col("salary")])) - .and_then(|plan| plan.explain(true, false)) - .and_then(|plan| plan.build()) - .map_err(BallistaError::DataFusionError)?; + .await?; - let plan = LogicalPlanBuilder::scan_csv( + let ctx = SessionContext::new(); + ctx.register_table(scan_info.table_name.as_ref(), scan_info.provider.clone())?; + + let verbose_plan = scan_info + .builder + .sort(vec![col("salary")])? + .explain(true, false)? + .build() + .map_err(BallistaError::DataFusionError)?; + + roundtrip_test_new(&verbose_plan, &ctx); + + let scan_info = LogicalPlanBuilder::scan_csv( Arc::new(LocalFileSystem {}), "employee.csv", CsvReadOptions::new().schema(&schema).has_header(true), Some(vec![3, 4]), 4, ) - .await - .and_then(|plan| plan.sort(vec![col("salary")])) - .and_then(|plan| plan.explain(false, false)) - .and_then(|plan| plan.build()) - .map_err(BallistaError::DataFusionError)?; + .await?; + + let ctx = SessionContext::new(); + ctx.register_table(scan_info.table_name.as_ref(), scan_info.provider.clone())?; - roundtrip_test!(plan); + let plan = scan_info + .builder + .sort(vec![col("salary")])? + .explain(false, false)? + .build() + .map_err(BallistaError::DataFusionError)?; - roundtrip_test!(verbose_plan); + roundtrip_test_new(&plan, &ctx); Ok(()) } @@ -1292,30 +1401,42 @@ mod roundtrip_tests { Field::new("salary", DataType::Int32, false), ]); - let scan_plan = LogicalPlanBuilder::scan_csv( + let ctx = SessionContext::new(); + + let scan_info = LogicalPlanBuilder::scan_csv( Arc::new(LocalFileSystem {}), "employee1", CsvReadOptions::new().schema(&schema).has_header(true), Some(vec![0, 3, 4]), 4, ) - .await? - .build() - .map_err(BallistaError::DataFusionError)?; + .await?; + + ctx.register_table(scan_info.table_name.as_ref(), scan_info.provider.clone())?; + + let scan_plan = scan_info + .builder + .build() + .map_err(BallistaError::DataFusionError)?; - let plan = LogicalPlanBuilder::scan_csv( + let scan_info = LogicalPlanBuilder::scan_csv( Arc::new(LocalFileSystem {}), "employee2", CsvReadOptions::new().schema(&schema).has_header(true), Some(vec![0, 3, 4]), 4, ) - .await - .and_then(|plan| plan.join(&scan_plan, JoinType::Inner, (vec!["id"], vec!["id"]))) - .and_then(|plan| plan.build()) - .map_err(BallistaError::DataFusionError)?; + .await?; + + ctx.register_table(scan_info.table_name.as_ref(), scan_info.provider.clone())?; + + let plan = scan_info + .builder + .join(&scan_plan, JoinType::Inner, (vec!["id"], vec!["id"]))? + .build() + .map_err(BallistaError::DataFusionError)?; - roundtrip_test!(plan); + roundtrip_test_new(&plan, &ctx); Ok(()) } @@ -1329,35 +1450,44 @@ mod roundtrip_tests { Field::new("salary", DataType::Int32, false), ]); - let plan = LogicalPlanBuilder::scan_csv( + let scan_info = LogicalPlanBuilder::scan_csv( Arc::new(LocalFileSystem {}), "employee.csv", CsvReadOptions::new().schema(&schema).has_header(true), Some(vec![3, 4]), 4, ) - .await - .and_then(|plan| plan.sort(vec![col("salary")])) - .and_then(|plan| plan.build()) - .map_err(BallistaError::DataFusionError)?; - roundtrip_test!(plan); + .await?; + + let ctx = SessionContext::new(); + ctx.register_table(scan_info.table_name.as_ref(), scan_info.provider.clone())?; + + let plan = scan_info + .builder + .sort(vec![col("salary")])? + .build() + .map_err(BallistaError::DataFusionError)?; + + roundtrip_test_new(&plan, &ctx); Ok(()) } #[tokio::test] async fn roundtrip_empty_relation() -> Result<()> { + let ctx = SessionContext::new(); + let plan_false = LogicalPlanBuilder::empty(false) .build() .map_err(BallistaError::DataFusionError)?; - roundtrip_test!(plan_false); + roundtrip_test_new(&plan_false, &ctx); let plan_true = LogicalPlanBuilder::empty(true) .build() .map_err(BallistaError::DataFusionError)?; - roundtrip_test!(plan_true); + roundtrip_test_new(&plan_true, &ctx); Ok(()) } @@ -1372,19 +1502,27 @@ mod roundtrip_tests { Field::new("salary", DataType::Int32, false), ]); - let plan = LogicalPlanBuilder::scan_csv( + let scan_info = LogicalPlanBuilder::scan_csv( Arc::new(LocalFileSystem {}), "employee.csv", CsvReadOptions::new().schema(&schema).has_header(true), Some(vec![3, 4]), 4, ) - .await - .and_then(|plan| plan.aggregate(vec![col("state")], vec![max(col("salary"))])) - .and_then(|plan| plan.build()) - .map_err(BallistaError::DataFusionError)?; + .await?; + + assert_eq!("employee_csv", scan_info.table_name); - roundtrip_test!(plan); + let ctx = SessionContext::new(); + ctx.register_table(scan_info.table_name.as_str(), scan_info.provider.clone())?; + + let plan = scan_info + .builder + .aggregate(vec![col("state")], vec![max(col("salary"))])? + .build() + .map_err(BallistaError::DataFusionError)?; + + roundtrip_test_new(&plan, &ctx); Ok(()) } @@ -1410,20 +1548,27 @@ mod roundtrip_tests { Field::new("salary", DataType::Int32, false), ]); - let plan = LogicalPlanBuilder::scan_csv( + let scan_info = LogicalPlanBuilder::scan_csv_with_name( custom_object_store.clone(), "test://employee.csv", CsvReadOptions::new().schema(&schema).has_header(true), Some(vec![3, 4]), + "employee", 4, ) - .await - .and_then(|plan| plan.build()) - .map_err(BallistaError::DataFusionError)?; + .await?; + + let ctx = SessionContext::new(); + ctx.register_table(scan_info.table_name.as_ref(), scan_info.provider.clone())?; + + let plan = scan_info.build().map_err(BallistaError::DataFusionError)?; + + let catalog_list: Arc = Arc::new(MemoryCatalogList::default()); let proto: protobuf::LogicalPlanNode = protobuf::LogicalPlanNode::try_from_logical_plan( &plan, + catalog_list.as_ref(), codec.logical_extension_codec(), ) .expect("from logical plan"); @@ -1435,7 +1580,9 @@ mod roundtrip_tests { let round_trip_store = match round_trip { LogicalPlan::TableScan(scan) => { - match scan.source.as_ref().as_any().downcast_ref::() { + let source = + get_table_provider(catalog_list.as_ref(), &scan.table_name).unwrap(); + match source.as_ref().as_any().downcast_ref::() { Some(listing_table) => { format!("{:?}", listing_table.object_store()) } diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index ed41ce61c4c46..1fdb23353377c 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -30,6 +30,7 @@ use datafusion::logical_plan::{ use crate::{error::BallistaError, serde::scheduler::Action as BallistaAction}; +use datafusion::catalog::catalog::CatalogList; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::logical_plan::plan::Extension; use datafusion::physical_plan::ExecutionPlan; @@ -76,6 +77,7 @@ pub trait AsLogicalPlan: Debug + Send + Sync + Clone { fn try_from_logical_plan( plan: &LogicalPlan, + catalog_list: &dyn CatalogList, extension_codec: &dyn LogicalExtensionCodec, ) -> Result where @@ -726,7 +728,11 @@ mod tests { let extension_codec = TopKExtensionCodec {}; - let proto = LogicalPlanNode::try_from_logical_plan(&topk_plan, &extension_codec)?; + let proto = LogicalPlanNode::try_from_logical_plan( + &topk_plan, + ctx.state.read().catalog_list.as_ref(), + &extension_codec, + )?; let logical_round_trip = proto.try_into_logical_plan(&ctx, &extension_codec)?; assert_eq!( diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 7e759bd606ebf..a0dd199ff045f 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -165,7 +165,7 @@ pub(crate) fn parse_physical_expr( .collect::, _>>()?; // TODO Do not create new the ExecutionProps - let execution_props = ExecutionProps::new(); + let execution_props = ExecutionProps::default(); let fun_expr = functions::create_physical_fun( &(&scalar_function).into(), diff --git a/ballista/rust/core/src/serde/physical_plan/mod.rs b/ballista/rust/core/src/serde/physical_plan/mod.rs index 955725957d11a..dff5fd2dc01d5 100644 --- a/ballista/rust/core/src/serde/physical_plan/mod.rs +++ b/ballista/rust/core/src/serde/physical_plan/mod.rs @@ -1334,7 +1334,7 @@ mod roundtrip_tests { let input = Arc::new(EmptyExec::new(false, schema.clone())); - let execution_props = ExecutionProps::new(); + let execution_props = ExecutionProps::default(); let fun_expr = functions::create_physical_fun( &BuiltinScalarFunction::Abs, diff --git a/ballista/rust/executor/src/execution_loop.rs b/ballista/rust/executor/src/execution_loop.rs index 06128f8dbc532..d2c3a40de4fd2 100644 --- a/ballista/rust/executor/src/execution_loop.rs +++ b/ballista/rust/executor/src/execution_loop.rs @@ -36,6 +36,7 @@ use ballista_core::error::BallistaError; use ballista_core::serde::physical_plan::from_proto::parse_protobuf_hash_partitioning; use ballista_core::serde::scheduler::ExecutorSpecification; use ballista_core::serde::{AsExecutionPlan, AsLogicalPlan, BallistaCodec}; +use datafusion::catalog::catalog::MemoryCatalogList; use datafusion::execution::context::TaskContext; pub async fn poll_loop( diff --git a/ballista/rust/executor/src/executor.rs b/ballista/rust/executor/src/executor.rs index fa092137abd6e..fa00bf3853353 100644 --- a/ballista/rust/executor/src/executor.rs +++ b/ballista/rust/executor/src/executor.rs @@ -25,6 +25,7 @@ use ballista_core::error::BallistaError; use ballista_core::execution_plans::ShuffleWriterExec; use ballista_core::serde::protobuf; use ballista_core::serde::protobuf::ExecutorRegistration; +use datafusion::catalog::catalog::CatalogList; use datafusion::error::DataFusionError; use datafusion::execution::context::TaskContext; use datafusion::execution::runtime_env::RuntimeEnv; diff --git a/ballista/rust/executor/src/executor_server.rs b/ballista/rust/executor/src/executor_server.rs index 11a5c75527ccd..e8f832ef3efe9 100644 --- a/ballista/rust/executor/src/executor_server.rs +++ b/ballista/rust/executor/src/executor_server.rs @@ -38,6 +38,7 @@ use ballista_core::serde::protobuf::{ }; use ballista_core::serde::scheduler::ExecutorState; use ballista_core::serde::{AsExecutionPlan, AsLogicalPlan, BallistaCodec}; +use datafusion::catalog::catalog::MemoryCatalogList; use datafusion::execution::context::TaskContext; use datafusion::physical_plan::ExecutionPlan; diff --git a/ballista/rust/scheduler/src/scheduler_server/mod.rs b/ballista/rust/scheduler/src/scheduler_server/mod.rs index 4b47e5239bde3..d8e654b6128b7 100644 --- a/ballista/rust/scheduler/src/scheduler_server/mod.rs +++ b/ballista/rust/scheduler/src/scheduler_server/mod.rs @@ -603,6 +603,7 @@ mod test { LogicalPlanBuilder::scan_empty(None, &schema, Some(vec![0, 1])) .unwrap() + .builder .aggregate(vec![col("id")], vec![sum(col("gmv"))]) .unwrap() .build() diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index 1060bd2e0f94a..1ed04ed74d562 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -1394,7 +1394,7 @@ mod tests { let config = SessionConfig::new() .with_target_partitions(1) .with_batch_size(10); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::with_config(config.clone()); let codec: BallistaCodec< protobuf::LogicalPlanNode, protobuf::PhysicalPlanNode, @@ -1428,11 +1428,14 @@ mod tests { let proto: protobuf::LogicalPlanNode = protobuf::LogicalPlanNode::try_from_logical_plan( &plan, + ctx.state.read().catalog_list.as_ref(), codec.logical_extension_codec(), ) .unwrap(); + + let round_trip_ctx = SessionContext::with_config(config.clone()); let round_trip: LogicalPlan = (&proto) - .try_into_logical_plan(&ctx, codec.logical_extension_codec()) + .try_into_logical_plan(&round_trip_ctx, codec.logical_extension_codec()) .unwrap(); assert_eq!( format!("{:?}", plan), @@ -1445,11 +1448,14 @@ mod tests { let proto: protobuf::LogicalPlanNode = protobuf::LogicalPlanNode::try_from_logical_plan( &plan, + ctx.state.read().catalog_list.as_ref(), codec.logical_extension_codec(), ) .unwrap(); + + let round_trip_ctx = SessionContext::with_config(config.clone()); let round_trip: LogicalPlan = (&proto) - .try_into_logical_plan(&ctx, codec.logical_extension_codec()) + .try_into_logical_plan(&round_trip_ctx, codec.logical_extension_codec()) .unwrap(); assert_eq!( format!("{:?}", plan), diff --git a/datafusion/core/src/catalog/catalog.rs b/datafusion/core/src/catalog/catalog.rs index 9a932ee35e1c1..140396c94a0a2 100644 --- a/datafusion/core/src/catalog/catalog.rs +++ b/datafusion/core/src/catalog/catalog.rs @@ -19,6 +19,8 @@ //! representing collections of named schemas. use crate::catalog::schema::SchemaProvider; +use crate::catalog::TableReference; +use crate::datasource::TableProvider; use datafusion_common::{DataFusionError, Result}; use parking_lot::RwLock; use std::any::Any; @@ -46,6 +48,46 @@ pub trait CatalogList: Sync + Send { fn catalog(&self, name: &str) -> Option>; } +/// Get a TableProvider from the catalog +pub fn get_table_provider( + catalog_list: &dyn CatalogList, + table_name: &str, +) -> Option> { + // TODO do we have these defined as defaults somewhere? + let mut catalog_name = "datafusion".to_owned(); + let mut schema_name = "public".to_owned(); + let table_ref_name; + + let table_ref: TableReference = table_name.into(); + match table_ref { + TableReference::Bare { table } => table_ref_name = table.to_string(), + TableReference::Partial { schema, table } => { + schema_name = schema.to_string(); + table_ref_name = table.to_string(); + } + TableReference::Full { + catalog, + schema, + table, + } => { + catalog_name = catalog.to_string(); + schema_name = schema.to_string(); + table_ref_name = table.to_string(); + } + } + + match catalog_list.catalog(&catalog_name) { + Some(catalog) => match catalog.schema(&schema_name) { + Some(schema) => match schema.table(&table_ref_name) { + Some(table) => Some(table.clone()), + _ => None, + }, + _ => None, + }, + _ => None, + } +} + /// Simple in-memory list of catalogs pub struct MemoryCatalogList { /// Collection of catalogs containing schemas and ultimately TableProviders diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index 9414f2887d3ed..5fc8d6096abcd 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -658,8 +658,8 @@ mod tests { #[tokio::test] async fn select_columns() -> Result<()> { // build plan using Table API - - let t = test_table().await?; + let mut ctx = SessionContext::new(); + let t = test_table(&mut ctx).await?; let t2 = t.select_columns(&["c1", "c2", "c11"])?; let plan = t2.to_logical_plan(); @@ -675,7 +675,8 @@ mod tests { #[tokio::test] async fn select_expr() -> Result<()> { // build plan using Table API - let t = test_table().await?; + let mut ctx = SessionContext::new(); + let t = test_table(&mut ctx).await?; let t2 = t.select(vec![col("c1"), col("c2"), col("c11")])?; let plan = t2.to_logical_plan(); @@ -691,7 +692,8 @@ mod tests { #[tokio::test] async fn select_with_window_exprs() -> Result<()> { // build plan using Table API - let t = test_table().await?; + let mut ctx = SessionContext::new(); + let t = test_table(&mut ctx).await?; let first_row = Expr::WindowFunction { fun: window_functions::WindowFunction::BuiltInWindowFunction( window_functions::BuiltInWindowFunction::FirstValue, @@ -716,7 +718,8 @@ mod tests { #[tokio::test] async fn aggregate() -> Result<()> { // build plan using DataFrame API - let df = test_table().await?; + let mut ctx = SessionContext::new(); + let df = test_table(&mut ctx).await?; let group_expr = vec![col("c1")]; let aggr_expr = vec![ min(col("c12")), @@ -749,8 +752,9 @@ mod tests { #[tokio::test] async fn join() -> Result<()> { - let left = test_table().await?.select_columns(&["c1", "c2"])?; - let right = test_table_with_name("c2") + let mut ctx = SessionContext::new(); + let left = test_table(&mut ctx).await?.select_columns(&["c1", "c2"])?; + let right = test_table_with_name(&mut ctx, "c2") .await? .select_columns(&["c1", "c3"])?; let left_rows = left.collect().await?; @@ -766,7 +770,8 @@ mod tests { #[tokio::test] async fn limit() -> Result<()> { // build query using Table API - let t = test_table().await?; + let mut ctx = SessionContext::new(); + let t = test_table(&mut ctx).await?; let t2 = t.select_columns(&["c1", "c2", "c11"])?.limit(10)?; let plan = t2.to_logical_plan(); @@ -783,7 +788,8 @@ mod tests { #[tokio::test] async fn explain() -> Result<()> { // build query using Table API - let df = test_table().await?; + let mut ctx = SessionContext::new(); + let df = test_table(&mut ctx).await?; let df = df .select_columns(&["c1", "c2", "c11"])? .limit(10)? @@ -839,7 +845,8 @@ mod tests { #[tokio::test] async fn sendable() { - let df = test_table().await.unwrap(); + let mut ctx = SessionContext::new(); + let df = test_table(&mut ctx).await.unwrap(); // dataframes should be sendable between threads/tasks let task = tokio::task::spawn(async move { df.select_columns(&["c1"]) @@ -850,7 +857,8 @@ mod tests { #[tokio::test] async fn intersect() -> Result<()> { - let df = test_table().await?.select_columns(&["c1", "c3"])?; + let mut ctx = SessionContext::new(); + let df = test_table(&mut ctx).await?.select_columns(&["c1", "c3"])?; let plan = df.intersect(df.clone())?; let result = plan.to_logical_plan(); let expected = create_plan( @@ -864,7 +872,8 @@ mod tests { #[tokio::test] async fn except() -> Result<()> { - let df = test_table().await?.select_columns(&["c1", "c3"])?; + let mut ctx = SessionContext::new(); + let df = test_table(&mut ctx).await?.select_columns(&["c1", "c3"])?; let plan = df.except(df.clone())?; let result = plan.to_logical_plan(); let expected = create_plan( @@ -878,8 +887,8 @@ mod tests { #[tokio::test] async fn register_table() -> Result<()> { - let df = test_table().await?.select_columns(&["c1", "c12"])?; - let ctx = SessionContext::new(); + let mut ctx = SessionContext::new(); + let df = test_table(&mut ctx).await?.select_columns(&["c1", "c12"])?; let df_impl = Arc::new(DataFrame::new(ctx.state.clone(), &df.to_logical_plan())); // register a dataframe as a table @@ -942,14 +951,16 @@ mod tests { ctx.create_logical_plan(sql) } - async fn test_table_with_name(name: &str) -> Result> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx, name).await?; + async fn test_table_with_name( + ctx: &mut SessionContext, + name: &str, + ) -> Result> { + register_aggregate_csv(ctx, name).await?; ctx.table(name) } - async fn test_table() -> Result> { - test_table_with_name("aggregate_test_100").await + async fn test_table(ctx: &mut SessionContext) -> Result> { + test_table_with_name(ctx, "aggregate_test_100").await } async fn register_aggregate_csv( diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 4bb77fb93c5b3..7048b92773efd 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -100,9 +100,9 @@ use super::options::{ }; /// The default catalog name - this impacts what SQL queries use if not specified -const DEFAULT_CATALOG: &str = "datafusion"; +pub const DEFAULT_CATALOG: &str = "datafusion"; /// The default schema name - this impacts what SQL queries use if not specified -const DEFAULT_SCHEMA: &str = "public"; +pub const DEFAULT_SCHEMA: &str = "public"; /// SessionContext is the main interface for executing queries with DataFusion. It stands for /// the connection between user and DataFusion/Ballista cluster. @@ -474,17 +474,18 @@ impl SessionContext { let uri: String = uri.into(); let (object_store, path) = self.runtime_env().object_store(&uri)?; let target_partitions = self.copied_config().target_partitions; + let scan = &LogicalPlanBuilder::scan_avro( + object_store, + path, + options, + None, + target_partitions, + ) + .await?; + self.register_table(scan.table_name.to_owned().as_str(), scan.provider.clone())?; Ok(Arc::new(DataFrame::new( self.state.clone(), - &LogicalPlanBuilder::scan_avro( - object_store, - path, - options, - None, - target_partitions, - ) - .await? - .build()?, + &scan.builder.build()?, ))) } @@ -497,17 +498,18 @@ impl SessionContext { let uri: String = uri.into(); let (object_store, path) = self.runtime_env().object_store(&uri)?; let target_partitions = self.copied_config().target_partitions; + let scan = LogicalPlanBuilder::scan_json( + object_store, + path, + options, + None, + target_partitions, + ) + .await?; + self.register_table(scan.table_name.as_str(), scan.provider.clone())?; Ok(Arc::new(DataFrame::new( self.state.clone(), - &LogicalPlanBuilder::scan_json( - object_store, - path, - options, - None, - target_partitions, - ) - .await? - .build()?, + &scan.builder.build()?, ))) } @@ -528,17 +530,18 @@ impl SessionContext { let uri: String = uri.into(); let (object_store, path) = self.runtime_env().object_store(&uri)?; let target_partitions = self.copied_config().target_partitions; + let scan = LogicalPlanBuilder::scan_csv( + object_store, + path, + options, + None, + target_partitions, + ) + .await?; + self.register_table(scan.table_name.as_str(), scan.provider.clone())?; Ok(Arc::new(DataFrame::new( self.state.clone(), - &LogicalPlanBuilder::scan_csv( - object_store, - path, - options, - None, - target_partitions, - ) - .await? - .build()?, + &scan.builder.build()?, ))) } @@ -551,23 +554,28 @@ impl SessionContext { let uri: String = uri.into(); let (object_store, path) = self.runtime_env().object_store(&uri)?; let target_partitions = self.copied_config().target_partitions; - let logical_plan = LogicalPlanBuilder::scan_parquet( + let scan = LogicalPlanBuilder::scan_parquet( object_store, path, options, None, target_partitions, ) - .await? - .build()?; - Ok(Arc::new(DataFrame::new(self.state.clone(), &logical_plan))) + .await?; + self.register_table(scan.table_name.as_str(), scan.provider.clone())?; + Ok(Arc::new(DataFrame::new( + self.state.clone(), + &scan.builder.build()?, + ))) } /// Creates a DataFrame for reading a custom TableProvider. pub fn read_table(&self, provider: Arc) -> Result> { + let scan = LogicalPlanBuilder::scan(UNNAMED_TABLE, provider.clone(), None)?; + self.register_table(scan.table_name.as_str(), scan.provider.clone())?; Ok(Arc::new(DataFrame::new( self.state.clone(), - &LogicalPlanBuilder::scan(UNNAMED_TABLE, provider, None)?.build()?, + &scan.builder.build()?, ))) } @@ -762,13 +770,15 @@ impl SessionContext { let schema = self.state.read().schema_for_ref(table_ref)?; match schema.table(table_ref.table()) { Some(ref provider) => { - let plan = LogicalPlanBuilder::scan( + let scan = LogicalPlanBuilder::scan( table_ref.table(), Arc::clone(provider), None, - )? - .build()?; - Ok(Arc::new(DataFrame::new(self.state.clone(), &plan))) + )?; + Ok(Arc::new(DataFrame::new( + self.state.clone(), + &scan.builder.build()?, + ))) } _ => Err(DataFusionError::Plan(format!( "No table named '{}'", @@ -1226,7 +1236,7 @@ impl SessionState { scalar_functions: HashMap::new(), aggregate_functions: HashMap::new(), config, - execution_props: ExecutionProps::new(), + execution_props: ExecutionProps::default(), runtime_env: runtime, } } @@ -1351,7 +1361,11 @@ impl SessionState { debug!("Input logical plan:\n{}\n", plan.display_indent()); trace!("Full input logical plan:\n{:?}", plan); for optimizer in optimizers { - new_plan = optimizer.optimize(&new_plan, execution_props)?; + new_plan = optimizer.optimize( + &new_plan, + execution_props, + self.catalog_list.as_ref(), + )?; observer(&new_plan, optimizer.as_ref()); } debug!("Optimized logical plan:\n{}\n", new_plan.display_indent()); @@ -1463,6 +1477,7 @@ impl TaskContext { scalar_functions: HashMap>, aggregate_functions: HashMap>, runtime: Arc, + // catalog_list: Arc, ) -> Self { Self { task_id: Some(task_id), @@ -1471,6 +1486,7 @@ impl TaskContext { scalar_functions, aggregate_functions, runtime, + // catalog_list, } } @@ -2693,6 +2709,7 @@ mod tests { ])); let plan = LogicalPlanBuilder::scan_empty(None, schema.as_ref(), None)? + .builder .aggregate(vec![col("c1")], vec![sum(col("c2"))])? .project(vec![col("c1"), sum(col("c2")).alias("total_salary")])? .build()?; diff --git a/datafusion/core/src/logical_plan/builder.rs b/datafusion/core/src/logical_plan/builder.rs index c88b25d0a2251..308e37201e3d1 100644 --- a/datafusion/core/src/logical_plan/builder.rs +++ b/datafusion/core/src/logical_plan/builder.rs @@ -54,6 +54,32 @@ use crate::sql::utils::group_window_expr_by_sort_keys; /// Default table name for unnamed table pub const UNNAMED_TABLE: &str = "?table?"; +/// Return type from all scan methods +pub struct ScanInfo { + pub builder: LogicalPlanBuilder, + pub table_name: String, + pub provider: Arc, +} + +impl ScanInfo { + fn new( + builder: LogicalPlanBuilder, + table_name: String, + provider: Arc, + ) -> Self { + Self { + builder, + table_name, + provider, + } + } + + // TODO delete this hack once finished refactoring + pub fn build(&self) -> Result { + self.builder.build() + } +} + /// Builder for logical plans /// /// ``` @@ -83,6 +109,7 @@ pub const UNNAMED_TABLE: &str = "?table?"; /// &employee_schema(), /// None, /// )? +/// .builder /// // Keep only rows where salary < 1000 /// .filter(col("salary").lt_eq(lit(1000)))? /// // only show "last_name" in the final results @@ -93,6 +120,7 @@ pub const UNNAMED_TABLE: &str = "?table?"; /// # } /// ``` pub struct LogicalPlanBuilder { + /// The plan that is being built plan: LogicalPlan, } @@ -197,7 +225,7 @@ impl LogicalPlanBuilder { partitions: Vec>, schema: SchemaRef, projection: Option>, - ) -> Result { + ) -> Result { let provider = Arc::new(MemTable::try_new(schema, partitions)?); Self::scan(UNNAMED_TABLE, provider, projection) } @@ -209,7 +237,7 @@ impl LogicalPlanBuilder { options: CsvReadOptions<'_>, projection: Option>, target_partitions: usize, - ) -> Result { + ) -> Result { let path = path.into(); Self::scan_csv_with_name( object_store, @@ -230,7 +258,7 @@ impl LogicalPlanBuilder { projection: Option>, table_name: impl Into, target_partitions: usize, - ) -> Result { + ) -> Result { let listing_options = options.to_listing_options(target_partitions); let path: String = path.into(); @@ -258,7 +286,7 @@ impl LogicalPlanBuilder { options: ParquetReadOptions<'_>, projection: Option>, target_partitions: usize, - ) -> Result { + ) -> Result { let path = path.into(); Self::scan_parquet_with_name( object_store, @@ -279,7 +307,7 @@ impl LogicalPlanBuilder { projection: Option>, target_partitions: usize, table_name: impl Into, - ) -> Result { + ) -> Result { let listing_options = options.to_listing_options(target_partitions); let path: String = path.into(); @@ -303,7 +331,7 @@ impl LogicalPlanBuilder { options: AvroReadOptions<'_>, projection: Option>, target_partitions: usize, - ) -> Result { + ) -> Result { let path = path.into(); Self::scan_avro_with_name( object_store, @@ -324,7 +352,7 @@ impl LogicalPlanBuilder { projection: Option>, table_name: impl Into, target_partitions: usize, - ) -> Result { + ) -> Result { let listing_options = options.to_listing_options(target_partitions); let path: String = path.into(); @@ -352,7 +380,7 @@ impl LogicalPlanBuilder { options: NdJsonReadOptions<'_>, projection: Option>, target_partitions: usize, - ) -> Result { + ) -> Result { let path = path.into(); Self::scan_json_with_name( object_store, @@ -373,7 +401,7 @@ impl LogicalPlanBuilder { projection: Option>, table_name: impl Into, target_partitions: usize, - ) -> Result { + ) -> Result { let listing_options = options.to_listing_options(target_partitions); let path: String = path.into(); @@ -399,7 +427,7 @@ impl LogicalPlanBuilder { name: Option<&str>, table_schema: &Schema, projection: Option>, - ) -> Result { + ) -> Result { let table_schema = Arc::new(table_schema.clone()); let provider = Arc::new(EmptyTable::new(table_schema)); Self::scan(name.unwrap_or(UNNAMED_TABLE), provider, projection) @@ -410,7 +438,7 @@ impl LogicalPlanBuilder { table_name: impl Into, provider: Arc, projection: Option>, - ) -> Result { + ) -> Result { Self::scan_with_filters(table_name, provider, projection, vec![]) } @@ -420,7 +448,7 @@ impl LogicalPlanBuilder { provider: Arc, projection: Option>, filters: Vec, - ) -> Result { + ) -> Result { let table_name = table_name.into(); if table_name.is_empty() { @@ -429,6 +457,10 @@ impl LogicalPlanBuilder { )); } + //TODO hack so we don't register "employee.csv" as schema "employee" table "csv" + // this did not come up before because we accessed TableProvider directly + let table_name = table_name.replace(".", "_"); + let schema = provider.schema(); let projected_schema = projection @@ -448,14 +480,16 @@ impl LogicalPlanBuilder { })?; let table_scan = LogicalPlan::TableScan(TableScan { - table_name, - source: provider, + table_name: table_name.clone(), projected_schema: Arc::new(projected_schema), projection, filters, + full_filters: vec![], + partial_filters: vec![], + unsupported_filters: vec![], limit: None, }); - Ok(Self::from(table_scan)) + Ok(ScanInfo::new(Self::from(table_scan), table_name, provider)) } /// Wrap a plan in a window pub(crate) fn window_plan( @@ -1214,6 +1248,7 @@ mod tests { &employee_schema(), Some(vec![0, 3]), )? + .builder .filter(col("state").eq(lit("CO")))? .project(vec![col("id")])? .build()?; @@ -1230,8 +1265,9 @@ mod tests { #[test] fn plan_builder_schema() { let schema = employee_schema(); - let plan = - LogicalPlanBuilder::scan_empty(Some("employee_csv"), &schema, None).unwrap(); + let plan = LogicalPlanBuilder::scan_empty(Some("employee_csv"), &schema, None) + .unwrap() + .builder; let expected = DFSchema::try_from_qualified_schema("employee_csv", &schema).unwrap(); @@ -1246,6 +1282,7 @@ mod tests { &employee_schema(), Some(vec![3, 4]), )? + .builder .aggregate( vec![col("state")], vec![sum(col("salary")).alias("total_salary")], @@ -1269,6 +1306,7 @@ mod tests { &employee_schema(), Some(vec![3, 4]), )? + .builder .sort(vec![ Expr::Sort { expr: Box::new(col("state")), @@ -1297,6 +1335,7 @@ mod tests { .build()?; let plan = LogicalPlanBuilder::scan_empty(Some("t1"), &employee_schema(), None)? + .builder .join_using(&t2, JoinType::Inner, vec!["id"])? .project(vec![Expr::Wildcard])? .build()?; @@ -1321,6 +1360,7 @@ mod tests { )?; let plan = plan + .builder .union(plan.build()?)? .union(plan.build()?)? .union(plan.build()?)? @@ -1346,6 +1386,7 @@ mod tests { // project id and first_name by column index Some(vec![0, 1]), )? + .builder // two columns with the same name => error .project(vec![col("id"), col("first_name").alias("id")]); @@ -1372,6 +1413,7 @@ mod tests { // project state and salary by column index Some(vec![3, 4]), )? + .builder // two columns with the same name => error .aggregate(vec![col("state")], vec![sum(col("salary")).alias("state")]); diff --git a/datafusion/core/src/logical_plan/plan.rs b/datafusion/core/src/logical_plan/plan.rs index 66307c6aba464..a12eb91649377 100644 --- a/datafusion/core/src/logical_plan/plan.rs +++ b/datafusion/core/src/logical_plan/plan.rs @@ -20,8 +20,6 @@ use super::display::{GraphvizVisitor, IndentVisitor}; use super::expr::{Column, Expr}; use super::extension::UserDefinedLogicalNode; -use crate::datasource::datasource::TableProviderFilterPushDown; -use crate::datasource::TableProvider; use crate::error::DataFusionError; use crate::logical_plan::dfschema::DFSchemaRef; use crate::sql::parser::FileType; @@ -130,14 +128,18 @@ pub struct Window { pub struct TableScan { /// The name of the table pub table_name: String, - /// The source of the table - pub source: Arc, /// Optional column indices to use as a projection pub projection: Option>, /// The schema description of the output pub projected_schema: DFSchemaRef, /// Optional expressions to be used as filters by the table provider pub filters: Vec, + /// Filters that are fully supported by the table provider + pub full_filters: Vec, + /// Filters that are partially supported by the table provider + pub partial_filters: Vec, + /// Filters that are not supported by the table provider + pub unsupported_filters: Vec, /// Optional limit to skip reading pub limit: Option, } @@ -774,7 +776,8 @@ impl LogicalPlan { /// let schema = Schema::new(vec![ /// Field::new("id", DataType::Int32, false), /// ]); - /// let plan = LogicalPlanBuilder::scan_empty(Some("foo_csv"), &schema, None).unwrap() + /// let (plan, _, _) = LogicalPlanBuilder::scan_empty(Some("foo_csv"), &schema, None).unwrap(); + /// plan /// .filter(col("id").eq(lit(5))).unwrap() /// .build().unwrap(); /// @@ -815,7 +818,8 @@ impl LogicalPlan { /// let schema = Schema::new(vec![ /// Field::new("id", DataType::Int32, false), /// ]); - /// let plan = LogicalPlanBuilder::scan_empty(Some("foo_csv"), &schema, None).unwrap() + /// let (plan, _, _) = LogicalPlanBuilder::scan_empty(Some("foo_csv"), &schema, None).unwrap(); + /// plan /// .filter(col("id").eq(lit(5))).unwrap() /// .build().unwrap(); /// @@ -855,7 +859,8 @@ impl LogicalPlan { /// let schema = Schema::new(vec![ /// Field::new("id", DataType::Int32, false), /// ]); - /// let plan = LogicalPlanBuilder::scan_empty(Some("foo.csv"), &schema, None).unwrap() + /// let (plan, _, _) = LogicalPlanBuilder::scan_empty(Some("foo.csv"), &schema, None).unwrap(); + /// plan /// .filter(col("id").eq(lit(5))).unwrap() /// .build().unwrap(); /// @@ -950,10 +955,12 @@ impl LogicalPlan { } LogicalPlan::TableScan(TableScan { - ref source, ref table_name, ref projection, ref filters, + ref full_filters, + ref partial_filters, + ref unsupported_filters, ref limit, .. }) => { @@ -964,31 +971,11 @@ impl LogicalPlan { )?; if !filters.is_empty() { - let mut full_filter = vec![]; - let mut partial_filter = vec![]; - let mut unsupported_filters = vec![]; - - filters.iter().for_each(|x| { - if let Ok(t) = source.supports_filter_pushdown(x) { - match t { - TableProviderFilterPushDown::Exact => { - full_filter.push(x) - } - TableProviderFilterPushDown::Inexact => { - partial_filter.push(x) - } - TableProviderFilterPushDown::Unsupported => { - unsupported_filters.push(x) - } - } - } - }); - - if !full_filter.is_empty() { - write!(f, ", full_filters={:?}", full_filter)?; + if !full_filters.is_empty() { + write!(f, ", full_filters={:?}", full_filters)?; }; - if !partial_filter.is_empty() { - write!(f, ", partial_filters={:?}", partial_filter)?; + if !partial_filters.is_empty() { + write!(f, ", partial_filters={:?}", partial_filters)?; } if !unsupported_filters.is_empty() { write!( @@ -1240,18 +1227,19 @@ mod tests { } fn display_plan() -> LogicalPlan { - LogicalPlanBuilder::scan_empty( + let scan = LogicalPlanBuilder::scan_empty( Some("employee_csv"), &employee_schema(), Some(vec![0, 3]), ) - .unwrap() - .filter(col("state").eq(lit("CO"))) - .unwrap() - .project(vec![col("id")]) - .unwrap() - .build() - .unwrap() + .unwrap(); + scan.builder + .filter(col("state").eq(lit("CO"))) + .unwrap() + .project(vec![col("id")]) + .unwrap() + .build() + .unwrap() } #[test] @@ -1551,6 +1539,7 @@ mod tests { LogicalPlanBuilder::scan_empty(None, &schema, Some(vec![0, 1])) .unwrap() + .builder .filter(col("state").eq(lit("CO"))) .unwrap() .project(vec![col("id")]) diff --git a/datafusion/core/src/optimizer/common_subexpr_eliminate.rs b/datafusion/core/src/optimizer/common_subexpr_eliminate.rs index 39964df4a6635..2db6bd21f1b69 100644 --- a/datafusion/core/src/optimizer/common_subexpr_eliminate.rs +++ b/datafusion/core/src/optimizer/common_subexpr_eliminate.rs @@ -17,6 +17,7 @@ //! Eliminate common sub-expression. +use crate::catalog::catalog::CatalogList; use crate::error::Result; use crate::execution::context::ExecutionProps; use crate::logical_plan::plan::{Filter, Projection, Window}; @@ -60,8 +61,9 @@ impl OptimizerRule for CommonSubexprEliminate { &self, plan: &LogicalPlan, execution_props: &ExecutionProps, + catalog_list: &dyn CatalogList, ) -> Result { - optimize(plan, execution_props) + optimize(plan, execution_props, catalog_list) } fn name(&self) -> &str { @@ -82,7 +84,11 @@ impl CommonSubexprEliminate { } } -fn optimize(plan: &LogicalPlan, execution_props: &ExecutionProps) -> Result { +fn optimize( + plan: &LogicalPlan, + execution_props: &ExecutionProps, + catalog_list: &dyn CatalogList, +) -> Result { let mut expr_set = ExprSet::new(); match plan { @@ -101,6 +107,7 @@ fn optimize(plan: &LogicalPlan, execution_props: &ExecutionProps) -> Result Result Result Result Result Result>>()?; utils::from_plan(plan, &expr, &new_inputs) @@ -299,6 +310,7 @@ fn rewrite_expr( expr_set: &mut ExprSet, schema: &DFSchema, execution_props: &ExecutionProps, + catalog_list: &dyn CatalogList, ) -> Result<(Vec>, LogicalPlan)> { let mut affected_id = HashSet::::new(); @@ -323,7 +335,7 @@ fn rewrite_expr( }) .collect::>>()?; - let mut new_input = optimize(input, execution_props)?; + let mut new_input = optimize(input, execution_props, catalog_list)?; if !affected_id.is_empty() { new_input = build_project_plan(new_input, affected_id, expr_set)?; } @@ -660,7 +672,11 @@ mod test { fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { let optimizer = CommonSubexprEliminate {}; let optimized_plan = optimizer - .optimize(plan, &ExecutionProps::new()) + .optimize( + plan, + &ExecutionProps::default(), + create_test_table_catalog_list().as_ref(), + ) .expect("failed to optimize plan"); let formatted_plan = format!("{:?}", optimized_plan); assert_eq!(formatted_plan, expected); diff --git a/datafusion/core/src/optimizer/eliminate_filter.rs b/datafusion/core/src/optimizer/eliminate_filter.rs index 800963ef550f3..41474434ee433 100644 --- a/datafusion/core/src/optimizer/eliminate_filter.rs +++ b/datafusion/core/src/optimizer/eliminate_filter.rs @@ -18,6 +18,7 @@ //! Optimizer rule to replace `where false` on a plan with an empty relation. //! This saves time in planning and executing the query. //! Note that this rule should be applied after simplify expressions optimizer rule. +use crate::catalog::catalog::CatalogList; use datafusion_common::ScalarValue; use datafusion_expr::Expr; @@ -45,6 +46,7 @@ impl OptimizerRule for EliminateFilter { &self, plan: &LogicalPlan, execution_props: &ExecutionProps, + catalog_list: &dyn CatalogList, ) -> Result { match plan { LogicalPlan::Filter(Filter { @@ -65,7 +67,7 @@ impl OptimizerRule for EliminateFilter { let inputs = plan.inputs(); let new_inputs = inputs .iter() - .map(|plan| self.optimize(plan, execution_props)) + .map(|plan| self.optimize(plan, execution_props, catalog_list)) .collect::>>()?; utils::from_plan(plan, &plan.expressions(), &new_inputs) @@ -88,7 +90,11 @@ mod tests { fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { let rule = EliminateFilter::new(); let optimized_plan = rule - .optimize(plan, &ExecutionProps::new()) + .optimize( + plan, + &ExecutionProps::default(), + create_test_table_catalog_list().as_ref(), + ) .expect("failed to optimize plan"); let formatted_plan = format!("{:?}", optimized_plan); assert_eq!(formatted_plan, expected); diff --git a/datafusion/core/src/optimizer/eliminate_limit.rs b/datafusion/core/src/optimizer/eliminate_limit.rs index c1fc2068d3250..bd4770d056fce 100644 --- a/datafusion/core/src/optimizer/eliminate_limit.rs +++ b/datafusion/core/src/optimizer/eliminate_limit.rs @@ -17,6 +17,7 @@ //! Optimizer rule to replace `LIMIT 0` on a plan with an empty relation. //! This saves time in planning and executing the query. +use crate::catalog::catalog::CatalogList; use crate::error::Result; use crate::logical_plan::{EmptyRelation, Limit, LogicalPlan}; use crate::optimizer::optimizer::OptimizerRule; @@ -40,6 +41,7 @@ impl OptimizerRule for EliminateLimit { &self, plan: &LogicalPlan, execution_props: &ExecutionProps, + catalog_list: &dyn CatalogList, ) -> Result { match plan { LogicalPlan::Limit(Limit { n, input }) if *n == 0 => { @@ -56,7 +58,7 @@ impl OptimizerRule for EliminateLimit { let inputs = plan.inputs(); let new_inputs = inputs .iter() - .map(|plan| self.optimize(plan, execution_props)) + .map(|plan| self.optimize(plan, execution_props, catalog_list)) .collect::>>()?; utils::from_plan(plan, &expr, &new_inputs) @@ -79,7 +81,11 @@ mod tests { fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { let rule = EliminateLimit::new(); let optimized_plan = rule - .optimize(plan, &ExecutionProps::new()) + .optimize( + plan, + &ExecutionProps::default(), + create_test_table_catalog_list().as_ref(), + ) .expect("failed to optimize plan"); let formatted_plan = format!("{:?}", optimized_plan); assert_eq!(formatted_plan, expected); diff --git a/datafusion/core/src/optimizer/filter_push_down.rs b/datafusion/core/src/optimizer/filter_push_down.rs index 30a7ee97328e8..74ef083262623 100644 --- a/datafusion/core/src/optimizer/filter_push_down.rs +++ b/datafusion/core/src/optimizer/filter_push_down.rs @@ -14,6 +14,7 @@ //! Filter Push Down optimizer rule ensures that filters are applied as early as possible in the plan +use crate::catalog::catalog::{get_table_provider, CatalogList}; use crate::datasource::datasource::TableProviderFilterPushDown; use crate::execution::context::ExecutionProps; use crate::logical_plan::plan::{Aggregate, Filter, Join, Projection, Union}; @@ -24,6 +25,7 @@ use crate::logical_plan::{DFSchema, Expr}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; use crate::{error::Result, logical_plan::Operator}; +use datafusion_common::DataFusionError; use std::{ collections::{HashMap, HashSet}, sync::Arc, @@ -84,11 +86,16 @@ fn get_predicates<'a>( } /// Optimizes the plan -fn push_down(state: &State, plan: &LogicalPlan) -> Result { +fn push_down( + state: &State, + plan: &LogicalPlan, + execution_props: &ExecutionProps, + catalog_list: &dyn CatalogList, +) -> Result { let new_inputs = plan .inputs() .iter() - .map(|input| optimize(input, state.clone())) + .map(|input| optimize(input, state.clone(), execution_props, catalog_list)) .collect::>>()?; let expr = plan.expressions(); @@ -140,6 +147,8 @@ fn keep_filters( /// in `state` depend on the columns `used_columns`. fn issue_filters( mut state: State, + execution_props: &ExecutionProps, + catalog_list: &dyn CatalogList, used_columns: HashSet, plan: &LogicalPlan, ) -> Result { @@ -147,7 +156,7 @@ fn issue_filters( if predicates.is_empty() { // all filters can be pushed down => optimize inputs and return new plan - return push_down(&state, plan); + return push_down(&state, plan, execution_props, catalog_list); } let plan = add_filter(plan.clone(), &predicates); @@ -155,7 +164,7 @@ fn issue_filters( state.filters = remove_filters(&state.filters, &predicate_columns); // continue optimization over all input nodes by cloning the current state (i.e. each node is independent) - push_down(&state, &plan) + push_down(&state, &plan, execution_props, catalog_list) } /// converts "A AND B AND C" => [A, B, C] @@ -254,6 +263,8 @@ fn get_pushable_join_predicates<'a>( fn optimize_join( mut state: State, + execution_props: &ExecutionProps, + catalog_list: &dyn CatalogList, plan: &LogicalPlan, left: &LogicalPlan, right: &LogicalPlan, @@ -275,11 +286,11 @@ fn optimize_join( let mut left_state = state.clone(); left_state.filters = keep_filters(&left_state.filters, &to_left); - let left = optimize(left, left_state)?; + let left = optimize(left, left_state, execution_props, catalog_list)?; let mut right_state = state.clone(); right_state.filters = keep_filters(&right_state.filters, &to_right); - let right = optimize(right, right_state)?; + let right = optimize(right, right_state, execution_props, catalog_list)?; // create a new Join with the new `left` and `right` let expr = plan.expressions(); @@ -296,13 +307,20 @@ fn optimize_join( } } -fn optimize(plan: &LogicalPlan, mut state: State) -> Result { +fn optimize( + plan: &LogicalPlan, + mut state: State, + execution_props: &ExecutionProps, + catalog_list: &dyn CatalogList, +) -> Result { match plan { LogicalPlan::Explain { .. } => { // push the optimization to the plan of this explain - push_down(&state, plan) + push_down(&state, plan, execution_props, catalog_list) + } + LogicalPlan::Analyze { .. } => { + push_down(&state, plan, execution_props, catalog_list) } - LogicalPlan::Analyze { .. } => push_down(&state, plan), LogicalPlan::Filter(Filter { input, predicate }) => { let mut predicates = vec![]; split_members(predicate, &mut predicates); @@ -328,9 +346,12 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { // As those contain only literals, they could be optimized using constant folding // and removal of WHERE TRUE / WHERE FALSE if !no_col_predicates.is_empty() { - Ok(add_filter(optimize(input, state)?, &no_col_predicates)) + Ok(add_filter( + optimize(input, state, execution_props, catalog_list)?, + &no_col_predicates, + )) } else { - optimize(input, state) + optimize(input, state, execution_props, catalog_list) } } LogicalPlan::Projection(Projection { @@ -366,7 +387,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { } // optimize inner - let new_input = optimize(input, state)?; + let new_input = optimize(input, state, execution_props, catalog_list)?; utils::from_plan(plan, expr, &[new_input]) } @@ -387,11 +408,11 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { .collect::>>()?; used_columns.extend(agg_columns); - issue_filters(state, used_columns, plan) + issue_filters(state, execution_props, catalog_list, used_columns, plan) } LogicalPlan::Sort { .. } => { // sort is filter-commutable - push_down(&state, plan) + push_down(&state, plan, execution_props, catalog_list) } LogicalPlan::Union(Union { inputs: _, @@ -416,7 +437,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { } } - push_down(&state, plan) + push_down(&state, plan, execution_props, catalog_list) } LogicalPlan::Limit(Limit { input, .. }) => { // limit is _not_ filter-commutable => collect all columns from its input @@ -426,10 +447,10 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { .iter() .map(|f| f.qualified_column()) .collect::>(); - issue_filters(state, used_columns, plan) + issue_filters(state, execution_props, catalog_list, used_columns, plan) } LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { - optimize_join(state, plan, left, right) + optimize_join(state, execution_props, catalog_list, plan, left, right) } LogicalPlan::Join(Join { left, @@ -494,53 +515,88 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { .collect::>>()?; state.filters.extend(join_side_filters); } - optimize_join(state, plan, left, right) + optimize_join(state, execution_props, catalog_list, plan, left, right) } LogicalPlan::TableScan(TableScan { - source, projected_schema, filters, projection, table_name, limit, + .. }) => { let mut used_columns = HashSet::new(); let mut new_filters = filters.clone(); - for (filter_expr, cols) in &state.filters { - let (preserve_filter_node, add_to_provider) = - match source.supports_filter_pushdown(filter_expr)? { - TableProviderFilterPushDown::Unsupported => (true, false), - TableProviderFilterPushDown::Inexact => (true, true), - TableProviderFilterPushDown::Exact => (false, true), - }; + if let Some(source) = get_table_provider(catalog_list, table_name) { + // TODO consolidate this logic and remove duplication + + let mut full_filters: Vec = vec![]; + let mut partial_filters: Vec = vec![]; + let mut unsupported_filters: Vec = vec![]; - if preserve_filter_node { - used_columns.extend(cols.clone()); + if !filters.is_empty() { + filters.iter().for_each(|x| { + if let Ok(t) = source.supports_filter_pushdown(x) { + match t { + TableProviderFilterPushDown::Exact => { + full_filters.push(x.clone()) + } + TableProviderFilterPushDown::Inexact => { + partial_filters.push(x.clone()) + } + TableProviderFilterPushDown::Unsupported => { + unsupported_filters.push(x.clone()) + } + } + } + }); } - if add_to_provider { - // Don't add expression again if it's already present in - // pushed down filters. - if new_filters.contains(filter_expr) { - continue; + for (filter_expr, cols) in &state.filters { + let (preserve_filter_node, add_to_provider) = + match source.supports_filter_pushdown(filter_expr)? { + TableProviderFilterPushDown::Unsupported => (true, false), + TableProviderFilterPushDown::Inexact => (true, true), + TableProviderFilterPushDown::Exact => (false, true), + }; + + if preserve_filter_node { + used_columns.extend(cols.clone()); + } + + if add_to_provider { + // Don't add expression again if it's already present in + // pushed down filters. + if new_filters.contains(filter_expr) { + continue; + } + new_filters.push(filter_expr.clone()); } - new_filters.push(filter_expr.clone()); } - } - issue_filters( - state, - used_columns, - &LogicalPlan::TableScan(TableScan { - source: source.clone(), - projection: projection.clone(), - projected_schema: projected_schema.clone(), - table_name: table_name.clone(), - filters: new_filters, - limit: *limit, - }), - ) + issue_filters( + state, + execution_props, + catalog_list, + used_columns, + &LogicalPlan::TableScan(TableScan { + projection: projection.clone(), + projected_schema: projected_schema.clone(), + table_name: table_name.clone(), + filters: new_filters, + full_filters, + partial_filters, + unsupported_filters, + limit: *limit, + }), + ) + } else { + Err(DataFusionError::Plan(format!( + "No table provider named {}", + table_name + ))) + } } _ => { // all other plans are _not_ filter-commutable @@ -550,7 +606,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { .iter() .map(|f| f.qualified_column()) .collect::>(); - issue_filters(state, used_columns, plan) + issue_filters(state, execution_props, catalog_list, used_columns, plan) } } } @@ -560,8 +616,13 @@ impl OptimizerRule for FilterPushDown { "filter_push_down" } - fn optimize(&self, plan: &LogicalPlan, _: &ExecutionProps) -> Result { - optimize(plan, State::default()) + fn optimize( + &self, + plan: &LogicalPlan, + execution_props: &ExecutionProps, + catalog_list: &dyn CatalogList, + ) -> Result { + optimize(plan, State::default(), execution_props, catalog_list) } } @@ -605,8 +666,12 @@ mod tests { fn optimize_plan(plan: &LogicalPlan) -> LogicalPlan { let rule = FilterPushDown::new(); - rule.optimize(plan, &ExecutionProps::new()) - .expect("failed to optimize plan") + rule.optimize( + plan, + &ExecutionProps::default(), + create_test_table_catalog_list().as_ref(), + ) + .expect("failed to optimize plan") } fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { @@ -1413,11 +1478,13 @@ mod tests { let table_scan = LogicalPlan::TableScan(TableScan { table_name: "test".to_string(), filters: vec![], + full_filters: vec![], + partial_filters: vec![], + unsupported_filters: vec![], projected_schema: Arc::new(DFSchema::try_from( (*test_provider.schema()).clone(), )?), projection: None, - source: Arc::new(test_provider), limit: None, }); @@ -1486,11 +1553,13 @@ mod tests { let table_scan = LogicalPlan::TableScan(TableScan { table_name: "test".to_string(), filters: vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))], + full_filters: vec![], + partial_filters: vec![], + unsupported_filters: vec![], projected_schema: Arc::new(DFSchema::try_from( (*test_provider.schema()).clone(), )?), projection: Some(vec![0]), - source: Arc::new(test_provider), limit: None, }); diff --git a/datafusion/core/src/optimizer/limit_push_down.rs b/datafusion/core/src/optimizer/limit_push_down.rs index 0c68f1761601d..d4f77f33f3780 100644 --- a/datafusion/core/src/optimizer/limit_push_down.rs +++ b/datafusion/core/src/optimizer/limit_push_down.rs @@ -18,6 +18,7 @@ //! Optimizer rule to push down LIMIT in the query plan //! It will push down through projection, limits (taking the smaller limit) use super::utils; +use crate::catalog::catalog::CatalogList; use crate::error::Result; use crate::execution::context::ExecutionProps; use crate::logical_plan::plan::Projection; @@ -42,7 +43,7 @@ fn limit_push_down( _optimizer: &LimitPushDown, upper_limit: Option, plan: &LogicalPlan, - _execution_props: &ExecutionProps, + execution_props: &ExecutionProps, ) -> Result { match (plan, upper_limit) { (LogicalPlan::Limit(Limit { n, input }), upper_limit) => { @@ -54,25 +55,29 @@ fn limit_push_down( _optimizer, Some(smallest), input.as_ref(), - _execution_props, + execution_props, )?), })) } ( LogicalPlan::TableScan(TableScan { table_name, - source, projection, filters, + full_filters, + partial_filters, + unsupported_filters, limit, projected_schema, }), Some(upper_limit), ) => Ok(LogicalPlan::TableScan(TableScan { table_name: table_name.clone(), - source: source.clone(), projection: projection.clone(), filters: filters.clone(), + full_filters: full_filters.clone(), + partial_filters: partial_filters.clone(), + unsupported_filters: unsupported_filters.clone(), limit: limit .map(|x| std::cmp::min(x, upper_limit)) .or(Some(upper_limit)), @@ -94,7 +99,7 @@ fn limit_push_down( _optimizer, upper_limit, input.as_ref(), - _execution_props, + execution_props, )?), schema: schema.clone(), alias: alias.clone(), @@ -118,7 +123,7 @@ fn limit_push_down( _optimizer, Some(upper_limit), x, - _execution_props, + execution_props, )?), })) }) @@ -138,7 +143,7 @@ fn limit_push_down( let inputs = plan.inputs(); let new_inputs = inputs .iter() - .map(|plan| limit_push_down(_optimizer, None, plan, _execution_props)) + .map(|plan| limit_push_down(_optimizer, None, plan, execution_props)) .collect::>>()?; utils::from_plan(plan, &expr, &new_inputs) @@ -151,6 +156,7 @@ impl OptimizerRule for LimitPushDown { &self, plan: &LogicalPlan, execution_props: &ExecutionProps, + _catalog_list: &dyn CatalogList, ) -> Result { limit_push_down(self, None, plan, execution_props) } @@ -171,7 +177,11 @@ mod test { fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { let rule = LimitPushDown::new(); let optimized_plan = rule - .optimize(plan, &ExecutionProps::new()) + .optimize( + plan, + &ExecutionProps::default(), + create_test_table_catalog_list().as_ref(), + ) .expect("failed to optimize plan"); let formatted_plan = format!("{:?}", optimized_plan); assert_eq!(formatted_plan, expected); diff --git a/datafusion/core/src/optimizer/optimizer.rs b/datafusion/core/src/optimizer/optimizer.rs index 5cf4047947044..2b6cfdb0a1490 100644 --- a/datafusion/core/src/optimizer/optimizer.rs +++ b/datafusion/core/src/optimizer/optimizer.rs @@ -17,6 +17,7 @@ //! Query optimizer traits +use crate::catalog::catalog::CatalogList; use crate::error::Result; use crate::execution::context::ExecutionProps; use crate::logical_plan::LogicalPlan; @@ -30,6 +31,7 @@ pub trait OptimizerRule { &self, plan: &LogicalPlan, execution_props: &ExecutionProps, + catalog_list: &dyn CatalogList, ) -> Result; /// A human readable name for this optimizer rule diff --git a/datafusion/core/src/optimizer/projection_push_down.rs b/datafusion/core/src/optimizer/projection_push_down.rs index 10bf5d10f9602..0653c759cf5f8 100644 --- a/datafusion/core/src/optimizer/projection_push_down.rs +++ b/datafusion/core/src/optimizer/projection_push_down.rs @@ -18,6 +18,7 @@ //! Projection Push Down optimizer rule ensures that only referenced columns are //! loaded into memory +use crate::catalog::catalog::{get_table_provider, CatalogList}; use crate::error::{DataFusionError, Result}; use crate::execution::context::ExecutionProps; use crate::logical_plan::plan::{ @@ -47,6 +48,7 @@ impl OptimizerRule for ProjectionPushDown { &self, plan: &LogicalPlan, execution_props: &ExecutionProps, + catalog_list: &dyn CatalogList, ) -> Result { // set of all columns refered by the plan (and thus considered required by the root) let required_columns = plan @@ -55,7 +57,14 @@ impl OptimizerRule for ProjectionPushDown { .iter() .map(|f| f.qualified_column()) .collect::>(); - optimize_plan(self, plan, &required_columns, false, execution_props) + optimize_plan( + self, + plan, + &required_columns, + false, + execution_props, + catalog_list, + ) } fn name(&self) -> &str { @@ -131,6 +140,7 @@ fn optimize_plan( required_columns: &HashSet, // set of columns required up to this step has_projection: bool, _execution_props: &ExecutionProps, + catalog_list: &dyn CatalogList, ) -> Result { let mut new_required_columns = required_columns.clone(); match plan { @@ -170,6 +180,7 @@ fn optimize_plan( &new_required_columns, true, _execution_props, + catalog_list, )?; let new_required_columns_optimized = new_input @@ -216,6 +227,7 @@ fn optimize_plan( &new_required_columns, true, _execution_props, + catalog_list, )?); let optimized_right = Arc::new(optimize_plan( @@ -224,6 +236,7 @@ fn optimize_plan( &new_required_columns, true, _execution_props, + catalog_list, )?); let schema = build_join_schema( @@ -277,6 +290,7 @@ fn optimize_plan( &new_required_columns, true, _execution_props, + catalog_list, )?) .window(new_window_expr)? .build() @@ -329,6 +343,7 @@ fn optimize_plan( &new_required_columns, true, _execution_props, + catalog_list, )?), schema: DFSchemaRef::new(new_schema), })) @@ -337,26 +352,37 @@ fn optimize_plan( // * remove un-used columns from the scan projection LogicalPlan::TableScan(TableScan { table_name, - source, filters, + full_filters, + partial_filters, + unsupported_filters, limit, .. }) => { - let (projection, projected_schema) = get_projected_schema( - Some(table_name), - &source.schema(), - required_columns, - has_projection, - )?; - // return the table scan with projection - Ok(LogicalPlan::TableScan(TableScan { - table_name: table_name.clone(), - source: source.clone(), - projection: Some(projection), - projected_schema, - filters: filters.clone(), - limit: *limit, - })) + if let Some(source) = get_table_provider(catalog_list, table_name) { + let (projection, projected_schema) = get_projected_schema( + Some(table_name), + &source.schema(), + required_columns, + has_projection, + )?; + // return the table scan with projection + Ok(LogicalPlan::TableScan(TableScan { + table_name: table_name.clone(), + projection: Some(projection), + projected_schema, + filters: filters.clone(), + full_filters: full_filters.clone(), + partial_filters: partial_filters.clone(), + unsupported_filters: unsupported_filters.clone(), + limit: *limit, + })) + } else { + Err(DataFusionError::Execution(format!( + "Could not resolve table provider named {}", + table_name + ))) + } } LogicalPlan::Explain { .. } => Err(DataFusionError::Internal( "Unsupported logical plan: Explain must be root of the plan".to_string(), @@ -378,6 +404,7 @@ fn optimize_plan( &required_columns, false, _execution_props, + catalog_list, )?), verbose: a.verbose, schema: a.schema.clone(), @@ -414,6 +441,7 @@ fn optimize_plan( &new_required_columns, has_projection, _execution_props, + catalog_list, ) }) .collect::>>()?; @@ -451,6 +479,7 @@ fn optimize_plan( &new_required_columns, has_projection, _execution_props, + catalog_list, )?]; let expr = vec![]; utils::from_plan(plan, &expr, &new_inputs) @@ -490,6 +519,7 @@ fn optimize_plan( &new_required_columns, has_projection, _execution_props, + catalog_list, ) }) .collect::>>()?; @@ -981,6 +1011,10 @@ mod tests { fn optimize(plan: &LogicalPlan) -> Result { let rule = ProjectionPushDown::new(); - rule.optimize(plan, &ExecutionProps::new()) + rule.optimize( + plan, + &ExecutionProps::default(), + create_test_table_catalog_list().as_ref(), + ) } } diff --git a/datafusion/core/src/optimizer/simplify_expressions.rs b/datafusion/core/src/optimizer/simplify_expressions.rs index 8e1e7d319e8e7..c796ea5fb2329 100644 --- a/datafusion/core/src/optimizer/simplify_expressions.rs +++ b/datafusion/core/src/optimizer/simplify_expressions.rs @@ -17,6 +17,7 @@ //! Simplify expressions optimizer rule +use crate::catalog::catalog::CatalogList; use crate::error::DataFusionError; use crate::execution::context::ExecutionProps; use crate::logical_plan::ExprSchemable; @@ -195,6 +196,7 @@ impl OptimizerRule for SimplifyExpressions { &self, plan: &LogicalPlan, execution_props: &ExecutionProps, + catalog_list: &dyn CatalogList, ) -> Result { // We need to pass down the all schemas within the plan tree to `optimize_expr` in order to // to evaluate expression types. For example, a projection plan's schema will only include @@ -206,7 +208,7 @@ impl OptimizerRule for SimplifyExpressions { let new_inputs = plan .inputs() .iter() - .map(|input| self.optimize(input, execution_props)) + .map(|input| self.optimize(input, execution_props, catalog_list)) .collect::>>()?; let expr = plan @@ -257,7 +259,7 @@ impl SimplifyExpressions { /// # use datafusion::optimizer::simplify_expressions::ConstEvaluator; /// # use datafusion::execution::context::ExecutionProps; /// -/// let execution_props = ExecutionProps::new(); +/// let execution_props = ExecutionProps::default(); /// let mut const_evaluator = ConstEvaluator::new(&execution_props); /// /// // (1 + 2) + a @@ -744,6 +746,7 @@ mod tests { }; use crate::physical_plan::functions::{make_scalar_function, BuiltinScalarFunction}; use crate::physical_plan::udf::ScalarUDF; + use crate::test::create_test_table_catalog_list; #[test] fn test_simplify_or_true() { @@ -1201,7 +1204,7 @@ mod tests { fn simplify(expr: Expr) -> Expr { let schema = expr_test_schema(); - let execution_props = ExecutionProps::new(); + let execution_props = ExecutionProps::default(); let info = SimplifyContext::new(vec![&schema], &execution_props); expr.simplify(&info).unwrap() } @@ -1505,6 +1508,7 @@ mod tests { ]); LogicalPlanBuilder::scan_empty(Some("test"), &schema, None) .expect("creating scan") + .builder .build() .expect("building plan") } @@ -1512,7 +1516,11 @@ mod tests { fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { let rule = SimplifyExpressions::new(); let optimized_plan = rule - .optimize(plan, &ExecutionProps::new()) + .optimize( + plan, + &ExecutionProps::default(), + create_test_table_catalog_list().as_ref(), + ) .expect("failed to optimize plan"); let formatted_plan = format!("{:?}", optimized_plan); assert_eq!(formatted_plan, expected); @@ -1737,7 +1745,11 @@ mod tests { }; let err = rule - .optimize(plan, &execution_props) + .optimize( + plan, + &execution_props, + create_test_table_catalog_list().as_ref(), + ) .expect_err("expected optimization to fail"); err.to_string() @@ -1754,7 +1766,11 @@ mod tests { }; let optimized_plan = rule - .optimize(plan, &execution_props) + .optimize( + plan, + &execution_props, + create_test_table_catalog_list().as_ref(), + ) .expect("failed to optimize plan"); return format!("{:?}", optimized_plan); } diff --git a/datafusion/core/src/optimizer/single_distinct_to_groupby.rs b/datafusion/core/src/optimizer/single_distinct_to_groupby.rs index dfbefa63acd8f..f4695d60055c7 100644 --- a/datafusion/core/src/optimizer/single_distinct_to_groupby.rs +++ b/datafusion/core/src/optimizer/single_distinct_to_groupby.rs @@ -17,6 +17,7 @@ //! single distinct to group by optimizer rule +use crate::catalog::catalog::CatalogList; use crate::error::Result; use crate::execution::context::ExecutionProps; use crate::logical_plan::plan::{Aggregate, Projection}; @@ -189,6 +190,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { &self, plan: &LogicalPlan, _execution_props: &ExecutionProps, + _catalog_list: &dyn CatalogList, ) -> Result { optimize(plan) } @@ -207,7 +209,11 @@ mod tests { fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { let rule = SingleDistinctToGroupBy::new(); let optimized_plan = rule - .optimize(plan, &ExecutionProps::new()) + .optimize( + plan, + &ExecutionProps::default(), + create_test_table_catalog_list().as_ref(), + ) .expect("failed to optimize plan"); let formatted_plan = format!("{}", optimized_plan.display_indent_schema()); assert_eq!(formatted_plan, expected); diff --git a/datafusion/core/src/optimizer/to_approx_perc.rs b/datafusion/core/src/optimizer/to_approx_perc.rs index c33c3f67602a1..86486ac76a436 100644 --- a/datafusion/core/src/optimizer/to_approx_perc.rs +++ b/datafusion/core/src/optimizer/to_approx_perc.rs @@ -17,6 +17,7 @@ //! espression/function to approx_percentile optimizer rule +use crate::catalog::catalog::CatalogList; use crate::error::Result; use crate::execution::context::ExecutionProps; use crate::logical_plan::plan::Aggregate; @@ -114,6 +115,7 @@ impl OptimizerRule for ToApproxPerc { &self, plan: &LogicalPlan, _execution_props: &ExecutionProps, + _catalog_list: &dyn CatalogList, ) -> Result { optimize(plan) } @@ -132,7 +134,11 @@ mod tests { fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { let rule = ToApproxPerc::new(); let optimized_plan = rule - .optimize(plan, &ExecutionProps::new()) + .optimize( + plan, + &ExecutionProps::default(), + create_test_table_catalog_list().as_ref(), + ) .expect("failed to optimize plan"); let formatted_plan = format!("{}", optimized_plan.display_indent_schema()); assert_eq!(formatted_plan, expected); diff --git a/datafusion/core/src/optimizer/utils.rs b/datafusion/core/src/optimizer/utils.rs index 0dab2d3ed7bcb..fc9d25fc64054 100644 --- a/datafusion/core/src/optimizer/utils.rs +++ b/datafusion/core/src/optimizer/utils.rs @@ -23,6 +23,7 @@ use crate::logical_plan::plan::{ Aggregate, Analyze, Extension, Filter, Join, Projection, Sort, SubqueryAlias, Window, }; +use crate::catalog::catalog::CatalogList; use crate::logical_plan::{ build_join_schema, Column, CreateMemoryTable, DFSchemaRef, Expr, ExprVisitable, Limit, LogicalPlan, LogicalPlanBuilder, Operator, Partitioning, Recursion, @@ -108,12 +109,13 @@ pub fn optimize_children( optimizer: &impl OptimizerRule, plan: &LogicalPlan, execution_props: &ExecutionProps, + catalog_list: &dyn CatalogList, ) -> Result { let new_exprs = plan.expressions(); let new_inputs = plan .inputs() .into_iter() - .map(|plan| optimizer.optimize(plan, execution_props)) + .map(|plan| optimizer.optimize(plan, execution_props, catalog_list)) .collect::>>()?; from_plan(plan, &new_exprs, &new_inputs) diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 67b7476e55795..8dd2f2ae004c2 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -134,7 +134,7 @@ impl PruningPredicate { let stat_dfschema = DFSchema::try_from(stat_schema.clone())?; // TODO allow these properties to be passed in - let execution_props = ExecutionProps::new(); + let execution_props = ExecutionProps::default(); let predicate_expr = create_physical_expr( &logical_predicate_expr, &stat_dfschema, diff --git a/datafusion/core/src/physical_plan/functions.rs b/datafusion/core/src/physical_plan/functions.rs index ae7a2bd7bbd7c..4fabd0da71bc8 100644 --- a/datafusion/core/src/physical_plan/functions.rs +++ b/datafusion/core/src/physical_plan/functions.rs @@ -1196,7 +1196,7 @@ mod tests { ($FUNC:ident, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $DATA_TYPE: ident, $ARRAY_TYPE:ident) => { // used to provide type annotation let expected: Result> = $EXPECTED; - let execution_props = ExecutionProps::new(); + let execution_props = ExecutionProps::default(); // any type works here: we evaluate against a literal of `value` let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); @@ -3431,7 +3431,7 @@ mod tests { #[test] fn test_empty_arguments_error() -> Result<()> { - let execution_props = ExecutionProps::new(); + let execution_props = ExecutionProps::default(); let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); // pick some arbitrary functions to test @@ -3474,7 +3474,7 @@ mod tests { #[test] fn test_empty_arguments() -> Result<()> { - let execution_props = ExecutionProps::new(); + let execution_props = ExecutionProps::default(); let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let funs = [BuiltinScalarFunction::Now, BuiltinScalarFunction::Random]; @@ -3497,7 +3497,7 @@ mod tests { Field::new("b", value2.data_type().clone(), false), ]); let columns: Vec = vec![value1, value2]; - let execution_props = ExecutionProps::new(); + let execution_props = ExecutionProps::default(); let expr = create_physical_expr( &BuiltinScalarFunction::Array, @@ -3560,7 +3560,7 @@ mod tests { fn test_regexp_match() -> Result<()> { use arrow::array::ListArray; let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); - let execution_props = ExecutionProps::new(); + let execution_props = ExecutionProps::default(); let col_value: ArrayRef = Arc::new(StringArray::from_slice(&["aaa-555"])); let pattern = lit(ScalarValue::Utf8(Some(r".*-(\d*)".to_string()))); @@ -3599,7 +3599,7 @@ mod tests { fn test_regexp_match_all_literals() -> Result<()> { use arrow::array::ListArray; let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let execution_props = ExecutionProps::new(); + let execution_props = ExecutionProps::default(); let col_value = lit(ScalarValue::Utf8(Some("aaa-555".to_string()))); let pattern = lit(ScalarValue::Utf8(Some(r".*-(\d*)".to_string()))); diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 98076d1365bcd..f278fce94e1d9 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -22,6 +22,7 @@ use super::{ aggregates, empty::EmptyExec, expressions::binary, functions, hash_join::PartitionMode, udaf, union::UnionExec, values::ValuesExec, windows, }; +use crate::catalog::catalog::get_table_provider; use crate::execution::context::{ExecutionProps, SessionState}; use crate::logical_plan::plan::{ Aggregate, EmptyRelation, Filter, Join, Projection, Sort, SubqueryAlias, TableScan, @@ -333,18 +334,23 @@ impl DefaultPhysicalPlanner { async move { let exec_plan: Result> = match logical_plan { LogicalPlan::TableScan (TableScan { - source, + table_name, projection, filters, limit, .. }) => { - // Remove all qualifiers from the scan as the provider - // doesn't know (nor should care) how the relation was - // referred to in the query - let filters = unnormalize_cols(filters.iter().cloned()); - let unaliased: Vec = filters.into_iter().map(unalias).collect(); - source.scan(projection, &unaliased, *limit).await + match get_table_provider(session_state.catalog_list.as_ref(), &table_name) { + Some(t) => { + // Remove all qualifiers from the scan as the provider + // doesn't know (nor should care) how the relation was + // referred to in the query + let filters = unnormalize_cols(filters.iter().cloned()); + let unaliased: Vec = filters.into_iter().map(unalias).collect(); + t.scan(projection, &unaliased, *limit).await + } + _ => Err(DataFusionError::Plan(format!("No table provider named {}", table_name))) + } } LogicalPlan::Values(Values { values, @@ -1517,6 +1523,7 @@ mod tests { 1, ) .await? + .builder // filter clause needs the type coercion rule applied .filter(col("c7").lt(lit(5_u8)))? .project(vec![col("c1"), col("c2")])? @@ -1569,6 +1576,7 @@ mod tests { 1, ) .await? + .builder .filter(col("c7").lt(col("c12")))? .build()?; @@ -1613,6 +1621,7 @@ mod tests { 1, ) .await? + .builder .project(vec![case.clone()]); let message = format!( "Expression {:?} expected to error due to impossible coercion", @@ -1713,6 +1722,7 @@ mod tests { 1, ) .await? + .builder // filter clause needs the type coercion rule applied .filter(col("c12").lt(lit(0.05)))? .project(vec![col("c1").in_list(list, false)])? @@ -1735,6 +1745,7 @@ mod tests { 1, ) .await? + .builder // filter clause needs the type coercion rule applied .filter(col("c12").lt(lit(0.05)))? .project(vec![col("c12").lt_eq(lit(0.025)).in_list(list, false)])? @@ -1776,6 +1787,7 @@ mod tests { 1, ) .await? + .builder .filter(col("c12").lt(lit(0.05)))? .project(vec![col("c1").in_list(list, false)])? .build()?; @@ -1804,6 +1816,7 @@ mod tests { 1, ) .await? + .builder .filter(col("c12").lt(lit(0.05)))? .project(vec![col("c1").in_list(list, false)])? .build()?; @@ -1828,6 +1841,7 @@ mod tests { 1, ) .await? + .builder .aggregate(vec![col("c1")], vec![sum(col("c2"))])? .build()?; @@ -1861,6 +1875,7 @@ mod tests { 1, ) .await? + .builder .aggregate(vec![col("c1")], vec![sum(col("c2"))])? .build()?; @@ -1881,6 +1896,7 @@ mod tests { let logical_plan = LogicalPlanBuilder::scan_empty(Some("employee"), &schema, None) .unwrap() + .builder .explain(true, false) .unwrap() .build() diff --git a/datafusion/core/src/sql/planner.rs b/datafusion/core/src/sql/planner.rs index cff38d47b4799..7b584216c49b9 100644 --- a/datafusion/core/src/sql/planner.rs +++ b/datafusion/core/src/sql/planner.rs @@ -648,10 +648,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { (Some(cte_plan), _) => Ok(cte_plan.clone()), (_, Some(provider)) => { let scan = - LogicalPlanBuilder::scan(&table_name, provider, None); + LogicalPlanBuilder::scan(&table_name, provider, None)?; let scan = match alias { - Some(ref name) => scan?.alias(name.name.value.as_str()), - _ => scan, + Some(ref name) => { + scan.builder.alias(name.name.value.as_str()) + } + _ => Ok(scan.builder), }; scan?.build() } diff --git a/datafusion/core/src/test/mod.rs b/datafusion/core/src/test/mod.rs index 4bdf6d666d717..48e5c6dc04e34 100644 --- a/datafusion/core/src/test/mod.rs +++ b/datafusion/core/src/test/mod.rs @@ -18,16 +18,22 @@ //! Common unit test utility methods use crate::arrow::array::UInt32Array; +use crate::catalog::catalog::{ + CatalogList, CatalogProvider, MemoryCatalogList, MemoryCatalogProvider, +}; +use crate::catalog::schema::MemorySchemaProvider; +use crate::catalog::schema::SchemaProvider; use crate::datasource::{ listing::{local_unpartitioned_file, PartitionedFile}, MemTable, TableProvider, }; use crate::error::Result; +use crate::execution::context::{DEFAULT_CATALOG, DEFAULT_SCHEMA}; use crate::from_slice::FromSlice; use crate::logical_plan::{LogicalPlan, LogicalPlanBuilder}; use array::{Array, ArrayRef}; use arrow::array::{self, DecimalBuilder, Int32Array}; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use futures::{Future, FutureExt}; use std::fs::File; @@ -105,13 +111,17 @@ pub fn create_partitioned_csv( Ok((tmp_dir.into_path().to_str().unwrap().to_string(), groups)) } -/// some tests share a common table with different names -pub fn test_table_scan_with_name(name: &str) -> Result { - let schema = Schema::new(vec![ +pub fn test_table_schema() -> Schema { + Schema::new(vec![ Field::new("a", DataType::UInt32, false), Field::new("b", DataType::UInt32, false), Field::new("c", DataType::UInt32, false), - ]); + ]) +} + +/// some tests share a common table with different names +pub fn test_table_scan_with_name(name: &str) -> Result { + let schema = test_table_schema(); LogicalPlanBuilder::scan_empty(Some(name), &schema, None)?.build() } @@ -120,6 +130,23 @@ pub fn test_table_scan() -> Result { test_table_scan_with_name("test") } +pub fn create_test_table_catalog_list() -> Arc { + let catalog_list = MemoryCatalogList::default(); + let catalog = Arc::new(MemoryCatalogProvider::new()); + catalog_list.register_catalog(DEFAULT_CATALOG.to_owned(), catalog.clone()); + let schema = Arc::new(MemorySchemaProvider::new()); + + let test_table = + Arc::new(MemTable::try_new(SchemaRef::new(test_table_schema()), vec![]).unwrap()); + for table_name in &["test", "test2"] { + schema + .register_table(table_name.to_string(), test_table.clone()) + .unwrap(); + } + catalog.register_schema(DEFAULT_SCHEMA, schema).unwrap(); + Arc::new(catalog_list) +} + pub fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) { let actual: Vec = plan .schema() diff --git a/datafusion/core/tests/custom_sources.rs b/datafusion/core/tests/custom_sources.rs index 0916e966cf8b9..a3326669ff98b 100644 --- a/datafusion/core/tests/custom_sources.rs +++ b/datafusion/core/tests/custom_sources.rs @@ -203,19 +203,19 @@ impl TableProvider for CustomTableProvider { async fn custom_source_dataframe() -> Result<()> { let ctx = SessionContext::new(); - let table = ctx.read_table(Arc::new(CustomTableProvider))?; - let logical_plan = LogicalPlanBuilder::from(table.to_logical_plan()) - .project(vec![col("c2")])? - .build()?; + let provider = Arc::new(CustomTableProvider); + let table = ctx.read_table(provider.clone())?; + let builder = + LogicalPlanBuilder::from(table.to_logical_plan()).project(vec![col("c2")])?; + let logical_plan = builder.build()?; let optimized_plan = ctx.optimize(&logical_plan)?; match &optimized_plan { LogicalPlan::Projection(Projection { input, .. }) => match &**input { LogicalPlan::TableScan(TableScan { - source, - projected_schema, - .. + projected_schema, .. }) => { + let source = provider; assert_eq!(source.schema().fields().len(), 2); assert_eq!(projected_schema.fields().len(), 1); } diff --git a/datafusion/core/tests/parquet_pruning.rs b/datafusion/core/tests/parquet_pruning.rs index d5392e9dcbff5..f53e8146fe4d3 100644 --- a/datafusion/core/tests/parquet_pruning.rs +++ b/datafusion/core/tests/parquet_pruning.rs @@ -501,6 +501,7 @@ impl ContextWithParquet { let sql = format!("EXPR only: {:?}", expr); let logical_plan = LogicalPlanBuilder::scan("t", self.provider.clone(), None) .unwrap() + .builder .filter(expr) .unwrap() .build() diff --git a/datafusion/core/tests/simplification.rs b/datafusion/core/tests/simplification.rs index fe5f5e254b523..f47def9ae108f 100644 --- a/datafusion/core/tests/simplification.rs +++ b/datafusion/core/tests/simplification.rs @@ -59,7 +59,7 @@ impl From for MyInfo { fn from(schema: DFSchema) -> Self { Self { schema, - execution_props: ExecutionProps::new(), + execution_props: ExecutionProps::default(), } } } diff --git a/datafusion/core/tests/sql/explain.rs b/datafusion/core/tests/sql/explain.rs index b85228016e507..bd68f1e7e5933 100644 --- a/datafusion/core/tests/sql/explain.rs +++ b/datafusion/core/tests/sql/explain.rs @@ -25,13 +25,14 @@ use datafusion::{ fn optimize_explain() { let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); - let plan = LogicalPlanBuilder::scan_empty(Some("employee"), &schema, None) - .unwrap() - .explain(true, false) - .unwrap() - .build() + let scan = LogicalPlanBuilder::scan_empty(Some("employee"), &schema, None).unwrap(); + + let ctx = SessionContext::new(); + ctx.register_table(scan.table_name.as_str(), scan.provider) .unwrap(); + let plan = scan.builder.explain(true, false).unwrap().build().unwrap(); + if let LogicalPlan::Explain(e) = &plan { assert_eq!(e.stringified_plans.len(), 1); } else { @@ -39,7 +40,7 @@ fn optimize_explain() { } // now optimize the plan and expect to see more plans - let optimized_plan = SessionContext::new().optimize(&plan).unwrap(); + let optimized_plan = ctx.optimize(&plan).unwrap(); if let LogicalPlan::Explain(e) = &optimized_plan { // should have more than one plan assert!( diff --git a/datafusion/core/tests/sql/projection.rs b/datafusion/core/tests/sql/projection.rs index d717698e45222..01b2c65b9d5f5 100644 --- a/datafusion/core/tests/sql/projection.rs +++ b/datafusion/core/tests/sql/projection.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use datafusion::catalog::TableReference; use datafusion::logical_plan::{LogicalPlanBuilder, UNNAMED_TABLE}; use tempfile::TempDir; @@ -149,10 +150,13 @@ async fn projection_on_table_scan() -> Result<()> { match &optimized_plan { LogicalPlan::Projection(Projection { input, .. }) => match &**input { LogicalPlan::TableScan(TableScan { - source, + table_name, projected_schema, .. }) => { + let source = ctx + .table(TableReference::Bare { table: &table_name }) + .expect("table provider found"); assert_eq!(source.schema().fields().len(), 3); assert_eq!(projected_schema.fields().len(), 1); } @@ -184,9 +188,9 @@ async fn preserve_nullability_on_projection() -> Result<()> { let schema: Schema = ctx.table("test").unwrap().schema().clone().into(); assert!(!schema.field_with_name("c1")?.is_nullable()); - let plan = LogicalPlanBuilder::scan_empty(None, &schema, None)? - .project(vec![col("c1")])? - .build()?; + let scan = LogicalPlanBuilder::scan_empty(None, &schema, None)?; + ctx.register_table(scan.table_name.as_str(), scan.provider)?; + let plan = scan.builder.project(vec![col("c1")])?.build()?; let plan = ctx.optimize(&plan)?; let physical_plan = ctx.create_physical_plan(&Arc::new(plan)).await?; @@ -212,20 +216,25 @@ async fn projection_on_memory_scan() -> Result<()> { ], )?]]; - let plan = LogicalPlanBuilder::scan_memory(partitions, schema, None)? - .project(vec![col("b")])? - .build()?; - assert_fields_eq(&plan, vec!["b"]); + let scan = LogicalPlanBuilder::scan_memory(partitions, schema.clone(), None)?; let ctx = SessionContext::new(); + ctx.register_table(scan.table_name.as_str(), scan.provider)?; + + let plan = scan.builder.project(vec![col("b")])?.build()?; + assert_fields_eq(&plan, vec!["b"]); + let optimized_plan = ctx.optimize(&plan)?; match &optimized_plan { LogicalPlan::Projection(Projection { input, .. }) => match &**input { LogicalPlan::TableScan(TableScan { - source, + table_name, projected_schema, .. }) => { + let source = ctx + .table(TableReference::Bare { table: &table_name }) + .expect("table provider found"); assert_eq!(source.schema().fields().len(), 3); assert_eq!(projected_schema.fields().len(), 1); } diff --git a/datafusion/core/tests/user_defined_plan.rs b/datafusion/core/tests/user_defined_plan.rs index 43e6eeacdb104..9999c3a5157cc 100644 --- a/datafusion/core/tests/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined_plan.rs @@ -86,6 +86,7 @@ use std::task::{Context, Poll}; use std::{any::Any, collections::BTreeMap, fmt, sync::Arc}; use async_trait::async_trait; +use datafusion::catalog::catalog::CatalogList; use datafusion::execution::context::{ExecutionProps, TaskContext}; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::logical_plan::plan::{Extension, Sort}; @@ -285,6 +286,7 @@ impl OptimizerRule for TopKOptimizerRule { &self, plan: &LogicalPlan, execution_props: &ExecutionProps, + catalog_list: &dyn CatalogList, ) -> Result { // Note: this code simply looks for the pattern of a Limit followed by a // Sort and replaces it by a TopK node. It does not handle many @@ -300,7 +302,11 @@ impl OptimizerRule for TopKOptimizerRule { return Ok(LogicalPlan::Extension(Extension { node: Arc::new(TopKPlanNode { k: *n, - input: self.optimize(input.as_ref(), execution_props)?, + input: self.optimize( + input.as_ref(), + execution_props, + catalog_list, + )?, expr: expr[0].clone(), }), })); @@ -310,7 +316,7 @@ impl OptimizerRule for TopKOptimizerRule { // If we didn't find the Limit/Sort combination, recurse as // normal and build the result. - optimize_children(self, plan, execution_props) + optimize_children(self, plan, execution_props, catalog_list) } fn name(&self) -> &str {