Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ base64 = "0.4.0"
byteorder = "1.0.0"
bytes = "0.4.1"
httparse = "1.2.1"
env_logger = "0.4.2"
log = "0.3.7"
rand = "0.3.15"
sha1 = "0.2.0"
Expand All @@ -30,3 +29,6 @@ utf-8 = "0.7.0"
[dependencies.native-tls]
optional = true
version = "0.1.1"

[dev-dependencies]
env_logger = "0.4.2"
4 changes: 2 additions & 2 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result<AutoStream
TlsHandshakeError::Failure(f) => f.into(),
TlsHandshakeError::Interrupted(_) => panic!("Bug: TLS handshake not blocked"),
})
.map(|s| StreamSwitcher::Tls(s))
.map(StreamSwitcher::Tls)
}
}
}
Expand All @@ -73,7 +73,7 @@ fn wrap_stream(stream: TcpStream, _domain: &str, mode: Mode) -> Result<AutoStrea
fn connect_to_some<A>(addrs: A, url: &Url, mode: Mode) -> Result<AutoStream>
where A: Iterator<Item=SocketAddr>
{
let domain = url.host_str().ok_or(Error::Url("No host name in the URL".into()))?;
let domain = url.host_str().ok_or_else(|| Error::Url("No host name in the URL".into()))?;
for addr in addrs {
debug!("Trying to contact {} at {}...", url, addr);
if let Ok(raw_stream) = TcpStream::connect(addr) {
Expand Down
2 changes: 1 addition & 1 deletion src/handshake/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ impl Request {
/// Reply to the response.
pub fn reply(&self) -> Result<Vec<u8>> {
let key = self.headers.find_first("Sec-WebSocket-Key")
.ok_or(Error::Protocol("Missing Sec-WebSocket-Key".into()))?;
.ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?;
let reply = format!("\
HTTP/1.1 101 Switching Protocols\r\n\
Connection: Upgrade\r\n\
Expand Down
3 changes: 2 additions & 1 deletion src/input_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ impl InputBuffer {

/// Reserve the given amount of space.
pub fn reserve(&mut self, space: usize, limit: usize) -> Result<(), SizeLimit>{
if self.inp_mut().remaining_mut() >= space {
let remaining = self.inp_mut().capacity() - self.inp_mut().len();
if remaining >= space {
// We have enough space right now.
Ok(())
} else {
Expand Down
118 changes: 74 additions & 44 deletions src/protocol/frame/frame.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use std::fmt;
use std::mem::transmute;
use std::io::{Cursor, Read, Write};
use std::io::{Cursor, Read, Write, ErrorKind};
use std::default::Default;
use std::iter::FromIterator;
use std::string::{String, FromUtf8Error};
use std::result::Result as StdResult;
use byteorder::{ByteOrder, NetworkEndian};
use byteorder::{ByteOrder, ReadBytesExt, NetworkEndian};
use bytes::BufMut;

use rand;
Expand All @@ -14,15 +13,34 @@ use error::{Error, Result};
use super::coding::{OpCode, Control, Data, CloseCode};

fn apply_mask(buf: &mut [u8], mask: &[u8; 4]) {
let iter = buf.iter_mut().zip(mask.iter().cycle());
for (byte, &key) in iter {
*byte ^= key
for (i, byte) in buf.iter_mut().enumerate() {
*byte ^= mask[i & 3];
}
}

/// Faster version of `apply_mask()` which operates on 4-byte blocks.
///
/// Safety: `buf` must be at least 4-bytes aligned.
unsafe fn apply_mask_aligned32(buf: &mut [u8], mask: &[u8; 4]) {
debug_assert_eq!(buf.as_ptr() as usize % 4, 0);

let mask_u32 = transmute(*mask);

let mut ptr = buf.as_mut_ptr() as *mut u32;
for _ in 0..(buf.len() / 4) {
*ptr ^= mask_u32;
ptr = ptr.offset(1);
}

// Possible last block with less than 4 bytes.
let last_block_start = buf.len() & !3;
let last_block = &mut buf[last_block_start..];
apply_mask(last_block, mask);
}

#[inline]
fn generate_mask() -> [u8; 4] {
unsafe { transmute(rand::random::<u32>()) }
rand::random()
}

/// A struct representing a WebSocket frame.
Expand Down Expand Up @@ -175,7 +193,10 @@ impl Frame {
#[inline]
pub fn remove_mask(&mut self) {
self.mask.and_then(|mask| {
Some(apply_mask(&mut self.payload, &mask))
// Assumes Vec's backing memory is at least 4-bytes aligned.
unsafe {
Some(apply_mask_aligned32(&mut self.payload, &mask))
}
});
self.mask = None;
}
Expand Down Expand Up @@ -252,10 +273,7 @@ impl Frame {
let u: u16 = code.into();
transmute(u.to_be())
};
Vec::from_iter(
raw[..].iter()
.chain(reason.as_bytes().iter())
.map(|&b| b))
[&raw[..], reason.as_bytes()].concat()
} else {
Vec::new()
};
Expand Down Expand Up @@ -301,29 +319,24 @@ impl Frame {

let mut length = (second & 0x7F) as u64;

if length == 126 {
let mut length_bytes = [0u8; 2];
if try!(cursor.read(&mut length_bytes)) != 2 {
cursor.set_position(initial);
return Ok(None)
}

length = unsafe {
let mut wide: u16 = transmute(length_bytes);
wide = u16::from_be(wide);
wide
} as u64;
header_length += 2;
} else if length == 127 {
let mut length_bytes = [0u8; 8];
if try!(cursor.read(&mut length_bytes)) != 8 {
cursor.set_position(initial);
return Ok(None)
}

unsafe { length = transmute(length_bytes); }
length = u64::from_be(length);
header_length += 8;
if let Some(length_nbytes) = match length {
126 => Some(2),
127 => Some(8),
_ => None,
} {
match cursor.read_uint::<NetworkEndian>(length_nbytes) {
Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => {
cursor.set_position(initial);
return Ok(None);
}
Err(err) => {
return Err(Error::from(err));
}
Ok(read) => {
length = read;
}
};
header_length += length_nbytes as u64;
}
trace!("Payload length: {}", length);

Expand Down Expand Up @@ -407,18 +420,14 @@ impl Frame {
try!(w.write(&headers));
} else if self.payload.len() <= 65535 {
two |= 126;
let length_bytes: [u8; 2] = unsafe {
let short = self.payload.len() as u16;
transmute(short.to_be())
};
let mut length_bytes = [0u8; 2];
NetworkEndian::write_u16(&mut length_bytes, self.payload.len() as u16);
let headers = [one, two, length_bytes[0], length_bytes[1]];
try!(w.write(&headers));
} else {
two |= 127;
let length_bytes: [u8; 8] = unsafe {
let long = self.payload.len() as u64;
transmute(long.to_be())
};
let mut length_bytes = [0u8; 8];
NetworkEndian::write_u64(&mut length_bytes, self.payload.len() as u64);
let headers = [
one,
two,
Expand All @@ -436,7 +445,10 @@ impl Frame {

if self.is_masked() {
let mask = self.mask.take().unwrap();
apply_mask(&mut self.payload, &mask);
// Assumes Vec's backing memory is at least 4-bytes aligned.
unsafe {
apply_mask_aligned32(&mut self.payload, &mask);
}
try!(w.write(&mask));
}

Expand Down Expand Up @@ -490,6 +502,24 @@ mod tests {
use super::super::coding::{OpCode, Data};
use std::io::Cursor;

#[test]
fn test_apply_mask() {
let mask = [
0x6d, 0xb6, 0xb2, 0x80,
];
let unmasked = vec![
0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82, 0xff, 0xfe, 0x00,
];

let mut masked = unmasked.clone();
apply_mask(&mut masked, &mask);

let mut masked_aligned = unmasked.clone();
unsafe { apply_mask_aligned32(&mut masked_aligned, &mask) };

assert_eq!(masked, masked_aligned);
}

#[test]
fn parse() {
let mut raw: Cursor<Vec<u8>> = Cursor::new(vec![
Expand Down
15 changes: 6 additions & 9 deletions src/protocol/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,12 @@ impl<Stream: Read + Write> WebSocket<Stream> {
/// This function guarantees that the close frame will be queued.
/// There is no need to call it again, just like write_message().
pub fn close(&mut self) -> Result<()> {
match self.state {
WebSocketState::Active => {
self.state = WebSocketState::ClosedByUs;
let frame = Frame::close(None);
self.send_queue.push_back(frame);
}
_ => {
// already closed, nothing to do
}
if let WebSocketState::Active = self.state {
self.state = WebSocketState::ClosedByUs;
let frame = Frame::close(None);
self.send_queue.push_back(frame);
} else {
// Already closed, nothing to do.
}
self.write_pending()
}
Expand Down