Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions arrow-pg/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ pub fn into_pg_type(arrow_type: &DataType) -> PgWireResult<Type> {
DataType::Float16 | DataType::Float32 => Type::FLOAT4,
DataType::Float64 => Type::FLOAT8,
DataType::Decimal128(_, _) => Type::NUMERIC,
DataType::Utf8 => Type::VARCHAR,
DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT,
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT,
DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) => {
match field.data_type() {
DataType::Boolean => Type::BOOL_ARRAY,
Expand All @@ -67,8 +66,7 @@ pub fn into_pg_type(arrow_type: &DataType) -> PgWireResult<Type> {
| DataType::BinaryView => Type::BYTEA_ARRAY,
DataType::Float16 | DataType::Float32 => Type::FLOAT4_ARRAY,
DataType::Float64 => Type::FLOAT8_ARRAY,
DataType::Utf8 => Type::VARCHAR_ARRAY,
DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT_ARRAY,
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT_ARRAY,
struct_type @ DataType::Struct(_) => Type::new(
Type::RECORD_ARRAY.name().into(),
Type::RECORD_ARRAY.oid(),
Expand Down
8 changes: 3 additions & 5 deletions arrow-pg/src/datatypes/df.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,9 @@ where
} else if let Some(infer_type) = inferenced_type {
into_pg_type(infer_type)
} else {
Err(PgWireError::UserError(Box::new(ErrorInfo::new(
"FATAL".to_string(),
"XX000".to_string(),
"Unknown parameter type".to_string(),
))))
// Default to TEXT/VARCHAR for untyped parameters
// This allows arithmetic operations to work with implicit casting
Ok(Type::TEXT)
}
}

Expand Down
2 changes: 1 addition & 1 deletion arrow-pg/src/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ mod tests {
{
let mut bytes = BytesMut::new();
let _sql_text = value.to_sql_text(data_type, &mut bytes);
let string = String::from_utf8((&bytes).to_vec());
let string = String::from_utf8(bytes.to_vec());
self.encoded_value = string.unwrap();
Ok(())
}
Expand Down
129 changes: 121 additions & 8 deletions datafusion-postgres/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,16 @@ pub struct HandlerFactory {
}

impl HandlerFactory {
pub fn new(session_context: Arc<SessionContext>, auth_manager: Arc<AuthManager>) -> Self {
let session_service =
Arc::new(DfSessionService::new(session_context, auth_manager.clone()));
pub fn new(
session_context: Arc<SessionContext>,
auth_manager: Arc<AuthManager>,
query_timeout: Option<std::time::Duration>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest you to put this into a new pull request. Also it should be configured at per-session level using SET statement_timeout

) -> Self {
let session_service = Arc::new(DfSessionService::new(
session_context,
auth_manager.clone(),
query_timeout,
));
HandlerFactory { session_service }
}
}
Expand All @@ -71,12 +78,14 @@ pub struct DfSessionService {
timezone: Arc<Mutex<String>>,
auth_manager: Arc<AuthManager>,
sql_rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>>,
query_timeout: Option<std::time::Duration>,
}

impl DfSessionService {
pub fn new(
session_context: Arc<SessionContext>,
auth_manager: Arc<AuthManager>,
query_timeout: Option<std::time::Duration>,
) -> DfSessionService {
let sql_rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![
Arc::new(AliasDuplicatedProjectionRewrite),
Expand All @@ -97,6 +106,7 @@ impl DfSessionService {
timezone: Arc::new(Mutex::new("UTC".to_string())),
auth_manager,
sql_rewrite_rules,
query_timeout,
}
}

Expand Down Expand Up @@ -378,7 +388,19 @@ impl SimpleQueryHandler for DfSessionService {
)));
}

let df_result = self.session_context.sql(&query).await;
let df_result = if let Some(timeout) = self.query_timeout {
tokio::time::timeout(timeout, self.session_context.sql(&query))
.await
.map_err(|_| {
PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"57014".to_string(), // query_canceled error code
"canceling statement due to query timeout".to_string(),
)))
})?
} else {
self.session_context.sql(&query).await
};

// Handle query execution errors and transaction state
let df = match df_result {
Expand Down Expand Up @@ -540,10 +562,26 @@ impl ExtendedQueryHandler for DfSessionService {
.optimize(&plan)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;

let dataframe = match self.session_context.execute_logical_plan(optimised).await {
Ok(df) => df,
Err(e) => {
return Err(PgWireError::ApiError(Box::new(e)));
let dataframe = if let Some(timeout) = self.query_timeout {
tokio::time::timeout(
timeout,
self.session_context.execute_logical_plan(optimised),
)
.await
.map_err(|_| {
PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"57014".to_string(), // query_canceled error code
"canceling statement due to query timeout".to_string(),
)))
})?
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
} else {
match self.session_context.execute_logical_plan(optimised).await {
Ok(df) => df,
Err(e) => {
return Err(PgWireError::ApiError(Box::new(e)));
}
}
};
let resp = df::encode_dataframe(dataframe, &portal.result_column_format).await?;
Expand Down Expand Up @@ -593,3 +631,78 @@ fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<
types.sort_by(|a, b| a.0.cmp(b.0));
types.into_iter().map(|pt| pt.1.as_ref()).collect()
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{auth::AuthManager, ServerOptions};
use datafusion::prelude::SessionContext;
use std::time::Duration;

#[test]
fn test_server_options_default_timeout() {
let opts = ServerOptions::default();
assert_eq!(opts.query_timeout, Some(Duration::from_secs(30)));
}

#[test]
fn test_server_options_no_timeout() {
let mut opts = ServerOptions::new();
opts.query_timeout = None;
assert_eq!(opts.query_timeout, None);
}

#[test]
fn test_handler_factory_with_timeout() {
let session_context = Arc::new(SessionContext::new());
let auth_manager = Arc::new(AuthManager::new());
let timeout = Some(Duration::from_secs(60));

let factory = HandlerFactory::new(session_context, auth_manager, timeout);
assert_eq!(factory.session_service.query_timeout, timeout);
}

#[test]
fn test_session_service_timeout_configuration() {
let session_context = Arc::new(SessionContext::new());
let auth_manager = Arc::new(AuthManager::new());

// Test with timeout
let service_with_timeout = DfSessionService::new(
session_context.clone(),
auth_manager.clone(),
Some(Duration::from_secs(45)),
);
assert_eq!(
service_with_timeout.query_timeout,
Some(Duration::from_secs(45))
);

// Test without timeout (None)
let service_no_timeout = DfSessionService::new(session_context, auth_manager, None);
assert_eq!(service_no_timeout.query_timeout, None);
}

#[test]
fn test_timeout_configuration_from_seconds() {
// Test 0 seconds = no timeout
let opts_no_timeout = ServerOptions::new().with_query_timeout_secs(0);
assert_eq!(opts_no_timeout.query_timeout, None);

// Test positive seconds = Some(Duration)
let opts_with_timeout = ServerOptions::new().with_query_timeout_secs(60);
assert_eq!(
opts_with_timeout.query_timeout,
Some(Duration::from_secs(60))
);
}

#[test]
fn test_max_connections_configuration() {
let opts = ServerOptions::new().with_max_connections(500);
assert_eq!(opts.max_connections, 500);

let opts_default = ServerOptions::default();
assert_eq!(opts_default.max_connections, 1000);
}
}
39 changes: 36 additions & 3 deletions datafusion-postgres/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use pgwire::tokio::process_socket;
use rustls_pemfile::{certs, pkcs8_private_keys};
use rustls_pki_types::{CertificateDer, PrivateKeyDer};
use tokio::net::TcpListener;
use tokio::sync::Semaphore;
use tokio_rustls::rustls::{self, ServerConfig};
use tokio_rustls::TlsAcceptor;

Expand All @@ -34,12 +35,24 @@ pub struct ServerOptions {
port: u16,
tls_cert_path: Option<String>,
tls_key_path: Option<String>,
max_connections: usize,
query_timeout: Option<std::time::Duration>,
}

impl ServerOptions {
pub fn new() -> ServerOptions {
ServerOptions::default()
}

/// Set query timeout from seconds. Use 0 for no timeout.
pub fn with_query_timeout_secs(mut self, timeout_secs: u64) -> Self {
self.query_timeout = if timeout_secs == 0 {
None
} else {
Some(std::time::Duration::from_secs(timeout_secs))
};
self
}
}

impl Default for ServerOptions {
Expand All @@ -49,6 +62,8 @@ impl Default for ServerOptions {
port: 5432,
tls_cert_path: None,
tls_key_path: None,
max_connections: 1000,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default to 0, which means no restriction on this.

query_timeout: Some(std::time::Duration::from_secs(30)),
}
}
}
Expand Down Expand Up @@ -85,7 +100,11 @@ pub async fn serve(
let auth_manager = Arc::new(AuthManager::new());

// Create the handler factory with authentication
let factory = Arc::new(HandlerFactory::new(session_context, auth_manager));
let factory = Arc::new(HandlerFactory::new(
session_context,
auth_manager,
opts.query_timeout,
));

serve_with_handlers(factory, opts).await
}
Expand Down Expand Up @@ -126,17 +145,31 @@ pub async fn serve_with_handlers(
info!("Listening on {server_addr} (unencrypted)");
}

// Connection limiter to prevent resource exhaustion
let connection_semaphore = Arc::new(Semaphore::new(opts.max_connections));

// Accept incoming connections
loop {
match listener.accept().await {
Ok((socket, _addr)) => {
Ok((socket, addr)) => {
let factory_ref = handlers.clone();
let tls_acceptor_ref = tls_acceptor.clone();
let semaphore_ref = connection_semaphore.clone();

tokio::spawn(async move {
// Acquire connection permit to limit concurrency
let _permit = match semaphore_ref.try_acquire() {
Ok(permit) => permit,
Err(_) => {
warn!("Connection rejected from {addr}: max connections reached");
return;
}
};

if let Err(e) = process_socket(socket, tls_acceptor_ref, factory_ref).await {
warn!("Error processing socket: {e}");
warn!("Error processing socket from {addr}: {e}");
}
// Permit is automatically released when _permit is dropped
});
}
Err(e) => {
Expand Down
4 changes: 2 additions & 2 deletions datafusion-postgres/src/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ struct RemoveUnsupportedTypesVisitor<'a> {
unsupported_types: &'a HashSet<String>,
}

impl<'a> VisitorMut for RemoveUnsupportedTypesVisitor<'a> {
impl VisitorMut for RemoveUnsupportedTypesVisitor<'_> {
type Break = ();

fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
Expand Down Expand Up @@ -444,7 +444,7 @@ struct PrependUnqualifiedTableNameVisitor<'a> {
table_names: &'a HashSet<String>,
}

impl<'a> VisitorMut for PrependUnqualifiedTableNameVisitor<'a> {
impl VisitorMut for PrependUnqualifiedTableNameVisitor<'_> {
type Break = ();

fn pre_visit_table_factor(
Expand Down
6 changes: 5 additions & 1 deletion datafusion-postgres/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ pub fn setup_handlers() -> DfSessionService {
let session_context = SessionContext::new();
setup_pg_catalog(&session_context, "datafusion").expect("Failed to setup sesession context");

DfSessionService::new(Arc::new(session_context), Arc::new(AuthManager::new()))
DfSessionService::new(
Arc::new(session_context),
Arc::new(AuthManager::new()),
Some(std::time::Duration::from_secs(30)),
)
}

#[derive(Debug, Default)]
Expand Down
2 changes: 1 addition & 1 deletion datafusion-postgres/tests/dbeaver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ pub async fn test_dbeaver_startup_sql() {
for query in DBEAVER_QUERIES {
SimpleQueryHandler::do_query(&service, &mut client, query)
.await
.expect(&format!("failed to run sql: {query}"));
.unwrap_or_else(|_| panic!("failed to run sql: {query}"));
}
}
Loading