Skip to content

Commit c44a3ea

Browse files
authored
feat: add initial copy api design (#105)
* feat: add initial copy api design * refactor: update copy api return types * feat: add more copy apis * feat: redesigned copy handler * refactor: update main entrypoint api for copy_handler
1 parent 2f4ee5f commit c44a3ea

File tree

11 files changed

+161
-2
lines changed

11 files changed

+161
-2
lines changed

examples/bench.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use futures::StreamExt;
66
use tokio::net::TcpListener;
77

88
use pgwire::api::auth::noop::NoopStartupHandler;
9+
use pgwire::api::copy::NoopCopyHandler;
910
use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler};
1011
use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response};
1112
use pgwire::api::{ClientInfo, MakeHandler, StatelessMakeHandler, Type};
@@ -73,6 +74,7 @@ pub async fn main() {
7374
PlaceholderExtendedQueryHandler,
7475
)));
7576
let authenticator = Arc::new(StatelessMakeHandler::new(Arc::new(NoopStartupHandler)));
77+
let noop_copy_handler = Arc::new(NoopCopyHandler);
7678

7779
let server_addr = "127.0.0.1:5433";
7880
let listener = TcpListener::bind(server_addr).await.unwrap();
@@ -82,13 +84,16 @@ pub async fn main() {
8284
let authenticator_ref = authenticator.make();
8385
let processor_ref = processor.make();
8486
let placeholder_ref = placeholder.make();
87+
let copy_handler_ref = noop_copy_handler.clone();
88+
8589
tokio::spawn(async move {
8690
process_socket(
8791
incoming_socket.0,
8892
None,
8993
authenticator_ref,
9094
processor_ref,
9195
placeholder_ref,
96+
copy_handler_ref,
9297
)
9398
.await
9499
});

examples/duckdb.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use futures::stream;
88
use futures::Stream;
99
use pgwire::api::auth::md5pass::{hash_md5_password, MakeMd5PasswordAuthStartupHandler};
1010
use pgwire::api::auth::{AuthSource, DefaultServerParameterProvider, LoginInfo, Password};
11+
use pgwire::api::copy::NoopCopyHandler;
1112
use pgwire::api::portal::{Format, Portal};
1213
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
1314
use pgwire::api::results::{
@@ -350,6 +351,7 @@ pub async fn main() {
350351
Arc::new(parameters),
351352
));
352353
let processor = Arc::new(MakeDuckDBBackend::new());
354+
let noop_copy_handler = Arc::new(NoopCopyHandler);
353355

354356
let server_addr = "127.0.0.1:5432";
355357
let listener = TcpListener::bind(server_addr).await.unwrap();
@@ -358,13 +360,16 @@ pub async fn main() {
358360
let incoming_socket = listener.accept().await.unwrap();
359361
let authenticator_ref = authenticator.make();
360362
let processor_ref = processor.make();
363+
let copy_handler_ref = noop_copy_handler.clone();
364+
361365
tokio::spawn(async move {
362366
process_socket(
363367
incoming_socket.0,
364368
None,
365369
authenticator_ref,
366370
processor_ref.clone(),
367371
processor_ref,
372+
copy_handler_ref,
368373
)
369374
.await
370375
});

examples/gluesql.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use tokio::net::TcpListener;
66

77
use gluesql::prelude::*;
88
use pgwire::api::auth::noop::NoopStartupHandler;
9+
use pgwire::api::copy::NoopCopyHandler;
910
use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler};
1011
use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag};
1112
use pgwire::api::{ClientInfo, MakeHandler, StatelessMakeHandler, Type};
@@ -170,6 +171,7 @@ pub async fn main() {
170171
PlaceholderExtendedQueryHandler,
171172
)));
172173
let authenticator = Arc::new(StatelessMakeHandler::new(Arc::new(NoopStartupHandler)));
174+
let noop_copy_handler = Arc::new(NoopCopyHandler);
173175

174176
let server_addr = "127.0.0.1:5432";
175177
let listener = TcpListener::bind(server_addr).await.unwrap();
@@ -179,13 +181,16 @@ pub async fn main() {
179181
let authenticator_ref = authenticator.make();
180182
let processor_ref = processor.make();
181183
let placeholder_ref = placeholder.make();
184+
let copy_handler_ref = noop_copy_handler.clone();
185+
182186
tokio::spawn(async move {
183187
process_socket(
184188
incoming_socket.0,
185189
None,
186190
authenticator_ref,
187191
processor_ref,
188192
placeholder_ref,
193+
copy_handler_ref,
189194
)
190195
.await
191196
});

examples/scram.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use tokio_rustls::TlsAcceptor;
1212

1313
use pgwire::api::auth::scram::{gen_salted_password, MakeSASLScramAuthStartupHandler};
1414
use pgwire::api::auth::{AuthSource, DefaultServerParameterProvider, LoginInfo, Password};
15+
use pgwire::api::copy::NoopCopyHandler;
1516
use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler};
1617
use pgwire::api::results::{Response, Tag};
1718

@@ -79,6 +80,7 @@ pub async fn main() {
7980
let placeholder = Arc::new(StatelessMakeHandler::new(Arc::new(
8081
PlaceholderExtendedQueryHandler,
8182
)));
83+
let noop_copy_handler = Arc::new(NoopCopyHandler);
8284
let mut authenticator = MakeSASLScramAuthStartupHandler::new(
8385
Arc::new(DummyAuthDB),
8486
Arc::new(DefaultServerParameterProvider::default()),
@@ -99,13 +101,16 @@ pub async fn main() {
99101
let authenticator_ref = authenticator.make();
100102
let processor_ref = processor.make();
101103
let placeholder_ref = placeholder.make();
104+
let copy_handler_ref = noop_copy_handler.clone();
105+
102106
tokio::spawn(async move {
103107
process_socket(
104108
incoming_socket.0,
105109
Some(tls_acceptor_ref),
106110
authenticator_ref,
107111
processor_ref,
108112
placeholder_ref,
113+
copy_handler_ref,
109114
)
110115
.await
111116
});

examples/secure_server.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use tokio_rustls::rustls::ServerConfig;
1111
use tokio_rustls::TlsAcceptor;
1212

1313
use pgwire::api::auth::noop::NoopStartupHandler;
14+
use pgwire::api::copy::NoopCopyHandler;
1415
use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler};
1516
use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag};
1617
use pgwire::api::{ClientInfo, MakeHandler, StatelessMakeHandler, Type};
@@ -84,6 +85,7 @@ pub async fn main() {
8485
PlaceholderExtendedQueryHandler,
8586
)));
8687
let authenticator = Arc::new(StatelessMakeHandler::new(Arc::new(NoopStartupHandler)));
88+
let noop_copy_handler = Arc::new(NoopCopyHandler);
8789

8890
let server_addr = "127.0.0.1:5433";
8991
let tls_acceptor = Arc::new(setup_tls().unwrap());
@@ -96,13 +98,15 @@ pub async fn main() {
9698
let authenticator_ref = authenticator.make();
9799
let processor_ref = processor.make();
98100
let placeholder_ref = placeholder.make();
101+
let copy_handler_ref = noop_copy_handler.clone();
99102
tokio::spawn(async move {
100103
process_socket(
101104
incoming_socket.0,
102105
Some(tls_acceptor_ref),
103106
authenticator_ref,
104107
processor_ref,
105108
placeholder_ref,
109+
copy_handler_ref,
106110
)
107111
.await
108112
});

examples/server.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use futures::{stream, Sink, SinkExt, StreamExt};
66
use tokio::net::TcpListener;
77

88
use pgwire::api::auth::noop::NoopStartupHandler;
9+
use pgwire::api::copy::NoopCopyHandler;
910
use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler};
1011
use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag};
1112
use pgwire::api::{ClientInfo, MakeHandler, StatelessMakeHandler, Type};
@@ -76,6 +77,7 @@ pub async fn main() {
7677
PlaceholderExtendedQueryHandler,
7778
)));
7879
let authenticator = Arc::new(StatelessMakeHandler::new(Arc::new(NoopStartupHandler)));
80+
let noop_copy_handler = Arc::new(NoopCopyHandler);
7981

8082
let server_addr = "127.0.0.1:5432";
8183
let listener = TcpListener::bind(server_addr).await.unwrap();
@@ -85,13 +87,15 @@ pub async fn main() {
8587
let authenticator_ref = authenticator.make();
8688
let processor_ref = processor.make();
8789
let placeholder_ref = placeholder.make();
90+
let copy_handler_ref = noop_copy_handler.clone();
8891
tokio::spawn(async move {
8992
process_socket(
9093
incoming_socket.0,
9194
None,
9295
authenticator_ref,
9396
processor_ref,
9497
placeholder_ref,
98+
copy_handler_ref,
9599
)
96100
.await
97101
});

examples/sqlite.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ use std::sync::{Arc, Mutex};
33
use async_trait::async_trait;
44
use futures::stream;
55
use futures::Stream;
6+
67
use pgwire::api::auth::md5pass::{hash_md5_password, MakeMd5PasswordAuthStartupHandler};
78
use pgwire::api::auth::{AuthSource, DefaultServerParameterProvider, LoginInfo, Password};
9+
use pgwire::api::copy::NoopCopyHandler;
810
use pgwire::api::portal::{Format, Portal};
911
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
1012
use pgwire::api::results::{
@@ -306,6 +308,7 @@ pub async fn main() {
306308
Arc::new(parameters),
307309
));
308310
let processor = Arc::new(MakeSqliteBackend::new());
311+
let noop_copy_handler = Arc::new(NoopCopyHandler);
309312

310313
let server_addr = "127.0.0.1:5432";
311314
let listener = TcpListener::bind(server_addr).await.unwrap();
@@ -314,13 +317,16 @@ pub async fn main() {
314317
let incoming_socket = listener.accept().await.unwrap();
315318
let authenticator_ref = authenticator.make();
316319
let processor_ref = processor.make();
320+
let copy_handler_ref = noop_copy_handler.clone();
321+
317322
tokio::spawn(async move {
318323
process_socket(
319324
incoming_socket.0,
320325
None,
321326
authenticator_ref,
322327
processor_ref.clone(),
323328
processor_ref,
329+
copy_handler_ref,
324330
)
325331
.await
326332
});

src/api/copy.rs

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
use async_trait::async_trait;
2+
use futures::sink::{Sink, SinkExt};
3+
use std::fmt::Debug;
4+
5+
use crate::error::{PgWireError, PgWireResult};
6+
use crate::messages::copy::{
7+
CopyBothResponse, CopyData, CopyDone, CopyFail, CopyInResponse, CopyOutResponse,
8+
};
9+
use crate::messages::PgWireBackendMessage;
10+
11+
use super::ClientInfo;
12+
13+
/// handler for copy messages
14+
#[async_trait]
15+
pub trait CopyHandler: Send + Sync {
16+
async fn on_copy_data<C>(&self, _client: &mut C, _copy_data: CopyData) -> PgWireResult<()>
17+
where
18+
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
19+
C::Error: Debug,
20+
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
21+
{
22+
Ok(())
23+
}
24+
25+
async fn on_copy_done<C>(&self, _client: &mut C, _done: CopyDone) -> PgWireResult<()>
26+
where
27+
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
28+
C::Error: Debug,
29+
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
30+
{
31+
Ok(())
32+
}
33+
34+
async fn on_copy_fail<C>(&self, _client: &mut C, _fail: CopyFail) -> PgWireResult<()>
35+
where
36+
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
37+
C::Error: Debug,
38+
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
39+
{
40+
Ok(())
41+
}
42+
}
43+
44+
pub async fn send_copy_in_response<C>(
45+
client: &mut C,
46+
overall_format: i8,
47+
columns: usize,
48+
column_formats: Vec<i16>,
49+
) -> PgWireResult<()>
50+
where
51+
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
52+
C::Error: Debug,
53+
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
54+
{
55+
let resp = CopyInResponse::new(overall_format, columns as i16, column_formats);
56+
client
57+
.send(PgWireBackendMessage::CopyInResponse(resp))
58+
.await?;
59+
Ok(())
60+
}
61+
62+
pub async fn send_copy_out_response<C>(
63+
client: &mut C,
64+
overall_format: i8,
65+
columns: usize,
66+
column_formats: Vec<i16>,
67+
) -> PgWireResult<()>
68+
where
69+
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
70+
C::Error: Debug,
71+
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
72+
{
73+
let resp = CopyOutResponse::new(overall_format, columns as i16, column_formats);
74+
client
75+
.send(PgWireBackendMessage::CopyOutResponse(resp))
76+
.await?;
77+
Ok(())
78+
}
79+
80+
pub async fn send_copy_both_response<C>(
81+
client: &mut C,
82+
overall_format: i8,
83+
columns: usize,
84+
column_formats: Vec<i16>,
85+
) -> PgWireResult<()>
86+
where
87+
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
88+
C::Error: Debug,
89+
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
90+
{
91+
let resp = CopyBothResponse::new(overall_format, columns as i16, column_formats);
92+
client
93+
.send(PgWireBackendMessage::CopyBothResponse(resp))
94+
.await?;
95+
Ok(())
96+
}
97+
98+
#[derive(Clone, Copy, Debug, Default)]
99+
pub struct NoopCopyHandler;
100+
101+
impl CopyHandler for NoopCopyHandler {}

src/api/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use std::sync::Arc;
77
pub use postgres_types::Type;
88

99
pub mod auth;
10+
pub mod copy;
1011
pub mod portal;
1112
pub mod query;
1213
pub mod results;
@@ -22,6 +23,7 @@ pub enum PgWireConnectionState {
2223
AuthenticationInProgress,
2324
ReadyForQuery,
2425
QueryInProgress,
26+
CopyInProgress,
2527
AwaitingSync,
2628
}
2729

0 commit comments

Comments
 (0)