diff --git a/README.md b/README.md index 409577a..dd4db52 100644 --- a/README.md +++ b/README.md @@ -22,8 +22,6 @@ project. - Permission control - Built-in `pg_catalog` tables - Built-in postgres functions for common meta queries - - [x] DBeaver compatibility - - [x] pgcli compatibility - `datafusion-postgres-cli`: A cli tool starts a postgres compatible server for datafusion supported file formats, just like python's `SimpleHTTPServer`. - `arrow-pg`: A data type mapping, encoding/decoding library for arrow and @@ -31,6 +29,14 @@ project. See `auth.rs` for complete implementation examples using `DfAuthSource`. +## Supported Database Clients + +- Database Clients + - [x] DBeaver + - [x] pgcli +- BI + - [x] Metabase + ## Quick Start ### The Library `datafusion-postgres` diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index b808012..b41ccd5 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -3,12 +3,15 @@ use std::sync::Arc; use crate::auth::{AuthManager, Permission, ResourceType}; use crate::sql::{ - parse, rewrite, AliasDuplicatedProjectionRewrite, BlacklistSqlRewriter, FixArrayLiteral, - PrependUnqualifiedPgTableName, RemoveTableFunctionQualifier, RemoveUnsupportedTypes, - ResolveUnqualifiedIdentifer, RewriteArrayAnyAllOperation, SqlStatementRewriteRule, + parse, rewrite, AliasDuplicatedProjectionRewrite, BlacklistSqlRewriter, + CurrentUserVariableToSessionUserFunctionCall, FixArrayLiteral, PrependUnqualifiedPgTableName, + RemoveTableFunctionQualifier, RemoveUnsupportedTypes, ResolveUnqualifiedIdentifer, + RewriteArrayAnyAllOperation, SqlStatementRewriteRule, }; use async_trait::async_trait; -use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::datatypes::{DataType, Field, Schema}; +use datafusion::common::ToDFSchema; +use datafusion::error::DataFusionError; use datafusion::logical_expr::LogicalPlan; use datafusion::prelude::*; use datafusion::sql::parser::Statement; @@ -107,6 +110,7 @@ impl DfSessionService { Arc::new(PrependUnqualifiedPgTableName), Arc::new(FixArrayLiteral), Arc::new(RemoveTableFunctionQualifier), + Arc::new(CurrentUserVariableToSessionUserFunctionCall), ]; let parser = Arc::new(Parser { session_context: session_context.clone(), @@ -420,13 +424,15 @@ impl DfSessionService { let resp = Self::mock_show_response("statement_timeout", &timeout_str)?; Ok(Some(Response::Query(resp))) } - _ => Err(PgWireError::UserError(Box::new( - pgwire::error::ErrorInfo::new( - "ERROR".to_string(), - "42704".to_string(), - format!("Unrecognized SHOW command: {query_lower}"), - ), - ))), + "show transaction isolation level" => { + let resp = Self::mock_show_response("transaction_isolation", "read_committed")?; + Ok(Some(Response::Query(resp))) + } + _ => { + info!("Unsupported show statement: {query_lower}"); + let resp = Self::mock_show_response("unsupported_show_statement", "")?; + Ok(Some(Response::Query(resp))) + } } } else { Ok(None) @@ -714,24 +720,15 @@ pub struct Parser { sql_rewrite_rules: Vec>, } -#[async_trait] -impl QueryParser for Parser { - type Statement = (String, LogicalPlan); - - async fn parse_sql( - &self, - _client: &C, - sql: &str, - _types: &[Type], - ) -> PgWireResult { - log::debug!("Received parse extended query: {sql}"); // Log for debugging - +impl Parser { + fn try_shortcut_parse_plan(&self, sql: &str) -> Result, DataFusionError> { // Check for transaction commands that shouldn't be parsed by DataFusion let sql_lower = sql.to_lowercase(); let sql_trimmed = sql_lower.trim(); + if matches!( sql_trimmed, - "begin" + "" | "begin" | "begin transaction" | "begin work" | "start transaction" @@ -747,13 +744,50 @@ impl QueryParser for Parser { ) { // Return a dummy plan for transaction commands - they'll be handled by transaction handler let dummy_schema = datafusion::common::DFSchema::empty(); - let dummy_plan = datafusion::logical_expr::LogicalPlan::EmptyRelation( + return Ok(Some(LogicalPlan::EmptyRelation( datafusion::logical_expr::EmptyRelation { produce_one_row: false, - schema: std::sync::Arc::new(dummy_schema), + schema: Arc::new(dummy_schema), }, - ); - return Ok((sql.to_string(), dummy_plan)); + ))); + } + + // show statement may not be supported by datafusion + if sql_trimmed.starts_with("show") { + // Return a dummy plan for transaction commands - they'll be handled by transaction handler + let show_schema = + Arc::new(Schema::new(vec![Field::new("show", DataType::Utf8, false)])); + let df_schema = show_schema.to_dfschema()?; + return Ok(Some(LogicalPlan::EmptyRelation( + datafusion::logical_expr::EmptyRelation { + produce_one_row: true, + schema: Arc::new(df_schema), + }, + ))); + } + + Ok(None) + } +} + +#[async_trait] +impl QueryParser for Parser { + type Statement = (String, LogicalPlan); + + async fn parse_sql( + &self, + _client: &C, + sql: &str, + _types: &[Type], + ) -> PgWireResult { + log::debug!("Received parse extended query: {sql}"); // Log for debugging + + // Check for transaction commands that shouldn't be parsed by DataFusion + if let Some(plan) = self + .try_shortcut_parse_plan(sql) + .map_err(|e| PgWireError::ApiError(Box::new(e)))? + { + return Ok((sql.to_string(), plan)); } let mut statements = parse(sql).map_err(|e| PgWireError::ApiError(Box::new(e)))?; diff --git a/datafusion-postgres/src/pg_catalog.rs b/datafusion-postgres/src/pg_catalog.rs index 6f53a6a..1e77f60 100644 --- a/datafusion-postgres/src/pg_catalog.rs +++ b/datafusion-postgres/src/pg_catalog.rs @@ -4,8 +4,7 @@ use std::sync::Arc; use async_trait::async_trait; use datafusion::arrow::array::{ - as_boolean_array, ArrayRef, BooleanArray, BooleanBuilder, RecordBatch, StringArray, - StringBuilder, + as_boolean_array, ArrayRef, BooleanBuilder, RecordBatch, StringArray, StringBuilder, }; use datafusion::arrow::datatypes::{DataType, Field, SchemaRef}; use datafusion::arrow::ipc::reader::FileReader; @@ -20,12 +19,16 @@ use datafusion::prelude::{create_udf, Expr, SessionContext}; use postgres_types::Oid; use tokio::sync::RwLock; +mod empty_table; +mod has_privilege_udf; mod pg_attribute; mod pg_class; mod pg_database; mod pg_get_expr_udf; mod pg_namespace; mod pg_settings; +mod pg_tables; +mod pg_views; const PG_CATALOG_TABLE_PG_AGGREGATE: &str = "pg_aggregate"; const PG_CATALOG_TABLE_PG_AM: &str = "pg_am"; @@ -89,6 +92,10 @@ const PG_CATALOG_TABLE_PG_TABLESPACE: &str = "pg_tablespace"; const PG_CATALOG_TABLE_PG_TRIGGER: &str = "pg_trigger"; const PG_CATALOG_TABLE_PG_USER_MAPPING: &str = "pg_user_mapping"; const PG_CATALOG_VIEW_PG_SETTINGS: &str = "pg_settings"; +const PG_CATALOG_VIEW_PG_VIEWS: &str = "pg_views"; +const PG_CATALOG_VIEW_PG_MATVIEWS: &str = "pg_matviews"; +const PG_CATALOG_VIEW_PG_TABLES: &str = "pg_tables"; +const PG_CATALOG_VIEW_PG_STAT_USER_TABELS: &str = "pg_stat_user_tables"; /// Determine PostgreSQL table type (relkind) from DataFusion TableProvider fn get_table_type(table: &Arc) -> &'static str { @@ -184,6 +191,9 @@ pub const PG_CATALOG_TABLES: &[&str] = &[ PG_CATALOG_TABLE_PG_TRIGGER, PG_CATALOG_TABLE_PG_USER_MAPPING, PG_CATALOG_VIEW_PG_SETTINGS, + PG_CATALOG_VIEW_PG_VIEWS, + PG_CATALOG_VIEW_PG_MATVIEWS, + PG_CATALOG_VIEW_PG_STAT_USER_TABELS, ]; #[derive(Debug, Hash, Eq, PartialEq, PartialOrd, Ord)] @@ -315,9 +325,10 @@ impl SchemaProvider for PgCatalogSchemaProvider { self.oid_counter.clone(), self.oid_cache.clone(), )); - Ok(Some(Arc::new( - StreamingTable::try_new(Arc::clone(table.schema()), vec![table]).unwrap(), - ))) + Ok(Some(Arc::new(StreamingTable::try_new( + Arc::clone(table.schema()), + vec![table], + )?))) } PG_CATALOG_TABLE_PG_CLASS => { let table = Arc::new(pg_class::PgClassTable::new( @@ -325,9 +336,10 @@ impl SchemaProvider for PgCatalogSchemaProvider { self.oid_counter.clone(), self.oid_cache.clone(), )); - Ok(Some(Arc::new( - StreamingTable::try_new(Arc::clone(table.schema()), vec![table]).unwrap(), - ))) + Ok(Some(Arc::new(StreamingTable::try_new( + Arc::clone(table.schema()), + vec![table], + )?))) } PG_CATALOG_TABLE_PG_DATABASE => { let table = Arc::new(pg_database::PgDatabaseTable::new( @@ -335,9 +347,10 @@ impl SchemaProvider for PgCatalogSchemaProvider { self.oid_counter.clone(), self.oid_cache.clone(), )); - Ok(Some(Arc::new( - StreamingTable::try_new(Arc::clone(table.schema()), vec![table]).unwrap(), - ))) + Ok(Some(Arc::new(StreamingTable::try_new( + Arc::clone(table.schema()), + vec![table], + )?))) } PG_CATALOG_TABLE_PG_NAMESPACE => { let table = Arc::new(pg_namespace::PgNamespaceTable::new( @@ -345,14 +358,27 @@ impl SchemaProvider for PgCatalogSchemaProvider { self.oid_counter.clone(), self.oid_cache.clone(), )); - Ok(Some(Arc::new( - StreamingTable::try_new(Arc::clone(table.schema()), vec![table]).unwrap(), - ))) + Ok(Some(Arc::new(StreamingTable::try_new( + Arc::clone(table.schema()), + vec![table], + )?))) + } + PG_CATALOG_VIEW_PG_TABLES => { + let table = Arc::new(pg_tables::PgTablesTable::new(self.catalog_list.clone())); + Ok(Some(Arc::new(StreamingTable::try_new( + Arc::clone(table.schema()), + vec![table], + )?))) } PG_CATALOG_VIEW_PG_SETTINGS => { let table = pg_settings::PgSettingsView::try_new()?; Ok(Some(Arc::new(table.try_into_memtable()?))) } + PG_CATALOG_VIEW_PG_VIEWS => Ok(Some(Arc::new(pg_views::pg_views()?))), + PG_CATALOG_VIEW_PG_MATVIEWS => Ok(Some(Arc::new(pg_views::pg_matviews()?))), + PG_CATALOG_VIEW_PG_STAT_USER_TABELS => { + Ok(Some(Arc::new(pg_views::pg_stat_user_tables()?))) + } _ => Ok(None), } @@ -687,7 +713,7 @@ impl PgCatalogStaticTables { } } -pub fn create_current_schemas_udf() -> ScalarUDF { +pub fn create_current_schemas_udf(name: &str) -> ScalarUDF { // Define the function implementation let func = move |args: &[ColumnarValue]| { let args = ColumnarValue::values_to_arrays(args)?; @@ -710,7 +736,7 @@ pub fn create_current_schemas_udf() -> ScalarUDF { // Wrap the implementation in a scalar function create_udf( - "current_schemas", + name, vec![DataType::Boolean], DataType::List(Arc::new(Field::new("schema", DataType::Utf8, false))), Volatility::Immutable, @@ -718,7 +744,7 @@ pub fn create_current_schemas_udf() -> ScalarUDF { ) } -pub fn create_current_schema_udf() -> ScalarUDF { +pub fn create_current_schema_udf(name: &str) -> ScalarUDF { // Define the function implementation let func = move |_args: &[ColumnarValue]| { // Create a UTF8 array with a single value @@ -731,7 +757,28 @@ pub fn create_current_schema_udf() -> ScalarUDF { // Wrap the implementation in a scalar function create_udf( - "current_schema", + name, + vec![], + DataType::Utf8, + Volatility::Immutable, + Arc::new(func), + ) +} + +pub fn create_current_database_udf(name: &str) -> ScalarUDF { + // Define the function implementation + let func = move |_args: &[ColumnarValue]| { + // Create a UTF8 array with a single value + let mut builder = StringBuilder::new(); + builder.append_value("datafusion"); + let array: ArrayRef = Arc::new(builder.finish()); + + Ok(ColumnarValue::Array(array)) + }; + + // Wrap the implementation in a scalar function + create_udf( + name, vec![], DataType::Utf8, Volatility::Immutable, @@ -789,7 +836,7 @@ pub fn create_pg_get_userbyid_udf() -> ScalarUDF { ) } -pub fn create_pg_table_is_visible() -> ScalarUDF { +pub fn create_pg_table_is_visible(name: &str) -> ScalarUDF { // Define the function implementation let func = move |args: &[ColumnarValue]| { let args = ColumnarValue::values_to_arrays(args)?; @@ -808,7 +855,7 @@ pub fn create_pg_table_is_visible() -> ScalarUDF { // Wrap the implementation in a scalar function create_udf( - "pg_catalog.pg_table_is_visible", + name, vec![DataType::Int32], DataType::Boolean, Volatility::Stable, @@ -816,62 +863,6 @@ pub fn create_pg_table_is_visible() -> ScalarUDF { ) } -pub fn create_has_table_privilege_3param_udf() -> ScalarUDF { - // Define the function implementation for 3-parameter version - let func = move |args: &[ColumnarValue]| { - let args = ColumnarValue::values_to_arrays(args)?; - let user = &args[0]; // User (can be name or OID) - let _table = &args[1]; // Table (can be name or OID) - let _privilege = &args[2]; // Privilege type (SELECT, INSERT, etc.) - - // For now, always return true (full access) - let mut builder = BooleanArray::builder(user.len()); - for _ in 0..user.len() { - builder.append_value(true); - } - - let array: ArrayRef = Arc::new(builder.finish()); - - Ok(ColumnarValue::Array(array)) - }; - - // Wrap the implementation in a scalar function - create_udf( - "has_table_privilege", - vec![DataType::Utf8, DataType::Utf8, DataType::Utf8], - DataType::Boolean, - Volatility::Stable, - Arc::new(func), - ) -} - -pub fn create_has_table_privilege_2param_udf() -> ScalarUDF { - // Define the function implementation for 2-parameter version (current user, table, privilege) - let func = move |args: &[ColumnarValue]| { - let args = ColumnarValue::values_to_arrays(args)?; - let table = &args[0]; // Table (can be name or OID) - let _privilege = &args[1]; // Privilege type (SELECT, INSERT, etc.) - - // For now, always return true (full access for current user) - let mut builder = BooleanArray::builder(table.len()); - for _ in 0..table.len() { - builder.append_value(true); - } - let array: ArrayRef = Arc::new(builder.finish()); - - Ok(ColumnarValue::Array(array)) - }; - - // Wrap the implementation in a scalar function - create_udf( - "has_table_privilege", - vec![DataType::Utf8, DataType::Utf8], - DataType::Boolean, - Volatility::Stable, - Arc::new(func), - ) -} - pub fn create_format_type_udf() -> ScalarUDF { let func = move |args: &[ColumnarValue]| { let args = ColumnarValue::values_to_arrays(args)?; @@ -923,7 +914,6 @@ pub fn create_pg_get_partkeydef_udf() -> ScalarUDF { let args = ColumnarValue::values_to_arrays(args)?; let oid = &args[0]; - // For now, always return true (full access for current user) let mut builder = StringBuilder::new(); for _ in 0..oid.len() { builder.append_value(""); @@ -962,16 +952,37 @@ pub fn setup_pg_catalog( })? .register_schema("pg_catalog", Arc::new(pg_catalog))?; - session_context.register_udf(create_current_schema_udf()); - session_context.register_udf(create_current_schemas_udf()); + session_context.register_udf(create_current_database_udf("current_database")); + session_context.register_udf(create_current_schema_udf("current_schema")); + session_context.register_udf(create_current_schema_udf("pg_catalog.current_schema")); + session_context.register_udf(create_current_schemas_udf("current_schemas")); + session_context.register_udf(create_current_schemas_udf("pg_catalog.current_schemas")); session_context.register_udf(create_version_udf()); session_context.register_udf(create_pg_get_userbyid_udf()); - session_context.register_udf(create_has_table_privilege_2param_udf()); - session_context.register_udf(create_pg_table_is_visible()); + session_context.register_udf(has_privilege_udf::create_has_privilege_udf( + "has_table_privilege", + )); + session_context.register_udf(has_privilege_udf::create_has_privilege_udf( + "pg_catalog.has_table_privilege", + )); + session_context.register_udf(has_privilege_udf::create_has_privilege_udf( + "has_schema_privilege", + )); + session_context.register_udf(has_privilege_udf::create_has_privilege_udf( + "pg_catalog.has_schema_privilege", + )); + session_context.register_udf(has_privilege_udf::create_has_privilege_udf( + "has_any_column_privilege", + )); + session_context.register_udf(has_privilege_udf::create_has_privilege_udf( + "pg_catalog.has_any_column_privilege", + )); + session_context.register_udf(create_pg_table_is_visible("pg_catalog")); + session_context.register_udf(create_pg_table_is_visible("pg_catalog.pg_table_is_visible")); session_context.register_udf(create_format_type_udf()); session_context.register_udf(create_session_user_udf()); session_context.register_udtf("pg_get_keywords", static_tables.pg_get_keywords.clone()); - session_context.register_udf(pg_get_expr_udf::PgGetExprUDF::new().into_scalar_udf()); + session_context.register_udf(pg_get_expr_udf::create_pg_get_expr_udf()); session_context.register_udf(create_pg_get_partkeydef_udf()); Ok(()) diff --git a/datafusion-postgres/src/pg_catalog/empty_table.rs b/datafusion-postgres/src/pg_catalog/empty_table.rs new file mode 100644 index 0000000..e12032d --- /dev/null +++ b/datafusion-postgres/src/pg_catalog/empty_table.rs @@ -0,0 +1,18 @@ +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::catalog::MemTable; +use datafusion::error::Result; + +#[derive(Debug, Clone)] +pub(crate) struct EmptyTable { + schema: SchemaRef, +} + +impl EmptyTable { + pub(crate) fn new(schema: SchemaRef) -> Self { + Self { schema } + } + + pub fn try_into_memtable(self) -> Result { + MemTable::try_new(self.schema, vec![vec![]]) + } +} diff --git a/datafusion-postgres/src/pg_catalog/has_privilege_udf.rs b/datafusion-postgres/src/pg_catalog/has_privilege_udf.rs new file mode 100644 index 0000000..89f56cc --- /dev/null +++ b/datafusion-postgres/src/pg_catalog/has_privilege_udf.rs @@ -0,0 +1,71 @@ +use std::sync::Arc; + +use datafusion::arrow::array::{ArrayRef, BooleanArray}; +use datafusion::error::Result; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDF}; +use datafusion::{ + arrow::datatypes::DataType, + logical_expr::{ScalarUDFImpl, Signature, TypeSignature, Volatility}, +}; + +#[derive(Debug)] +pub struct PgHasPrivilegeUDF { + signature: Signature, + name: String, +} + +impl PgHasPrivilegeUDF { + pub(crate) fn new(name: &str) -> PgHasPrivilegeUDF { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]), + ], + Volatility::Stable, + ), + name: name.to_owned(), + } + } + + pub fn into_scalar_udf(self) -> ScalarUDF { + ScalarUDF::new_from_impl(self) + } +} + +impl ScalarUDFImpl for PgHasPrivilegeUDF { + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Boolean) + } + + fn name(&self) -> &str { + self.name.as_ref() + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = ColumnarValue::values_to_arrays(&args.args)?; + + let len = args[0].len(); + + // For now, always return true (full access for current user) + let mut builder = BooleanArray::builder(len); + for _ in 0..len { + builder.append_value(true); + } + let array: ArrayRef = Arc::new(builder.finish()); + + Ok(ColumnarValue::Array(array)) + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } +} + +pub fn create_has_privilege_udf(name: &str) -> ScalarUDF { + PgHasPrivilegeUDF::new(name).into_scalar_udf() +} diff --git a/datafusion-postgres/src/pg_catalog/pg_get_expr_udf.rs b/datafusion-postgres/src/pg_catalog/pg_get_expr_udf.rs index 20b759c..d6672cd 100644 --- a/datafusion-postgres/src/pg_catalog/pg_get_expr_udf.rs +++ b/datafusion-postgres/src/pg_catalog/pg_get_expr_udf.rs @@ -72,3 +72,7 @@ impl ScalarUDFImpl for PgGetExprUDF { self } } + +pub fn create_pg_get_expr_udf() -> ScalarUDF { + PgGetExprUDF::new().into_scalar_udf() +} diff --git a/datafusion-postgres/src/pg_catalog/pg_tables.rs b/datafusion-postgres/src/pg_catalog/pg_tables.rs new file mode 100644 index 0000000..7220d9d --- /dev/null +++ b/datafusion-postgres/src/pg_catalog/pg_tables.rs @@ -0,0 +1,102 @@ +use std::sync::Arc; + +use datafusion::arrow::array::{ArrayRef, BooleanArray, RecordBatch, StringArray}; +use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::catalog::CatalogProviderList; +use datafusion::error::Result; +use datafusion::execution::{SendableRecordBatchStream, TaskContext}; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion::physical_plan::streaming::PartitionStream; + +#[derive(Debug, Clone)] +pub(crate) struct PgTablesTable { + schema: SchemaRef, + catalog_list: Arc, +} + +impl PgTablesTable { + pub(crate) fn new(catalog_list: Arc) -> PgTablesTable { + // Define the schema for pg_class + // This matches key columns from PostgreSQL's pg_class + let schema = Arc::new(Schema::new(vec![ + Field::new("schemaname", DataType::Utf8, false), + Field::new("tablename", DataType::Utf8, false), + Field::new("tableowner", DataType::Utf8, false), + Field::new("tablespace", DataType::Utf8, true), + Field::new("hasindex", DataType::Boolean, false), + Field::new("hasrules", DataType::Boolean, false), + Field::new("hastriggers", DataType::Boolean, false), + Field::new("rowsecurity", DataType::Boolean, false), + ])); + + Self { + schema, + catalog_list, + } + } + + /// Generate record batches based on the current state of the catalog + async fn get_data(this: PgTablesTable) -> Result { + // Vectors to store column data + let mut schema_names = Vec::new(); + let mut table_names = Vec::new(); + let mut table_owners = Vec::new(); + let mut table_spaces: Vec> = Vec::new(); + let mut has_index = Vec::new(); + let mut has_rules = Vec::new(); + let mut has_triggers = Vec::new(); + let mut row_security = Vec::new(); + + // Iterate through all catalogs and schemas + for catalog_name in this.catalog_list.catalog_names() { + if let Some(catalog) = this.catalog_list.catalog(&catalog_name) { + for schema_name in catalog.schema_names() { + if let Some(schema) = catalog.schema(&schema_name) { + // Now process all tables in this schema + for table_name in schema.table_names() { + schema_names.push(schema_name.to_string()); + table_names.push(table_name.to_string()); + table_owners.push("postgres".to_string()); + table_spaces.push(None); + has_index.push(false); + has_rules.push(false); + has_triggers.push(false); + row_security.push(false); + } + } + } + } + } + + // Create Arrow arrays from the collected data + let arrays: Vec = vec![ + Arc::new(StringArray::from(schema_names)), + Arc::new(StringArray::from(table_names)), + Arc::new(StringArray::from(table_owners)), + Arc::new(StringArray::from(table_spaces)), + Arc::new(BooleanArray::from(has_index)), + Arc::new(BooleanArray::from(has_rules)), + Arc::new(BooleanArray::from(has_triggers)), + Arc::new(BooleanArray::from(row_security)), + ]; + + // Create a record batch + let batch = RecordBatch::try_new(this.schema.clone(), arrays)?; + + Ok(batch) + } +} + +impl PartitionStream for PgTablesTable { + fn schema(&self) -> &SchemaRef { + &self.schema + } + + fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { + let this = self.clone(); + Box::pin(RecordBatchStreamAdapter::new( + this.schema.clone(), + futures::stream::once(async move { PgTablesTable::get_data(this).await }), + )) + } +} diff --git a/datafusion-postgres/src/pg_catalog/pg_views.rs b/datafusion-postgres/src/pg_catalog/pg_views.rs new file mode 100644 index 0000000..36286e3 --- /dev/null +++ b/datafusion-postgres/src/pg_catalog/pg_views.rs @@ -0,0 +1,92 @@ +use std::sync::Arc; + +use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; +use datafusion::catalog::MemTable; +use datafusion::error::Result; + +use super::empty_table::EmptyTable; + +pub fn pg_views() -> Result { + let schema = Arc::new(Schema::new(vec![ + Field::new("schemaname", DataType::Utf8, true), + Field::new("viewname", DataType::Utf8, true), + Field::new("viewowner", DataType::Utf8, true), + Field::new("definition", DataType::Utf8, true), + ])); + EmptyTable::new(schema).try_into_memtable() +} + +pub fn pg_matviews() -> Result { + let schema = Arc::new(Schema::new(vec![ + Field::new("schemaname", DataType::Utf8, true), + Field::new("matviewname", DataType::Utf8, true), + Field::new("matviewowner", DataType::Utf8, true), + Field::new("tablespace", DataType::Utf8, true), + Field::new("hasindexes", DataType::Boolean, true), + Field::new("ispopulated", DataType::Boolean, true), + Field::new("definition", DataType::Utf8, true), + ])); + + EmptyTable::new(schema).try_into_memtable() +} + +pub fn pg_stat_user_tables() -> Result { + let schema = Arc::new(Schema::new(vec![ + Field::new("relid", DataType::Int32, false), + Field::new("schemaname", DataType::Utf8, false), + Field::new("relname", DataType::Utf8, false), + Field::new("seq_scan", DataType::Int64, false), + Field::new( + "last_seq_scan", + DataType::Timestamp(TimeUnit::Millisecond, None), + true, + ), + Field::new("seq_tup_read", DataType::Int64, false), + Field::new("idx_scan", DataType::Int64, false), + Field::new( + "last_idx_scan", + DataType::Timestamp(TimeUnit::Millisecond, None), + true, + ), + Field::new("idx_tup_fetch", DataType::Int64, false), + Field::new("n_tup_ins", DataType::Int64, false), + Field::new("n_tup_upd", DataType::Int64, false), + Field::new("n_tup_del", DataType::Int64, false), + Field::new("n_tup_hot_upd", DataType::Int64, false), + Field::new("n_tup_newpage_upd", DataType::Int64, false), + Field::new("n_live_tup", DataType::Int64, false), + Field::new("n_dead_tup", DataType::Int64, false), + Field::new("n_mod_since_analyze", DataType::Int64, false), + Field::new("n_ins_since_vacuum", DataType::Int64, false), + Field::new( + "last_vacuum", + DataType::Timestamp(TimeUnit::Millisecond, None), + true, + ), + Field::new( + "last_autovacuum", + DataType::Timestamp(TimeUnit::Millisecond, None), + true, + ), + Field::new( + "last_analyze", + DataType::Timestamp(TimeUnit::Millisecond, None), + true, + ), + Field::new( + "last_autoanalyze", + DataType::Timestamp(TimeUnit::Millisecond, None), + true, + ), + Field::new("vacuum_count", DataType::Int64, false), + Field::new("autovacuum_count", DataType::Int64, false), + Field::new("analyze_count", DataType::Int64, false), + Field::new("autoanalyze_count", DataType::Int64, false), + Field::new("total_vacuum_time", DataType::Float64, false), + Field::new("total_autovacuum_time", DataType::Float64, false), + Field::new("total_analyze_time", DataType::Float64, false), + Field::new("total_autoanalyze_time", DataType::Float64, false), + ])); + + EmptyTable::new(schema).try_into_memtable() +} diff --git a/datafusion-postgres/src/sql.rs b/datafusion-postgres/src/sql.rs index ffb2066..736769c 100644 --- a/datafusion-postgres/src/sql.rs +++ b/datafusion-postgres/src/sql.rs @@ -613,6 +613,57 @@ impl SqlStatementRewriteRule for RemoveTableFunctionQualifier { } } +/// Replace `current_user` with `session_user()` +#[derive(Debug)] +pub struct CurrentUserVariableToSessionUserFunctionCall; + +struct CurrentUserVariableToSessionUserFunctionCallVisitor; + +impl VisitorMut for CurrentUserVariableToSessionUserFunctionCallVisitor { + type Break = (); + + fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow { + if let Expr::Identifier(ident) = expr { + if ident.quote_style.is_none() && ident.value.to_lowercase() == "current_user" { + *expr = Expr::Function(Function { + name: ObjectName::from(vec![Ident::new("session_user")]), + args: FunctionArguments::None, + uses_odbc_syntax: false, + parameters: FunctionArguments::None, + filter: None, + null_treatment: None, + over: None, + within_group: vec![], + }); + } + } + + if let Expr::Function(func) = expr { + let fname = func + .name + .0 + .iter() + .map(|ident| ident.to_string()) + .collect::>() + .join("."); + if fname.to_lowercase() == "current_user" { + func.name = ObjectName::from(vec![Ident::new("session_user")]) + } + } + + ControlFlow::Continue(()) + } +} + +impl SqlStatementRewriteRule for CurrentUserVariableToSessionUserFunctionCall { + fn rewrite(&self, mut s: Statement) -> Statement { + let mut visitor = CurrentUserVariableToSessionUserFunctionCallVisitor; + + let _ = s.visit(&mut visitor); + s + } +} + #[cfg(test)] mod tests { use super::*; @@ -802,4 +853,20 @@ mod tests { "SELECT * FROM pg_get_keywords()" ); } + + #[test] + fn test_current_user() { + let rules: Vec> = + vec![Arc::new(CurrentUserVariableToSessionUserFunctionCall)]; + + assert_rewrite!(&rules, "SELECT current_user", "SELECT session_user"); + + assert_rewrite!(&rules, "SELECT CURRENT_USER", "SELECT session_user"); + + assert_rewrite!( + &rules, + "SELECT is_null(current_user)", + "SELECT is_null(session_user)" + ); + } } diff --git a/datafusion-postgres/tests/metabase.rs b/datafusion-postgres/tests/metabase.rs new file mode 100644 index 0000000..3c15700 --- /dev/null +++ b/datafusion-postgres/tests/metabase.rs @@ -0,0 +1,53 @@ +mod common; + +use common::*; +use pgwire::api::query::SimpleQueryHandler; + +const METABASE_QUERIES: &[&str] = &[ + "SET extra_float_digits = 2", + "SET application_name = 'Metabase v0.55.1 [f8f63fdf-d8f8-4573-86ea-4fe4a9548041]'", + "SHOW TRANSACTION ISOLATION LEVEL", + "SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL READ UNCOMMITTED", + r#"SELECT nspname AS "TABLE_SCHEM", current_database() AS "TABLE_CATALOG" FROM pg_catalog.pg_namespace WHERE nspname <> 'pg_toast' AND (nspname !~ '^pg_temp_' OR nspname = (pg_catalog.current_schemas(true))[1]) AND (nspname !~ '^pg_toast_temp_' OR nspname = replace((pg_catalog.current_schemas(true))[1], 'pg_temp_', 'pg_toast_temp_')) ORDER BY "TABLE_SCHEM""#, + r#"with table_privileges as ( + select + NULL as role, + t.schemaname as schema, + t.objectname as table, + pg_catalog.has_any_column_privilege(current_user, '"' || replace(t.schemaname, '"', '""') || '"' || '.' || '"' || replace(t.objectname, '"', '""') || '"', 'update') as update, + pg_catalog.has_any_column_privilege(current_user, '"' || replace(t.schemaname, '"', '""') || '"' || '.' || '"' || replace(t.objectname, '"', '""') || '"', 'select') as select, + pg_catalog.has_any_column_privilege(current_user, '"' || replace(t.schemaname, '"', '""') || '"' || '.' || '"' || replace(t.objectname, '"', '""') || '"', 'insert') as insert, + pg_catalog.has_table_privilege( current_user, '"' || replace(t.schemaname, '"', '""') || '"' || '.' || '"' || replace(t.objectname, '"', '""') || '"', 'delete') as delete + from ( + select schemaname, tablename as objectname from pg_catalog.pg_tables + union + select schemaname, viewname as objectname from pg_catalog.pg_views + union + select schemaname, matviewname as objectname from pg_catalog.pg_matviews + ) t + where t.schemaname !~ '^pg_' + and t.schemaname <> 'information_schema' + and pg_catalog.has_schema_privilege(current_user, t.schemaname, 'usage') + ) + select t.* + from table_privileges t"#, + r#"SELECT "n"."nspname" AS "schema", "c"."relname" AS "name", CASE "c"."relkind" WHEN 'r' THEN 'TABLE' WHEN 'p' THEN 'PARTITIONED TABLE' WHEN 'v' THEN 'VIEW' WHEN 'f' THEN 'FOREIGN TABLE' WHEN 'm' THEN 'MATERIALIZED VIEW' ELSE NULL END AS "type", "d"."description" AS "description", "stat"."n_live_tup" AS "estimated_row_count" FROM "pg_catalog"."pg_class" AS "c" INNER JOIN "pg_catalog"."pg_namespace" AS "n" ON "c"."relnamespace" = "n"."oid" LEFT JOIN "pg_catalog"."pg_description" AS "d" ON ("c"."oid" = "d"."objoid") AND ("d"."objsubid" = '0') AND ("d"."classoid" = 'pg_class'::regclass) LEFT JOIN "pg_stat_user_tables" AS "stat" ON ("n"."nspname" = "stat"."schemaname") AND ("c"."relname" = "stat"."relname") WHERE ("c"."relnamespace" = "n"."oid") AND ("n"."nspname" !~ '^pg_') AND ("n"."nspname" <> 'information_schema') AND c.relkind in ('r', 'p', 'v', 'f', 'm') AND ("n"."nspname" IN ('public')) ORDER BY "type" ASC, "schema" ASC, "name" ASC"#, + "SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL READ COMMITTED", + "SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL READ UNCOMMITTED", + "show timezone", +]; + +#[tokio::test] +pub async fn test_metabase_startup_sql() { + env_logger::init(); + let service = setup_handlers(); + let mut client = MockClient::new(); + + for query in METABASE_QUERIES { + SimpleQueryHandler::do_query(&service, &mut client, query) + .await + .expect(&format!( + "failed to run sql: \n--------------\n {query}\n--------------\n" + )); + } +}