Skip to content

Commit 97745c3

Browse files
committed
Move server types into server module
1 parent 78a4ea8 commit 97745c3

File tree

2 files changed

+283
-283
lines changed

2 files changed

+283
-283
lines changed

src/lib.rs

Lines changed: 2 additions & 281 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,17 @@
3838
3939
#![warn(unreachable_pub)]
4040

41-
use std::future::Future;
4241
use std::io;
4342
#[cfg(unix)]
4443
use std::os::unix::io::{AsRawFd, RawFd};
4544
#[cfg(windows)]
4645
use std::os::windows::io::{AsRawSocket, RawSocket};
4746
use std::pin::Pin;
48-
use std::sync::Arc;
4947
use std::task::{Context, Poll};
5048

5149
pub use rustls;
5250

53-
use rustls::server::AcceptedAlert;
54-
use rustls::{CommonState, ServerConfig, ServerConnection};
51+
use rustls::CommonState;
5552
use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
5653

5754
macro_rules! ready {
@@ -66,284 +63,8 @@ macro_rules! ready {
6663
pub mod client;
6764
pub use client::{Connect, FallibleConnect, TlsConnector, TlsConnectorWithAlpn};
6865
mod common;
69-
use common::{MidHandshake, TlsState};
7066
pub mod server;
71-
72-
/// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method.
73-
#[derive(Clone)]
74-
pub struct TlsAcceptor {
75-
inner: Arc<ServerConfig>,
76-
}
77-
78-
impl From<Arc<ServerConfig>> for TlsAcceptor {
79-
fn from(inner: Arc<ServerConfig>) -> TlsAcceptor {
80-
TlsAcceptor { inner }
81-
}
82-
}
83-
84-
impl TlsAcceptor {
85-
#[inline]
86-
pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
87-
where
88-
IO: AsyncRead + AsyncWrite + Unpin,
89-
{
90-
self.accept_with(stream, |_| ())
91-
}
92-
93-
pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
94-
where
95-
IO: AsyncRead + AsyncWrite + Unpin,
96-
F: FnOnce(&mut ServerConnection),
97-
{
98-
let mut session = match ServerConnection::new(self.inner.clone()) {
99-
Ok(session) => session,
100-
Err(error) => {
101-
return Accept(MidHandshake::Error {
102-
io: stream,
103-
// TODO(eliza): should this really return an `io::Error`?
104-
// Probably not...
105-
error: io::Error::new(io::ErrorKind::Other, error),
106-
});
107-
}
108-
};
109-
f(&mut session);
110-
111-
Accept(MidHandshake::Handshaking(server::TlsStream {
112-
session,
113-
io: stream,
114-
state: TlsState::Stream,
115-
need_flush: false,
116-
}))
117-
}
118-
119-
/// Get a read-only reference to underlying config
120-
pub fn config(&self) -> &Arc<ServerConfig> {
121-
&self.inner
122-
}
123-
}
124-
125-
pub struct LazyConfigAcceptor<IO> {
126-
acceptor: rustls::server::Acceptor,
127-
io: Option<IO>,
128-
alert: Option<(rustls::Error, AcceptedAlert)>,
129-
}
130-
131-
impl<IO> LazyConfigAcceptor<IO>
132-
where
133-
IO: AsyncRead + AsyncWrite + Unpin,
134-
{
135-
#[inline]
136-
pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self {
137-
Self {
138-
acceptor,
139-
io: Some(io),
140-
alert: None,
141-
}
142-
}
143-
144-
/// Takes back the client connection. Will return `None` if called more than once or if the
145-
/// connection has been accepted.
146-
///
147-
/// # Example
148-
///
149-
/// ```no_run
150-
/// # fn choose_server_config(
151-
/// # _: rustls::server::ClientHello,
152-
/// # ) -> std::sync::Arc<rustls::ServerConfig> {
153-
/// # unimplemented!();
154-
/// # }
155-
/// # #[allow(unused_variables)]
156-
/// # async fn listen() {
157-
/// use tokio::io::AsyncWriteExt;
158-
/// let listener = tokio::net::TcpListener::bind("127.0.0.1:4443").await.unwrap();
159-
/// let (stream, _) = listener.accept().await.unwrap();
160-
///
161-
/// let acceptor = tokio_rustls::LazyConfigAcceptor::new(rustls::server::Acceptor::default(), stream);
162-
/// tokio::pin!(acceptor);
163-
///
164-
/// match acceptor.as_mut().await {
165-
/// Ok(start) => {
166-
/// let clientHello = start.client_hello();
167-
/// let config = choose_server_config(clientHello);
168-
/// let stream = start.into_stream(config).await.unwrap();
169-
/// // Proceed with handling the ServerConnection...
170-
/// }
171-
/// Err(err) => {
172-
/// if let Some(mut stream) = acceptor.take_io() {
173-
/// stream
174-
/// .write_all(
175-
/// format!("HTTP/1.1 400 Invalid Input\r\n\r\n\r\n{:?}\n", err)
176-
/// .as_bytes()
177-
/// )
178-
/// .await
179-
/// .unwrap();
180-
/// }
181-
/// }
182-
/// }
183-
/// # }
184-
/// ```
185-
pub fn take_io(&mut self) -> Option<IO> {
186-
self.io.take()
187-
}
188-
}
189-
190-
impl<IO> Future for LazyConfigAcceptor<IO>
191-
where
192-
IO: AsyncRead + AsyncWrite + Unpin,
193-
{
194-
type Output = Result<StartHandshake<IO>, io::Error>;
195-
196-
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
197-
let this = self.get_mut();
198-
loop {
199-
let io = match this.io.as_mut() {
200-
Some(io) => io,
201-
None => {
202-
return Poll::Ready(Err(io::Error::new(
203-
io::ErrorKind::Other,
204-
"acceptor cannot be polled after acceptance",
205-
)))
206-
}
207-
};
208-
209-
if let Some((err, mut alert)) = this.alert.take() {
210-
match alert.write(&mut common::SyncWriteAdapter { io, cx }) {
211-
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
212-
this.alert = Some((err, alert));
213-
return Poll::Pending;
214-
}
215-
Ok(0) | Err(_) => {
216-
return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidData, err)))
217-
}
218-
Ok(_) => {
219-
this.alert = Some((err, alert));
220-
continue;
221-
}
222-
};
223-
}
224-
225-
let mut reader = common::SyncReadAdapter { io, cx };
226-
match this.acceptor.read_tls(&mut reader) {
227-
Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()).into(),
228-
Ok(_) => {}
229-
Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
230-
Err(e) => return Err(e).into(),
231-
}
232-
233-
match this.acceptor.accept() {
234-
Ok(Some(accepted)) => {
235-
let io = this.io.take().unwrap();
236-
return Poll::Ready(Ok(StartHandshake { accepted, io }));
237-
}
238-
Ok(None) => {}
239-
Err((err, alert)) => {
240-
this.alert = Some((err, alert));
241-
}
242-
}
243-
}
244-
}
245-
}
246-
247-
/// An incoming connection received through [`LazyConfigAcceptor`].
248-
///
249-
/// This contains the generic `IO` asynchronous transport,
250-
/// [`ClientHello`](rustls::server::ClientHello) data,
251-
/// and all the state required to continue the TLS handshake (e.g. via
252-
/// [`StartHandshake::into_stream`]).
253-
#[derive(Debug)]
254-
pub struct StartHandshake<IO> {
255-
accepted: rustls::server::Accepted,
256-
io: IO,
257-
}
258-
259-
impl<IO> StartHandshake<IO>
260-
where
261-
IO: AsyncRead + AsyncWrite + Unpin,
262-
{
263-
pub fn client_hello(&self) -> rustls::server::ClientHello<'_> {
264-
self.accepted.client_hello()
265-
}
266-
267-
pub fn into_stream(self, config: Arc<ServerConfig>) -> Accept<IO> {
268-
self.into_stream_with(config, |_| ())
269-
}
270-
271-
pub fn into_stream_with<F>(self, config: Arc<ServerConfig>, f: F) -> Accept<IO>
272-
where
273-
F: FnOnce(&mut ServerConnection),
274-
{
275-
let mut conn = match self.accepted.into_connection(config) {
276-
Ok(conn) => conn,
277-
Err((error, alert)) => {
278-
return Accept(MidHandshake::SendAlert {
279-
io: self.io,
280-
alert,
281-
// TODO(eliza): should this really return an `io::Error`?
282-
// Probably not...
283-
error: io::Error::new(io::ErrorKind::InvalidData, error),
284-
});
285-
}
286-
};
287-
f(&mut conn);
288-
289-
Accept(MidHandshake::Handshaking(server::TlsStream {
290-
session: conn,
291-
io: self.io,
292-
state: TlsState::Stream,
293-
need_flush: false,
294-
}))
295-
}
296-
}
297-
298-
/// Future returned from `TlsAcceptor::accept` which will resolve
299-
/// once the accept handshake has finished.
300-
pub struct Accept<IO>(MidHandshake<server::TlsStream<IO>>);
301-
302-
/// Like [Accept], but returns `IO` on failure.
303-
pub struct FallibleAccept<IO>(MidHandshake<server::TlsStream<IO>>);
304-
305-
impl<IO> Accept<IO> {
306-
#[inline]
307-
pub fn into_fallible(self) -> FallibleAccept<IO> {
308-
FallibleAccept(self.0)
309-
}
310-
311-
pub fn get_ref(&self) -> Option<&IO> {
312-
match &self.0 {
313-
MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
314-
MidHandshake::SendAlert { io, .. } => Some(io),
315-
MidHandshake::Error { io, .. } => Some(io),
316-
MidHandshake::End => None,
317-
}
318-
}
319-
320-
pub fn get_mut(&mut self) -> Option<&mut IO> {
321-
match &mut self.0 {
322-
MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
323-
MidHandshake::SendAlert { io, .. } => Some(io),
324-
MidHandshake::Error { io, .. } => Some(io),
325-
MidHandshake::End => None,
326-
}
327-
}
328-
}
329-
330-
impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
331-
type Output = io::Result<server::TlsStream<IO>>;
332-
333-
#[inline]
334-
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
335-
Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
336-
}
337-
}
338-
339-
impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleAccept<IO> {
340-
type Output = Result<server::TlsStream<IO>, (io::Error, IO)>;
341-
342-
#[inline]
343-
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
344-
Pin::new(&mut self.0).poll(cx)
345-
}
346-
}
67+
pub use server::{Accept, FallibleAccept, LazyConfigAcceptor, StartHandshake, TlsAcceptor};
34768

34869
/// Unified TLS stream type
34970
///

0 commit comments

Comments
 (0)