From 4508aeeb34170ac99a6757f5b7612100b29aa20c Mon Sep 17 00:00:00 2001 From: Nazar Mokrynskyi Date: Wed, 28 May 2025 05:21:57 +0300 Subject: [PATCH 1/3] Faster PoT verification for CPUs that support AVX512F+VAES --- crates/subspace-proof-of-time/src/aes.rs | 55 ++++++++ .../subspace-proof-of-time/src/aes/x86_64.rs | 133 ++++++++++++++---- crates/subspace-proof-of-time/src/lib.rs | 1 + 3 files changed, 162 insertions(+), 27 deletions(-) diff --git a/crates/subspace-proof-of-time/src/aes.rs b/crates/subspace-proof-of-time/src/aes.rs index 6a7ad05ee39..21aea776568 100644 --- a/crates/subspace-proof-of-time/src/aes.rs +++ b/crates/subspace-proof-of-time/src/aes.rs @@ -51,6 +51,25 @@ pub(crate) fn verify_sequential( ) -> bool { assert_eq!(checkpoint_iterations % 2, 0); + #[cfg(target_arch = "x86_64")] + { + cpufeatures::new!(has_aes, "avx512f", "vaes"); + if has_aes::get() { + return unsafe { + x86_64::verify_sequential_avx512f(&seed, &key, checkpoints, checkpoint_iterations) + }; + } + } + + verify_sequential_generic(seed, key, checkpoints, checkpoint_iterations) +} + +fn verify_sequential_generic( + seed: PotSeed, + key: PotKey, + checkpoints: &PotCheckpoints, + checkpoint_iterations: u32, +) -> bool { let key = Array::from(*key); let cipher = Aes128::new(&key); @@ -113,6 +132,12 @@ mod tests { &checkpoints, checkpoint_iterations, )); + assert!(verify_sequential_generic( + seed, + key, + &checkpoints, + checkpoint_iterations, + )); // Decryption of invalid cipher text fails. let mut checkpoints_1 = checkpoints; @@ -123,6 +148,12 @@ mod tests { &checkpoints_1, checkpoint_iterations, )); + assert!(!verify_sequential_generic( + seed, + key, + &checkpoints_1, + checkpoint_iterations, + )); // Decryption with wrong number of iterations fails. assert!(!verify_sequential( @@ -131,12 +162,24 @@ mod tests { &checkpoints, checkpoint_iterations + 2, )); + assert!(!verify_sequential_generic( + seed, + key, + &checkpoints, + checkpoint_iterations + 2, + )); assert!(!verify_sequential( seed, key, &checkpoints, checkpoint_iterations - 2, )); + assert!(!verify_sequential_generic( + seed, + key, + &checkpoints, + checkpoint_iterations - 2, + )); // Decryption with wrong seed fails. assert!(!verify_sequential( @@ -145,6 +188,12 @@ mod tests { &checkpoints, checkpoint_iterations, )); + assert!(!verify_sequential_generic( + PotSeed::from(SEED_1), + key, + &checkpoints, + checkpoint_iterations, + )); // Decryption with wrong key fails. assert!(!verify_sequential( @@ -153,5 +202,11 @@ mod tests { &checkpoints, checkpoint_iterations, )); + assert!(!verify_sequential_generic( + seed, + PotKey::from(KEY_1), + &checkpoints, + checkpoint_iterations, + )); } } diff --git a/crates/subspace-proof-of-time/src/aes/x86_64.rs b/crates/subspace-proof-of-time/src/aes/x86_64.rs index f1873ee8999..79a74a338ad 100644 --- a/crates/subspace-proof-of-time/src/aes/x86_64.rs +++ b/crates/subspace-proof-of-time/src/aes/x86_64.rs @@ -1,6 +1,8 @@ use core::arch::x86_64::*; -use core::mem; -use subspace_core_primitives::pot::PotCheckpoints; +use core::{array, mem}; +use subspace_core_primitives::pot::{PotCheckpoints, PotOutput}; + +const NUM_ROUND_KEYS: usize = 11; /// Create PoT proof with checkpoints #[target_feature(enable = "aes")] @@ -12,40 +14,116 @@ pub(super) unsafe fn create( ) -> PotCheckpoints { let mut checkpoints = PotCheckpoints::default(); - let keys_reg = expand_key(key); - let xor_key = _mm_xor_si128(keys_reg[10], keys_reg[0]); - let mut seed_reg = _mm_loadu_si128(seed.as_ptr() as *const __m128i); - seed_reg = _mm_xor_si128(seed_reg, keys_reg[0]); - for checkpoint in checkpoints.iter_mut() { - for _ in 0..checkpoint_iterations { - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[1]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[2]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[3]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[4]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[5]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[6]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[7]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[8]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[9]); - seed_reg = _mm_aesenclast_si128(seed_reg, xor_key); - } + unsafe { + let keys_reg = expand_key(key); + let xor_key = _mm_xor_si128(keys_reg[10], keys_reg[0]); + let mut seed_reg = _mm_loadu_si128(seed.as_ptr() as *const __m128i); + seed_reg = _mm_xor_si128(seed_reg, keys_reg[0]); + for checkpoint in checkpoints.iter_mut() { + for _ in 0..checkpoint_iterations { + seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[1]); + seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[2]); + seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[3]); + seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[4]); + seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[5]); + seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[6]); + seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[7]); + seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[8]); + seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[9]); + seed_reg = _mm_aesenclast_si128(seed_reg, xor_key); + } - let checkpoint_reg = _mm_xor_si128(seed_reg, keys_reg[0]); - _mm_storeu_si128( - checkpoint.as_mut().as_mut_ptr() as *mut __m128i, - checkpoint_reg, - ); + let checkpoint_reg = _mm_xor_si128(seed_reg, keys_reg[0]); + _mm_storeu_si128(checkpoint.as_mut_ptr() as *mut __m128i, checkpoint_reg); + } } checkpoints } +/// Verification mimics `create` function, but also has decryption half for better performance +#[target_feature(enable = "avx512f,vaes")] +#[inline] +pub(super) unsafe fn verify_sequential_avx512f( + seed: &[u8; 16], + key: &[u8; 16], + checkpoints: &PotCheckpoints, + checkpoint_iterations: u32, +) -> bool { + let checkpoints = PotOutput::repr_from_slice(checkpoints.as_slice()); + + unsafe { + let keys_reg = expand_key(key); + let xor_key = _mm_xor_si128(keys_reg[10], keys_reg[0]); + let xor_key_512 = _mm512_broadcast_i32x4(xor_key); + + // Invert keys for decryption + let mut inv_keys = keys_reg; + for i in 1..10 { + inv_keys[i] = _mm_aesimc_si128(keys_reg[10 - i]); + } + + let keys_512 = array::from_fn::<_, NUM_ROUND_KEYS, _>(|i| _mm512_broadcast_i32x4(keys_reg[i])); + let inv_keys_512 = + array::from_fn::<_, NUM_ROUND_KEYS, _>(|i| _mm512_broadcast_i32x4(inv_keys[i])); + + let mut input_0 = [[0u8; 16]; 4]; + input_0[0] = *seed; + input_0[1..].copy_from_slice(&checkpoints[..3]); + let mut input_0 = _mm512_loadu_si512(input_0.as_ptr() as *const __m512i); + let mut input_1 = _mm512_loadu_si512(checkpoints[3..7].as_ptr() as *const __m512i); + + let mut output_0 = _mm512_loadu_si512(checkpoints[0..4].as_ptr() as *const __m512i); + let mut output_1 = _mm512_loadu_si512(checkpoints[4..8].as_ptr() as *const __m512i); + + input_0 = _mm512_xor_si512(input_0, keys_512[0]); + input_1 = _mm512_xor_si512(input_1, keys_512[0]); + + output_0 = _mm512_xor_si512(output_0, keys_512[10]); + output_1 = _mm512_xor_si512(output_1, keys_512[10]); + + for _ in 0..checkpoint_iterations / 2 { + for i in 1..10 { + input_0 = _mm512_aesenc_epi128(input_0, keys_512[i]); + input_1 = _mm512_aesenc_epi128(input_1, keys_512[i]); + + output_0 = _mm512_aesdec_epi128(output_0, inv_keys_512[i]); + output_1 = _mm512_aesdec_epi128(output_1, inv_keys_512[i]); + } + + input_0 = _mm512_aesenclast_epi128(input_0, xor_key_512); + input_1 = _mm512_aesenclast_epi128(input_1, xor_key_512); + + output_0 = _mm512_aesdeclast_epi128(output_0, xor_key_512); + output_1 = _mm512_aesdeclast_epi128(output_1, xor_key_512); + } + + // Code below is a more efficient version of this: + // input_0 = _mm512_xor_si512(input_0, keys_512[0]); + // input_1 = _mm512_xor_si512(input_1, keys_512[0]); + // output_0 = _mm512_xor_si512(output_0, keys_512[10]); + // output_1 = _mm512_xor_si512(output_1, keys_512[10]); + // + // let mask0 = _mm512_cmpeq_epu64_mask(input_0, output_0); + // let mask1 = _mm512_cmpeq_epu64_mask(input_1, output_1); + + let diff_0 = _mm512_xor_si512(input_0, output_0); + let diff_1 = _mm512_xor_si512(input_1, output_1); + + let mask0 = _mm512_cmpeq_epu64_mask(diff_0, xor_key_512); + let mask1 = _mm512_cmpeq_epu64_mask(diff_1, xor_key_512); + + // All inputs match outputs + (mask0 & mask1) == u8::MAX + } +} + // Below code copied with minor changes from following place under MIT/Apache-2.0 license by Artyom // Pavlov: // https://github.com/RustCrypto/block-ciphers/blob/9413fcadd28d53854954498c0589b747d8e4ade2/aes/src/ni/aes128.rs /// AES-128 round keys -type RoundKeys = [__m128i; 11]; +type RoundKeys = [__m128i; NUM_ROUND_KEYS]; macro_rules! expand_round { ($keys:expr, $pos:expr, $round:expr) => { @@ -72,9 +150,10 @@ macro_rules! expand_round { unsafe fn expand_key(key: &[u8; 16]) -> RoundKeys { // SAFETY: `RoundKeys` is a `[__m128i; 11]` which can be initialized // with all zeroes. - let mut keys: RoundKeys = mem::zeroed(); + let mut keys: RoundKeys = unsafe { mem::zeroed() }; - let k = _mm_loadu_si128(key.as_ptr() as *const __m128i); + // SAFETY: No alignment requirement in `_mm_loadu_si128` + let k = unsafe { _mm_loadu_si128(key.as_ptr() as *const __m128i) }; keys[0] = k; expand_round!(keys, 1, 0x01); diff --git a/crates/subspace-proof-of-time/src/lib.rs b/crates/subspace-proof-of-time/src/lib.rs index 182bf90a81b..a5908679fb8 100644 --- a/crates/subspace-proof-of-time/src/lib.rs +++ b/crates/subspace-proof-of-time/src/lib.rs @@ -1,5 +1,6 @@ //! Proof of time implementation. +#![cfg_attr(target_arch = "x86_64", feature(stdarch_x86_avx512))] #![no_std] mod aes; From bb4d9af885425e2265c6b5c0d96439855a7d8e0f Mon Sep 17 00:00:00 2001 From: Nazar Mokrynskyi Date: Thu, 29 May 2025 06:56:50 +0300 Subject: [PATCH 2/3] Reduce `unsafe` in `subspace-proof-of-time` --- crates/subspace-proof-of-time/src/aes.rs | 6 +- .../subspace-proof-of-time/src/aes/x86_64.rs | 195 +++++++++--------- crates/subspace-proof-of-time/src/lib.rs | 1 + 3 files changed, 99 insertions(+), 103 deletions(-) diff --git a/crates/subspace-proof-of-time/src/aes.rs b/crates/subspace-proof-of-time/src/aes.rs index 21aea776568..0796ae8c5dc 100644 --- a/crates/subspace-proof-of-time/src/aes.rs +++ b/crates/subspace-proof-of-time/src/aes.rs @@ -15,6 +15,7 @@ pub(crate) fn create(seed: PotSeed, key: PotKey, checkpoint_iterations: u32) -> { cpufeatures::new!(has_aes, "aes"); if has_aes::get() { + // SAFETY: Checked `aes` feature return unsafe { x86_64::create(seed.as_ref(), key.as_ref(), checkpoint_iterations) }; } } @@ -53,8 +54,9 @@ pub(crate) fn verify_sequential( #[cfg(target_arch = "x86_64")] { - cpufeatures::new!(has_aes, "avx512f", "vaes"); - if has_aes::get() { + cpufeatures::new!(has_avx512f_vaes, "avx512f", "vaes"); + if has_avx512f_vaes::get() { + // SAFETY: Checked `avx512f` and `vaes` features return unsafe { x86_64::verify_sequential_avx512f(&seed, &key, checkpoints, checkpoint_iterations) }; diff --git a/crates/subspace-proof-of-time/src/aes/x86_64.rs b/crates/subspace-proof-of-time/src/aes/x86_64.rs index 79a74a338ad..0765972e68a 100644 --- a/crates/subspace-proof-of-time/src/aes/x86_64.rs +++ b/crates/subspace-proof-of-time/src/aes/x86_64.rs @@ -1,5 +1,6 @@ use core::arch::x86_64::*; -use core::{array, mem}; +use core::array; +use core::simd::{u8x16, u8x64}; use subspace_core_primitives::pot::{PotCheckpoints, PotOutput}; const NUM_ROUND_KEYS: usize = 11; @@ -7,35 +8,33 @@ const NUM_ROUND_KEYS: usize = 11; /// Create PoT proof with checkpoints #[target_feature(enable = "aes")] #[inline] -pub(super) unsafe fn create( +pub(super) fn create( seed: &[u8; 16], key: &[u8; 16], checkpoint_iterations: u32, ) -> PotCheckpoints { let mut checkpoints = PotCheckpoints::default(); - unsafe { - let keys_reg = expand_key(key); - let xor_key = _mm_xor_si128(keys_reg[10], keys_reg[0]); - let mut seed_reg = _mm_loadu_si128(seed.as_ptr() as *const __m128i); - seed_reg = _mm_xor_si128(seed_reg, keys_reg[0]); - for checkpoint in checkpoints.iter_mut() { - for _ in 0..checkpoint_iterations { - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[1]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[2]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[3]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[4]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[5]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[6]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[7]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[8]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[9]); - seed_reg = _mm_aesenclast_si128(seed_reg, xor_key); - } - - let checkpoint_reg = _mm_xor_si128(seed_reg, keys_reg[0]); - _mm_storeu_si128(checkpoint.as_mut_ptr() as *mut __m128i, checkpoint_reg); + let keys_reg = expand_key(key); + let xor_key = _mm_xor_si128(keys_reg[10], keys_reg[0]); + let mut seed_reg = __m128i::from(u8x16::from_array(*seed)); + seed_reg = _mm_xor_si128(seed_reg, keys_reg[0]); + for checkpoint in checkpoints.iter_mut() { + for _ in 0..checkpoint_iterations { + seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[1]); + seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[2]); + seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[3]); + seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[4]); + seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[5]); + seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[6]); + seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[7]); + seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[8]); + seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[9]); + seed_reg = _mm_aesenclast_si128(seed_reg, xor_key); } + + let checkpoint_reg = _mm_xor_si128(seed_reg, keys_reg[0]); + **checkpoint = u8x16::from(checkpoint_reg).to_array(); } checkpoints @@ -44,7 +43,7 @@ pub(super) unsafe fn create( /// Verification mimics `create` function, but also has decryption half for better performance #[target_feature(enable = "avx512f,vaes")] #[inline] -pub(super) unsafe fn verify_sequential_avx512f( +pub(super) fn verify_sequential_avx512f( seed: &[u8; 16], key: &[u8; 16], checkpoints: &PotCheckpoints, @@ -52,37 +51,40 @@ pub(super) unsafe fn verify_sequential_avx512f( ) -> bool { let checkpoints = PotOutput::repr_from_slice(checkpoints.as_slice()); - unsafe { - let keys_reg = expand_key(key); - let xor_key = _mm_xor_si128(keys_reg[10], keys_reg[0]); - let xor_key_512 = _mm512_broadcast_i32x4(xor_key); + let keys = expand_key(key); + let xor_key = _mm_xor_si128(keys[10], keys[0]); + let xor_key_512 = _mm512_broadcast_i32x4(xor_key); - // Invert keys for decryption - let mut inv_keys = keys_reg; - for i in 1..10 { - inv_keys[i] = _mm_aesimc_si128(keys_reg[10 - i]); - } + // Invert keys for decryption, the first and last element is not used below, hence they are + // copied as is from encryption keys (otherwise the first and last element would need to be + // swapped) + let mut inv_keys = keys; + for i in 1..10 { + inv_keys[i] = _mm_aesimc_si128(keys[10 - i]); + } - let keys_512 = array::from_fn::<_, NUM_ROUND_KEYS, _>(|i| _mm512_broadcast_i32x4(keys_reg[i])); - let inv_keys_512 = - array::from_fn::<_, NUM_ROUND_KEYS, _>(|i| _mm512_broadcast_i32x4(inv_keys[i])); + let keys_512 = array::from_fn::<_, NUM_ROUND_KEYS, _>(|i| _mm512_broadcast_i32x4(keys[i])); + let inv_keys_512 = + array::from_fn::<_, NUM_ROUND_KEYS, _>(|i| _mm512_broadcast_i32x4(inv_keys[i])); - let mut input_0 = [[0u8; 16]; 4]; - input_0[0] = *seed; - input_0[1..].copy_from_slice(&checkpoints[..3]); - let mut input_0 = _mm512_loadu_si512(input_0.as_ptr() as *const __m512i); - let mut input_1 = _mm512_loadu_si512(checkpoints[3..7].as_ptr() as *const __m512i); + let mut input_0 = [[0u8; 16]; 4]; + input_0[0] = *seed; + input_0[1..].copy_from_slice(&checkpoints[..3]); + let mut input_0 = __m512i::from(u8x64::from_slice(input_0.as_flattened())); + let mut input_1 = __m512i::from(u8x64::from_slice(checkpoints[3..7].as_flattened())); - let mut output_0 = _mm512_loadu_si512(checkpoints[0..4].as_ptr() as *const __m512i); - let mut output_1 = _mm512_loadu_si512(checkpoints[4..8].as_ptr() as *const __m512i); + let mut output_0 = __m512i::from(u8x64::from_slice(checkpoints[0..4].as_flattened())); + let mut output_1 = __m512i::from(u8x64::from_slice(checkpoints[4..8].as_flattened())); - input_0 = _mm512_xor_si512(input_0, keys_512[0]); - input_1 = _mm512_xor_si512(input_1, keys_512[0]); + input_0 = _mm512_xor_si512(input_0, keys_512[0]); + input_1 = _mm512_xor_si512(input_1, keys_512[0]); - output_0 = _mm512_xor_si512(output_0, keys_512[10]); - output_1 = _mm512_xor_si512(output_1, keys_512[10]); + output_0 = _mm512_xor_si512(output_0, keys_512[10]); + output_1 = _mm512_xor_si512(output_1, keys_512[10]); - for _ in 0..checkpoint_iterations / 2 { + for _ in 0..checkpoint_iterations / 2 { + // TODO: Shouldn't be unsafe: https://github.com/rust-lang/rust/issues/141718 + unsafe { for i in 1..10 { input_0 = _mm512_aesenc_epi128(input_0, keys_512[i]); input_1 = _mm512_aesenc_epi128(input_1, keys_512[i]); @@ -97,75 +99,66 @@ pub(super) unsafe fn verify_sequential_avx512f( output_0 = _mm512_aesdeclast_epi128(output_0, xor_key_512); output_1 = _mm512_aesdeclast_epi128(output_1, xor_key_512); } + } - // Code below is a more efficient version of this: - // input_0 = _mm512_xor_si512(input_0, keys_512[0]); - // input_1 = _mm512_xor_si512(input_1, keys_512[0]); - // output_0 = _mm512_xor_si512(output_0, keys_512[10]); - // output_1 = _mm512_xor_si512(output_1, keys_512[10]); - // - // let mask0 = _mm512_cmpeq_epu64_mask(input_0, output_0); - // let mask1 = _mm512_cmpeq_epu64_mask(input_1, output_1); + // Code below is a more efficient version of this: + // input_0 = _mm512_xor_si512(input_0, keys_512[0]); + // input_1 = _mm512_xor_si512(input_1, keys_512[0]); + // output_0 = _mm512_xor_si512(output_0, keys_512[10]); + // output_1 = _mm512_xor_si512(output_1, keys_512[10]); + // + // let mask0 = _mm512_cmpeq_epu64_mask(input_0, output_0); + // let mask1 = _mm512_cmpeq_epu64_mask(input_1, output_1); - let diff_0 = _mm512_xor_si512(input_0, output_0); - let diff_1 = _mm512_xor_si512(input_1, output_1); + let diff_0 = _mm512_xor_si512(input_0, output_0); + let diff_1 = _mm512_xor_si512(input_1, output_1); - let mask0 = _mm512_cmpeq_epu64_mask(diff_0, xor_key_512); - let mask1 = _mm512_cmpeq_epu64_mask(diff_1, xor_key_512); + let mask0 = _mm512_cmpeq_epu64_mask(diff_0, xor_key_512); + let mask1 = _mm512_cmpeq_epu64_mask(diff_1, xor_key_512); - // All inputs match outputs - (mask0 & mask1) == u8::MAX - } + // All inputs match outputs + (mask0 & mask1) == u8::MAX } -// Below code copied with minor changes from following place under MIT/Apache-2.0 license by Artyom -// Pavlov: -// https://github.com/RustCrypto/block-ciphers/blob/9413fcadd28d53854954498c0589b747d8e4ade2/aes/src/ni/aes128.rs +// Below code copied with minor changes from the following place under MIT/Apache-2.0 license by +// Artyom Pavlov: +// https://github.com/RustCrypto/block-ciphers/blob/fbb68f40b122909d92e40ee8a50112b6e5d0af8f/aes/src/ni/expand.rs -/// AES-128 round keys -type RoundKeys = [__m128i; NUM_ROUND_KEYS]; - -macro_rules! expand_round { - ($keys:expr, $pos:expr, $round:expr) => { - let mut t1 = $keys[$pos - 1]; +#[target_feature(enable = "aes")] +fn expand_key(key: &[u8; 16]) -> [__m128i; NUM_ROUND_KEYS] { + #[target_feature(enable = "aes")] + fn expand_round(keys: &mut [__m128i; NUM_ROUND_KEYS], pos: usize) { + let mut t1 = keys[pos - 1]; let mut t2; let mut t3; - t2 = _mm_aeskeygenassist_si128(t1, $round); - t2 = _mm_shuffle_epi32(t2, 0xff); - t3 = _mm_slli_si128(t1, 0x4); + t2 = _mm_aeskeygenassist_si128::(t1); + t2 = _mm_shuffle_epi32::<0xff>(t2); + t3 = _mm_slli_si128::<0x4>(t1); t1 = _mm_xor_si128(t1, t3); - t3 = _mm_slli_si128(t3, 0x4); + t3 = _mm_slli_si128::<0x4>(t3); t1 = _mm_xor_si128(t1, t3); - t3 = _mm_slli_si128(t3, 0x4); + t3 = _mm_slli_si128::<0x4>(t3); t1 = _mm_xor_si128(t1, t3); t1 = _mm_xor_si128(t1, t2); - $keys[$pos] = t1; - }; -} + keys[pos] = t1; + } -#[target_feature(enable = "aes")] -#[inline] -unsafe fn expand_key(key: &[u8; 16]) -> RoundKeys { - // SAFETY: `RoundKeys` is a `[__m128i; 11]` which can be initialized - // with all zeroes. - let mut keys: RoundKeys = unsafe { mem::zeroed() }; - - // SAFETY: No alignment requirement in `_mm_loadu_si128` - let k = unsafe { _mm_loadu_si128(key.as_ptr() as *const __m128i) }; - keys[0] = k; - - expand_round!(keys, 1, 0x01); - expand_round!(keys, 2, 0x02); - expand_round!(keys, 3, 0x04); - expand_round!(keys, 4, 0x08); - expand_round!(keys, 5, 0x10); - expand_round!(keys, 6, 0x20); - expand_round!(keys, 7, 0x40); - expand_round!(keys, 8, 0x80); - expand_round!(keys, 9, 0x1B); - expand_round!(keys, 10, 0x36); + let mut keys = [_mm_setzero_si128(); NUM_ROUND_KEYS]; + keys[0] = __m128i::from(u8x16::from(*key)); + + let kr = &mut keys; + expand_round::<0x01>(kr, 1); + expand_round::<0x02>(kr, 2); + expand_round::<0x04>(kr, 3); + expand_round::<0x08>(kr, 4); + expand_round::<0x10>(kr, 5); + expand_round::<0x20>(kr, 6); + expand_round::<0x40>(kr, 7); + expand_round::<0x80>(kr, 8); + expand_round::<0x1B>(kr, 9); + expand_round::<0x36>(kr, 10); keys } diff --git a/crates/subspace-proof-of-time/src/lib.rs b/crates/subspace-proof-of-time/src/lib.rs index a5908679fb8..0f3ca68a6bd 100644 --- a/crates/subspace-proof-of-time/src/lib.rs +++ b/crates/subspace-proof-of-time/src/lib.rs @@ -1,6 +1,7 @@ //! Proof of time implementation. #![cfg_attr(target_arch = "x86_64", feature(stdarch_x86_avx512))] +#![feature(portable_simd)] #![no_std] mod aes; From 7a1ed142d69d58b8b57cf331154c038d515aae3c Mon Sep 17 00:00:00 2001 From: Nazar Mokrynskyi Date: Sat, 31 May 2025 05:26:47 +0300 Subject: [PATCH 3/3] Implement PoT verification optimized for AVX2+VAES and AES+SSE4.1 --- crates/subspace-proof-of-time/src/aes.rs | 135 +++++++----- .../subspace-proof-of-time/src/aes/x86_64.rs | 195 +++++++++++++++++- 2 files changed, 275 insertions(+), 55 deletions(-) diff --git a/crates/subspace-proof-of-time/src/aes.rs b/crates/subspace-proof-of-time/src/aes.rs index 0796ae8c5dc..f3e25563379 100644 --- a/crates/subspace-proof-of-time/src/aes.rs +++ b/crates/subspace-proof-of-time/src/aes.rs @@ -58,7 +58,28 @@ pub(crate) fn verify_sequential( if has_avx512f_vaes::get() { // SAFETY: Checked `avx512f` and `vaes` features return unsafe { - x86_64::verify_sequential_avx512f(&seed, &key, checkpoints, checkpoint_iterations) + x86_64::verify_sequential_avx512f_vaes( + &seed, + &key, + checkpoints, + checkpoint_iterations, + ) + }; + } + + cpufeatures::new!(has_avx2_vaes, "avx2", "vaes"); + if has_avx2_vaes::get() { + // SAFETY: Checked `avx2` and `vaes` features + return unsafe { + x86_64::verify_sequential_avx2_vaes(&seed, &key, checkpoints, checkpoint_iterations) + }; + } + + cpufeatures::new!(has_aes_sse41, "aes", "sse4.1"); + if has_aes_sse41::get() { + // SAFETY: Checked `aes` and `sse4.1` features + return unsafe { + x86_64::verify_sequential_aes_sse41(&seed, &key, checkpoints, checkpoint_iterations) }; } } @@ -115,6 +136,65 @@ mod tests { ]; const BAD_CIPHER: [u8; 16] = [22; 16]; + fn verify_test( + seed: PotSeed, + key: PotKey, + checkpoints: &PotCheckpoints, + checkpoint_iterations: u32, + ) -> bool { + let sequential = verify_sequential(seed, key, checkpoints, checkpoint_iterations); + let sequential_generic = + verify_sequential_generic(seed, key, checkpoints, checkpoint_iterations); + assert_eq!(sequential, sequential_generic); + + #[cfg(target_arch = "x86_64")] + { + cpufeatures::new!(has_avx512f_vaes, "avx512f", "vaes"); + if has_avx512f_vaes::get() { + // SAFETY: Checked `avx512f` and `vaes` features + let avx512f_vaes = unsafe { + x86_64::verify_sequential_avx512f_vaes( + &seed, + &key, + checkpoints, + checkpoint_iterations, + ) + }; + assert_eq!(sequential, avx512f_vaes); + } + + cpufeatures::new!(has_avx2_vaes, "avx2", "vaes"); + if has_avx2_vaes::get() { + // SAFETY: Checked `avx2` and `vaes` features + let avx2_vaes = unsafe { + x86_64::verify_sequential_avx2_vaes( + &seed, + &key, + checkpoints, + checkpoint_iterations, + ) + }; + assert_eq!(sequential, avx2_vaes); + } + + cpufeatures::new!(has_aes_sse41, "aes", "sse4.1"); + if has_aes_sse41::get() { + // SAFETY: Checked `aes` and `sse4.1` features + let aes = unsafe { + x86_64::verify_sequential_aes_sse41( + &seed, + &key, + checkpoints, + checkpoint_iterations, + ) + }; + assert_eq!(sequential, aes); + } + } + + sequential + } + #[test] fn test_create_verify() { let seed = PotSeed::from(SEED); @@ -128,29 +208,12 @@ mod tests { assert_eq!(checkpoints, generic_checkpoints); } - assert!(verify_sequential( - seed, - key, - &checkpoints, - checkpoint_iterations, - )); - assert!(verify_sequential_generic( - seed, - key, - &checkpoints, - checkpoint_iterations, - )); + assert!(verify_test(seed, key, &checkpoints, checkpoint_iterations,)); // Decryption of invalid cipher text fails. let mut checkpoints_1 = checkpoints; checkpoints_1[0] = PotOutput::from(BAD_CIPHER); - assert!(!verify_sequential( - seed, - key, - &checkpoints_1, - checkpoint_iterations, - )); - assert!(!verify_sequential_generic( + assert!(!verify_test( seed, key, &checkpoints_1, @@ -158,25 +221,13 @@ mod tests { )); // Decryption with wrong number of iterations fails. - assert!(!verify_sequential( - seed, - key, - &checkpoints, - checkpoint_iterations + 2, - )); - assert!(!verify_sequential_generic( + assert!(!verify_test( seed, key, &checkpoints, checkpoint_iterations + 2, )); - assert!(!verify_sequential( - seed, - key, - &checkpoints, - checkpoint_iterations - 2, - )); - assert!(!verify_sequential_generic( + assert!(!verify_test( seed, key, &checkpoints, @@ -184,13 +235,7 @@ mod tests { )); // Decryption with wrong seed fails. - assert!(!verify_sequential( - PotSeed::from(SEED_1), - key, - &checkpoints, - checkpoint_iterations, - )); - assert!(!verify_sequential_generic( + assert!(!verify_test( PotSeed::from(SEED_1), key, &checkpoints, @@ -198,13 +243,7 @@ mod tests { )); // Decryption with wrong key fails. - assert!(!verify_sequential( - seed, - PotKey::from(KEY_1), - &checkpoints, - checkpoint_iterations, - )); - assert!(!verify_sequential_generic( + assert!(!verify_test( seed, PotKey::from(KEY_1), &checkpoints, diff --git a/crates/subspace-proof-of-time/src/aes/x86_64.rs b/crates/subspace-proof-of-time/src/aes/x86_64.rs index 0765972e68a..7bf70b99afe 100644 --- a/crates/subspace-proof-of-time/src/aes/x86_64.rs +++ b/crates/subspace-proof-of-time/src/aes/x86_64.rs @@ -1,6 +1,6 @@ use core::arch::x86_64::*; use core::array; -use core::simd::{u8x16, u8x64}; +use core::simd::{u8x16, u8x32, u8x64}; use subspace_core_primitives::pot::{PotCheckpoints, PotOutput}; const NUM_ROUND_KEYS: usize = 11; @@ -40,10 +40,191 @@ pub(super) fn create( checkpoints } +/// Verification mimics `create` function, but also has decryption half for better performance +#[target_feature(enable = "aes,sse4.1")] +#[inline] +pub(super) fn verify_sequential_aes_sse41( + seed: &[u8; 16], + key: &[u8; 16], + checkpoints: &PotCheckpoints, + checkpoint_iterations: u32, +) -> bool { + let checkpoints = PotOutput::repr_from_slice(checkpoints.as_slice()); + + let keys = expand_key(key); + let xor_key = _mm_xor_si128(keys[10], keys[0]); + + // Invert keys for decryption, the first and last element is not used below, hence they are + // copied as is from encryption keys (otherwise the first and last element would need to be + // swapped) + let mut inv_keys = keys; + for i in 1..10 { + inv_keys[i] = _mm_aesimc_si128(keys[10 - i]); + } + + let mut inputs: [__m128i; 8] = [ + __m128i::from(u8x16::from(*seed)), + __m128i::from(u8x16::from(checkpoints[0])), + __m128i::from(u8x16::from(checkpoints[1])), + __m128i::from(u8x16::from(checkpoints[2])), + __m128i::from(u8x16::from(checkpoints[3])), + __m128i::from(u8x16::from(checkpoints[4])), + __m128i::from(u8x16::from(checkpoints[5])), + __m128i::from(u8x16::from(checkpoints[6])), + ]; + + let mut outputs: [__m128i; 8] = [ + __m128i::from(u8x16::from(checkpoints[0])), + __m128i::from(u8x16::from(checkpoints[1])), + __m128i::from(u8x16::from(checkpoints[2])), + __m128i::from(u8x16::from(checkpoints[3])), + __m128i::from(u8x16::from(checkpoints[4])), + __m128i::from(u8x16::from(checkpoints[5])), + __m128i::from(u8x16::from(checkpoints[6])), + __m128i::from(u8x16::from(checkpoints[7])), + ]; + + inputs = inputs.map(|input| _mm_xor_si128(input, keys[0])); + outputs = outputs.map(|output| _mm_xor_si128(output, keys[10])); + + for _ in 0..checkpoint_iterations / 2 { + for i in 1..10 { + inputs = inputs.map(|input| _mm_aesenc_si128(input, keys[i])); + outputs = outputs.map(|output| _mm_aesdec_si128(output, inv_keys[i])); + } + + inputs = inputs.map(|input| _mm_aesenclast_si128(input, xor_key)); + outputs = outputs.map(|output| _mm_aesdeclast_si128(output, xor_key)); + } + + // All bits set + let all_ones = _mm_set1_epi8(-1); + + inputs.into_iter().zip(outputs).all(|(input, output)| { + let diff = _mm_xor_si128(input, output); + let cmp = _mm_xor_si128(diff, xor_key); + _mm_test_all_zeros(cmp, all_ones) == 1 + }) +} + +/// Verification mimics `create` function, but also has decryption half for better performance +#[target_feature(enable = "avx2,vaes")] +#[inline] +pub(super) fn verify_sequential_avx2_vaes( + seed: &[u8; 16], + key: &[u8; 16], + checkpoints: &PotCheckpoints, + checkpoint_iterations: u32, +) -> bool { + let checkpoints = PotOutput::repr_from_slice(checkpoints.as_slice()); + + let keys = expand_key(key); + let xor_key = _mm_xor_si128(keys[10], keys[0]); + let xor_key_256 = _mm256_broadcastsi128_si256(xor_key); + + // Invert keys for decryption, the first and last element is not used below, hence they are + // copied as is from encryption keys (otherwise the first and last element would need to be + // swapped) + let mut inv_keys = keys; + for i in 1..10 { + inv_keys[i] = _mm_aesimc_si128(keys[10 - i]); + } + + let keys_256 = array::from_fn::<_, NUM_ROUND_KEYS, _>(|i| _mm256_broadcastsi128_si256(keys[i])); + let inv_keys_256 = + array::from_fn::<_, NUM_ROUND_KEYS, _>(|i| _mm256_broadcastsi128_si256(inv_keys[i])); + + let mut input_0 = [[0u8; 16]; 2]; + input_0[0] = *seed; + input_0[1] = checkpoints[0]; + let mut input_0 = __m256i::from(u8x32::from_slice(input_0.as_flattened())); + + let mut input_1 = __m256i::from(u8x32::from_slice(checkpoints[1..3].as_flattened())); + let mut input_2 = __m256i::from(u8x32::from_slice(checkpoints[3..5].as_flattened())); + let mut input_3 = __m256i::from(u8x32::from_slice(checkpoints[5..7].as_flattened())); + + let mut output_0 = __m256i::from(u8x32::from_slice(checkpoints[0..2].as_flattened())); + let mut output_1 = __m256i::from(u8x32::from_slice(checkpoints[2..4].as_flattened())); + let mut output_2 = __m256i::from(u8x32::from_slice(checkpoints[4..6].as_flattened())); + let mut output_3 = __m256i::from(u8x32::from_slice(checkpoints[6..8].as_flattened())); + + input_0 = _mm256_xor_si256(input_0, keys_256[0]); + input_1 = _mm256_xor_si256(input_1, keys_256[0]); + input_2 = _mm256_xor_si256(input_2, keys_256[0]); + input_3 = _mm256_xor_si256(input_3, keys_256[0]); + + output_0 = _mm256_xor_si256(output_0, keys_256[10]); + output_1 = _mm256_xor_si256(output_1, keys_256[10]); + output_2 = _mm256_xor_si256(output_2, keys_256[10]); + output_3 = _mm256_xor_si256(output_3, keys_256[10]); + + for _ in 0..checkpoint_iterations / 2 { + // TODO: Shouldn't be unsafe: https://github.com/rust-lang/rust/issues/141718 + unsafe { + for i in 1..10 { + input_0 = _mm256_aesenc_epi128(input_0, keys_256[i]); + input_1 = _mm256_aesenc_epi128(input_1, keys_256[i]); + input_2 = _mm256_aesenc_epi128(input_2, keys_256[i]); + input_3 = _mm256_aesenc_epi128(input_3, keys_256[i]); + + output_0 = _mm256_aesdec_epi128(output_0, inv_keys_256[i]); + output_1 = _mm256_aesdec_epi128(output_1, inv_keys_256[i]); + output_2 = _mm256_aesdec_epi128(output_2, inv_keys_256[i]); + output_3 = _mm256_aesdec_epi128(output_3, inv_keys_256[i]); + } + + input_0 = _mm256_aesenclast_epi128(input_0, xor_key_256); + input_1 = _mm256_aesenclast_epi128(input_1, xor_key_256); + input_2 = _mm256_aesenclast_epi128(input_2, xor_key_256); + input_3 = _mm256_aesenclast_epi128(input_3, xor_key_256); + + output_0 = _mm256_aesdeclast_epi128(output_0, xor_key_256); + output_1 = _mm256_aesdeclast_epi128(output_1, xor_key_256); + output_2 = _mm256_aesdeclast_epi128(output_2, xor_key_256); + output_3 = _mm256_aesdeclast_epi128(output_3, xor_key_256); + } + } + + // Code below is a more efficient version of this: + // input_0 = _mm256_xor_si256(input_0, keys_256[0]); + // input_1 = _mm256_xor_si256(input_1, keys_256[0]); + // input_2 = _mm256_xor_si256(input_2, keys_256[0]); + // input_3 = _mm256_xor_si256(input_3, keys_256[0]); + // output_0 = _mm256_xor_si256(output_0, keys_256[10]); + // output_1 = _mm256_xor_si256(output_1, keys_256[10]); + // output_2 = _mm256_xor_si256(output_2, keys_256[10]); + // output_3 = _mm256_xor_si256(output_3, keys_256[10]); + // + // let mask_0 = _mm256_cmpeq_epi64(input_0, output_0); + // let mask_1 = _mm256_cmpeq_epi64(input_1, output_1); + // let mask_2 = _mm256_cmpeq_epi64(input_2, output_1); + // let mask_3 = _mm256_cmpeq_epi64(input_3, output_1); + + let diff_0 = _mm256_xor_si256(input_0, output_0); + let diff_1 = _mm256_xor_si256(input_1, output_1); + let diff_2 = _mm256_xor_si256(input_2, output_2); + let diff_3 = _mm256_xor_si256(input_3, output_3); + + let mask_0 = _mm256_cmpeq_epi64(diff_0, xor_key_256); + let mask_1 = _mm256_cmpeq_epi64(diff_1, xor_key_256); + let mask_2 = _mm256_cmpeq_epi64(diff_2, xor_key_256); + let mask_3 = _mm256_cmpeq_epi64(diff_3, xor_key_256); + + // All bits set + let all_ones = _mm256_set1_epi64x(-1); + + let match_0 = _mm256_testc_si256(mask_0, all_ones) != 0; + let match_1 = _mm256_testc_si256(mask_1, all_ones) != 0; + let match_2 = _mm256_testc_si256(mask_2, all_ones) != 0; + let match_3 = _mm256_testc_si256(mask_3, all_ones) != 0; + + match_0 && match_1 && match_2 && match_3 +} + /// Verification mimics `create` function, but also has decryption half for better performance #[target_feature(enable = "avx512f,vaes")] #[inline] -pub(super) fn verify_sequential_avx512f( +pub(super) fn verify_sequential_avx512f_vaes( seed: &[u8; 16], key: &[u8; 16], checkpoints: &PotCheckpoints, @@ -107,17 +288,17 @@ pub(super) fn verify_sequential_avx512f( // output_0 = _mm512_xor_si512(output_0, keys_512[10]); // output_1 = _mm512_xor_si512(output_1, keys_512[10]); // - // let mask0 = _mm512_cmpeq_epu64_mask(input_0, output_0); - // let mask1 = _mm512_cmpeq_epu64_mask(input_1, output_1); + // let mask_0 = _mm512_cmpeq_epu64_mask(input_0, output_0); + // let mask_1 = _mm512_cmpeq_epu64_mask(input_1, output_1); let diff_0 = _mm512_xor_si512(input_0, output_0); let diff_1 = _mm512_xor_si512(input_1, output_1); - let mask0 = _mm512_cmpeq_epu64_mask(diff_0, xor_key_512); - let mask1 = _mm512_cmpeq_epu64_mask(diff_1, xor_key_512); + let mask_0 = _mm512_cmpeq_epu64_mask(diff_0, xor_key_512); + let mask_1 = _mm512_cmpeq_epu64_mask(diff_1, xor_key_512); // All inputs match outputs - (mask0 & mask1) == u8::MAX + (mask_0 & mask_1) == u8::MAX } // Below code copied with minor changes from the following place under MIT/Apache-2.0 license by