diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFAesBase.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFAesBase.java index 0311b4f500ba..92ce86068d39 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFAesBase.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFAesBase.java @@ -22,14 +22,20 @@ import java.security.GeneralSecurityException; import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.security.spec.AlgorithmParameterSpec; import javax.crypto.Cipher; import javax.crypto.NoSuchPaddingException; import javax.crypto.SecretKey; +import javax.crypto.spec.GCMParameterSpec; +import javax.crypto.spec.IvParameterSpec; import javax.crypto.spec.SecretKeySpec; +import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.session.SessionState; import org.apache.hadoop.hive.serde2.objectinspector.ConstantObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter; @@ -44,18 +50,25 @@ * */ public abstract class GenericUDFAesBase extends GenericUDF { - protected transient Converter[] converters = new Converter[2]; - protected transient PrimitiveCategory[] inputTypes = new PrimitiveCategory[2]; + protected transient Converter[] converters = new Converter[3]; + protected transient PrimitiveCategory[] inputTypes = new PrimitiveCategory[3]; protected final BytesWritable output = new BytesWritable(); protected transient boolean isStr0; protected transient boolean isStr1; protected transient boolean isKeyConstant; protected transient Cipher cipher; protected transient SecretKey secretKey; + private boolean isGCM; + private boolean isCTR; + private static final int GCM_IV_LENGTH = 12; + private static final int CTR_IV_LENGTH = 16; + private static final int GCM_TAG_LENGTH = 128; + private static final String AES_GCM_NOPADDING = "AES/GCM/NoPadding"; + private static final String AES_CTR_NOPADDING = "AES/CTR/NoPadding"; @Override public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { - checkArgsSize(arguments, 2, 2); + checkArgsSize(arguments, 2, 3); checkArgPrimitive(arguments, 0); checkArgPrimitive(arguments, 1); @@ -104,8 +117,24 @@ public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumen secretKey = getSecretKey(key, keyLength); } + String cipherTransform = "AES"; + + if (arguments.length == 3) { + checkArgPrimitive(arguments, 2); + checkArgGroups(arguments, 2, inputTypes, STRING_GROUP); + + String userDefinedMode = getConstantStringValue(arguments, 2); + if (userDefinedMode != null) { + cipherTransform = userDefinedMode; + } + + } + + this.isGCM = AES_GCM_NOPADDING.equalsIgnoreCase(cipherTransform); + this.isCTR = AES_CTR_NOPADDING.equalsIgnoreCase(cipherTransform); + try { - cipher = Cipher.getInstance("AES"); + cipher = Cipher.getInstance(cipherTransform); } catch (NoSuchPaddingException | NoSuchAlgorithmException e) { throw new RuntimeException(e); } @@ -186,9 +215,52 @@ protected SecretKey getSecretKey(byte[] key, int keyLength) { protected byte[] aesFunction(byte[] input, int inputLength, SecretKey secretKey) { try { - cipher.init(getCipherMode(), secretKey); - byte[] res = cipher.doFinal(input, 0, inputLength); - return res; + if (isGCM || isCTR) { + int ivLen = isGCM ? GCM_IV_LENGTH : CTR_IV_LENGTH; + + if (getCipherMode() == Cipher.ENCRYPT_MODE) { + byte[] iv = new byte[ivLen]; + new SecureRandom().nextBytes(iv); + + AlgorithmParameterSpec paramSpec; + if (isGCM) { + paramSpec = new GCMParameterSpec(GCM_TAG_LENGTH, iv); + } else { + paramSpec = new IvParameterSpec(iv); + } + + cipher.init(Cipher.ENCRYPT_MODE, secretKey, paramSpec); + byte[] cipherText = cipher.doFinal(input, 0, inputLength); + + byte[] output = new byte[iv.length + cipherText.length]; + System.arraycopy(iv, 0, output, 0, iv.length); + System.arraycopy(cipherText, 0, output, iv.length, cipherText.length); + + return output; + + } else { + int minLen = isGCM ? (ivLen + (GCM_TAG_LENGTH / 8)) : (ivLen + 1); + if (inputLength < minLen) { + return null; + } + + byte[] iv = new byte[ivLen]; + System.arraycopy(input, 0, iv, 0, ivLen); + + AlgorithmParameterSpec paramSpec; + if (isGCM) { + paramSpec = new GCMParameterSpec(GCM_TAG_LENGTH, iv); + } else { + paramSpec = new IvParameterSpec(iv); + } + + cipher.init(Cipher.DECRYPT_MODE, secretKey, paramSpec); + return cipher.doFinal(input, ivLen, inputLength - ivLen); + } + } else { + cipher.init(getCipherMode(), secretKey); + return cipher.doFinal(input, 0, inputLength); + } } catch (GeneralSecurityException e) { return null; }