diff --git a/rust/benchmarks/src/bin/tpch.rs b/rust/benchmarks/src/bin/tpch.rs index b0a6cedd172..b6a3a4161ee 100644 --- a/rust/benchmarks/src/bin/tpch.rs +++ b/rust/benchmarks/src/bin/tpch.rs @@ -157,9 +157,9 @@ async fn benchmark(opt: BenchmarkOpt) -> Result Result> { +) -> Result> { match table_format { // dbgen creates .tbl ('|' delimited) files without header "tbl" => { @@ -1614,7 +1614,7 @@ mod tests { let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]])?; - ctx.register_table(table, Arc::new(provider)); + ctx.register_table(table, Arc::new(provider))?; } let plan = create_logical_plan(&mut ctx, n)?; diff --git a/rust/datafusion/benches/aggregate_query_sql.rs b/rust/datafusion/benches/aggregate_query_sql.rs index 75d9d3432ba..8f1a97e198d 100644 --- a/rust/datafusion/benches/aggregate_query_sql.rs +++ b/rust/datafusion/benches/aggregate_query_sql.rs @@ -150,7 +150,7 @@ fn create_context( // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, partitions)?; - ctx.register_table("t", Arc::new(provider)); + ctx.register_table("t", Arc::new(provider))?; Ok(Arc::new(Mutex::new(ctx))) } diff --git a/rust/datafusion/benches/filter_query_sql.rs b/rust/datafusion/benches/filter_query_sql.rs index 363ae416f67..8600bdc88c6 100644 --- a/rust/datafusion/benches/filter_query_sql.rs +++ b/rust/datafusion/benches/filter_query_sql.rs @@ -62,7 +62,7 @@ fn create_context(array_len: usize, batch_size: usize) -> Result Arc> { // create local execution context let mut ctx = ExecutionContext::new(); ctx.state.lock().unwrap().config.concurrency = 1; - ctx.register_table("aggregate_test_100", Arc::new(mem_table)); + ctx.register_table("aggregate_test_100", Arc::new(mem_table)) + .unwrap(); ctx_holder.lock().unwrap().push(Arc::new(Mutex::new(ctx))) }); diff --git a/rust/datafusion/examples/dataframe_in_memory.rs b/rust/datafusion/examples/dataframe_in_memory.rs index 28414bf8700..de8552a3bba 100644 --- a/rust/datafusion/examples/dataframe_in_memory.rs +++ b/rust/datafusion/examples/dataframe_in_memory.rs @@ -49,7 +49,7 @@ async fn main() -> Result<()> { // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, vec![vec![batch]])?; - ctx.register_table("t", Arc::new(provider)); + ctx.register_table("t", Arc::new(provider))?; let df = ctx.table("t")?; // construct an expression corresponding to "SELECT a, b FROM t WHERE b = 10" in SQL diff --git a/rust/datafusion/examples/simple_udaf.rs b/rust/datafusion/examples/simple_udaf.rs index a36d200235a..55aa350b13d 100644 --- a/rust/datafusion/examples/simple_udaf.rs +++ b/rust/datafusion/examples/simple_udaf.rs @@ -48,7 +48,7 @@ fn create_context() -> Result { // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?; - ctx.register_table("t", Arc::new(provider)); + ctx.register_table("t", Arc::new(provider))?; Ok(ctx) } diff --git a/rust/datafusion/examples/simple_udf.rs b/rust/datafusion/examples/simple_udf.rs index d49aac48527..00debdbddac 100644 --- a/rust/datafusion/examples/simple_udf.rs +++ b/rust/datafusion/examples/simple_udf.rs @@ -50,7 +50,7 @@ fn create_context() -> Result { // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, vec![vec![batch]])?; - ctx.register_table("t", Arc::new(provider)); + ctx.register_table("t", Arc::new(provider))?; Ok(ctx) } diff --git a/rust/datafusion/src/catalog/catalog.rs b/rust/datafusion/src/catalog/catalog.rs new file mode 100644 index 00000000000..69059d13bb3 --- /dev/null +++ b/rust/datafusion/src/catalog/catalog.rs @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Describes the interface and built-in implementations of catalogs, +//! representing collections of named schemas. + +use crate::catalog::schema::SchemaProvider; +use std::any::Any; +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; + +/// Represents a catalog, comprising a number of named schemas. +pub trait CatalogProvider: Sync + Send { + /// Returns the catalog provider as [`Any`](std::any::Any) + /// so that it can be downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; + + /// Retrieves the list of available schema names in this catalog. + fn schema_names(&self) -> Vec; + + /// Retrieves a specific schema from the catalog by name, provided it exists. + fn schema(&self, name: &str) -> Option>; +} + +/// Simple in-memory implementation of a catalog. +pub struct MemoryCatalogProvider { + schemas: RwLock>>, +} + +impl MemoryCatalogProvider { + /// Instantiates a new MemoryCatalogProvider with an empty collection of schemas. + pub fn new() -> Self { + Self { + schemas: RwLock::new(HashMap::new()), + } + } + + /// Adds a new schema to this catalog. + /// If a schema of the same name existed before, it is replaced in the catalog and returned. + pub fn register_schema( + &self, + name: impl Into, + schema: Arc, + ) -> Option> { + let mut schemas = self.schemas.write().unwrap(); + schemas.insert(name.into(), schema) + } +} + +impl CatalogProvider for MemoryCatalogProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema_names(&self) -> Vec { + let schemas = self.schemas.read().unwrap(); + schemas.keys().cloned().collect() + } + + fn schema(&self, name: &str) -> Option> { + let schemas = self.schemas.read().unwrap(); + schemas.get(name).cloned() + } +} diff --git a/rust/datafusion/src/catalog/mod.rs b/rust/datafusion/src/catalog/mod.rs new file mode 100644 index 00000000000..b61ed154acc --- /dev/null +++ b/rust/datafusion/src/catalog/mod.rs @@ -0,0 +1,145 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module contains interfaces and default implementations +//! of table namespacing concepts, including catalogs and schemas. + +pub mod catalog; +pub mod schema; + +use crate::error::DataFusionError; +use std::convert::TryFrom; + +/// Represents a resolved path to a table of the form "catalog.schema.table" +#[derive(Clone, Copy)] +pub struct ResolvedTableReference<'a> { + /// The catalog (aka database) containing the table + pub catalog: &'a str, + /// The schema containing the table + pub schema: &'a str, + /// The table name + pub table: &'a str, +} + +/// Represents a path to a table that may require further resolution +#[derive(Clone, Copy)] +pub enum TableReference<'a> { + /// An unqualified table reference, e.g. "table" + Bare { + /// The table name + table: &'a str, + }, + /// A partially resolved table reference, e.g. "schema.table" + Partial { + /// The schema containing the table + schema: &'a str, + /// The table name + table: &'a str, + }, + /// A fully resolved table reference, e.g. "catalog.schema.table" + Full { + /// The catalog (aka database) containing the table + catalog: &'a str, + /// The schema containing the table + schema: &'a str, + /// The table name + table: &'a str, + }, +} + +impl<'a> TableReference<'a> { + /// Retrieve the actual table name, regardless of qualification + pub fn table(&self) -> &str { + match self { + Self::Full { table, .. } + | Self::Partial { table, .. } + | Self::Bare { table } => table, + } + } + + /// Given a default catalog and schema, ensure this table reference is fully resolved + pub fn resolve( + self, + default_catalog: &'a str, + default_schema: &'a str, + ) -> ResolvedTableReference<'a> { + match self { + Self::Full { + catalog, + schema, + table, + } => ResolvedTableReference { + catalog, + schema, + table, + }, + Self::Partial { schema, table } => ResolvedTableReference { + catalog: default_catalog, + schema, + table, + }, + Self::Bare { table } => ResolvedTableReference { + catalog: default_catalog, + schema: default_schema, + table, + }, + } + } +} + +impl<'a> From<&'a str> for TableReference<'a> { + fn from(s: &'a str) -> Self { + Self::Bare { table: s } + } +} + +impl<'a> From> for TableReference<'a> { + fn from(resolved: ResolvedTableReference<'a>) -> Self { + Self::Full { + catalog: resolved.catalog, + schema: resolved.schema, + table: resolved.table, + } + } +} + +impl<'a> TryFrom<&'a sqlparser::ast::ObjectName> for TableReference<'a> { + type Error = DataFusionError; + + fn try_from(value: &'a sqlparser::ast::ObjectName) -> Result { + let idents = &value.0; + + match idents.len() { + 1 => Ok(Self::Bare { + table: &idents[0].value, + }), + 2 => Ok(Self::Partial { + schema: &idents[0].value, + table: &idents[1].value, + }), + 3 => Ok(Self::Full { + catalog: &idents[0].value, + schema: &idents[1].value, + table: &idents[2].value, + }), + _ => Err(DataFusionError::Plan(format!( + "invalid table reference: {}", + value + ))), + } + } +} diff --git a/rust/datafusion/src/catalog/schema.rs b/rust/datafusion/src/catalog/schema.rs new file mode 100644 index 00000000000..0e39546a5f8 --- /dev/null +++ b/rust/datafusion/src/catalog/schema.rs @@ -0,0 +1,104 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Describes the interface and built-in implementations of schemas, +//! representing collections of named tables. + +use crate::datasource::TableProvider; +use crate::error::{DataFusionError, Result}; +use std::any::Any; +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; + +/// Represents a schema, comprising a number of named tables. +pub trait SchemaProvider: Sync + Send { + /// Returns the schema provider as [`Any`](std::any::Any) + /// so that it can be downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; + + /// Retrieves the list of available table names in this schema. + fn table_names(&self) -> Vec; + + /// Retrieves a specific table from the schema by name, provided it exists. + fn table(&self, name: &str) -> Option>; + + /// If supported by the implementation, adds a new table to this schema. + /// If a table of the same name existed before, it is replaced in the schema and returned. + #[allow(unused_variables)] + fn register_table( + &self, + name: String, + table: Arc, + ) -> Result>> { + Err(DataFusionError::Execution( + "schema provider does not support registering tables".to_owned(), + )) + } + + /// If supported by the implementation, removes an existing table from this schema and returns it. + /// If no table of that name exists, returns Ok(None). + #[allow(unused_variables)] + fn deregister_table(&self, name: &str) -> Result>> { + Err(DataFusionError::Execution( + "schema provider does not support deregistering tables".to_owned(), + )) + } +} + +/// Simple in-memory implementation of a schema. +pub struct MemorySchemaProvider { + tables: RwLock>>, +} + +impl MemorySchemaProvider { + /// Instantiates a new MemorySchemaProvider with an empty collection of tables. + pub fn new() -> Self { + Self { + tables: RwLock::new(HashMap::new()), + } + } +} + +impl SchemaProvider for MemorySchemaProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn table_names(&self) -> Vec { + let tables = self.tables.read().unwrap(); + tables.keys().cloned().collect() + } + + fn table(&self, name: &str) -> Option> { + let tables = self.tables.read().unwrap(); + tables.get(name).cloned() + } + + fn register_table( + &self, + name: String, + table: Arc, + ) -> Result>> { + let mut tables = self.tables.write().unwrap(); + Ok(tables.insert(name, table)) + } + + fn deregister_table(&self, name: &str) -> Result>> { + let mut tables = self.tables.write().unwrap(); + Ok(tables.remove(name)) + } +} diff --git a/rust/datafusion/src/datasource/datasource.rs b/rust/datafusion/src/datasource/datasource.rs index 4e6ad36160c..e2b07336486 100644 --- a/rust/datafusion/src/datasource/datasource.rs +++ b/rust/datafusion/src/datasource/datasource.rs @@ -67,7 +67,7 @@ pub enum TableProviderFilterPushDown { } /// Source table -pub trait TableProvider { +pub trait TableProvider: Sync + Send { /// Returns the table provider as [`Any`](std::any::Any) so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; diff --git a/rust/datafusion/src/datasource/memory.rs b/rust/datafusion/src/datasource/memory.rs index 514eac6dcda..af404808702 100644 --- a/rust/datafusion/src/datasource/memory.rs +++ b/rust/datafusion/src/datasource/memory.rs @@ -110,7 +110,7 @@ impl MemTable { /// Create a mem table by reading from another data source pub async fn load( - t: Arc, + t: Arc, batch_size: usize, output_partitions: Option, ) -> Result { diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index caf09be708d..f0902a995a1 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -32,6 +32,11 @@ use tokio::task::{self, JoinHandle}; use arrow::csv; +use crate::catalog::{ + catalog::{CatalogProvider, MemoryCatalogProvider}, + schema::{MemorySchemaProvider, SchemaProvider}, + ResolvedTableReference, TableReference, +}; use crate::datasource::csv::CsvFile; use crate::datasource::parquet::ParquetTable; use crate::datasource::TableProvider; @@ -111,9 +116,24 @@ impl ExecutionContext { /// Creates a new execution context using the provided configuration. pub fn with_config(config: ExecutionConfig) -> Self { + let mut catalogs = HashMap::new(); + + if config.create_default_catalog_and_schema { + let default_catalog = MemoryCatalogProvider::new(); + default_catalog.register_schema( + config.default_schema.clone(), + Arc::new(MemorySchemaProvider::new()), + ); + + catalogs.insert( + config.default_catalog.clone(), + Arc::new(default_catalog) as Arc, + ); + } + Self { state: Arc::new(Mutex::new(ExecutionContextState { - datasources: HashMap::new(), + catalogs, scalar_functions: HashMap::new(), var_provider: HashMap::new(), aggregate_functions: HashMap::new(), @@ -239,7 +259,7 @@ impl ExecutionContext { /// Creates a DataFrame for reading a custom TableProvider. pub fn read_table( &mut self, - provider: Arc, + provider: Arc, ) -> Result> { let schema = provider.schema(); let table_scan = LogicalPlan::TableScan { @@ -264,7 +284,7 @@ impl ExecutionContext { filename: &str, options: CsvReadOptions, ) -> Result<()> { - self.register_table(name, Arc::new(CsvFile::try_new(filename, options)?)); + self.register_table(name, Arc::new(CsvFile::try_new(filename, options)?))?; Ok(()) } @@ -275,48 +295,83 @@ impl ExecutionContext { &filename, self.state.lock().unwrap().config.concurrency, )?; - self.register_table(name, Arc::new(table)); + self.register_table(name, Arc::new(table))?; Ok(()) } - /// Registers a named table using a custom `TableProvider` so that + /// Registers a named catalog using a custom `CatalogProvider` so that /// it can be referenced from SQL statements executed against this /// context. /// - /// Returns the `TableProvider` previously registered for this + /// Returns the `CatalogProvider` previously registered for this /// name, if any - pub fn register_table( - &mut self, - name: &str, - provider: Arc, - ) -> Option> { + pub fn register_catalog( + &self, + name: impl Into, + catalog: Arc, + ) -> Option> { + self.state + .lock() + .unwrap() + .catalogs + .insert(name.into(), catalog) + } + + /// Retrieves a `CatalogProvider` instance by name + pub fn catalog(&self, name: &str) -> Option> { + self.state.lock().unwrap().catalogs.get(name).cloned() + } + + /// Registers a table using a custom `TableProvider` so that + /// it can be referenced from SQL statements executed against this + /// context. + /// + /// Returns the `TableProvider` previously registered for this + /// reference, if any + pub fn register_table<'a>( + &'a mut self, + table_ref: impl Into>, + provider: Arc, + ) -> Result>> { + let table_ref = table_ref.into(); self.state .lock() .unwrap() - .datasources - .insert(name.to_string(), provider) + .schema_for_ref(table_ref)? + .register_table(table_ref.table().to_owned(), provider) } - /// Deregisters the named table. + /// Deregisters the given table. /// /// Returns the registered provider, if any - pub fn deregister_table( - &mut self, - name: &str, - ) -> Option> { - self.state.lock().unwrap().datasources.remove(name) + pub fn deregister_table<'a>( + &'a mut self, + table_ref: impl Into>, + ) -> Result>> { + let table_ref = table_ref.into(); + self.state + .lock() + .unwrap() + .schema_for_ref(table_ref)? + .deregister_table(table_ref.table()) } /// Retrieves a DataFrame representing a table previously registered by calling the /// register_table function. /// - /// Returns an error if no table has been registered with the provided name. - pub fn table(&self, table_name: &str) -> Result> { - match self.state.lock().unwrap().datasources.get(table_name) { - Some(provider) => { + /// Returns an error if no table has been registered with the provided reference. + pub fn table<'a>( + &self, + table_ref: impl Into>, + ) -> Result> { + let table_ref = table_ref.into(); + let schema = self.state.lock().unwrap().schema_for_ref(table_ref)?; + + match schema.table(table_ref.table()) { + Some(ref provider) => { let schema = provider.schema(); let table_scan = LogicalPlan::TableScan { - table_name: table_name.to_string(), + table_name: table_ref.table().to_owned(), source: Arc::clone(provider), projected_schema: schema.to_dfschema_ref()?, projection: None, @@ -330,24 +385,30 @@ impl ExecutionContext { } _ => Err(DataFusionError::Plan(format!( "No table named '{}'", - table_name + table_ref.table() ))), } } - /// Returns the set of available tables. + /// Returns the set of available tables in the default catalog and schema. /// /// Use [`table`] to get a specific table. /// /// [`table`]: ExecutionContext::table - pub fn tables(&self) -> HashSet { - self.state + #[deprecated( + note = "Please use the catalog provider interface (`ExecutionContext::catalog`) to examine available catalogs, schemas, and tables" + )] + pub fn tables(&self) -> Result> { + Ok(self + .state .lock() .unwrap() - .datasources - .keys() + // a bare reference will always resolve to the default catalog and schema + .schema_for_ref(TableReference::Bare { table: "" })? + .table_names() + .iter() .cloned() - .collect() + .collect()) } /// Optimizes the logical plan by applying optimizer rules. @@ -512,6 +573,12 @@ pub struct ExecutionConfig { optimizers: Vec>, /// Responsible for planning `LogicalPlan`s, and `ExecutionPlan` query_planner: Arc, + /// Default catalog name for table resolution + default_catalog: String, + /// Default schema name for table resolution + default_schema: String, + /// Whether the default catalog and schema should be created automatically + create_default_catalog_and_schema: bool, } impl ExecutionConfig { @@ -528,6 +595,9 @@ impl ExecutionConfig { Arc::new(LimitPushDown::new()), ], query_planner: Arc::new(DefaultQueryPlanner {}), + default_catalog: "datafusion".to_owned(), + default_schema: "public".to_owned(), + create_default_catalog_and_schema: true, } } @@ -564,13 +634,30 @@ impl ExecutionConfig { self.optimizers.push(optimizer_rule); self } + + /// Selects a name for the default catalog and schema + pub fn with_default_catalog_and_schema( + mut self, + catalog: impl Into, + schema: impl Into, + ) -> Self { + self.default_catalog = catalog.into(); + self.default_schema = schema.into(); + self + } + + /// Controls whether the default catalog and schema will be automatically created + pub fn create_default_catalog_and_schema(mut self, create: bool) -> Self { + self.create_default_catalog_and_schema = create; + self + } } /// Execution context for registering data sources and executing queries #[derive(Clone)] pub struct ExecutionContextState { - /// Data sources that are registered with the context - pub datasources: HashMap>, + /// Collection of catalogs containing schemas and ultimately TableProviders + pub catalogs: HashMap>, /// Scalar functions that are registered with the context pub scalar_functions: HashMap>, /// Variable provider that are registered with the context @@ -581,12 +668,45 @@ pub struct ExecutionContextState { pub config: ExecutionConfig, } +impl ExecutionContextState { + fn resolve_table_ref<'a>( + &'a self, + table_ref: impl Into>, + ) -> ResolvedTableReference<'a> { + table_ref + .into() + .resolve(&self.config.default_catalog, &self.config.default_schema) + } + + fn schema_for_ref<'a>( + &'a self, + table_ref: impl Into>, + ) -> Result> { + let resolved_ref = self.resolve_table_ref(table_ref.into()); + + self.catalogs + .get(resolved_ref.catalog) + .ok_or_else(|| { + DataFusionError::Plan(format!( + "failed to resolve catalog: {}", + resolved_ref.catalog + )) + })? + .schema(resolved_ref.schema) + .ok_or_else(|| { + DataFusionError::Plan(format!( + "failed to resolve schema: {}", + resolved_ref.schema + )) + }) + } +} + impl ContextProvider for ExecutionContextState { - fn get_table_provider( - &self, - name: &str, - ) -> Option> { - self.datasources.get(name).map(|ds| Arc::clone(ds)) + fn get_table_provider(&self, name: TableReference) -> Option> { + let resolved_ref = self.resolve_table_ref(name); + let schema = self.schema_for_ref(resolved_ref).ok()?; + schema.table(resolved_ref.table) } fn get_function_meta(&self, name: &str) -> Option> { @@ -731,7 +851,7 @@ mod tests { ctx.register_variable(VarType::UserDefined, Arc::new(variable_provider)); let provider = test::create_table_dual(); - ctx.register_table("dual", provider); + ctx.register_table("dual", provider)?; let results = plan_and_collect(&mut ctx, "SELECT @@version, @name FROM dual").await?; @@ -755,10 +875,10 @@ mod tests { let mut ctx = create_ctx(&tmp_dir, partition_count)?; let provider = test::create_table_dual(); - ctx.register_table("dual", provider); + ctx.register_table("dual", provider)?; - assert!(ctx.deregister_table("dual").is_some()); - assert!(ctx.deregister_table("dual").is_none()); + assert!(ctx.deregister_table("dual")?.is_some()); + assert!(ctx.deregister_table("dual")?.is_none()); Ok(()) } @@ -875,17 +995,10 @@ mod tests { let tmp_dir = TempDir::new()?; let ctx = create_ctx(&tmp_dir, 1)?; - let schema = ctx - .state - .lock() - .unwrap() - .datasources - .get("test") - .unwrap() - .schema(); + let schema: Schema = ctx.table("test").unwrap().schema().clone().into(); assert_eq!(schema.field_with_name("c1")?.is_nullable(), false); - let plan = LogicalPlanBuilder::scan_empty("", schema.as_ref(), None)? + let plan = LogicalPlanBuilder::scan_empty("", &schema, None)? .project(&[col("c1")])? .build()?; @@ -1362,7 +1475,7 @@ mod tests { .unwrap(); let provider = MemTable::try_new(schema.clone(), vec![vec![batch]]).unwrap(); - ctx.register_table("t", Arc::new(provider)); + ctx.register_table("t", Arc::new(provider)).unwrap(); let results = plan_and_collect( &mut ctx, @@ -1745,7 +1858,7 @@ mod tests { let mut ctx = ExecutionContext::new(); let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]])?; - ctx.register_table("t", Arc::new(provider)); + ctx.register_table("t", Arc::new(provider))?; let myfunc = |args: &[ArrayRef]| { let l = &args[0] @@ -1825,7 +1938,7 @@ mod tests { assert_eq!(a.value(i) + b.value(i), sum.value(i)); } - ctx.deregister_table("t"); + ctx.deregister_table("t")?; Ok(()) } @@ -1847,7 +1960,7 @@ mod tests { let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; - ctx.register_table("t", Arc::new(provider)); + ctx.register_table("t", Arc::new(provider))?; let result = plan_and_collect(&mut ctx, "SELECT AVG(a) FROM t").await?; @@ -1884,7 +1997,7 @@ mod tests { let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; - ctx.register_table("t", Arc::new(provider)); + ctx.register_table("t", Arc::new(provider))?; // define a udaf, using a DataFusion's accumulator let my_avg = create_udaf( @@ -1922,6 +2035,114 @@ mod tests { Ok(()) } + fn table_with_sequence( + seq_start: i32, + seq_end: i32, + ) -> Result> { + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); + let arr = Arc::new(Int32Array::from((seq_start..=seq_end).collect::>())); + let partitions = vec![vec![RecordBatch::try_new( + schema.clone(), + vec![arr as ArrayRef], + )?]]; + Ok(Arc::new(MemTable::try_new(schema, partitions)?)) + } + + #[tokio::test] + async fn disabled_default_catalog_and_schema() -> Result<()> { + let mut ctx = ExecutionContext::with_config( + ExecutionConfig::new().create_default_catalog_and_schema(false), + ); + + assert!(matches!( + ctx.register_table("test", table_with_sequence(1, 1)?), + Err(DataFusionError::Plan(_)) + )); + + assert!(matches!( + ctx.sql("select * from datafusion.public.test"), + Err(DataFusionError::Plan(_)) + )); + + Ok(()) + } + + #[tokio::test] + async fn custom_catalog_and_schema() -> Result<()> { + let mut ctx = ExecutionContext::with_config( + ExecutionConfig::new() + .create_default_catalog_and_schema(false) + .with_default_catalog_and_schema("my_catalog", "my_schema"), + ); + + let catalog = MemoryCatalogProvider::new(); + let schema = MemorySchemaProvider::new(); + schema.register_table("test".to_owned(), table_with_sequence(1, 1)?)?; + catalog.register_schema("my_schema", Arc::new(schema)); + ctx.register_catalog("my_catalog", Arc::new(catalog)); + + for table_ref in &["my_catalog.my_schema.test", "my_schema.test", "test"] { + let result = plan_and_collect( + &mut ctx, + &format!("SELECT COUNT(*) AS count FROM {}", table_ref), + ) + .await?; + + let expected = vec![ + "+-------+", + "| count |", + "+-------+", + "| 1 |", + "+-------+", + ]; + assert_batches_eq!(expected, &result); + } + + Ok(()) + } + + #[tokio::test] + async fn cross_catalog_access() -> Result<()> { + let mut ctx = ExecutionContext::new(); + + let catalog_a = MemoryCatalogProvider::new(); + let schema_a = MemorySchemaProvider::new(); + schema_a.register_table("table_a".to_owned(), table_with_sequence(1, 1)?)?; + catalog_a.register_schema("schema_a", Arc::new(schema_a)); + ctx.register_catalog("catalog_a", Arc::new(catalog_a)); + + let catalog_b = MemoryCatalogProvider::new(); + let schema_b = MemorySchemaProvider::new(); + schema_b.register_table("table_b".to_owned(), table_with_sequence(1, 2)?)?; + catalog_b.register_schema("schema_b", Arc::new(schema_b)); + ctx.register_catalog("catalog_b", Arc::new(catalog_b)); + + let result = plan_and_collect( + &mut ctx, + "SELECT cat, SUM(i) AS total FROM ( + SELECT i, 'a' AS cat FROM catalog_a.schema_a.table_a + UNION ALL + SELECT i, 'b' AS cat FROM catalog_b.schema_b.table_b + ) + GROUP BY cat + ORDER BY cat + ", + ) + .await?; + + let expected = vec![ + "+-----+-------+", + "| cat | total |", + "+-----+-------+", + "| a | 1 |", + "| b | 3 |", + "+-----+-------+", + ]; + assert_batches_eq!(expected, &result); + + Ok(()) + } + struct MyPhysicalPlanner {} impl PhysicalPlanner for MyPhysicalPlanner { diff --git a/rust/datafusion/src/lib.rs b/rust/datafusion/src/lib.rs index f0fcc4f1d29..e9bf2b4877a 100644 --- a/rust/datafusion/src/lib.rs +++ b/rust/datafusion/src/lib.rs @@ -157,6 +157,7 @@ extern crate arrow; extern crate sqlparser; +pub mod catalog; pub mod dataframe; pub mod datasource; pub mod error; diff --git a/rust/datafusion/src/logical_plan/builder.rs b/rust/datafusion/src/logical_plan/builder.rs index 58dfd0fa5d6..a89e797c7b6 100644 --- a/rust/datafusion/src/logical_plan/builder.rs +++ b/rust/datafusion/src/logical_plan/builder.rs @@ -103,7 +103,7 @@ impl LogicalPlanBuilder { /// Convert a table provider into a builder with a TableScan pub fn scan( name: &str, - provider: Arc, + provider: Arc, projection: Option>, ) -> Result { let schema = provider.schema(); diff --git a/rust/datafusion/src/logical_plan/plan.rs b/rust/datafusion/src/logical_plan/plan.rs index 92110c8daf1..00d25fb0f3f 100644 --- a/rust/datafusion/src/logical_plan/plan.rs +++ b/rust/datafusion/src/logical_plan/plan.rs @@ -134,7 +134,7 @@ pub enum LogicalPlan { /// The name of the table table_name: String, /// The source of the table - source: Arc, + source: Arc, /// Optional column indices to use as a projection projection: Option>, /// The schema description of the output diff --git a/rust/datafusion/src/physical_plan/parquet.rs b/rust/datafusion/src/physical_plan/parquet.rs index 5e9b1562751..569eb59847e 100644 --- a/rust/datafusion/src/physical_plan/parquet.rs +++ b/rust/datafusion/src/physical_plan/parquet.rs @@ -392,7 +392,7 @@ impl RowGroupPredicateBuilder { .collect::>(); let stat_schema = Schema::new(stat_fields); let execution_context_state = ExecutionContextState { - datasources: HashMap::new(), + catalogs: HashMap::new(), scalar_functions: HashMap::new(), var_provider: HashMap::new(), aggregate_functions: HashMap::new(), diff --git a/rust/datafusion/src/physical_plan/planner.rs b/rust/datafusion/src/physical_plan/planner.rs index c6f321ca9b8..81500b3522c 100644 --- a/rust/datafusion/src/physical_plan/planner.rs +++ b/rust/datafusion/src/physical_plan/planner.rs @@ -764,7 +764,7 @@ mod tests { fn make_ctx_state() -> ExecutionContextState { ExecutionContextState { - datasources: HashMap::new(), + catalogs: HashMap::new(), scalar_functions: HashMap::new(), var_provider: HashMap::new(), aggregate_functions: HashMap::new(), diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index 31cea52a0f8..45ad6891866 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -17,9 +17,11 @@ //! SQL Query Planner (produces logical plan from SQL AST) +use std::convert::TryInto; use std::str::FromStr; use std::sync::Arc; +use crate::catalog::TableReference; use crate::datasource::TableProvider; use crate::logical_plan::Expr::Alias; use crate::logical_plan::{ @@ -58,10 +60,7 @@ use super::utils::{ /// functions referenced in SQL statements pub trait ContextProvider { /// Getter for a datasource - fn get_table_provider( - &self, - name: &str, - ) -> Option>; + fn get_table_provider(&self, name: TableReference) -> Option>; /// Getter for a UDF description fn get_function_meta(&self, name: &str) -> Option>; /// Getter for a UDAF description @@ -376,7 +375,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { match relation { TableFactor::Table { name, .. } => { let table_name = name.to_string(); - match self.schema_provider.get_table_provider(&table_name) { + match self.schema_provider.get_table_provider(name.try_into()?) { Some(provider) => { LogicalPlanBuilder::scan(&table_name, provider, None)?.build() } @@ -2495,9 +2494,9 @@ mod tests { impl ContextProvider for MockContextProvider { fn get_table_provider( &self, - name: &str, - ) -> Option> { - let schema = match name { + name: TableReference, + ) -> Option> { + let schema = match name.table() { "person" => Some(Schema::new(vec![ Field::new("id", DataType::UInt32, false), Field::new("first_name", DataType::Utf8, false), @@ -2540,7 +2539,7 @@ mod tests { ])), _ => None, }; - schema.map(|s| -> Arc { + schema.map(|s| -> Arc { Arc::new(EmptyTable::new(Arc::new(s))) }) } diff --git a/rust/datafusion/src/test/mod.rs b/rust/datafusion/src/test/mod.rs index 75a956f1cf4..04f340a9936 100644 --- a/rust/datafusion/src/test/mod.rs +++ b/rust/datafusion/src/test/mod.rs @@ -29,7 +29,7 @@ use std::io::{BufReader, BufWriter}; use std::sync::Arc; use tempfile::TempDir; -pub fn create_table_dual() -> Arc { +pub fn create_table_dual() -> Arc { let dual_schema = Arc::new(Schema::new(vec![ Field::new("id", DataType::Int32, false), Field::new("name", DataType::Utf8, false), diff --git a/rust/datafusion/tests/dataframe.rs b/rust/datafusion/tests/dataframe.rs index e0c698ed5fb..9d5f92a7753 100644 --- a/rust/datafusion/tests/dataframe.rs +++ b/rust/datafusion/tests/dataframe.rs @@ -61,11 +61,11 @@ async fn join() -> Result<()> { let table1 = MemTable::try_new(schema1, vec![vec![batch1]])?; let table2 = MemTable::try_new(schema2, vec![vec![batch2]])?; - ctx.register_table("aa", Arc::new(table1)); + ctx.register_table("aa", Arc::new(table1))?; let df1 = ctx.table("aa")?; - ctx.register_table("aaa", Arc::new(table2)); + ctx.register_table("aaa", Arc::new(table2))?; let df2 = ctx.table("aaa")?; diff --git a/rust/datafusion/tests/provider_filter_pushdown.rs b/rust/datafusion/tests/provider_filter_pushdown.rs index a64f7fb74fb..f38ac59341e 100644 --- a/rust/datafusion/tests/provider_filter_pushdown.rs +++ b/rust/datafusion/tests/provider_filter_pushdown.rs @@ -156,7 +156,7 @@ async fn assert_provider_row_count(value: i64, expected_count: u64) -> Result<() let result_col: &UInt64Array = as_primitive_array(results[0].column(0)); assert_eq!(result_col.value(0), expected_count); - ctx.register_table("data", Arc::new(provider)); + ctx.register_table("data", Arc::new(provider))?; let sql_results = ctx .sql(&format!("select count(*) from data where flag = {}", value))? .collect() diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 1bce3f7c07c..ba8230a442f 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -35,7 +35,10 @@ use datafusion::{ datasource::{csv::CsvReadOptions, MemTable}, physical_plan::collect, }; -use datafusion::{error::Result, physical_plan::ColumnarValue}; +use datafusion::{ + error::{DataFusionError, Result}, + physical_plan::ColumnarValue, +}; #[tokio::test] async fn nyc() -> Result<()> { @@ -1168,7 +1171,7 @@ fn create_case_context() -> Result { ]))], )?; let table = MemTable::try_new(schema, vec![vec![data]])?; - ctx.register_table("t1", Arc::new(table)); + ctx.register_table("t1", Arc::new(table))?; Ok(ctx) } @@ -1317,7 +1320,7 @@ fn create_join_context( ], )?; let t1_table = MemTable::try_new(t1_schema, vec![vec![t1_data]])?; - ctx.register_table("t1", Arc::new(t1_table)); + ctx.register_table("t1", Arc::new(t1_table))?; let t2_schema = Arc::new(Schema::new(vec![ Field::new(column_right, DataType::UInt32, true), @@ -1336,7 +1339,7 @@ fn create_join_context( ], )?; let t2_table = MemTable::try_new(t2_schema, vec![vec![t2_data]])?; - ctx.register_table("t2", Arc::new(t2_table)); + ctx.register_table("t2", Arc::new(t2_table))?; Ok(ctx) } @@ -1358,7 +1361,7 @@ fn create_join_context_qualified() -> Result { ], )?; let t1_table = MemTable::try_new(t1_schema, vec![vec![t1_data]])?; - ctx.register_table("t1", Arc::new(t1_table)); + ctx.register_table("t1", Arc::new(t1_table))?; let t2_schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::UInt32, true), @@ -1374,7 +1377,7 @@ fn create_join_context_qualified() -> Result { ], )?; let t2_table = MemTable::try_new(t2_schema, vec![vec![t2_data]])?; - ctx.register_table("t2", Arc::new(t2_table)); + ctx.register_table("t2", Arc::new(t2_table))?; Ok(ctx) } @@ -1594,7 +1597,7 @@ async fn generic_query_length>>( let table = MemTable::try_new(schema, vec![vec![data]])?; let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table)); + ctx.register_table("test", Arc::new(table))?; let sql = "SELECT length(c1) FROM test"; let actual = execute(&mut ctx, sql).await; let expected = vec![vec!["0"], vec!["1"], vec!["2"], vec!["3"]]; @@ -1630,7 +1633,7 @@ async fn query_not() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table)); + ctx.register_table("test", Arc::new(table))?; let sql = "SELECT NOT c1 FROM test"; let actual = execute(&mut ctx, sql).await; let expected = vec![vec!["true"], vec!["NULL"], vec!["false"]]; @@ -1656,7 +1659,7 @@ async fn query_concat() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table)); + ctx.register_table("test", Arc::new(table))?; let sql = "SELECT concat(c1, '-hi-', cast(c2 as varchar)) FROM test"; let actual = execute(&mut ctx, sql).await; let expected = vec![ @@ -1687,7 +1690,7 @@ async fn query_array() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table)); + ctx.register_table("test", Arc::new(table))?; let sql = "SELECT array(c1, cast(c2 as varchar)) FROM test"; let actual = execute(&mut ctx, sql).await; let expected = vec![ @@ -1780,7 +1783,7 @@ fn make_timestamp_nano_table() -> Result> { #[tokio::test] async fn to_timestamp() -> Result<()> { let mut ctx = ExecutionContext::new(); - ctx.register_table("ts_data", make_timestamp_nano_table()?); + ctx.register_table("ts_data", make_timestamp_nano_table()?)?; let sql = "SELECT COUNT(*) FROM ts_data where ts > to_timestamp('2020-09-08T12:00:00+00:00')"; let actual = execute(&mut ctx, sql).await; @@ -1806,7 +1809,7 @@ async fn query_is_null() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table)); + ctx.register_table("test", Arc::new(table))?; let sql = "SELECT c1 IS NULL FROM test"; let actual = execute(&mut ctx, sql).await; let expected = vec![vec!["false"], vec!["true"], vec!["false"]]; @@ -1830,7 +1833,7 @@ async fn query_is_not_null() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table)); + ctx.register_table("test", Arc::new(table))?; let sql = "SELECT c1 IS NOT NULL FROM test"; let actual = execute(&mut ctx, sql).await; let expected = vec![vec!["true"], vec!["false"], vec!["true"]]; @@ -1857,7 +1860,7 @@ async fn query_count_distinct() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table)); + ctx.register_table("test", Arc::new(table))?; let sql = "SELECT COUNT(DISTINCT c1) FROM test"; let actual = execute(&mut ctx, sql).await; let expected = vec![vec!["3".to_string()]]; @@ -1886,7 +1889,7 @@ async fn query_on_string_dictionary() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table)); + ctx.register_table("test", Arc::new(table))?; // Basic SELECT let sql = "SELECT * FROM test"; @@ -1957,7 +1960,7 @@ async fn query_scalar_minus_array() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table)); + ctx.register_table("test", Arc::new(table))?; let sql = "SELECT 4 - c1 FROM test"; let actual = execute(&mut ctx, sql).await; let expected = vec![vec!["4"], vec!["3"], vec!["NULL"], vec!["1"]]; @@ -2036,7 +2039,7 @@ async fn csv_group_by_date() -> Result<()> { )?; let table = MemTable::try_new(schema, vec![vec![data]])?; - ctx.register_table("dates", Arc::new(table)); + ctx.register_table("dates", Arc::new(table))?; let sql = "SELECT SUM(cnt) FROM dates GROUP BY date"; let actual = execute(&mut ctx, sql).await; let mut actual: Vec = actual.iter().flatten().cloned().collect(); @@ -2504,3 +2507,36 @@ async fn inner_join_qualified_names() -> Result<()> { } Ok(()) } + +#[tokio::test] +async fn qualified_table_references() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx)?; + + for table_ref in &[ + "aggregate_test_100", + "public.aggregate_test_100", + "datafusion.public.aggregate_test_100", + ] { + let sql = format!("SELECT COUNT(*) FROM {}", table_ref); + let results = execute(&mut ctx, &sql).await; + assert_eq!(results, vec![vec!["100"]]); + } + Ok(()) +} + +#[tokio::test] +async fn invalid_qualified_table_references() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx)?; + + for table_ref in &[ + "nonexistentschema.aggregate_test_100", + "nonexistentcatalog.public.aggregate_test_100", + "way.too.many.namespaces.as.ident.prefixes.aggregate_test_100", + ] { + let sql = format!("SELECT COUNT(*) FROM {}", table_ref); + assert!(matches!(ctx.sql(&sql), Err(DataFusionError::Plan(_)))); + } + Ok(()) +}