diff --git a/Cargo.lock b/Cargo.lock index 66afc83..1285a9f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1886,9 +1886,9 @@ dependencies = [ [[package]] name = "pgwire" -version = "0.22.0" +version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3770f56e1e8a608c6de40011b9a00c6b669c14d121024411701b4bc3b2a5be99" +checksum = "22fb7a8b4570b74080587c5f3e187553375d18e72a38c72ca7f70a065972c65d" dependencies = [ "async-trait", "aws-lc-rs", diff --git a/Cargo.toml b/Cargo.toml index db850ca..4b78e42 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ rust-version = "1.73" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -pgwire = "0.22" +pgwire = "0.23" datafusion = "39" futures = "0.3" async-trait = "0.1" diff --git a/src/main.rs b/src/main.rs index 471e5a7..e5862fe 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,7 +5,8 @@ use datafusion::execution::options::{ }; use datafusion::prelude::SessionContext; use pgwire::api::auth::noop::NoopStartupHandler; -use pgwire::api::{MakeHandler, StatelessMakeHandler}; +use pgwire::api::copy::NoopCopyHandler; +use pgwire::api::PgWireHandlerFactory; use pgwire::tokio::process_socket; use structopt::StructOpt; use tokio::net::TcpListener; @@ -42,6 +43,31 @@ fn parse_table_def(table_def: &str) -> (&str, &str) { .expect("Use this pattern to register table: table_name:file_path") } +struct HandlerFactory(Arc); + +impl PgWireHandlerFactory for HandlerFactory { + type StartupHandler = NoopStartupHandler; + type SimpleQueryHandler = handlers::DfSessionService; + type ExtendedQueryHandler = handlers::DfSessionService; + type CopyHandler = NoopCopyHandler; + + fn simple_query_handler(&self) -> Arc { + self.0.clone() + } + + fn extended_query_handler(&self) -> Arc { + self.0.clone() + } + + fn startup_handler(&self) -> Arc { + Arc::new(NoopStartupHandler) + } + + fn copy_handler(&self) -> Arc { + Arc::new(NoopCopyHandler) + } +} + #[tokio::main] async fn main() { let opts = Opt::from_args(); @@ -96,27 +122,17 @@ async fn main() { println!("Loaded {} as table {}", table_path, table_name); } - let processor = Arc::new(StatelessMakeHandler::new(Arc::new( - handlers::DfSessionService::new(session_context), - ))); - let authenticator = Arc::new(StatelessMakeHandler::new(Arc::new(NoopStartupHandler))); + let factory = Arc::new(HandlerFactory(Arc::new(handlers::DfSessionService::new( + session_context, + )))); let server_addr = "127.0.0.1:5432"; let listener = TcpListener::bind(server_addr).await.unwrap(); println!("Listening to {}", server_addr); loop { let incoming_socket = listener.accept().await.unwrap(); - let authenticator_ref = authenticator.make(); - let processor_ref = processor.make(); - tokio::spawn(async move { - process_socket( - incoming_socket.0, - None, - authenticator_ref, - processor_ref.clone(), - processor_ref, - ) - .await - }); + let factory_ref = factory.clone(); + + tokio::spawn(async move { process_socket(incoming_socket.0, None, factory_ref).await }); } }