Skip to content

Commit

Permalink
Expose quinn::EndpointConfig for {Client,Server}Config
Browse files Browse the repository at this point in the history
  • Loading branch information
pablosichert authored and BiagioFesta committed Oct 3, 2024
1 parent bf3a540 commit 41f9fb7
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 22 deletions.
1 change: 1 addition & 0 deletions wtransport/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ allowed_external_types = [
"quinn::send_stream::SendStream",
"quinn::send_stream::WriteError",
"quinn_proto::config::ClientConfig",
"quinn_proto::config::EndpointConfig",
"quinn_proto::config::ServerConfig",
"quinn_proto::config::TransportConfig",
"quinn_proto::connection::ConnectionError",
Expand Down
71 changes: 61 additions & 10 deletions wtransport/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

use crate::tls::build_native_cert_store;
use crate::tls::Identity;
use quinn::EndpointConfig;
use quinn::TransportConfig;
use std::fmt::Debug;
use std::fmt::Display;
Expand Down Expand Up @@ -230,6 +231,7 @@ pub struct InvalidIdleTimeout;
pub struct ServerConfig {
pub(crate) bind_address: SocketAddr,
pub(crate) dual_stack_config: Ipv6DualStackConfig,
pub(crate) endpoint_config: quinn::EndpointConfig,
pub(crate) quic_config: quinn::ServerConfig,
}

Expand All @@ -241,6 +243,20 @@ impl ServerConfig {
ServerConfigBuilder::default()
}

/// Returns a reference to the inner QUIC endpoint configuration.
#[cfg(feature = "quinn")]
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
pub fn quic_endpoint_config(&self) -> &quinn::EndpointConfig {
&self.endpoint_config
}

/// Returns a mutable reference to the inner QUIC endpoint configuration.
#[cfg(feature = "quinn")]
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
pub fn quic_endpoint_config_mut(&mut self) -> &mut quinn::EndpointConfig {
&mut self.endpoint_config
}

/// Returns a reference to the inner QUIC configuration.
#[cfg(feature = "quinn")]
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
Expand Down Expand Up @@ -360,9 +376,10 @@ impl ServerConfigBuilder<states::WantsIdentity> {
use crate::tls::server::build_default_tls_config;

let tls_config = build_default_tls_config(identity);
let quic_endpoint_config = EndpointConfig::default();
let quic_transport_config = TransportConfig::default();

self.with(tls_config, quic_transport_config)
self.with(tls_config, quic_endpoint_config, quic_transport_config)
}

/// Allows for manual configuration of a custom TLS setup using a provided
Expand Down Expand Up @@ -401,9 +418,10 @@ impl ServerConfigBuilder<states::WantsIdentity> {
self,
tls_config: TlsServerConfig,
) -> ServerConfigBuilder<states::WantsTransportConfigServer> {
let quic_endpoint_config = EndpointConfig::default();
let quic_transport_config = TransportConfig::default();

self.with(tls_config, quic_transport_config)
self.with(tls_config, quic_endpoint_config, quic_transport_config)
}

/// Configures the server with a custom QUIC transport configuration and a default TLS setup
Expand Down Expand Up @@ -453,8 +471,9 @@ impl ServerConfigBuilder<states::WantsIdentity> {
use crate::tls::server::build_default_tls_config;

let tls_config = build_default_tls_config(identity);
let quic_endpoint_config = EndpointConfig::default();

self.with(tls_config, quic_transport_config)
self.with(tls_config, quic_endpoint_config, quic_transport_config)
}

/// Configures the server with both a custom TLS configuration and a custom QUIC transport
Expand All @@ -478,7 +497,8 @@ impl ServerConfigBuilder<states::WantsIdentity> {
tls_config: TlsServerConfig,
quic_transport_config: QuicTransportConfig,
) -> ServerConfigBuilder<states::WantsTransportConfigServer> {
self.with(tls_config, quic_transport_config)
let quic_endpoint_config = EndpointConfig::default();
self.with(tls_config, quic_endpoint_config, quic_transport_config)
}

/// Directly builds [`ServerConfig`] skipping TLS and transport configuration.
Expand All @@ -490,19 +510,22 @@ impl ServerConfigBuilder<states::WantsIdentity> {
ServerConfig {
bind_address: self.0.bind_address,
dual_stack_config: self.0.dual_stack_config,
endpoint_config: EndpointConfig::default(),
quic_config,
}
}

fn with(
self,
tls_config: TlsServerConfig,
endpoint_config: EndpointConfig,
transport_config: TransportConfig,
) -> ServerConfigBuilder<states::WantsTransportConfigServer> {
ServerConfigBuilder(states::WantsTransportConfigServer {
bind_address: self.0.bind_address,
dual_stack_config: self.0.dual_stack_config,
tls_config,
endpoint_config,
transport_config,
migration: true,
})
Expand Down Expand Up @@ -530,6 +553,7 @@ impl ServerConfigBuilder<states::WantsTransportConfigServer> {
ServerConfig {
bind_address: self.0.bind_address,
dual_stack_config: self.0.dual_stack_config,
endpoint_config: self.0.endpoint_config,
quic_config,
}
}
Expand Down Expand Up @@ -680,6 +704,7 @@ impl ServerConfigBuilder<states::WantsTransportConfigServer> {
pub struct ClientConfig {
pub(crate) bind_address: SocketAddr,
pub(crate) dual_stack_config: Ipv6DualStackConfig,
pub(crate) endpoint_config: quinn::EndpointConfig,
pub(crate) quic_config: quinn::ClientConfig,
pub(crate) dns_resolver: Arc<dyn DnsResolver + Send + Sync>,
}
Expand All @@ -702,6 +727,20 @@ impl ClientConfig {
self.dns_resolver = Arc::new(dns_resolver);
}

/// Returns a reference to the inner QUIC endpoint configuration.
#[cfg(feature = "quinn")]
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
pub fn quic_endpoint_config(&self) -> &quinn::EndpointConfig {
&self.endpoint_config
}

/// Returns a mutable reference to the inner QUIC endpoint configuration.
#[cfg(feature = "quinn")]
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
pub fn quic_endpoint_config_mut(&mut self) -> &mut quinn::EndpointConfig {
&mut self.endpoint_config
}

/// Returns a reference to the inner QUIC configuration.
#[cfg(feature = "quinn")]
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
Expand Down Expand Up @@ -805,9 +844,10 @@ impl ClientConfigBuilder<states::WantsRootStore> {
use crate::tls::client::build_default_tls_config;

let tls_config = build_default_tls_config(Arc::new(build_native_cert_store()), None);
let endpoint_config = EndpointConfig::default();
let transport_config = TransportConfig::default();

self.with(tls_config, transport_config)
self.with(tls_config, endpoint_config, transport_config)
}

/// Configures the client to skip server certificate validation, potentially
Expand Down Expand Up @@ -845,9 +885,10 @@ impl ClientConfigBuilder<states::WantsRootStore> {
Some(Arc::new(NoServerVerification::new())),
);

let endpoint_config = EndpointConfig::default();
let transport_config = TransportConfig::default();

self.with(tls_config, transport_config)
self.with(tls_config, endpoint_config, transport_config)
}

/// Configures the client to skip *some* server certificates validation.
Expand Down Expand Up @@ -883,9 +924,10 @@ impl ClientConfigBuilder<states::WantsRootStore> {
Some(Arc::new(ServerHashVerification::new(hashes))),
);

let endpoint_config = EndpointConfig::default();
let transport_config = TransportConfig::default();

self.with(tls_config, transport_config)
self.with(tls_config, endpoint_config, transport_config)
}

/// Allows for manual configuration of a custom TLS setup using a provided
Expand All @@ -904,9 +946,10 @@ impl ClientConfigBuilder<states::WantsRootStore> {
self,
tls_config: TlsClientConfig,
) -> ClientConfigBuilder<states::WantsTransportConfigClient> {
let endpoint_config = EndpointConfig::default();
let transport_config = TransportConfig::default();

self.with(tls_config, transport_config)
self.with(tls_config, endpoint_config, transport_config)
}

/// Similar to [`with_native_certs`](Self::with_native_certs), but it allows specifying a custom
Expand Down Expand Up @@ -942,8 +985,9 @@ impl ClientConfigBuilder<states::WantsRootStore> {
use crate::tls::client::build_default_tls_config;

let tls_config = build_default_tls_config(Arc::new(build_native_cert_store()), None);
let quic_endpoint_config = EndpointConfig::default();

self.with(tls_config, quic_transport_config)
self.with(tls_config, quic_endpoint_config, quic_transport_config)
}

/// Configures the client with both a custom TLS configuration and a custom QUIC transport
Expand All @@ -967,7 +1011,8 @@ impl ClientConfigBuilder<states::WantsRootStore> {
tls_config: TlsClientConfig,
quic_transport_config: QuicTransportConfig,
) -> ClientConfigBuilder<states::WantsTransportConfigClient> {
self.with(tls_config, quic_transport_config)
let quic_endpoint_config = EndpointConfig::default();
self.with(tls_config, quic_endpoint_config, quic_transport_config)
}

/// Directly builds [`ClientConfig`] skipping TLS and transport configuration.
Expand All @@ -979,6 +1024,7 @@ impl ClientConfigBuilder<states::WantsRootStore> {
ClientConfig {
bind_address: self.0.bind_address,
dual_stack_config: self.0.dual_stack_config,
endpoint_config: EndpointConfig::default(),
quic_config,
dns_resolver: Arc::<TokioDnsResolver>::default(),
}
Expand All @@ -987,12 +1033,14 @@ impl ClientConfigBuilder<states::WantsRootStore> {
fn with(
self,
tls_config: TlsClientConfig,
endpoint_config: EndpointConfig,
transport_config: TransportConfig,
) -> ClientConfigBuilder<states::WantsTransportConfigClient> {
ClientConfigBuilder(states::WantsTransportConfigClient {
bind_address: self.0.bind_address,
dual_stack_config: self.0.dual_stack_config,
tls_config,
endpoint_config,
transport_config,
dns_resolver: Arc::<TokioDnsResolver>::default(),
})
Expand All @@ -1016,6 +1064,7 @@ impl ClientConfigBuilder<states::WantsTransportConfigClient> {
ClientConfig {
bind_address: self.0.bind_address,
dual_stack_config: self.0.dual_stack_config,
endpoint_config: self.0.endpoint_config,
quic_config,
dns_resolver: self.0.dns_resolver,
}
Expand Down Expand Up @@ -1102,6 +1151,7 @@ pub mod states {
pub(super) bind_address: SocketAddr,
pub(super) dual_stack_config: Ipv6DualStackConfig,
pub(super) tls_config: TlsServerConfig,
pub(super) endpoint_config: quinn::EndpointConfig,
pub(super) transport_config: quinn::TransportConfig,
pub(super) migration: bool,
}
Expand All @@ -1111,6 +1161,7 @@ pub mod states {
pub(super) bind_address: SocketAddr,
pub(super) dual_stack_config: Ipv6DualStackConfig,
pub(super) tls_config: TlsClientConfig,
pub(super) endpoint_config: quinn::EndpointConfig,
pub(super) transport_config: quinn::TransportConfig,
pub(super) dns_resolver: Arc<dyn DnsResolver + Send + Sync>,
}
Expand Down
17 changes: 5 additions & 12 deletions wtransport/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,17 +153,14 @@ impl<Side> Endpoint<Side> {
impl Endpoint<endpoint_side::Server> {
/// Constructs a *server* endpoint.
pub fn server(server_config: ServerConfig) -> std::io::Result<Self> {
let endpoint_config = server_config.endpoint_config;
let quic_config = server_config.quic_config;
let socket =
Self::bind_socket(server_config.bind_address, server_config.dual_stack_config)?;
let runtime = Arc::new(TokioRuntime);

let endpoint = quinn::Endpoint::new(
quinn::EndpointConfig::default(),
Some(quic_config),
socket.into(),
runtime,
)?;
let endpoint =
quinn::Endpoint::new(endpoint_config, Some(quic_config), socket.into(), runtime)?;

Ok(Self {
endpoint,
Expand Down Expand Up @@ -213,17 +210,13 @@ impl Endpoint<endpoint_side::Server> {
impl Endpoint<endpoint_side::Client> {
/// Constructs a *client* endpoint.
pub fn client(client_config: ClientConfig) -> std::io::Result<Self> {
let endpoint_config = client_config.endpoint_config;
let quic_config = client_config.quic_config;
let socket =
Self::bind_socket(client_config.bind_address, client_config.dual_stack_config)?;
let runtime = Arc::new(TokioRuntime);

let mut endpoint = quinn::Endpoint::new(
quinn::EndpointConfig::default(),
None,
socket.into(),
runtime,
)?;
let mut endpoint = quinn::Endpoint::new(endpoint_config, None, socket.into(), runtime)?;

endpoint.set_default_client_config(quic_config);

Expand Down

0 comments on commit 41f9fb7

Please sign in to comment.