Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion arrow-pg/src/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ mod tests {
{
let mut bytes = BytesMut::new();
let _sql_text = value.to_sql_text(data_type, &mut bytes);
let string = String::from_utf8((&bytes).to_vec());
let string = String::from_utf8(bytes.to_vec());
self.encoded_value = string.unwrap();
Ok(())
}
Expand Down
156 changes: 151 additions & 5 deletions datafusion-postgres/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ pub struct DfSessionService {
timezone: Arc<Mutex<String>>,
auth_manager: Arc<AuthManager>,
sql_rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>>,
statement_timeout: Arc<Mutex<Option<std::time::Duration>>>,
}

impl DfSessionService {
Expand All @@ -97,6 +98,7 @@ impl DfSessionService {
timezone: Arc::new(Mutex::new("UTC".to_string())),
auth_manager,
sql_rewrite_rules,
statement_timeout: Arc::new(Mutex::new(None)),
}
}

Expand Down Expand Up @@ -215,6 +217,52 @@ impl DfSessionService {
),
)))
}
} else if query_lower.starts_with("set statement_timeout") {
let parts: Vec<&str> = query_lower.split_whitespace().collect();
if parts.len() >= 3 {
let timeout_str = parts[2].trim_matches('"').trim_matches('\'');
let mut statement_timeout = self.statement_timeout.lock().await;

if timeout_str == "0" || timeout_str.is_empty() {
*statement_timeout = None;
} else {
// Parse timeout value (supports ms, s, min formats)
let timeout_ms = if timeout_str.ends_with("ms") {
timeout_str.trim_end_matches("ms").parse::<u64>()
} else if timeout_str.ends_with("s") {
timeout_str
.trim_end_matches("s")
.parse::<u64>()
.map(|s| s * 1000)
} else if timeout_str.ends_with("min") {
timeout_str
.trim_end_matches("min")
.parse::<u64>()
.map(|m| m * 60 * 1000)
} else {
// Default to milliseconds
timeout_str.parse::<u64>()
};

match timeout_ms {
Ok(ms) if ms > 0 => {
*statement_timeout = Some(std::time::Duration::from_millis(ms));
}
_ => {
*statement_timeout = None;
}
}
}
Ok(Some(Response::Execution(Tag::new("SET"))))
} else {
Err(PgWireError::UserError(Box::new(
pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"42601".to_string(),
"Invalid SET statement_timeout syntax".to_string(),
),
)))
}
} else {
// pass SET query to datafusion
if let Err(e) = self.session_context.sql(query_lower).await {
Expand Down Expand Up @@ -305,6 +353,15 @@ impl DfSessionService {
let resp = Self::mock_show_response("search_path", default_schema)?;
Ok(Some(Response::Query(resp)))
}
"show statement_timeout" => {
let timeout = *self.statement_timeout.lock().await;
let timeout_str = match timeout {
Some(duration) => format!("{}ms", duration.as_millis()),
None => "0".to_string(),
};
let resp = Self::mock_show_response("statement_timeout", &timeout_str)?;
Ok(Some(Response::Query(resp)))
}
_ => Err(PgWireError::UserError(Box::new(
pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
Expand Down Expand Up @@ -378,7 +435,22 @@ impl SimpleQueryHandler for DfSessionService {
)));
}

let df_result = self.session_context.sql(&query).await;
let df_result = {
let timeout = *self.statement_timeout.lock().await;
if let Some(timeout_duration) = timeout {
tokio::time::timeout(timeout_duration, self.session_context.sql(&query))
.await
.map_err(|_| {
PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"57014".to_string(), // query_canceled error code
"canceling statement due to statement timeout".to_string(),
)))
})?
} else {
self.session_context.sql(&query).await
}
};

// Handle query execution errors and transaction state
let df = match df_result {
Expand Down Expand Up @@ -540,10 +612,29 @@ impl ExtendedQueryHandler for DfSessionService {
.optimize(&plan)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;

let dataframe = match self.session_context.execute_logical_plan(optimised).await {
Ok(df) => df,
Err(e) => {
return Err(PgWireError::ApiError(Box::new(e)));
let dataframe = {
let timeout = *self.statement_timeout.lock().await;
if let Some(timeout_duration) = timeout {
tokio::time::timeout(
timeout_duration,
self.session_context.execute_logical_plan(optimised),
)
.await
.map_err(|_| {
PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"57014".to_string(), // query_canceled error code
"canceling statement due to statement timeout".to_string(),
)))
})?
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
} else {
match self.session_context.execute_logical_plan(optimised).await {
Ok(df) => df,
Err(e) => {
return Err(PgWireError::ApiError(Box::new(e)));
}
}
}
};
let resp = df::encode_dataframe(dataframe, &portal.result_column_format).await?;
Expand Down Expand Up @@ -593,3 +684,58 @@ fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<
types.sort_by(|a, b| a.0.cmp(b.0));
types.into_iter().map(|pt| pt.1.as_ref()).collect()
}

#[cfg(test)]
mod tests {
use super::*;
use crate::auth::AuthManager;
use datafusion::prelude::SessionContext;
use std::time::Duration;

#[tokio::test]
async fn test_statement_timeout_set_and_show() {
let session_context = Arc::new(SessionContext::new());
let auth_manager = Arc::new(AuthManager::new());
let service = DfSessionService::new(session_context, auth_manager);

// Test setting timeout to 5000ms
let set_response = service
.try_respond_set_statements("set statement_timeout '5000ms'")
.await
.unwrap();
assert!(set_response.is_some());

// Verify the timeout was set
let timeout = *service.statement_timeout.lock().await;
assert_eq!(timeout, Some(Duration::from_millis(5000)));

// Test SHOW statement_timeout
let show_response = service
.try_respond_show_statements("show statement_timeout")
.await
.unwrap();
assert!(show_response.is_some());
}

#[tokio::test]
async fn test_statement_timeout_disable() {
let session_context = Arc::new(SessionContext::new());
let auth_manager = Arc::new(AuthManager::new());
let service = DfSessionService::new(session_context, auth_manager);

// Set timeout first
service
.try_respond_set_statements("set statement_timeout '1000ms'")
.await
.unwrap();

// Disable timeout with 0
service
.try_respond_set_statements("set statement_timeout '0'")
.await
.unwrap();

let timeout = *service.statement_timeout.lock().await;
assert_eq!(timeout, None);
}
}
4 changes: 2 additions & 2 deletions datafusion-postgres/src/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ struct RemoveUnsupportedTypesVisitor<'a> {
unsupported_types: &'a HashSet<String>,
}

impl<'a> VisitorMut for RemoveUnsupportedTypesVisitor<'a> {
impl VisitorMut for RemoveUnsupportedTypesVisitor<'_> {
type Break = ();

fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
Expand Down Expand Up @@ -444,7 +444,7 @@ struct PrependUnqualifiedTableNameVisitor<'a> {
table_names: &'a HashSet<String>,
}

impl<'a> VisitorMut for PrependUnqualifiedTableNameVisitor<'a> {
impl VisitorMut for PrependUnqualifiedTableNameVisitor<'_> {
type Break = ();

fn pre_visit_table_factor(
Expand Down
2 changes: 1 addition & 1 deletion datafusion-postgres/tests/dbeaver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ pub async fn test_dbeaver_startup_sql() {
for query in DBEAVER_QUERIES {
SimpleQueryHandler::do_query(&service, &mut client, query)
.await
.expect(&format!("failed to run sql: {query}"));
.unwrap_or_else(|_| panic!("failed to run sql: {query}"));
}
}
Loading