Skip to content

Commit

Permalink
Merge pull request #88 from llogiq/wasm
Browse files Browse the repository at this point in the history
implement wasm32 simd, optimize aarch64 num_chars, bump version
  • Loading branch information
llogiq authored Oct 22, 2023
2 parents 150b4aa + d5d3acc commit 2d41959
Show file tree
Hide file tree
Showing 6 changed files with 238 additions and 31 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ authors = ["Andre Bogus <[email protected]>", "Joshua Landau <[email protected]
description = "count occurrences of a given byte, or the number of UTF-8 code points, in a byte slice, fast"
edition = "2018"
name = "bytecount"
version = "0.6.4"
version = "0.6.5"
license = "Apache-2.0/MIT"
repository = "https://github.com/llogiq/bytecount"
categories = ["algorithms", "no-std"]
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ The [newlinebench](https://github.com/llogiq/newlinebench) repository has furthe

To use bytecount in your crate, if you have [cargo-edit](https://github.com/killercup/cargo-edit), just type
`cargo add bytecount` in a terminal with the crate root as the current path. Otherwise you can manually edit your
`Cargo.toml` to add `bytecount = 0.6.4` to your `[dependencies]` section.
`Cargo.toml` to add `bytecount = 0.6.5` to your `[dependencies]` section.

In your crate root (`lib.rs` or `main.rs`, depending on if you are writing a
library or application), add `extern crate bytecount;`. Now you can simply use
Expand Down
15 changes: 15 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ mod integer_simd;
any(target_arch = "x86", target_arch = "x86_64")
),
target_arch = "aarch64",
target_arch = "wasm32",
feature = "generic-simd"
))]
mod simd;
Expand Down Expand Up @@ -96,6 +97,13 @@ pub fn count(haystack: &[u8], needle: u8) -> usize {
return simd::aarch64::chunk_count(haystack, needle);
}
}

#[cfg(target_arch = "wasm32")]
{
unsafe {
return simd::wasm::chunk_count(haystack, needle);
}
}
}

if haystack.len() >= mem::size_of::<usize>() {
Expand Down Expand Up @@ -151,6 +159,13 @@ pub fn num_chars(utf8_chars: &[u8]) -> usize {
return simd::aarch64::chunk_num_chars(utf8_chars);
}
}

#[cfg(target_arch = "wasm32")]
{
unsafe {
return simd::wasm::chunk_num_chars(utf8_chars);
}
}
}

if utf8_chars.len() >= mem::size_of::<usize>() {
Expand Down
63 changes: 34 additions & 29 deletions src/simd/aarch64.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use core::arch::aarch64::{
uint8x16_t, uint8x16x4_t, vaddlvq_u8, vandq_u8, vceqq_u8, vdupq_n_u8, vld1q_u8, vld1q_u8_x4,
vmvnq_u8, vsubq_u8,
vsubq_u8,
};

const MASK: [u8; 32] = [
Expand Down Expand Up @@ -35,6 +35,10 @@ unsafe fn sum(u8s: uint8x16_t) -> usize {
vaddlvq_u8(u8s) as usize
}

unsafe fn sum4(u1: uint8x16_t, u2: uint8x16_t, u3: uint8x16_t, u4: uint8x16_t) -> usize {
((vaddlvq_u8(u1) + vaddlvq_u8(u2)) + (vaddlvq_u8(u3) + vaddlvq_u8(u4))) as usize
}

#[target_feature(enable = "neon")]
pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize {
assert!(haystack.len() >= 16);
Expand All @@ -56,7 +60,7 @@ pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize {
count4 = vsubq_u8(count4, vceqq_u8(h4, needles));
offset += 64;
}
count += sum(count1) + sum(count2) + sum(count3) + sum(count4);
count += sum4(count1, count2, count3, count4);
}

// 64
Expand All @@ -70,7 +74,7 @@ pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize {
count4 = vsubq_u8(count4, vceqq_u8(h4, needles));
offset += 64;
}
count += sum(count1) + sum(count2) + sum(count3) + sum(count4);
count += sum4(count1, count2, count3, count4);

let mut counts = vdupq_n_u8(0);
// 16
Expand All @@ -93,11 +97,11 @@ pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize {
}

#[target_feature(enable = "neon")]
unsafe fn is_leading_utf8_byte(u8s: uint8x16_t) -> uint8x16_t {
vmvnq_u8(vceqq_u8(
unsafe fn is_following_utf8_byte(u8s: uint8x16_t) -> uint8x16_t {
vceqq_u8(
vandq_u8(u8s, vdupq_n_u8(0b1100_0000)),
vdupq_n_u8(0b1000_0000),
))
)
}

#[target_feature(enable = "neon")]
Expand All @@ -108,50 +112,51 @@ pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize {
let mut count = 0;

// 4080
while utf8_chars.len() >= offset + 16 * 255 {
let mut counts = vdupq_n_u8(0);
while utf8_chars.len() >= offset + 64 * 255 {
let (mut count1, mut count2, mut count3, mut count4) =
(vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0));

for _ in 0..255 {
counts = vsubq_u8(
counts,
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset)),
);
offset += 16;
let uint8x16x4_t(h1, h2, h3, h4) = u8x16_x4_from_offset(utf8_chars, offset);
count1 = vsubq_u8(count1,is_following_utf8_byte(h1));
count2 = vsubq_u8(count2,is_following_utf8_byte(h2));
count3 = vsubq_u8(count3,is_following_utf8_byte(h3));
count4 = vsubq_u8(count4,is_following_utf8_byte(h4));
offset += 64;
}
count += sum(counts);
count += sum4(count1, count2, count3, count4);
}

// 2048
if utf8_chars.len() >= offset + 16 * 128 {
let mut counts = vdupq_n_u8(0);
for _ in 0..128 {
counts = vsubq_u8(
counts,
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset)),
);
offset += 16;
// 4080
let (mut count1, mut count2, mut count3, mut count4) =
(vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0));
for _ in 0..(utf8_chars.len() - offset) / 64 {
let uint8x16x4_t(h1, h2, h3, h4) = u8x16_x4_from_offset(utf8_chars, offset);
count1 = vsubq_u8(count1, is_following_utf8_byte(h1));
count2 = vsubq_u8(count2, is_following_utf8_byte(h2));
count3 = vsubq_u8(count3, is_following_utf8_byte(h3));
count4 = vsubq_u8(count4, is_following_utf8_byte(h4));
offset += 64;
}
count += sum(counts);
}

count += sum4(count1, count2, count3, count4);
// 16
let mut counts = vdupq_n_u8(0);
for i in 0..(utf8_chars.len() - offset) / 16 {
counts = vsubq_u8(
counts,
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset + i * 16)),
is_following_utf8_byte(u8x16_from_offset(utf8_chars, offset + i * 16)),
);
}
if utf8_chars.len() % 16 != 0 {
counts = vsubq_u8(
counts,
vandq_u8(
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, utf8_chars.len() - 16)),
is_following_utf8_byte(u8x16_from_offset(utf8_chars, utf8_chars.len() - 16)),
u8x16_from_offset(&MASK, utf8_chars.len() % 16),
),
);
}
count += sum(counts);

count
utf8_chars.len() - count
}
3 changes: 3 additions & 0 deletions src/simd/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@ pub mod x86_avx2;
/// Modern ARM machines are also quite capable thanks to NEON
#[cfg(target_arch = "aarch64")]
pub mod aarch64;

#[cfg(target_arch = "wasm32")]
pub mod wasm;
184 changes: 184 additions & 0 deletions src/simd/wasm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
use core::arch::wasm32::*;

const MASK: [u8; 32] = [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255,
];

#[target_feature(enable = "simd128")]
unsafe fn u8x16_from_offset(slice: &[u8], offset: usize) -> v128 {
debug_assert!(
offset + 16 <= slice.len(),
"{} + 16 ≥ {}",
offset,
slice.len()
);
v128_load(slice.as_ptr().add(offset) as *const _)
}

#[target_feature(enable = "simd128")]
unsafe fn u8x16x4_from_offset(slice: &[u8], offset: usize) -> (v128, v128, v128, v128) {
debug_assert!(
offset + 64 <= slice.len(),
"{} + 64 ≥ {}",
offset,
slice.len()
);
(
v128_load(slice.as_ptr().add(offset + 0) as *const _),
v128_load(slice.as_ptr().add(offset + 16) as *const _),
v128_load(slice.as_ptr().add(offset + 32) as *const _),
v128_load(slice.as_ptr().add(offset + 48) as *const _),
)
}

// TODO: We might want to amortize some additions by
// keeping in multiple u16s and u32s respectively for a few ns
#[target_feature(enable = "simd128")]
unsafe fn sum(u8s: v128) -> usize {
let u16s = u16x8_extadd_pairwise_u8x16(u8s);
let u32s = u32x4_extadd_pairwise_u16x8(u16s);
let (u1, u2, u3, u4) = (
u32x4_extract_lane::<1>(u32s),
u32x4_extract_lane::<2>(u32s),
u32x4_extract_lane::<3>(u32s),
u32x4_extract_lane::<4>(u32s),
);
((u1 + u2) + (u3 + u4)) as usize
}

#[target_feature(enable = "simd128")]
unsafe fn sum4(u1: v128, u2: v128, u3: v128, u4: v128) -> usize {
// sum < (2^2 * 2^3 * 2^8 = 2^13) < 2^16, therefore no overflow here
let u16s = u16x8_add(
u16x8_add(u16x8_extadd_pairwise_u8x16(u1), u16x8_extadd_pairwise_u8x16(u2)),
u16x8_add(u16x8_extadd_pairwise_u8x16(u3), u16x8_extadd_pairwise_u8x16(u4)),
);
let u32s = u32x4_extadd_pairwise_u16x8(u16s);
let (u1, u2, u3, u4) = (
u32x4_extract_lane::<1>(u32s),
u32x4_extract_lane::<2>(u32s),
u32x4_extract_lane::<3>(u32s),
u32x4_extract_lane::<4>(u32s),
);
((u1 + u2) + (u3 + u4)) as usize
}

#[target_feature(enable = "simd128")]
pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize {
let needles = u8x16_splat(needle);
let mut count = 0;
let mut offset = 0;

while haystack.len() >= offset + 16 * 255 {
let (mut count1, mut count2, mut count3, mut count4) =
(u8x16_splat(0), u8x16_splat(0), u8x16_splat(0), u8x16_splat(0));
for _ in 0..255 {
let (h1, h2, h3, h4) = u8x16x4_from_offset(haystack, offset);
count1 = u8x16_sub(count1, u8x16_eq(h1, needles));
count2 = u8x16_sub(count2, u8x16_eq(h2, needles));
count3 = u8x16_sub(count3, u8x16_eq(h3, needles));
count4 = u8x16_sub(count4, u8x16_eq(h4, needles));
offset += 64;
}
count += sum4(count1, count2, count3, count4);
}

// 64
let (mut count1, mut count2, mut count3, mut count4) =
(u8x16_splat(0), u8x16_splat(0), u8x16_splat(0), u8x16_splat(0));
for _ in 0..(haystack.len() - offset) / 64 {
let (h1, h2, h3, h4) = u8x16x4_from_offset(haystack, offset);
count1 = u8x16_sub(count1, u8x16_eq(h1, needles));
count2 = u8x16_sub(count2, u8x16_eq(h2, needles));
count3 = u8x16_sub(count3, u8x16_eq(h3, needles));
count4 = u8x16_sub(count4, u8x16_eq(h4, needles));
offset += 64;
}
count += sum4(count1, count2, count3, count4);

let mut counts = u8x16_splat(0);
// 16
for i in 0..(haystack.len() - offset) / 16 {
counts = u8x16_sub(
counts,
u8x16_eq(u8x16_from_offset(haystack, offset + i * 16), needles),
);
}
if haystack.len() % 16 != 0 {
counts = u8x16_sub(
counts,
v128_and(
u8x16_eq(u8x16_from_offset(haystack, haystack.len() - 16), needles),
u8x16_from_offset(&MASK, haystack.len() % 16),
),
);
}
count + sum(counts)
}

#[target_feature(enable = "simd128")]
unsafe fn is_leading_utf8_byte(u8s: v128) -> v128 {
u8x16_ne(
v128_and(u8s, u8x16_splat(0b1100_0000)),
u8x16_splat(0b1000_0000),
)
}

#[target_feature(enable = "simd128")]
pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize {
assert!(utf8_chars.len() >= 16);

let mut offset = 0;
let mut count = 0;

// 4080
while utf8_chars.len() >= offset + 64 * 255 {
let (mut count1, mut count2, mut count3, mut count4) =
(u8x16_splat(0), u8x16_splat(0), u8x16_splat(0), u8x16_splat(0));

for _ in 0..255 {
let (h1, h2, h3, h4) = u8x16x4_from_offset(utf8_chars, offset);
count1 = u8x16_sub(count1,is_leading_utf8_byte(h1));
count2 = u8x16_sub(count2,is_leading_utf8_byte(h2));
count3 = u8x16_sub(count3,is_leading_utf8_byte(h3));
count4 = u8x16_sub(count4,is_leading_utf8_byte(h4));
offset += 64;
}
count += sum4(count1, count2, count3, count4);
}

// 4080
let (mut count1, mut count2, mut count3, mut count4) =
(u8x16_splat(0), u8x16_splat(0), u8x16_splat(0), u8x16_splat(0));
for _ in 0..(utf8_chars.len() - offset) / 64 {
let (h1, h2, h3, h4) = u8x16x4_from_offset(utf8_chars, offset);
count1 = u8x16_sub(count1, is_leading_utf8_byte(h1));
count2 = u8x16_sub(count2, is_leading_utf8_byte(h2));
count3 = u8x16_sub(count3, is_leading_utf8_byte(h3));
count4 = u8x16_sub(count4, is_leading_utf8_byte(h4));
offset += 64;
}
count += sum4(count1, count2, count3, count4);

// 16
let mut counts = u8x16_splat(0);
for i in 0..(utf8_chars.len() - offset) / 16 {
counts = u8x16_sub(
counts,
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset + i * 16)),
);
}
if utf8_chars.len() % 16 != 0 {
counts = u8x16_sub(
counts,
v128_and(
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, utf8_chars.len() - 16)),
u8x16_from_offset(&MASK, utf8_chars.len() % 16),
),
);
}
count += sum(counts);

count
}

0 comments on commit 2d41959

Please sign in to comment.