diff --git a/benches/parse.rs b/benches/parse.rs index 97a6204..d3d9c9d 100644 --- a/benches/parse.rs +++ b/benches/parse.rs @@ -113,7 +113,7 @@ fn uri(c: &mut Criterion) { .throughput(Throughput::Bytes(input.len() as u64)) .bench_function(name, |b| b.iter(|| { let mut b = httparse::_benchable::Bytes::new(black_box(input)); - httparse::_benchable::parse_uri(&mut b).unwrap() + httparse::_benchable::parse_uri(&mut b, false).unwrap() })); } diff --git a/src/lib.rs b/src/lib.rs index 4ccd783..bc2f32a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -90,9 +90,32 @@ static URI_MAP: [bool; 256] = byte_map![ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ]; +static URI_NON_COMPLIANT_MAP: [bool; 256] = byte_map![ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, +]; + #[inline] -pub(crate) fn is_uri_token(b: u8) -> bool { - URI_MAP[b as usize] +pub(crate) fn is_uri_token(b: u8, allow_non_compliant: bool) -> bool { + if allow_non_compliant { + URI_NON_COMPLIANT_MAP[b as usize] + } else { + URI_MAP[b as usize] + } } static HEADER_NAME_MAP: [bool; 256] = byte_map![ @@ -260,6 +283,7 @@ pub struct ParserConfig { allow_multiple_spaces_in_request_line_delimiters: bool, allow_multiple_spaces_in_response_status_delimiters: bool, allow_space_before_first_header_name: bool, + allow_rfc3986_non_compliant_path: bool, ignore_invalid_headers_in_responses: bool, ignore_invalid_headers_in_requests: bool, } @@ -539,7 +563,7 @@ impl<'h, 'b> Request<'h, 'b> { if config.allow_multiple_spaces_in_request_line_delimiters { complete!(skip_spaces(&mut bytes)); } - self.path = Some(complete!(parse_uri(&mut bytes))); + self.path = Some(complete!(parse_uri(&mut bytes, config.allow_rfc3986_non_compliant_path))); if config.allow_multiple_spaces_in_request_line_delimiters { complete!(skip_spaces(&mut bytes)); } @@ -952,9 +976,9 @@ fn parse_token<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> { #[doc(hidden)] #[allow(missing_docs)] // WARNING: Exported for internal benchmarks, not fit for public consumption -pub fn parse_uri<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> { +pub fn parse_uri<'a>(bytes: &mut Bytes<'a>, allow_non_compliant: bool) -> Result<&'a str> { let start = bytes.pos(); - simd::match_uri_vectored(bytes); + simd::match_uri_vectored(bytes, allow_non_compliant); let end = bytes.pos(); if next!(bytes) == b' ' { @@ -2676,4 +2700,32 @@ mod tests { assert_eq!(response.headers[0].name, "foo"); assert_eq!(response.headers[0].value, &b"bar"[..]); } + + #[test] + fn test_rfc3986_non_compliant_path_ko() { + let mut headers = [EMPTY_HEADER; 1]; + let mut request = Request::new(&mut headers[..]); + + let result = crate::ParserConfig::default().parse_request(&mut request, b"GET /test?post=I\xE2\x80\x99msorryIforkedyou HTTP/1.1\r\nHost: example.org\r\n\r\n"); + + assert_eq!(result, Err(crate::Error::Token)); + } + + #[test] + fn test_rfc3986_non_compliant_path_ok() { + let mut headers = [EMPTY_HEADER; 1]; + let mut request = Request::new(&mut headers[..]); + let mut config = crate::ParserConfig::default(); + config.allow_rfc3986_non_compliant_path = true; + + let result = config.parse_request(&mut request, b"GET /test?post=I\xE2\x80\x99msorryIforkedyou HTTP/1.1\r\nHost: example.org\r\n\r\n"); + + assert_eq!(result, Ok(Status::Complete(67))); + assert_eq!(request.version.unwrap(), 1); + assert_eq!(request.method.unwrap(), "GET"); + assert_eq!(request.path.unwrap(), "/test?post=I’msorryIforkedyou"); + assert_eq!(request.headers.len(), 1); + assert_eq!(request.headers[0].name, "Host"); + assert_eq!(request.headers[0].value, &b"example.org"[..]); + } } diff --git a/src/simd/avx2.rs b/src/simd/avx2.rs index c1a41f9..57e584d 100644 --- a/src/simd/avx2.rs +++ b/src/simd/avx2.rs @@ -2,9 +2,15 @@ use crate::iter::Bytes; #[inline] #[target_feature(enable = "avx2")] -pub unsafe fn match_uri_vectored(bytes: &mut Bytes) { +pub unsafe fn match_uri_vectored(bytes: &mut Bytes, allow_non_compliant: bool) { while bytes.as_ref().len() >= 32 { - let advance = match_url_char_32_avx(bytes.as_ref()); + + let advance = if allow_non_compliant { + match_url_char_non_compliant_32_avx(bytes.as_ref()) + } else { + match_url_char_32_avx(bytes.as_ref()) + }; + bytes.advance(advance); if advance != 32 { @@ -12,7 +18,7 @@ pub unsafe fn match_uri_vectored(bytes: &mut Bytes) { } } // NOTE: use SWAR for <32B, more efficient than falling back to SSE4.2 - super::swar::match_uri_vectored(bytes) + super::swar::match_uri_vectored(bytes, allow_non_compliant) } #[inline(always)] @@ -56,6 +62,33 @@ unsafe fn match_url_char_32_avx(buf: &[u8]) -> usize { r.trailing_zeros() as usize } +#[inline(always)] +#[allow(non_snake_case, overflowing_literals)] +#[allow(unused)] +unsafe fn match_url_char_non_compliant_32_avx(buf: &[u8]) -> usize { + debug_assert!(buf.len() >= 32); + + #[cfg(target_arch = "x86")] + use core::arch::x86::*; + #[cfg(target_arch = "x86_64")] + use core::arch::x86_64::*; + + let ptr = buf.as_ptr(); + + // %x21-%x7e %x80-%xff + let DEL: __m256i = _mm256_set1_epi8(0x7f); + let LOW: __m256i = _mm256_set1_epi8(0x21); + + let dat = _mm256_lddqu_si256(ptr as *const _); + // unsigned comparison dat >= LOW + let low = _mm256_cmpeq_epi8(_mm256_max_epu8(dat, LOW), dat); + let del = _mm256_cmpeq_epi8(dat, DEL); + let bit = _mm256_andnot_si256(del, low); + let res = _mm256_movemask_epi8(bit) as u32; + // TODO: use .trailing_ones() once MSRV >= 1.46 + (!res).trailing_zeros() as usize +} + #[target_feature(enable = "avx2")] pub unsafe fn match_header_value_vectored(bytes: &mut Bytes) { while bytes.as_ref().len() >= 32 { @@ -107,11 +140,11 @@ fn avx2_code_matches_uri_chars_table() { #[allow(clippy::undocumented_unsafe_blocks)] unsafe { - assert!(byte_is_allowed(b'_', match_uri_vectored)); + assert!(byte_is_allowed(b'_', |b| match_uri_vectored(b, false))); for (b, allowed) in crate::URI_MAP.iter().cloned().enumerate() { assert_eq!( - byte_is_allowed(b as u8, match_uri_vectored), allowed, + byte_is_allowed(b as u8, |b| match_uri_vectored(b, false)), allowed, "byte_is_allowed({:?}) should be {:?}", b, allowed, ); } diff --git a/src/simd/runtime.rs b/src/simd/runtime.rs index c523a92..20f6017 100644 --- a/src/simd/runtime.rs +++ b/src/simd/runtime.rs @@ -34,13 +34,13 @@ pub fn match_header_name_vectored(bytes: &mut Bytes) { super::swar::match_header_name_vectored(bytes); } -pub fn match_uri_vectored(bytes: &mut Bytes) { +pub fn match_uri_vectored(bytes: &mut Bytes, allow_non_compliant: bool) { // SAFETY: calls are guarded by a feature check unsafe { match get_runtime_feature() { - AVX2 => avx2::match_uri_vectored(bytes), - SSE42 => sse42::match_uri_vectored(bytes), - _ /* NOP */ => super::swar::match_uri_vectored(bytes), + AVX2 => avx2::match_uri_vectored(bytes, allow_non_compliant), + SSE42 => sse42::match_uri_vectored(bytes, allow_non_compliant), + _ /* NOP */ => super::swar::match_uri_vectored(bytes, allow_non_compliant), } } } diff --git a/src/simd/sse42.rs b/src/simd/sse42.rs index d6fbf02..4ab8e00 100644 --- a/src/simd/sse42.rs +++ b/src/simd/sse42.rs @@ -1,16 +1,21 @@ use crate::iter::Bytes; #[target_feature(enable = "sse4.2")] -pub unsafe fn match_uri_vectored(bytes: &mut Bytes) { +pub unsafe fn match_uri_vectored(bytes: &mut Bytes, allow_non_compliant: bool) { while bytes.as_ref().len() >= 16 { - let advance = match_url_char_16_sse(bytes.as_ref()); + let advance = if allow_non_compliant { + match_url_char_non_compliant_16_sse(bytes.as_ref()) + } else { + match_url_char_16_sse(bytes.as_ref()) + }; + bytes.advance(advance); if advance != 16 { return; } } - super::swar::match_uri_vectored(bytes); + super::swar::match_uri_vectored(bytes, allow_non_compliant); } #[inline(always)] @@ -61,6 +66,33 @@ unsafe fn match_url_char_16_sse(buf: &[u8]) -> usize { r.trailing_zeros() as usize } +#[inline(always)] +#[allow(non_snake_case)] +unsafe fn match_url_char_non_compliant_16_sse(buf: &[u8]) -> usize { + debug_assert!(buf.len() >= 16); + + #[cfg(target_arch = "x86")] + use core::arch::x86::*; + #[cfg(target_arch = "x86_64")] + use core::arch::x86_64::*; + + let ptr = buf.as_ptr(); + + // %x21-%x7e %x80-%xff + let DEL: __m128i = _mm_set1_epi8(0x7f); + let LOW: __m128i = _mm_set1_epi8(0x21); + + let dat = _mm_lddqu_si128(ptr as *const _); + // unsigned comparison dat >= LOW + let low = _mm_cmpeq_epi8(_mm_max_epu8(dat, LOW), dat); + let del = _mm_cmpeq_epi8(dat, DEL); + let bit = _mm_andnot_si128(del, low); + let res = _mm_movemask_epi8(bit) as u16; + + // TODO: use .trailing_ones() once MSRV >= 1.46 + (!res).trailing_zeros() as usize +} + #[target_feature(enable = "sse4.2")] pub unsafe fn match_header_value_vectored(bytes: &mut Bytes) { while bytes.as_ref().len() >= 16 { @@ -111,11 +143,11 @@ fn sse_code_matches_uri_chars_table() { #[allow(clippy::undocumented_unsafe_blocks)] unsafe { - assert!(byte_is_allowed(b'_', match_uri_vectored)); + assert!(byte_is_allowed(b'_', |b| match_uri_vectored(b, false))); for (b, allowed) in crate::URI_MAP.iter().cloned().enumerate() { assert_eq!( - byte_is_allowed(b as u8, match_uri_vectored), allowed, + byte_is_allowed(b as u8, |b| match_uri_vectored(b, false)), allowed, "byte_is_allowed({:?}) should be {:?}", b, allowed, ); } diff --git a/src/simd/swar.rs b/src/simd/swar.rs index 857fc58..d745318 100644 --- a/src/simd/swar.rs +++ b/src/simd/swar.rs @@ -7,7 +7,7 @@ const BLOCK_SIZE: usize = core::mem::size_of::(); type ByteBlock = [u8; BLOCK_SIZE]; #[inline] -pub fn match_uri_vectored(bytes: &mut Bytes) { +pub fn match_uri_vectored(bytes: &mut Bytes, allow_non_compliant: bool) { loop { if let Some(bytes8) = bytes.peek_n::(BLOCK_SIZE) { let n = match_uri_char_8_swar(bytes8); @@ -21,7 +21,7 @@ pub fn match_uri_vectored(bytes: &mut Bytes) { } } if let Some(b) = bytes.peek() { - if is_uri_token(b) { + if is_uri_token(b, allow_non_compliant) { // SAFETY: using peek to retrieve the byte ensures that there is at least 1 more byte // in bytes, so calling advance is safe. unsafe {