@@ -271,18 +271,35 @@ pub fn setup_inbound<CMH: ChannelMessageHandler + 'static>(peer_manager: Arc<pee
271271///
272272/// See the module-level documentation for how to handle the event_notify mpsc::Sender.
273273pub fn setup_outbound < CMH : ChannelMessageHandler + ' static > ( peer_manager : Arc < peer_handler:: PeerManager < SocketDescriptor , Arc < CMH > > > , event_notify : mpsc:: Sender < ( ) > , their_node_id : PublicKey , stream : TcpStream ) -> impl std:: future:: Future < Output =( ) > {
274- let ( reader, write_receiver, read_receiver, us) = Connection :: new ( event_notify, stream) ;
274+ let ( reader, mut write_receiver, read_receiver, us) = Connection :: new ( event_notify, stream) ;
275275 #[ cfg( debug_assertions) ]
276276 let last_us = Arc :: clone ( & us) ;
277277
278278 let handle_opt = if let Ok ( initial_send) = peer_manager. new_outbound_connection ( their_node_id, SocketDescriptor :: new ( us. clone ( ) ) ) {
279279 Some ( tokio:: spawn ( async move {
280- if SocketDescriptor :: new ( us. clone ( ) ) . send_data ( & initial_send, true ) != initial_send. len ( ) {
281- // We should essentially always have enough room in a TCP socket buffer to send the
282- // initial 10s of bytes, if not, just give up as hopeless.
283- eprintln ! ( "Failed to write first full message to socket!" ) ;
284- peer_manager. socket_disconnected ( & SocketDescriptor :: new ( Arc :: clone ( & us) ) ) ;
285- } else {
280+ // We should essentially always have enough room in a TCP socket buffer to send the
281+ // initial 10s of bytes. However, tokio running in single-threaded mode will always
282+ // fail writes and wake us back up later to write. Thus, we handle a single
283+ // std::task::Poll::Pending but still expect to write the full set of bytes at once
284+ // and use a relatively tight timeout.
285+ if let Ok ( Ok ( ( ) ) ) = tokio:: time:: timeout ( Duration :: from_millis ( 100 ) , async {
286+ loop {
287+ match SocketDescriptor :: new ( us. clone ( ) ) . send_data ( & initial_send, true ) {
288+ v if v == initial_send. len ( ) => break Ok ( ( ) ) ,
289+ 0 => {
290+ write_receiver. recv ( ) . await ;
291+ // In theory we could check for if we've been instructed to disconnect
292+ // the peer here, but its OK to just skip it - we'll check for it in
293+ // schedule_read prior to any relevant calls into RL.
294+ } ,
295+ _ => {
296+ eprintln ! ( "Failed to write first full message to socket!" ) ;
297+ peer_manager. socket_disconnected ( & SocketDescriptor :: new ( Arc :: clone ( & us) ) ) ;
298+ break Err ( ( ) ) ;
299+ }
300+ }
301+ }
302+ } ) . await {
286303 Connection :: schedule_read ( peer_manager, us, reader, read_receiver, write_receiver) . await ;
287304 }
288305 } ) )
@@ -531,8 +548,7 @@ mod tests {
531548 }
532549 }
533550
534- #[ tokio:: test( threaded_scheduler) ]
535- async fn basic_connection_test ( ) {
551+ async fn do_basic_connection_test ( ) {
536552 let secp_ctx = Secp256k1 :: new ( ) ;
537553 let a_key = SecretKey :: from_slice ( & [ 1 ; 32 ] ) . unwrap ( ) ;
538554 let b_key = SecretKey :: from_slice ( & [ 1 ; 32 ] ) . unwrap ( ) ;
@@ -597,4 +613,13 @@ mod tests {
597613 fut_a. await ;
598614 fut_b. await ;
599615 }
616+
617+ #[ tokio:: test( threaded_scheduler) ]
618+ async fn basic_threaded_connection_test ( ) {
619+ do_basic_connection_test ( ) . await ;
620+ }
621+ #[ tokio:: test]
622+ async fn basic_unthreaded_connection_test ( ) {
623+ do_basic_connection_test ( ) . await ;
624+ }
600625}
0 commit comments