diff --git a/include/gmssl/sm4.h b/include/gmssl/sm4.h index d5cd87a7..dc240a20 100644 --- a/include/gmssl/sm4.h +++ b/include/gmssl/sm4.h @@ -94,7 +94,8 @@ int sm4_ctr32_encrypt_finish(SM4_CTR_CTX *ctx, uint8_t *out, size_t *outlen); #define SM4_GCM_MAX_AAD_SIZE (1<<24) // 16MiB #define SM4_GCM_MIN_PLAINTEXT_SIZE 0 -#define SM4_GCM_MAX_PLAINTEXT_SIZE ((((uint64_t)1 << 39) - 256) >> 3) // 68719476704 +#define SM4_GCM_MAX_PLAINTEXT_NBLOCKS (((uint64_t)1 << 32) - 2) +#define SM4_GCM_MAX_PLAINTEXT_SIZE (SM4_GCM_MAX_PLAINTEXT_NBLOCKS * 16) // 68719476704 #define SM4_GCM_MAX_TAG_SIZE 16 #define SM4_GCM_MIN_TAG_SIZE 12 @@ -117,6 +118,7 @@ typedef struct { size_t taglen; uint8_t mac[16]; size_t maclen; + uint64_t encedlen; } SM4_GCM_CTX; int sm4_gcm_encrypt_init(SM4_GCM_CTX *ctx, diff --git a/src/sm4_gcm.c b/src/sm4_gcm.c index d56b72e5..5fbfe561 100644 --- a/src/sm4_gcm.c +++ b/src/sm4_gcm.c @@ -8,6 +8,7 @@ */ +#include #include #include #include @@ -33,7 +34,15 @@ int sm4_gcm_encrypt(const SM4_KEY *key, const uint8_t *iv, size_t ivlen, uint8_t Y[16]; uint8_t T[16]; - if (taglen > SM4_GCM_MAX_TAG_SIZE) { + if (ivlen < SM4_GCM_MIN_IV_SIZE || ivlen > SM4_GCM_MAX_IV_SIZE) { + error_print(); + return -1; + } + if (taglen < SM4_GCM_MIN_TAG_SIZE || taglen > SM4_GCM_MAX_TAG_SIZE) { + error_print(); + return -1; + } + if (inlen > SM4_GCM_MAX_PLAINTEXT_SIZE) { error_print(); return -1; } @@ -67,6 +76,19 @@ int sm4_gcm_decrypt(const SM4_KEY *key, const uint8_t *iv, size_t ivlen, uint8_t Y[16]; uint8_t T[16]; + if (ivlen < SM4_GCM_MIN_IV_SIZE || ivlen > SM4_GCM_MAX_IV_SIZE) { + error_print(); + return -1; + } + if (taglen < SM4_GCM_MIN_TAG_SIZE || taglen > SM4_GCM_MAX_TAG_SIZE) { + error_print(); + return -1; + } + if (inlen > SM4_GCM_MAX_PLAINTEXT_SIZE) { + error_print(); + return -1; + } + sm4_encrypt(key, H, H); if (ivlen == 12) { @@ -111,7 +133,7 @@ int sm4_gcm_encrypt_init(SM4_GCM_CTX *ctx, error_print(); return -1; } - if (taglen < 8 || taglen > 16) { + if (taglen < SM4_GCM_MIN_TAG_SIZE || taglen > SM4_GCM_MAX_TAG_SIZE) { error_print(); return -1; } @@ -152,15 +174,27 @@ int sm4_gcm_encrypt_update(SM4_GCM_CTX *ctx, const uint8_t *in, size_t inlen, ui error_print(); return -1; } + if (inlen > INT_MAX) { + error_print(); + return -1; + } + if (inlen > SM4_GCM_MAX_PLAINTEXT_SIZE - ctx->encedlen) { + error_print(); + return -1; + } if (!out) { *outlen = 16 * ((inlen + 15)/16); return 1; } + if (sm4_ctr32_encrypt_update(&ctx->enc_ctx, in, inlen, out, outlen) != 1) { error_print(); return -1; } + ghash_update(&ctx->mac_ctx, out, *outlen); + + ctx->encedlen += inlen; return 1; } @@ -205,6 +239,15 @@ int sm4_gcm_decrypt_update(SM4_GCM_CTX *ctx, const uint8_t *in, size_t inlen, ui error_print(); return -1; } + if (inlen > INT_MAX) { + error_print(); + return -1; + } + if (inlen > SM4_GCM_MAX_PLAINTEXT_SIZE - ctx->encedlen) { + error_print(); + return -1; + } + if (!out) { *outlen = 16 * ((inlen + 15)/16); return 1; @@ -256,6 +299,8 @@ int sm4_gcm_decrypt_update(SM4_GCM_CTX *ctx, const uint8_t *in, size_t inlen, ui *outlen += len; memcpy(ctx->mac, in + inlen, GHASH_SIZE); } + + ctx->encedlen += inlen; return 1; }