diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index b7aec42..a3a8374 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -27,9 +27,20 @@ use tokio::sync::Mutex; use crate::auth::AuthManager; use arrow_pg::datatypes::df; use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type}; +use datafusion::sql::sqlparser; use datafusion_pg_catalog::pg_catalog::context::{Permission, ResourceType}; use datafusion_pg_catalog::sql::PostgresCompatibilityParser; +#[async_trait] +pub trait QueryHook: Send + Sync { + async fn handle_query( + &self, + statement: &sqlparser::ast::Statement, + session_context: &SessionContext, + client: &dyn ClientInfo, + ) -> Option>; +} + // Metadata keys for session-level settings const METADATA_STATEMENT_TIMEOUT: &str = "statement_timeout_ms"; @@ -45,9 +56,16 @@ pub struct HandlerFactory { } impl HandlerFactory { - pub fn new(session_context: Arc, auth_manager: Arc) -> Self { - let session_service = - Arc::new(DfSessionService::new(session_context, auth_manager.clone())); + pub fn new( + session_context: Arc, + auth_manager: Arc, + query_hooks: Vec>, + ) -> Self { + let session_service = Arc::new(DfSessionService::new( + session_context, + auth_manager.clone(), + query_hooks, + )); HandlerFactory { session_service } } } @@ -87,12 +105,14 @@ pub struct DfSessionService { parser: Arc, timezone: Arc>, auth_manager: Arc, + query_hooks: Vec>, } impl DfSessionService { pub fn new( session_context: Arc, auth_manager: Arc, + query_hooks: Vec>, ) -> DfSessionService { let parser = Arc::new(Parser { session_context: session_context.clone(), @@ -103,6 +123,7 @@ impl DfSessionService { parser, timezone: Arc::new(Mutex::new("UTC".to_string())), auth_manager, + query_hooks, } } @@ -468,6 +489,16 @@ impl SimpleQueryHandler for DfSessionService { self.check_query_permission(client, &query).await?; } + // Call query hooks with the parsed statement + for hook in &self.query_hooks { + if let Some(result) = hook + .handle_query(&statement, &self.session_context, client) + .await + { + return result.map(|response| vec![response]); + } + } + if let Some(resp) = self .try_respond_set_statements(client, &query_lower) .await? @@ -610,6 +641,26 @@ impl ExtendedQueryHandler for DfSessionService { .to_string(); log::debug!("Received execute extended query: {query}"); // Log for debugging + // Check query hooks first + for hook in &self.query_hooks { + // Parse the SQL to get the Statement for the hook + let sql = &portal.statement.statement.0; + let statements = self + .parser + .sql_parser + .parse(sql) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + + if let Some(statement) = statements.into_iter().next() { + if let Some(result) = hook + .handle_query(&statement, &self.session_context, client) + .await + { + return result; + } + } + } + // Check permissions for the query (skip for SET and SHOW statements) if !query.starts_with("set") && !query.starts_with("show") { self.check_query_permission(client, &portal.statement.statement.0) @@ -909,7 +960,7 @@ mod tests { async fn test_statement_timeout_set_and_show() { let session_context = Arc::new(SessionContext::new()); let auth_manager = Arc::new(AuthManager::new()); - let service = DfSessionService::new(session_context, auth_manager); + let service = DfSessionService::new(session_context, auth_manager, vec![]); let mut client = MockClient::new(); // Test setting timeout to 5000ms @@ -935,7 +986,7 @@ mod tests { async fn test_statement_timeout_disable() { let session_context = Arc::new(SessionContext::new()); let auth_manager = Arc::new(AuthManager::new()); - let service = DfSessionService::new(session_context, auth_manager); + let service = DfSessionService::new(session_context, auth_manager, vec![]); let mut client = MockClient::new(); // Set timeout first @@ -953,4 +1004,46 @@ mod tests { let timeout = DfSessionService::get_statement_timeout(&client); assert_eq!(timeout, None); } + + struct TestHook; + + #[async_trait] + impl QueryHook for TestHook { + async fn handle_query( + &self, + statement: &sqlparser::ast::Statement, + _ctx: &SessionContext, + _client: &dyn ClientInfo, + ) -> Option> { + if statement.to_string().contains("magic") { + Some(Ok(Response::EmptyQuery)) + } else { + None + } + } + } + + #[tokio::test] + async fn test_query_hooks() { + let hook = TestHook; + let ctx = SessionContext::new(); + let client = MockClient::new(); + + // Parse a statement that contains "magic" + let parser = PostgresCompatibilityParser::new(); + let statements = parser.parse("SELECT magic").unwrap(); + let stmt = &statements[0]; + + // Hook should intercept + let result = hook.handle_query(stmt, &ctx, &client).await; + assert!(result.is_some()); + + // Parse a normal statement + let statements = parser.parse("SELECT 1").unwrap(); + let stmt = &statements[0]; + + // Hook should not intercept + let result = hook.handle_query(stmt, &ctx, &client).await; + assert!(result.is_none()); + } } diff --git a/datafusion-postgres/src/lib.rs b/datafusion-postgres/src/lib.rs index 4ed0a16..cdfe6dd 100644 --- a/datafusion-postgres/src/lib.rs +++ b/datafusion-postgres/src/lib.rs @@ -20,7 +20,7 @@ use tokio_rustls::TlsAcceptor; use crate::auth::AuthManager; use handlers::HandlerFactory; -pub use handlers::{DfSessionService, Parser}; +pub use handlers::{DfSessionService, Parser, QueryHook}; /// re-exports pub use arrow_pg; @@ -85,7 +85,7 @@ pub async fn serve( auth_manager: Arc, ) -> Result<(), std::io::Error> { // Create the handler factory with authentication - let factory = Arc::new(HandlerFactory::new(session_context, auth_manager)); + let factory = Arc::new(HandlerFactory::new(session_context, auth_manager, vec![])); serve_with_handlers(factory, opts).await } diff --git a/datafusion-postgres/tests/common/mod.rs b/datafusion-postgres/tests/common/mod.rs index 054b38d..6c646ff 100644 --- a/datafusion-postgres/tests/common/mod.rs +++ b/datafusion-postgres/tests/common/mod.rs @@ -20,7 +20,11 @@ pub fn setup_handlers() -> DfSessionService { ) .expect("Failed to setup sesession context"); - DfSessionService::new(Arc::new(session_context), Arc::new(AuthManager::new())) + DfSessionService::new( + Arc::new(session_context), + Arc::new(AuthManager::new()), + vec![], + ) } #[derive(Debug, Default)]