Skip to content
56 changes: 54 additions & 2 deletions datafusion-postgres/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
use datafusion_pg_catalog::pg_catalog::context::{Permission, ResourceType};
use datafusion_pg_catalog::sql::PostgresCompatibilityParser;

#[async_trait]
pub trait QueryHook: Send + Sync {
async fn handle_query(
&self,
statement: &Statement,
session_context: &SessionContext,
client: &dyn ClientInfo,
) -> Option<PgWireResult<Vec<Response>>>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the input is a single Statement, the return value should be at most one Response

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I've updated the code to return a single Response.

}

// Metadata keys for session-level settings
const METADATA_STATEMENT_TIMEOUT: &str = "statement_timeout_ms";

Expand All @@ -46,8 +56,11 @@ pub struct HandlerFactory {

impl HandlerFactory {
pub fn new(session_context: Arc<SessionContext>, auth_manager: Arc<AuthManager>) -> Self {
let session_service =
Arc::new(DfSessionService::new(session_context, auth_manager.clone()));
let session_service = Arc::new(DfSessionService::new(
session_context,
auth_manager.clone(),
None,
));
HandlerFactory { session_service }
}
}
Expand Down Expand Up @@ -87,12 +100,14 @@ pub struct DfSessionService {
parser: Arc<Parser>,
timezone: Arc<Mutex<String>>,
auth_manager: Arc<AuthManager>,
query_hook: Option<Arc<dyn QueryHook>>,
}

impl DfSessionService {
pub fn new(
session_context: Arc<SessionContext>,
auth_manager: Arc<AuthManager>,
query_hook: Option<Arc<dyn QueryHook>>,
) -> DfSessionService {
let parser = Arc::new(Parser {
session_context: session_context.clone(),
Expand All @@ -103,6 +118,7 @@ impl DfSessionService {
parser,
timezone: Arc::new(Mutex::new("UTC".to_string())),
auth_manager,
query_hook,
}
}

Expand Down Expand Up @@ -468,6 +484,17 @@ impl SimpleQueryHandler for DfSessionService {
self.check_query_permission(client, &query).await?;
}

// Call query hook with the parsed statement
if let Some(hook) = &self.query_hook {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Loop over query_hooks and make early return if one of them returns non-empty response

let wrapped_statement = Statement::Statement(Box::new(statement.clone()));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need wrap here? Can we feed statement directly to hook.handle_query?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was simply felt we would want to expose the datafusion Statement and types to users instead of the sqlparser's. But not wrapping might also be cleaner. I've updated the code.

if let Some(result) = hook
.handle_query(&wrapped_statement, &self.session_context, client)
.await
{
return result;
}
}

if let Some(resp) = self
.try_respond_set_statements(client, &query_lower)
.await?
Expand Down Expand Up @@ -610,6 +637,31 @@ impl ExtendedQueryHandler for DfSessionService {
.to_string();
log::debug!("Received execute extended query: {query}"); // Log for debugging

// Check query hook first
if let Some(hook) = &self.query_hook {
// Parse the SQL to get the Statement for the hook
let sql = &portal.statement.statement.0;
let statements = self
.parser
.sql_parser
.parse(sql)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;

if let Some(statement) = statements.into_iter().next() {
let wrapped_statement = Statement::Statement(Box::new(statement));
if let Some(result) = hook
.handle_query(&wrapped_statement, &self.session_context, client)
.await
{
// Convert Vec<Response> to single Response
// For extended query, we expect a single response
if let Some(response) = result?.into_iter().next() {
return Ok(response);
}
}
}
}

// Check permissions for the query (skip for SET and SHOW statements)
if !query.starts_with("set") && !query.starts_with("show") {
self.check_query_permission(client, &portal.statement.statement.0)
Expand Down
2 changes: 1 addition & 1 deletion datafusion-postgres/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use tokio_rustls::TlsAcceptor;

use crate::auth::AuthManager;
use handlers::HandlerFactory;
pub use handlers::{DfSessionService, Parser};
pub use handlers::{DfSessionService, Parser, QueryHook};

/// re-exports
pub use arrow_pg;
Expand Down
Loading