diff --git a/Cargo.lock b/Cargo.lock index 33bc38a..d311afc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1494,6 +1494,7 @@ dependencies = [ "bytes", "chrono", "datafusion", + "env_logger", "futures", "getset", "log", diff --git a/README.md b/README.md index 6a82025..762d786 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ project. - Permission control - Built-in `pg_catalog` tables - Built-in postgres functions for common meta queries + - [x] DBeaver compatibility - `datafusion-postgres-cli`: A cli tool starts a postgres compatible server for datafusion supported file formats, just like python's `SimpleHTTPServer`. - `arrow-pg`: A data type mapping, encoding/decoding library for arrow and diff --git a/arrow-pg/src/list_encoder.rs b/arrow-pg/src/list_encoder.rs index d1ca983..49f6373 100644 --- a/arrow-pg/src/list_encoder.rs +++ b/arrow-pg/src/list_encoder.rs @@ -1,5 +1,6 @@ use std::{str::FromStr, sync::Arc}; +use arrow::array::{BinaryViewArray, StringViewArray}; #[cfg(not(feature = "datafusion"))] use arrow::{ array::{ @@ -150,6 +151,15 @@ pub(crate) fn encode_list( .collect(); encode_field(&value, type_, format) } + DataType::Utf8View => { + let value: Vec> = arr + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .collect(); + encode_field(&value, type_, format) + } DataType::Binary => { let value: Vec> = arr .as_any() @@ -168,6 +178,15 @@ pub(crate) fn encode_list( .collect(); encode_field(&value, type_, format) } + DataType::BinaryView => { + let value: Vec> = arr + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .collect(); + encode_field(&value, type_, format) + } DataType::Date32 => { let value: Vec> = arr diff --git a/datafusion-postgres/Cargo.toml b/datafusion-postgres/Cargo.toml index c87f968..4889673 100644 --- a/datafusion-postgres/Cargo.toml +++ b/datafusion-postgres/Cargo.toml @@ -28,3 +28,6 @@ tokio = { version = "1.47", features = ["sync", "net"] } tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] } rustls-pemfile = "2.0" rustls-pki-types = "1.0" + +[dev-dependencies] +env_logger = "0.11" diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index f5d5d52..7d2ccb3 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -3,14 +3,16 @@ use std::sync::Arc; use crate::auth::{AuthManager, Permission, ResourceType}; use crate::sql::{ - parse, rewrite, AliasDuplicatedProjectionRewrite, RemoveUnsupportedTypes, - ResolveUnqualifiedIdentifer, SqlStatementRewriteRule, + parse, rewrite, AliasDuplicatedProjectionRewrite, FixArrayLiteral, PrependUnqualifiedTableName, + RemoveTableFunctionQualifier, RemoveUnsupportedTypes, ResolveUnqualifiedIdentifer, + RewriteArrayAnyAllOperation, SqlStatementRewriteRule, }; use async_trait::async_trait; use datafusion::arrow::datatypes::DataType; use datafusion::logical_expr::LogicalPlan; use datafusion::prelude::*; use datafusion::sql::parser::Statement; +use log::warn; use pgwire::api::auth::noop::NoopStartupHandler; use pgwire::api::auth::StartupHandler; use pgwire::api::portal::{Format, Portal}; @@ -80,6 +82,10 @@ impl DfSessionService { Arc::new(AliasDuplicatedProjectionRewrite), Arc::new(ResolveUnqualifiedIdentifer), Arc::new(RemoveUnsupportedTypes::new()), + Arc::new(RewriteArrayAnyAllOperation), + Arc::new(PrependUnqualifiedTableName::new()), + Arc::new(FixArrayLiteral), + Arc::new(RemoveTableFunctionQualifier), ]; let parser = Arc::new(Parser { session_context: session_context.clone(), @@ -211,14 +217,12 @@ impl DfSessionService { } } else { // pass SET query to datafusion - let df = self - .session_context - .sql(query_lower) - .await - .map_err(|err| PgWireError::ApiError(Box::new(err)))?; - - let resp = df::encode_dataframe(df, &Format::UnifiedText).await?; - Ok(Some(Response::Query(resp))) + if let Err(e) = self.session_context.sql(query_lower).await { + warn!("SET statement {query_lower} is not supported by datafusion, error {e}, statement ignored"); + } + + // Always return SET success + Ok(Some(Response::Execution(Tag::new("SET")))) } } else { Ok(None) @@ -297,8 +301,8 @@ impl DfSessionService { Ok(Some(Response::Query(resp))) } "show search_path" => { - let default_catalog = "datafusion"; - let resp = Self::mock_show_response("search_path", default_catalog)?; + let default_schema = "public"; + let resp = Self::mock_show_response("search_path", default_schema)?; Ok(Some(Response::Query(resp))) } _ => Err(PgWireError::UserError(Box::new( @@ -331,7 +335,8 @@ impl SimpleQueryHandler for DfSessionService { statement = rewrite(statement, &self.sql_rewrite_rules); // TODO: improve statement check by using statement directly - let query_lower = statement.to_string().to_lowercase().trim().to_string(); + let query = statement.to_string(); + let query_lower = query.to_lowercase().trim().to_string(); // Check permissions for the query (skip for SET, transaction, and SHOW statements) if !query_lower.starts_with("set") @@ -343,7 +348,7 @@ impl SimpleQueryHandler for DfSessionService { && !query_lower.starts_with("abort") && !query_lower.starts_with("show") { - self.check_query_permission(client, query).await?; + self.check_query_permission(client, &query).await?; } if let Some(resp) = self.try_respond_set_statements(&query_lower).await? { @@ -373,7 +378,7 @@ impl SimpleQueryHandler for DfSessionService { ))); } - let df_result = self.session_context.sql(query).await; + let df_result = self.session_context.sql(&query).await; // Handle query execution errors and transaction state let df = match df_result { diff --git a/datafusion-postgres/src/pg_catalog.rs b/datafusion-postgres/src/pg_catalog.rs index 90e3a31..6701d58 100644 --- a/datafusion-postgres/src/pg_catalog.rs +++ b/datafusion-postgres/src/pg_catalog.rs @@ -10,13 +10,13 @@ use datafusion::arrow::array::{ use datafusion::arrow::datatypes::{DataType, Field, SchemaRef}; use datafusion::arrow::ipc::reader::FileReader; use datafusion::catalog::streaming::StreamingTable; -use datafusion::catalog::{CatalogProviderList, MemTable, SchemaProvider}; +use datafusion::catalog::{CatalogProviderList, MemTable, SchemaProvider, TableFunctionImpl}; use datafusion::common::utils::SingleRowListArrayBuilder; use datafusion::datasource::{TableProvider, ViewTable}; use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::{ColumnarValue, ScalarUDF, Volatility}; use datafusion::physical_plan::streaming::PartitionStream; -use datafusion::prelude::{create_udf, SessionContext}; +use datafusion::prelude::{create_udf, Expr, SessionContext}; use postgres_types::Oid; use tokio::sync::RwLock; @@ -24,6 +24,7 @@ mod pg_attribute; mod pg_class; mod pg_database; mod pg_namespace; +mod pg_settings; const PG_CATALOG_TABLE_PG_AGGREGATE: &str = "pg_aggregate"; const PG_CATALOG_TABLE_PG_AM: &str = "pg_am"; @@ -86,6 +87,7 @@ const PG_CATALOG_TABLE_PG_SUBSCRIPTION_REL: &str = "pg_subscription_rel"; const PG_CATALOG_TABLE_PG_TABLESPACE: &str = "pg_tablespace"; const PG_CATALOG_TABLE_PG_TRIGGER: &str = "pg_trigger"; const PG_CATALOG_TABLE_PG_USER_MAPPING: &str = "pg_user_mapping"; +const PG_CATALOG_VIEW_PG_SETTINGS: &str = "pg_settings"; /// Determine PostgreSQL table type (relkind) from DataFusion TableProvider fn get_table_type(table: &Arc) -> &'static str { @@ -180,6 +182,7 @@ pub const PG_CATALOG_TABLES: &[&str] = &[ PG_CATALOG_TABLE_PG_TABLESPACE, PG_CATALOG_TABLE_PG_TRIGGER, PG_CATALOG_TABLE_PG_USER_MAPPING, + PG_CATALOG_VIEW_PG_SETTINGS, ]; #[derive(Debug, Hash, Eq, PartialEq, PartialOrd, Ord)] @@ -196,7 +199,7 @@ pub struct PgCatalogSchemaProvider { catalog_list: Arc, oid_counter: Arc, oid_cache: Arc>>, - static_tables: PgCatalogStaticTables, + static_tables: Arc, } #[async_trait] @@ -345,6 +348,10 @@ impl SchemaProvider for PgCatalogSchemaProvider { StreamingTable::try_new(Arc::clone(table.schema()), vec![table]).unwrap(), ))) } + PG_CATALOG_VIEW_PG_SETTINGS => { + let table = pg_settings::PgSettingsView::try_new()?; + Ok(Some(Arc::new(table.try_into_memtable()?))) + } _ => Ok(None), } @@ -356,12 +363,15 @@ impl SchemaProvider for PgCatalogSchemaProvider { } impl PgCatalogSchemaProvider { - pub fn try_new(catalog_list: Arc) -> Result { + pub fn try_new( + catalog_list: Arc, + static_tables: Arc, + ) -> Result { Ok(Self { catalog_list, oid_counter: Arc::new(AtomicU32::new(16384)), oid_cache: Arc::new(RwLock::new(HashMap::new())), - static_tables: PgCatalogStaticTables::try_new()?, + static_tables, }) } } @@ -399,10 +409,17 @@ impl ArrowTable { } } +impl TableFunctionImpl for ArrowTable { + fn call(&self, _args: &[Expr]) -> Result> { + let table = self.clone().try_into_memtable()?; + Ok(Arc::new(table)) + } +} + /// pg_catalog table as datafusion table provider /// /// This implementation only contains static tables -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct PgCatalogStaticTables { pub pg_aggregate: Arc, pub pg_am: Arc, @@ -461,6 +478,8 @@ pub struct PgCatalogStaticTables { pub pg_tablespace: Arc, pub pg_trigger: Arc, pub pg_user_mapping: Arc, + + pub pg_get_keywords: Arc, } impl PgCatalogStaticTables { @@ -647,6 +666,10 @@ impl PgCatalogStaticTables { pg_user_mapping: Self::create_arrow_table( include_bytes!("../../pg_catalog_arrow_exports/pg_user_mapping.feather").to_vec(), )?, + + pg_get_keywords: Self::create_arrow_table_function( + include_bytes!("../../pg_catalog_arrow_exports/pg_get_keywords.feather").to_vec(), + )?, }) } @@ -656,6 +679,11 @@ impl PgCatalogStaticTables { let mem_table = table.try_into_memtable()?; Ok(Arc::new(mem_table)) } + + fn create_arrow_table_function(data_bytes: Vec) -> Result> { + let table = ArrowTable::from_ipc_data(data_bytes)?; + Ok(Arc::new(table)) + } } pub fn create_current_schemas_udf() -> ScalarUDF { @@ -862,7 +890,78 @@ pub fn create_format_type_udf() -> ScalarUDF { create_udf( "format_type", - vec![DataType::Int32, DataType::Int32], + vec![DataType::Int64, DataType::Int32], + DataType::Utf8, + Volatility::Stable, + Arc::new(func), + ) +} + +pub fn create_session_user_udf() -> ScalarUDF { + let func = move |_args: &[ColumnarValue]| { + let mut builder = StringBuilder::new(); + // TODO: return real user + builder.append_value("postgres"); + + let array: ArrayRef = Arc::new(builder.finish()); + + Ok(ColumnarValue::Array(array)) + }; + + create_udf( + "session_user", + vec![], + DataType::Utf8, + Volatility::Stable, + Arc::new(func), + ) +} + +pub fn create_pg_get_expr_udf() -> ScalarUDF { + let func = move |args: &[ColumnarValue]| { + let args = ColumnarValue::values_to_arrays(args)?; + let expr = &args[0]; + let _oid = &args[1]; + + // For now, always return true (full access for current user) + let mut builder = StringBuilder::new(); + for _ in 0..expr.len() { + builder.append_value(""); + } + + let array: ArrayRef = Arc::new(builder.finish()); + + Ok(ColumnarValue::Array(array)) + }; + + create_udf( + "pg_catalog.pg_get_expr", + vec![DataType::Utf8, DataType::Int32], + DataType::Utf8, + Volatility::Stable, + Arc::new(func), + ) +} + +pub fn create_pg_get_partkeydef_udf() -> ScalarUDF { + let func = move |args: &[ColumnarValue]| { + let args = ColumnarValue::values_to_arrays(args)?; + let oid = &args[0]; + + // For now, always return true (full access for current user) + let mut builder = StringBuilder::new(); + for _ in 0..oid.len() { + builder.append_value(""); + } + + let array: ArrayRef = Arc::new(builder.finish()); + + Ok(ColumnarValue::Array(array)) + }; + + create_udf( + "pg_catalog.pg_get_partkeydef", + vec![DataType::Utf8], DataType::Utf8, Volatility::Stable, Arc::new(func), @@ -874,8 +973,11 @@ pub fn setup_pg_catalog( session_context: &SessionContext, catalog_name: &str, ) -> Result<(), Box> { - let pg_catalog = - PgCatalogSchemaProvider::try_new(session_context.state().catalog_list().clone())?; + let static_tables = Arc::new(PgCatalogStaticTables::try_new()?); + let pg_catalog = PgCatalogSchemaProvider::try_new( + session_context.state().catalog_list().clone(), + static_tables.clone(), + )?; session_context .catalog(catalog_name) .ok_or_else(|| { @@ -892,6 +994,10 @@ pub fn setup_pg_catalog( session_context.register_udf(create_has_table_privilege_2param_udf()); session_context.register_udf(create_pg_table_is_visible()); session_context.register_udf(create_format_type_udf()); + session_context.register_udf(create_session_user_udf()); + session_context.register_udtf("pg_get_keywords", static_tables.pg_get_keywords.clone()); + session_context.register_udf(create_pg_get_expr_udf()); + session_context.register_udf(create_pg_get_partkeydef_udf()); Ok(()) } @@ -1145,5 +1251,9 @@ mod test { include_bytes!("../../pg_catalog_arrow_exports/pg_user_mapping.feather").to_vec(), ) .expect("Failed to load ipc data"); + let _ = ArrowTable::from_ipc_data( + include_bytes!("../../pg_catalog_arrow_exports/pg_get_keywords.feather").to_vec(), + ) + .expect("Failed to load ipc data"); } } diff --git a/datafusion-postgres/src/pg_catalog/pg_class.rs b/datafusion-postgres/src/pg_catalog/pg_class.rs index 72f2211..2767c6c 100644 --- a/datafusion-postgres/src/pg_catalog/pg_class.rs +++ b/datafusion-postgres/src/pg_catalog/pg_class.rs @@ -63,6 +63,7 @@ impl PgClassTable { Field::new("relrewrite", DataType::Int32, true), // OID of a rule that rewrites this relation Field::new("relfrozenxid", DataType::Int32, false), // All transaction IDs before this have been replaced with a permanent ("frozen") transaction ID Field::new("relminmxid", DataType::Int32, false), // All Multixact IDs before this have been replaced with a transaction ID + Field::new("relpartbound", DataType::Utf8, true), ])); Self { @@ -106,6 +107,7 @@ impl PgClassTable { let mut relrewrites = Vec::new(); let mut relfrozenxids = Vec::new(); let mut relminmxids = Vec::new(); + let mut relpartbound = Vec::new(); let mut oid_cache = this.oid_cache.write().await; // Every time when call pg_catalog we generate a new cache and drop the @@ -190,6 +192,7 @@ impl PgClassTable { relrewrites.push(None); relfrozenxids.push(0); relminmxids.push(0); + relpartbound.push("".to_string()); } } } @@ -231,6 +234,7 @@ impl PgClassTable { Arc::new(Int32Array::from_iter(relrewrites.into_iter())), Arc::new(Int32Array::from(relfrozenxids)), Arc::new(Int32Array::from(relminmxids)), + Arc::new(StringArray::from(relpartbound)), ]; // Create a record batch diff --git a/datafusion-postgres/src/pg_catalog/pg_settings.rs b/datafusion-postgres/src/pg_catalog/pg_settings.rs new file mode 100644 index 0000000..c94cd82 --- /dev/null +++ b/datafusion-postgres/src/pg_catalog/pg_settings.rs @@ -0,0 +1,115 @@ +use std::sync::Arc; + +use datafusion::arrow::array::{ArrayRef, BooleanArray, Int32Array, RecordBatch, StringArray}; +use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::catalog::MemTable; +use datafusion::error::Result; + +#[derive(Debug, Clone)] +pub(crate) struct PgSettingsView { + schema: SchemaRef, + data: Vec, +} + +impl PgSettingsView { + pub(crate) fn try_new() -> Result { + let schema = Arc::new(Schema::new(vec![ + // name | setting | unit | category | short_ + //desc | extra_desc + //| context | vartype | source | min_val | max_val | enumvals | + //boot_val | reset_val | sourcefile | sourceline | pending_restart + Field::new("name", DataType::Utf8, true), + Field::new("setting", DataType::Utf8, true), + Field::new("unit", DataType::Utf8, true), + Field::new("category", DataType::Utf8, true), + Field::new("short_desc", DataType::Utf8, true), + Field::new("extra_desc", DataType::Utf8, true), + Field::new("context", DataType::Utf8, true), + Field::new("vartype", DataType::Utf8, true), + Field::new("source", DataType::Utf8, true), + Field::new("min_val", DataType::Utf8, true), + Field::new("max_val", DataType::Utf8, true), + Field::new("enumvals", DataType::Utf8, true), + Field::new("bool_val", DataType::Utf8, true), + Field::new("reset_val", DataType::Utf8, true), + Field::new("sourcefile", DataType::Utf8, true), + Field::new("sourceline", DataType::Int32, true), + Field::new("pending_restart", DataType::Boolean, true), + ])); + + let data = Self::create_data(schema.clone())?; + + Ok(Self { schema, data }) + } + + fn create_data(schema: Arc) -> Result> { + let mut name: Vec> = Vec::new(); + let mut setting: Vec> = Vec::new(); + let mut unit: Vec> = Vec::new(); + let mut category: Vec> = Vec::new(); + let mut short_desc: Vec> = Vec::new(); + let mut extra_desc: Vec> = Vec::new(); + let mut context: Vec> = Vec::new(); + let mut vartype: Vec> = Vec::new(); + let mut source: Vec> = Vec::new(); + let mut min_val: Vec> = Vec::new(); + let mut max_val: Vec> = Vec::new(); + let mut enumvals: Vec> = Vec::new(); + let mut bool_val: Vec> = Vec::new(); + let mut reset_val: Vec> = Vec::new(); + let mut sourcefile: Vec> = Vec::new(); + let mut sourceline: Vec> = Vec::new(); + let mut pending_restart: Vec> = Vec::new(); + + let data = vec![("standard_conforming_strings", "on")]; + + for (setting_name, setting_val) in data { + name.push(Some(setting_name)); + setting.push(Some(setting_val)); + + unit.push(None); + category.push(None); + short_desc.push(None); + extra_desc.push(None); + context.push(None); + vartype.push(None); + source.push(None); + min_val.push(None); + max_val.push(None); + enumvals.push(None); + bool_val.push(None); + reset_val.push(None); + sourcefile.push(None); + sourceline.push(None); + pending_restart.push(None); + } + + let arrays: Vec = vec![ + Arc::new(StringArray::from(name)), + Arc::new(StringArray::from(setting)), + Arc::new(StringArray::from(unit)), + Arc::new(StringArray::from(category)), + Arc::new(StringArray::from(short_desc)), + Arc::new(StringArray::from(extra_desc)), + Arc::new(StringArray::from(context)), + Arc::new(StringArray::from(vartype)), + Arc::new(StringArray::from(source)), + Arc::new(StringArray::from(min_val)), + Arc::new(StringArray::from(max_val)), + Arc::new(StringArray::from(enumvals)), + Arc::new(StringArray::from(bool_val)), + Arc::new(StringArray::from(reset_val)), + Arc::new(StringArray::from(sourcefile)), + Arc::new(Int32Array::from(sourceline)), + Arc::new(BooleanArray::from(pending_restart)), + ]; + + let batch = RecordBatch::try_new(schema.clone(), arrays)?; + + Ok(vec![batch]) + } + + pub(crate) fn try_into_memtable(self) -> Result { + MemTable::try_new(self.schema, vec![self.data]) + } +} diff --git a/datafusion-postgres/src/sql.rs b/datafusion-postgres/src/sql.rs index 3a7fe77..65c8021 100644 --- a/datafusion-postgres/src/sql.rs +++ b/datafusion-postgres/src/sql.rs @@ -2,8 +2,20 @@ use std::collections::HashSet; use std::ops::ControlFlow; use std::sync::Arc; +use datafusion::sql::sqlparser::ast::Array; +use datafusion::sql::sqlparser::ast::ArrayElemTypeDef; +use datafusion::sql::sqlparser::ast::BinaryOperator; +use datafusion::sql::sqlparser::ast::CastKind; +use datafusion::sql::sqlparser::ast::DataType; use datafusion::sql::sqlparser::ast::Expr; +use datafusion::sql::sqlparser::ast::Function; +use datafusion::sql::sqlparser::ast::FunctionArg; +use datafusion::sql::sqlparser::ast::FunctionArgExpr; +use datafusion::sql::sqlparser::ast::FunctionArgumentList; +use datafusion::sql::sqlparser::ast::FunctionArguments; use datafusion::sql::sqlparser::ast::Ident; +use datafusion::sql::sqlparser::ast::ObjectName; +use datafusion::sql::sqlparser::ast::ObjectNamePart; use datafusion::sql::sqlparser::ast::OrderByKind; use datafusion::sql::sqlparser::ast::Query; use datafusion::sql::sqlparser::ast::Select; @@ -13,7 +25,9 @@ use datafusion::sql::sqlparser::ast::SetExpr; use datafusion::sql::sqlparser::ast::Statement; use datafusion::sql::sqlparser::ast::TableFactor; use datafusion::sql::sqlparser::ast::TableWithJoins; +use datafusion::sql::sqlparser::ast::UnaryOperator; use datafusion::sql::sqlparser::ast::Value; +use datafusion::sql::sqlparser::ast::ValueWithSpan; use datafusion::sql::sqlparser::ast::VisitMut; use datafusion::sql::sqlparser::ast::VisitorMut; use datafusion::sql::sqlparser::dialect::PostgreSqlDialect; @@ -272,6 +286,7 @@ impl RemoveUnsupportedTypes { pub fn new() -> Self { let mut unsupported_types = HashSet::new(); unsupported_types.insert("regclass".to_owned()); + unsupported_types.insert("regproc".to_owned()); Self { unsupported_types } } @@ -326,6 +341,262 @@ impl SqlStatementRewriteRule for RemoveUnsupportedTypes { } } +/// Rewrite Postgres's ANY operator to array_contains +#[derive(Debug)] +pub struct RewriteArrayAnyAllOperation; + +struct RewriteArrayAnyAllOperationVisitor; + +impl RewriteArrayAnyAllOperationVisitor { + fn any_to_array_cofntains(&self, left: &Expr, right: &Expr) -> Expr { + Expr::Function(Function { + name: ObjectName::from(vec![Ident::new("array_contains")]), + args: FunctionArguments::List(FunctionArgumentList { + args: vec![ + FunctionArg::Unnamed(FunctionArgExpr::Expr(right.clone())), + FunctionArg::Unnamed(FunctionArgExpr::Expr(left.clone())), + ], + duplicate_treatment: None, + clauses: vec![], + }), + uses_odbc_syntax: false, + parameters: FunctionArguments::None, + filter: None, + null_treatment: None, + over: None, + within_group: vec![], + }) + } +} + +impl VisitorMut for RewriteArrayAnyAllOperationVisitor { + type Break = (); + + fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow { + match expr { + Expr::AnyOp { + left, + compare_op, + right, + .. + } => match compare_op { + BinaryOperator::Eq => { + *expr = self.any_to_array_cofntains(left.as_ref(), right.as_ref()); + } + BinaryOperator::NotEq => { + // TODO:left not equals to any element in array + } + _ => {} + }, + Expr::AllOp { + left, + compare_op, + right, + } => match compare_op { + BinaryOperator::Eq => { + // TODO: left equals to every element in array + } + BinaryOperator::NotEq => { + *expr = Expr::UnaryOp { + op: UnaryOperator::Not, + expr: Box::new(self.any_to_array_cofntains(left.as_ref(), right.as_ref())), + } + } + _ => {} + }, + _ => {} + } + + ControlFlow::Continue(()) + } +} + +impl SqlStatementRewriteRule for RewriteArrayAnyAllOperation { + fn rewrite(&self, mut s: Statement) -> Statement { + let mut visitor = RewriteArrayAnyAllOperationVisitor; + + let _ = s.visit(&mut visitor); + + s + } +} + +/// Prepend qualifier to table_name +/// +/// Postgres has pg_catalog in search_path by default so it allow access to +/// `pg_namespace` without `pg_catalog.` qualifier +#[derive(Debug)] +pub struct PrependUnqualifiedTableName { + table_names: HashSet, +} + +impl PrependUnqualifiedTableName { + pub fn new() -> Self { + let mut table_names = HashSet::new(); + + table_names.insert("pg_namespace".to_owned()); + + Self { table_names } + } +} + +struct PrependUnqualifiedTableNameVisitor<'a> { + table_names: &'a HashSet, +} + +impl<'a> VisitorMut for PrependUnqualifiedTableNameVisitor<'a> { + type Break = (); + + fn pre_visit_table_factor( + &mut self, + table_factor: &mut TableFactor, + ) -> ControlFlow { + if let TableFactor::Table { name, .. } = table_factor { + if name.0.len() == 1 { + let ObjectNamePart::Identifier(ident) = &name.0[0]; + if self.table_names.contains(&ident.to_string()) { + *name = ObjectName(vec![ + ObjectNamePart::Identifier(Ident::new("pg_catalog")), + name.0[0].clone(), + ]); + } + } + } + + ControlFlow::Continue(()) + } +} + +impl SqlStatementRewriteRule for PrependUnqualifiedTableName { + fn rewrite(&self, mut s: Statement) -> Statement { + let mut visitor = PrependUnqualifiedTableNameVisitor { + table_names: &self.table_names, + }; + + let _ = s.visit(&mut visitor); + s + } +} + +#[derive(Debug)] +pub struct FixArrayLiteral; + +struct FixArrayLiteralVisitor; + +impl FixArrayLiteralVisitor { + fn is_string_type(dt: &DataType) -> bool { + matches!( + dt, + DataType::Text | DataType::Varchar(_) | DataType::Char(_) | DataType::String(_) + ) + } +} + +impl VisitorMut for FixArrayLiteralVisitor { + type Break = (); + + fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow { + if let Expr::Cast { + kind, + expr, + data_type, + .. + } = expr + { + if kind == &CastKind::DoubleColon { + if let DataType::Array(arr) = data_type { + // cast some to + if let Expr::Value(ValueWithSpan { + value: Value::SingleQuotedString(array_literal), + .. + }) = expr.as_ref() + { + let items = + array_literal.trim_matches(|c| c == '{' || c == '}' || c == ' '); + let items = items.split(',').map(|s| s.trim()).filter(|s| !s.is_empty()); + + let is_text = match arr { + ArrayElemTypeDef::AngleBracket(dt) => Self::is_string_type(dt.as_ref()), + ArrayElemTypeDef::SquareBracket(dt, _) => { + Self::is_string_type(dt.as_ref()) + } + ArrayElemTypeDef::Parenthesis(dt) => Self::is_string_type(dt.as_ref()), + _ => false, + }; + + let elems = items + .map(|s| { + if is_text { + Expr::Value( + Value::SingleQuotedString(s.to_string()).with_empty_span(), + ) + } else { + Expr::Value( + Value::Number(s.to_string(), false).with_empty_span(), + ) + } + }) + .collect(); + *expr = Box::new(Expr::Array(Array { + elem: elems, + named: true, + })); + } + } + } + } + + ControlFlow::Continue(()) + } +} + +impl SqlStatementRewriteRule for FixArrayLiteral { + fn rewrite(&self, mut s: Statement) -> Statement { + let mut visitor = FixArrayLiteralVisitor; + + let _ = s.visit(&mut visitor); + s + } +} + +/// Remove qualifier from table function +/// +/// The query engine doesn't support qualified table function name +#[derive(Debug)] +pub struct RemoveTableFunctionQualifier; + +struct RemoveTableFunctionQualifierVisitor; + +impl VisitorMut for RemoveTableFunctionQualifierVisitor { + type Break = (); + + fn pre_visit_table_factor( + &mut self, + table_factor: &mut TableFactor, + ) -> ControlFlow { + if let TableFactor::Table { name, args, .. } = table_factor { + if args.is_some() { + // multiple idents in name, which means it's a qualified table name + if name.0.len() > 1 { + if let Some(last_ident) = name.0.pop() { + *name = ObjectName(vec![last_ident]); + } + } + } + } + ControlFlow::Continue(()) + } +} + +impl SqlStatementRewriteRule for RemoveTableFunctionQualifier { + fn rewrite(&self, mut s: Statement) -> Statement { + let mut visitor = RemoveTableFunctionQualifierVisitor; + + let _ = s.visit(&mut visitor); + s + } +} + #[cfg(test)] mod tests { use super::*; @@ -362,6 +633,15 @@ mod tests { "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id", "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id" ); + + let sql = "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname"; + let statement = parse(sql).expect("Failed to parse").remove(0); + + let statement = rewrite(statement, &rules); + assert_eq!( + statement.to_string(), + "SELECT n.oid AS __alias_oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY nspsname" + ); } #[test] @@ -417,4 +697,87 @@ mod tests { "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname" ); } + + #[test] + fn test_any_to_array_contains() { + let rules: Vec> = + vec![Arc::new(RewriteArrayAnyAllOperation)]; + + assert_rewrite!( + &rules, + "SELECT a = ANY(current_schemas(true))", + "SELECT array_contains(current_schemas(true), a)" + ); + + assert_rewrite!( + &rules, + "SELECT a <> ALL(current_schemas(true))", + "SELECT NOT array_contains(current_schemas(true), a)" + ); + + assert_rewrite!( + &rules, + "SELECT a FROM tbl WHERE a = ANY(current_schemas(true))", + "SELECT a FROM tbl WHERE array_contains(current_schemas(true), a)" + ); + } + + #[test] + fn test_prepend_unqualified_table_name() { + let rules: Vec> = + vec![Arc::new(PrependUnqualifiedTableName::new())]; + + assert_rewrite!( + &rules, + "SELECT * FROM pg_catalog.pg_namespace", + "SELECT * FROM pg_catalog.pg_namespace" + ); + + assert_rewrite!( + &rules, + "SELECT * FROM pg_namespace", + "SELECT * FROM pg_catalog.pg_namespace" + ); + + assert_rewrite!( + &rules, + "SELECT typtype, typname, pg_type.oid FROM pg_catalog.pg_type LEFT JOIN pg_namespace as ns ON ns.oid = oid", + "SELECT typtype, typname, pg_type.oid FROM pg_catalog.pg_type LEFT JOIN pg_catalog.pg_namespace AS ns ON ns.oid = oid" + ); + } + + #[test] + fn test_array_literal_fix() { + let rules: Vec> = vec![Arc::new(FixArrayLiteral)]; + + assert_rewrite!( + &rules, + "SELECT '{a, abc}'::text[]", + "SELECT ARRAY['a', 'abc']::TEXT[]" + ); + + assert_rewrite!( + &rules, + "SELECT '{1, 2}'::int[]", + "SELECT ARRAY[1, 2]::INT[]" + ); + + assert_rewrite!( + &rules, + "SELECT '{t, f}'::bool[]", + "SELECT ARRAY[t, f]::BOOL[]" + ); + } + + #[test] + fn test_remove_qualifier_from_table_function() { + let rules: Vec> = + vec![Arc::new(RemoveTableFunctionQualifier)]; + + assert_rewrite!( + &rules, + "SELECT * FROM pg_catalog.pg_get_keywords()", + "SELECT * FROM pg_get_keywords()" + ); + } } diff --git a/datafusion-postgres/tests/dbeaver.rs b/datafusion-postgres/tests/dbeaver.rs index b1776d7..e132b91 100644 --- a/datafusion-postgres/tests/dbeaver.rs +++ b/datafusion-postgres/tests/dbeaver.rs @@ -3,12 +3,37 @@ mod common; use common::*; use pgwire::api::query::SimpleQueryHandler; +const DBEAVER_QUERIES: &[&str] = &[ + "SET extra_float_digits = 3", + "SET application_name = 'PostgreSQL JDBC Driver'", + "SET application_name = 'DBeaver 25.1.5 - Main '", + "SELECT current_schema(),session_user", + "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace'::regclass ORDER BY nspname", + "SELECT n.nspname = ANY(current_schemas(true)), n.nspname, t.typname FROM pg_catalog.pg_type t JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid WHERE t.oid = 1034", + "SELECT typinput='pg_catalog.array_in'::regproc as is_array, typtype, typname, pg_type.oid FROM pg_catalog.pg_type LEFT JOIN (select ns.oid as nspoid, ns.nspname, r.r from pg_namespace as ns join ( select s.r, (current_schemas(false))[s.r] as nspname from generate_series(1, array_upper(current_schemas(false), 1)) as s(r) ) as r using ( nspname ) ) as sp ON sp.nspoid = typnamespace WHERE pg_type.oid = 1034 ORDER BY sp.r, pg_type.oid DESC", + "SHOW search_path", + "SELECT db.oid,db.* FROM pg_catalog.pg_database db WHERE datname='postgres'", + "SELECT * FROM pg_catalog.pg_settings where name='standard_conforming_strings'", + "SELECT string_agg(word, ',' ) from pg_catalog.pg_get_keywords() where word <> ALL ('{a,abs,absolute,action,ada,add,admin,after,all,allocate,alter,aIways,and,any,are,array,as,asc,asenstitive,assertion,assignment,asymmetric,at,atomic,attribute,attributes,authorization,avg,before,begin,bernoulli,between,bigint,binary,blob,boolean,both,breaadth,by,c,call,called,cardinaliity,cascade,cascaded,case,cast,catalog,catalog_name,ceil,ceiling,chain,char,char_length,character,character_length,character_set_catalog,character_set_name,character_set_schema,characteristics,characters,check,checkeed,class_origin,clob,close,coalesce,coboI,code_units,collate,collation,collaition_catalog,collaition_name,collaition_schema,collect,colum,column_name,command_function,command_function_code,commit,committed,condiition,condiition_number,connect,connection_name,constraint,constraint_catalog,constraint_name,constraint_schema,constraints,constructors,contains,continue,convert,corr,correspondiing,count,covar_pop,covar_samp,create,cross,cube,cume_dist,current,current_collation,current_date,current_default_transfom_group,current_path,current_role,current_time,current_timestamp,current_transfom_group_for_type,current_user,cursor,cursor_name,cycle,data,date,datetime_interval_code,datetime_interval_precision,day,deallocate,dec,decimaI,declare,default,defaults,not,null,nullable,nullif,nulls,number,numeric,object,octeet_length,octets,of,old,on,only,open,option,options,or,order,ordering,ordinaliity,others,out,outer,output,over,overlaps,overlay,overriding,pad,parameter,parameter_mode,parameter_name,parameter_ordinal_position,parameter_speciific_catalog,parameter_speciific_name,parameter_speciific_schema,partiaI,partitioon,pascal,path,percent_rank,percentile_cont,percentile_disc,placing,pli,position,power,preceding,precision,prepare,preseerv,primary,prior,privileges,procedure,public,range,rank,read,reads,real,recursivve,ref,references,referencing,regr_avgx,regr_avgy,regr_count,regr_intercept,regr_r2,regr_slope,regr_sxx,regr_sxy,regr_sy y,relative,release,repeatable,restart,result,retun,returned_cardinality,returned_length,returned_octeet_length,returned_sqlstate,returns,revoe,right,role,rollback,rollup,routine,routine_catalog,routine_name,routine_schema,row,row_count,row_number,rows,savepoint,scale,schema,schema_name,scope_catalog,scope_name,scope_schema,scroll,search,second,section,security,select,self,sensitive,sequence,seriializeable,server_name,session,session_user,set,sets,similar,simple,size,smalIint,some,source,space,specifiic,speciific_name,speciifictype,sql,sqlexception,sqlstate,sqlwarning,sqrt,start,state,statement,static,stddev_pop,stddev_samp,structure,style,subclass_origin,submultiset,substring,sum,symmetric,system,system_user,table,table_name,tablesample,temporary,then,ties,time,timesamp,timezone_hour,timezone_minute,to,top_level_count,trailing,transaction,transaction_active,transactions_committed,transactions_rolled_back,transfor,transforms,translate,translation,treat,trigger,trigger_catalog,trigger_name,trigger_schema,trim,true,type,unbounde,undefined,uncommitted,under,union,unique,unknown,unnaamed,unnest,update,upper,usage,user,user_defined_type_catalog,user_defined_type_code,user_defined_type_name,user_defined_type_schema,using,value,values,var_pop,var_samp,varchar,varying,view,when,whenever,where,width_bucket,window,with,within,without,work,write,year,zone}'::text[])", + "SELECT version()", + "SELECT * FROM pg_catalog.pg_enum WHERE 1<>1 LIMIT 1", + "SELECT reltype FROM pg_catalog.pg_class WHERE 1<>1 LIMIT 1", + "SELECT t.oid,t.*,c.relkind,format_type(nullif(t.typbasetype, 0), t.typtypmod) as base_type_name, d.description FROM pg_catalog.pg_type t LEFT OUTER JOIN pg_catalog.pg_type et ON et.oid=t.typelem LEFT OUTER JOIN pg_catalog.pg_class c ON c.oid=t.typrelid LEFT OUTER JOIN pg_catalog.pg_description d ON t.oid=d.objoid WHERE t.typname IS NOT NULL AND (c.relkind IS NULL OR c.relkind = 'c') AND (et.typcategory IS NULL OR et.typcategory <> 'C')", + "SELECT c.oid,c.*,d.description,pg_catalog.pg_get_expr(c.relpartbound, c.oid) as partition_expr, pg_catalog.pg_get_partkeydef(c.oid) as partition_key + FROM pg_catalog.pg_class c + LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=c.oid AND d.objsubid=0 AND d.classoid='pg_class'::regclass + WHERE c.relnamespace=11 AND c.relkind not in ('i','I','c')" +]; + #[tokio::test] pub async fn test_dbeaver_startup_sql() { + env_logger::init(); let service = setup_handlers(); let mut client = MockClient::new(); - SimpleQueryHandler::do_query(&service, &mut client, "SELECT 1") - .await - .expect("failed to run sql"); + for query in DBEAVER_QUERIES { + SimpleQueryHandler::do_query(&service, &mut client, query) + .await + .expect(&format!("failed to run sql: {query}")); + } } diff --git a/flake.nix b/flake.nix index 44ac8bb..bda9fbd 100644 --- a/flake.nix +++ b/flake.nix @@ -15,7 +15,8 @@ let pkgs = nixpkgs.legacyPackages.${system}; pythonEnv = pkgs.python3.withPackages (ps: with ps; [ - psycopg + psycopg2-binary + pyarrow ]); buildInputs = with pkgs; [ llvmPackages.libclang diff --git a/pg_catalog_arrow_exports/pg_get_keywords.feather b/pg_catalog_arrow_exports/pg_get_keywords.feather new file mode 100644 index 0000000..17099bd Binary files /dev/null and b/pg_catalog_arrow_exports/pg_get_keywords.feather differ diff --git a/pg_to_arrow.py b/pg_to_arrow.py new file mode 100644 index 0000000..a364a34 --- /dev/null +++ b/pg_to_arrow.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 +""" +Export PostgreSQL query results to Arrow IPC Feather format. +Minimal dependencies: psycopg2, pyarrow +""" + +import argparse +import psycopg2 +import pyarrow as pa +import pyarrow.feather as feather +from psycopg2.extras import RealDictCursor +from typing import Dict, Any, List, Optional +import sys + +def map_postgresql_to_arrow_type(type_oid: int) -> pa.DataType: + """Map PostgreSQL data types to Arrow data types.""" + # Map OIDs to Arrow types + type_mapping = { + # Integer types (OIDs from PostgreSQL documentation) + 20: pa.int64(), # int8 (bigint) + 21: pa.int16(), # int2 (smallint) + 23: pa.int32(), # int4 (integer) + 26: pa.int32(), # oid + + # Floating point types + 700: pa.float32(), # float4 (real) + 701: pa.float64(), # float8 (double precision) + 1700: pa.float64(), # numeric (decimal) + + # Boolean + 16: pa.bool_(), # bool + + # String types + 25: pa.string(), # text + 1043: pa.string(), # varchar + 18: pa.string(), # char + 19: pa.string(), # name + + # Date/time types + 1082: pa.date32(), # date + 1114: pa.timestamp('us'), # timestamp without time zone + 1184: pa.timestamp('us', tz='UTC'), # timestamp with time zone + 1083: pa.time64('us'), # time without time zone + 1266: pa.time64('us'), # time with time zone + + # Binary data + 17: pa.binary(), # bytea + + # JSON types + 114: pa.string(), # json + 3802: pa.string(), # jsonb + + # UUID + 2950: pa.string(), # uuid (Arrow doesn't have native UUID type) + + # Network types + 869: pa.string(), # inet + 650: pa.string(), # cidr + 829: pa.string(), # macaddr + } + + return type_mapping.get(type_oid, pa.string()) # Fallback to string + +def export_query_to_feather( + connection_string: str, + query: str, + output_file: str, + batch_size: int = 10000 +) -> None: + """Execute PostgreSQL query and export results to Arrow Feather format.""" + + try: + # Connect to PostgreSQL + conn = psycopg2.connect(connection_string) + cursor = conn.cursor(cursor_factory=RealDictCursor) + + # Execute query + cursor.execute(query) + + # Get column information + columns = [] + arrow_types = [] + column_names = [] + + for desc in cursor.description: + col_name = desc.name + col_oid = desc.type_code + + arrow_type = map_postgresql_to_arrow_type(col_oid) + + columns.append(col_name) + arrow_types.append(arrow_type) + column_names.append(col_name) + + # Process data in batches + all_data = {col: [] for col in columns} + rows_processed = 0 + + while True: + batch = cursor.fetchmany(batch_size) + if not batch: + break + + for row in batch: + for col in columns: + all_data[col].append(row[col]) + + rows_processed += len(batch) + print(f"Processed {rows_processed} rows...", end='\r') + + print(f"\nTotal rows processed: {rows_processed}") + + if rows_processed > 0: + # Convert to Arrow Table + arrays = [] + for col, arrow_type in zip(columns, arrow_types): + try: + array = pa.array(all_data[col], type=arrow_type) + except (pa.ArrowInvalid, pa.ArrowTypeError) as e: + print(f"Warning: Could not convert column '{col}' to {arrow_type}: {e}") + print("Falling back to string type") + array = pa.array([str(x) if x is not None else None for x in all_data[col]], type=pa.string()) + arrays.append(array) + + # Create table and write to feather + table = pa.Table.from_arrays(arrays, names=column_names) + feather.write_feather(table, output_file) + + print(f"Successfully exported {rows_processed} rows to {output_file}") + print(f"Schema: {table.schema}") + else: + print("No data found for the query.") + + except psycopg2.Error as e: + print(f"PostgreSQL error: {e}") + sys.exit(1) + except Exception as e: + print(f"Error: {e}") + import traceback + traceback.print_exc() + sys.exit(1) + finally: + if 'cursor' in locals(): + cursor.close() + if 'conn' in locals(): + conn.close() + +def main(): + parser = argparse.ArgumentParser(description='Export PostgreSQL query to Arrow Feather format') + + # Connection options + parser.add_argument('--host', default='localhost', help='PostgreSQL host') + parser.add_argument('--port', type=int, default=5432, help='PostgreSQL port') + parser.add_argument('--database', default='postgres', help='Database name') + parser.add_argument('--user', default='postgres', help='Database user') + parser.add_argument('--password', default='', help='Database password') + + # Alternative: connection string + parser.add_argument('--connection-string', help='PostgreSQL connection string (overrides individual connection params)') + + parser.add_argument('--query', required=True, help='SQL query to execute') + parser.add_argument('--output', required=True, help='Output feather file path') + parser.add_argument('--batch-size', type=int, default=10000, help='Batch size for processing') + + args = parser.parse_args() + + # Build connection string + if args.connection_string: + connection_string = args.connection_string + else: + connection_string = f"host={args.host} port={args.port} dbname={args.database} user={args.user} password={args.password}" + + export_query_to_feather( + connection_string=connection_string, + query=args.query, + output_file=args.output, + batch_size=args.batch_size + ) + +if __name__ == "__main__": + main()