Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use sender thread consistently to send msgs to a peer #3067

Merged
merged 1 commit into from
Oct 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 34 additions & 90 deletions p2p/src/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@
use crate::core::ser;
use crate::core::ser::{FixedLength, ProtocolVersion};
use crate::msg::{
read_body, read_discard, read_header, read_item, write_to_buf, MsgHeader, MsgHeaderWrapper,
Type,
read_body, read_discard, read_header, read_item, write_message, Msg, MsgHeader,
MsgHeaderWrapper,
};
use crate::types::Error;
use crate::util::{RateCounter, RwLock};
use std::fs::File;
use std::io::{self, Read, Write};
use std::net::{Shutdown, TcpStream};
use std::sync::atomic::{AtomicBool, Ordering};
Expand All @@ -44,12 +43,7 @@ const IO_TIMEOUT: Duration = Duration::from_millis(1000);
/// A trait to be implemented in order to receive messages from the
/// connection. Allows providing an optional response.
pub trait MessageHandler: Send + 'static {
fn consume<'a>(
&self,
msg: Message<'a>,
writer: &'a mut dyn Write,
tracker: Arc<Tracker>,
) -> Result<Option<Response<'a>>, Error>;
fn consume<'a>(&self, msg: Message<'a>, tracker: Arc<Tracker>) -> Result<Option<Msg>, Error>;
}

// Macro to simplify the boilerplate around async I/O error handling,
Expand Down Expand Up @@ -121,64 +115,6 @@ impl<'a> Message<'a> {
}
}

/// Response to a `Message`.
pub struct Response<'a> {
resp_type: Type,
body: Vec<u8>,
version: ProtocolVersion,
stream: &'a mut dyn Write,
attachment: Option<File>,
}

impl<'a> Response<'a> {
pub fn new<T: ser::Writeable>(
resp_type: Type,
version: ProtocolVersion,
body: T,
stream: &'a mut dyn Write,
) -> Result<Response<'a>, Error> {
let body = ser::ser_vec(&body, version)?;
Ok(Response {
resp_type,
body,
version,
stream,
attachment: None,
})
}

fn write(mut self, tracker: Arc<Tracker>) -> Result<(), Error> {
let mut msg = ser::ser_vec(
&MsgHeader::new(self.resp_type, self.body.len() as u64),
self.version,
)?;
msg.append(&mut self.body);
self.stream.write_all(&msg[..])?;
tracker.inc_sent(msg.len() as u64);

if let Some(mut file) = self.attachment {
let mut buf = [0u8; 8000];
loop {
match file.read(&mut buf[..]) {
Ok(0) => break,
Ok(n) => {
self.stream.write_all(&buf[..n])?;
// Increase sent bytes "quietly" without incrementing the counter.
// (In a loop here for the single attachment).
tracker.inc_quiet_sent(n as u64);
}
Err(e) => return Err(From::from(e)),
}
}
}
Ok(())
}

pub fn add_attachment(&mut self, file: File) {
self.attachment = Some(file);
}
}

pub const SEND_CHANNEL_CAP: usize = 100;

pub struct StopHandle {
Expand Down Expand Up @@ -220,20 +156,16 @@ impl StopHandle {
}
}

#[derive(Clone)]
pub struct ConnHandle {
/// Channel to allow sending data through the connection
pub send_channel: mpsc::SyncSender<Vec<u8>>,
pub send_channel: mpsc::SyncSender<Msg>,
}

impl ConnHandle {
pub fn send<T>(&self, body: T, msg_type: Type, version: ProtocolVersion) -> Result<u64, Error>
where
T: ser::Writeable,
{
let buf = write_to_buf(body, msg_type, version)?;
let buf_len = buf.len();
self.send_channel.try_send(buf)?;
Ok(buf_len as u64)
pub fn send(&self, msg: Msg) -> Result<(), Error> {
self.send_channel.try_send(msg)?;
Ok(())
}
}

Expand Down Expand Up @@ -294,13 +226,22 @@ where

let stopped = Arc::new(AtomicBool::new(false));

let (reader_thread, writer_thread) =
poll(stream, version, handler, send_rx, stopped.clone(), tracker)?;
let conn_handle = ConnHandle {
send_channel: send_tx,
};

let (reader_thread, writer_thread) = poll(
stream,
conn_handle.clone(),
version,
handler,
send_rx,
stopped.clone(),
tracker,
)?;

Ok((
ConnHandle {
send_channel: send_tx,
},
conn_handle,
StopHandle {
stopped,
reader_thread: Some(reader_thread),
Expand All @@ -311,9 +252,10 @@ where

fn poll<H>(
conn: TcpStream,
conn_handle: ConnHandle,
version: ProtocolVersion,
handler: H,
send_rx: mpsc::Receiver<Vec<u8>>,
send_rx: mpsc::Receiver<Msg>,
stopped: Arc<AtomicBool>,
tracker: Arc<Tracker>,
) -> io::Result<(JoinHandle<()>, JoinHandle<()>)>
Expand All @@ -323,9 +265,11 @@ where
// Split out tcp stream out into separate reader/writer halves.
let mut reader = conn.try_clone().expect("clone conn for reader failed");
let mut writer = conn.try_clone().expect("clone conn for writer failed");
let mut responder = conn.try_clone().expect("clone conn for writer failed");
let reader_stopped = stopped.clone();

let reader_tracker = tracker.clone();
let writer_tracker = tracker.clone();

let reader_thread = thread::Builder::new()
.name("peer_read".to_string())
.spawn(move || {
Expand All @@ -342,17 +286,16 @@ where
);

// Increase received bytes counter
tracker.inc_received(MsgHeader::LEN as u64 + msg.header.msg_len);
reader_tracker.inc_received(MsgHeader::LEN as u64 + msg.header.msg_len);

if let Some(Some(resp)) =
try_break!(handler.consume(msg, &mut responder, tracker.clone()))
{
try_break!(resp.write(tracker.clone()));
let resp_msg = try_break!(handler.consume(msg, reader_tracker.clone()));
if let Some(Some(resp_msg)) = resp_msg {
try_break!(conn_handle.send(resp_msg));
}
}
Some(MsgHeaderWrapper::Unknown(msg_len)) => {
// Increase received bytes counter
tracker.inc_received(MsgHeader::LEN as u64 + msg_len);
reader_tracker.inc_received(MsgHeader::LEN as u64 + msg_len);

try_break!(read_discard(msg_len, &mut reader));
}
Expand Down Expand Up @@ -383,7 +326,8 @@ where
let maybe_data = retry_send.or_else(|_| send_rx.recv_timeout(IO_TIMEOUT));
retry_send = Err(());
if let Ok(data) = maybe_data {
let written = try_break!(writer.write_all(&data[..]).map_err(&From::from));
let written =
try_break!(write_message(&mut writer, &data, writer_tracker.clone()));
if written.is_none() {
retry_send = Ok(data);
}
Expand Down
12 changes: 9 additions & 3 deletions p2p/src/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::conn::Tracker;
use crate::core::core::hash::Hash;
use crate::core::pow::Difficulty;
use crate::core::ser::ProtocolVersion;
use crate::msg::{read_message, write_message, Hand, Shake, Type, USER_AGENT};
use crate::msg::{read_message, write_message, Hand, Msg, Shake, Type, USER_AGENT};
use crate::peer::Peer;
use crate::types::{Capabilities, Direction, Error, P2PConfig, PeerAddr, PeerInfo, PeerLiveInfo};
use crate::util::RwLock;
Expand Down Expand Up @@ -47,6 +48,7 @@ pub struct Handshake {
genesis: Hash,
config: P2PConfig,
protocol_version: ProtocolVersion,
tracker: Arc<Tracker>,
}

impl Handshake {
Expand All @@ -58,6 +60,7 @@ impl Handshake {
genesis,
config,
protocol_version: ProtocolVersion::local(),
tracker: Arc::new(Tracker::new()),
}
}

Expand Down Expand Up @@ -99,7 +102,8 @@ impl Handshake {
};

// write and read the handshake response
write_message(conn, hand, Type::Hand, self.protocol_version)?;
let msg = Msg::new(Type::Hand, hand, self.protocol_version)?;
write_message(conn, &msg, self.tracker.clone())?;

let shake: Shake = read_message(conn, self.protocol_version, Type::Shake)?;
if shake.genesis != self.genesis {
Expand Down Expand Up @@ -196,7 +200,9 @@ impl Handshake {
user_agent: USER_AGENT.to_string(),
};

write_message(conn, shake, Type::Shake, negotiated_version)?;
let msg = Msg::new(Type::Shake, shake, negotiated_version)?;
write_message(conn, &msg, self.tracker.clone())?;

trace!("Success handshake with {}.", peer_info.addr);

Ok(peer_info)
Expand Down
75 changes: 52 additions & 23 deletions p2p/src/msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

//! Message types that transit over the network and related serialization code.

use crate::conn::Tracker;
use crate::core::core::hash::Hash;
use crate::core::core::BlockHeader;
use crate::core::pow::Difficulty;
Expand All @@ -25,7 +26,9 @@ use crate::types::{
Capabilities, Error, PeerAddr, ReasonForBan, MAX_BLOCK_HEADERS, MAX_LOCATORS, MAX_PEER_ADDRS,
};
use num::FromPrimitive;
use std::fs::File;
use std::io::{Read, Write};
use std::sync::Arc;

/// Grin's user agent with current version
pub const USER_AGENT: &'static str = concat!("MW/Grin ", env!("CARGO_PKG_VERSION"));
Expand Down Expand Up @@ -114,6 +117,33 @@ fn magic() -> [u8; 2] {
}
}

pub struct Msg {
header: MsgHeader,
body: Vec<u8>,
attachment: Option<File>,
version: ProtocolVersion,
}

impl Msg {
pub fn new<T: Writeable>(
msg_type: Type,
msg: T,
version: ProtocolVersion,
) -> Result<Msg, Error> {
let body = ser::ser_vec(&msg, version)?;
Ok(Msg {
header: MsgHeader::new(msg_type, body.len() as u64),
body,
attachment: None,
version,
})
}

pub fn add_attachment(&mut self, attachment: File) {
self.attachment = Some(attachment)
}
}

/// Read a header from the provided stream without blocking if the
/// underlying stream is async. Typically headers will be polled for, so
/// we do not want to block.
Expand Down Expand Up @@ -182,32 +212,31 @@ pub fn read_message<T: Readable>(
}
}

pub fn write_to_buf<T: Writeable>(
msg: T,
msg_type: Type,
version: ProtocolVersion,
) -> Result<Vec<u8>, Error> {
// prepare the body first so we know its serialized length
let mut body_buf = vec![];
ser::serialize(&mut body_buf, version, &msg)?;

// build and serialize the header using the body size
let mut msg_buf = vec![];
let blen = body_buf.len() as u64;
ser::serialize(&mut msg_buf, version, &MsgHeader::new(msg_type, blen))?;
msg_buf.append(&mut body_buf);

Ok(msg_buf)
}

pub fn write_message<T: Writeable>(
pub fn write_message(
stream: &mut dyn Write,
msg: T,
msg_type: Type,
version: ProtocolVersion,
msg: &Msg,
tracker: Arc<Tracker>,
) -> Result<(), Error> {
let buf = write_to_buf(msg, msg_type, version)?;
let mut buf = ser::ser_vec(&msg.header, msg.version)?;
buf.extend(&msg.body[..]);
stream.write_all(&buf[..])?;
tracker.inc_sent(buf.len() as u64);
if let Some(file) = &msg.attachment {
let mut file = file.try_clone()?;
let mut buf = [0u8; 8000];
loop {
match file.read(&mut buf[..]) {
Ok(0) => break,
Ok(n) => {
stream.write_all(&buf[..n])?;
// Increase sent bytes "quietly" without incrementing the counter.
// (In a loop here for the single attachment).
tracker.inc_quiet_sent(n as u64);
}
Err(e) => return Err(From::from(e)),
}
}
}
Ok(())
}

Expand Down
10 changes: 3 additions & 7 deletions p2p/src/peer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use crate::core::ser::Writeable;
use crate::core::{core, global};
use crate::handshake::Handshake;
use crate::msg::{
self, BanReason, GetPeerAddrs, KernelDataRequest, Locator, Ping, TxHashSetRequest, Type,
self, BanReason, GetPeerAddrs, KernelDataRequest, Locator, Msg, Ping, TxHashSetRequest, Type,
};
use crate::protocol::Protocol;
use crate::types::{
Expand Down Expand Up @@ -233,12 +233,8 @@ impl Peer {

/// Send a msg with given msg_type to our peer via the connection.
fn send<T: Writeable>(&self, msg: T, msg_type: Type) -> Result<(), Error> {
let bytes = self
.send_handle
.lock()
.send(msg, msg_type, self.info.version)?;
self.tracker.inc_sent(bytes);
Ok(())
let msg = Msg::new(msg_type, msg, self.info.version)?;
self.send_handle.lock().send(msg)
}

/// Send a ping to the remote peer, providing our local difficulty and
Expand Down
Loading