Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions cpp/src/gandiva/encrypt_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,24 @@
// under the License.

#include "gandiva/encrypt_utils.h"
#include <string.h>

#include <stdexcept>

namespace gandiva {
GANDIVA_EXPORT
int32_t aes_encrypt(const char* plaintext, int32_t plaintext_len, const char* key,
unsigned char* cipher) {
int32_t aes_encrypt(const char* plaintext, int32_t plaintext_len, const char* key,
int32_t key_len, unsigned char* cipher) {
int32_t cipher_len = 0;
int32_t len = 0;
EVP_CIPHER_CTX* en_ctx = EVP_CIPHER_CTX_new();
const EVP_CIPHER* cipher_algo = get_cipher_algo(key_len);

if (!en_ctx) {
throw std::runtime_error("could not create a new evp cipher ctx for encryption");
}

if (!EVP_EncryptInit_ex(en_ctx, EVP_aes_128_ecb(), nullptr,
if (!EVP_EncryptInit_ex(en_ctx, cipher_algo, nullptr,
reinterpret_cast<const unsigned char*>(key), nullptr)) {
throw std::runtime_error("could not initialize evp cipher ctx for encryption");
}
Expand All @@ -55,17 +57,18 @@ int32_t aes_encrypt(const char* plaintext, int32_t plaintext_len, const char* ke
}

GANDIVA_EXPORT
int32_t aes_decrypt(const char* ciphertext, int32_t ciphertext_len, const char* key,
unsigned char* plaintext) {
int32_t aes_decrypt(const char* ciphertext, int32_t ciphertext_len, const char* key,
int32_t key_len, unsigned char* plaintext) {
int32_t plaintext_len = 0;
int32_t len = 0;
EVP_CIPHER_CTX* de_ctx = EVP_CIPHER_CTX_new();
const EVP_CIPHER* cipher_algo = get_cipher_algo(key_len);

if (!de_ctx) {
throw std::runtime_error("could not create a new evp cipher ctx for decryption");
}

if (!EVP_DecryptInit_ex(de_ctx, EVP_aes_128_ecb(), nullptr,
if (!EVP_DecryptInit_ex(de_ctx, cipher_algo, nullptr,
reinterpret_cast<const unsigned char*>(key), nullptr)) {
throw std::runtime_error("could not initialize evp cipher ctx for decryption");
}
Expand All @@ -87,4 +90,17 @@ int32_t aes_decrypt(const char* ciphertext, int32_t ciphertext_len, const char*
EVP_CIPHER_CTX_free(de_ctx);
return plaintext_len;
}

const EVP_CIPHER* get_cipher_algo(int32_t key_length){
switch (key_length) {
case 16:
return EVP_aes_128_ecb();
case 24:
return EVP_aes_192_ecb();
case 32:
return EVP_aes_256_ecb();
default:
throw std::runtime_error("unsupported key length");
}
}
} // namespace gandiva
6 changes: 4 additions & 2 deletions cpp/src/gandiva/encrypt_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@ namespace gandiva {
**/
GANDIVA_EXPORT
int32_t aes_encrypt(const char* plaintext, int32_t plaintext_len, const char* key,
unsigned char* cipher);
int32_t key_len, unsigned char* cipher);

/**
* Decrypt data using aes algorithm
**/
GANDIVA_EXPORT
int32_t aes_decrypt(const char* ciphertext, int32_t ciphertext_len, const char* key,
unsigned char* plaintext);
int32_t key_len, unsigned char* plaintext);

const EVP_CIPHER* get_cipher_algo(int32_t key_length);

} // namespace gandiva
114 changes: 35 additions & 79 deletions cpp/src/gandiva/encrypt_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,38 @@
#include <gtest/gtest.h>

TEST(TestShaEncryptUtils, TestAesEncryptDecrypt) {
// 8 bytes key
auto* key = "1234abcd";
// 16 bytes key
auto* key = "12345678abcdefgh";
auto* to_encrypt = "some test string";

auto key_len = static_cast<int32_t>(strlen(reinterpret_cast<const char*>(key)));
auto to_encrypt_len =
static_cast<int32_t>(strlen(reinterpret_cast<const char*>(to_encrypt)));
unsigned char cipher_1[64];

int32_t cipher_1_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_1);
int32_t cipher_1_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_1);

unsigned char decrypted_1[64];
int32_t decrypted_1_len = gandiva::aes_decrypt(reinterpret_cast<const char*>(cipher_1),
cipher_1_len, key, decrypted_1);
cipher_1_len, key, key_len, decrypted_1);

EXPECT_EQ(std::string(reinterpret_cast<const char*>(to_encrypt), to_encrypt_len),
std::string(reinterpret_cast<const char*>(decrypted_1), decrypted_1_len));

// 16 bytes key
key = "12345678abcdefgh";
// 24 bytes key
key = "12345678abcdefgh12345678";
to_encrypt = "some\ntest\nstring";

key_len = static_cast<int32_t>(strlen(reinterpret_cast<const char*>(key)));
to_encrypt_len =
static_cast<int32_t>(strlen(reinterpret_cast<const char*>(to_encrypt)));
unsigned char cipher_2[64];

int32_t cipher_2_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_2);
int32_t cipher_2_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_2);

unsigned char decrypted_2[64];
int32_t decrypted_2_len = gandiva::aes_decrypt(reinterpret_cast<const char*>(cipher_2),
cipher_2_len, key, decrypted_2);
cipher_2_len, key, key_len, decrypted_2);

EXPECT_EQ(std::string(reinterpret_cast<const char*>(to_encrypt), to_encrypt_len),
std::string(reinterpret_cast<const char*>(decrypted_2), decrypted_2_len));
Expand All @@ -58,97 +60,51 @@ TEST(TestShaEncryptUtils, TestAesEncryptDecrypt) {
key = "12345678abcdefgh12345678abcdefgh";
to_encrypt = "New\ntest\nstring";

key_len = static_cast<int32_t>(strlen(reinterpret_cast<const char*>(key)));
to_encrypt_len =
static_cast<int32_t>(strlen(reinterpret_cast<const char*>(to_encrypt)));
unsigned char cipher_3[64];

int32_t cipher_3_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_3);
int32_t cipher_3_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_3);

unsigned char decrypted_3[64];
int32_t decrypted_3_len = gandiva::aes_decrypt(reinterpret_cast<const char*>(cipher_3),
cipher_3_len, key, decrypted_3);
cipher_3_len, key, key_len, decrypted_3);

EXPECT_EQ(std::string(reinterpret_cast<const char*>(to_encrypt), to_encrypt_len),
std::string(reinterpret_cast<const char*>(decrypted_3), decrypted_3_len));

// 64 bytes key
// check exception
char cipher[64] = "JBB7oJAQuqhDCx01fvBRi8PcljW1+nbnOSMk+R0Sz7E==";
int32_t cipher_len = static_cast<int32_t>(strlen(reinterpret_cast<const char*>(cipher)));
unsigned char plain_text[64];

key = "12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh";
to_encrypt = "New\ntest\nstring";

key_len = static_cast<int32_t>(strlen(reinterpret_cast<const char*>(key)));
to_encrypt_len =
static_cast<int32_t>(strlen(reinterpret_cast<const char*>(to_encrypt)));
unsigned char cipher_4[64];
ASSERT_THROW({
gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_4);
}, std::runtime_error);

int32_t cipher_4_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_4);

unsigned char decrypted_4[64];
int32_t decrypted_4_len = gandiva::aes_decrypt(reinterpret_cast<const char*>(cipher_4),
cipher_4_len, key, decrypted_4);

EXPECT_EQ(std::string(reinterpret_cast<const char*>(to_encrypt), to_encrypt_len),
std::string(reinterpret_cast<const char*>(decrypted_4), decrypted_4_len));

// 128 bytes key
key =
"12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12"
"345678abcdefgh12345678abcdefgh12345678abcdefgh";
to_encrypt = "A much more longer string then the previous one, but without newline";

to_encrypt_len =
static_cast<int32_t>(strlen(reinterpret_cast<const char*>(to_encrypt)));
unsigned char cipher_5[128];

int32_t cipher_5_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_5);

unsigned char decrypted_5[128];
int32_t decrypted_5_len = gandiva::aes_decrypt(reinterpret_cast<const char*>(cipher_5),
cipher_5_len, key, decrypted_5);

EXPECT_EQ(std::string(reinterpret_cast<const char*>(to_encrypt), to_encrypt_len),
std::string(reinterpret_cast<const char*>(decrypted_5), decrypted_5_len));

// 192 bytes key
key =
"12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12"
"345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh1234"
"5678abcdefgh12345678abcdefgh";
to_encrypt =
"A much more longer string then the previous one, but with \nnewline, pretty cool, "
"right?";

to_encrypt_len =
static_cast<int32_t>(strlen(reinterpret_cast<const char*>(to_encrypt)));
unsigned char cipher_6[256];

int32_t cipher_6_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_6);

unsigned char decrypted_6[256];
int32_t decrypted_6_len = gandiva::aes_decrypt(reinterpret_cast<const char*>(cipher_6),
cipher_6_len, key, decrypted_6);
ASSERT_THROW({
gandiva::aes_decrypt(cipher, cipher_len, key, key_len, plain_text);
}, std::runtime_error);

EXPECT_EQ(std::string(reinterpret_cast<const char*>(to_encrypt), to_encrypt_len),
std::string(reinterpret_cast<const char*>(decrypted_6), decrypted_6_len));

// 256 bytes key
key =
"12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12"
"345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh1234"
"5678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh123456"
"78abcdefgh";
to_encrypt =
"A much more longer string then the previous one, but with \nnewline, pretty cool, "
"right?";
key = "12345678";
to_encrypt = "New\ntest\nstring";

key_len = static_cast<int32_t>(strlen(reinterpret_cast<const char*>(key)));
to_encrypt_len =
static_cast<int32_t>(strlen(reinterpret_cast<const char*>(to_encrypt)));
unsigned char cipher_7[256];

int32_t cipher_7_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_7);

unsigned char decrypted_7[256];
int32_t decrypted_7_len = gandiva::aes_decrypt(reinterpret_cast<const char*>(cipher_7),
cipher_7_len, key, decrypted_7);

EXPECT_EQ(std::string(reinterpret_cast<const char*>(to_encrypt), to_encrypt_len),
std::string(reinterpret_cast<const char*>(decrypted_7), decrypted_7_len));
unsigned char cipher_5[64];
ASSERT_THROW({
gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_5);
}, std::runtime_error);
ASSERT_THROW({
gandiva::aes_decrypt(cipher, cipher_len, key, key_len, plain_text);
}, std::runtime_error);
}
26 changes: 21 additions & 5 deletions cpp/src/gandiva/gdv_function_stubs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,6 @@ CAST_NUMERIC_FROM_VARBINARY(double, arrow::DoubleType, FLOAT8)
#undef GDV_FN_CAST_VARCHAR_INTEGER
#undef GDV_FN_CAST_VARCHAR_REAL

static constexpr int64_t kAesBlockSize = 16; // bytes

GANDIVA_EXPORT
const char* gdv_fn_aes_encrypt(int64_t context, const char* data, int32_t data_len,
const char* key_data, int32_t key_data_len,
Expand All @@ -318,6 +316,15 @@ const char* gdv_fn_aes_encrypt(int64_t context, const char* data, int32_t data_l
return "";
}

int64_t kAesBlockSize = 0;
if (key_data_len == 16 || key_data_len == 24 || key_data_len == 32) {
kAesBlockSize = static_cast<int64_t>(key_data_len);
} else {
gdv_fn_context_set_error_msg(context, "invalid key length");
*out_len = 0;
return nullptr;
Comment on lines +320 to +325
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't the helper functions already check the length? I think it would be better to just do the check in one place since the code would be easier to read and better tested. It doesn't look like this path is being tested since the unit tests call the helper functions.

}

*out_len =
static_cast<int32_t>(arrow::bit_util::RoundUpToPowerOf2(data_len, kAesBlockSize));
char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
Expand All @@ -330,7 +337,7 @@ const char* gdv_fn_aes_encrypt(int64_t context, const char* data, int32_t data_l
}

try {
*out_len = gandiva::aes_encrypt(data, data_len, key_data,
*out_len = gandiva::aes_encrypt(data, data_len, key_data, key_data_len,
reinterpret_cast<unsigned char*>(ret));
} catch (const std::runtime_error& e) {
gdv_fn_context_set_error_msg(context, e.what());
Expand All @@ -351,6 +358,15 @@ const char* gdv_fn_aes_decrypt(int64_t context, const char* data, int32_t data_l
return "";
}

int64_t kAesBlockSize = 0;
if (key_data_len == 16 || key_data_len == 24 || key_data_len == 32) {
kAesBlockSize = static_cast<int64_t>(key_data_len);
} else {
gdv_fn_context_set_error_msg(context, "invalid key length");
*out_len = 0;
return nullptr;
}

*out_len =
static_cast<int32_t>(arrow::bit_util::RoundUpToPowerOf2(data_len, kAesBlockSize));
char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
Expand All @@ -363,14 +379,14 @@ const char* gdv_fn_aes_decrypt(int64_t context, const char* data, int32_t data_l
}

try {
*out_len = gandiva::aes_decrypt(data, data_len, key_data,
*out_len = gandiva::aes_decrypt(data, data_len, key_data, key_data_len,
reinterpret_cast<unsigned char*>(ret));
} catch (const std::runtime_error& e) {
gdv_fn_context_set_error_msg(context, e.what());
*out_len = 0;
return nullptr;
}

ret[*out_len] = '\0';
return ret;
}

Expand Down
70 changes: 70 additions & 0 deletions cpp/src/gandiva/gdv_function_stubs_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1345,4 +1345,74 @@ TEST(TestGdvFnStubs, TestMask) {
EXPECT_EQ(std::string(result, out_len), expected);
}

TEST(TestGdvFnStubs, TestAesEncryptDecrypt16) {
gandiva::ExecutionContext ctx;
std::string key16 = "12345678abcdefgh";
auto key16_len = static_cast<int32_t>(key16.length());
int32_t cipher_len = 0;
int32_t decrypted_len = 0;
std::string data = "test string";
auto data_len = static_cast<int32_t>(data.length());
int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);

const char* cipher = gdv_fn_aes_encrypt(ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, &cipher_len);
const char* decrypted_value = gdv_fn_aes_decrypt(ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, &decrypted_len);

EXPECT_EQ(data, std::string(reinterpret_cast<const char*>(decrypted_value), decrypted_len));
}

TEST(TestGdvFnStubs, TestAesEncryptDecrypt24) {
gandiva::ExecutionContext ctx;
std::string key24 = "12345678abcdefgh12345678";
auto key24_len = static_cast<int32_t>(key24.length());
int32_t cipher_len = 0;
int32_t decrypted_len = 0;
std::string data = "test string";
auto data_len = static_cast<int32_t>(data.length());
int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);

const char* cipher = gdv_fn_aes_encrypt(ctx_ptr, data.c_str(), data_len, key24.c_str(), key24_len, &cipher_len);

const char* decrypted_value = gdv_fn_aes_decrypt(ctx_ptr, cipher, cipher_len, key24.c_str(), key24_len, &decrypted_len);

EXPECT_EQ(data, std::string(reinterpret_cast<const char*>(decrypted_value), decrypted_len));
}

TEST(TestGdvFnStubs, TestAesEncryptDecrypt32) {
gandiva::ExecutionContext ctx;
std::string key32 = "12345678abcdefgh12345678abcdefgh";
auto key32_len = static_cast<int32_t>(key32.length());
int32_t cipher_len = 0;
int32_t decrypted_len = 0;
std::string data = "test string";
auto data_len = static_cast<int32_t>(data.length());
int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);

const char* cipher = gdv_fn_aes_encrypt(ctx_ptr, data.c_str(), data_len, key32.c_str(), key32_len, &cipher_len);

const char* decrypted_value = gdv_fn_aes_decrypt(ctx_ptr, cipher, cipher_len, key32.c_str(), key32_len, &decrypted_len);

EXPECT_EQ(data, std::string(reinterpret_cast<const char*>(decrypted_value), decrypted_len));
}

TEST(TestGdvFnStubs, TestAesEncryptDecryptValidation) {
gandiva::ExecutionContext ctx;
std::string key33 = "12345678abcdefgh12345678abcdefghb";
auto key33_len = static_cast<int32_t>(key33.length());
int32_t decrypted_len = 0;
std::string data = "test string";
auto data_len = static_cast<int32_t>(data.length());
int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);
std::string cipher = "12345678abcdefgh12345678abcdefghb";
auto cipher_len = static_cast<int32_t>(cipher.length());

gdv_fn_aes_encrypt(ctx_ptr, data.c_str(), data_len, key33.c_str(), key33_len, &cipher_len);
EXPECT_THAT(ctx.get_error(),
::testing::HasSubstr("invalid key length"));
ctx.Reset();

gdv_fn_aes_decrypt(ctx_ptr, cipher.c_str(), cipher_len, key33.c_str(), key33_len, &decrypted_len); EXPECT_THAT(ctx.get_error(),
::testing::HasSubstr("invalid key length"));
ctx.Reset();
}
} // namespace gandiva
Loading