Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap MsgHeader in MsgHeaderWrapper for Known/Unknown msg type support #2791

Merged
merged 5 commits into from
May 14, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 28 additions & 17 deletions p2p/src/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ use std::{cmp, thread, time};

use crate::core::ser;
use crate::core::ser::FixedLength;
use crate::msg::{read_body, read_header, read_item, write_to_buf, MsgHeader, Type};
use crate::msg::{
read_body, read_discard, read_header, read_item, write_to_buf, MsgHeader, MsgHeaderWrapper,
Type,
};
use crate::types::Error;
use crate::util::read_write::{read_exact, write_all};
use crate::util::{RateCounter, RwLock};
Expand Down Expand Up @@ -252,27 +255,35 @@ fn poll<H>(
let mut retry_send = Err(());
loop {
// check the read end
if let Some(h) = try_break!(error_tx, read_header(&mut reader, None)) {
let msg = Message::from_header(h, &mut reader);
match try_break!(error_tx, read_header(&mut reader, None)) {
Some(MsgHeaderWrapper::Known(header)) => {
let msg = Message::from_header(header, &mut reader);

trace!(
"Received message header, type {:?}, len {}.",
msg.header.msg_type,
msg.header.msg_len
);
trace!(
"Received message header, type {:?}, len {}.",
msg.header.msg_type,
msg.header.msg_len
);

// Increase received bytes counter
received_bytes
.write()
.inc(MsgHeader::LEN as u64 + msg.header.msg_len);

// Increase received bytes counter
let received = received_bytes.clone();
{
let mut received_bytes = received_bytes.write();
received_bytes.inc(MsgHeader::LEN as u64 + msg.header.msg_len);
if let Some(Some(resp)) = try_break!(
error_tx,
handler.consume(msg, &mut writer, received_bytes.clone())
) {
try_break!(error_tx, resp.write(sent_bytes.clone()));
}
}
Some(MsgHeaderWrapper::Unknown(msg_len)) => {
// Increase received bytes counter
received_bytes.write().inc(MsgHeader::LEN as u64 + msg_len);

if let Some(Some(resp)) =
try_break!(error_tx, handler.consume(msg, &mut writer, received))
{
try_break!(error_tx, resp.write(sent_bytes.clone()));
try_break!(error_tx, read_discard(msg_len, &mut reader));
}
None => {}
}

// check the write end, use or_else so try_recv is lazily eval'd
Expand Down
104 changes: 79 additions & 25 deletions p2p/src/msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,24 +119,20 @@ fn magic() -> [u8; 2] {
/// Read a header from the provided stream without blocking if the
/// underlying stream is async. Typically headers will be polled for, so
/// we do not want to block.
pub fn read_header(stream: &mut dyn Read, msg_type: Option<Type>) -> Result<MsgHeader, Error> {
///
/// Note: We return a MsgHeaderWrapper here as we may encounter an unknown msg type.
///
pub fn read_header(
stream: &mut dyn Read,
msg_type: Option<Type>,
) -> Result<MsgHeaderWrapper, Error> {
let mut head = vec![0u8; MsgHeader::LEN];
if Some(Type::Hand) == msg_type {
read_exact(stream, &mut head, time::Duration::from_millis(10), true)?;
} else {
read_exact(stream, &mut head, time::Duration::from_secs(10), false)?;
}
let header = ser::deserialize::<MsgHeader>(&mut &head[..])?;
let max_len = max_msg_size(header.msg_type);

// TODO 4x the limits for now to leave ourselves space to change things
if header.msg_len > max_len * 4 {
error!(
"Too large read {}, had {}, wanted {}.",
header.msg_type as u8, max_len, header.msg_len
);
return Err(Error::Serialization(ser::Error::TooLargeReadErr));
}
let header = ser::deserialize::<MsgHeaderWrapper>(&mut &head[..])?;
Ok(header)
}

Expand All @@ -158,13 +154,28 @@ pub fn read_body<T: Readable>(h: &MsgHeader, stream: &mut dyn Read) -> Result<T,
ser::deserialize(&mut &body[..]).map_err(From::from)
}

/// Read (an unknown) message from the provided stream and discard it.
pub fn read_discard(msg_len: u64, stream: &mut dyn Read) -> Result<(), Error> {
let mut buffer = vec![0u8; msg_len as usize];
antiochp marked this conversation as resolved.
Show resolved Hide resolved
read_exact(stream, &mut buffer, time::Duration::from_secs(20), true)?;
Ok(())
}

/// Reads a full message from the underlying stream.
pub fn read_message<T: Readable>(stream: &mut dyn Read, msg_type: Type) -> Result<T, Error> {
let header = read_header(stream, Some(msg_type))?;
if header.msg_type != msg_type {
return Err(Error::BadMessage);
match read_header(stream, Some(msg_type))? {
MsgHeaderWrapper::Known(header) => {
if header.msg_type == msg_type {
read_body(&header, stream)
} else {
Err(Error::BadMessage)
}
}
MsgHeaderWrapper::Unknown(msg_len) => {
read_discard(msg_len, stream)?;
Err(Error::BadMessage)
}
}
read_body(&header, stream)
}

pub fn write_to_buf<T: Writeable>(msg: T, msg_type: Type) -> Result<Vec<u8>, Error> {
Expand All @@ -191,7 +202,19 @@ pub fn write_message<T: Writeable>(
Ok(())
}

/// A wrapper around a message header. If the header is for an unknown msg type
/// then we will be unable to parse the msg itself (just a bunch of random bytes).
/// But we need to know how many bytes to discard to discard the full message.
#[derive(Clone)]
pub enum MsgHeaderWrapper {
/// A "known" msg type with deserialized msg header.
Known(MsgHeader),
/// An unknown msg type with corresponding msg size in bytes.
Unknown(u64),
}

/// Header of any protocol message, used to identify incoming messages.
#[derive(Clone)]
pub struct MsgHeader {
magic: [u8; 2],
/// Type of the message.
Expand Down Expand Up @@ -228,19 +251,50 @@ impl Writeable for MsgHeader {
}
}

impl Readable for MsgHeader {
fn read(reader: &mut dyn Reader) -> Result<MsgHeader, ser::Error> {
impl Readable for MsgHeaderWrapper {
fn read(reader: &mut dyn Reader) -> Result<MsgHeaderWrapper, ser::Error> {
let m = magic();
reader.expect_u8(m[0])?;
reader.expect_u8(m[1])?;
let (t, len) = ser_multiread!(reader, read_u8, read_u64);

// Read the msg header.
// We do not yet know if the msg type is one we support locally.
let (t, msg_len) = ser_multiread!(reader, read_u8, read_u64);

// Attempt to convert the msg type byte into one of our known msg type enum variants.
// Check the msg_len while we are at it.
match Type::from_u8(t) {
Some(ty) => Ok(MsgHeader {
magic: m,
msg_type: ty,
msg_len: len,
}),
None => Err(ser::Error::CorruptedData),
Some(msg_type) => {
// TODO 4x the limits for now to leave ourselves space to change things.
let max_len = max_msg_size(msg_type) * 4;
if msg_len > max_len {
error!(
"Too large read {:?}, max_len: {}, msg_len: {}.",
msg_type, max_len, msg_len
);
return Err(ser::Error::TooLargeReadErr);
}

Ok(MsgHeaderWrapper::Known(MsgHeader {
magic: m,
msg_type,
msg_len,
}))
}
None => {
// Unknown msg type, but we still want to limit how big the msg is.
// Default to max_block_size (4x as above for space to change things).
let max_len = max_block_size() * 4;
antiochp marked this conversation as resolved.
Show resolved Hide resolved
if msg_len > max_len {
antiochp marked this conversation as resolved.
Show resolved Hide resolved
error!(
"Too large read (unknown msg type) {:?}, max_len: {}, msg_len: {}.",
t, max_len, msg_len
);
return Err(ser::Error::TooLargeReadErr);
}

Ok(MsgHeaderWrapper::Unknown(msg_len))
}
}
}
}
Expand Down
13 changes: 10 additions & 3 deletions p2p/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,9 +353,16 @@ impl MessageHandler for Protocol {

Ok(None)
}

_ => {
debug!("unknown message type {:?}", msg.header.msg_type);
Type::Error => {
antiochp marked this conversation as resolved.
Show resolved Hide resolved
debug!("Received an unexpected msg: {:?}", msg.header.msg_type);
Ok(None)
}
Type::Hand => {
debug!("Received an unexpected msg: {:?}", msg.header.msg_type);
Ok(None)
}
Type::Shake => {
debug!("Received an unexpected msg: {:?}", msg.header.msg_type);
Ok(None)
}
}
Expand Down