Skip to content

Commit

Permalink
use sender thread consistently to send msgs to a peer
Browse files Browse the repository at this point in the history
  • Loading branch information
antiochp committed Sep 30, 2019
1 parent 751ca06 commit f9a3a57
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 159 deletions.
124 changes: 34 additions & 90 deletions p2p/src/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@
use crate::core::ser;
use crate::core::ser::{FixedLength, ProtocolVersion};
use crate::msg::{
read_body, read_discard, read_header, read_item, write_to_buf, MsgHeader, MsgHeaderWrapper,
Type,
read_body, read_discard, read_header, read_item, write_message, Msg, MsgHeader,
MsgHeaderWrapper,
};
use crate::types::Error;
use crate::util::{RateCounter, RwLock};
use std::fs::File;
use std::io::{self, Read, Write};
use std::net::{Shutdown, TcpStream};
use std::sync::atomic::{AtomicBool, Ordering};
Expand All @@ -44,12 +43,7 @@ const IO_TIMEOUT: Duration = Duration::from_millis(1000);
/// A trait to be implemented in order to receive messages from the
/// connection. Allows providing an optional response.
pub trait MessageHandler: Send + 'static {
fn consume<'a>(
&self,
msg: Message<'a>,
writer: &'a mut dyn Write,
tracker: Arc<Tracker>,
) -> Result<Option<Response<'a>>, Error>;
fn consume<'a>(&self, msg: Message<'a>, tracker: Arc<Tracker>) -> Result<Option<Msg>, Error>;
}

// Macro to simplify the boilerplate around async I/O error handling,
Expand Down Expand Up @@ -121,64 +115,6 @@ impl<'a> Message<'a> {
}
}

/// Response to a `Message`.
pub struct Response<'a> {
resp_type: Type,
body: Vec<u8>,
version: ProtocolVersion,
stream: &'a mut dyn Write,
attachment: Option<File>,
}

impl<'a> Response<'a> {
pub fn new<T: ser::Writeable>(
resp_type: Type,
version: ProtocolVersion,
body: T,
stream: &'a mut dyn Write,
) -> Result<Response<'a>, Error> {
let body = ser::ser_vec(&body, version)?;
Ok(Response {
resp_type,
body,
version,
stream,
attachment: None,
})
}

fn write(mut self, tracker: Arc<Tracker>) -> Result<(), Error> {
let mut msg = ser::ser_vec(
&MsgHeader::new(self.resp_type, self.body.len() as u64),
self.version,
)?;
msg.append(&mut self.body);
self.stream.write_all(&msg[..])?;
tracker.inc_sent(msg.len() as u64);

if let Some(mut file) = self.attachment {
let mut buf = [0u8; 8000];
loop {
match file.read(&mut buf[..]) {
Ok(0) => break,
Ok(n) => {
self.stream.write_all(&buf[..n])?;
// Increase sent bytes "quietly" without incrementing the counter.
// (In a loop here for the single attachment).
tracker.inc_quiet_sent(n as u64);
}
Err(e) => return Err(From::from(e)),
}
}
}
Ok(())
}

pub fn add_attachment(&mut self, file: File) {
self.attachment = Some(file);
}
}

pub const SEND_CHANNEL_CAP: usize = 100;

pub struct StopHandle {
Expand Down Expand Up @@ -220,20 +156,16 @@ impl StopHandle {
}
}

#[derive(Clone)]
pub struct ConnHandle {
/// Channel to allow sending data through the connection
pub send_channel: mpsc::SyncSender<Vec<u8>>,
pub send_channel: mpsc::SyncSender<Msg>,
}

impl ConnHandle {
pub fn send<T>(&self, body: T, msg_type: Type, version: ProtocolVersion) -> Result<u64, Error>
where
T: ser::Writeable,
{
let buf = write_to_buf(body, msg_type, version)?;
let buf_len = buf.len();
self.send_channel.try_send(buf)?;
Ok(buf_len as u64)
pub fn send(&self, msg: Msg) -> Result<(), Error> {
self.send_channel.try_send(msg)?;
Ok(())
}
}

Expand Down Expand Up @@ -294,13 +226,22 @@ where

let stopped = Arc::new(AtomicBool::new(false));

let (reader_thread, writer_thread) =
poll(stream, version, handler, send_rx, stopped.clone(), tracker)?;
let conn_handle = ConnHandle {
send_channel: send_tx,
};

let (reader_thread, writer_thread) = poll(
stream,
conn_handle.clone(),
version,
handler,
send_rx,
stopped.clone(),
tracker,
)?;

Ok((
ConnHandle {
send_channel: send_tx,
},
conn_handle,
StopHandle {
stopped,
reader_thread: Some(reader_thread),
Expand All @@ -311,9 +252,10 @@ where

fn poll<H>(
conn: TcpStream,
conn_handle: ConnHandle,
version: ProtocolVersion,
handler: H,
send_rx: mpsc::Receiver<Vec<u8>>,
send_rx: mpsc::Receiver<Msg>,
stopped: Arc<AtomicBool>,
tracker: Arc<Tracker>,
) -> io::Result<(JoinHandle<()>, JoinHandle<()>)>
Expand All @@ -323,9 +265,11 @@ where
// Split out tcp stream out into separate reader/writer halves.
let mut reader = conn.try_clone().expect("clone conn for reader failed");
let mut writer = conn.try_clone().expect("clone conn for writer failed");
let mut responder = conn.try_clone().expect("clone conn for writer failed");
let reader_stopped = stopped.clone();

let reader_tracker = tracker.clone();
let writer_tracker = tracker.clone();

let reader_thread = thread::Builder::new()
.name("peer_read".to_string())
.spawn(move || {
Expand All @@ -342,17 +286,16 @@ where
);

// Increase received bytes counter
tracker.inc_received(MsgHeader::LEN as u64 + msg.header.msg_len);
reader_tracker.inc_received(MsgHeader::LEN as u64 + msg.header.msg_len);

if let Some(Some(resp)) =
try_break!(handler.consume(msg, &mut responder, tracker.clone()))
{
try_break!(resp.write(tracker.clone()));
let resp_msg = try_break!(handler.consume(msg, reader_tracker.clone()));
if let Some(Some(resp_msg)) = resp_msg {
try_break!(conn_handle.send(resp_msg));
}
}
Some(MsgHeaderWrapper::Unknown(msg_len)) => {
// Increase received bytes counter
tracker.inc_received(MsgHeader::LEN as u64 + msg_len);
reader_tracker.inc_received(MsgHeader::LEN as u64 + msg_len);

try_break!(read_discard(msg_len, &mut reader));
}
Expand Down Expand Up @@ -383,7 +326,8 @@ where
let maybe_data = retry_send.or_else(|_| send_rx.recv_timeout(IO_TIMEOUT));
retry_send = Err(());
if let Ok(data) = maybe_data {
let written = try_break!(writer.write_all(&data[..]).map_err(&From::from));
let written =
try_break!(write_message(&mut writer, &data, writer_tracker.clone()));
if written.is_none() {
retry_send = Ok(data);
}
Expand Down
12 changes: 9 additions & 3 deletions p2p/src/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::conn::Tracker;
use crate::core::core::hash::Hash;
use crate::core::pow::Difficulty;
use crate::core::ser::ProtocolVersion;
use crate::msg::{read_message, write_message, Hand, Shake, Type, USER_AGENT};
use crate::msg::{read_message, write_message, Hand, Msg, Shake, Type, USER_AGENT};
use crate::peer::Peer;
use crate::types::{Capabilities, Direction, Error, P2PConfig, PeerAddr, PeerInfo, PeerLiveInfo};
use crate::util::RwLock;
Expand Down Expand Up @@ -47,6 +48,7 @@ pub struct Handshake {
genesis: Hash,
config: P2PConfig,
protocol_version: ProtocolVersion,
tracker: Arc<Tracker>,
}

impl Handshake {
Expand All @@ -58,6 +60,7 @@ impl Handshake {
genesis,
config,
protocol_version: ProtocolVersion::local(),
tracker: Arc::new(Tracker::new()),
}
}

Expand Down Expand Up @@ -99,7 +102,8 @@ impl Handshake {
};

// write and read the handshake response
write_message(conn, hand, Type::Hand, self.protocol_version)?;
let msg = Msg::new(Type::Hand, hand, self.protocol_version)?;
write_message(conn, &msg, self.tracker.clone())?;

let shake: Shake = read_message(conn, self.protocol_version, Type::Shake)?;
if shake.genesis != self.genesis {
Expand Down Expand Up @@ -196,7 +200,9 @@ impl Handshake {
user_agent: USER_AGENT.to_string(),
};

write_message(conn, shake, Type::Shake, negotiated_version)?;
let msg = Msg::new(Type::Shake, shake, negotiated_version)?;
write_message(conn, &msg, self.tracker.clone())?;

trace!("Success handshake with {}.", peer_info.addr);

Ok(peer_info)
Expand Down
75 changes: 52 additions & 23 deletions p2p/src/msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

//! Message types that transit over the network and related serialization code.
use crate::conn::Tracker;
use crate::core::core::hash::Hash;
use crate::core::core::BlockHeader;
use crate::core::pow::Difficulty;
Expand All @@ -25,7 +26,9 @@ use crate::types::{
Capabilities, Error, PeerAddr, ReasonForBan, MAX_BLOCK_HEADERS, MAX_LOCATORS, MAX_PEER_ADDRS,
};
use num::FromPrimitive;
use std::fs::File;
use std::io::{Read, Write};
use std::sync::Arc;

/// Grin's user agent with current version
pub const USER_AGENT: &'static str = concat!("MW/Grin ", env!("CARGO_PKG_VERSION"));
Expand Down Expand Up @@ -114,6 +117,33 @@ fn magic() -> [u8; 2] {
}
}

pub struct Msg {
header: MsgHeader,
body: Vec<u8>,
attachment: Option<File>,
version: ProtocolVersion,
}

impl Msg {
pub fn new<T: Writeable>(
msg_type: Type,
msg: T,
version: ProtocolVersion,
) -> Result<Msg, Error> {
let body = ser::ser_vec(&msg, version)?;
Ok(Msg {
header: MsgHeader::new(msg_type, body.len() as u64),
body,
attachment: None,
version,
})
}

pub fn add_attachment(&mut self, attachment: File) {
self.attachment = Some(attachment)
}
}

/// Read a header from the provided stream without blocking if the
/// underlying stream is async. Typically headers will be polled for, so
/// we do not want to block.
Expand Down Expand Up @@ -182,32 +212,31 @@ pub fn read_message<T: Readable>(
}
}

pub fn write_to_buf<T: Writeable>(
msg: T,
msg_type: Type,
version: ProtocolVersion,
) -> Result<Vec<u8>, Error> {
// prepare the body first so we know its serialized length
let mut body_buf = vec![];
ser::serialize(&mut body_buf, version, &msg)?;

// build and serialize the header using the body size
let mut msg_buf = vec![];
let blen = body_buf.len() as u64;
ser::serialize(&mut msg_buf, version, &MsgHeader::new(msg_type, blen))?;
msg_buf.append(&mut body_buf);

Ok(msg_buf)
}

pub fn write_message<T: Writeable>(
pub fn write_message(
stream: &mut dyn Write,
msg: T,
msg_type: Type,
version: ProtocolVersion,
msg: &Msg,
tracker: Arc<Tracker>,
) -> Result<(), Error> {
let buf = write_to_buf(msg, msg_type, version)?;
let mut buf = ser::ser_vec(&msg.header, msg.version)?;
buf.extend(&msg.body[..]);
stream.write_all(&buf[..])?;
tracker.inc_sent(buf.len() as u64);
if let Some(file) = &msg.attachment {
let mut file = file.try_clone()?;
let mut buf = [0u8; 8000];
loop {
match file.read(&mut buf[..]) {
Ok(0) => break,
Ok(n) => {
stream.write_all(&buf[..n])?;
// Increase sent bytes "quietly" without incrementing the counter.
// (In a loop here for the single attachment).
tracker.inc_quiet_sent(n as u64);
}
Err(e) => return Err(From::from(e)),
}
}
}
Ok(())
}

Expand Down
10 changes: 3 additions & 7 deletions p2p/src/peer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use crate::core::ser::Writeable;
use crate::core::{core, global};
use crate::handshake::Handshake;
use crate::msg::{
self, BanReason, GetPeerAddrs, KernelDataRequest, Locator, Ping, TxHashSetRequest, Type,
self, BanReason, GetPeerAddrs, KernelDataRequest, Locator, Msg, Ping, TxHashSetRequest, Type,
};
use crate::protocol::Protocol;
use crate::types::{
Expand Down Expand Up @@ -233,12 +233,8 @@ impl Peer {

/// Send a msg with given msg_type to our peer via the connection.
fn send<T: Writeable>(&self, msg: T, msg_type: Type) -> Result<(), Error> {
let bytes = self
.send_handle
.lock()
.send(msg, msg_type, self.info.version)?;
self.tracker.inc_sent(bytes);
Ok(())
let msg = Msg::new(msg_type, msg, self.info.version)?;
self.send_handle.lock().send(msg)
}

/// Send a ping to the remote peer, providing our local difficulty and
Expand Down
Loading

0 comments on commit f9a3a57

Please sign in to comment.