Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ project.
- Permission control
- Built-in `pg_catalog` tables
- Built-in postgres functions for common meta queries
- [x] DBeaver compatibility
- `datafusion-postgres-cli`: A cli tool starts a postgres compatible server for
datafusion supported file formats, just like python's `SimpleHTTPServer`.
- `arrow-pg`: A data type mapping, encoding/decoding library for arrow and
Expand Down
19 changes: 19 additions & 0 deletions arrow-pg/src/list_encoder.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{str::FromStr, sync::Arc};

use arrow::array::{BinaryViewArray, StringViewArray};
#[cfg(not(feature = "datafusion"))]
use arrow::{
array::{
Expand Down Expand Up @@ -150,6 +151,15 @@ pub(crate) fn encode_list(
.collect();
encode_field(&value, type_, format)
}
DataType::Utf8View => {
let value: Vec<Option<&str>> = arr
.as_any()
.downcast_ref::<StringViewArray>()
.unwrap()
.iter()
.collect();
encode_field(&value, type_, format)
}
DataType::Binary => {
let value: Vec<Option<_>> = arr
.as_any()
Expand All @@ -168,6 +178,15 @@ pub(crate) fn encode_list(
.collect();
encode_field(&value, type_, format)
}
DataType::BinaryView => {
let value: Vec<Option<_>> = arr
.as_any()
.downcast_ref::<BinaryViewArray>()
.unwrap()
.iter()
.collect();
encode_field(&value, type_, format)
}

DataType::Date32 => {
let value: Vec<Option<_>> = arr
Expand Down
3 changes: 3 additions & 0 deletions datafusion-postgres/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@ tokio = { version = "1.47", features = ["sync", "net"] }
tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] }
rustls-pemfile = "2.0"
rustls-pki-types = "1.0"

[dev-dependencies]
env_logger = "0.11"
35 changes: 20 additions & 15 deletions datafusion-postgres/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@ use std::sync::Arc;

use crate::auth::{AuthManager, Permission, ResourceType};
use crate::sql::{
parse, rewrite, AliasDuplicatedProjectionRewrite, RemoveUnsupportedTypes,
ResolveUnqualifiedIdentifer, SqlStatementRewriteRule,
parse, rewrite, AliasDuplicatedProjectionRewrite, FixArrayLiteral, PrependUnqualifiedTableName,
RemoveTableFunctionQualifier, RemoveUnsupportedTypes, ResolveUnqualifiedIdentifer,
RewriteArrayAnyAllOperation, SqlStatementRewriteRule,
};
use async_trait::async_trait;
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::LogicalPlan;
use datafusion::prelude::*;
use datafusion::sql::parser::Statement;
use log::warn;
use pgwire::api::auth::noop::NoopStartupHandler;
use pgwire::api::auth::StartupHandler;
use pgwire::api::portal::{Format, Portal};
Expand Down Expand Up @@ -80,6 +82,10 @@ impl DfSessionService {
Arc::new(AliasDuplicatedProjectionRewrite),
Arc::new(ResolveUnqualifiedIdentifer),
Arc::new(RemoveUnsupportedTypes::new()),
Arc::new(RewriteArrayAnyAllOperation),
Arc::new(PrependUnqualifiedTableName::new()),
Arc::new(FixArrayLiteral),
Arc::new(RemoveTableFunctionQualifier),
];
let parser = Arc::new(Parser {
session_context: session_context.clone(),
Expand Down Expand Up @@ -211,14 +217,12 @@ impl DfSessionService {
}
} else {
// pass SET query to datafusion
let df = self
.session_context
.sql(query_lower)
.await
.map_err(|err| PgWireError::ApiError(Box::new(err)))?;

let resp = df::encode_dataframe(df, &Format::UnifiedText).await?;
Ok(Some(Response::Query(resp)))
if let Err(e) = self.session_context.sql(query_lower).await {
warn!("SET statement {query_lower} is not supported by datafusion, error {e}, statement ignored");
}

// Always return SET success
Ok(Some(Response::Execution(Tag::new("SET"))))
}
} else {
Ok(None)
Expand Down Expand Up @@ -297,8 +301,8 @@ impl DfSessionService {
Ok(Some(Response::Query(resp)))
}
"show search_path" => {
let default_catalog = "datafusion";
let resp = Self::mock_show_response("search_path", default_catalog)?;
let default_schema = "public";
let resp = Self::mock_show_response("search_path", default_schema)?;
Ok(Some(Response::Query(resp)))
}
_ => Err(PgWireError::UserError(Box::new(
Expand Down Expand Up @@ -331,7 +335,8 @@ impl SimpleQueryHandler for DfSessionService {
statement = rewrite(statement, &self.sql_rewrite_rules);

// TODO: improve statement check by using statement directly
let query_lower = statement.to_string().to_lowercase().trim().to_string();
let query = statement.to_string();
let query_lower = query.to_lowercase().trim().to_string();

// Check permissions for the query (skip for SET, transaction, and SHOW statements)
if !query_lower.starts_with("set")
Expand All @@ -343,7 +348,7 @@ impl SimpleQueryHandler for DfSessionService {
&& !query_lower.starts_with("abort")
&& !query_lower.starts_with("show")
{
self.check_query_permission(client, query).await?;
self.check_query_permission(client, &query).await?;
}

if let Some(resp) = self.try_respond_set_statements(&query_lower).await? {
Expand Down Expand Up @@ -373,7 +378,7 @@ impl SimpleQueryHandler for DfSessionService {
)));
}

let df_result = self.session_context.sql(query).await;
let df_result = self.session_context.sql(&query).await;

// Handle query execution errors and transaction state
let df = match df_result {
Expand Down
128 changes: 119 additions & 9 deletions datafusion-postgres/src/pg_catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,21 @@ use datafusion::arrow::array::{
use datafusion::arrow::datatypes::{DataType, Field, SchemaRef};
use datafusion::arrow::ipc::reader::FileReader;
use datafusion::catalog::streaming::StreamingTable;
use datafusion::catalog::{CatalogProviderList, MemTable, SchemaProvider};
use datafusion::catalog::{CatalogProviderList, MemTable, SchemaProvider, TableFunctionImpl};
use datafusion::common::utils::SingleRowListArrayBuilder;
use datafusion::datasource::{TableProvider, ViewTable};
use datafusion::error::{DataFusionError, Result};
use datafusion::logical_expr::{ColumnarValue, ScalarUDF, Volatility};
use datafusion::physical_plan::streaming::PartitionStream;
use datafusion::prelude::{create_udf, SessionContext};
use datafusion::prelude::{create_udf, Expr, SessionContext};
use postgres_types::Oid;
use tokio::sync::RwLock;

mod pg_attribute;
mod pg_class;
mod pg_database;
mod pg_namespace;
mod pg_settings;

const PG_CATALOG_TABLE_PG_AGGREGATE: &str = "pg_aggregate";
const PG_CATALOG_TABLE_PG_AM: &str = "pg_am";
Expand Down Expand Up @@ -86,6 +87,7 @@ const PG_CATALOG_TABLE_PG_SUBSCRIPTION_REL: &str = "pg_subscription_rel";
const PG_CATALOG_TABLE_PG_TABLESPACE: &str = "pg_tablespace";
const PG_CATALOG_TABLE_PG_TRIGGER: &str = "pg_trigger";
const PG_CATALOG_TABLE_PG_USER_MAPPING: &str = "pg_user_mapping";
const PG_CATALOG_VIEW_PG_SETTINGS: &str = "pg_settings";

/// Determine PostgreSQL table type (relkind) from DataFusion TableProvider
fn get_table_type(table: &Arc<dyn TableProvider>) -> &'static str {
Expand Down Expand Up @@ -180,6 +182,7 @@ pub const PG_CATALOG_TABLES: &[&str] = &[
PG_CATALOG_TABLE_PG_TABLESPACE,
PG_CATALOG_TABLE_PG_TRIGGER,
PG_CATALOG_TABLE_PG_USER_MAPPING,
PG_CATALOG_VIEW_PG_SETTINGS,
];

#[derive(Debug, Hash, Eq, PartialEq, PartialOrd, Ord)]
Expand All @@ -196,7 +199,7 @@ pub struct PgCatalogSchemaProvider {
catalog_list: Arc<dyn CatalogProviderList>,
oid_counter: Arc<AtomicU32>,
oid_cache: Arc<RwLock<HashMap<OidCacheKey, Oid>>>,
static_tables: PgCatalogStaticTables,
static_tables: Arc<PgCatalogStaticTables>,
}

#[async_trait]
Expand Down Expand Up @@ -345,6 +348,10 @@ impl SchemaProvider for PgCatalogSchemaProvider {
StreamingTable::try_new(Arc::clone(table.schema()), vec![table]).unwrap(),
)))
}
PG_CATALOG_VIEW_PG_SETTINGS => {
let table = pg_settings::PgSettingsView::try_new()?;
Ok(Some(Arc::new(table.try_into_memtable()?)))
}

_ => Ok(None),
}
Expand All @@ -356,12 +363,15 @@ impl SchemaProvider for PgCatalogSchemaProvider {
}

impl PgCatalogSchemaProvider {
pub fn try_new(catalog_list: Arc<dyn CatalogProviderList>) -> Result<PgCatalogSchemaProvider> {
pub fn try_new(
catalog_list: Arc<dyn CatalogProviderList>,
static_tables: Arc<PgCatalogStaticTables>,
) -> Result<PgCatalogSchemaProvider> {
Ok(Self {
catalog_list,
oid_counter: Arc::new(AtomicU32::new(16384)),
oid_cache: Arc::new(RwLock::new(HashMap::new())),
static_tables: PgCatalogStaticTables::try_new()?,
static_tables,
})
}
}
Expand Down Expand Up @@ -399,10 +409,17 @@ impl ArrowTable {
}
}

impl TableFunctionImpl for ArrowTable {
fn call(&self, _args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
let table = self.clone().try_into_memtable()?;
Ok(Arc::new(table))
}
}

/// pg_catalog table as datafusion table provider
///
/// This implementation only contains static tables
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct PgCatalogStaticTables {
pub pg_aggregate: Arc<dyn TableProvider>,
pub pg_am: Arc<dyn TableProvider>,
Expand Down Expand Up @@ -461,6 +478,8 @@ pub struct PgCatalogStaticTables {
pub pg_tablespace: Arc<dyn TableProvider>,
pub pg_trigger: Arc<dyn TableProvider>,
pub pg_user_mapping: Arc<dyn TableProvider>,

pub pg_get_keywords: Arc<dyn TableFunctionImpl>,
}

impl PgCatalogStaticTables {
Expand Down Expand Up @@ -647,6 +666,10 @@ impl PgCatalogStaticTables {
pg_user_mapping: Self::create_arrow_table(
include_bytes!("../../pg_catalog_arrow_exports/pg_user_mapping.feather").to_vec(),
)?,

pg_get_keywords: Self::create_arrow_table_function(
include_bytes!("../../pg_catalog_arrow_exports/pg_get_keywords.feather").to_vec(),
)?,
})
}

Expand All @@ -656,6 +679,11 @@ impl PgCatalogStaticTables {
let mem_table = table.try_into_memtable()?;
Ok(Arc::new(mem_table))
}

fn create_arrow_table_function(data_bytes: Vec<u8>) -> Result<Arc<dyn TableFunctionImpl>> {
let table = ArrowTable::from_ipc_data(data_bytes)?;
Ok(Arc::new(table))
}
}

pub fn create_current_schemas_udf() -> ScalarUDF {
Expand Down Expand Up @@ -862,7 +890,78 @@ pub fn create_format_type_udf() -> ScalarUDF {

create_udf(
"format_type",
vec![DataType::Int32, DataType::Int32],
vec![DataType::Int64, DataType::Int32],
DataType::Utf8,
Volatility::Stable,
Arc::new(func),
)
}

pub fn create_session_user_udf() -> ScalarUDF {
let func = move |_args: &[ColumnarValue]| {
let mut builder = StringBuilder::new();
// TODO: return real user
builder.append_value("postgres");

let array: ArrayRef = Arc::new(builder.finish());

Ok(ColumnarValue::Array(array))
};

create_udf(
"session_user",
vec![],
DataType::Utf8,
Volatility::Stable,
Arc::new(func),
)
}

pub fn create_pg_get_expr_udf() -> ScalarUDF {
let func = move |args: &[ColumnarValue]| {
let args = ColumnarValue::values_to_arrays(args)?;
let expr = &args[0];
let _oid = &args[1];

// For now, always return true (full access for current user)
let mut builder = StringBuilder::new();
for _ in 0..expr.len() {
builder.append_value("");
}

let array: ArrayRef = Arc::new(builder.finish());

Ok(ColumnarValue::Array(array))
};

create_udf(
"pg_catalog.pg_get_expr",
vec![DataType::Utf8, DataType::Int32],
DataType::Utf8,
Volatility::Stable,
Arc::new(func),
)
}

pub fn create_pg_get_partkeydef_udf() -> ScalarUDF {
let func = move |args: &[ColumnarValue]| {
let args = ColumnarValue::values_to_arrays(args)?;
let oid = &args[0];

// For now, always return true (full access for current user)
let mut builder = StringBuilder::new();
for _ in 0..oid.len() {
builder.append_value("");
}

let array: ArrayRef = Arc::new(builder.finish());

Ok(ColumnarValue::Array(array))
};

create_udf(
"pg_catalog.pg_get_partkeydef",
vec![DataType::Utf8],
DataType::Utf8,
Volatility::Stable,
Arc::new(func),
Expand All @@ -874,8 +973,11 @@ pub fn setup_pg_catalog(
session_context: &SessionContext,
catalog_name: &str,
) -> Result<(), Box<DataFusionError>> {
let pg_catalog =
PgCatalogSchemaProvider::try_new(session_context.state().catalog_list().clone())?;
let static_tables = Arc::new(PgCatalogStaticTables::try_new()?);
let pg_catalog = PgCatalogSchemaProvider::try_new(
session_context.state().catalog_list().clone(),
static_tables.clone(),
)?;
session_context
.catalog(catalog_name)
.ok_or_else(|| {
Expand All @@ -892,6 +994,10 @@ pub fn setup_pg_catalog(
session_context.register_udf(create_has_table_privilege_2param_udf());
session_context.register_udf(create_pg_table_is_visible());
session_context.register_udf(create_format_type_udf());
session_context.register_udf(create_session_user_udf());
session_context.register_udtf("pg_get_keywords", static_tables.pg_get_keywords.clone());
session_context.register_udf(create_pg_get_expr_udf());
session_context.register_udf(create_pg_get_partkeydef_udf());

Ok(())
}
Expand Down Expand Up @@ -1145,5 +1251,9 @@ mod test {
include_bytes!("../../pg_catalog_arrow_exports/pg_user_mapping.feather").to_vec(),
)
.expect("Failed to load ipc data");
let _ = ArrowTable::from_ipc_data(
include_bytes!("../../pg_catalog_arrow_exports/pg_get_keywords.feather").to_vec(),
)
.expect("Failed to load ipc data");
}
}
Loading
Loading