-
Couldn't load subscription status.
- Fork 22
Introduce QueryHook so users can process custom queries with own logic. #204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
e9189fb
0006fe1
2142e40
81d0f42
a373572
c174747
b3e3360
3bd7513
822d624
b366428
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,6 +30,16 @@ use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type}; | |
| use datafusion_pg_catalog::pg_catalog::context::{Permission, ResourceType}; | ||
| use datafusion_pg_catalog::sql::PostgresCompatibilityParser; | ||
|
|
||
| #[async_trait] | ||
| pub trait QueryHook: Send + Sync { | ||
| async fn handle_query( | ||
| &self, | ||
| statement: &Statement, | ||
| session_context: &SessionContext, | ||
| client: &dyn ClientInfo, | ||
| ) -> Option<PgWireResult<Vec<Response>>>; | ||
| } | ||
|
|
||
| // Metadata keys for session-level settings | ||
| const METADATA_STATEMENT_TIMEOUT: &str = "statement_timeout_ms"; | ||
|
|
||
|
|
@@ -46,8 +56,11 @@ pub struct HandlerFactory { | |
|
|
||
| impl HandlerFactory { | ||
| pub fn new(session_context: Arc<SessionContext>, auth_manager: Arc<AuthManager>) -> Self { | ||
| let session_service = | ||
| Arc::new(DfSessionService::new(session_context, auth_manager.clone())); | ||
| let session_service = Arc::new(DfSessionService::new( | ||
| session_context, | ||
| auth_manager.clone(), | ||
| None, | ||
| )); | ||
| HandlerFactory { session_service } | ||
| } | ||
| } | ||
|
|
@@ -87,12 +100,14 @@ pub struct DfSessionService { | |
| parser: Arc<Parser>, | ||
| timezone: Arc<Mutex<String>>, | ||
| auth_manager: Arc<AuthManager>, | ||
| query_hook: Option<Arc<dyn QueryHook>>, | ||
sunng87 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| impl DfSessionService { | ||
| pub fn new( | ||
| session_context: Arc<SessionContext>, | ||
| auth_manager: Arc<AuthManager>, | ||
| query_hook: Option<Arc<dyn QueryHook>>, | ||
| ) -> DfSessionService { | ||
| let parser = Arc::new(Parser { | ||
| session_context: session_context.clone(), | ||
|
|
@@ -103,6 +118,7 @@ impl DfSessionService { | |
| parser, | ||
| timezone: Arc::new(Mutex::new("UTC".to_string())), | ||
| auth_manager, | ||
| query_hook, | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -468,6 +484,17 @@ impl SimpleQueryHandler for DfSessionService { | |
| self.check_query_permission(client, &query).await?; | ||
| } | ||
|
|
||
| // Call query hook with the parsed statement | ||
| if let Some(hook) = &self.query_hook { | ||
|
||
| let wrapped_statement = Statement::Statement(Box::new(statement.clone())); | ||
|
||
| if let Some(result) = hook | ||
| .handle_query(&wrapped_statement, &self.session_context, client) | ||
| .await | ||
| { | ||
| return result; | ||
| } | ||
| } | ||
|
|
||
| if let Some(resp) = self | ||
sunng87 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| .try_respond_set_statements(client, &query_lower) | ||
| .await? | ||
|
|
@@ -610,6 +637,31 @@ impl ExtendedQueryHandler for DfSessionService { | |
| .to_string(); | ||
| log::debug!("Received execute extended query: {query}"); // Log for debugging | ||
|
|
||
| // Check query hook first | ||
| if let Some(hook) = &self.query_hook { | ||
| // Parse the SQL to get the Statement for the hook | ||
| let sql = &portal.statement.statement.0; | ||
| let statements = self | ||
| .parser | ||
| .sql_parser | ||
| .parse(sql) | ||
| .map_err(|e| PgWireError::ApiError(Box::new(e)))?; | ||
|
|
||
| if let Some(statement) = statements.into_iter().next() { | ||
| let wrapped_statement = Statement::Statement(Box::new(statement)); | ||
| if let Some(result) = hook | ||
| .handle_query(&wrapped_statement, &self.session_context, client) | ||
| .await | ||
| { | ||
| // Convert Vec<Response> to single Response | ||
| // For extended query, we expect a single response | ||
| if let Some(response) = result?.into_iter().next() { | ||
| return Ok(response); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Check permissions for the query (skip for SET and SHOW statements) | ||
| if !query.starts_with("set") && !query.starts_with("show") { | ||
| self.check_query_permission(client, &portal.statement.statement.0) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because the input is a single
Statement, the return value should be at most oneResponseThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. I've updated the code to return a single Response.