@@ -7,7 +7,7 @@ use crate::{
77 header_crypto:: { LONG_HEADER_MASK , SHORT_HEADER_MASK } ,
88 tls, CryptoSuite , HeaderKey , Key ,
99 } ,
10- transport,
10+ endpoint , transport,
1111} ;
1212use bytes:: Bytes ;
1313use core:: {
@@ -113,12 +113,12 @@ impl<S: tls::Session, C: tls::Session> Pair<S, C> {
113113 use crate :: crypto:: InitialKey ;
114114
115115 let server = server_endpoint. new_server_session ( & TEST_SERVER_TRANSPORT_PARAMS ) ;
116- let mut server_context = Context :: default ( ) ;
116+ let mut server_context = Context :: new ( endpoint :: Type :: Server ) ;
117117 server_context. initial . crypto = Some ( S :: InitialKey :: new_server ( server_name. as_bytes ( ) ) ) ;
118118
119119 let client =
120120 client_endpoint. new_client_session ( & TEST_CLIENT_TRANSPORT_PARAMS , server_name. clone ( ) ) ;
121- let mut client_context = Context :: default ( ) ;
121+ let mut client_context = Context :: new ( endpoint :: Type :: Client ) ;
122122 client_context. initial . crypto = Some ( C :: InitialKey :: new_client ( server_name. as_bytes ( ) ) ) ;
123123
124124 Self {
@@ -215,8 +215,14 @@ impl<S: tls::Session, C: tls::Session> Pair<S, C> {
215215 TEST_CLIENT_TRANSPORT_PARAMS ,
216216 "server did not receive the client transport parameters"
217217 ) ;
218- // TODO fix sni bug in s2n-quic-rustls
219- // assert_eq!(self.client.1.server_name.as_ref().expect("missing SNI on client"), &self.server_name[..]);
218+ assert_eq ! (
219+ self . client
220+ . 1
221+ . server_name
222+ . as_ref( )
223+ . expect( "missing SNI on client" ) ,
224+ & self . server_name[ ..]
225+ ) ;
220226 assert_eq ! (
221227 self . server
222228 . 1
@@ -240,26 +246,10 @@ pub struct Context<C: CryptoSuite> {
240246 pub server_name : Option < Bytes > ,
241247 pub application_protocol : Option < Bytes > ,
242248 pub transport_parameters : Option < Bytes > ,
249+ endpoint : endpoint:: Type ,
243250 waker : Waker ,
244251}
245252
246- impl < C : CryptoSuite > Default for Context < C > {
247- fn default ( ) -> Self {
248- let ( waker, _wake_counter) = new_count_waker ( ) ;
249- Self {
250- initial : Space :: default ( ) ,
251- handshake : Space :: default ( ) ,
252- application : Space :: default ( ) ,
253- zero_rtt_crypto : None ,
254- handshake_complete : false ,
255- server_name : None ,
256- application_protocol : None ,
257- transport_parameters : None ,
258- waker,
259- }
260- }
261- }
262-
263253impl < C : CryptoSuite > fmt:: Debug for Context < C > {
264254 fn fmt ( & self , f : & mut fmt:: Formatter ) -> fmt:: Result {
265255 f. debug_struct ( "Context" )
@@ -276,6 +266,22 @@ impl<C: CryptoSuite> fmt::Debug for Context<C> {
276266}
277267
278268impl < C : CryptoSuite > Context < C > {
269+ fn new ( endpoint : endpoint:: Type ) -> Self {
270+ let ( waker, _wake_counter) = new_count_waker ( ) ;
271+ Self {
272+ initial : Space :: default ( ) ,
273+ handshake : Space :: default ( ) ,
274+ application : Space :: default ( ) ,
275+ zero_rtt_crypto : None ,
276+ handshake_complete : false ,
277+ server_name : None ,
278+ application_protocol : None ,
279+ transport_parameters : None ,
280+ endpoint,
281+ waker,
282+ }
283+ }
284+
279285 /// Transfers incoming and outgoing buffers between two contexts
280286 pub fn transfer < O : CryptoSuite > ( & mut self , other : & mut Context < O > ) {
281287 self . initial . transfer ( & mut other. initial ) ;
@@ -288,11 +294,10 @@ impl<C: CryptoSuite> Context<C> {
288294 self . assert_done ( ) ;
289295 other. assert_done ( ) ;
290296
291- // TODO fix sni bug in s2n-quic-rustls
292- //assert_eq!(
293- // self.sni, other.sni,
294- // "sni is not consistent between endpoints"
295- //);
297+ assert_eq ! (
298+ self . server_name, other. server_name,
299+ "sni is not consistent between endpoints"
300+ ) ;
296301 assert_eq ! (
297302 self . application_protocol, other. application_protocol,
298303 "application_protocol is not consistent between endpoints"
@@ -322,13 +327,16 @@ impl<C: CryptoSuite> Context<C> {
322327 }
323328
324329 fn on_application_params ( & mut self , params : tls:: ApplicationParameters ) {
325- self . application_protocol = Some ( Bytes :: copy_from_slice ( params. application_protocol ) ) ;
326- self . server_name = params. server_name . map ( |sni| sni. into_bytes ( ) ) ;
327330 self . transport_parameters = Some ( Bytes :: copy_from_slice ( params. transport_parameters ) ) ;
328331 }
329332
330333 fn log ( & self , event : & str ) {
331- eprintln ! ( "{}: {}" , core:: any:: type_name:: <C >( ) , event) ;
334+ eprintln ! (
335+ "{:?}: {}: {}" ,
336+ self . endpoint,
337+ core:: any:: type_name:: <C >( ) ,
338+ event,
339+ ) ;
332340 }
333341}
334342
@@ -507,11 +515,33 @@ impl<C: CryptoSuite> tls::Context<C> for Context<C> {
507515 Ok ( ( ) )
508516 }
509517
518+ fn on_server_name (
519+ & mut self ,
520+ server_name : crate :: application:: ServerName ,
521+ ) -> Result < ( ) , transport:: Error > {
522+ self . log ( "server name" ) ;
523+ self . server_name = Some ( server_name. into_bytes ( ) ) ;
524+ Ok ( ( ) )
525+ }
526+
527+ fn on_application_protocol (
528+ & mut self ,
529+ application_protocol : Bytes ,
530+ ) -> Result < ( ) , transport:: Error > {
531+ self . log ( "application protocol" ) ;
532+ self . application_protocol = Some ( application_protocol) ;
533+ Ok ( ( ) )
534+ }
535+
510536 fn on_handshake_complete ( & mut self ) -> Result < ( ) , transport:: Error > {
511537 assert ! (
512538 !self . handshake_complete,
513539 "handshake complete called multiple times"
514540 ) ;
541+ assert ! (
542+ !self . application_protocol. as_ref( ) . unwrap( ) . is_empty( ) ,
543+ "application_protocol is empty at handshake complete"
544+ ) ;
515545 self . handshake_complete = true ;
516546 self . log ( "handshake complete" ) ;
517547 Ok ( ( ) )
0 commit comments