Skip to content
103 changes: 98 additions & 5 deletions datafusion-postgres/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,20 @@ use tokio::sync::Mutex;
use crate::auth::AuthManager;
use arrow_pg::datatypes::df;
use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
use datafusion::sql::sqlparser;
use datafusion_pg_catalog::pg_catalog::context::{Permission, ResourceType};
use datafusion_pg_catalog::sql::PostgresCompatibilityParser;

#[async_trait]
pub trait QueryHook: Send + Sync {
async fn handle_query(
&self,
statement: &sqlparser::ast::Statement,
session_context: &SessionContext,
client: &dyn ClientInfo,
) -> Option<PgWireResult<Response>>;
}

// Metadata keys for session-level settings
const METADATA_STATEMENT_TIMEOUT: &str = "statement_timeout_ms";

Expand All @@ -45,9 +56,16 @@ pub struct HandlerFactory {
}

impl HandlerFactory {
pub fn new(session_context: Arc<SessionContext>, auth_manager: Arc<AuthManager>) -> Self {
let session_service =
Arc::new(DfSessionService::new(session_context, auth_manager.clone()));
pub fn new(
session_context: Arc<SessionContext>,
auth_manager: Arc<AuthManager>,
query_hooks: Vec<Arc<dyn QueryHook>>,
) -> Self {
let session_service = Arc::new(DfSessionService::new(
session_context,
auth_manager.clone(),
query_hooks,
));
HandlerFactory { session_service }
}
}
Expand Down Expand Up @@ -87,12 +105,14 @@ pub struct DfSessionService {
parser: Arc<Parser>,
timezone: Arc<Mutex<String>>,
auth_manager: Arc<AuthManager>,
query_hooks: Vec<Arc<dyn QueryHook>>,
}

impl DfSessionService {
pub fn new(
session_context: Arc<SessionContext>,
auth_manager: Arc<AuthManager>,
query_hooks: Vec<Arc<dyn QueryHook>>,
) -> DfSessionService {
let parser = Arc::new(Parser {
session_context: session_context.clone(),
Expand All @@ -103,6 +123,7 @@ impl DfSessionService {
parser,
timezone: Arc::new(Mutex::new("UTC".to_string())),
auth_manager,
query_hooks,
}
}

Expand Down Expand Up @@ -468,6 +489,16 @@ impl SimpleQueryHandler for DfSessionService {
self.check_query_permission(client, &query).await?;
}

// Call query hooks with the parsed statement
for hook in &self.query_hooks {
if let Some(result) = hook
.handle_query(&statement, &self.session_context, client)
.await
{
return result.map(|response| vec![response]);
}
}

if let Some(resp) = self
.try_respond_set_statements(client, &query_lower)
.await?
Expand Down Expand Up @@ -610,6 +641,26 @@ impl ExtendedQueryHandler for DfSessionService {
.to_string();
log::debug!("Received execute extended query: {query}"); // Log for debugging

// Check query hooks first
for hook in &self.query_hooks {
// Parse the SQL to get the Statement for the hook
let sql = &portal.statement.statement.0;
let statements = self
.parser
.sql_parser
.parse(sql)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;

if let Some(statement) = statements.into_iter().next() {
if let Some(result) = hook
.handle_query(&statement, &self.session_context, client)
.await
{
return result;
}
}
}

// Check permissions for the query (skip for SET and SHOW statements)
if !query.starts_with("set") && !query.starts_with("show") {
self.check_query_permission(client, &portal.statement.statement.0)
Expand Down Expand Up @@ -909,7 +960,7 @@ mod tests {
async fn test_statement_timeout_set_and_show() {
let session_context = Arc::new(SessionContext::new());
let auth_manager = Arc::new(AuthManager::new());
let service = DfSessionService::new(session_context, auth_manager);
let service = DfSessionService::new(session_context, auth_manager, vec![]);
let mut client = MockClient::new();

// Test setting timeout to 5000ms
Expand All @@ -935,7 +986,7 @@ mod tests {
async fn test_statement_timeout_disable() {
let session_context = Arc::new(SessionContext::new());
let auth_manager = Arc::new(AuthManager::new());
let service = DfSessionService::new(session_context, auth_manager);
let service = DfSessionService::new(session_context, auth_manager, vec![]);
let mut client = MockClient::new();

// Set timeout first
Expand All @@ -953,4 +1004,46 @@ mod tests {
let timeout = DfSessionService::get_statement_timeout(&client);
assert_eq!(timeout, None);
}

struct TestHook;

#[async_trait]
impl QueryHook for TestHook {
async fn handle_query(
&self,
statement: &sqlparser::ast::Statement,
_ctx: &SessionContext,
_client: &dyn ClientInfo,
) -> Option<PgWireResult<Response>> {
if statement.to_string().contains("magic") {
Some(Ok(Response::EmptyQuery))
} else {
None
}
}
}

#[tokio::test]
async fn test_query_hooks() {
let hook = TestHook;
let ctx = SessionContext::new();
let client = MockClient::new();

// Parse a statement that contains "magic"
let parser = PostgresCompatibilityParser::new();
let statements = parser.parse("SELECT magic").unwrap();
let stmt = &statements[0];

// Hook should intercept
let result = hook.handle_query(stmt, &ctx, &client).await;
assert!(result.is_some());

// Parse a normal statement
let statements = parser.parse("SELECT 1").unwrap();
let stmt = &statements[0];

// Hook should not intercept
let result = hook.handle_query(stmt, &ctx, &client).await;
assert!(result.is_none());
}
}
4 changes: 2 additions & 2 deletions datafusion-postgres/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use tokio_rustls::TlsAcceptor;

use crate::auth::AuthManager;
use handlers::HandlerFactory;
pub use handlers::{DfSessionService, Parser};
pub use handlers::{DfSessionService, Parser, QueryHook};

/// re-exports
pub use arrow_pg;
Expand Down Expand Up @@ -85,7 +85,7 @@ pub async fn serve(
auth_manager: Arc<AuthManager>,
) -> Result<(), std::io::Error> {
// Create the handler factory with authentication
let factory = Arc::new(HandlerFactory::new(session_context, auth_manager));
let factory = Arc::new(HandlerFactory::new(session_context, auth_manager, vec![]));

serve_with_handlers(factory, opts).await
}
Expand Down
6 changes: 5 additions & 1 deletion datafusion-postgres/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ pub fn setup_handlers() -> DfSessionService {
)
.expect("Failed to setup sesession context");

DfSessionService::new(Arc::new(session_context), Arc::new(AuthManager::new()))
DfSessionService::new(
Arc::new(session_context),
Arc::new(AuthManager::new()),
vec![],
)
}

#[derive(Debug, Default)]
Expand Down
Loading