diff --git a/core/src/ser.rs b/core/src/ser.rs index 367614dc1f..88dcd74e01 100644 --- a/core/src/ser.rs +++ b/core/src/ser.rs @@ -21,7 +21,6 @@ use crate::core::hash::{DefaultHashable, Hash, Hashed}; use crate::keychain::{BlindingFactor, Identifier, IDENTIFIER_SIZE}; -use crate::util::read_write::read_exact; use crate::util::secp::constants::{ AGG_SIGNATURE_SIZE, MAX_PROOF_SIZE, PEDERSEN_COMMITMENT_SIZE, SECRET_KEY_SIZE, }; @@ -31,7 +30,6 @@ use byteorder::{BigEndian, ByteOrder, ReadBytesExt}; use std::fmt::Debug; use std::io::{self, Read, Write}; use std::marker; -use std::time::Duration; use std::{cmp, error, fmt}; /// Possible errors deriving from serializing or deserializing. @@ -366,17 +364,14 @@ impl<'a> Reader for BinReader<'a> { pub struct StreamingReader<'a> { total_bytes_read: u64, stream: &'a mut dyn Read, - timeout: Duration, } impl<'a> StreamingReader<'a> { /// Create a new streaming reader with the provided underlying stream. - /// Also takes a duration to be used for each individual read_exact call. - pub fn new(stream: &'a mut dyn Read, timeout: Duration) -> StreamingReader<'a> { + pub fn new(stream: &'a mut dyn Read) -> StreamingReader<'a> { StreamingReader { total_bytes_read: 0, stream, - timeout, } } @@ -427,7 +422,7 @@ impl<'a> Reader for StreamingReader<'a> { /// Read a fixed number of bytes. fn read_fixed_bytes(&mut self, len: usize) -> Result, Error> { let mut buf = vec![0u8; len]; - read_exact(&mut self.stream, &mut buf, self.timeout, true)?; + self.stream.read_exact(&mut buf)?; self.total_bytes_read += len as u64; Ok(buf) } diff --git a/p2p/src/conn.rs b/p2p/src/conn.rs index d6313929a2..4b771dba52 100644 --- a/p2p/src/conn.rs +++ b/p2p/src/conn.rs @@ -20,18 +20,18 @@ //! forces us to go through some additional gymnastic to loop over the async //! stream and make sure we get the right number of bytes out. -use std::fs::File; -use std::io::{self, Read, Write}; -use std::net::{Shutdown, TcpStream}; -use std::sync::{mpsc, Arc}; -use std::{cmp, thread, time}; - use crate::core::ser; use crate::core::ser::FixedLength; use crate::msg::{read_body, read_header, read_item, write_to_buf, MsgHeader, Type}; use crate::types::Error; -use crate::util::read_write::{read_exact, write_all}; use crate::util::{RateCounter, RwLock}; +use std::fs::File; +use std::io::{self, Read, Write}; +use std::net::{Shutdown, TcpStream}; +use std::sync::atomic::{AtomicBool, AtomicU8, Ordering}; +use std::sync::{mpsc, Arc}; +use std::time::Duration; +use std::{cmp, thread}; /// A trait to be implemented in order to receive messages from the /// connection. Allows providing an optional response. @@ -40,7 +40,7 @@ pub trait MessageHandler: Send + 'static { &self, msg: Message<'a>, writer: &'a mut dyn Write, - received_bytes: Arc>, + tracker: Arc, ) -> Result>, Error>; } @@ -50,7 +50,11 @@ macro_rules! try_break { ($chan:ident, $inner:expr) => { match $inner { Ok(v) => Some(v), - Err(Error::Connection(ref e)) if e.kind() == io::ErrorKind::WouldBlock => None, + Err(Error::Connection(ref e)) + if e.kind() == io::ErrorKind::WouldBlock || e.kind() == io::ErrorKind::TimedOut => + { + None + } Err(e) => { let _ = $chan.send(e); break; @@ -87,12 +91,7 @@ impl<'a> Message<'a> { while written < len { let read_len = cmp::min(8000, len - written); let mut buf = vec![0u8; read_len]; - read_exact( - &mut self.stream, - &mut buf[..], - time::Duration::from_secs(10), - true, - )?; + self.stream.read_exact(&mut buf[..])?; writer.write_all(&mut buf)?; written += read_len; } @@ -123,26 +122,21 @@ impl<'a> Response<'a> { }) } - fn write(mut self, sent_bytes: Arc>) -> Result<(), Error> { + fn write(mut self, tracker: Arc) -> Result<(), Error> { let mut msg = ser::ser_vec(&MsgHeader::new(self.resp_type, self.body.len() as u64))?; msg.append(&mut self.body); - write_all(&mut self.stream, &msg[..], time::Duration::from_secs(10))?; - // Increase sent bytes counter - { - let mut sent_bytes = sent_bytes.write(); - sent_bytes.inc(msg.len() as u64); - } + self.stream.write_all(&msg[..])?; + tracker.sent_bytes_inc(msg.len() as u64); if let Some(mut file) = self.attachment { let mut buf = [0u8; 8000]; loop { match file.read(&mut buf[..]) { Ok(0) => break, Ok(n) => { - write_all(&mut self.stream, &buf[..n], time::Duration::from_secs(10))?; + self.stream.write_all(&buf[..n])?; // Increase sent bytes "quietly" without incrementing the counter. // (In a loop here for the single attachment). - let mut sent_bytes = sent_bytes.write(); - sent_bytes.inc_quiet(n as u64); + tracker.sent_bytes_inc_quiet(n as u64); } Err(e) => return Err(From::from(e)), } @@ -165,10 +159,10 @@ pub struct Tracker { pub received_bytes: Arc>, /// Channel to allow sending data through the connection pub send_channel: mpsc::SyncSender>, - /// Channel to close the connection - pub close_channel: mpsc::Sender<()>, /// Channel to check for errors on the connection - pub error_channel: mpsc::Receiver, + //pub error_channel: mpsc::Receiver, + closed: AtomicBool, + pub number_threads: AtomicU8, } impl Tracker { @@ -179,24 +173,48 @@ impl Tracker { let buf = write_to_buf(body, msg_type)?; let buf_len = buf.len(); self.send_channel.try_send(buf)?; + self.sent_bytes_inc(buf_len as u64); + Ok(()) + } - // Increase sent bytes counter + pub fn is_closed(&self) -> bool { + self.closed.load(Ordering::Relaxed) + } + + pub fn close(&self) { + self.closed.store(true, Ordering::Relaxed) + } + + pub fn received_bytes_inc(&self, bytes: u64) { + let mut received_bytes = self.received_bytes.write(); + received_bytes.inc(bytes); + } + + pub fn received_bytes_inc_quiet(&self, bytes: u64) { + let mut received_bytes = self.received_bytes.write(); + received_bytes.inc_quiet(bytes); + } + + pub fn sent_bytes_inc(&self, bytes: u64) { let mut sent_bytes = self.sent_bytes.write(); - sent_bytes.inc(buf_len as u64); + sent_bytes.inc(bytes); + } - Ok(()) + pub fn sent_bytes_inc_quiet(&self, bytes: u64) { + let mut sent_bytes = self.sent_bytes.write(); + sent_bytes.inc_quiet(bytes); } } +const IO_TIMEOUT: Duration = Duration::from_millis(1000); /// Start listening on the provided connection and wraps it. Does not hang /// the current thread, instead just returns a future and the Connection /// itself. -pub fn listen(stream: TcpStream, handler: H) -> Tracker +pub fn listen(stream: TcpStream, handler: H) -> (Arc, mpsc::Receiver) where H: MessageHandler, { let (send_tx, send_rx) = mpsc::sync_channel(SEND_CHANNEL_CAP); - let (close_tx, close_rx) = mpsc::channel(); let (error_tx, error_rx) = mpsc::channel(); // Counter of number of bytes received @@ -205,25 +223,25 @@ where let sent_bytes = Arc::new(RwLock::new(RateCounter::new())); stream - .set_nonblocking(true) - .expect("Non-blocking IO not available."); - poll( - stream, - handler, - send_rx, - error_tx, - close_rx, - received_bytes.clone(), - sent_bytes.clone(), - ); - - Tracker { + .set_read_timeout(Some(IO_TIMEOUT)) + .expect("can't set read timeout"); + stream + .set_write_timeout(Some(IO_TIMEOUT)) + .expect("can't set write timeout"); + //stream + // .set_nonblocking(true) + // .expect("Non-blocking IO not available."); + // + + let tracker = Arc::new(Tracker { sent_bytes: sent_bytes.clone(), received_bytes: received_bytes.clone(), send_channel: send_tx, - close_channel: close_tx, - error_channel: error_rx, - } + closed: AtomicBool::new(false), + number_threads: AtomicU8::new(0), + }); + poll(stream, handler, send_rx, error_tx, tracker.clone()); + (tracker, error_rx) } fn poll( @@ -231,24 +249,23 @@ fn poll( handler: H, send_rx: mpsc::Receiver>, error_tx: mpsc::Sender, - close_rx: mpsc::Receiver<()>, - received_bytes: Arc>, - sent_bytes: Arc>, + tracker: Arc, ) where H: MessageHandler, { // Split out tcp stream out into separate reader/writer halves. let mut reader = conn.try_clone().expect("clone conn for reader failed"); + let mut responder = conn.try_clone().expect("clone conn for reader failed"); let mut writer = conn.try_clone().expect("clone conn for writer failed"); + let tracker_read = tracker.clone(); + let error_read_tx = error_tx.clone(); let _ = thread::Builder::new() - .name("peer".to_string()) + .name("peer_read".to_string()) .spawn(move || { - let sleep_time = time::Duration::from_millis(5); - let mut retry_send = Err(()); + tracker_read.number_threads.fetch_add(1, Ordering::Relaxed); loop { - // check the read end - if let Some(h) = try_break!(error_tx, read_header(&mut reader, None)) { + if let Some(h) = try_break!(error_read_tx, read_header(&mut reader, None)) { let msg = Message::from_header(h, &mut reader); trace!( @@ -258,21 +275,38 @@ fn poll( ); // Increase received bytes counter - let received = received_bytes.clone(); - { - let mut received_bytes = received_bytes.write(); - received_bytes.inc(MsgHeader::LEN as u64 + msg.header.msg_len); - } + tracker_read.received_bytes_inc(MsgHeader::LEN as u64 + msg.header.msg_len); - if let Some(Some(resp)) = - try_break!(error_tx, handler.consume(msg, &mut writer, received)) - { - try_break!(error_tx, resp.write(sent_bytes.clone())); + if let Some(Some(resp)) = try_break!( + error_read_tx, + handler.consume(msg, &mut responder, tracker_read.clone()) + ) { + try_break!(error_read_tx, resp.write(tracker_read.clone())); } } - // check the write end, use or_else so try_recv is lazily eval'd - let maybe_data = retry_send.or_else(|_| send_rx.try_recv()); + if tracker_read.is_closed() { + debug!( + "Connection close with {} initiated by us", + conn.peer_addr() + .map(|a| a.to_string()) + .unwrap_or("?".to_owned()) + ); + break; + } + } + tracker_read.number_threads.fetch_sub(1, Ordering::Relaxed); + let _ = conn.shutdown(Shutdown::Both); + }); + + let _ = thread::Builder::new() + .name("peer_write".to_string()) + .spawn(move || { + tracker.number_threads.fetch_add(1, Ordering::Relaxed); + let mut retry_send = Err(()); + // check the write end, use or_else so try_recv is lazily eval'd + loop { + let maybe_data = retry_send.or_else(|_| send_rx.recv_timeout(IO_TIMEOUT)); retry_send = Err(()); if let Ok(data) = maybe_data { let written = @@ -281,20 +315,11 @@ fn poll( retry_send = Ok(data); } } - - // check the close channel - if let Ok(_) = close_rx.try_recv() { - debug!( - "Connection close with {} initiated by us", - conn.peer_addr() - .map(|a| a.to_string()) - .unwrap_or("?".to_owned()) - ); + if tracker.is_closed() { + debug!("Connection close with initiated by us, closing writer end",); break; } - - thread::sleep(sleep_time); } - let _ = conn.shutdown(Shutdown::Both); + tracker.number_threads.fetch_sub(1, Ordering::Relaxed); }); } diff --git a/p2p/src/msg.rs b/p2p/src/msg.rs index 30a7892af1..7e7578ef7c 100644 --- a/p2p/src/msg.rs +++ b/p2p/src/msg.rs @@ -14,10 +14,6 @@ //! Message types that transit over the network and related serialization code. -use num::FromPrimitive; -use std::io::{Read, Write}; -use std::time; - use crate::core::core::hash::Hash; use crate::core::core::BlockHeader; use crate::core::pow::Difficulty; @@ -26,7 +22,8 @@ use crate::core::{consensus, global}; use crate::types::{ Capabilities, Error, PeerAddr, ReasonForBan, MAX_BLOCK_HEADERS, MAX_LOCATORS, MAX_PEER_ADDRS, }; -use crate::util::read_write::read_exact; +use num::FromPrimitive; +use std::io::{Read, Write}; /// Our local node protocol version. /// We will increment the protocol version with every change to p2p msg serialization @@ -122,9 +119,9 @@ fn magic() -> [u8; 2] { pub fn read_header(stream: &mut dyn Read, msg_type: Option) -> Result { let mut head = vec![0u8; MsgHeader::LEN]; if Some(Type::Hand) == msg_type { - read_exact(stream, &mut head, time::Duration::from_millis(10), true)?; + stream.read_exact(&mut head)?; } else { - read_exact(stream, &mut head, time::Duration::from_secs(10), false)?; + stream.read_exact(&mut head)?; } let header = ser::deserialize::(&mut &head[..])?; let max_len = max_msg_size(header.msg_type); @@ -144,8 +141,7 @@ pub fn read_header(stream: &mut dyn Read, msg_type: Option) -> Result(stream: &mut dyn Read) -> Result<(T, u64), Error> { - let timeout = time::Duration::from_secs(20); - let mut reader = StreamingReader::new(stream, timeout); + let mut reader = StreamingReader::new(stream); let res = T::read(&mut reader)?; Ok((res, reader.total_bytes_read())) } @@ -154,7 +150,7 @@ pub fn read_item(stream: &mut dyn Read) -> Result<(T, u64), Error> /// until we have a result (or timeout). pub fn read_body(h: &MsgHeader, stream: &mut dyn Read) -> Result { let mut body = vec![0u8; h.msg_len as usize]; - read_exact(stream, &mut body, time::Duration::from_secs(20), true)?; + stream.read_exact(&mut body)?; ser::deserialize(&mut &body[..]).map_err(From::from) } diff --git a/p2p/src/peer.rs b/p2p/src/peer.rs index 52eb672c7f..8e48a1ffd3 100644 --- a/p2p/src/peer.rs +++ b/p2p/src/peer.rs @@ -12,13 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::util::{Mutex, RwLock}; -use std::fmt; -use std::fs::File; -use std::net::{Shutdown, TcpStream}; -use std::path::PathBuf; -use std::sync::Arc; - use crate::chain; use crate::conn; use crate::core::core::hash::{Hash, Hashed}; @@ -31,7 +24,14 @@ use crate::types::{ Capabilities, ChainAdapter, Error, NetAdapter, P2PConfig, PeerAddr, PeerInfo, ReasonForBan, TxHashSetRead, }; +use crate::util::{Mutex, RwLock}; use chrono::prelude::{DateTime, Utc}; +use std::fmt; +use std::fs::File; +use std::net::{Shutdown, TcpStream}; +use std::path::PathBuf; +use std::sync::mpsc; +use std::sync::Arc; const MAX_TRACK_SIZE: usize = 30; const MAX_PEER_MSG_PER_MIN: u64 = 500; @@ -54,13 +54,14 @@ pub struct Peer { state: Arc>, // set of all hashes known to this peer (so no need to send) tracking_adapter: TrackingAdapter, - connection: Option>, + connection: Option>, + errors: Option>>, } macro_rules! connection { ($holder:expr) => { match $holder.connection.as_ref() { - Some(conn) => conn.lock(), + Some(conn) => conn, None => return Err(Error::Internal), } }; @@ -80,6 +81,7 @@ impl Peer { state: Arc::new(RwLock::new(State::Connected)), tracking_adapter: TrackingAdapter::new(adapter), connection: None, + errors: None, } } @@ -139,7 +141,9 @@ impl Peer { pub fn start(&mut self, conn: TcpStream) { let adapter = Arc::new(self.tracking_adapter.clone()); let handler = Protocol::new(adapter, self.info.clone()); - self.connection = Some(Mutex::new(conn::listen(conn, handler))); + let (tracker, errors) = conn::listen(conn, handler); + self.connection = Some(tracker); + self.errors = Some(Mutex::new(errors)); } pub fn is_denied(config: &P2PConfig, peer_addr: PeerAddr) -> bool { @@ -198,7 +202,6 @@ impl Peer { /// Whether the peer is considered abusive, mostly for spammy nodes pub fn is_abusive(&self) -> bool { if let Some(ref conn) = self.connection { - let conn = conn.lock(); let rec = conn.received_bytes.read(); let sent = conn.sent_bytes.read(); rec.count_per_min() > MAX_PEER_MSG_PER_MIN @@ -211,8 +214,7 @@ impl Peer { /// Number of bytes sent to the peer pub fn last_min_sent_bytes(&self) -> Option { if let Some(ref tracker) = self.connection { - let conn = tracker.lock(); - let sent_bytes = conn.sent_bytes.read(); + let sent_bytes = tracker.sent_bytes.read(); return Some(sent_bytes.bytes_per_min()); } None @@ -221,8 +223,7 @@ impl Peer { /// Number of bytes received from the peer pub fn last_min_received_bytes(&self) -> Option { if let Some(ref tracker) = self.connection { - let conn = tracker.lock(); - let received_bytes = conn.received_bytes.read(); + let received_bytes = tracker.received_bytes.read(); return Some(received_bytes.bytes_per_min()); } None @@ -230,9 +231,8 @@ impl Peer { pub fn last_min_message_counts(&self) -> Option<(u64, u64)> { if let Some(ref tracker) = self.connection { - let conn = tracker.lock(); - let received_bytes = conn.received_bytes.read(); - let sent_bytes = conn.sent_bytes.read(); + let received_bytes = tracker.received_bytes.read(); + let sent_bytes = tracker.sent_bytes.read(); return Some((sent_bytes.count_per_min(), received_bytes.count_per_min())); } None @@ -409,16 +409,20 @@ impl Peer { /// Stops the peer, closing its connection pub fn stop(&self) { if let Some(conn) = self.connection.as_ref() { - stop_with_connection(&conn.lock()); + stop_with_connection(&conn); } } fn check_connection(&self) -> bool { - let connection = match self.connection.as_ref() { - Some(conn) => conn.lock(), + let conn = match self.connection.as_ref() { + Some(conn) => conn, + None => return false, + }; + let errors = match self.errors.as_ref() { + Some(errors) => errors.lock(), None => return false, }; - match connection.error_channel.try_recv() { + match errors.try_recv() { Ok(Error::Serialization(e)) => { let need_stop = { let mut state = self.state.write(); @@ -434,7 +438,7 @@ impl Peer { "Client {} corrupted, will disconnect ({:?}).", self.info.addr, e ); - stop_with_connection(&connection); + stop_with_connection(&conn); } false } @@ -450,7 +454,7 @@ impl Peer { }; if need_stop { debug!("Client {} connection lost: {:?}", self.info.addr, e); - stop_with_connection(&connection); + stop_with_connection(conn); } false } @@ -463,7 +467,15 @@ impl Peer { } fn stop_with_connection(connection: &conn::Tracker) { - let _ = connection.close_channel.send(()); + connection.close(); + // check server shutdown + // while connection + // .number_threads + // .load(std::sync::atomic::Ordering::Relaxed) + // != 0 + // { + // std::thread::sleep(std::time::Duration::from_millis(200)); + // } } /// Adapter implementation that forwards everything to an underlying adapter diff --git a/p2p/src/peers.rs b/p2p/src/peers.rs index 6b5986b6be..d9cfabac3c 100644 --- a/p2p/src/peers.rs +++ b/p2p/src/peers.rs @@ -434,9 +434,20 @@ impl Peers { } pub fn stop(&self) { + let mut handles = vec![]; let mut peers = self.peers.write(); for (_, peer) in peers.drain() { - peer.stop(); + handles.push( + std::thread::Builder::new() + .name("peer_stop".to_string()) + .spawn(move || { + peer.stop(); + }) + .unwrap(), + ); + } + for h in handles { + let _ = h.join(); } } diff --git a/p2p/src/protocol.rs b/p2p/src/protocol.rs index 200d1f3703..0b92c3390b 100644 --- a/p2p/src/protocol.rs +++ b/p2p/src/protocol.rs @@ -12,17 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +use crate::conn::{Message, MessageHandler, Response, Tracker}; +use crate::core::core::{self, hash::Hash, CompactBlock}; +use chrono::prelude::Utc; use rand::{thread_rng, Rng}; use std::cmp; use std::fs::{self, File, OpenOptions}; use std::io::{BufWriter, Write}; use std::sync::Arc; -use crate::conn::{Message, MessageHandler, Response}; -use crate::core::core::{self, hash::Hash, CompactBlock}; -use crate::util::{RateCounter, RwLock}; -use chrono::prelude::Utc; - use crate::msg::{ BanReason, GetPeerAddrs, Headers, Locator, PeerAddrs, Ping, Pong, TxHashSetArchive, TxHashSetRequest, Type, @@ -45,7 +43,7 @@ impl MessageHandler for Protocol { &self, mut msg: Message<'a>, writer: &'a mut dyn Write, - received_bytes: Arc>, + tracker: Arc, ) -> Result>, Error> { let adapter = &self.adapter; @@ -312,10 +310,7 @@ impl MessageHandler for Protocol { // Increase received bytes quietly (without affecting the counters). // Otherwise we risk banning a peer as "abusive". - { - let mut received_bytes = received_bytes.write(); - received_bytes.inc_quiet(size as u64); - } + tracker.received_bytes_inc_quiet(size as u64) } tmp_zip .into_inner() diff --git a/store/src/types.rs b/store/src/types.rs index 5f681d8a6b..9e0520a484 100644 --- a/store/src/types.rs +++ b/store/src/types.rs @@ -22,7 +22,6 @@ use std::fs::{self, File, OpenOptions}; use std::io::{self, BufReader, BufWriter, Write}; use std::marker; use std::path::{Path, PathBuf}; -use std::time; /// Represents a single entry in the size_file. /// Offset (in bytes) and size (in bytes) of a variable sized entry @@ -446,8 +445,7 @@ where { let reader = File::open(&self.path)?; let mut buf_reader = BufReader::new(reader); - let mut streaming_reader = - StreamingReader::new(&mut buf_reader, time::Duration::from_secs(1)); + let mut streaming_reader = StreamingReader::new(&mut buf_reader); let mut buf_writer = BufWriter::new(File::create(&tmp_path)?); let mut bin_writer = BinWriter::new(&mut buf_writer); @@ -493,8 +491,7 @@ where { let reader = File::open(&self.path)?; let mut buf_reader = BufReader::new(reader); - let mut streaming_reader = - StreamingReader::new(&mut buf_reader, time::Duration::from_secs(1)); + let mut streaming_reader = StreamingReader::new(&mut buf_reader); let mut buf_writer = BufWriter::new(File::create(&tmp_path)?); let mut bin_writer = BinWriter::new(&mut buf_writer); diff --git a/util/src/lib.rs b/util/src/lib.rs index 93883e951f..e91bb348af 100644 --- a/util/src/lib.rs +++ b/util/src/lib.rs @@ -47,9 +47,6 @@ pub use crate::types::{LogLevel, LoggingConfig, ZeroingString}; pub mod macros; -// read_exact and write_all impls -pub mod read_write; - // other utils #[allow(unused_imports)] use std::ops::Deref; diff --git a/util/src/read_write.rs b/util/src/read_write.rs deleted file mode 100644 index 15e3f3f72a..0000000000 --- a/util/src/read_write.rs +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright 2018 The Grin Developers -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! Custom impls of read_exact and write_all to work around async stream restrictions. - -use std::io; -use std::io::prelude::*; -use std::thread; -use std::time::Duration; - -/// The default implementation of read_exact is useless with an async stream (TcpStream) as -/// it will return as soon as something has been read, regardless of -/// whether the buffer has been filled (and then errors). This implementation -/// will block until it has read exactly `len` bytes and returns them as a -/// `vec`. Except for a timeout, this implementation will never return a -/// partially filled buffer. -/// -/// The timeout in milliseconds aborts the read when it's met. Note that the -/// time is not guaranteed to be exact. To support cases where we want to poll -/// instead of blocking, a `block_on_empty` boolean, when false, ensures -/// `read_exact` returns early with a `io::ErrorKind::WouldBlock` if nothing -/// has been read from the socket. -pub fn read_exact( - stream: &mut dyn Read, - mut buf: &mut [u8], - timeout: Duration, - block_on_empty: bool, -) -> io::Result<()> { - let sleep_time = Duration::from_micros(10); - let mut count = Duration::new(0, 0); - - let mut read = 0; - loop { - match stream.read(buf) { - Ok(0) => { - return Err(io::Error::new( - io::ErrorKind::ConnectionAborted, - "read_exact", - )); - } - Ok(n) => { - let tmp = buf; - buf = &mut tmp[n..]; - read += n; - } - Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {} - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - if read == 0 && !block_on_empty { - return Err(io::Error::new(io::ErrorKind::WouldBlock, "read_exact")); - } - } - Err(e) => return Err(e), - } - if !buf.is_empty() { - thread::sleep(sleep_time); - count += sleep_time; - } else { - break; - } - if count > timeout { - return Err(io::Error::new( - io::ErrorKind::TimedOut, - "reading from stream", - )); - } - } - Ok(()) -} - -/// Same as `read_exact` but for writing. -pub fn write_all(stream: &mut dyn Write, mut buf: &[u8], timeout: Duration) -> io::Result<()> { - let sleep_time = Duration::from_micros(10); - let mut count = Duration::new(0, 0); - - while !buf.is_empty() { - match stream.write(buf) { - Ok(0) => { - return Err(io::Error::new( - io::ErrorKind::WriteZero, - "failed to write whole buffer", - )); - } - Ok(n) => buf = &buf[n..], - Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {} - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {} - Err(e) => return Err(e), - } - if !buf.is_empty() { - thread::sleep(sleep_time); - count += sleep_time; - } else { - break; - } - if count > timeout { - return Err(io::Error::new(io::ErrorKind::TimedOut, "writing to stream")); - } - } - Ok(()) -}