diff --git a/arrow-pg/src/datatypes.rs b/arrow-pg/src/datatypes.rs index 7af25f1..c3c6276 100644 --- a/arrow-pg/src/datatypes.rs +++ b/arrow-pg/src/datatypes.rs @@ -42,8 +42,7 @@ pub fn into_pg_type(arrow_type: &DataType) -> PgWireResult { DataType::Float16 | DataType::Float32 => Type::FLOAT4, DataType::Float64 => Type::FLOAT8, DataType::Decimal128(_, _) => Type::NUMERIC, - DataType::Utf8 => Type::VARCHAR, - DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT, DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) => { match field.data_type() { DataType::Boolean => Type::BOOL_ARRAY, @@ -67,8 +66,7 @@ pub fn into_pg_type(arrow_type: &DataType) -> PgWireResult { | DataType::BinaryView => Type::BYTEA_ARRAY, DataType::Float16 | DataType::Float32 => Type::FLOAT4_ARRAY, DataType::Float64 => Type::FLOAT8_ARRAY, - DataType::Utf8 => Type::VARCHAR_ARRAY, - DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT_ARRAY, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT_ARRAY, struct_type @ DataType::Struct(_) => Type::new( Type::RECORD_ARRAY.name().into(), Type::RECORD_ARRAY.oid(), diff --git a/arrow-pg/src/datatypes/df.rs b/arrow-pg/src/datatypes/df.rs index af98b99..741d31c 100644 --- a/arrow-pg/src/datatypes/df.rs +++ b/arrow-pg/src/datatypes/df.rs @@ -66,11 +66,9 @@ where } else if let Some(infer_type) = inferenced_type { into_pg_type(infer_type) } else { - Err(PgWireError::UserError(Box::new(ErrorInfo::new( - "FATAL".to_string(), - "XX000".to_string(), - "Unknown parameter type".to_string(), - )))) + // Default to TEXT/VARCHAR for untyped parameters + // This allows arithmetic operations to work with implicit casting + Ok(Type::TEXT) } } diff --git a/arrow-pg/src/encoder.rs b/arrow-pg/src/encoder.rs index 5490e1f..8ac10da 100644 --- a/arrow-pg/src/encoder.rs +++ b/arrow-pg/src/encoder.rs @@ -574,7 +574,7 @@ mod tests { { let mut bytes = BytesMut::new(); let _sql_text = value.to_sql_text(data_type, &mut bytes); - let string = String::from_utf8((&bytes).to_vec()); + let string = String::from_utf8(bytes.to_vec()); self.encoded_value = string.unwrap(); Ok(()) } diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index 7d2ccb3..219252b 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -43,9 +43,16 @@ pub struct HandlerFactory { } impl HandlerFactory { - pub fn new(session_context: Arc, auth_manager: Arc) -> Self { - let session_service = - Arc::new(DfSessionService::new(session_context, auth_manager.clone())); + pub fn new( + session_context: Arc, + auth_manager: Arc, + query_timeout: Option, + ) -> Self { + let session_service = Arc::new(DfSessionService::new( + session_context, + auth_manager.clone(), + query_timeout, + )); HandlerFactory { session_service } } } @@ -71,12 +78,14 @@ pub struct DfSessionService { timezone: Arc>, auth_manager: Arc, sql_rewrite_rules: Vec>, + query_timeout: Option, } impl DfSessionService { pub fn new( session_context: Arc, auth_manager: Arc, + query_timeout: Option, ) -> DfSessionService { let sql_rewrite_rules: Vec> = vec![ Arc::new(AliasDuplicatedProjectionRewrite), @@ -97,6 +106,7 @@ impl DfSessionService { timezone: Arc::new(Mutex::new("UTC".to_string())), auth_manager, sql_rewrite_rules, + query_timeout, } } @@ -378,7 +388,19 @@ impl SimpleQueryHandler for DfSessionService { ))); } - let df_result = self.session_context.sql(&query).await; + let df_result = if let Some(timeout) = self.query_timeout { + tokio::time::timeout(timeout, self.session_context.sql(&query)) + .await + .map_err(|_| { + PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new( + "ERROR".to_string(), + "57014".to_string(), // query_canceled error code + "canceling statement due to query timeout".to_string(), + ))) + })? + } else { + self.session_context.sql(&query).await + }; // Handle query execution errors and transaction state let df = match df_result { @@ -540,10 +562,26 @@ impl ExtendedQueryHandler for DfSessionService { .optimize(&plan) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - let dataframe = match self.session_context.execute_logical_plan(optimised).await { - Ok(df) => df, - Err(e) => { - return Err(PgWireError::ApiError(Box::new(e))); + let dataframe = if let Some(timeout) = self.query_timeout { + tokio::time::timeout( + timeout, + self.session_context.execute_logical_plan(optimised), + ) + .await + .map_err(|_| { + PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new( + "ERROR".to_string(), + "57014".to_string(), // query_canceled error code + "canceling statement due to query timeout".to_string(), + ))) + })? + .map_err(|e| PgWireError::ApiError(Box::new(e)))? + } else { + match self.session_context.execute_logical_plan(optimised).await { + Ok(df) => df, + Err(e) => { + return Err(PgWireError::ApiError(Box::new(e))); + } } }; let resp = df::encode_dataframe(dataframe, &portal.result_column_format).await?; @@ -593,3 +631,78 @@ fn ordered_param_types(types: &HashMap>) -> Vec, tls_key_path: Option, + max_connections: usize, + query_timeout: Option, } impl ServerOptions { pub fn new() -> ServerOptions { ServerOptions::default() } + + /// Set query timeout from seconds. Use 0 for no timeout. + pub fn with_query_timeout_secs(mut self, timeout_secs: u64) -> Self { + self.query_timeout = if timeout_secs == 0 { + None + } else { + Some(std::time::Duration::from_secs(timeout_secs)) + }; + self + } } impl Default for ServerOptions { @@ -49,6 +62,8 @@ impl Default for ServerOptions { port: 5432, tls_cert_path: None, tls_key_path: None, + max_connections: 1000, + query_timeout: Some(std::time::Duration::from_secs(30)), } } } @@ -85,7 +100,11 @@ pub async fn serve( let auth_manager = Arc::new(AuthManager::new()); // Create the handler factory with authentication - let factory = Arc::new(HandlerFactory::new(session_context, auth_manager)); + let factory = Arc::new(HandlerFactory::new( + session_context, + auth_manager, + opts.query_timeout, + )); serve_with_handlers(factory, opts).await } @@ -126,17 +145,31 @@ pub async fn serve_with_handlers( info!("Listening on {server_addr} (unencrypted)"); } + // Connection limiter to prevent resource exhaustion + let connection_semaphore = Arc::new(Semaphore::new(opts.max_connections)); + // Accept incoming connections loop { match listener.accept().await { - Ok((socket, _addr)) => { + Ok((socket, addr)) => { let factory_ref = handlers.clone(); let tls_acceptor_ref = tls_acceptor.clone(); + let semaphore_ref = connection_semaphore.clone(); tokio::spawn(async move { + // Acquire connection permit to limit concurrency + let _permit = match semaphore_ref.try_acquire() { + Ok(permit) => permit, + Err(_) => { + warn!("Connection rejected from {addr}: max connections reached"); + return; + } + }; + if let Err(e) = process_socket(socket, tls_acceptor_ref, factory_ref).await { - warn!("Error processing socket: {e}"); + warn!("Error processing socket from {addr}: {e}"); } + // Permit is automatically released when _permit is dropped }); } Err(e) => { diff --git a/datafusion-postgres/src/sql.rs b/datafusion-postgres/src/sql.rs index 65c8021..2ae841f 100644 --- a/datafusion-postgres/src/sql.rs +++ b/datafusion-postgres/src/sql.rs @@ -296,7 +296,7 @@ struct RemoveUnsupportedTypesVisitor<'a> { unsupported_types: &'a HashSet, } -impl<'a> VisitorMut for RemoveUnsupportedTypesVisitor<'a> { +impl VisitorMut for RemoveUnsupportedTypesVisitor<'_> { type Break = (); fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow { @@ -444,7 +444,7 @@ struct PrependUnqualifiedTableNameVisitor<'a> { table_names: &'a HashSet, } -impl<'a> VisitorMut for PrependUnqualifiedTableNameVisitor<'a> { +impl VisitorMut for PrependUnqualifiedTableNameVisitor<'_> { type Break = (); fn pre_visit_table_factor( diff --git a/datafusion-postgres/tests/common/mod.rs b/datafusion-postgres/tests/common/mod.rs index 7c7df52..5f53588 100644 --- a/datafusion-postgres/tests/common/mod.rs +++ b/datafusion-postgres/tests/common/mod.rs @@ -14,7 +14,11 @@ pub fn setup_handlers() -> DfSessionService { let session_context = SessionContext::new(); setup_pg_catalog(&session_context, "datafusion").expect("Failed to setup sesession context"); - DfSessionService::new(Arc::new(session_context), Arc::new(AuthManager::new())) + DfSessionService::new( + Arc::new(session_context), + Arc::new(AuthManager::new()), + Some(std::time::Duration::from_secs(30)), + ) } #[derive(Debug, Default)] diff --git a/datafusion-postgres/tests/dbeaver.rs b/datafusion-postgres/tests/dbeaver.rs index e132b91..24e5ab8 100644 --- a/datafusion-postgres/tests/dbeaver.rs +++ b/datafusion-postgres/tests/dbeaver.rs @@ -34,6 +34,6 @@ pub async fn test_dbeaver_startup_sql() { for query in DBEAVER_QUERIES { SimpleQueryHandler::do_query(&service, &mut client, query) .await - .expect(&format!("failed to run sql: {query}")); + .unwrap_or_else(|_| panic!("failed to run sql: {query}")); } }