1+ use subspace_core_primitives:: pot:: { PotCheckpoints , PotOutput } ;
12use core:: arch:: x86_64:: * ;
2- use core:: mem;
3- use subspace_core_primitives:: pot:: PotCheckpoints ;
3+ use core:: { array, mem} ;
44
55/// Create PoT proof with checkpoints
66#[ target_feature( enable = "aes" ) ]
@@ -12,40 +12,117 @@ pub(super) unsafe fn create(
1212) -> PotCheckpoints {
1313 let mut checkpoints = PotCheckpoints :: default ( ) ;
1414
15- let keys_reg = expand_key ( key) ;
16- let xor_key = _mm_xor_si128 ( keys_reg[ 10 ] , keys_reg[ 0 ] ) ;
17- let mut seed_reg = _mm_loadu_si128 ( seed. as_ptr ( ) as * const __m128i ) ;
18- seed_reg = _mm_xor_si128 ( seed_reg, keys_reg[ 0 ] ) ;
19- for checkpoint in checkpoints. iter_mut ( ) {
20- for _ in 0 ..checkpoint_iterations {
21- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 1 ] ) ;
22- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 2 ] ) ;
23- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 3 ] ) ;
24- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 4 ] ) ;
25- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 5 ] ) ;
26- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 6 ] ) ;
27- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 7 ] ) ;
28- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 8 ] ) ;
29- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 9 ] ) ;
30- seed_reg = _mm_aesenclast_si128 ( seed_reg, xor_key) ;
31- }
15+ unsafe {
16+ let keys_reg = expand_key ( key) ;
17+ let xor_key = _mm_xor_si128 ( keys_reg[ 10 ] , keys_reg[ 0 ] ) ;
18+ let mut seed_reg = _mm_loadu_si128 ( seed. as_ptr ( ) as * const __m128i ) ;
19+ seed_reg = _mm_xor_si128 ( seed_reg, keys_reg[ 0 ] ) ;
20+ for checkpoint in checkpoints. iter_mut ( ) {
21+ for _ in 0 ..checkpoint_iterations {
22+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 1 ] ) ;
23+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 2 ] ) ;
24+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 3 ] ) ;
25+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 4 ] ) ;
26+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 5 ] ) ;
27+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 6 ] ) ;
28+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 7 ] ) ;
29+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 8 ] ) ;
30+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 9 ] ) ;
31+ seed_reg = _mm_aesenclast_si128 ( seed_reg, xor_key) ;
32+ }
3233
33- let checkpoint_reg = _mm_xor_si128 ( seed_reg, keys_reg[ 0 ] ) ;
34- _mm_storeu_si128 (
35- checkpoint. as_mut ( ) . as_mut_ptr ( ) as * mut __m128i ,
36- checkpoint_reg,
37- ) ;
34+ let checkpoint_reg = _mm_xor_si128 ( seed_reg, keys_reg[ 0 ] ) ;
35+ _mm_storeu_si128 ( checkpoint. as_mut_ptr ( ) as * mut __m128i , checkpoint_reg) ;
36+ }
3837 }
3938
4039 checkpoints
4140}
4241
42+ /// Verification mimics `create` function, but also has decryption half for better performance
43+ #[ target_feature( enable = "avx512f,vaes" ) ]
44+ #[ inline]
45+ pub ( super ) unsafe fn verify_sequential_avx512f (
46+ seed : & [ u8 ; 16 ] ,
47+ key : & [ u8 ; 16 ] ,
48+ checkpoints : & PotCheckpoints ,
49+ checkpoint_iterations : u32 ,
50+ ) -> bool {
51+ let checkpoints = PotOutput :: repr_from_slice ( checkpoints. as_slice ( ) ) ;
52+
53+ unsafe {
54+ let keys_reg = expand_key ( key) ;
55+ let xor_key = _mm_xor_si128 ( keys_reg[ 10 ] , keys_reg[ 0 ] ) ;
56+ let xor_key_512 = _mm512_broadcast_i32x4 ( xor_key) ;
57+
58+ // Invert keys for decryption
59+ let mut inv_keys = keys_reg;
60+ for i in 1 ..10 {
61+ inv_keys[ i] = _mm_aesimc_si128 ( keys_reg[ 10 - i] ) ;
62+ }
63+
64+ let keys_512 = array:: from_fn :: < _ , NUM_ROUNDS , _ > ( |i| _mm512_broadcast_i32x4 ( keys_reg[ i] ) ) ;
65+ let inv_keys_512 =
66+ array:: from_fn :: < _ , NUM_ROUNDS , _ > ( |i| _mm512_broadcast_i32x4 ( inv_keys[ i] ) ) ;
67+
68+ let mut input_0 = [ [ 0u8 ; 16 ] ; 4 ] ;
69+ input_0[ 0 ] = * seed;
70+ input_0[ 1 ..] . copy_from_slice ( & checkpoints[ ..3 ] ) ;
71+ let mut input_0 = _mm512_loadu_si512 ( input_0. as_ptr ( ) as * const __m512i ) ;
72+ let mut input_1 = _mm512_loadu_si512 ( checkpoints[ 3 ..7 ] . as_ptr ( ) as * const __m512i ) ;
73+
74+ let mut output_0 = _mm512_loadu_si512 ( checkpoints[ 0 ..4 ] . as_ptr ( ) as * const __m512i ) ;
75+ let mut output_1 = _mm512_loadu_si512 ( checkpoints[ 4 ..8 ] . as_ptr ( ) as * const __m512i ) ;
76+
77+ input_0 = _mm512_xor_si512 ( input_0, keys_512[ 0 ] ) ;
78+ input_1 = _mm512_xor_si512 ( input_1, keys_512[ 0 ] ) ;
79+
80+ output_0 = _mm512_xor_si512 ( output_0, keys_512[ 10 ] ) ;
81+ output_1 = _mm512_xor_si512 ( output_1, keys_512[ 10 ] ) ;
82+
83+ for _ in 0 ..checkpoint_iterations / 2 {
84+ for i in 1 ..10 {
85+ input_0 = _mm512_aesenc_epi128 ( input_0, keys_512[ i] ) ;
86+ input_1 = _mm512_aesenc_epi128 ( input_1, keys_512[ i] ) ;
87+
88+ output_0 = _mm512_aesdec_epi128 ( output_0, inv_keys_512[ i] ) ;
89+ output_1 = _mm512_aesdec_epi128 ( output_1, inv_keys_512[ i] ) ;
90+ }
91+
92+ input_0 = _mm512_aesenclast_epi128 ( input_0, xor_key_512) ;
93+ input_1 = _mm512_aesenclast_epi128 ( input_1, xor_key_512) ;
94+
95+ output_0 = _mm512_aesdeclast_epi128 ( output_0, xor_key_512) ;
96+ output_1 = _mm512_aesdeclast_epi128 ( output_1, xor_key_512) ;
97+ }
98+
99+ // Code below is a more efficient version of this:
100+ // input_0 = _mm512_xor_si512(input_0, keys_512[0]);
101+ // input_1 = _mm512_xor_si512(input_1, keys_512[0]);
102+ // output_0 = _mm512_xor_si512(output_0, keys_512[10]);
103+ // output_1 = _mm512_xor_si512(output_1, keys_512[10]);
104+ //
105+ // let mask0 = _mm512_cmpeq_epu64_mask(input_0, output_0);
106+ // let mask1 = _mm512_cmpeq_epu64_mask(input_1, output_1);
107+
108+ let diff_0 = _mm512_xor_si512 ( input_0, output_0) ;
109+ let diff_1 = _mm512_xor_si512 ( input_1, output_1) ;
110+
111+ let mask0 = _mm512_cmpeq_epu64_mask ( diff_0, xor_key_512) ;
112+ let mask1 = _mm512_cmpeq_epu64_mask ( diff_1, xor_key_512) ;
113+
114+ // All inputs match outputs
115+ ( mask0 & mask1) == u8:: MAX
116+ }
117+ }
118+
43119// Below code copied with minor changes from following place under MIT/Apache-2.0 license by Artyom
44120// Pavlov:
45121// https://github.com/RustCrypto/block-ciphers/blob/9413fcadd28d53854954498c0589b747d8e4ade2/aes/src/ni/aes128.rs
46122
123+ const NUM_ROUNDS : usize = 11 ;
47124/// AES-128 round keys
48- type RoundKeys = [ __m128i ; 11 ] ;
125+ type RoundKeys = [ __m128i ; NUM_ROUNDS ] ;
49126
50127macro_rules! expand_round {
51128 ( $keys: expr, $pos: expr, $round: expr) => {
@@ -72,9 +149,10 @@ macro_rules! expand_round {
72149unsafe fn expand_key ( key : & [ u8 ; 16 ] ) -> RoundKeys {
73150 // SAFETY: `RoundKeys` is a `[__m128i; 11]` which can be initialized
74151 // with all zeroes.
75- let mut keys: RoundKeys = mem:: zeroed ( ) ;
152+ let mut keys: RoundKeys = unsafe { mem:: zeroed ( ) } ;
76153
77- let k = _mm_loadu_si128 ( key. as_ptr ( ) as * const __m128i ) ;
154+ // SAFETY: No alignment requirement in `_mm_loadu_si128`
155+ let k = unsafe { _mm_loadu_si128 ( key. as_ptr ( ) as * const __m128i ) } ;
78156 keys[ 0 ] = k;
79157
80158 expand_round ! ( keys, 1 , 0x01 ) ;
0 commit comments