diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 6bc1c89a9..fef2eb75f 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -255,14 +255,16 @@ jobs: - name: Ensure no panics in annotated code env: # Increase inlining threshold to make sure the compiler can see that some functions do not panic - RUSTFLAGS: ${{ env.RUSTFLAGS }} -Cllvm-args=--inline-threshold=2000 + # Native CPU and LTO to allow compiler to apply more optimizations and prove lack of panics in more cases + RUSTFLAGS: ${{ env.RUSTFLAGS }} -Cllvm-args=--inline-threshold=3000 -C embed-bitcode -C lto=fat -Z dylib-lto -C target-cpu=native run: | cargo -Zgitoxide -Zgit build --release --all-targets --features no-panic - name: Ensure no panics in annotated code (various features) env: # Increase inlining threshold to make sure the compiler can see that some functions do not panic - RUSTFLAGS: ${{ env.RUSTFLAGS }} -Cllvm-args=--inline-threshold=2000 + # Native CPU and LTO to allow compiler to apply more optimizations and prove lack of panics in more cases + RUSTFLAGS: ${{ env.RUSTFLAGS }} -Cllvm-args=--inline-threshold=3000 -C embed-bitcode -C lto=fat -Z dylib-lto -C target-cpu=native run: | # Ensure no panics with `guest` feature echo "Checking `no-panic` in contracts" diff --git a/crates/shared/ab-proof-of-time/Cargo.toml b/crates/shared/ab-proof-of-time/Cargo.toml index e1c350304..b290fa552 100644 --- a/crates/shared/ab-proof-of-time/Cargo.toml +++ b/crates/shared/ab-proof-of-time/Cargo.toml @@ -31,7 +31,7 @@ aes = { workspace = true } no-panic = { workspace = true, optional = true } thiserror = { workspace = true } -[target.'cfg(target_arch = "x86_64")'.dependencies] +[target.'cfg(any(target_arch = "aarch64", target_arch = "x86_64"))'.dependencies] cpufeatures = { workspace = true } [dev-dependencies] diff --git a/crates/shared/ab-proof-of-time/src/aes.rs b/crates/shared/ab-proof-of-time/src/aes.rs index 5e7ae5264..0908ea4a5 100644 --- a/crates/shared/ab-proof-of-time/src/aes.rs +++ b/crates/shared/ab-proof-of-time/src/aes.rs @@ -1,5 +1,7 @@ //! AES related functionality. +#[cfg(target_arch = "aarch64")] +mod aarch64; #[cfg(target_arch = "x86_64")] mod x86_64; @@ -21,6 +23,14 @@ pub(crate) fn create(seed: PotSeed, key: PotKey, checkpoint_iterations: u32) -> return unsafe { x86_64::create(seed.as_ref(), key.as_ref(), checkpoint_iterations) }; } } + #[cfg(target_arch = "aarch64")] + { + cpufeatures::new!(has_aes, "aes"); + if has_aes::get() { + // SAFETY: Checked `aes` feature + return unsafe { aarch64::create(seed.as_ref(), key.as_ref(), checkpoint_iterations) }; + } + } create_generic(seed, key, checkpoint_iterations) } @@ -89,6 +99,16 @@ pub(crate) fn verify_sequential( }; } } + #[cfg(target_arch = "aarch64")] + { + cpufeatures::new!(has_aes, "aes"); + if has_aes::get() { + // SAFETY: Checked `aes` feature + return unsafe { + aarch64::verify_sequential_aes(&seed, &key, checkpoints, checkpoint_iterations) + }; + } + } verify_sequential_generic(seed, key, checkpoints, checkpoint_iterations) } @@ -151,9 +171,8 @@ mod tests { checkpoint_iterations: u32, ) -> bool { let sequential = verify_sequential(seed, key, checkpoints, checkpoint_iterations); - let sequential_generic = - verify_sequential_generic(seed, key, checkpoints, checkpoint_iterations); - assert_eq!(sequential, sequential_generic); + let generic = verify_sequential_generic(seed, key, checkpoints, checkpoint_iterations); + assert_eq!(sequential, generic); #[cfg(target_arch = "x86_64")] { @@ -188,7 +207,7 @@ mod tests { cpufeatures::new!(has_aes_sse41, "aes", "sse4.1"); if has_aes_sse41::get() { // SAFETY: Checked `aes` and `sse4.1` features - let aes = unsafe { + let aes_sse41 = unsafe { x86_64::verify_sequential_aes_sse41( &seed, &key, @@ -196,6 +215,17 @@ mod tests { checkpoint_iterations, ) }; + assert_eq!(sequential, aes_sse41); + } + } + #[cfg(target_arch = "aarch64")] + { + cpufeatures::new!(has_aes, "aes"); + if has_aes::get() { + // SAFETY: Checked `aes` feature + let aes = unsafe { + aarch64::verify_sequential_aes(&seed, &key, checkpoints, checkpoint_iterations) + }; assert_eq!(sequential, aes); } } diff --git a/crates/shared/ab-proof-of-time/src/aes/aarch64.rs b/crates/shared/ab-proof-of-time/src/aes/aarch64.rs new file mode 100644 index 000000000..97cd1ea6f --- /dev/null +++ b/crates/shared/ab-proof-of-time/src/aes/aarch64.rs @@ -0,0 +1,175 @@ +use ab_core_primitives::pot::{PotCheckpoints, PotOutput}; +use core::arch::aarch64::*; +use core::simd::u8x16; +use core::slice; + +const NUM_ROUND_KEYS: usize = 11; + +/// Create PoT proof with checkpoints +#[target_feature(enable = "aes")] +#[inline] +#[cfg_attr(feature = "no-panic", no_panic::no_panic)] +pub(super) fn create( + seed: &[u8; 16], + key: &[u8; 16], + checkpoint_iterations: u32, +) -> PotCheckpoints { + let mut checkpoints = PotCheckpoints::default(); + + let keys = expand_key(key); + let xor_key = veorq_u8(keys[10], keys[0]); + let mut seed = uint8x16_t::from(u8x16::from(*seed)); + seed = veorq_u8(seed, keys[10]); + for checkpoint in checkpoints.iter_mut() { + for _ in 0..checkpoint_iterations { + seed = vaesmcq_u8(vaeseq_u8(seed, xor_key)); + seed = vaesmcq_u8(vaeseq_u8(seed, keys[1])); + seed = vaesmcq_u8(vaeseq_u8(seed, keys[2])); + seed = vaesmcq_u8(vaeseq_u8(seed, keys[3])); + seed = vaesmcq_u8(vaeseq_u8(seed, keys[4])); + seed = vaesmcq_u8(vaeseq_u8(seed, keys[5])); + seed = vaesmcq_u8(vaeseq_u8(seed, keys[6])); + seed = vaesmcq_u8(vaeseq_u8(seed, keys[7])); + seed = vaesmcq_u8(vaeseq_u8(seed, keys[8])); + seed = vaeseq_u8(seed, keys[9]); + } + + let checkpoint_reg = veorq_u8(seed, keys[10]); + **checkpoint = u8x16::from(checkpoint_reg).to_array(); + } + + checkpoints +} + +/// Verification mimics `create` function, but also has decryption half for better performance +#[target_feature(enable = "aes")] +#[inline] +// TODO: Enable on all platforms once it works +#[cfg_attr(all(feature = "no-panic", target_os = "linux"), no_panic::no_panic)] +pub(super) fn verify_sequential_aes( + seed: &[u8; 16], + key: &[u8; 16], + checkpoints: &PotCheckpoints, + checkpoint_iterations: u32, +) -> bool { + let checkpoints = PotOutput::repr_from_slice(checkpoints.as_slice()); + + let keys = expand_key(key); + let xor_key = veorq_u8(keys[10], keys[0]); + + // Invert keys for decryption, the first and last element is not used below, hence they are + // copied as is from encryption keys (otherwise the first and last element would need to be + // swapped) + let mut inv_keys = keys; + for i in 1..10 { + inv_keys[i] = vaesimcq_u8(keys[10 - i]); + } + + let mut inputs: [uint8x16_t; PotCheckpoints::NUM_CHECKPOINTS.get() as usize] = [ + uint8x16_t::from(u8x16::from(*seed)), + uint8x16_t::from(u8x16::from(checkpoints[0])), + uint8x16_t::from(u8x16::from(checkpoints[1])), + uint8x16_t::from(u8x16::from(checkpoints[2])), + uint8x16_t::from(u8x16::from(checkpoints[3])), + uint8x16_t::from(u8x16::from(checkpoints[4])), + uint8x16_t::from(u8x16::from(checkpoints[5])), + uint8x16_t::from(u8x16::from(checkpoints[6])), + ]; + + let mut outputs: [uint8x16_t; PotCheckpoints::NUM_CHECKPOINTS.get() as usize] = [ + uint8x16_t::from(u8x16::from(checkpoints[0])), + uint8x16_t::from(u8x16::from(checkpoints[1])), + uint8x16_t::from(u8x16::from(checkpoints[2])), + uint8x16_t::from(u8x16::from(checkpoints[3])), + uint8x16_t::from(u8x16::from(checkpoints[4])), + uint8x16_t::from(u8x16::from(checkpoints[5])), + uint8x16_t::from(u8x16::from(checkpoints[6])), + uint8x16_t::from(u8x16::from(checkpoints[7])), + ]; + + inputs = inputs.map(|input| veorq_u8(input, keys[10])); + outputs = outputs.map(|output| veorq_u8(output, keys[0])); + + for _ in 0..checkpoint_iterations / 2 { + inputs = inputs.map(|input| vaesmcq_u8(vaeseq_u8(input, xor_key))); + outputs = outputs.map(|output| vaesimcq_u8(vaesdq_u8(output, xor_key))); + + for i in 1..9 { + inputs = inputs.map(|input| vaesmcq_u8(vaeseq_u8(input, keys[i]))); + outputs = outputs.map(|output| vaesimcq_u8(vaesdq_u8(output, inv_keys[i]))); + } + + inputs = inputs.map(|input| vaeseq_u8(input, keys[9])); + outputs = outputs.map(|output| vaesdq_u8(output, inv_keys[9])); + } + + inputs.into_iter().zip(outputs).all(|(input, output)| { + let diff = veorq_u8(input, output); + let cmp = vceqq_u8(diff, xor_key); + vminvq_u8(cmp) == u8::MAX + }) +} + +// Below code copied with minor changes from the following place under MIT/Apache-2.0 license by +// Artyom Pavlov: +// https://github.com/RustCrypto/block-ciphers/blob/fbb68f40b122909d92e40ee8a50112b6e5d0af8f/aes/src/armv8/expand.rs + +/// There are 4 AES words in a block. +const BLOCK_WORDS: usize = 4; + +/// The AES (nee Rijndael) notion of a word is always 32-bits, or 4-bytes. +const WORD_SIZE: usize = 4; + +/// AES round constants. +const ROUND_CONSTS: [u32; 10] = [0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36]; + +/// AES key expansion. +#[target_feature(enable = "aes")] +#[cfg_attr(feature = "no-panic", no_panic::no_panic)] +fn expand_key(key: &[u8; 16]) -> [uint8x16_t; NUM_ROUND_KEYS] { + let mut expanded_keys = [uint8x16_t::from(u8x16::default()); NUM_ROUND_KEYS]; + + // Sanity check, as this is required in order for the subsequent conversion to be sound. + const _: () = assert!(align_of::() >= align_of::()); + let columns = unsafe { + slice::from_raw_parts_mut( + expanded_keys.as_mut_ptr().cast::(), + NUM_ROUND_KEYS * BLOCK_WORDS, + ) + }; + + for (i, chunk) in key.array_chunks::().enumerate() { + columns[i] = u32::from_ne_bytes(*chunk); + } + + // From "The Rijndael Block Cipher" Section 4.1: + // > The number of columns of the Cipher Key is denoted by `Nk` and is + // > equal to the key length divided by 32 [bits]. + let nk = 16 / WORD_SIZE; + + for i in nk..NUM_ROUND_KEYS * BLOCK_WORDS { + let mut word = columns[i - 1]; + + if i % nk == 0 { + word = sub_word(word).rotate_right(8) ^ ROUND_CONSTS[i / nk - 1]; + } else if nk > 6 && i % nk == 4 { + word = sub_word(word); + } + + columns[i] = columns[i - nk] ^ word; + } + + expanded_keys +} + +/// Sub bytes for a single AES word: used for key expansion +#[target_feature(enable = "aes")] +#[cfg_attr(feature = "no-panic", no_panic::no_panic)] +fn sub_word(input: u32) -> u32 { + let input = vreinterpretq_u8_u32(vdupq_n_u32(input)); + + // AES single round encryption (with a "round" key of all zeros) + let sub_input = vaeseq_u8(input, vdupq_n_u8(0)); + + vgetq_lane_u32::<0>(vreinterpretq_u32_u8(sub_input)) +} diff --git a/crates/shared/ab-proof-of-time/src/aes/x86_64.rs b/crates/shared/ab-proof-of-time/src/aes/x86_64.rs index 42c916d56..c7965b5a5 100644 --- a/crates/shared/ab-proof-of-time/src/aes/x86_64.rs +++ b/crates/shared/ab-proof-of-time/src/aes/x86_64.rs @@ -16,25 +16,25 @@ pub(super) fn create( ) -> PotCheckpoints { let mut checkpoints = PotCheckpoints::default(); - let keys_reg = expand_key(key); - let xor_key = _mm_xor_si128(keys_reg[10], keys_reg[0]); - let mut seed_reg = __m128i::from(u8x16::from_array(*seed)); - seed_reg = _mm_xor_si128(seed_reg, keys_reg[0]); + let keys = expand_key(key); + let xor_key = _mm_xor_si128(keys[10], keys[0]); + let mut seed = __m128i::from(u8x16::from_array(*seed)); + seed = _mm_xor_si128(seed, keys[0]); for checkpoint in checkpoints.iter_mut() { for _ in 0..checkpoint_iterations { - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[1]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[2]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[3]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[4]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[5]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[6]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[7]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[8]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[9]); - seed_reg = _mm_aesenclast_si128(seed_reg, xor_key); + seed = _mm_aesenc_si128(seed, keys[1]); + seed = _mm_aesenc_si128(seed, keys[2]); + seed = _mm_aesenc_si128(seed, keys[3]); + seed = _mm_aesenc_si128(seed, keys[4]); + seed = _mm_aesenc_si128(seed, keys[5]); + seed = _mm_aesenc_si128(seed, keys[6]); + seed = _mm_aesenc_si128(seed, keys[7]); + seed = _mm_aesenc_si128(seed, keys[8]); + seed = _mm_aesenc_si128(seed, keys[9]); + seed = _mm_aesenclast_si128(seed, xor_key); } - let checkpoint_reg = _mm_xor_si128(seed_reg, keys_reg[0]); + let checkpoint_reg = _mm_xor_si128(seed, keys[0]); **checkpoint = u8x16::from(checkpoint_reg).to_array(); } @@ -64,7 +64,7 @@ pub(super) fn verify_sequential_aes_sse41( inv_keys[i] = _mm_aesimc_si128(keys[10 - i]); } - let mut inputs: [__m128i; 8] = [ + let mut inputs: [__m128i; PotCheckpoints::NUM_CHECKPOINTS.get() as usize] = [ __m128i::from(u8x16::from(*seed)), __m128i::from(u8x16::from(checkpoints[0])), __m128i::from(u8x16::from(checkpoints[1])), @@ -75,7 +75,7 @@ pub(super) fn verify_sequential_aes_sse41( __m128i::from(u8x16::from(checkpoints[6])), ]; - let mut outputs: [__m128i; 8] = [ + let mut outputs: [__m128i; PotCheckpoints::NUM_CHECKPOINTS.get() as usize] = [ __m128i::from(u8x16::from(checkpoints[0])), __m128i::from(u8x16::from(checkpoints[1])), __m128i::from(u8x16::from(checkpoints[2])), diff --git a/crates/shared/ab-proof-of-time/src/lib.rs b/crates/shared/ab-proof-of-time/src/lib.rs index 3cfa429de..0e0085c51 100644 --- a/crates/shared/ab-proof-of-time/src/lib.rs +++ b/crates/shared/ab-proof-of-time/src/lib.rs @@ -1,5 +1,6 @@ //! Proof of time implementation. +#![cfg_attr(target_arch = "aarch64", feature(array_chunks, iter_array_chunks))] #![cfg_attr(target_arch = "x86_64", feature(stdarch_x86_avx512))] #![feature(portable_simd)] #![no_std] diff --git a/subspace/docker/bootstrap-node.Dockerfile b/subspace/docker/bootstrap-node.Dockerfile index a0d6c3731..6216f4075 100644 --- a/subspace/docker/bootstrap-node.Dockerfile +++ b/subspace/docker/bootstrap-node.Dockerfile @@ -67,7 +67,7 @@ RUN \ ; fi && \ if [ $TARGETARCH = "amd64" ] && [ "$RUSTFLAGS" = "" ]; then \ case "$TARGETVARIANT" in \ - # x86-64-v2 with AES-NI + # x86-64-v2 "v2") export RUSTFLAGS="-C target-cpu=x86-64-v2" ;; \ # x86-64-v3 with AES-NI "v3") export RUSTFLAGS="-C target-cpu=x86-64-v3 -C target-feature=+aes" ;; \ diff --git a/subspace/docker/farmer.Dockerfile b/subspace/docker/farmer.Dockerfile index c69f26c91..7d675bc79 100644 --- a/subspace/docker/farmer.Dockerfile +++ b/subspace/docker/farmer.Dockerfile @@ -106,7 +106,7 @@ RUN \ ; fi && \ if [ $TARGETARCH = "amd64" ] && [ "$RUSTFLAGS" = "" ]; then \ case "$TARGETVARIANT" in \ - # x86-64-v2 with AES-NI + # x86-64-v2 "v2") export RUSTFLAGS="-C target-cpu=x86-64-v2" ;; \ # x86-64-v3 with AES-NI "v3") export RUSTFLAGS="-C target-cpu=x86-64-v3 -C target-feature=+aes" ;; \ diff --git a/subspace/docker/node.Dockerfile b/subspace/docker/node.Dockerfile index e711cf598..2f6ad86c2 100644 --- a/subspace/docker/node.Dockerfile +++ b/subspace/docker/node.Dockerfile @@ -68,7 +68,7 @@ RUN \ ; fi && \ if [ $TARGETARCH = "amd64" ] && [ "$RUSTFLAGS" = "" ]; then \ case "$TARGETVARIANT" in \ - # x86-64-v2 with AES-NI + # x86-64-v2 "v2") export RUSTFLAGS="-C target-cpu=x86-64-v2" ;; \ # x86-64-v3 with AES-NI "v3") export RUSTFLAGS="-C target-cpu=x86-64-v3 -C target-feature=+aes" ;; \ diff --git a/subspace/docker/runtime.Dockerfile b/subspace/docker/runtime.Dockerfile index 6046f1a7c..68ddbdcff 100644 --- a/subspace/docker/runtime.Dockerfile +++ b/subspace/docker/runtime.Dockerfile @@ -67,7 +67,7 @@ RUN \ ; fi && \ if [ $TARGETARCH = "amd64" ] && [ "$RUSTFLAGS" = "" ]; then \ case "$TARGETVARIANT" in \ - # x86-64-v2 with AES-NI + # x86-64-v2 "v2") export RUSTFLAGS="-C target-cpu=x86-64-v2" ;; \ # x86-64-v3 with AES-NI "v3") export RUSTFLAGS="-C target-cpu=x86-64-v3 -C target-feature=+aes" ;; \