diff --git a/zlib-rs/src/deflate/slide_hash.rs b/zlib-rs/src/deflate/slide_hash.rs index 57f9ae83..71fd2174 100644 --- a/zlib-rs/src/deflate/slide_hash.rs +++ b/zlib-rs/src/deflate/slide_hash.rs @@ -11,7 +11,8 @@ pub fn slide_hash(state: &mut crate::deflate::State) { fn slide_hash_chain(table: &mut [u16], wsize: u16) { #[cfg(target_arch = "x86_64")] if crate::cpu_features::is_enabled_avx2() { - return avx2::slide_hash_chain(table, wsize); + // SAFETY: the avx2 target feature is enabled. + return unsafe { avx2::slide_hash_chain(table, wsize) }; } #[cfg(target_arch = "aarch64")] @@ -27,13 +28,37 @@ fn slide_hash_chain(table: &mut [u16], wsize: u16) { rust::slide_hash_chain(table, wsize); } +#[inline(always)] +fn generic_slide_hash_chain(table: &mut [u16], wsize: u16) { + debug_assert_eq!(table.len() % N, 0); + + for chunk in table.chunks_exact_mut(N) { + for m in chunk.iter_mut() { + *m = m.saturating_sub(wsize); + } + } +} + mod rust { pub fn slide_hash_chain(table: &mut [u16], wsize: u16) { - for chunk in table.chunks_exact_mut(32) { - for m in chunk.iter_mut() { - *m = m.saturating_sub(wsize); - } - } + // 32 means that 4 128-bit values can be processed per iteration. That appear to be the + // optimal amount on x86_64 (SSE) and aarch64 (NEON). + super::generic_slide_hash_chain::<32>(table, wsize); + } +} + +#[cfg(target_arch = "x86_64")] +mod avx2 { + /// # Safety + /// + /// Behavior is undefined if the `avx` target feature is not enabled + #[target_feature(enable = "avx2")] + pub unsafe fn slide_hash_chain(table: &mut [u16], wsize: u16) { + // 64 means that 4 256-bit values can be processed per iteration. + // That appear to be the optimal amount for avx2. + // + // This vectorizes well https://godbolt.org/z/sGbdYba7K + super::generic_slide_hash_chain::<64>(table, wsize); } } @@ -85,38 +110,6 @@ mod neon { } } -#[cfg(target_arch = "x86_64")] -mod avx2 { - use core::arch::x86_64::{ - __m256i, _mm256_loadu_si256, _mm256_set1_epi16, _mm256_storeu_si256, _mm256_subs_epu16, - }; - - pub fn slide_hash_chain(table: &mut [u16], wsize: u16) { - assert!(crate::cpu_features::is_enabled_avx2()); - unsafe { slide_hash_chain_internal(table, wsize) } - } - - /// # Safety - /// - /// Behavior is undefined if the `avx` target feature is not enabled - #[target_feature(enable = "avx2")] - unsafe fn slide_hash_chain_internal(table: &mut [u16], wsize: u16) { - debug_assert_eq!(table.len() % 16, 0); - - let ymm_wsize = unsafe { _mm256_set1_epi16(wsize as i16) }; - - for chunk in table.chunks_exact_mut(16) { - let chunk = chunk.as_mut_ptr() as *mut __m256i; - - unsafe { - let value = _mm256_loadu_si256(chunk); - let result = _mm256_subs_epu16(value, ymm_wsize); - _mm256_storeu_si256(chunk, result); - } - } - } -} - #[cfg(target_arch = "wasm32")] mod wasm { use core::arch::wasm32::{u16x8_splat, u16x8_sub_sat, v128, v128_load, v128_store}; @@ -180,7 +173,7 @@ mod tests { if crate::cpu_features::is_enabled_avx2() { let mut input = INPUT; - avx2::slide_hash_chain(&mut input, WSIZE); + unsafe { avx2::slide_hash_chain(&mut input, WSIZE) }; assert_eq!(input, OUTPUT); } @@ -212,8 +205,8 @@ mod tests { quickcheck::quickcheck! { fn slide_is_rust_slide(v: Vec, wsize: u16) -> bool { - // pad to a multiple of 32 - let difference = v.len().next_multiple_of(32) - v.len(); + // pad to a multiple of 64 (the biggest chunk size currently in use) + let difference = v.len().next_multiple_of(64) - v.len(); let mut v = v; v.extend(core::iter::repeat(u16::MAX).take(difference));