Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 108 additions & 73 deletions datafusion-postgres/src/pg_catalog.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::collections::HashMap;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;

use async_trait::async_trait;
Expand All @@ -16,6 +18,8 @@ use datafusion::logical_expr::{ColumnarValue, ScalarUDF, Volatility};
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::physical_plan::streaming::PartitionStream;
use datafusion::prelude::{create_udf, SessionContext};
use postgres_types::Oid;
use tokio::sync::RwLock;

const PG_CATALOG_TABLE_PG_TYPE: &str = "pg_type";
const PG_CATALOG_TABLE_PG_CLASS: &str = "pg_class";
Expand Down Expand Up @@ -208,10 +212,19 @@ impl PgTypesData {
}
}

#[derive(Debug, Hash, Eq, PartialEq, PartialOrd, Ord)]
enum OidCacheKey {
Schema(String),
/// Table by schema and table name
Table(String, String),
}

// Create custom schema provider for pg_catalog
#[derive(Debug)]
pub struct PgCatalogSchemaProvider {
catalog_list: Arc<dyn CatalogProviderList>,
oid_counter: Arc<AtomicU32>,
oid_cache: Arc<RwLock<HashMap<OidCacheKey, Oid>>>,
}

#[async_trait]
Expand All @@ -229,13 +242,21 @@ impl SchemaProvider for PgCatalogSchemaProvider {
PG_CATALOG_TABLE_PG_TYPE => Ok(Some(self.create_pg_type_table())),
PG_CATALOG_TABLE_PG_AM => Ok(Some(self.create_pg_am_table())),
PG_CATALOG_TABLE_PG_CLASS => {
let table = Arc::new(PgClassTable::new(self.catalog_list.clone()));
let table = Arc::new(PgClassTable::new(
self.catalog_list.clone(),
self.oid_counter.clone(),
self.oid_cache.clone(),
));
Ok(Some(Arc::new(
StreamingTable::try_new(Arc::clone(table.schema()), vec![table]).unwrap(),
)))
}
PG_CATALOG_TABLE_PG_NAMESPACE => {
let table = Arc::new(PgNamespaceTable::new(self.catalog_list.clone()));
let table = Arc::new(PgNamespaceTable::new(
self.catalog_list.clone(),
self.oid_counter.clone(),
self.oid_cache.clone(),
));
Ok(Some(Arc::new(
StreamingTable::try_new(Arc::clone(table.schema()), vec![table]).unwrap(),
)))
Expand Down Expand Up @@ -266,7 +287,11 @@ impl SchemaProvider for PgCatalogSchemaProvider {

impl PgCatalogSchemaProvider {
pub fn new(catalog_list: Arc<dyn CatalogProviderList>) -> PgCatalogSchemaProvider {
Self { catalog_list }
Self {
catalog_list,
oid_counter: Arc::new(AtomicU32::new(0)),
oid_cache: Arc::new(RwLock::new(HashMap::new())),
}
}

/// Create a populated pg_type table with standard PostgreSQL data types
Expand Down Expand Up @@ -1033,14 +1058,20 @@ impl PgProcData {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
struct PgClassTable {
schema: SchemaRef,
catalog_list: Arc<dyn CatalogProviderList>,
oid_counter: Arc<AtomicU32>,
oid_cache: Arc<RwLock<HashMap<OidCacheKey, Oid>>>,
}

impl PgClassTable {
fn new(catalog_list: Arc<dyn CatalogProviderList>) -> PgClassTable {
fn new(
catalog_list: Arc<dyn CatalogProviderList>,
oid_counter: Arc<AtomicU32>,
oid_cache: Arc<RwLock<HashMap<OidCacheKey, Oid>>>,
) -> PgClassTable {
// Define the schema for pg_class
// This matches key columns from PostgreSQL's pg_class
let schema = Arc::new(Schema::new(vec![
Expand Down Expand Up @@ -1079,14 +1110,13 @@ impl PgClassTable {
Self {
schema,
catalog_list,
oid_counter,
oid_cache,
}
}

/// Generate record batches based on the current state of the catalog
async fn get_data(
schema: SchemaRef,
catalog_list: Arc<dyn CatalogProviderList>,
) -> Result<RecordBatch> {
async fn get_data(this: PgClassTable) -> Result<RecordBatch> {
// Vectors to store column data
let mut oids = Vec::new();
let mut relnames = Vec::new();
Expand Down Expand Up @@ -1119,24 +1149,37 @@ impl PgClassTable {
let mut relfrozenxids = Vec::new();
let mut relminmxids = Vec::new();

// Start OID counter (this is simplistic and would need to be more robust in practice)
let mut next_oid = 10000;
let mut oid_cache = this.oid_cache.write().await;
// Every time when call pg_catalog we generate a new cache and drop the
// original one in case that schemas or tables were dropped.
let mut swap_cache = HashMap::new();

// Iterate through all catalogs and schemas
for catalog_name in catalog_list.catalog_names() {
if let Some(catalog) = catalog_list.catalog(&catalog_name) {
for catalog_name in this.catalog_list.catalog_names() {
if let Some(catalog) = this.catalog_list.catalog(&catalog_name) {
for schema_name in catalog.schema_names() {
if let Some(schema) = catalog.schema(&schema_name) {
let schema_oid = next_oid;
next_oid += 1;
let cache_key = OidCacheKey::Schema(schema_name.clone());
let schema_oid = if let Some(oid) = oid_cache.get(&cache_key) {
*oid
} else {
this.oid_counter.fetch_add(1, Ordering::Relaxed)
};
swap_cache.insert(cache_key, schema_oid);

// Add an entry for the schema itself (as a namespace)
// (In a full implementation, this would go in pg_namespace)

// Now process all tables in this schema
for table_name in schema.table_names() {
let table_oid = next_oid;
next_oid += 1;
let cache_key =
OidCacheKey::Table(schema_name.clone(), table_name.clone());
let table_oid = if let Some(oid) = oid_cache.get(&cache_key) {
*oid
} else {
this.oid_counter.fetch_add(1, Ordering::Relaxed)
};
swap_cache.insert(cache_key, table_oid);

if let Some(table) = schema.table(&table_name).await? {
// Determine the correct table type based on the table provider and context
Expand All @@ -1147,14 +1190,14 @@ impl PgClassTable {
let column_count = table.schema().fields().len() as i16;

// Add table entry
oids.push(table_oid);
oids.push(table_oid as i32);
relnames.push(table_name.clone());
relnamespaces.push(schema_oid);
relnamespaces.push(schema_oid as i32);
reltypes.push(0); // Simplified: we're not tracking data types
reloftypes.push(None);
relowners.push(0); // Simplified: no owner tracking
relams.push(0); // Default access method
relfilenodes.push(table_oid); // Use OID as filenode
relfilenodes.push(table_oid as i32); // Use OID as filenode
reltablespaces.push(0); // Default tablespace
relpages.push(1); // Default page count
reltuples.push(0.0); // No row count stats
Expand Down Expand Up @@ -1184,6 +1227,8 @@ impl PgClassTable {
}
}

*oid_cache = swap_cache;

// Create Arrow arrays from the collected data
let arrays: Vec<ArrayRef> = vec![
Arc::new(Int32Array::from(oids)),
Expand Down Expand Up @@ -1219,7 +1264,7 @@ impl PgClassTable {
];

// Create a record batch
let batch = RecordBatch::try_new(schema.clone(), arrays)?;
let batch = RecordBatch::try_new(this.schema.clone(), arrays)?;

Ok(batch)
}
Expand All @@ -1231,23 +1276,28 @@ impl PartitionStream for PgClassTable {
}

fn execute(&self, _ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
let catalog_list = self.catalog_list.clone();
let schema = Arc::clone(&self.schema);
let this = self.clone();
Box::pin(RecordBatchStreamAdapter::new(
schema.clone(),
futures::stream::once(async move { Self::get_data(schema, catalog_list).await }),
this.schema.clone(),
futures::stream::once(async move { PgClassTable::get_data(this).await }),
))
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
struct PgNamespaceTable {
schema: SchemaRef,
catalog_list: Arc<dyn CatalogProviderList>,
oid_counter: Arc<AtomicU32>,
oid_cache: Arc<RwLock<HashMap<OidCacheKey, Oid>>>,
}

impl PgNamespaceTable {
pub fn new(catalog_list: Arc<dyn CatalogProviderList>) -> Self {
pub fn new(
catalog_list: Arc<dyn CatalogProviderList>,
oid_counter: Arc<AtomicU32>,
oid_cache: Arc<RwLock<HashMap<OidCacheKey, Oid>>>,
) -> Self {
// Define the schema for pg_namespace
// This matches the columns from PostgreSQL's pg_namespace
let schema = Arc::new(Schema::new(vec![
Expand All @@ -1261,62 +1311,38 @@ impl PgNamespaceTable {
Self {
schema,
catalog_list,
oid_counter,
oid_cache,
}
}

/// Generate record batches based on the current state of the catalog
async fn get_data(
schema: SchemaRef,
catalog_list: Arc<dyn CatalogProviderList>,
) -> Result<RecordBatch> {
async fn get_data(this: PgNamespaceTable) -> Result<RecordBatch> {
// Vectors to store column data
let mut oids = Vec::new();
let mut nspnames = Vec::new();
let mut nspowners = Vec::new();
let mut nspacls: Vec<Option<String>> = Vec::new();
let mut options: Vec<Option<String>> = Vec::new();

// Start OID counter (should be consistent with the values used in pg_class)
let mut next_oid = 10000;
// to store all schema-oid mapping temporarily before adding to global oid cache
let mut schema_oid_cache = HashMap::new();

// Add standard PostgreSQL system schemas
// pg_catalog schema (OID 11)
oids.push(11);
nspnames.push("pg_catalog".to_string());
nspowners.push(10); // Default superuser
nspacls.push(None);
options.push(None);

// public schema (OID 2200)
oids.push(2200);
nspnames.push("public".to_string());
nspowners.push(10); // Default superuser
nspacls.push(None);
options.push(None);

// information_schema (OID 12)
oids.push(12);
nspnames.push("information_schema".to_string());
nspowners.push(10); // Default superuser
nspacls.push(None);
options.push(None);
let mut oid_cache = this.oid_cache.write().await;

// Now add all schemas from DataFusion catalogs
for catalog_name in catalog_list.catalog_names() {
if let Some(catalog) = catalog_list.catalog(&catalog_name) {
for catalog_name in this.catalog_list.catalog_names() {
if let Some(catalog) = this.catalog_list.catalog(&catalog_name) {
for schema_name in catalog.schema_names() {
// Skip schemas we've already added as system schemas
if schema_name == "pg_catalog"
|| schema_name == "public"
|| schema_name == "information_schema"
{
continue;
}

let schema_oid = next_oid;
next_oid += 1;

oids.push(schema_oid);
let cache_key = OidCacheKey::Schema(schema_name.clone());
let schema_oid = if let Some(oid) = oid_cache.get(&cache_key) {
*oid
} else {
this.oid_counter.fetch_add(1, Ordering::Relaxed)
};
schema_oid_cache.insert(cache_key, schema_oid);

oids.push(schema_oid as i32);
nspnames.push(schema_name.clone());
nspowners.push(10); // Default owner
nspacls.push(None);
Expand All @@ -1325,6 +1351,16 @@ impl PgNamespaceTable {
}
}

// remove all schema cache and table of the schema which is no longer exists
oid_cache.retain(|key, _| match key {
OidCacheKey::Schema(_) => false,
OidCacheKey::Table(schema_name, _) => {
schema_oid_cache.contains_key(&OidCacheKey::Schema(schema_name.clone()))
}
});
// add new schema cache
oid_cache.extend(schema_oid_cache);

// Create Arrow arrays from the collected data
let arrays: Vec<ArrayRef> = vec![
Arc::new(Int32Array::from(oids)),
Expand All @@ -1335,7 +1371,7 @@ impl PgNamespaceTable {
];

// Create a full record batch
let batch = RecordBatch::try_new(schema.clone(), arrays)?;
let batch = RecordBatch::try_new(this.schema.clone(), arrays)?;

Ok(batch)
}
Expand All @@ -1347,11 +1383,10 @@ impl PartitionStream for PgNamespaceTable {
}

fn execute(&self, _ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
let catalog_list = self.catalog_list.clone();
let schema = Arc::clone(&self.schema);
let this = self.clone();
Box::pin(RecordBatchStreamAdapter::new(
schema.clone(),
futures::stream::once(async move { Self::get_data(schema, catalog_list).await }),
this.schema.clone(),
futures::stream::once(async move { Self::get_data(this).await }),
))
}
}
Expand Down
Loading