Skip to content

Commit 2b9a11c

Browse files
authored
Introduce QueryHook so users can process custom queries with own logic. (#204)
* initial query hook support * introduce QueryHook * switch to accepting datafusion statements * remove transaction state * accept a vector of hooks instead of options of hooks * add a small test * run cargo fmt * update trait to accept sqlparser::ast::Statement instead, so we don't need to wrap in a datafusion::Statement * since handle_query only accepts a single statement, only a single response is expected.
1 parent 49c3cee commit 2b9a11c

File tree

3 files changed

+105
-8
lines changed

3 files changed

+105
-8
lines changed

datafusion-postgres/src/handlers.rs

Lines changed: 98 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,20 @@ use tokio::sync::Mutex;
2727
use crate::auth::AuthManager;
2828
use arrow_pg::datatypes::df;
2929
use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
30+
use datafusion::sql::sqlparser;
3031
use datafusion_pg_catalog::pg_catalog::context::{Permission, ResourceType};
3132
use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
3233

34+
#[async_trait]
35+
pub trait QueryHook: Send + Sync {
36+
async fn handle_query(
37+
&self,
38+
statement: &sqlparser::ast::Statement,
39+
session_context: &SessionContext,
40+
client: &dyn ClientInfo,
41+
) -> Option<PgWireResult<Response>>;
42+
}
43+
3344
// Metadata keys for session-level settings
3445
const METADATA_STATEMENT_TIMEOUT: &str = "statement_timeout_ms";
3546

@@ -45,9 +56,16 @@ pub struct HandlerFactory {
4556
}
4657

4758
impl HandlerFactory {
48-
pub fn new(session_context: Arc<SessionContext>, auth_manager: Arc<AuthManager>) -> Self {
49-
let session_service =
50-
Arc::new(DfSessionService::new(session_context, auth_manager.clone()));
59+
pub fn new(
60+
session_context: Arc<SessionContext>,
61+
auth_manager: Arc<AuthManager>,
62+
query_hooks: Vec<Arc<dyn QueryHook>>,
63+
) -> Self {
64+
let session_service = Arc::new(DfSessionService::new(
65+
session_context,
66+
auth_manager.clone(),
67+
query_hooks,
68+
));
5169
HandlerFactory { session_service }
5270
}
5371
}
@@ -87,12 +105,14 @@ pub struct DfSessionService {
87105
parser: Arc<Parser>,
88106
timezone: Arc<Mutex<String>>,
89107
auth_manager: Arc<AuthManager>,
108+
query_hooks: Vec<Arc<dyn QueryHook>>,
90109
}
91110

92111
impl DfSessionService {
93112
pub fn new(
94113
session_context: Arc<SessionContext>,
95114
auth_manager: Arc<AuthManager>,
115+
query_hooks: Vec<Arc<dyn QueryHook>>,
96116
) -> DfSessionService {
97117
let parser = Arc::new(Parser {
98118
session_context: session_context.clone(),
@@ -103,6 +123,7 @@ impl DfSessionService {
103123
parser,
104124
timezone: Arc::new(Mutex::new("UTC".to_string())),
105125
auth_manager,
126+
query_hooks,
106127
}
107128
}
108129

@@ -468,6 +489,16 @@ impl SimpleQueryHandler for DfSessionService {
468489
self.check_query_permission(client, &query).await?;
469490
}
470491

492+
// Call query hooks with the parsed statement
493+
for hook in &self.query_hooks {
494+
if let Some(result) = hook
495+
.handle_query(&statement, &self.session_context, client)
496+
.await
497+
{
498+
return result.map(|response| vec![response]);
499+
}
500+
}
501+
471502
if let Some(resp) = self
472503
.try_respond_set_statements(client, &query_lower)
473504
.await?
@@ -610,6 +641,26 @@ impl ExtendedQueryHandler for DfSessionService {
610641
.to_string();
611642
log::debug!("Received execute extended query: {query}"); // Log for debugging
612643

644+
// Check query hooks first
645+
for hook in &self.query_hooks {
646+
// Parse the SQL to get the Statement for the hook
647+
let sql = &portal.statement.statement.0;
648+
let statements = self
649+
.parser
650+
.sql_parser
651+
.parse(sql)
652+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
653+
654+
if let Some(statement) = statements.into_iter().next() {
655+
if let Some(result) = hook
656+
.handle_query(&statement, &self.session_context, client)
657+
.await
658+
{
659+
return result;
660+
}
661+
}
662+
}
663+
613664
// Check permissions for the query (skip for SET and SHOW statements)
614665
if !query.starts_with("set") && !query.starts_with("show") {
615666
self.check_query_permission(client, &portal.statement.statement.0)
@@ -909,7 +960,7 @@ mod tests {
909960
async fn test_statement_timeout_set_and_show() {
910961
let session_context = Arc::new(SessionContext::new());
911962
let auth_manager = Arc::new(AuthManager::new());
912-
let service = DfSessionService::new(session_context, auth_manager);
963+
let service = DfSessionService::new(session_context, auth_manager, vec![]);
913964
let mut client = MockClient::new();
914965

915966
// Test setting timeout to 5000ms
@@ -935,7 +986,7 @@ mod tests {
935986
async fn test_statement_timeout_disable() {
936987
let session_context = Arc::new(SessionContext::new());
937988
let auth_manager = Arc::new(AuthManager::new());
938-
let service = DfSessionService::new(session_context, auth_manager);
989+
let service = DfSessionService::new(session_context, auth_manager, vec![]);
939990
let mut client = MockClient::new();
940991

941992
// Set timeout first
@@ -953,4 +1004,46 @@ mod tests {
9531004
let timeout = DfSessionService::get_statement_timeout(&client);
9541005
assert_eq!(timeout, None);
9551006
}
1007+
1008+
struct TestHook;
1009+
1010+
#[async_trait]
1011+
impl QueryHook for TestHook {
1012+
async fn handle_query(
1013+
&self,
1014+
statement: &sqlparser::ast::Statement,
1015+
_ctx: &SessionContext,
1016+
_client: &dyn ClientInfo,
1017+
) -> Option<PgWireResult<Response>> {
1018+
if statement.to_string().contains("magic") {
1019+
Some(Ok(Response::EmptyQuery))
1020+
} else {
1021+
None
1022+
}
1023+
}
1024+
}
1025+
1026+
#[tokio::test]
1027+
async fn test_query_hooks() {
1028+
let hook = TestHook;
1029+
let ctx = SessionContext::new();
1030+
let client = MockClient::new();
1031+
1032+
// Parse a statement that contains "magic"
1033+
let parser = PostgresCompatibilityParser::new();
1034+
let statements = parser.parse("SELECT magic").unwrap();
1035+
let stmt = &statements[0];
1036+
1037+
// Hook should intercept
1038+
let result = hook.handle_query(stmt, &ctx, &client).await;
1039+
assert!(result.is_some());
1040+
1041+
// Parse a normal statement
1042+
let statements = parser.parse("SELECT 1").unwrap();
1043+
let stmt = &statements[0];
1044+
1045+
// Hook should not intercept
1046+
let result = hook.handle_query(stmt, &ctx, &client).await;
1047+
assert!(result.is_none());
1048+
}
9561049
}

datafusion-postgres/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use tokio_rustls::TlsAcceptor;
2020

2121
use crate::auth::AuthManager;
2222
use handlers::HandlerFactory;
23-
pub use handlers::{DfSessionService, Parser};
23+
pub use handlers::{DfSessionService, Parser, QueryHook};
2424

2525
/// re-exports
2626
pub use arrow_pg;
@@ -85,7 +85,7 @@ pub async fn serve(
8585
auth_manager: Arc<AuthManager>,
8686
) -> Result<(), std::io::Error> {
8787
// Create the handler factory with authentication
88-
let factory = Arc::new(HandlerFactory::new(session_context, auth_manager));
88+
let factory = Arc::new(HandlerFactory::new(session_context, auth_manager, vec![]));
8989

9090
serve_with_handlers(factory, opts).await
9191
}

datafusion-postgres/tests/common/mod.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@ pub fn setup_handlers() -> DfSessionService {
2020
)
2121
.expect("Failed to setup sesession context");
2222

23-
DfSessionService::new(Arc::new(session_context), Arc::new(AuthManager::new()))
23+
DfSessionService::new(
24+
Arc::new(session_context),
25+
Arc::new(AuthManager::new()),
26+
vec![],
27+
)
2428
}
2529

2630
#[derive(Debug, Default)]

0 commit comments

Comments
 (0)