From 36b1703fd59b7065811447c52a4b391da7093703 Mon Sep 17 00:00:00 2001 From: Ruan Pearce-Authers Date: Sat, 20 Mar 2021 12:21:42 +0000 Subject: [PATCH 01/17] Require TableProvider to always implement Send + Sync --- rust/benchmarks/src/bin/tpch.rs | 2 +- rust/datafusion/src/datasource/datasource.rs | 2 +- rust/datafusion/src/datasource/memory.rs | 2 +- rust/datafusion/src/execution/context.rs | 18 ++++++------------ rust/datafusion/src/logical_plan/builder.rs | 2 +- rust/datafusion/src/logical_plan/plan.rs | 2 +- rust/datafusion/src/sql/planner.rs | 12 +++--------- rust/datafusion/src/test/mod.rs | 2 +- 8 files changed, 15 insertions(+), 27 deletions(-) diff --git a/rust/benchmarks/src/bin/tpch.rs b/rust/benchmarks/src/bin/tpch.rs index b0a6cedd172..60baa6605f2 100644 --- a/rust/benchmarks/src/bin/tpch.rs +++ b/rust/benchmarks/src/bin/tpch.rs @@ -1105,7 +1105,7 @@ fn get_table( table: &str, table_format: &str, max_concurrency: usize, -) -> Result> { +) -> Result> { match table_format { // dbgen creates .tbl ('|' delimited) files without header "tbl" => { diff --git a/rust/datafusion/src/datasource/datasource.rs b/rust/datafusion/src/datasource/datasource.rs index 4e6ad36160c..e2b07336486 100644 --- a/rust/datafusion/src/datasource/datasource.rs +++ b/rust/datafusion/src/datasource/datasource.rs @@ -67,7 +67,7 @@ pub enum TableProviderFilterPushDown { } /// Source table -pub trait TableProvider { +pub trait TableProvider: Sync + Send { /// Returns the table provider as [`Any`](std::any::Any) so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; diff --git a/rust/datafusion/src/datasource/memory.rs b/rust/datafusion/src/datasource/memory.rs index 514eac6dcda..af404808702 100644 --- a/rust/datafusion/src/datasource/memory.rs +++ b/rust/datafusion/src/datasource/memory.rs @@ -110,7 +110,7 @@ impl MemTable { /// Create a mem table by reading from another data source pub async fn load( - t: Arc, + t: Arc, batch_size: usize, output_partitions: Option, ) -> Result { diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index caf09be708d..4e6c0c481f3 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -239,7 +239,7 @@ impl ExecutionContext { /// Creates a DataFrame for reading a custom TableProvider. pub fn read_table( &mut self, - provider: Arc, + provider: Arc, ) -> Result> { let schema = provider.schema(); let table_scan = LogicalPlan::TableScan { @@ -288,8 +288,8 @@ impl ExecutionContext { pub fn register_table( &mut self, name: &str, - provider: Arc, - ) -> Option> { + provider: Arc, + ) -> Option> { self.state .lock() .unwrap() @@ -300,10 +300,7 @@ impl ExecutionContext { /// Deregisters the named table. /// /// Returns the registered provider, if any - pub fn deregister_table( - &mut self, - name: &str, - ) -> Option> { + pub fn deregister_table(&mut self, name: &str) -> Option> { self.state.lock().unwrap().datasources.remove(name) } @@ -570,7 +567,7 @@ impl ExecutionConfig { #[derive(Clone)] pub struct ExecutionContextState { /// Data sources that are registered with the context - pub datasources: HashMap>, + pub datasources: HashMap>, /// Scalar functions that are registered with the context pub scalar_functions: HashMap>, /// Variable provider that are registered with the context @@ -582,10 +579,7 @@ pub struct ExecutionContextState { } impl ContextProvider for ExecutionContextState { - fn get_table_provider( - &self, - name: &str, - ) -> Option> { + fn get_table_provider(&self, name: &str) -> Option> { self.datasources.get(name).map(|ds| Arc::clone(ds)) } diff --git a/rust/datafusion/src/logical_plan/builder.rs b/rust/datafusion/src/logical_plan/builder.rs index 58dfd0fa5d6..a89e797c7b6 100644 --- a/rust/datafusion/src/logical_plan/builder.rs +++ b/rust/datafusion/src/logical_plan/builder.rs @@ -103,7 +103,7 @@ impl LogicalPlanBuilder { /// Convert a table provider into a builder with a TableScan pub fn scan( name: &str, - provider: Arc, + provider: Arc, projection: Option>, ) -> Result { let schema = provider.schema(); diff --git a/rust/datafusion/src/logical_plan/plan.rs b/rust/datafusion/src/logical_plan/plan.rs index 92110c8daf1..00d25fb0f3f 100644 --- a/rust/datafusion/src/logical_plan/plan.rs +++ b/rust/datafusion/src/logical_plan/plan.rs @@ -134,7 +134,7 @@ pub enum LogicalPlan { /// The name of the table table_name: String, /// The source of the table - source: Arc, + source: Arc, /// Optional column indices to use as a projection projection: Option>, /// The schema description of the output diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index 31cea52a0f8..b6e269bfbd0 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -58,10 +58,7 @@ use super::utils::{ /// functions referenced in SQL statements pub trait ContextProvider { /// Getter for a datasource - fn get_table_provider( - &self, - name: &str, - ) -> Option>; + fn get_table_provider(&self, name: &str) -> Option>; /// Getter for a UDF description fn get_function_meta(&self, name: &str) -> Option>; /// Getter for a UDAF description @@ -2493,10 +2490,7 @@ mod tests { struct MockContextProvider {} impl ContextProvider for MockContextProvider { - fn get_table_provider( - &self, - name: &str, - ) -> Option> { + fn get_table_provider(&self, name: &str) -> Option> { let schema = match name { "person" => Some(Schema::new(vec![ Field::new("id", DataType::UInt32, false), @@ -2540,7 +2534,7 @@ mod tests { ])), _ => None, }; - schema.map(|s| -> Arc { + schema.map(|s| -> Arc { Arc::new(EmptyTable::new(Arc::new(s))) }) } diff --git a/rust/datafusion/src/test/mod.rs b/rust/datafusion/src/test/mod.rs index 75a956f1cf4..04f340a9936 100644 --- a/rust/datafusion/src/test/mod.rs +++ b/rust/datafusion/src/test/mod.rs @@ -29,7 +29,7 @@ use std::io::{BufReader, BufWriter}; use std::sync::Arc; use tempfile::TempDir; -pub fn create_table_dual() -> Arc { +pub fn create_table_dual() -> Arc { let dual_schema = Arc::new(Schema::new(vec![ Field::new("id", DataType::Int32, false), Field::new("name", DataType::Utf8, false), From 866e85f31ff11f5abc96be7aa8b35b31065790e6 Mon Sep 17 00:00:00 2001 From: Ruan Pearce-Authers Date: Sat, 20 Mar 2021 15:16:19 +0000 Subject: [PATCH 02/17] Initial implementation of catalogs and schemas --- rust/datafusion/src/catalog/catalog.rs | 69 +++++++++++++++++ rust/datafusion/src/catalog/mod.rs | 22 ++++++ rust/datafusion/src/catalog/schema.rs | 79 ++++++++++++++++++++ rust/datafusion/src/execution/context.rs | 67 +++++++++++------ rust/datafusion/src/lib.rs | 1 + rust/datafusion/src/physical_plan/parquet.rs | 3 +- rust/datafusion/src/physical_plan/planner.rs | 3 +- 7 files changed, 221 insertions(+), 23 deletions(-) create mode 100644 rust/datafusion/src/catalog/catalog.rs create mode 100644 rust/datafusion/src/catalog/mod.rs create mode 100644 rust/datafusion/src/catalog/schema.rs diff --git a/rust/datafusion/src/catalog/catalog.rs b/rust/datafusion/src/catalog/catalog.rs new file mode 100644 index 00000000000..9255ff1e550 --- /dev/null +++ b/rust/datafusion/src/catalog/catalog.rs @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Describes the interface and built-in implementations of catalogs, +//! representing collections of named schemas. + +use crate::catalog::schema::SchemaProvider; +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; + +/// Represents a catalog, comprising a number of named schemas. +pub trait CatalogProvider: Sync + Send { + /// Retrieves the list of available schema names in this catalog. + fn schema_names(&self) -> Vec; + + /// Retrieves a specific schema from the catalog by name, provided it exists. + fn schema(&self, name: &str) -> Option>; +} + +/// Simple in-memory implementation of a catalog. +pub struct MemoryCatalogProvider { + schemas: RwLock>>, +} + +impl MemoryCatalogProvider { + /// Instantiates a new MemoryCatalogProvider with an empty collection of schemas. + pub fn new() -> Self { + Self { + schemas: RwLock::new(HashMap::new()), + } + } + + /// Adds a new schema to this catalog. + /// If a schema of the same name existed before, it is replaced in the catalog and returned. + pub fn register_schema( + &mut self, + name: impl Into, + schema: Arc, + ) -> Option> { + let mut schemas = self.schemas.write().unwrap(); + schemas.insert(name.into(), schema) + } +} + +impl CatalogProvider for MemoryCatalogProvider { + fn schema_names(&self) -> Vec { + let schemas = self.schemas.read().unwrap(); + schemas.keys().cloned().collect() + } + + fn schema(&self, name: &str) -> Option> { + let schemas = self.schemas.read().unwrap(); + schemas.get(name).cloned() + } +} diff --git a/rust/datafusion/src/catalog/mod.rs b/rust/datafusion/src/catalog/mod.rs new file mode 100644 index 00000000000..c98f6e4254f --- /dev/null +++ b/rust/datafusion/src/catalog/mod.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module contains interfaces and default implementations +//! of table namespacing concepts, including catalogs and schemas. + +pub mod catalog; +pub mod schema; diff --git a/rust/datafusion/src/catalog/schema.rs b/rust/datafusion/src/catalog/schema.rs new file mode 100644 index 00000000000..2786ba9d902 --- /dev/null +++ b/rust/datafusion/src/catalog/schema.rs @@ -0,0 +1,79 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Describes the interface and built-in implementations of schemas, +//! representing collections of named tables. + +use crate::datasource::TableProvider; +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; + +/// Represents a schema, comprising a number of named tables. +pub trait SchemaProvider: Sync + Send { + /// Retrieves the list of available table names in this schema. + fn table_names(&self) -> Vec; + + /// Retrieves a specific table from the schema by name, provided it exists. + fn table(&self, name: &str) -> Option>; +} + +/// Simple in-memory implementation of a schema. +pub struct MemorySchemaProvider { + tables: RwLock>>, +} + +impl MemorySchemaProvider { + /// Instantiates a new MemorySchemaProvider with an empty collection of tables. + pub fn new() -> Self { + Self { + tables: RwLock::new(HashMap::new()), + } + } + + /// Adds a new table to this schema. + /// If a table of the same name existed before, it is replaced in the schema and returned. + pub fn register_table( + &self, + name: impl Into, + table: Arc, + ) -> Option> { + let mut tables = self.tables.write().unwrap(); + tables.insert(name.into(), table) + } + + /// Removes an existing table from this schema and returns it. + /// If no table of that name exists, returns None. + pub fn deregister_table( + &self, + name: impl AsRef, + ) -> Option> { + let mut tables = self.tables.write().unwrap(); + tables.remove(name.as_ref()) + } +} + +impl SchemaProvider for MemorySchemaProvider { + fn table_names(&self) -> Vec { + let tables = self.tables.read().unwrap(); + tables.keys().cloned().collect() + } + + fn table(&self, name: &str) -> Option> { + let tables = self.tables.read().unwrap(); + tables.get(name).cloned() + } +} diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 4e6c0c481f3..baa712c5de9 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -32,6 +32,10 @@ use tokio::task::{self, JoinHandle}; use arrow::csv; +use crate::catalog::{ + catalog::{CatalogProvider, MemoryCatalogProvider}, + schema::{MemorySchemaProvider, SchemaProvider}, +}; use crate::datasource::csv::CsvFile; use crate::datasource::parquet::ParquetTable; use crate::datasource::TableProvider; @@ -111,9 +115,20 @@ impl ExecutionContext { /// Creates a new execution context using the provided configuration. pub fn with_config(config: ExecutionConfig) -> Self { + let default_schema = Arc::new(MemorySchemaProvider::new()); + let mut default_catalog = MemoryCatalogProvider::new(); + default_catalog.register_schema("public", default_schema.clone()); + + let mut catalogs = HashMap::new(); + catalogs.insert( + "datafusion".to_owned(), + Arc::new(default_catalog) as Arc, + ); + Self { state: Arc::new(Mutex::new(ExecutionContextState { - datasources: HashMap::new(), + default_schema: Some(default_schema.clone()), + catalogs, scalar_functions: HashMap::new(), var_provider: HashMap::new(), aggregate_functions: HashMap::new(), @@ -279,6 +294,10 @@ impl ExecutionContext { Ok(()) } + fn get_default_schema(&self) -> Option> { + self.state.lock().unwrap().default_schema.clone() + } + /// Registers a named table using a custom `TableProvider` so that /// it can be referenced from SQL statements executed against this /// context. @@ -290,18 +309,18 @@ impl ExecutionContext { name: &str, provider: Arc, ) -> Option> { - self.state - .lock() - .unwrap() - .datasources - .insert(name.to_string(), provider) + self.get_default_schema() + .expect("no default schema available for table registration") + .register_table(name.to_string(), provider) } /// Deregisters the named table. /// /// Returns the registered provider, if any pub fn deregister_table(&mut self, name: &str) -> Option> { - self.state.lock().unwrap().datasources.remove(name) + self.get_default_schema() + .expect("no default schema available for table deregistration") + .deregister_table(name) } /// Retrieves a DataFrame representing a table previously registered by calling the @@ -309,8 +328,12 @@ impl ExecutionContext { /// /// Returns an error if no table has been registered with the provided name. pub fn table(&self, table_name: &str) -> Result> { - match self.state.lock().unwrap().datasources.get(table_name) { - Some(provider) => { + let default_schema = self + .get_default_schema() + .expect("no default schema available for table retrieval"); + + match default_schema.table(table_name) { + Some(ref provider) => { let schema = provider.schema(); let table_scan = LogicalPlan::TableScan { table_name: table_name.to_string(), @@ -338,11 +361,10 @@ impl ExecutionContext { /// /// [`table`]: ExecutionContext::table pub fn tables(&self) -> HashSet { - self.state - .lock() - .unwrap() - .datasources - .keys() + self.get_default_schema() + .expect("no default schema available for table listing") + .table_names() + .iter() .cloned() .collect() } @@ -566,8 +588,10 @@ impl ExecutionConfig { /// Execution context for registering data sources and executing queries #[derive(Clone)] pub struct ExecutionContextState { - /// Data sources that are registered with the context - pub datasources: HashMap>, + /// + pub default_schema: Option>, + /// + pub catalogs: HashMap>, /// Scalar functions that are registered with the context pub scalar_functions: HashMap>, /// Variable provider that are registered with the context @@ -580,7 +604,10 @@ pub struct ExecutionContextState { impl ContextProvider for ExecutionContextState { fn get_table_provider(&self, name: &str) -> Option> { - self.datasources.get(name).map(|ds| Arc::clone(ds)) + self.default_schema + .as_ref() + .expect("no default schema available for table retrieval") + .table(name) } fn get_function_meta(&self, name: &str) -> Option> { @@ -870,11 +897,9 @@ mod tests { let ctx = create_ctx(&tmp_dir, 1)?; let schema = ctx - .state - .lock() + .get_default_schema() .unwrap() - .datasources - .get("test") + .table("test") .unwrap() .schema(); assert_eq!(schema.field_with_name("c1")?.is_nullable(), false); diff --git a/rust/datafusion/src/lib.rs b/rust/datafusion/src/lib.rs index f0fcc4f1d29..e9bf2b4877a 100644 --- a/rust/datafusion/src/lib.rs +++ b/rust/datafusion/src/lib.rs @@ -157,6 +157,7 @@ extern crate arrow; extern crate sqlparser; +pub mod catalog; pub mod dataframe; pub mod datasource; pub mod error; diff --git a/rust/datafusion/src/physical_plan/parquet.rs b/rust/datafusion/src/physical_plan/parquet.rs index 5e9b1562751..320f1718deb 100644 --- a/rust/datafusion/src/physical_plan/parquet.rs +++ b/rust/datafusion/src/physical_plan/parquet.rs @@ -392,7 +392,8 @@ impl RowGroupPredicateBuilder { .collect::>(); let stat_schema = Schema::new(stat_fields); let execution_context_state = ExecutionContextState { - datasources: HashMap::new(), + default_schema: None, + catalogs: HashMap::new(), scalar_functions: HashMap::new(), var_provider: HashMap::new(), aggregate_functions: HashMap::new(), diff --git a/rust/datafusion/src/physical_plan/planner.rs b/rust/datafusion/src/physical_plan/planner.rs index c6f321ca9b8..d68f6a83326 100644 --- a/rust/datafusion/src/physical_plan/planner.rs +++ b/rust/datafusion/src/physical_plan/planner.rs @@ -764,7 +764,8 @@ mod tests { fn make_ctx_state() -> ExecutionContextState { ExecutionContextState { - datasources: HashMap::new(), + default_schema: None, + catalogs: HashMap::new(), scalar_functions: HashMap::new(), var_provider: HashMap::new(), aggregate_functions: HashMap::new(), From 7ab647e75ed22daa4fd9262130d6d7bd3c63a617 Mon Sep 17 00:00:00 2001 From: Ruan Pearce-Authers Date: Sat, 20 Mar 2021 20:08:26 +0000 Subject: [PATCH 03/17] Implement table references All tests passing with fallible table management --- rust/benchmarks/src/bin/tpch.rs | 6 +- .../datafusion/benches/aggregate_query_sql.rs | 2 +- rust/datafusion/benches/filter_query_sql.rs | 2 +- rust/datafusion/benches/math_query_sql.rs | 2 +- .../benches/sort_limit_query_sql.rs | 3 +- .../examples/dataframe_in_memory.rs | 2 +- rust/datafusion/examples/simple_udaf.rs | 2 +- rust/datafusion/examples/simple_udf.rs | 2 +- rust/datafusion/src/catalog/mod.rs | 143 ++++++++++++++ rust/datafusion/src/catalog/schema.rs | 56 +++--- rust/datafusion/src/execution/context.rs | 178 ++++++++++++------ rust/datafusion/src/physical_plan/parquet.rs | 1 - rust/datafusion/src/physical_plan/planner.rs | 1 - rust/datafusion/src/sql/planner.rs | 14 +- rust/datafusion/tests/dataframe.rs | 4 +- .../tests/provider_filter_pushdown.rs | 2 +- rust/datafusion/tests/sql.rs | 32 ++-- 17 files changed, 340 insertions(+), 112 deletions(-) diff --git a/rust/benchmarks/src/bin/tpch.rs b/rust/benchmarks/src/bin/tpch.rs index 60baa6605f2..61135437b4b 100644 --- a/rust/benchmarks/src/bin/tpch.rs +++ b/rust/benchmarks/src/bin/tpch.rs @@ -157,9 +157,9 @@ async fn benchmark(opt: BenchmarkOpt) -> Result Result Arc> { // create local execution context let mut ctx = ExecutionContext::new(); ctx.state.lock().unwrap().config.concurrency = 1; - ctx.register_table("aggregate_test_100", Arc::new(mem_table)); + ctx.register_table("aggregate_test_100", Arc::new(mem_table)) + .unwrap(); ctx_holder.lock().unwrap().push(Arc::new(Mutex::new(ctx))) }); diff --git a/rust/datafusion/examples/dataframe_in_memory.rs b/rust/datafusion/examples/dataframe_in_memory.rs index 28414bf8700..a2e78fba50a 100644 --- a/rust/datafusion/examples/dataframe_in_memory.rs +++ b/rust/datafusion/examples/dataframe_in_memory.rs @@ -49,7 +49,7 @@ async fn main() -> Result<()> { // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, vec![vec![batch]])?; - ctx.register_table("t", Arc::new(provider)); + ctx.register_table("t", Arc::new(provider)).unwrap(); let df = ctx.table("t")?; // construct an expression corresponding to "SELECT a, b FROM t WHERE b = 10" in SQL diff --git a/rust/datafusion/examples/simple_udaf.rs b/rust/datafusion/examples/simple_udaf.rs index a36d200235a..55aa350b13d 100644 --- a/rust/datafusion/examples/simple_udaf.rs +++ b/rust/datafusion/examples/simple_udaf.rs @@ -48,7 +48,7 @@ fn create_context() -> Result { // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?; - ctx.register_table("t", Arc::new(provider)); + ctx.register_table("t", Arc::new(provider))?; Ok(ctx) } diff --git a/rust/datafusion/examples/simple_udf.rs b/rust/datafusion/examples/simple_udf.rs index d49aac48527..00debdbddac 100644 --- a/rust/datafusion/examples/simple_udf.rs +++ b/rust/datafusion/examples/simple_udf.rs @@ -50,7 +50,7 @@ fn create_context() -> Result { // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, vec![vec![batch]])?; - ctx.register_table("t", Arc::new(provider)); + ctx.register_table("t", Arc::new(provider))?; Ok(ctx) } diff --git a/rust/datafusion/src/catalog/mod.rs b/rust/datafusion/src/catalog/mod.rs index c98f6e4254f..84a626ebf27 100644 --- a/rust/datafusion/src/catalog/mod.rs +++ b/rust/datafusion/src/catalog/mod.rs @@ -20,3 +20,146 @@ pub mod catalog; pub mod schema; + +use crate::error::DataFusionError; +use std::convert::TryFrom; + +/// Represents a resolved path to a table of the form "catalog.schema.table" +#[derive(Clone, Copy)] +pub struct ResolvedTableReference<'a> { + /// The catalog (aka database) containing the table + pub catalog: &'a str, + /// The schema containing the table + pub schema: &'a str, + /// The table name + pub table: &'a str, +} + +impl<'a> std::fmt::Display for ResolvedTableReference<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}.{}.{}", self.catalog, self.schema, self.table) + } +} + +/// Represents a path to a table that may require further resolution +#[derive(Clone, Copy)] +pub enum TableReference<'a> { + /// An unqualified table reference, e.g. "table" + Bare { + /// The table name + table: &'a str, + }, + /// A partially resolved table reference, e.g. "schema.table" + Partial { + /// The schema containing the table + schema: &'a str, + /// The table name + table: &'a str, + }, + /// A fully resolved table reference, e.g. "catalog.schema.table" + Full { + /// The catalog (aka database) containing the table + catalog: &'a str, + /// The schema containing the table + schema: &'a str, + /// The table name + table: &'a str, + }, +} + +impl<'a> TableReference<'a> { + /// Retrieve the actual table name, regardless of qualification + pub fn table(&self) -> &str { + match self { + Self::Full { table, .. } + | Self::Partial { table, .. } + | Self::Bare { table } => table, + } + } + + /// Given a default catalog and schema, ensure this table reference is fully resolved + pub fn resolve( + self, + default_catalog: &'a str, + default_schema: &'a str, + ) -> ResolvedTableReference<'a> { + match self { + Self::Full { + catalog, + schema, + table, + } => ResolvedTableReference { + catalog, + schema, + table, + }, + Self::Partial { schema, table } => ResolvedTableReference { + catalog: default_catalog, + schema, + table, + }, + Self::Bare { table } => ResolvedTableReference { + catalog: default_catalog, + schema: default_schema, + table, + }, + } + } +} + +impl<'a> std::fmt::Display for TableReference<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Bare { table } => write!(f, "{}", table), + Self::Partial { schema, table } => write!(f, "{}.{}", schema, table), + Self::Full { + catalog, + schema, + table, + } => write!(f, "{}.{}.{}", catalog, schema, table), + } + } +} + +impl<'a> From<&'a str> for TableReference<'a> { + fn from(s: &'a str) -> Self { + Self::Bare { table: s } + } +} + +impl<'a> From> for TableReference<'a> { + fn from(resolved: ResolvedTableReference<'a>) -> Self { + Self::Full { + catalog: resolved.catalog, + schema: resolved.schema, + table: resolved.table, + } + } +} + +impl<'a> TryFrom<&'a sqlparser::ast::ObjectName> for TableReference<'a> { + type Error = DataFusionError; + + fn try_from(value: &'a sqlparser::ast::ObjectName) -> Result { + let idents = &value.0; + + match idents.len() { + 1 => Ok(Self::Bare { + table: &idents[0].value, + }), + 2 => Ok(Self::Partial { + schema: &idents[0].value, + table: &idents[1].value, + }), + 3 => Ok(Self::Full { + catalog: &idents[0].value, + schema: &idents[1].value, + table: &idents[2].value, + }), + _ => Err(DataFusionError::Plan(format!( + "invalid table reference: {}", + value + ))), + } + } +} diff --git a/rust/datafusion/src/catalog/schema.rs b/rust/datafusion/src/catalog/schema.rs index 2786ba9d902..827e2d75510 100644 --- a/rust/datafusion/src/catalog/schema.rs +++ b/rust/datafusion/src/catalog/schema.rs @@ -19,6 +19,7 @@ //! representing collections of named tables. use crate::datasource::TableProvider; +use crate::error::{DataFusionError, Result}; use std::collections::HashMap; use std::sync::{Arc, RwLock}; @@ -29,6 +30,26 @@ pub trait SchemaProvider: Sync + Send { /// Retrieves a specific table from the schema by name, provided it exists. fn table(&self, name: &str) -> Option>; + + /// If supported by the implementation, adds a new table to this schema. + /// If a table of the same name existed before, it is replaced in the schema and returned. + fn register_table( + &self, + _name: String, + _table: Arc, + ) -> Result>> { + Err(DataFusionError::Execution( + "schema provider does not support registering tables".to_owned(), + )) + } + + /// If supported by the implementation, removes an existing table from this schema and returns it. + /// If no table of that name exists, returns Ok(None). + fn deregister_table(&self, _name: &str) -> Result>> { + Err(DataFusionError::Execution( + "schema provider does not support deregistering tables".to_owned(), + )) + } } /// Simple in-memory implementation of a schema. @@ -43,27 +64,6 @@ impl MemorySchemaProvider { tables: RwLock::new(HashMap::new()), } } - - /// Adds a new table to this schema. - /// If a table of the same name existed before, it is replaced in the schema and returned. - pub fn register_table( - &self, - name: impl Into, - table: Arc, - ) -> Option> { - let mut tables = self.tables.write().unwrap(); - tables.insert(name.into(), table) - } - - /// Removes an existing table from this schema and returns it. - /// If no table of that name exists, returns None. - pub fn deregister_table( - &self, - name: impl AsRef, - ) -> Option> { - let mut tables = self.tables.write().unwrap(); - tables.remove(name.as_ref()) - } } impl SchemaProvider for MemorySchemaProvider { @@ -76,4 +76,18 @@ impl SchemaProvider for MemorySchemaProvider { let tables = self.tables.read().unwrap(); tables.get(name).cloned() } + + fn register_table( + &self, + name: String, + table: Arc, + ) -> Result>> { + let mut tables = self.tables.write().unwrap(); + Ok(tables.insert(name, table)) + } + + fn deregister_table(&self, name: &str) -> Result>> { + let mut tables = self.tables.write().unwrap(); + Ok(tables.remove(name)) + } } diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index baa712c5de9..102b48fd9cf 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -35,6 +35,7 @@ use arrow::csv; use crate::catalog::{ catalog::{CatalogProvider, MemoryCatalogProvider}, schema::{MemorySchemaProvider, SchemaProvider}, + ResolvedTableReference, TableReference, }; use crate::datasource::csv::CsvFile; use crate::datasource::parquet::ParquetTable; @@ -115,19 +116,23 @@ impl ExecutionContext { /// Creates a new execution context using the provided configuration. pub fn with_config(config: ExecutionConfig) -> Self { - let default_schema = Arc::new(MemorySchemaProvider::new()); - let mut default_catalog = MemoryCatalogProvider::new(); - default_catalog.register_schema("public", default_schema.clone()); - let mut catalogs = HashMap::new(); - catalogs.insert( - "datafusion".to_owned(), - Arc::new(default_catalog) as Arc, - ); + + if config.create_default_catalog_and_schema { + let mut default_catalog = MemoryCatalogProvider::new(); + default_catalog.register_schema( + config.default_schema.clone(), + Arc::new(MemorySchemaProvider::new()), + ); + + catalogs.insert( + config.default_catalog.clone(), + Arc::new(default_catalog) as Arc, + ); + } Self { state: Arc::new(Mutex::new(ExecutionContextState { - default_schema: Some(default_schema.clone()), catalogs, scalar_functions: HashMap::new(), var_provider: HashMap::new(), @@ -279,7 +284,7 @@ impl ExecutionContext { filename: &str, options: CsvReadOptions, ) -> Result<()> { - self.register_table(name, Arc::new(CsvFile::try_new(filename, options)?)); + self.register_table(name, Arc::new(CsvFile::try_new(filename, options)?))?; Ok(()) } @@ -290,12 +295,19 @@ impl ExecutionContext { &filename, self.state.lock().unwrap().config.concurrency, )?; - self.register_table(name, Arc::new(table)); + self.register_table(name, Arc::new(table))?; Ok(()) } - fn get_default_schema(&self) -> Option> { - self.state.lock().unwrap().default_schema.clone() + fn get_schema<'a>( + &'a self, + table_ref: impl Into>, + ) -> Result> { + self.state + .lock() + .unwrap() + .schema_for_ref(table_ref.into()) + .ok_or_else(|| DataFusionError::Plan("failed to resolve schema".to_owned())) } /// Registers a named table using a custom `TableProvider` so that @@ -304,39 +316,55 @@ impl ExecutionContext { /// /// Returns the `TableProvider` previously registered for this /// name, if any - pub fn register_table( - &mut self, - name: &str, + pub fn register_table<'a>( + &'a mut self, + table_ref: impl Into>, provider: Arc, - ) -> Option> { - self.get_default_schema() - .expect("no default schema available for table registration") - .register_table(name.to_string(), provider) + ) -> Result>> { + let table_ref = table_ref.into(); + self.get_schema(table_ref)? + .register_table(table_ref.table().to_owned(), provider) } /// Deregisters the named table. /// /// Returns the registered provider, if any - pub fn deregister_table(&mut self, name: &str) -> Option> { - self.get_default_schema() - .expect("no default schema available for table deregistration") - .deregister_table(name) + pub fn deregister_table<'a>( + &'a mut self, + table_ref: impl Into>, + ) -> Result>> { + let table_ref = table_ref.into(); + self.get_schema(table_ref)? + .deregister_table(table_ref.table()) } /// Retrieves a DataFrame representing a table previously registered by calling the /// register_table function. /// /// Returns an error if no table has been registered with the provided name. - pub fn table(&self, table_name: &str) -> Result> { - let default_schema = self - .get_default_schema() - .expect("no default schema available for table retrieval"); + pub fn table<'a>( + &self, + table_ref: impl Into>, + ) -> Result> { + let table_ref = table_ref.into(); + + let (table_name, table) = { + let state = self.state.lock().unwrap(); + let resolved_ref = state.resolve_table_ref(table_ref); + let table = state + .schema_for_ref(resolved_ref) + .and_then(|s| s.table(resolved_ref.table)); + + // intentionally returning only the table name for now + // TODO: + (resolved_ref.table.to_string(), table) + }; - match default_schema.table(table_name) { + match table { Some(ref provider) => { let schema = provider.schema(); let table_scan = LogicalPlan::TableScan { - table_name: table_name.to_string(), + table_name: table_name, source: Arc::clone(provider), projected_schema: schema.to_dfschema_ref()?, projection: None, @@ -361,9 +389,11 @@ impl ExecutionContext { /// /// [`table`]: ExecutionContext::table pub fn tables(&self) -> HashSet { - self.get_default_schema() - .expect("no default schema available for table listing") - .table_names() + // quick hack to get default schema + self.get_schema(TableReference::Bare { table: "" }) + .ok() + .map(|s| s.table_names()) + .unwrap_or_default() .iter() .cloned() .collect() @@ -531,6 +561,12 @@ pub struct ExecutionConfig { optimizers: Vec>, /// Responsible for planning `LogicalPlan`s, and `ExecutionPlan` query_planner: Arc, + /// Default catalog name for table resolution + default_catalog: String, + /// Default schema name for table resolution + default_schema: String, + /// Whether the default catalog and schema should be created automatically + create_default_catalog_and_schema: bool, } impl ExecutionConfig { @@ -547,6 +583,9 @@ impl ExecutionConfig { Arc::new(LimitPushDown::new()), ], query_planner: Arc::new(DefaultQueryPlanner {}), + default_catalog: "datafusion".to_owned(), + default_schema: "public".to_owned(), + create_default_catalog_and_schema: true, } } @@ -583,13 +622,22 @@ impl ExecutionConfig { self.optimizers.push(optimizer_rule); self } + + /// Selects a name for the default catalog and schema + pub fn with_default_catalog_and_schema( + mut self, + catalog: impl Into, + schema: impl Into, + ) -> Self { + self.default_catalog = catalog.into(); + self.default_schema = schema.into(); + self + } } /// Execution context for registering data sources and executing queries #[derive(Clone)] pub struct ExecutionContextState { - /// - pub default_schema: Option>, /// pub catalogs: HashMap>, /// Scalar functions that are registered with the context @@ -602,12 +650,35 @@ pub struct ExecutionContextState { pub config: ExecutionConfig, } +impl ExecutionContextState { + fn resolve_table_ref<'a>( + &'a self, + table_ref: impl Into>, + ) -> ResolvedTableReference<'a> { + table_ref + .into() + .resolve(&self.config.default_catalog, &self.config.default_schema) + } + + fn schema_for_ref<'a>( + &'a self, + table_ref: impl Into>, + ) -> Option> { + let resolved_ref = self.resolve_table_ref(table_ref.into()); + + Some( + self.catalogs + .get(resolved_ref.catalog)? + .schema(resolved_ref.schema)?, + ) + } +} + impl ContextProvider for ExecutionContextState { - fn get_table_provider(&self, name: &str) -> Option> { - self.default_schema - .as_ref() - .expect("no default schema available for table retrieval") - .table(name) + fn get_table_provider(&self, name: TableReference) -> Option> { + let resolved_ref = self.resolve_table_ref(name); + let schema = self.schema_for_ref(resolved_ref)?; + schema.table(resolved_ref.table) } fn get_function_meta(&self, name: &str) -> Option> { @@ -752,7 +823,7 @@ mod tests { ctx.register_variable(VarType::UserDefined, Arc::new(variable_provider)); let provider = test::create_table_dual(); - ctx.register_table("dual", provider); + ctx.register_table("dual", provider).unwrap(); let results = plan_and_collect(&mut ctx, "SELECT @@version, @name FROM dual").await?; @@ -776,10 +847,10 @@ mod tests { let mut ctx = create_ctx(&tmp_dir, partition_count)?; let provider = test::create_table_dual(); - ctx.register_table("dual", provider); + ctx.register_table("dual", provider).unwrap(); - assert!(ctx.deregister_table("dual").is_some()); - assert!(ctx.deregister_table("dual").is_none()); + assert!(ctx.deregister_table("dual").unwrap().is_some()); + assert!(ctx.deregister_table("dual").unwrap().is_none()); Ok(()) } @@ -896,15 +967,10 @@ mod tests { let tmp_dir = TempDir::new()?; let ctx = create_ctx(&tmp_dir, 1)?; - let schema = ctx - .get_default_schema() - .unwrap() - .table("test") - .unwrap() - .schema(); + let schema: Schema = ctx.table("test").unwrap().schema().clone().into(); assert_eq!(schema.field_with_name("c1")?.is_nullable(), false); - let plan = LogicalPlanBuilder::scan_empty("", schema.as_ref(), None)? + let plan = LogicalPlanBuilder::scan_empty("", &schema, None)? .project(&[col("c1")])? .build()?; @@ -1381,7 +1447,7 @@ mod tests { .unwrap(); let provider = MemTable::try_new(schema.clone(), vec![vec![batch]]).unwrap(); - ctx.register_table("t", Arc::new(provider)); + ctx.register_table("t", Arc::new(provider)).unwrap(); let results = plan_and_collect( &mut ctx, @@ -1764,7 +1830,7 @@ mod tests { let mut ctx = ExecutionContext::new(); let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]])?; - ctx.register_table("t", Arc::new(provider)); + ctx.register_table("t", Arc::new(provider)).unwrap(); let myfunc = |args: &[ArrayRef]| { let l = &args[0] @@ -1844,7 +1910,7 @@ mod tests { assert_eq!(a.value(i) + b.value(i), sum.value(i)); } - ctx.deregister_table("t"); + ctx.deregister_table("t").unwrap(); Ok(()) } @@ -1866,7 +1932,7 @@ mod tests { let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; - ctx.register_table("t", Arc::new(provider)); + ctx.register_table("t", Arc::new(provider)).unwrap(); let result = plan_and_collect(&mut ctx, "SELECT AVG(a) FROM t").await?; @@ -1903,7 +1969,7 @@ mod tests { let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; - ctx.register_table("t", Arc::new(provider)); + ctx.register_table("t", Arc::new(provider)).unwrap(); // define a udaf, using a DataFusion's accumulator let my_avg = create_udaf( diff --git a/rust/datafusion/src/physical_plan/parquet.rs b/rust/datafusion/src/physical_plan/parquet.rs index 320f1718deb..569eb59847e 100644 --- a/rust/datafusion/src/physical_plan/parquet.rs +++ b/rust/datafusion/src/physical_plan/parquet.rs @@ -392,7 +392,6 @@ impl RowGroupPredicateBuilder { .collect::>(); let stat_schema = Schema::new(stat_fields); let execution_context_state = ExecutionContextState { - default_schema: None, catalogs: HashMap::new(), scalar_functions: HashMap::new(), var_provider: HashMap::new(), diff --git a/rust/datafusion/src/physical_plan/planner.rs b/rust/datafusion/src/physical_plan/planner.rs index d68f6a83326..81500b3522c 100644 --- a/rust/datafusion/src/physical_plan/planner.rs +++ b/rust/datafusion/src/physical_plan/planner.rs @@ -764,7 +764,6 @@ mod tests { fn make_ctx_state() -> ExecutionContextState { ExecutionContextState { - default_schema: None, catalogs: HashMap::new(), scalar_functions: HashMap::new(), var_provider: HashMap::new(), diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index b6e269bfbd0..8f7d912e978 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -17,9 +17,11 @@ //! SQL Query Planner (produces logical plan from SQL AST) +use std::convert::TryInto; use std::str::FromStr; use std::sync::Arc; +use crate::catalog::TableReference; use crate::datasource::TableProvider; use crate::logical_plan::Expr::Alias; use crate::logical_plan::{ @@ -58,7 +60,7 @@ use super::utils::{ /// functions referenced in SQL statements pub trait ContextProvider { /// Getter for a datasource - fn get_table_provider(&self, name: &str) -> Option>; + fn get_table_provider(&self, name: TableReference) -> Option>; /// Getter for a UDF description fn get_function_meta(&self, name: &str) -> Option>; /// Getter for a UDAF description @@ -373,7 +375,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { match relation { TableFactor::Table { name, .. } => { let table_name = name.to_string(); - match self.schema_provider.get_table_provider(&table_name) { + match self.schema_provider.get_table_provider(name.try_into()?) { Some(provider) => { LogicalPlanBuilder::scan(&table_name, provider, None)?.build() } @@ -2490,8 +2492,12 @@ mod tests { struct MockContextProvider {} impl ContextProvider for MockContextProvider { - fn get_table_provider(&self, name: &str) -> Option> { - let schema = match name { + fn get_table_provider( + &self, + name: TableReference, + ) -> Option> { + let resolved_ref = name.resolve("", ""); + let schema = match resolved_ref.table { "person" => Some(Schema::new(vec![ Field::new("id", DataType::UInt32, false), Field::new("first_name", DataType::Utf8, false), diff --git a/rust/datafusion/tests/dataframe.rs b/rust/datafusion/tests/dataframe.rs index e0c698ed5fb..9d5f92a7753 100644 --- a/rust/datafusion/tests/dataframe.rs +++ b/rust/datafusion/tests/dataframe.rs @@ -61,11 +61,11 @@ async fn join() -> Result<()> { let table1 = MemTable::try_new(schema1, vec![vec![batch1]])?; let table2 = MemTable::try_new(schema2, vec![vec![batch2]])?; - ctx.register_table("aa", Arc::new(table1)); + ctx.register_table("aa", Arc::new(table1))?; let df1 = ctx.table("aa")?; - ctx.register_table("aaa", Arc::new(table2)); + ctx.register_table("aaa", Arc::new(table2))?; let df2 = ctx.table("aaa")?; diff --git a/rust/datafusion/tests/provider_filter_pushdown.rs b/rust/datafusion/tests/provider_filter_pushdown.rs index a64f7fb74fb..f38ac59341e 100644 --- a/rust/datafusion/tests/provider_filter_pushdown.rs +++ b/rust/datafusion/tests/provider_filter_pushdown.rs @@ -156,7 +156,7 @@ async fn assert_provider_row_count(value: i64, expected_count: u64) -> Result<() let result_col: &UInt64Array = as_primitive_array(results[0].column(0)); assert_eq!(result_col.value(0), expected_count); - ctx.register_table("data", Arc::new(provider)); + ctx.register_table("data", Arc::new(provider))?; let sql_results = ctx .sql(&format!("select count(*) from data where flag = {}", value))? .collect() diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 1bce3f7c07c..93e478d80e6 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -1168,7 +1168,7 @@ fn create_case_context() -> Result { ]))], )?; let table = MemTable::try_new(schema, vec![vec![data]])?; - ctx.register_table("t1", Arc::new(table)); + ctx.register_table("t1", Arc::new(table))?; Ok(ctx) } @@ -1317,7 +1317,7 @@ fn create_join_context( ], )?; let t1_table = MemTable::try_new(t1_schema, vec![vec![t1_data]])?; - ctx.register_table("t1", Arc::new(t1_table)); + ctx.register_table("t1", Arc::new(t1_table))?; let t2_schema = Arc::new(Schema::new(vec![ Field::new(column_right, DataType::UInt32, true), @@ -1336,7 +1336,7 @@ fn create_join_context( ], )?; let t2_table = MemTable::try_new(t2_schema, vec![vec![t2_data]])?; - ctx.register_table("t2", Arc::new(t2_table)); + ctx.register_table("t2", Arc::new(t2_table))?; Ok(ctx) } @@ -1358,7 +1358,7 @@ fn create_join_context_qualified() -> Result { ], )?; let t1_table = MemTable::try_new(t1_schema, vec![vec![t1_data]])?; - ctx.register_table("t1", Arc::new(t1_table)); + ctx.register_table("t1", Arc::new(t1_table))?; let t2_schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::UInt32, true), @@ -1374,7 +1374,7 @@ fn create_join_context_qualified() -> Result { ], )?; let t2_table = MemTable::try_new(t2_schema, vec![vec![t2_data]])?; - ctx.register_table("t2", Arc::new(t2_table)); + ctx.register_table("t2", Arc::new(t2_table))?; Ok(ctx) } @@ -1594,7 +1594,7 @@ async fn generic_query_length>>( let table = MemTable::try_new(schema, vec![vec![data]])?; let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table)); + ctx.register_table("test", Arc::new(table))?; let sql = "SELECT length(c1) FROM test"; let actual = execute(&mut ctx, sql).await; let expected = vec![vec!["0"], vec!["1"], vec!["2"], vec!["3"]]; @@ -1630,7 +1630,7 @@ async fn query_not() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table)); + ctx.register_table("test", Arc::new(table))?; let sql = "SELECT NOT c1 FROM test"; let actual = execute(&mut ctx, sql).await; let expected = vec![vec!["true"], vec!["NULL"], vec!["false"]]; @@ -1656,7 +1656,7 @@ async fn query_concat() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table)); + ctx.register_table("test", Arc::new(table))?; let sql = "SELECT concat(c1, '-hi-', cast(c2 as varchar)) FROM test"; let actual = execute(&mut ctx, sql).await; let expected = vec![ @@ -1687,7 +1687,7 @@ async fn query_array() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table)); + ctx.register_table("test", Arc::new(table))?; let sql = "SELECT array(c1, cast(c2 as varchar)) FROM test"; let actual = execute(&mut ctx, sql).await; let expected = vec![ @@ -1780,7 +1780,7 @@ fn make_timestamp_nano_table() -> Result> { #[tokio::test] async fn to_timestamp() -> Result<()> { let mut ctx = ExecutionContext::new(); - ctx.register_table("ts_data", make_timestamp_nano_table()?); + ctx.register_table("ts_data", make_timestamp_nano_table()?)?; let sql = "SELECT COUNT(*) FROM ts_data where ts > to_timestamp('2020-09-08T12:00:00+00:00')"; let actual = execute(&mut ctx, sql).await; @@ -1806,7 +1806,7 @@ async fn query_is_null() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table)); + ctx.register_table("test", Arc::new(table))?; let sql = "SELECT c1 IS NULL FROM test"; let actual = execute(&mut ctx, sql).await; let expected = vec![vec!["false"], vec!["true"], vec!["false"]]; @@ -1830,7 +1830,7 @@ async fn query_is_not_null() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table)); + ctx.register_table("test", Arc::new(table))?; let sql = "SELECT c1 IS NOT NULL FROM test"; let actual = execute(&mut ctx, sql).await; let expected = vec![vec!["true"], vec!["false"], vec!["true"]]; @@ -1857,7 +1857,7 @@ async fn query_count_distinct() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table)); + ctx.register_table("test", Arc::new(table))?; let sql = "SELECT COUNT(DISTINCT c1) FROM test"; let actual = execute(&mut ctx, sql).await; let expected = vec![vec!["3".to_string()]]; @@ -1886,7 +1886,7 @@ async fn query_on_string_dictionary() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table)); + ctx.register_table("test", Arc::new(table))?; // Basic SELECT let sql = "SELECT * FROM test"; @@ -1957,7 +1957,7 @@ async fn query_scalar_minus_array() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table)); + ctx.register_table("test", Arc::new(table))?; let sql = "SELECT 4 - c1 FROM test"; let actual = execute(&mut ctx, sql).await; let expected = vec![vec!["4"], vec!["3"], vec!["NULL"], vec!["1"]]; @@ -2036,7 +2036,7 @@ async fn csv_group_by_date() -> Result<()> { )?; let table = MemTable::try_new(schema, vec![vec![data]])?; - ctx.register_table("dates", Arc::new(table)); + ctx.register_table("dates", Arc::new(table))?; let sql = "SELECT SUM(cnt) FROM dates GROUP BY date"; let actual = execute(&mut ctx, sql).await; let mut actual: Vec = actual.iter().flatten().cloned().collect(); From 02e83b37e5a2c2aff15f1452f428964e76bb7a47 Mon Sep 17 00:00:00 2001 From: Ruan Pearce-Authers Date: Sat, 20 Mar 2021 20:26:33 +0000 Subject: [PATCH 04/17] Add register_catalog method to context --- rust/datafusion/src/execution/context.rs | 42 +++++++++++++++--------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 102b48fd9cf..106524b53d1 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -299,6 +299,29 @@ impl ExecutionContext { Ok(()) } + /// Registers a named catalog using a custom `CatalogProvider` so that + /// it can be referenced from SQL statements executed against this + /// context. + /// + /// Returns the `CatalogProvider` previously registered for this + /// name, if any + pub fn register_catalog( + &self, + name: impl Into, + catalog: Arc, + ) -> Option> { + self.state + .lock() + .unwrap() + .catalogs + .insert(name.into(), catalog) + } + + /// Retrieves a `CatalogProvider` instance by name + pub fn catalog(&self, name: &str) -> Option> { + self.state.lock().unwrap().catalogs.get(name).cloned() + } + fn get_schema<'a>( &'a self, table_ref: impl Into>, @@ -347,24 +370,13 @@ impl ExecutionContext { table_ref: impl Into>, ) -> Result> { let table_ref = table_ref.into(); + let schema = self.get_schema(table_ref)?; - let (table_name, table) = { - let state = self.state.lock().unwrap(); - let resolved_ref = state.resolve_table_ref(table_ref); - let table = state - .schema_for_ref(resolved_ref) - .and_then(|s| s.table(resolved_ref.table)); - - // intentionally returning only the table name for now - // TODO: - (resolved_ref.table.to_string(), table) - }; - - match table { + match schema.table(table_ref.table()) { Some(ref provider) => { let schema = provider.schema(); let table_scan = LogicalPlan::TableScan { - table_name: table_name, + table_name: table_ref.table().to_owned(), source: Arc::clone(provider), projected_schema: schema.to_dfschema_ref()?, projection: None, @@ -378,7 +390,7 @@ impl ExecutionContext { } _ => Err(DataFusionError::Plan(format!( "No table named '{}'", - table_name + table_ref.table() ))), } } From 4f000f31ce7c7b03c011d642b08308106818ee52 Mon Sep 17 00:00:00 2001 From: Ruan Pearce-Authers Date: Sat, 20 Mar 2021 20:49:46 +0000 Subject: [PATCH 05/17] Remove unnecessary unwraps from tests --- rust/benchmarks/src/bin/tpch.rs | 6 +++--- rust/datafusion/examples/dataframe_in_memory.rs | 2 +- rust/datafusion/src/execution/context.rs | 16 ++++++++-------- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/rust/benchmarks/src/bin/tpch.rs b/rust/benchmarks/src/bin/tpch.rs index 61135437b4b..b6a3a4161ee 100644 --- a/rust/benchmarks/src/bin/tpch.rs +++ b/rust/benchmarks/src/bin/tpch.rs @@ -157,9 +157,9 @@ async fn benchmark(opt: BenchmarkOpt) -> Result Result<()> { // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, vec![vec![batch]])?; - ctx.register_table("t", Arc::new(provider)).unwrap(); + ctx.register_table("t", Arc::new(provider))?; let df = ctx.table("t")?; // construct an expression corresponding to "SELECT a, b FROM t WHERE b = 10" in SQL diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 106524b53d1..112d4c8e582 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -835,7 +835,7 @@ mod tests { ctx.register_variable(VarType::UserDefined, Arc::new(variable_provider)); let provider = test::create_table_dual(); - ctx.register_table("dual", provider).unwrap(); + ctx.register_table("dual", provider)?; let results = plan_and_collect(&mut ctx, "SELECT @@version, @name FROM dual").await?; @@ -859,10 +859,10 @@ mod tests { let mut ctx = create_ctx(&tmp_dir, partition_count)?; let provider = test::create_table_dual(); - ctx.register_table("dual", provider).unwrap(); + ctx.register_table("dual", provider)?; - assert!(ctx.deregister_table("dual").unwrap().is_some()); - assert!(ctx.deregister_table("dual").unwrap().is_none()); + assert!(ctx.deregister_table("dual")?.is_some()); + assert!(ctx.deregister_table("dual")?.is_none()); Ok(()) } @@ -1842,7 +1842,7 @@ mod tests { let mut ctx = ExecutionContext::new(); let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]])?; - ctx.register_table("t", Arc::new(provider)).unwrap(); + ctx.register_table("t", Arc::new(provider))?; let myfunc = |args: &[ArrayRef]| { let l = &args[0] @@ -1922,7 +1922,7 @@ mod tests { assert_eq!(a.value(i) + b.value(i), sum.value(i)); } - ctx.deregister_table("t").unwrap(); + ctx.deregister_table("t")?; Ok(()) } @@ -1944,7 +1944,7 @@ mod tests { let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; - ctx.register_table("t", Arc::new(provider)).unwrap(); + ctx.register_table("t", Arc::new(provider))?; let result = plan_and_collect(&mut ctx, "SELECT AVG(a) FROM t").await?; @@ -1981,7 +1981,7 @@ mod tests { let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; - ctx.register_table("t", Arc::new(provider)).unwrap(); + ctx.register_table("t", Arc::new(provider))?; // define a udaf, using a DataFusion's accumulator let my_avg = create_udaf( From fecf4982f3008319c43833a6184c022224469ec7 Mon Sep 17 00:00:00 2001 From: Ruan Pearce-Authers Date: Sat, 20 Mar 2021 21:02:15 +0000 Subject: [PATCH 06/17] Add test for all table reference types --- rust/datafusion/tests/sql.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 93e478d80e6..6caaa884cf3 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -2504,3 +2504,20 @@ async fn inner_join_qualified_names() -> Result<()> { } Ok(()) } + +#[tokio::test] +async fn qualified_table_references() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx)?; + + for table_ref in &[ + "aggregate_test_100", + "public.aggregate_test_100", + "datafusion.public.aggregate_test_100", + ] { + let sql = format!("SELECT COUNT(*) FROM {}", table_ref); + let results = execute(&mut ctx, &sql).await; + assert_eq!(results, vec![vec!["100"]]); + } + Ok(()) +} From 1460598967323a690b2732a65da67fd66a09563e Mon Sep 17 00:00:00 2001 From: Ruan Pearce-Authers Date: Sat, 20 Mar 2021 21:27:14 +0000 Subject: [PATCH 07/17] Allow downcasting of catalog/schema providers --- rust/datafusion/src/catalog/catalog.rs | 9 +++++++++ rust/datafusion/src/catalog/schema.rs | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/rust/datafusion/src/catalog/catalog.rs b/rust/datafusion/src/catalog/catalog.rs index 9255ff1e550..df01ffc4d70 100644 --- a/rust/datafusion/src/catalog/catalog.rs +++ b/rust/datafusion/src/catalog/catalog.rs @@ -19,11 +19,16 @@ //! representing collections of named schemas. use crate::catalog::schema::SchemaProvider; +use std::any::Any; use std::collections::HashMap; use std::sync::{Arc, RwLock}; /// Represents a catalog, comprising a number of named schemas. pub trait CatalogProvider: Sync + Send { + /// Returns the catalog provider as [`Any`](std::any::Any) + /// so that it can be downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; + /// Retrieves the list of available schema names in this catalog. fn schema_names(&self) -> Vec; @@ -57,6 +62,10 @@ impl MemoryCatalogProvider { } impl CatalogProvider for MemoryCatalogProvider { + fn as_any(&self) -> &dyn Any { + self + } + fn schema_names(&self) -> Vec { let schemas = self.schemas.read().unwrap(); schemas.keys().cloned().collect() diff --git a/rust/datafusion/src/catalog/schema.rs b/rust/datafusion/src/catalog/schema.rs index 827e2d75510..41e7a669efc 100644 --- a/rust/datafusion/src/catalog/schema.rs +++ b/rust/datafusion/src/catalog/schema.rs @@ -20,11 +20,16 @@ use crate::datasource::TableProvider; use crate::error::{DataFusionError, Result}; +use std::any::Any; use std::collections::HashMap; use std::sync::{Arc, RwLock}; /// Represents a schema, comprising a number of named tables. pub trait SchemaProvider: Sync + Send { + /// Returns the schema provider as [`Any`](std::any::Any) + /// so that it can be downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; + /// Retrieves the list of available table names in this schema. fn table_names(&self) -> Vec; @@ -67,6 +72,10 @@ impl MemorySchemaProvider { } impl SchemaProvider for MemorySchemaProvider { + fn as_any(&self) -> &dyn Any { + self + } + fn table_names(&self) -> Vec { let tables = self.tables.read().unwrap(); tables.keys().cloned().collect() From 8d30ee46290b116f50aacba4864ec281a6d7adaa Mon Sep 17 00:00:00 2001 From: Ruan Pearce-Authers Date: Sat, 20 Mar 2021 22:03:33 +0000 Subject: [PATCH 08/17] Add doc comment for catalogs in execution state --- rust/datafusion/src/execution/context.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 112d4c8e582..6ed40b38530 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -650,7 +650,7 @@ impl ExecutionConfig { /// Execution context for registering data sources and executing queries #[derive(Clone)] pub struct ExecutionContextState { - /// + /// Collection of catalogs containing schemas and ultimately tables pub catalogs: HashMap>, /// Scalar functions that are registered with the context pub scalar_functions: HashMap>, From 5e1e942bb43a4f16f1c0e2a44ff1a428b24ba7fd Mon Sep 17 00:00:00 2001 From: Ruan Pearce-Authers Date: Sun, 21 Mar 2021 11:35:13 +0000 Subject: [PATCH 09/17] Remove unnecessary table resolution in mock context provider --- rust/datafusion/src/sql/planner.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index 8f7d912e978..45ad6891866 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -2496,8 +2496,7 @@ mod tests { &self, name: TableReference, ) -> Option> { - let resolved_ref = name.resolve("", ""); - let schema = match resolved_ref.table { + let schema = match name.table() { "person" => Some(Schema::new(vec![ Field::new("id", DataType::UInt32, false), Field::new("first_name", DataType::Utf8, false), From cf5b8466171dd914b89f80dca8ecc2b6bf6654a4 Mon Sep 17 00:00:00 2001 From: Ruan Pearce-Authers Date: Sun, 21 Mar 2021 11:41:40 +0000 Subject: [PATCH 10/17] Removed unused Display impls for table references --- rust/datafusion/src/catalog/mod.rs | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/rust/datafusion/src/catalog/mod.rs b/rust/datafusion/src/catalog/mod.rs index 84a626ebf27..b61ed154acc 100644 --- a/rust/datafusion/src/catalog/mod.rs +++ b/rust/datafusion/src/catalog/mod.rs @@ -35,12 +35,6 @@ pub struct ResolvedTableReference<'a> { pub table: &'a str, } -impl<'a> std::fmt::Display for ResolvedTableReference<'a> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}.{}.{}", self.catalog, self.schema, self.table) - } -} - /// Represents a path to a table that may require further resolution #[derive(Clone, Copy)] pub enum TableReference<'a> { @@ -107,20 +101,6 @@ impl<'a> TableReference<'a> { } } -impl<'a> std::fmt::Display for TableReference<'a> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Bare { table } => write!(f, "{}", table), - Self::Partial { schema, table } => write!(f, "{}.{}", schema, table), - Self::Full { - catalog, - schema, - table, - } => write!(f, "{}.{}.{}", catalog, schema, table), - } - } -} - impl<'a> From<&'a str> for TableReference<'a> { fn from(s: &'a str) -> Self { Self::Bare { table: s } From 0f33b6f7fca060010fd53d735b25324223c970c3 Mon Sep 17 00:00:00 2001 From: Ruan Pearce-Authers Date: Sun, 21 Mar 2021 11:51:20 +0000 Subject: [PATCH 11/17] Add test for invalid qualified table references --- rust/datafusion/tests/sql.rs | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 6caaa884cf3..feae10511c1 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -35,7 +35,10 @@ use datafusion::{ datasource::{csv::CsvReadOptions, MemTable}, physical_plan::collect, }; -use datafusion::{error::Result, physical_plan::ColumnarValue}; +use datafusion::{ + error::{DataFusionError, Result}, + physical_plan::ColumnarValue, +}; #[tokio::test] async fn nyc() -> Result<()> { @@ -2521,3 +2524,18 @@ async fn qualified_table_references() -> Result<()> { } Ok(()) } + +#[tokio::test] +async fn invalid_qualified_table_references() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx)?; + + for table_ref in &[ + "nonexistentschema.aggregate_test_100", + "nonexistentcatalog.public.aggregate_test_100", + ] { + let sql = format!("SELECT COUNT(*) FROM {}", table_ref); + assert!(matches!(ctx.sql(&sql), Err(DataFusionError::Plan(_)))); + } + Ok(()) +} From 58451a0397be477fa11a552504ab3f41980f8ae9 Mon Sep 17 00:00:00 2001 From: Ruan Pearce-Authers Date: Sun, 21 Mar 2021 11:57:05 +0000 Subject: [PATCH 12/17] Add test case for invalid compound ident table ref --- rust/datafusion/tests/sql.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index feae10511c1..ba8230a442f 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -2533,6 +2533,7 @@ async fn invalid_qualified_table_references() -> Result<()> { for table_ref in &[ "nonexistentschema.aggregate_test_100", "nonexistentcatalog.public.aggregate_test_100", + "way.too.many.namespaces.as.ident.prefixes.aggregate_test_100", ] { let sql = format!("SELECT COUNT(*) FROM {}", table_ref); assert!(matches!(ctx.sql(&sql), Err(DataFusionError::Plan(_)))); From d79bdd5260bb8477fbf3f7584b3978769cafc5d2 Mon Sep 17 00:00:00 2001 From: Ruan Pearce-Authers Date: Sun, 21 Mar 2021 19:11:05 +0000 Subject: [PATCH 13/17] Clean up schema retrieval and comments, deprecate tables method --- rust/datafusion/src/execution/context.rs | 76 ++++++++++++++---------- 1 file changed, 43 insertions(+), 33 deletions(-) diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 6ed40b38530..04c0eec060a 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -322,34 +322,26 @@ impl ExecutionContext { self.state.lock().unwrap().catalogs.get(name).cloned() } - fn get_schema<'a>( - &'a self, - table_ref: impl Into>, - ) -> Result> { - self.state - .lock() - .unwrap() - .schema_for_ref(table_ref.into()) - .ok_or_else(|| DataFusionError::Plan("failed to resolve schema".to_owned())) - } - - /// Registers a named table using a custom `TableProvider` so that + /// Registers a table using a custom `TableProvider` so that /// it can be referenced from SQL statements executed against this /// context. /// /// Returns the `TableProvider` previously registered for this - /// name, if any + /// reference, if any pub fn register_table<'a>( &'a mut self, table_ref: impl Into>, provider: Arc, ) -> Result>> { let table_ref = table_ref.into(); - self.get_schema(table_ref)? + self.state + .lock() + .unwrap() + .schema_for_ref(table_ref)? .register_table(table_ref.table().to_owned(), provider) } - /// Deregisters the named table. + /// Deregisters the given table. /// /// Returns the registered provider, if any pub fn deregister_table<'a>( @@ -357,20 +349,23 @@ impl ExecutionContext { table_ref: impl Into>, ) -> Result>> { let table_ref = table_ref.into(); - self.get_schema(table_ref)? + self.state + .lock() + .unwrap() + .schema_for_ref(table_ref)? .deregister_table(table_ref.table()) } /// Retrieves a DataFrame representing a table previously registered by calling the /// register_table function. /// - /// Returns an error if no table has been registered with the provided name. + /// Returns an error if no table has been registered with the provided reference. pub fn table<'a>( &self, table_ref: impl Into>, ) -> Result> { let table_ref = table_ref.into(); - let schema = self.get_schema(table_ref)?; + let schema = self.state.lock().unwrap().schema_for_ref(table_ref)?; match schema.table(table_ref.table()) { Some(ref provider) => { @@ -395,20 +390,25 @@ impl ExecutionContext { } } - /// Returns the set of available tables. + /// Returns the set of available tables in the default catalog and schema. /// /// Use [`table`] to get a specific table. /// /// [`table`]: ExecutionContext::table - pub fn tables(&self) -> HashSet { - // quick hack to get default schema - self.get_schema(TableReference::Bare { table: "" }) - .ok() - .map(|s| s.table_names()) - .unwrap_or_default() + #[deprecated( + note = "Please use the catalog provider interface (`ExecutionContext::catalog`) to examine available catalogs, schemas, and tables" + )] + pub fn tables(&self) -> Result> { + Ok(self + .state + .lock() + .unwrap() + // a bare reference will always resolve to the default catalog and schema + .schema_for_ref(TableReference::Bare { table: "" })? + .table_names() .iter() .cloned() - .collect() + .collect()) } /// Optimizes the logical plan by applying optimizer rules. @@ -675,21 +675,31 @@ impl ExecutionContextState { fn schema_for_ref<'a>( &'a self, table_ref: impl Into>, - ) -> Option> { + ) -> Result> { let resolved_ref = self.resolve_table_ref(table_ref.into()); - Some( - self.catalogs - .get(resolved_ref.catalog)? - .schema(resolved_ref.schema)?, - ) + self.catalogs + .get(resolved_ref.catalog) + .ok_or_else(|| { + DataFusionError::Plan(format!( + "failed to resolve catalog: {}", + resolved_ref.catalog + )) + })? + .schema(resolved_ref.schema) + .ok_or_else(|| { + DataFusionError::Plan(format!( + "failed to resolve schema: {}", + resolved_ref.schema + )) + }) } } impl ContextProvider for ExecutionContextState { fn get_table_provider(&self, name: TableReference) -> Option> { let resolved_ref = self.resolve_table_ref(name); - let schema = self.schema_for_ref(resolved_ref)?; + let schema = self.schema_for_ref(resolved_ref).ok()?; schema.table(resolved_ref.table) } From 66de8d300db2b5246a6d85e49dcd453be10e5185 Mon Sep 17 00:00:00 2001 From: Ruan Pearce-Authers Date: Sun, 21 Mar 2021 20:57:12 +0000 Subject: [PATCH 14/17] Add tests for custom catalog and schema --- rust/datafusion/src/catalog/catalog.rs | 2 +- rust/datafusion/src/execution/context.rs | 71 +++++++++++++++++++++++- 2 files changed, 70 insertions(+), 3 deletions(-) diff --git a/rust/datafusion/src/catalog/catalog.rs b/rust/datafusion/src/catalog/catalog.rs index df01ffc4d70..69059d13bb3 100644 --- a/rust/datafusion/src/catalog/catalog.rs +++ b/rust/datafusion/src/catalog/catalog.rs @@ -52,7 +52,7 @@ impl MemoryCatalogProvider { /// Adds a new schema to this catalog. /// If a schema of the same name existed before, it is replaced in the catalog and returned. pub fn register_schema( - &mut self, + &self, name: impl Into, schema: Arc, ) -> Option> { diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 04c0eec060a..e0f2ba4f69c 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -119,7 +119,7 @@ impl ExecutionContext { let mut catalogs = HashMap::new(); if config.create_default_catalog_and_schema { - let mut default_catalog = MemoryCatalogProvider::new(); + let default_catalog = MemoryCatalogProvider::new(); default_catalog.register_schema( config.default_schema.clone(), Arc::new(MemorySchemaProvider::new()), @@ -645,6 +645,12 @@ impl ExecutionConfig { self.default_schema = schema.into(); self } + + /// Controls whether the default catalog and schema will be automatically created + pub fn create_default_catalog_and_schema(mut self, create: bool) -> Self { + self.create_default_catalog_and_schema = create; + self + } } /// Execution context for registering data sources and executing queries @@ -753,7 +759,7 @@ mod tests { logical_plan::{col, create_udf, sum}, }; use crate::{ - datasource::MemTable, logical_plan::create_udaf, + datasource::empty::EmptyTable, datasource::MemTable, logical_plan::create_udaf, physical_plan::expressions::AvgAccumulator, }; use arrow::array::{ @@ -2029,6 +2035,67 @@ mod tests { Ok(()) } + fn empty_table_for_catalogs() -> Arc { + Arc::new(EmptyTable::new(Arc::new(Schema::new(vec![Field::new( + "a", + DataType::Int32, + true, + )])))) + } + + #[tokio::test] + async fn disabled_default_catalog_and_schema() -> Result<()> { + let mut ctx = ExecutionContext::with_config( + ExecutionConfig::new().create_default_catalog_and_schema(false), + ); + + assert!(matches!( + ctx.register_table("test", empty_table_for_catalogs()), + Err(DataFusionError::Plan(_)) + )); + + assert!(matches!( + ctx.sql("select * from datafusion.public.test"), + Err(DataFusionError::Plan(_)) + )); + + Ok(()) + } + + #[tokio::test] + async fn custom_catalog_and_schema() -> Result<()> { + let mut ctx = ExecutionContext::with_config( + ExecutionConfig::new() + .create_default_catalog_and_schema(false) + .with_default_catalog_and_schema("my_catalog", "my_schema"), + ); + + let catalog = MemoryCatalogProvider::new(); + let schema = MemorySchemaProvider::new(); + schema.register_table("test".to_owned(), empty_table_for_catalogs())?; + catalog.register_schema("my_schema", Arc::new(schema)); + ctx.register_catalog("my_catalog", Arc::new(catalog)); + + for table_ref in &["my_catalog.my_schema.test", "my_schema.test", "test"] { + let result = plan_and_collect( + &mut ctx, + &format!("SELECT COUNT(*) AS count FROM {}", table_ref), + ) + .await?; + + let expected = vec![ + "+-------+", + "| count |", + "+-------+", + "| 0 |", + "+-------+", + ]; + assert_batches_eq!(expected, &result); + } + + Ok(()) + } + struct MyPhysicalPlanner {} impl PhysicalPlanner for MyPhysicalPlanner { From a18829fa728cc434b1c12b86e0564e85ddc77689 Mon Sep 17 00:00:00 2001 From: Ruan Pearce-Authers Date: Sun, 21 Mar 2021 21:11:29 +0000 Subject: [PATCH 15/17] Fix [de]register_table signatures --- rust/datafusion/src/catalog/schema.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/rust/datafusion/src/catalog/schema.rs b/rust/datafusion/src/catalog/schema.rs index 41e7a669efc..0e39546a5f8 100644 --- a/rust/datafusion/src/catalog/schema.rs +++ b/rust/datafusion/src/catalog/schema.rs @@ -38,10 +38,11 @@ pub trait SchemaProvider: Sync + Send { /// If supported by the implementation, adds a new table to this schema. /// If a table of the same name existed before, it is replaced in the schema and returned. + #[allow(unused_variables)] fn register_table( &self, - _name: String, - _table: Arc, + name: String, + table: Arc, ) -> Result>> { Err(DataFusionError::Execution( "schema provider does not support registering tables".to_owned(), @@ -50,7 +51,8 @@ pub trait SchemaProvider: Sync + Send { /// If supported by the implementation, removes an existing table from this schema and returns it. /// If no table of that name exists, returns Ok(None). - fn deregister_table(&self, _name: &str) -> Result>> { + #[allow(unused_variables)] + fn deregister_table(&self, name: &str) -> Result>> { Err(DataFusionError::Execution( "schema provider does not support deregistering tables".to_owned(), )) From 03048c634a2aeaee15163d6765111ce12ddfe9cf Mon Sep 17 00:00:00 2001 From: Ruan Pearce-Authers Date: Sun, 21 Mar 2021 21:37:49 +0000 Subject: [PATCH 16/17] Add test for cross-catalog queries --- rust/datafusion/src/execution/context.rs | 67 ++++++++++++++++++++---- 1 file changed, 57 insertions(+), 10 deletions(-) diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index e0f2ba4f69c..de67023469d 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -759,7 +759,7 @@ mod tests { logical_plan::{col, create_udf, sum}, }; use crate::{ - datasource::empty::EmptyTable, datasource::MemTable, logical_plan::create_udaf, + datasource::MemTable, logical_plan::create_udaf, physical_plan::expressions::AvgAccumulator, }; use arrow::array::{ @@ -2035,12 +2035,17 @@ mod tests { Ok(()) } - fn empty_table_for_catalogs() -> Arc { - Arc::new(EmptyTable::new(Arc::new(Schema::new(vec![Field::new( - "a", - DataType::Int32, - true, - )])))) + fn table_with_sequence( + seq_start: i32, + seq_end: i32, + ) -> Result> { + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); + let arr = Arc::new(Int32Array::from((seq_start..=seq_end).collect::>())); + let partitions = vec![vec![RecordBatch::try_new( + schema.clone(), + vec![arr as ArrayRef], + )?]]; + Ok(Arc::new(MemTable::try_new(schema, partitions)?)) } #[tokio::test] @@ -2050,7 +2055,7 @@ mod tests { ); assert!(matches!( - ctx.register_table("test", empty_table_for_catalogs()), + ctx.register_table("test", table_with_sequence(1, 1)?), Err(DataFusionError::Plan(_)) )); @@ -2072,7 +2077,7 @@ mod tests { let catalog = MemoryCatalogProvider::new(); let schema = MemorySchemaProvider::new(); - schema.register_table("test".to_owned(), empty_table_for_catalogs())?; + schema.register_table("test".to_owned(), table_with_sequence(1, 1)?)?; catalog.register_schema("my_schema", Arc::new(schema)); ctx.register_catalog("my_catalog", Arc::new(catalog)); @@ -2087,7 +2092,7 @@ mod tests { "+-------+", "| count |", "+-------+", - "| 0 |", + "| 1 |", "+-------+", ]; assert_batches_eq!(expected, &result); @@ -2096,6 +2101,48 @@ mod tests { Ok(()) } + #[tokio::test] + async fn cross_catalog_access() -> Result<()> { + let mut ctx = ExecutionContext::new(); + + let catalog_a = MemoryCatalogProvider::new(); + let schema_a = MemorySchemaProvider::new(); + schema_a.register_table("table_a".to_owned(), table_with_sequence(1, 1)?)?; + catalog_a.register_schema("schema_a", Arc::new(schema_a)); + ctx.register_catalog("catalog_a", Arc::new(catalog_a)); + + let catalog_b = MemoryCatalogProvider::new(); + let schema_b = MemorySchemaProvider::new(); + schema_b.register_table("table_b".to_owned(), table_with_sequence(1, 2)?)?; + catalog_b.register_schema("schema_b", Arc::new(schema_b)); + ctx.register_catalog("catalog_b", Arc::new(catalog_b)); + + let result = plan_and_collect( + &mut ctx, + "SELECT cat, SUM(i) AS total FROM ( + SELECT i, 'a' AS cat FROM catalog_a.schema_a.table_a + UNION ALL + SELECT i, 'b' AS cat FROM catalog_b.schema_b.table_b + ) + GROUP BY cat + ORDER BY cat + ", + ) + .await?; + + let expected = vec![ + "+-----+-------+", + "| cat | total |", + "+-----+-------+", + "| a | 1 |", + "| b | 3 |", + "+-----+-------+", + ]; + assert_batches_eq!(expected, &result); + + Ok(()) + } + struct MyPhysicalPlanner {} impl PhysicalPlanner for MyPhysicalPlanner { From 0870acf9cda6bfc90068a64d58bc1c1fd43c1139 Mon Sep 17 00:00:00 2001 From: Ruan Pearce-Authers Date: Mon, 22 Mar 2021 20:28:44 +0000 Subject: [PATCH 17/17] Review suggestion for doc comment Co-authored-by: Andrew Lamb --- rust/datafusion/src/execution/context.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index de67023469d..f0902a995a1 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -656,7 +656,7 @@ impl ExecutionConfig { /// Execution context for registering data sources and executing queries #[derive(Clone)] pub struct ExecutionContextState { - /// Collection of catalogs containing schemas and ultimately tables + /// Collection of catalogs containing schemas and ultimately TableProviders pub catalogs: HashMap>, /// Scalar functions that are registered with the context pub scalar_functions: HashMap>,