Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions datafusion-postgres/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use pgwire::tokio::process_socket;
use rustls_pemfile::{certs, pkcs8_private_keys};
use rustls_pki_types::{CertificateDer, PrivateKeyDer};
use tokio::net::TcpListener;
use tokio::sync::Semaphore;
use tokio_rustls::rustls::{self, ServerConfig};
use tokio_rustls::TlsAcceptor;

Expand All @@ -34,6 +35,7 @@ pub struct ServerOptions {
port: u16,
tls_cert_path: Option<String>,
tls_key_path: Option<String>,
max_connections: usize,
}

impl ServerOptions {
Expand All @@ -49,6 +51,7 @@ impl Default for ServerOptions {
port: 5432,
tls_cert_path: None,
tls_key_path: None,
max_connections: 0, // 0 = no limit
}
}
}
Expand Down Expand Up @@ -126,17 +129,40 @@ pub async fn serve_with_handlers(
info!("Listening on {server_addr} (unencrypted)");
}

// Connection limiter (if configured)
let max_conn_count = opts.max_connections;
let connection_limiter = if max_conn_count > 0 {
Some(Arc::new(Semaphore::new(max_conn_count)))
} else {
None
};

// Accept incoming connections
loop {
match listener.accept().await {
Ok((socket, _addr)) => {
Ok((socket, addr)) => {
let factory_ref = handlers.clone();
let tls_acceptor_ref = tls_acceptor.clone();
let limiter_ref = connection_limiter.clone();

tokio::spawn(async move {
// Check connection limit if configured
let _permit = if let Some(ref semaphore) = limiter_ref {
match semaphore.try_acquire() {
Ok(permit) => Some(permit),
Err(_) => {
warn!("Connection rejected from {addr}: max connections ({max_conn_count}) reached");
return;
}
}
} else {
None
};

if let Err(e) = process_socket(socket, tls_acceptor_ref, factory_ref).await {
warn!("Error processing socket: {e}");
warn!("Error processing socket from {addr}: {e}");
}
// Permit is automatically released when _permit is dropped
});
}
Err(e) => {
Expand All @@ -145,3 +171,24 @@ pub async fn serve_with_handlers(
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_server_options_default_max_connections() {
let opts = ServerOptions::default();
assert_eq!(opts.max_connections, 0); // No limit by default
}

#[test]
fn test_server_options_max_connections_configuration() {
let opts = ServerOptions::new().with_max_connections(500);
assert_eq!(opts.max_connections, 500);

// Test that 0 means no limit
let opts_no_limit = ServerOptions::new().with_max_connections(0);
assert_eq!(opts_no_limit.max_connections, 0);
}
}
Loading