Skip to content

Commit

Permalink
feat(utf8): use a feature instead, add a specific error if utf8 is in…
Browse files Browse the repository at this point in the history
…valid
  • Loading branch information
joelwurtz committed Sep 3, 2024
1 parent 6cb1b83 commit 47380b3
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 47 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ edition = "2018"
build = "build.rs"

[features]
default = ["std"]
default = ["std", "utf8_in_path"]
utf8_in_path = []
std = []

[dev-dependencies]
Expand Down
59 changes: 40 additions & 19 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ fn is_token(b: u8) -> bool {
// ASCII codes to accept URI string.
// i.e. A-Z a-z 0-9 !#$%&'*+-._();:@=,/?[]~^
// TODO: Make a stricter checking for URI string?
#[cfg(not(feature = "utf8_in_path"))]
static URI_MAP: [bool; 256] = byte_map![
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
// \0 \n
Expand Down Expand Up @@ -90,7 +91,8 @@ static URI_MAP: [bool; 256] = byte_map![
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
];

static URI_NON_COMPLIANT_MAP: [bool; 256] = byte_map![
#[cfg(feature = "utf8_in_path")]
static URI_MAP: [bool; 256] = byte_map![
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
Expand All @@ -110,12 +112,8 @@ static URI_NON_COMPLIANT_MAP: [bool; 256] = byte_map![
];

#[inline]
pub(crate) fn is_uri_token(b: u8, allow_non_compliant: bool) -> bool {
if allow_non_compliant {
URI_NON_COMPLIANT_MAP[b as usize]
} else {
URI_MAP[b as usize]
}
pub(crate) fn is_uri_token(b: u8) -> bool {
URI_MAP[b as usize]
}

static HEADER_NAME_MAP: [bool; 256] = byte_map![
Expand Down Expand Up @@ -184,6 +182,9 @@ pub enum Error {
TooManyHeaders,
/// Invalid byte in HTTP version.
Version,
#[cfg(feature = "utf8_in_path")]
/// Invalid UTF-8 in path.
Utf8Error,
}

impl Error {
Expand All @@ -197,6 +198,8 @@ impl Error {
Error::Token => "invalid token",
Error::TooManyHeaders => "too many headers",
Error::Version => "invalid HTTP version",
#[cfg(feature = "utf8_in_path")]
Error::Utf8Error => "invalid UTF-8 in path",
}
}
}
Expand Down Expand Up @@ -283,7 +286,6 @@ pub struct ParserConfig {
allow_multiple_spaces_in_request_line_delimiters: bool,
allow_multiple_spaces_in_response_status_delimiters: bool,
allow_space_before_first_header_name: bool,
allow_rfc3986_non_compliant_path: bool,
ignore_invalid_headers_in_responses: bool,
ignore_invalid_headers_in_requests: bool,
}
Expand Down Expand Up @@ -563,7 +565,7 @@ impl<'h, 'b> Request<'h, 'b> {
if config.allow_multiple_spaces_in_request_line_delimiters {
complete!(skip_spaces(&mut bytes));
}
self.path = Some(complete!(parse_uri(&mut bytes, config.allow_rfc3986_non_compliant_path)));
self.path = Some(complete!(parse_uri(&mut bytes)));
if config.allow_multiple_spaces_in_request_line_delimiters {
complete!(skip_spaces(&mut bytes));
}
Expand Down Expand Up @@ -976,9 +978,9 @@ fn parse_token<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> {
#[doc(hidden)]
#[allow(missing_docs)]
// WARNING: Exported for internal benchmarks, not fit for public consumption
pub fn parse_uri<'a>(bytes: &mut Bytes<'a>, allow_non_compliant: bool) -> Result<&'a str> {
pub fn parse_uri<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> {
let start = bytes.pos();
simd::match_uri_vectored(bytes, allow_non_compliant);
simd::match_uri_vectored(bytes);
let end = bytes.pos();

if next!(bytes) == b' ' {
Expand All @@ -987,6 +989,14 @@ pub fn parse_uri<'a>(bytes: &mut Bytes<'a>, allow_non_compliant: bool) -> Result
return Err(Error::Token);
}

#[cfg(feature = "utf8_in_path")]
// SAFETY: all bytes up till `i` must have been `is_token` and therefore also utf-8.
return match str::from_utf8(unsafe { bytes.slice_skip(1) }) {
Ok(uri) => Ok(Status::Complete(uri)),
Err(_) => Err(Error::Utf8Error),
};

#[cfg(not(feature = "utf8_in_path"))]
return Ok(Status::Complete(
// SAFETY: all bytes up till `i` must have been `is_token` and therefore also utf-8.
unsafe { str::from_utf8_unchecked(bytes.slice_skip(1)) },
Expand Down Expand Up @@ -2077,7 +2087,7 @@ mod tests {
assert_eq!(parse_chunk_size(b"567f8a\rfoo"), Err(crate::InvalidChunkSize));
assert_eq!(parse_chunk_size(b"567f8a\rfoo"), Err(crate::InvalidChunkSize));
assert_eq!(parse_chunk_size(b"567xf8a\r\n"), Err(crate::InvalidChunkSize));
assert_eq!(parse_chunk_size(b"ffffffffffffffff\r\n"), Ok(Status::Complete((18, std::u64::MAX))));
assert_eq!(parse_chunk_size(b"ffffffffffffffff\r\n"), Ok(Status::Complete((18, u64::MAX))));
assert_eq!(parse_chunk_size(b"1ffffffffffffffff\r\n"), Err(crate::InvalidChunkSize));
assert_eq!(parse_chunk_size(b"Affffffffffffffff\r\n"), Err(crate::InvalidChunkSize));
assert_eq!(parse_chunk_size(b"fffffffffffffffff\r\n"), Err(crate::InvalidChunkSize));
Expand Down Expand Up @@ -2185,7 +2195,7 @@ mod tests {
assert_eq!(result, Err(crate::Error::Token));
}

static REQUEST_WITH_MULTIPLE_SPACES_AND_BAD_PATH: &[u8] = b"GET /foo>ohno HTTP/1.1\r\n\r\n";
static REQUEST_WITH_MULTIPLE_SPACES_AND_BAD_PATH: &[u8] = b"GET /foo ohno HTTP/1.1\r\n\r\n";

#[test]
fn test_request_with_multiple_spaces_and_bad_path() {
Expand All @@ -2194,7 +2204,7 @@ mod tests {
let result = crate::ParserConfig::default()
.allow_multiple_spaces_in_request_line_delimiters(true)
.parse_request(&mut request, REQUEST_WITH_MULTIPLE_SPACES_AND_BAD_PATH);
assert_eq!(result, Err(crate::Error::Token));
assert_eq!(result, Err(crate::Error::Version));
}

static RESPONSE_WITH_SPACES_IN_CODE: &[u8] = b"HTTP/1.1 99 200 OK\r\n\r\n";
Expand Down Expand Up @@ -2702,7 +2712,8 @@ mod tests {
}

#[test]
fn test_rfc3986_non_compliant_path_ko() {
#[cfg(not(feature = "utf8_in_path"))]
fn test_utf8_in_path_ko() {
let mut headers = [EMPTY_HEADER; 1];
let mut request = Request::new(&mut headers[..]);

Expand All @@ -2712,13 +2723,12 @@ mod tests {
}

#[test]
fn test_rfc3986_non_compliant_path_ok() {
#[cfg(feature = "utf8_in_path")]
fn test_utf8_in_path_ok() {
let mut headers = [EMPTY_HEADER; 1];
let mut request = Request::new(&mut headers[..]);
let mut config = crate::ParserConfig::default();
config.allow_rfc3986_non_compliant_path = true;

let result = config.parse_request(&mut request, b"GET /test?post=I\xE2\x80\x99msorryIforkedyou HTTP/1.1\r\nHost: example.org\r\n\r\n");
let result = crate::ParserConfig::default().parse_request(&mut request, b"GET /test?post=I\xE2\x80\x99msorryIforkedyou HTTP/1.1\r\nHost: example.org\r\n\r\n");

assert_eq!(result, Ok(Status::Complete(67)));
assert_eq!(request.version.unwrap(), 1);
Expand All @@ -2728,4 +2738,15 @@ mod tests {
assert_eq!(request.headers[0].name, "Host");
assert_eq!(request.headers[0].value, &b"example.org"[..]);
}

#[test]
#[cfg(feature = "utf8_in_path")]
fn test_bad_utf8_in_path() {
let mut headers = [EMPTY_HEADER; 1];
let mut request = Request::new(&mut headers[..]);

let result = crate::ParserConfig::default().parse_request(&mut request, b"GET /test?post=I\xE2msorryIforkedyou HTTP/1.1\r\nHost: example.org\r\n\r\n");

assert_eq!(result, Err(crate::Error::Utf8Error));
}
}
18 changes: 8 additions & 10 deletions src/simd/avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,10 @@ use crate::iter::Bytes;

#[inline]
#[target_feature(enable = "avx2", enable = "sse4.2")]
pub unsafe fn match_uri_vectored(bytes: &mut Bytes, allow_non_compliant: bool) {
pub unsafe fn match_uri_vectored(bytes: &mut Bytes) {
while bytes.as_ref().len() >= 32 {

let advance = if allow_non_compliant {
match_url_char_non_compliant_32_avx(bytes.as_ref())
} else {
match_url_char_32_avx(bytes.as_ref())
};
let advance = match_url_char_32_avx(bytes.as_ref());

bytes.advance(advance);

Expand All @@ -18,9 +14,10 @@ pub unsafe fn match_uri_vectored(bytes: &mut Bytes, allow_non_compliant: bool) {
}
}
// do both, since avx2 only works when bytes.len() >= 32
super::sse42::match_uri_vectored(bytes, allow_non_compliant)
super::sse42::match_uri_vectored(bytes)
}

#[cfg(not(feature = "utf8_in_path"))]
#[inline(always)]
#[allow(non_snake_case, overflowing_literals)]
#[allow(unused)]
Expand Down Expand Up @@ -62,10 +59,11 @@ unsafe fn match_url_char_32_avx(buf: &[u8]) -> usize {
r.trailing_zeros() as usize
}

#[cfg(feature = "utf8_in_path")]
#[inline(always)]
#[allow(non_snake_case, overflowing_literals)]
#[allow(unused)]
unsafe fn match_url_char_non_compliant_32_avx(buf: &[u8]) -> usize {
unsafe fn match_url_char_32_avx(buf: &[u8]) -> usize {
debug_assert!(buf.len() >= 32);

#[cfg(target_arch = "x86")]
Expand Down Expand Up @@ -140,11 +138,11 @@ fn avx2_code_matches_uri_chars_table() {

#[allow(clippy::undocumented_unsafe_blocks)]
unsafe {
assert!(byte_is_allowed(b'_', |b| match_uri_vectored(b, false)));
assert!(byte_is_allowed(b'_', match_uri_vectored));

for (b, allowed) in crate::URI_MAP.iter().cloned().enumerate() {
assert_eq!(
byte_is_allowed(b as u8, |b| match_uri_vectored(b, false)), allowed,
byte_is_allowed(b as u8, match_uri_vectored), allowed,
"byte_is_allowed({:?}) should be {:?}", b, allowed,
);
}
Expand Down
8 changes: 4 additions & 4 deletions src/simd/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ pub fn match_header_name_vectored(bytes: &mut Bytes) {
super::swar::match_header_name_vectored(bytes);
}

pub fn match_uri_vectored(bytes: &mut Bytes, allow_non_compliant: bool) {
pub fn match_uri_vectored(bytes: &mut Bytes) {
// SAFETY: calls are guarded by a feature check
unsafe {
match get_runtime_feature() {
AVX2 => avx2::match_uri_vectored(bytes, allow_non_compliant),
SSE42 => sse42::match_uri_vectored(bytes, allow_non_compliant),
_ /* NOP */ => super::swar::match_uri_vectored(bytes, allow_non_compliant),
AVX2 => avx2::match_uri_vectored(bytes),
SSE42 => sse42::match_uri_vectored(bytes),
_ /* NOP */ => super::swar::match_uri_vectored(bytes),
}
}
}
Expand Down
18 changes: 8 additions & 10 deletions src/simd/sse42.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
use crate::iter::Bytes;

#[target_feature(enable = "sse4.2")]
pub unsafe fn match_uri_vectored(bytes: &mut Bytes, allow_non_compliant: bool) {
pub unsafe fn match_uri_vectored(bytes: &mut Bytes) {
while bytes.as_ref().len() >= 16 {
let advance = if allow_non_compliant {
match_url_char_non_compliant_16_sse(bytes.as_ref())
} else {
match_url_char_16_sse(bytes.as_ref())
};
let advance = match_url_char_16_sse(bytes.as_ref());

bytes.advance(advance);

if advance != 16 {
return;
}
}
super::swar::match_uri_vectored(bytes, allow_non_compliant);
super::swar::match_uri_vectored(bytes);
}

#[cfg(not(feature = "utf8_in_path"))]
#[inline(always)]
#[allow(non_snake_case, overflowing_literals)]
unsafe fn match_url_char_16_sse(buf: &[u8]) -> usize {
Expand Down Expand Up @@ -66,9 +63,10 @@ unsafe fn match_url_char_16_sse(buf: &[u8]) -> usize {
r.trailing_zeros() as usize
}

#[cfg(feature = "utf8_in_path")]
#[inline(always)]
#[allow(non_snake_case)]
unsafe fn match_url_char_non_compliant_16_sse(buf: &[u8]) -> usize {
unsafe fn match_url_char_16_sse(buf: &[u8]) -> usize {
debug_assert!(buf.len() >= 16);

#[cfg(target_arch = "x86")]
Expand Down Expand Up @@ -143,11 +141,11 @@ fn sse_code_matches_uri_chars_table() {

#[allow(clippy::undocumented_unsafe_blocks)]
unsafe {
assert!(byte_is_allowed(b'_', |b| match_uri_vectored(b, false)));
assert!(byte_is_allowed(b'_', match_uri_vectored));

for (b, allowed) in crate::URI_MAP.iter().cloned().enumerate() {
assert_eq!(
byte_is_allowed(b as u8, |b| match_uri_vectored(b, false)), allowed,
byte_is_allowed(b as u8, match_uri_vectored), allowed,
"byte_is_allowed({:?}) should be {:?}", b, allowed,
);
}
Expand Down
6 changes: 3 additions & 3 deletions src/simd/swar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ const BLOCK_SIZE: usize = core::mem::size_of::<usize>();
type ByteBlock = [u8; BLOCK_SIZE];

#[inline]
pub fn match_uri_vectored(bytes: &mut Bytes, allow_non_compliant: bool) {
pub fn match_uri_vectored(bytes: &mut Bytes) {
loop {
if let Some(bytes8) = bytes.peek_n::<ByteBlock>(BLOCK_SIZE) {
let n = match_uri_char_8_swar(bytes8);
Expand All @@ -21,7 +21,7 @@ pub fn match_uri_vectored(bytes: &mut Bytes, allow_non_compliant: bool) {
}
}
if let Some(b) = bytes.peek() {
if is_uri_token(b, allow_non_compliant) {
if is_uri_token(b) {
// SAFETY: using peek to retrieve the byte ensures that there is at least 1 more byte
// in bytes, so calling advance is safe.
unsafe {
Expand Down Expand Up @@ -106,7 +106,7 @@ fn match_block(f: impl Fn(u8) -> bool, block: ByteBlock) -> usize {
// A const alternative to u64::from_ne_bytes to avoid bumping MSRV (1.36 => 1.44)
// creates a u64 whose bytes are each equal to b
const fn uniform_block(b: u8) -> usize {
(b as u64 * 0x01_01_01_01_01_01_01_01 /* [1_u8; 8] */) as usize
(b as u64 * 0x01_01_01_01_01_01_01_01 /* [1_u8; 8] */) as usize
}

// A byte-wise range-check on an enire word/block,
Expand Down

0 comments on commit 47380b3

Please sign in to comment.