Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make SchemaProvider::table async #4607

Merged
merged 9 commits into from
Jan 5, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion datafusion-examples/examples/dataframe_in_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async fn main() -> Result<()> {

// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
ctx.register_batch("t", batch)?;
let df = ctx.table("t")?;
let df = ctx.table("t").await?;

// construct an expression corresponding to "SELECT a, b FROM t WHERE b = 10" in SQL
let filter = col("b").eq(lit(10));
Expand Down
2 changes: 1 addition & 1 deletion datafusion-examples/examples/simple_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ async fn main() -> Result<()> {

// get a DataFrame from the context
// this table has 1 column `a` f32 with values {2,4,8,64}, whose geometric mean is 8.0.
let df = ctx.table("t")?;
let df = ctx.table("t").await?;

// perform the aggregation
let df = df.aggregate(vec![], vec![geometric_mean.call(vec![col("a")])])?;
Expand Down
2 changes: 1 addition & 1 deletion datafusion-examples/examples/simple_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ async fn main() -> Result<()> {
let expr = pow.call(vec![col("a"), col("b")]);

// get a DataFrame from the context
let df = ctx.table("t")?;
let df = ctx.table("t").await?;

// if we do not have `pow` in the scope and we registered it, we can get it from the registry
let pow = df.registry().udf("pow")?;
Expand Down
6 changes: 6 additions & 0 deletions datafusion/common/src/table_reference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ pub struct ResolvedTableReference<'a> {
pub table: &'a str,
}

impl<'a> std::fmt::Display for ResolvedTableReference<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}.{}.{}", self.catalog, self.schema, self.table)
}
}

/// Represents a path to a table that may require further resolution
#[derive(Debug, Clone, Copy)]
pub enum TableReference<'a> {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ pyo3 = { version = "0.17.1", optional = true }
rand = "0.8"
rayon = { version = "1.5", optional = true }
smallvec = { version = "1.6", features = ["union"] }
sqlparser = "0.30"
sqlparser = { version = "0.30", features = ["visitor"] }
tempfile = "3"
tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] }
tokio-stream = "0.1"
Expand Down
25 changes: 15 additions & 10 deletions datafusion/core/src/catalog/information_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
//!
//! Information Schema]<https://en.wikipedia.org/wiki/Information_schema>

use async_trait::async_trait;
use std::{any::Any, sync::Arc};

use arrow::{
Expand All @@ -43,6 +44,9 @@ pub const VIEWS: &str = "views";
pub const COLUMNS: &str = "columns";
pub const DF_SETTINGS: &str = "df_settings";

/// All information schema tables
pub const INFORMATION_SCHEMA_TABLES: &[&str] = &[TABLES, VIEWS, COLUMNS, DF_SETTINGS];

/// Implements the `information_schema` virtual schema and tables
///
/// The underlying tables in the `information_schema` are created on
Expand All @@ -69,7 +73,7 @@ struct InformationSchemaConfig {

impl InformationSchemaConfig {
/// Construct the `information_schema.tables` virtual table
fn make_tables(&self, builder: &mut InformationSchemaTablesBuilder) {
async fn make_tables(&self, builder: &mut InformationSchemaTablesBuilder) {
// create a mem table with the names of tables

for catalog_name in self.catalog_list.catalog_names() {
Expand All @@ -79,7 +83,7 @@ impl InformationSchemaConfig {
if schema_name != INFORMATION_SCHEMA {
let schema = catalog.schema(&schema_name).unwrap();
for table_name in schema.table_names() {
let table = schema.table(&table_name).unwrap();
let table = schema.table(&table_name).await.unwrap();
builder.add_table(
&catalog_name,
&schema_name,
Expand Down Expand Up @@ -108,15 +112,15 @@ impl InformationSchemaConfig {
}
}

fn make_views(&self, builder: &mut InformationSchemaViewBuilder) {
async fn make_views(&self, builder: &mut InformationSchemaViewBuilder) {
for catalog_name in self.catalog_list.catalog_names() {
let catalog = self.catalog_list.catalog(&catalog_name).unwrap();

for schema_name in catalog.schema_names() {
if schema_name != INFORMATION_SCHEMA {
let schema = catalog.schema(&schema_name).unwrap();
for table_name in schema.table_names() {
let table = schema.table(&table_name).unwrap();
let table = schema.table(&table_name).await.unwrap();
builder.add_view(
&catalog_name,
&schema_name,
Expand All @@ -130,15 +134,15 @@ impl InformationSchemaConfig {
}

/// Construct the `information_schema.columns` virtual table
fn make_columns(&self, builder: &mut InformationSchemaColumnsBuilder) {
async fn make_columns(&self, builder: &mut InformationSchemaColumnsBuilder) {
for catalog_name in self.catalog_list.catalog_names() {
let catalog = self.catalog_list.catalog(&catalog_name).unwrap();

for schema_name in catalog.schema_names() {
if schema_name != INFORMATION_SCHEMA {
let schema = catalog.schema(&schema_name).unwrap();
for table_name in schema.table_names() {
let table = schema.table(&table_name).unwrap();
let table = schema.table(&table_name).await.unwrap();
for (i, field) in table.schema().fields().iter().enumerate() {
builder.add_column(
&catalog_name,
Expand Down Expand Up @@ -168,6 +172,7 @@ impl InformationSchemaConfig {
}
}

#[async_trait]
impl SchemaProvider for InformationSchemaProvider {
fn as_any(&self) -> &(dyn Any + 'static) {
self
Expand All @@ -182,7 +187,7 @@ impl SchemaProvider for InformationSchemaProvider {
]
}

fn table(&self, name: &str) -> Option<Arc<dyn TableProvider>> {
async fn table(&self, name: &str) -> Option<Arc<dyn TableProvider>> {
let config = self.config.clone();
let table: Arc<dyn PartitionStream> = if name.eq_ignore_ascii_case("tables") {
Arc::new(InformationSchemaTables::new(config))
Expand Down Expand Up @@ -246,7 +251,7 @@ impl PartitionStream for InformationSchemaTables {
self.schema.clone(),
// TODO: Stream this
futures::stream::once(async move {
config.make_tables(&mut builder);
config.make_tables(&mut builder).await;
Ok(builder.finish())
}),
))
Expand Down Expand Up @@ -337,7 +342,7 @@ impl PartitionStream for InformationSchemaViews {
self.schema.clone(),
// TODO: Stream this
futures::stream::once(async move {
config.make_views(&mut builder);
config.make_views(&mut builder).await;
Ok(builder.finish())
}),
))
Expand Down Expand Up @@ -451,7 +456,7 @@ impl PartitionStream for InformationSchemaColumns {
self.schema.clone(),
// TODO: Stream this
futures::stream::once(async move {
config.make_columns(&mut builder);
config.make_columns(&mut builder).await;
Ok(builder.finish())
}),
))
Expand Down
4 changes: 3 additions & 1 deletion datafusion/core/src/catalog/listing_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use crate::catalog::schema::SchemaProvider;
use crate::datasource::datasource::TableProviderFactory;
use crate::datasource::TableProvider;
use crate::execution::context::SessionState;
use async_trait::async_trait;
use datafusion_common::parsers::CompressionTypeVariant;
use datafusion_common::{DFSchema, DataFusionError, OwnedTableReference};
use datafusion_expr::CreateExternalTable;
Expand Down Expand Up @@ -156,6 +157,7 @@ impl ListingSchemaProvider {
}
}

#[async_trait]
impl SchemaProvider for ListingSchemaProvider {
fn as_any(&self) -> &dyn Any {
self
Expand All @@ -170,7 +172,7 @@ impl SchemaProvider for ListingSchemaProvider {
.collect()
}

fn table(&self, name: &str) -> Option<Arc<dyn TableProvider>> {
async fn table(&self, name: &str) -> Option<Arc<dyn TableProvider>> {
self.tables
.lock()
.expect("Can't lock tables")
Expand Down
7 changes: 5 additions & 2 deletions datafusion/core/src/catalog/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
//! Describes the interface and built-in implementations of schemas,
//! representing collections of named tables.

use async_trait::async_trait;
use dashmap::DashMap;
use std::any::Any;
use std::sync::Arc;
Expand All @@ -26,6 +27,7 @@ use crate::datasource::TableProvider;
use crate::error::{DataFusionError, Result};

/// Represents a schema, comprising a number of named tables.
#[async_trait]
pub trait SchemaProvider: Sync + Send {
/// Returns the schema provider as [`Any`](std::any::Any)
/// so that it can be downcast to a specific implementation.
Expand All @@ -35,7 +37,7 @@ pub trait SchemaProvider: Sync + Send {
fn table_names(&self) -> Vec<String>;

/// Retrieves a specific table from the schema by name, provided it exists.
fn table(&self, name: &str) -> Option<Arc<dyn TableProvider>>;
async fn table(&self, name: &str) -> Option<Arc<dyn TableProvider>>;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And this is the change to support async catalogs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


/// If supported by the implementation, adds a new table to this schema.
/// If a table of the same name existed before, it returns "Table already exists" error.
Expand Down Expand Up @@ -85,6 +87,7 @@ impl Default for MemorySchemaProvider {
}
}

#[async_trait]
impl SchemaProvider for MemorySchemaProvider {
fn as_any(&self) -> &dyn Any {
self
Expand All @@ -97,7 +100,7 @@ impl SchemaProvider for MemorySchemaProvider {
.collect()
}

fn table(&self, name: &str) -> Option<Arc<dyn TableProvider>> {
async fn table(&self, name: &str) -> Option<Arc<dyn TableProvider>> {
self.tables.get(name).map(|table| table.value().clone())
}

Expand Down
27 changes: 18 additions & 9 deletions datafusion/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,7 @@ mod tests {
let ctx = SessionContext::new();
ctx.register_batch("t", batch)?;

let df = ctx.table("t")?.select_columns(&["f.c1"])?;
let df = ctx.table("t").await?.select_columns(&["f.c1"])?;

let df_results = df.collect().await?;

Expand Down Expand Up @@ -1036,7 +1036,7 @@ mod tests {
));

// build query with a UDF using DataFrame API
let df = ctx.table("aggregate_test_100")?;
let df = ctx.table("aggregate_test_100").await?;

let expr = df.registry().udf("my_fn")?.call(vec![col("c12")]);
let df = df.select(vec![expr])?;
Expand Down Expand Up @@ -1101,7 +1101,7 @@ mod tests {
ctx.register_table("test_table", Arc::new(df_impl.clone()))?;

// pull the table out
let table = ctx.table("test_table")?;
let table = ctx.table("test_table").await?;

let group_expr = vec![col("c1")];
let aggr_expr = vec![sum(col("c12"))];
Expand Down Expand Up @@ -1161,7 +1161,7 @@ mod tests {
async fn test_table_with_name(name: &str) -> Result<DataFrame> {
let mut ctx = SessionContext::new();
register_aggregate_csv(&mut ctx, name).await?;
ctx.table(name)
ctx.table(name).await
}

async fn test_table() -> Result<DataFrame> {
Expand Down Expand Up @@ -1301,8 +1301,15 @@ mod tests {
ctx.register_table("t1", table.clone())?;
ctx.register_table("t2", table)?;
let df = ctx
.table("t1")?
.join(ctx.table("t2")?, JoinType::Inner, &["c1"], &["c1"], None)?
.table("t1")
.await?
.join(
ctx.table("t2").await?,
JoinType::Inner,
&["c1"],
&["c1"],
None,
)?
.sort(vec![
// make the test deterministic
col("t1.c1").sort(true, true),
Expand Down Expand Up @@ -1379,10 +1386,11 @@ mod tests {
)
.await?;

ctx.register_table("t1", Arc::new(ctx.table("test")?))?;
ctx.register_table("t1", Arc::new(ctx.table("test").await?))?;

let df = ctx
.table("t1")?
.table("t1")
.await?
.filter(col("id").eq(lit(1)))?
.select_columns(&["bool_col", "int_col"])?;

Expand Down Expand Up @@ -1463,7 +1471,8 @@ mod tests {
ctx.register_batch("t", batch)?;

let df = ctx
.table("t")?
.table("t")
.await?
// try and create a column with a '.' in it
.with_column("f.c2", lit("hello"))?;

Expand Down
10 changes: 6 additions & 4 deletions datafusion/core/src/datasource/view.rs
Original file line number Diff line number Diff line change
Expand Up @@ -428,12 +428,13 @@ mod tests {
)
.await?;

ctx.register_table("t1", Arc::new(ctx.table("test")?))?;
ctx.register_table("t1", Arc::new(ctx.table("test").await?))?;

ctx.sql("CREATE VIEW t2 as SELECT * FROM t1").await?;

let df = ctx
.table("t2")?
.table("t2")
.await?
.filter(col("id").eq(lit(1)))?
.select_columns(&["bool_col", "int_col"])?;

Expand All @@ -457,12 +458,13 @@ mod tests {
)
.await?;

ctx.register_table("t1", Arc::new(ctx.table("test")?))?;
ctx.register_table("t1", Arc::new(ctx.table("test").await?))?;

ctx.sql("CREATE VIEW t2 as SELECT * FROM t1").await?;

let df = ctx
.table("t2")?
.table("t2")
.await?
.limit(0, Some(10))?
.select_columns(&["bool_col", "int_col"])?;

Expand Down
Loading