Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion arrow-pg/src/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ mod tests {
{
let mut bytes = BytesMut::new();
let _sql_text = value.to_sql_text(data_type, &mut bytes);
let string = String::from_utf8((&bytes).to_vec());
let string = String::from_utf8(bytes.to_vec());
self.encoded_value = string.unwrap();
Ok(())
}
Expand Down
282 changes: 269 additions & 13 deletions datafusion-postgres/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ use tokio::sync::Mutex;
use arrow_pg::datatypes::df;
use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};

// Metadata keys for session-level settings
const METADATA_STATEMENT_TIMEOUT: &str = "statement_timeout_ms";

/// Simple startup handler that does no authentication
/// For production, use DfAuthSource with proper pgwire authentication handlers
pub struct SimpleStartupHandler;
Expand Down Expand Up @@ -100,6 +103,34 @@ impl DfSessionService {
}
}

/// Get statement timeout from client metadata
fn get_statement_timeout<C>(client: &C) -> Option<std::time::Duration>
where
C: ClientInfo,
{
client
.metadata()
.get(METADATA_STATEMENT_TIMEOUT)
.and_then(|s| s.parse::<u64>().ok())
.map(std::time::Duration::from_millis)
}

/// Set statement timeout in client metadata
fn set_statement_timeout<C>(client: &mut C, timeout: Option<std::time::Duration>)
where
C: ClientInfo,
{
let metadata = client.metadata_mut();
if let Some(duration) = timeout {
metadata.insert(
METADATA_STATEMENT_TIMEOUT.to_string(),
duration.as_millis().to_string(),
);
} else {
metadata.remove(METADATA_STATEMENT_TIMEOUT);
}
}

/// Check if the current user has permission to execute a query
async fn check_query_permission<C>(&self, client: &C, query: &str) -> PgWireResult<()>
where
Expand Down Expand Up @@ -194,10 +225,14 @@ impl DfSessionService {
Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream)))
}

async fn try_respond_set_statements<'a>(
async fn try_respond_set_statements<'a, C>(
&self,
client: &mut C,
query_lower: &str,
) -> PgWireResult<Option<Response<'a>>> {
) -> PgWireResult<Option<Response<'a>>>
where
C: ClientInfo,
{
if query_lower.starts_with("set") {
if query_lower.starts_with("set time zone") {
let parts: Vec<&str> = query_lower.split_whitespace().collect();
Expand All @@ -215,6 +250,49 @@ impl DfSessionService {
),
)))
}
} else if query_lower.starts_with("set statement_timeout") {
let parts: Vec<&str> = query_lower.split_whitespace().collect();
if parts.len() >= 3 {
let timeout_str = parts[2].trim_matches('"').trim_matches('\'');

let timeout = if timeout_str == "0" || timeout_str.is_empty() {
None
} else {
// Parse timeout value (supports ms, s, min formats)
let timeout_ms = if timeout_str.ends_with("ms") {
timeout_str.trim_end_matches("ms").parse::<u64>()
} else if timeout_str.ends_with("s") {
timeout_str
.trim_end_matches("s")
.parse::<u64>()
.map(|s| s * 1000)
} else if timeout_str.ends_with("min") {
timeout_str
.trim_end_matches("min")
.parse::<u64>()
.map(|m| m * 60 * 1000)
} else {
// Default to milliseconds
timeout_str.parse::<u64>()
};

match timeout_ms {
Ok(ms) if ms > 0 => Some(std::time::Duration::from_millis(ms)),
_ => None,
}
};

Self::set_statement_timeout(client, timeout);
Ok(Some(Response::Execution(Tag::new("SET"))))
} else {
Err(PgWireError::UserError(Box::new(
pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"42601".to_string(),
"Invalid SET statement_timeout syntax".to_string(),
),
)))
}
} else {
// pass SET query to datafusion
if let Err(e) = self.session_context.sql(query_lower).await {
Expand Down Expand Up @@ -274,10 +352,14 @@ impl DfSessionService {
}
}

async fn try_respond_show_statements<'a>(
async fn try_respond_show_statements<'a, C>(
&self,
client: &C,
query_lower: &str,
) -> PgWireResult<Option<Response<'a>>> {
) -> PgWireResult<Option<Response<'a>>>
where
C: ClientInfo,
{
if query_lower.starts_with("show ") {
match query_lower.strip_suffix(";").unwrap_or(query_lower) {
"show time zone" => {
Expand Down Expand Up @@ -305,6 +387,15 @@ impl DfSessionService {
let resp = Self::mock_show_response("search_path", default_schema)?;
Ok(Some(Response::Query(resp)))
}
"show statement_timeout" => {
let timeout = Self::get_statement_timeout(client);
let timeout_str = match timeout {
Some(duration) => format!("{}ms", duration.as_millis()),
None => "0".to_string(),
};
let resp = Self::mock_show_response("statement_timeout", &timeout_str)?;
Ok(Some(Response::Query(resp)))
}
_ => Err(PgWireError::UserError(Box::new(
pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
Expand Down Expand Up @@ -351,7 +442,10 @@ impl SimpleQueryHandler for DfSessionService {
self.check_query_permission(client, &query).await?;
}

if let Some(resp) = self.try_respond_set_statements(&query_lower).await? {
if let Some(resp) = self
.try_respond_set_statements(client, &query_lower)
.await?
{
return Ok(vec![resp]);
}

Expand All @@ -362,7 +456,10 @@ impl SimpleQueryHandler for DfSessionService {
return Ok(vec![resp]);
}

if let Some(resp) = self.try_respond_show_statements(&query_lower).await? {
if let Some(resp) = self
.try_respond_show_statements(client, &query_lower)
.await?
{
return Ok(vec![resp]);
}

Expand All @@ -378,7 +475,22 @@ impl SimpleQueryHandler for DfSessionService {
)));
}

let df_result = self.session_context.sql(&query).await;
let df_result = {
let timeout = Self::get_statement_timeout(client);
if let Some(timeout_duration) = timeout {
tokio::time::timeout(timeout_duration, self.session_context.sql(&query))
.await
.map_err(|_| {
PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"57014".to_string(), // query_canceled error code
"canceling statement due to statement timeout".to_string(),
)))
})?
} else {
self.session_context.sql(&query).await
}
};

// Handle query execution errors and transaction state
let df = match df_result {
Expand Down Expand Up @@ -496,7 +608,7 @@ impl ExtendedQueryHandler for DfSessionService {
.await?;
}

if let Some(resp) = self.try_respond_set_statements(&query).await? {
if let Some(resp) = self.try_respond_set_statements(client, &query).await? {
return Ok(resp);
}

Expand All @@ -507,7 +619,7 @@ impl ExtendedQueryHandler for DfSessionService {
return Ok(resp);
}

if let Some(resp) = self.try_respond_show_statements(&query).await? {
if let Some(resp) = self.try_respond_show_statements(client, &query).await? {
return Ok(resp);
}

Expand Down Expand Up @@ -540,10 +652,29 @@ impl ExtendedQueryHandler for DfSessionService {
.optimize(&plan)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;

let dataframe = match self.session_context.execute_logical_plan(optimised).await {
Ok(df) => df,
Err(e) => {
return Err(PgWireError::ApiError(Box::new(e)));
let dataframe = {
let timeout = Self::get_statement_timeout(client);
if let Some(timeout_duration) = timeout {
tokio::time::timeout(
timeout_duration,
self.session_context.execute_logical_plan(optimised),
)
.await
.map_err(|_| {
PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"57014".to_string(), // query_canceled error code
"canceling statement due to statement timeout".to_string(),
)))
})?
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
} else {
match self.session_context.execute_logical_plan(optimised).await {
Ok(df) => df,
Err(e) => {
return Err(PgWireError::ApiError(Box::new(e)));
}
}
}
};
let resp = df::encode_dataframe(dataframe, &portal.result_column_format).await?;
Expand Down Expand Up @@ -593,3 +724,128 @@ fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<
types.sort_by(|a, b| a.0.cmp(b.0));
types.into_iter().map(|pt| pt.1.as_ref()).collect()
}

#[cfg(test)]
mod tests {
use super::*;
use crate::auth::AuthManager;
use datafusion::prelude::SessionContext;
use std::collections::HashMap;
use std::time::Duration;

struct MockClient {
metadata: HashMap<String, String>,
}

impl MockClient {
fn new() -> Self {
Self {
metadata: HashMap::new(),
}
}
}

impl ClientInfo for MockClient {
fn socket_addr(&self) -> std::net::SocketAddr {
"127.0.0.1:5432".parse().unwrap()
}

fn is_secure(&self) -> bool {
false
}

fn protocol_version(&self) -> pgwire::messages::ProtocolVersion {
pgwire::messages::ProtocolVersion::PROTOCOL3_0
}

fn set_protocol_version(&mut self, _version: pgwire::messages::ProtocolVersion) {}

fn pid_and_secret_key(&self) -> (i32, pgwire::messages::startup::SecretKey) {
(0, pgwire::messages::startup::SecretKey::I32(0))
}

fn set_pid_and_secret_key(
&mut self,
_pid: i32,
_secret_key: pgwire::messages::startup::SecretKey,
) {
}

fn state(&self) -> pgwire::api::PgWireConnectionState {
pgwire::api::PgWireConnectionState::ReadyForQuery
}

fn set_state(&mut self, _new_state: pgwire::api::PgWireConnectionState) {}

fn transaction_status(&self) -> pgwire::messages::response::TransactionStatus {
pgwire::messages::response::TransactionStatus::Idle
}

fn set_transaction_status(
&mut self,
_new_status: pgwire::messages::response::TransactionStatus,
) {
}

fn metadata(&self) -> &HashMap<String, String> {
&self.metadata
}

fn metadata_mut(&mut self) -> &mut HashMap<String, String> {
&mut self.metadata
}

fn client_certificates<'a>(&self) -> Option<&[rustls_pki_types::CertificateDer<'a>]> {
None
}
}

#[tokio::test]
async fn test_statement_timeout_set_and_show() {
let session_context = Arc::new(SessionContext::new());
let auth_manager = Arc::new(AuthManager::new());
let service = DfSessionService::new(session_context, auth_manager);
let mut client = MockClient::new();

// Test setting timeout to 5000ms
let set_response = service
.try_respond_set_statements(&mut client, "set statement_timeout '5000ms'")
.await
.unwrap();
assert!(set_response.is_some());

// Verify the timeout was set in client metadata
let timeout = DfSessionService::get_statement_timeout(&client);
assert_eq!(timeout, Some(Duration::from_millis(5000)));

// Test SHOW statement_timeout
let show_response = service
.try_respond_show_statements(&client, "show statement_timeout")
.await
.unwrap();
assert!(show_response.is_some());
}

#[tokio::test]
async fn test_statement_timeout_disable() {
let session_context = Arc::new(SessionContext::new());
let auth_manager = Arc::new(AuthManager::new());
let service = DfSessionService::new(session_context, auth_manager);
let mut client = MockClient::new();

// Set timeout first
service
.try_respond_set_statements(&mut client, "set statement_timeout '1000ms'")
.await
.unwrap();

// Disable timeout with 0
service
.try_respond_set_statements(&mut client, "set statement_timeout '0'")
.await
.unwrap();

let timeout = DfSessionService::get_statement_timeout(&client);
assert_eq!(timeout, None);
}
}
Loading
Loading