@@ -27,9 +27,20 @@ use tokio::sync::Mutex;
2727use  crate :: auth:: AuthManager ; 
2828use  arrow_pg:: datatypes:: df; 
2929use  arrow_pg:: datatypes:: { arrow_schema_to_pg_fields,  into_pg_type} ; 
30+ use  datafusion:: sql:: sqlparser; 
3031use  datafusion_pg_catalog:: pg_catalog:: context:: { Permission ,  ResourceType } ; 
3132use  datafusion_pg_catalog:: sql:: PostgresCompatibilityParser ; 
3233
34+ #[ async_trait]  
35+ pub  trait  QueryHook :  Send  + Sync  { 
36+     async  fn  handle_query ( 
37+         & self , 
38+         statement :  & sqlparser:: ast:: Statement , 
39+         session_context :  & SessionContext , 
40+         client :  & dyn  ClientInfo , 
41+     )  -> Option < PgWireResult < Response > > ; 
42+ } 
43+ 
3344// Metadata keys for session-level settings 
3445const  METADATA_STATEMENT_TIMEOUT :  & str  = "statement_timeout_ms" ; 
3546
@@ -45,9 +56,16 @@ pub struct HandlerFactory {
4556} 
4657
4758impl  HandlerFactory  { 
48-     pub  fn  new ( session_context :  Arc < SessionContext > ,  auth_manager :  Arc < AuthManager > )  -> Self  { 
49-         let  session_service =
50-             Arc :: new ( DfSessionService :: new ( session_context,  auth_manager. clone ( ) ) ) ; 
59+     pub  fn  new ( 
60+         session_context :  Arc < SessionContext > , 
61+         auth_manager :  Arc < AuthManager > , 
62+         query_hooks :  Vec < Arc < dyn  QueryHook > > , 
63+     )  -> Self  { 
64+         let  session_service = Arc :: new ( DfSessionService :: new ( 
65+             session_context, 
66+             auth_manager. clone ( ) , 
67+             query_hooks, 
68+         ) ) ; 
5169        HandlerFactory  {  session_service } 
5270    } 
5371} 
@@ -87,12 +105,14 @@ pub struct DfSessionService {
87105    parser :  Arc < Parser > , 
88106    timezone :  Arc < Mutex < String > > , 
89107    auth_manager :  Arc < AuthManager > , 
108+     query_hooks :  Vec < Arc < dyn  QueryHook > > , 
90109} 
91110
92111impl  DfSessionService  { 
93112    pub  fn  new ( 
94113        session_context :  Arc < SessionContext > , 
95114        auth_manager :  Arc < AuthManager > , 
115+         query_hooks :  Vec < Arc < dyn  QueryHook > > , 
96116    )  -> DfSessionService  { 
97117        let  parser = Arc :: new ( Parser  { 
98118            session_context :  session_context. clone ( ) , 
@@ -103,6 +123,7 @@ impl DfSessionService {
103123            parser, 
104124            timezone :  Arc :: new ( Mutex :: new ( "UTC" . to_string ( ) ) ) , 
105125            auth_manager, 
126+             query_hooks, 
106127        } 
107128    } 
108129
@@ -468,6 +489,16 @@ impl SimpleQueryHandler for DfSessionService {
468489                self . check_query_permission ( client,  & query) . await ?; 
469490            } 
470491
492+             // Call query hooks with the parsed statement 
493+             for  hook in  & self . query_hooks  { 
494+                 if  let  Some ( result)  = hook
495+                     . handle_query ( & statement,  & self . session_context ,  client) 
496+                     . await 
497+                 { 
498+                     return  result. map ( |response| vec ! [ response] ) ; 
499+                 } 
500+             } 
501+ 
471502            if  let  Some ( resp)  = self 
472503                . try_respond_set_statements ( client,  & query_lower) 
473504                . await ?
@@ -610,6 +641,26 @@ impl ExtendedQueryHandler for DfSessionService {
610641            . to_string ( ) ; 
611642        log:: debug!( "Received execute extended query: {query}" ) ;  // Log for debugging 
612643
644+         // Check query hooks first 
645+         for  hook in  & self . query_hooks  { 
646+             // Parse the SQL to get the Statement for the hook 
647+             let  sql = & portal. statement . statement . 0 ; 
648+             let  statements = self 
649+                 . parser 
650+                 . sql_parser 
651+                 . parse ( sql) 
652+                 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?; 
653+ 
654+             if  let  Some ( statement)  = statements. into_iter ( ) . next ( )  { 
655+                 if  let  Some ( result)  = hook
656+                     . handle_query ( & statement,  & self . session_context ,  client) 
657+                     . await 
658+                 { 
659+                     return  result; 
660+                 } 
661+             } 
662+         } 
663+ 
613664        // Check permissions for the query (skip for SET and SHOW statements) 
614665        if  !query. starts_with ( "set" )  && !query. starts_with ( "show" )  { 
615666            self . check_query_permission ( client,  & portal. statement . statement . 0 ) 
@@ -909,7 +960,7 @@ mod tests {
909960    async  fn  test_statement_timeout_set_and_show ( )  { 
910961        let  session_context = Arc :: new ( SessionContext :: new ( ) ) ; 
911962        let  auth_manager = Arc :: new ( AuthManager :: new ( ) ) ; 
912-         let  service = DfSessionService :: new ( session_context,  auth_manager) ; 
963+         let  service = DfSessionService :: new ( session_context,  auth_manager,   vec ! [ ] ) ; 
913964        let  mut  client = MockClient :: new ( ) ; 
914965
915966        // Test setting timeout to 5000ms 
@@ -935,7 +986,7 @@ mod tests {
935986    async  fn  test_statement_timeout_disable ( )  { 
936987        let  session_context = Arc :: new ( SessionContext :: new ( ) ) ; 
937988        let  auth_manager = Arc :: new ( AuthManager :: new ( ) ) ; 
938-         let  service = DfSessionService :: new ( session_context,  auth_manager) ; 
989+         let  service = DfSessionService :: new ( session_context,  auth_manager,   vec ! [ ] ) ; 
939990        let  mut  client = MockClient :: new ( ) ; 
940991
941992        // Set timeout first 
@@ -953,4 +1004,46 @@ mod tests {
9531004        let  timeout = DfSessionService :: get_statement_timeout ( & client) ; 
9541005        assert_eq ! ( timeout,  None ) ; 
9551006    } 
1007+ 
1008+     struct  TestHook ; 
1009+ 
1010+     #[ async_trait]  
1011+     impl  QueryHook  for  TestHook  { 
1012+         async  fn  handle_query ( 
1013+             & self , 
1014+             statement :  & sqlparser:: ast:: Statement , 
1015+             _ctx :  & SessionContext , 
1016+             _client :  & dyn  ClientInfo , 
1017+         )  -> Option < PgWireResult < Response > >  { 
1018+             if  statement. to_string ( ) . contains ( "magic" )  { 
1019+                 Some ( Ok ( Response :: EmptyQuery ) ) 
1020+             }  else  { 
1021+                 None 
1022+             } 
1023+         } 
1024+     } 
1025+ 
1026+     #[ tokio:: test]  
1027+     async  fn  test_query_hooks ( )  { 
1028+         let  hook = TestHook ; 
1029+         let  ctx = SessionContext :: new ( ) ; 
1030+         let  client = MockClient :: new ( ) ; 
1031+ 
1032+         // Parse a statement that contains "magic" 
1033+         let  parser = PostgresCompatibilityParser :: new ( ) ; 
1034+         let  statements = parser. parse ( "SELECT magic" ) . unwrap ( ) ; 
1035+         let  stmt = & statements[ 0 ] ; 
1036+ 
1037+         // Hook should intercept 
1038+         let  result = hook. handle_query ( stmt,  & ctx,  & client) . await ; 
1039+         assert ! ( result. is_some( ) ) ; 
1040+ 
1041+         // Parse a normal statement 
1042+         let  statements = parser. parse ( "SELECT 1" ) . unwrap ( ) ; 
1043+         let  stmt = & statements[ 0 ] ; 
1044+ 
1045+         // Hook should not intercept 
1046+         let  result = hook. handle_query ( stmt,  & ctx,  & client) . await ; 
1047+         assert ! ( result. is_none( ) ) ; 
1048+     } 
9561049} 
0 commit comments