Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding armv8 crypto extensions to AES #1086

Merged
merged 5 commits into from
Sep 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ else()
set(AES_IMPL ${AES_IMPL} aes/aes256_ni.c)
set_source_files_properties(aes/aes128_ni.c PROPERTIES COMPILE_FLAGS -maes)
set_source_files_properties(aes/aes256_ni.c PROPERTIES COMPILE_FLAGS -maes)
elseif (OQS_USE_ARM_AES_INSTRUCTIONS)
set(AES_IMPL ${AES_IMPL} aes/aes128_armv8.c)
set(AES_IMPL ${AES_IMPL} aes/aes256_armv8.c)
endif()
endif()

Expand Down
82 changes: 64 additions & 18 deletions src/common/aes/aes.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,38 +10,44 @@
#include "aes_local.h"

#if defined(OQS_DIST_X86_64_BUILD)
#define C_OR_NI(stmt_c, stmt_ni) \
#define C_OR_NI_OR_ARM(stmt_c, stmt_ni, stmt_arm) \
if (OQS_CPU_has_extension(OQS_CPU_EXT_AES)) { \
stmt_ni; \
} else { \
stmt_c; \
}
#elif defined(OQS_USE_AES_INSTRUCTIONS)
#define C_OR_NI(stmt_c, stmt_ni) \
#define C_OR_NI_OR_ARM(stmt_c, stmt_ni, stmt_arm) \
stmt_ni;
#elif defined(OQS_USE_ARM_AES_INSTRUCTIONS)
#define C_OR_NI_OR_ARM(stmt_c, stmt_ni, stmt_arm) \
stmt_arm;
#else
#define C_OR_NI(stmt_c, stmt_ni) \
#define C_OR_NI_OR_ARM(stmt_c, stmt_ni, stmt_arm) \
stmt_c;
#endif

void OQS_AES128_ECB_load_schedule(const uint8_t *key, void **_schedule) {
C_OR_NI(
C_OR_NI_OR_ARM(
oqs_aes128_load_schedule_c(key, _schedule),
oqs_aes128_load_schedule_ni(key, _schedule)
oqs_aes128_load_schedule_ni(key, _schedule),
oqs_aes128_load_schedule_armv8(key, _schedule)
)
}

void OQS_AES128_free_schedule(void *schedule) {
C_OR_NI(
C_OR_NI_OR_ARM(
oqs_aes128_free_schedule_c(schedule),
oqs_aes128_free_schedule_ni(schedule)
oqs_aes128_free_schedule_ni(schedule),
oqs_aes128_free_schedule_armv8(schedule)
)
}

void OQS_AES256_ECB_load_schedule(const uint8_t *key, void **_schedule) {
C_OR_NI(
C_OR_NI_OR_ARM(
oqs_aes256_load_schedule_c(key, _schedule),
oqs_aes256_load_schedule_ni(key, _schedule)
oqs_aes256_load_schedule_ni(key, _schedule),
oqs_aes256_load_schedule_armv8(key, _schedule)
)
}

Expand All @@ -50,9 +56,10 @@ void OQS_AES256_CTR_load_schedule(const uint8_t *key, void **_schedule) {
}

void OQS_AES256_free_schedule(void *schedule) {
C_OR_NI(
C_OR_NI_OR_ARM(
oqs_aes256_free_schedule_c(schedule),
oqs_aes256_free_schedule_ni(schedule)
oqs_aes256_free_schedule_ni(schedule),
oqs_aes256_free_schedule_armv8(schedule)
)
}

Expand All @@ -64,9 +71,10 @@ void OQS_AES128_ECB_enc(const uint8_t *plaintext, const size_t plaintext_len, co
}

void OQS_AES128_ECB_enc_sch(const uint8_t *plaintext, const size_t plaintext_len, const void *schedule, uint8_t *ciphertext) {
C_OR_NI(
C_OR_NI_OR_ARM(
oqs_aes128_ecb_enc_sch_c(plaintext, plaintext_len, schedule, ciphertext),
oqs_aes128_ecb_enc_sch_ni(plaintext, plaintext_len, schedule, ciphertext)
oqs_aes128_ecb_enc_sch_ni(plaintext, plaintext_len, schedule, ciphertext),
oqs_aes128_ecb_enc_sch_armv8(plaintext, plaintext_len, schedule, ciphertext)
)
}

Expand All @@ -78,13 +86,15 @@ void OQS_AES256_ECB_enc(const uint8_t *plaintext, const size_t plaintext_len, co
}

void OQS_AES256_ECB_enc_sch(const uint8_t *plaintext, const size_t plaintext_len, const void *schedule, uint8_t *ciphertext) {
C_OR_NI(
C_OR_NI_OR_ARM(
oqs_aes256_ecb_enc_sch_c(plaintext, plaintext_len, schedule, ciphertext),
oqs_aes256_ecb_enc_sch_ni(plaintext, plaintext_len, schedule, ciphertext)
oqs_aes256_ecb_enc_sch_ni(plaintext, plaintext_len, schedule, ciphertext),
oqs_aes256_ecb_enc_sch_armv8(plaintext, plaintext_len, schedule, ciphertext)
)
}

#if defined(OQS_DIST_X86_64_BUILD) || defined(OQS_USE_AES_INSTRUCTIONS)
#if defined(OQS_DIST_X86_64_BUILD) || defined(OQS_USE_AES_INSTRUCTIONS) || defined(OQS_USE_ARM_AES_INSTRUCTIONS)

static uint32_t UINT32_TO_BE(const uint32_t x) {
union {
uint32_t val;
Expand All @@ -97,7 +107,9 @@ static uint32_t UINT32_TO_BE(const uint32_t x) {
return y.val;
}
#define BE_TO_UINT32(n) (uint32_t)((((uint8_t *) &(n))[0] << 24) | (((uint8_t *) &(n))[1] << 16) | (((uint8_t *) &(n))[2] << 8) | (((uint8_t *) &(n))[3] << 0))
#endif

#if defined(OQS_DIST_X86_64_BUILD) || defined(OQS_USE_AES_INSTRUCTIONS)
void oqs_aes256_ctr_enc_sch_ni(const uint8_t *iv, const size_t iv_len, const void *schedule, uint8_t *out, size_t out_len) {
uint8_t block[16];
uint32_t ctr;
Expand Down Expand Up @@ -129,9 +141,43 @@ void oqs_aes256_ctr_enc_sch_ni(const uint8_t *iv, const size_t iv_len, const voi
}
#endif

#if defined(OQS_USE_ARM_AES_INSTRUCTIONS)
void oqs_aes256_ctr_enc_sch_armv8(const uint8_t *iv, const size_t iv_len, const void *schedule, uint8_t *out, size_t out_len) {
uint8_t block[16];
uint32_t ctr;
uint32_t ctr_be;
memcpy(block, iv, 12);
if (iv_len == 12) {
ctr = 0;
} else if (iv_len == 16) {
memcpy(&ctr_be, &iv[12], 4);
ctr = BE_TO_UINT32(ctr_be);
} else {
exit(EXIT_FAILURE);
}
while (out_len >= 16) {
ctr_be = UINT32_TO_BE(ctr);
memcpy(&block[12], (uint8_t *) &ctr_be, 4);
oqs_aes256_enc_sch_block_armv8(block, schedule, out);
out += 16;
out_len -= 16;
ctr++;
}
if (out_len > 0) {
uint8_t tmp[16];
ctr_be = UINT32_TO_BE(ctr);
memcpy(&block[12], (uint8_t *) &ctr_be, 4);
oqs_aes256_enc_sch_block_armv8(block, schedule, tmp);
memcpy(out, tmp, out_len);
}
}
#endif


void OQS_AES256_CTR_sch(const uint8_t *iv, const size_t iv_len, const void *schedule, uint8_t *out, size_t out_len) {
C_OR_NI(
C_OR_NI_OR_ARM(
oqs_aes256_ctr_enc_sch_c(iv, iv_len, schedule, out, out_len),
oqs_aes256_ctr_enc_sch_ni(iv, iv_len, schedule, out, out_len)
oqs_aes256_ctr_enc_sch_ni(iv, iv_len, schedule, out, out_len),
oqs_aes256_ctr_enc_sch_armv8(iv, iv_len, schedule, out, out_len)
)
}
218 changes: 218 additions & 0 deletions src/common/aes/aes128_armv8.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
// SPDX-License-Identifier: Public domain

#include <assert.h>
#include <stdint.h>
#include <stdio.h>
#include <oqs/common.h>

#include <arm_neon.h>

#define PQC_AES128_STATESIZE 88
typedef struct {
uint64_t sk_exp[PQC_AES128_STATESIZE];
} aes128ctx;


#define FSbData \
{ \
0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, \
0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB, 0x76, \
0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, \
0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4, 0x72, 0xC0, \
0xB7, 0xFD, 0x93, 0x26, 0x36, 0x3F, 0xF7, 0xCC, \
0x34, 0xA5, 0xE5, 0xF1, 0x71, 0xD8, 0x31, 0x15, \
0x04, 0xC7, 0x23, 0xC3, 0x18, 0x96, 0x05, 0x9A, \
0x07, 0x12, 0x80, 0xE2, 0xEB, 0x27, 0xB2, 0x75, \
0x09, 0x83, 0x2C, 0x1A, 0x1B, 0x6E, 0x5A, 0xA0, \
0x52, 0x3B, 0xD6, 0xB3, 0x29, 0xE3, 0x2F, 0x84, \
0x53, 0xD1, 0x00, 0xED, 0x20, 0xFC, 0xB1, 0x5B, \
0x6A, 0xCB, 0xBE, 0x39, 0x4A, 0x4C, 0x58, 0xCF, \
0xD0, 0xEF, 0xAA, 0xFB, 0x43, 0x4D, 0x33, 0x85, \
0x45, 0xF9, 0x02, 0x7F, 0x50, 0x3C, 0x9F, 0xA8, \
0x51, 0xA3, 0x40, 0x8F, 0x92, 0x9D, 0x38, 0xF5, \
0xBC, 0xB6, 0xDA, 0x21, 0x10, 0xFF, 0xF3, 0xD2, \
0xCD, 0x0C, 0x13, 0xEC, 0x5F, 0x97, 0x44, 0x17, \
0xC4, 0xA7, 0x7E, 0x3D, 0x64, 0x5D, 0x19, 0x73, \
0x60, 0x81, 0x4F, 0xDC, 0x22, 0x2A, 0x90, 0x88, \
0x46, 0xEE, 0xB8, 0x14, 0xDE, 0x5E, 0x0B, 0xDB, \
0xE0, 0x32, 0x3A, 0x0A, 0x49, 0x06, 0x24, 0x5C, \
0xC2, 0xD3, 0xAC, 0x62, 0x91, 0x95, 0xE4, 0x79, \
0xE7, 0xC8, 0x37, 0x6D, 0x8D, 0xD5, 0x4E, 0xA9, \
0x6C, 0x56, 0xF4, 0xEA, 0x65, 0x7A, 0xAE, 0x08, \
0xBA, 0x78, 0x25, 0x2E, 0x1C, 0xA6, 0xB4, 0xC6, \
0xE8, 0xDD, 0x74, 0x1F, 0x4B, 0xBD, 0x8B, 0x8A, \
0x70, 0x3E, 0xB5, 0x66, 0x48, 0x03, 0xF6, 0x0E, \
0x61, 0x35, 0x57, 0xB9, 0x86, 0xC1, 0x1D, 0x9E, \
0xE1, 0xF8, 0x98, 0x11, 0x69, 0xD9, 0x8E, 0x94, \
0x9B, 0x1E, 0x87, 0xE9, 0xCE, 0x55, 0x28, 0xDF, \
0x8C, 0xA1, 0x89, 0x0D, 0xBF, 0xE6, 0x42, 0x68, \
0x41, 0x99, 0x2D, 0x0F, 0xB0, 0x54, 0xBB, 0x16 \
}

static unsigned int FSb[256] = FSbData;
#undef FSbData

#define f_FSb_32__1(x) ((FSb[((x) >> 24) &0xFF] << 24) ^ \
(FSb[((x) >> 16) &0xFF] << 16))

#define f_FSb_32__2(x) ((FSb[((x) >> 8) &0xFF] << 8 ) ^ \
(FSb[((x) ) &0xFF] & 0xFF))



static inline unsigned int rotr(const unsigned int x, const unsigned int n) {
unsigned int r;
r = ((x >> n) | (x << (32 - n)));
return r;
}
static inline unsigned int rotl(const unsigned int x, const unsigned int n) {
unsigned int r;
r = ((x << n) | (x >> (32 - n)));
return r;
}

// From crypto_core/aes128encrypt/dolbeau/armv8crypto
static inline void aes128_armv8_keysched(const unsigned int key[], unsigned int *aes_edrk) {
unsigned int i = 0;
unsigned int rotl_aes_edrk;
unsigned int tmp8, tmp9, tmp10, tmp11;
unsigned int temp_lds;
#if defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
unsigned int round = 0x01000000;
#else
unsigned int round = 0x00000001;
#endif

tmp8 = (key[0]);
aes_edrk[0] = tmp8;
tmp9 = (key[1]);
aes_edrk[1] = tmp9;
tmp10 = (key[2]);
aes_edrk[2] = tmp10;
tmp11 = (key[3]);
aes_edrk[3] = tmp11;

for ( i = 4; i < 36; /* i += 4 */ ) {
#if defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
rotl_aes_edrk = rotl(tmp11, 8);
#else
rotl_aes_edrk = rotr(tmp11, 8);
#endif
temp_lds = f_FSb_32__1(rotl_aes_edrk) ^ f_FSb_32__2(rotl_aes_edrk);

tmp8 = tmp8 ^ round ^ temp_lds;
round = round << 1;

aes_edrk[i++] = tmp8;
tmp9 = tmp9 ^ tmp8;
aes_edrk[i++] = tmp9;
tmp10 = tmp10 ^ tmp9;
aes_edrk[i++] = tmp10;
tmp11 = tmp11 ^ tmp10;
aes_edrk[i++] = tmp11;
}

#if defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
round = 0x1B000000;
rotl_aes_edrk = rotl(tmp11, 8);
#else
round = 0x0000001B;
rotl_aes_edrk = rotr(tmp11, 8);
#endif
temp_lds = f_FSb_32__1(rotl_aes_edrk) ^ f_FSb_32__2(rotl_aes_edrk);

tmp8 = tmp8 ^ round ^ temp_lds;

aes_edrk[i++] = tmp8;
tmp9 = tmp9 ^ tmp8;
aes_edrk[i++] = tmp9;
tmp10 = tmp10 ^ tmp9;
aes_edrk[i++] = tmp10;
tmp11 = tmp11 ^ tmp10;
aes_edrk[i++] = tmp11;

#if defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
round = 0x36000000;
rotl_aes_edrk = rotl(tmp11, 8);
#else
round = 0x00000036;
rotl_aes_edrk = rotr(tmp11, 8);
#endif
temp_lds = f_FSb_32__1(rotl_aes_edrk) ^ f_FSb_32__2(rotl_aes_edrk);

tmp8 = tmp8 ^ round ^ temp_lds;

aes_edrk[i++] = tmp8;
tmp9 = tmp9 ^ tmp8;
aes_edrk[i++] = tmp9;
tmp10 = tmp10 ^ tmp9;
aes_edrk[i++] = tmp10;
tmp11 = tmp11 ^ tmp10;
aes_edrk[i++] = tmp11;
}


void oqs_aes128_load_schedule_armv8(const uint8_t *key, void **_schedule) {
*_schedule = malloc(44 * sizeof(int));
assert(*_schedule != NULL);
unsigned int *schedule = (unsigned int *) *_schedule;
aes128_armv8_keysched((const unsigned int *) key, schedule);
}

void oqs_aes128_free_schedule_armv8(void *schedule) {
if (schedule != NULL) {
OQS_MEM_secure_free(schedule, 44 * sizeof(int));
}
}


// From crypto_core/aes128encrypt/dolbeau/armv8crypto
static inline void aes128_armv8_encrypt(const unsigned char *rkeys, const unsigned char *n, unsigned char *out) {
uint8x16_t temp = vld1q_u8(n);
//int i;

/*
In ARMv8+crypto, the AESE instruction does the 'AddRoundKey' first then SubBytes and ShiftRows.
The AESMC instruction does the MixColumns.
So instead of a single XOR of the first round key before the rounds,
we end up having a single XOR of the last round key after the rounds.
*/

temp = vaeseq_u8(temp, vld1q_u8(rkeys));
temp = vaesmcq_u8(temp);
temp = vaeseq_u8(temp, vld1q_u8(rkeys + 16));
temp = vaesmcq_u8(temp);
temp = vaeseq_u8(temp, vld1q_u8(rkeys + 32));
temp = vaesmcq_u8(temp);
temp = vaeseq_u8(temp, vld1q_u8(rkeys + 48));
temp = vaesmcq_u8(temp);
temp = vaeseq_u8(temp, vld1q_u8(rkeys + 64));
temp = vaesmcq_u8(temp);
temp = vaeseq_u8(temp, vld1q_u8(rkeys + 80));
temp = vaesmcq_u8(temp);
temp = vaeseq_u8(temp, vld1q_u8(rkeys + 96));
temp = vaesmcq_u8(temp);
temp = vaeseq_u8(temp, vld1q_u8(rkeys + 112));
temp = vaesmcq_u8(temp);
temp = vaeseq_u8(temp, vld1q_u8(rkeys + 128));
temp = vaesmcq_u8(temp);

temp = vaeseq_u8(temp, vld1q_u8((rkeys + 144)));
temp = veorq_u8(temp, vld1q_u8((rkeys + 160)));

vst1q_u8(out, temp);
}

void oqs_aes128_enc_sch_block_armv8(const uint8_t *plaintext, const void *_schedule, uint8_t *ciphertext) {
const unsigned char *schedule = (const unsigned char *) _schedule;
aes128_armv8_encrypt(schedule, plaintext, ciphertext);
}

void oqs_aes128_ecb_enc_sch_armv8(const uint8_t *plaintext, const size_t plaintext_len, const void *schedule, uint8_t *ciphertext) {
assert(plaintext_len % 16 == 0);
const aes128ctx *ctx = (const aes128ctx *) schedule;

for (size_t block = 0; block < plaintext_len / 16; block++) {
oqs_aes128_enc_sch_block_armv8(plaintext + (16 * block), (const void *) ctx->sk_exp, ciphertext + (16 * block));
}
}
Loading