Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cleanup: drop SWAR's 64-bit assumptions #140

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 41 additions & 38 deletions src/simd/swar.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
/// SWAR: SIMD Within A Register
/// SIMD validator backend that validates register-sized chunks of data at a time.
// TODO: current impl assumes 64-bit registers, optimize for 32-bit
use crate::{is_header_name_token, is_header_value_token, is_uri_token, Bytes};

// Adapt block-size to match native register size, i.e: 32bit => 4, 64bit => 8
const BLOCK_SIZE: usize = core::mem::size_of::<usize>();
type ByteBlock = [u8; BLOCK_SIZE];

#[inline]
pub fn match_uri_vectored(bytes: &mut Bytes) {
loop {
if let Some(bytes8) = bytes.peek_n::<[u8; 8]>(8) {
if let Some(bytes8) = bytes.peek_n::<ByteBlock>(BLOCK_SIZE) {
let n = match_uri_char_8_swar(bytes8);
unsafe {
bytes.advance(n);
}
if n == 8 {
if n == BLOCK_SIZE {
continue;
}
}
Expand All @@ -28,12 +31,12 @@ pub fn match_uri_vectored(bytes: &mut Bytes) {
#[inline]
pub fn match_header_value_vectored(bytes: &mut Bytes) {
loop {
if let Some(bytes8) = bytes.peek_n::<[u8; 8]>(8) {
if let Some(bytes8) = bytes.peek_n::<ByteBlock>(BLOCK_SIZE) {
let n = match_header_value_char_8_swar(bytes8);
unsafe {
bytes.advance(n);
}
if n == 8 {
if n == BLOCK_SIZE {
continue;
}
}
Expand All @@ -49,19 +52,19 @@ pub fn match_header_value_vectored(bytes: &mut Bytes) {

#[inline]
pub fn match_header_name_vectored(bytes: &mut Bytes) {
while let Some(block) = bytes.peek_n::<[u8; 8]>(8) {
while let Some(block) = bytes.peek_n::<ByteBlock>(BLOCK_SIZE) {
let n = match_block(is_header_name_token, block);
unsafe {
bytes.advance(n);
}
if n != 8 {
if n != BLOCK_SIZE {
return;
}
}
unsafe { bytes.advance(match_tail(is_header_name_token, bytes.as_ref())) };
}

// Matches "tail", i.e: when we have <8 bytes in the buffer, should be uncommon
// Matches "tail", i.e: when we have <BLOCK_SIZE bytes in the buffer, should be uncommon
#[cold]
#[inline]
fn match_tail(f: impl Fn(u8) -> bool, bytes: &[u8]) -> usize {
Expand All @@ -75,35 +78,35 @@ fn match_tail(f: impl Fn(u8) -> bool, bytes: &[u8]) -> usize {

// Naive fallback block matcher
#[inline(always)]
fn match_block(f: impl Fn(u8) -> bool, block: [u8; 8]) -> usize {
fn match_block(f: impl Fn(u8) -> bool, block: ByteBlock) -> usize {
for (i, &b) in block.iter().enumerate() {
if !f(b) {
return i;
}
}
8
BLOCK_SIZE
}

/// // A const alternative to u64::from_ne_bytes to avoid bumping MSRV (1.36 => 1.44)
// A const alternative to u64::from_ne_bytes to avoid bumping MSRV (1.36 => 1.44)
// creates a u64 whose bytes are each equal to b
const fn uniform_block(b: u8) -> u64 {
b as u64 * 0x01_01_01_01_01_01_01_01 // [1_u8; 8]
const fn uniform_block(b: u8) -> usize {
(b as u64 * 0x01_01_01_01_01_01_01_01 /* [1_u8; 8] */) as usize
}

// A byte-wise range-check on an enire word/block,
// ensuring all bytes in the word satisfy
// `33 <= x <= 126 && x != '>' && x != '<'`
// IMPORTANT: it false negatives if the block contains '?'
#[inline]
fn match_uri_char_8_swar(block: [u8; 8]) -> usize {
fn match_uri_char_8_swar(block: ByteBlock) -> usize {
// 33 <= x <= 126
const M: u8 = 0x21;
const N: u8 = 0x7E;
const BM: u64 = uniform_block(M);
const BN: u64 = uniform_block(127 - N);
const M128: u64 = uniform_block(128);
const BM: usize = uniform_block(M);
const BN: usize = uniform_block(127 - N);
const M128: usize = uniform_block(128);

let x = u64::from_ne_bytes(block); // Really just a transmute
let x = usize::from_ne_bytes(block); // Really just a transmute
let lt = x.wrapping_sub(BM) & !x; // <= m
let gt = x.wrapping_add(BN) | x; // >= n

Expand All @@ -130,8 +133,8 @@ fn match_uri_char_8_swar(block: [u8; 8]) -> usize {
// }
// (xordist(b'<', 2), xordist(b'>', 2))
// ```
const B3: u64 = uniform_block(3); // (dist <= 2) + 1 to wrap
const BGT: u64 = uniform_block(b'>');
const B3: usize = uniform_block(3); // (dist <= 2) + 1 to wrap
const BGT: usize = uniform_block(b'>');

let xgt = x ^ BGT;
let ltgtq = xgt.wrapping_sub(B3) & !xgt;
Expand All @@ -143,15 +146,15 @@ fn match_uri_char_8_swar(block: [u8; 8]) -> usize {
// ensuring all bytes in the word satisfy `32 <= x <= 126`
// IMPORTANT: false negatives if obs-text is present (0x80..=0xFF)
#[inline]
fn match_header_value_char_8_swar(block: [u8; 8]) -> usize {
fn match_header_value_char_8_swar(block: ByteBlock) -> usize {
// 32 <= x <= 126
const M: u8 = 0x20;
const N: u8 = 0x7E;
const BM: u64 = uniform_block(M);
const BN: u64 = uniform_block(127 - N);
const M128: u64 = uniform_block(128);
const BM: usize = uniform_block(M);
const BN: usize = uniform_block(127 - N);
const M128: usize = uniform_block(128);

let x = u64::from_ne_bytes(block); // Really just a transmute
let x = usize::from_ne_bytes(block); // Really just a transmute
let lt = x.wrapping_sub(BM) & !x; // <= m
let gt = x.wrapping_add(BN) | x; // >= n
offsetnz((lt | gt) & M128)
Expand All @@ -160,10 +163,10 @@ fn match_header_value_char_8_swar(block: [u8; 8]) -> usize {
/// Check block to find offset of first non-zero byte
// NOTE: Curiously `block.trailing_zeros() >> 3` appears to be slower, maybe revisit
#[inline]
fn offsetnz(block: u64) -> usize {
fn offsetnz(block: usize) -> usize {
// fast path optimistic case (common for long valid sequences)
if block == 0 {
return 8;
return BLOCK_SIZE;
}

// perf: rust will unroll this loop
Expand All @@ -177,19 +180,19 @@ fn offsetnz(block: u64) -> usize {

#[test]
fn test_is_header_value_block() {
let is_header_value_block = |b| match_header_value_char_8_swar(b) == 8;
let is_header_value_block = |b| match_header_value_char_8_swar(b) == BLOCK_SIZE;

// 0..32 => false
for b in 0..32_u8 {
assert_eq!(is_header_value_block([b; 8]), false, "b={}", b);
assert_eq!(is_header_value_block([b; BLOCK_SIZE]), false, "b={}", b);
}
// 32..127 => true
for b in 32..127_u8 {
assert_eq!(is_header_value_block([b; 8]), true, "b={}", b);
assert_eq!(is_header_value_block([b; BLOCK_SIZE]), true, "b={}", b);
}
// 127..=255 => false
for b in 127..=255_u8 {
assert_eq!(is_header_value_block([b; 8]), false, "b={}", b);
assert_eq!(is_header_value_block([b; BLOCK_SIZE]), false, "b={}", b);
}

// A few sanity checks on non-uniform bytes for safe-measure
Expand All @@ -199,30 +202,30 @@ fn test_is_header_value_block() {

#[test]
fn test_is_uri_block() {
let is_uri_block = |b| match_uri_char_8_swar(b) == 8;
let is_uri_block = |b| match_uri_char_8_swar(b) == BLOCK_SIZE;

// 0..33 => false
for b in 0..33_u8 {
assert_eq!(is_uri_block([b; 8]), false, "b={}", b);
assert_eq!(is_uri_block([b; BLOCK_SIZE]), false, "b={}", b);
}
// 33..127 => true if b not in { '<', '?', '>' }
let falsy = |b| b"<?>".contains(&b);
for b in 33..127_u8 {
assert_eq!(is_uri_block([b; 8]), !falsy(b), "b={}", b);
assert_eq!(is_uri_block([b; BLOCK_SIZE]), !falsy(b), "b={}", b);
}
// 127..=255 => false
for b in 127..=255_u8 {
assert_eq!(is_uri_block([b; 8]), false, "b={}", b);
assert_eq!(is_uri_block([b; BLOCK_SIZE]), false, "b={}", b);
}
}

#[test]
fn test_offsetnz() {
let seq = [0_u8; 8];
for i in 0..8 {
let seq = [0_u8; BLOCK_SIZE];
for i in 0..BLOCK_SIZE {
let mut seq = seq.clone();
seq[i] = 1;
let x = u64::from_ne_bytes(seq);
let x = usize::from_ne_bytes(seq);
assert_eq!(offsetnz(x), i);
}
}