Skip to content
108 changes: 103 additions & 5 deletions datafusion-postgres/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
use datafusion_pg_catalog::pg_catalog::context::{Permission, ResourceType};
use datafusion_pg_catalog::sql::PostgresCompatibilityParser;

#[async_trait]
pub trait QueryHook: Send + Sync {
async fn handle_query(
&self,
statement: &Statement,
session_context: &SessionContext,
client: &dyn ClientInfo,
) -> Option<PgWireResult<Vec<Response>>>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the input is a single Statement, the return value should be at most one Response

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I've updated the code to return a single Response.

}

// Metadata keys for session-level settings
const METADATA_STATEMENT_TIMEOUT: &str = "statement_timeout_ms";

Expand All @@ -45,9 +55,16 @@ pub struct HandlerFactory {
}

impl HandlerFactory {
pub fn new(session_context: Arc<SessionContext>, auth_manager: Arc<AuthManager>) -> Self {
let session_service =
Arc::new(DfSessionService::new(session_context, auth_manager.clone()));
pub fn new(
session_context: Arc<SessionContext>,
auth_manager: Arc<AuthManager>,
query_hooks: Vec<Arc<dyn QueryHook>>,
) -> Self {
let session_service = Arc::new(DfSessionService::new(
session_context,
auth_manager.clone(),
query_hooks,
));
HandlerFactory { session_service }
}
}
Expand Down Expand Up @@ -87,12 +104,14 @@ pub struct DfSessionService {
parser: Arc<Parser>,
timezone: Arc<Mutex<String>>,
auth_manager: Arc<AuthManager>,
query_hooks: Vec<Arc<dyn QueryHook>>,
}

impl DfSessionService {
pub fn new(
session_context: Arc<SessionContext>,
auth_manager: Arc<AuthManager>,
query_hooks: Vec<Arc<dyn QueryHook>>,
) -> DfSessionService {
let parser = Arc::new(Parser {
session_context: session_context.clone(),
Expand All @@ -103,6 +122,7 @@ impl DfSessionService {
parser,
timezone: Arc::new(Mutex::new("UTC".to_string())),
auth_manager,
query_hooks,
}
}

Expand Down Expand Up @@ -468,6 +488,17 @@ impl SimpleQueryHandler for DfSessionService {
self.check_query_permission(client, &query).await?;
}

// Call query hooks with the parsed statement
for hook in &self.query_hooks {
let wrapped_statement = Statement::Statement(Box::new(statement.clone()));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need wrap here? Can we feed statement directly to hook.handle_query?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was simply felt we would want to expose the datafusion Statement and types to users instead of the sqlparser's. But not wrapping might also be cleaner. I've updated the code.

if let Some(result) = hook
.handle_query(&wrapped_statement, &self.session_context, client)
.await
{
return result;
}
}

if let Some(resp) = self
.try_respond_set_statements(client, &query_lower)
.await?
Expand Down Expand Up @@ -610,6 +641,31 @@ impl ExtendedQueryHandler for DfSessionService {
.to_string();
log::debug!("Received execute extended query: {query}"); // Log for debugging

// Check query hooks first
for hook in &self.query_hooks {
// Parse the SQL to get the Statement for the hook
let sql = &portal.statement.statement.0;
let statements = self
.parser
.sql_parser
.parse(sql)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;

if let Some(statement) = statements.into_iter().next() {
let wrapped_statement = Statement::Statement(Box::new(statement));
if let Some(result) = hook
.handle_query(&wrapped_statement, &self.session_context, client)
.await
{
// Convert Vec<Response> to single Response
// For extended query, we expect a single response
if let Some(response) = result?.into_iter().next() {
return Ok(response);
}
}
}
}

// Check permissions for the query (skip for SET and SHOW statements)
if !query.starts_with("set") && !query.starts_with("show") {
self.check_query_permission(client, &portal.statement.statement.0)
Expand Down Expand Up @@ -909,7 +965,7 @@ mod tests {
async fn test_statement_timeout_set_and_show() {
let session_context = Arc::new(SessionContext::new());
let auth_manager = Arc::new(AuthManager::new());
let service = DfSessionService::new(session_context, auth_manager);
let service = DfSessionService::new(session_context, auth_manager, vec![]);
let mut client = MockClient::new();

// Test setting timeout to 5000ms
Expand All @@ -935,7 +991,7 @@ mod tests {
async fn test_statement_timeout_disable() {
let session_context = Arc::new(SessionContext::new());
let auth_manager = Arc::new(AuthManager::new());
let service = DfSessionService::new(session_context, auth_manager);
let service = DfSessionService::new(session_context, auth_manager, vec![]);
let mut client = MockClient::new();

// Set timeout first
Expand All @@ -953,4 +1009,46 @@ mod tests {
let timeout = DfSessionService::get_statement_timeout(&client);
assert_eq!(timeout, None);
}

struct TestHook;

#[async_trait]
impl QueryHook for TestHook {
async fn handle_query(
&self,
statement: &Statement,
_ctx: &SessionContext,
_client: &dyn ClientInfo,
) -> Option<PgWireResult<Vec<Response>>> {
if statement.to_string().contains("magic") {
Some(Ok(vec![Response::EmptyQuery]))
} else {
None
}
}
}

#[tokio::test]
async fn test_query_hooks() {
let hook = TestHook;
let ctx = SessionContext::new();
let client = MockClient::new();

// Parse a statement that contains "magic"
let parser = PostgresCompatibilityParser::new();
let statements = parser.parse("SELECT magic").unwrap();
let stmt = Statement::Statement(Box::new(statements[0].clone()));

// Hook should intercept
let result = hook.handle_query(&stmt, &ctx, &client).await;
assert!(result.is_some());

// Parse a normal statement
let statements = parser.parse("SELECT 1").unwrap();
let stmt = Statement::Statement(Box::new(statements[0].clone()));

// Hook should not intercept
let result = hook.handle_query(&stmt, &ctx, &client).await;
assert!(result.is_none());
}
}
4 changes: 2 additions & 2 deletions datafusion-postgres/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use tokio_rustls::TlsAcceptor;

use crate::auth::AuthManager;
use handlers::HandlerFactory;
pub use handlers::{DfSessionService, Parser};
pub use handlers::{DfSessionService, Parser, QueryHook};

/// re-exports
pub use arrow_pg;
Expand Down Expand Up @@ -85,7 +85,7 @@ pub async fn serve(
auth_manager: Arc<AuthManager>,
) -> Result<(), std::io::Error> {
// Create the handler factory with authentication
let factory = Arc::new(HandlerFactory::new(session_context, auth_manager));
let factory = Arc::new(HandlerFactory::new(session_context, auth_manager, vec![]));

serve_with_handlers(factory, opts).await
}
Expand Down
6 changes: 5 additions & 1 deletion datafusion-postgres/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ pub fn setup_handlers() -> DfSessionService {
)
.expect("Failed to setup sesession context");

DfSessionService::new(Arc::new(session_context), Arc::new(AuthManager::new()))
DfSessionService::new(
Arc::new(session_context),
Arc::new(AuthManager::new()),
vec![],
)
}

#[derive(Debug, Default)]
Expand Down
Loading