Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,20 @@

import java.security.GeneralSecurityException;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.security.spec.AlgorithmParameterSpec;

import javax.crypto.Cipher;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.SecretKey;
import javax.crypto.spec.GCMParameterSpec;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;

import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.session.SessionState;
import org.apache.hadoop.hive.serde2.objectinspector.ConstantObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter;
Expand All @@ -44,18 +50,25 @@
*
*/
public abstract class GenericUDFAesBase extends GenericUDF {
protected transient Converter[] converters = new Converter[2];
protected transient PrimitiveCategory[] inputTypes = new PrimitiveCategory[2];
protected transient Converter[] converters = new Converter[3];
protected transient PrimitiveCategory[] inputTypes = new PrimitiveCategory[3];
protected final BytesWritable output = new BytesWritable();
protected transient boolean isStr0;
protected transient boolean isStr1;
protected transient boolean isKeyConstant;
protected transient Cipher cipher;
protected transient SecretKey secretKey;
private boolean isGCM;
private boolean isCTR;
private static final int GCM_IV_LENGTH = 12;
private static final int CTR_IV_LENGTH = 16;
private static final int GCM_TAG_LENGTH = 128;
private static final String AES_GCM_NOPADDING = "AES/GCM/NoPadding";
private static final String AES_CTR_NOPADDING = "AES/CTR/NoPadding";

@Override
public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
checkArgsSize(arguments, 2, 2);
checkArgsSize(arguments, 2, 3);

checkArgPrimitive(arguments, 0);
checkArgPrimitive(arguments, 1);
Expand Down Expand Up @@ -104,8 +117,24 @@ public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumen
secretKey = getSecretKey(key, keyLength);
}

String cipherTransform = "AES";

if (arguments.length == 3) {
checkArgPrimitive(arguments, 2);
checkArgGroups(arguments, 2, inputTypes, STRING_GROUP);

String userDefinedMode = getConstantStringValue(arguments, 2);
if (userDefinedMode != null) {
cipherTransform = userDefinedMode;
}

}

this.isGCM = AES_GCM_NOPADDING.equalsIgnoreCase(cipherTransform);
this.isCTR = AES_CTR_NOPADDING.equalsIgnoreCase(cipherTransform);

try {
cipher = Cipher.getInstance("AES");
cipher = Cipher.getInstance(cipherTransform);
} catch (NoSuchPaddingException | NoSuchAlgorithmException e) {
throw new RuntimeException(e);
}
Expand Down Expand Up @@ -186,9 +215,52 @@ protected SecretKey getSecretKey(byte[] key, int keyLength) {

protected byte[] aesFunction(byte[] input, int inputLength, SecretKey secretKey) {
try {
cipher.init(getCipherMode(), secretKey);
byte[] res = cipher.doFinal(input, 0, inputLength);
return res;
if (isGCM || isCTR) {
int ivLen = isGCM ? GCM_IV_LENGTH : CTR_IV_LENGTH;

if (getCipherMode() == Cipher.ENCRYPT_MODE) {
byte[] iv = new byte[ivLen];
new SecureRandom().nextBytes(iv);

AlgorithmParameterSpec paramSpec;
if (isGCM) {
paramSpec = new GCMParameterSpec(GCM_TAG_LENGTH, iv);
} else {
paramSpec = new IvParameterSpec(iv);
}

cipher.init(Cipher.ENCRYPT_MODE, secretKey, paramSpec);
byte[] cipherText = cipher.doFinal(input, 0, inputLength);

byte[] output = new byte[iv.length + cipherText.length];
System.arraycopy(iv, 0, output, 0, iv.length);
System.arraycopy(cipherText, 0, output, iv.length, cipherText.length);

return output;

} else {
int minLen = isGCM ? (ivLen + (GCM_TAG_LENGTH / 8)) : (ivLen + 1);
if (inputLength < minLen) {
return null;
}

byte[] iv = new byte[ivLen];
System.arraycopy(input, 0, iv, 0, ivLen);

AlgorithmParameterSpec paramSpec;
if (isGCM) {
paramSpec = new GCMParameterSpec(GCM_TAG_LENGTH, iv);
} else {
paramSpec = new IvParameterSpec(iv);
}

cipher.init(Cipher.DECRYPT_MODE, secretKey, paramSpec);
return cipher.doFinal(input, ivLen, inputLength - ivLen);
}
} else {
cipher.init(getCipherMode(), secretKey);
return cipher.doFinal(input, 0, inputLength);
}
} catch (GeneralSecurityException e) {
return null;
}
Expand Down