From 024d838cfff82cbcb633df2f1363b89983a0f958 Mon Sep 17 00:00:00 2001 From: aaronshan Date: Thu, 8 Sep 2016 16:33:24 +0800 Subject: [PATCH] add array_equals function --- README.md | 5 +- pom.xml | 2 +- .../functions/array/UDFArrayEquals.java | 99 +++++++++++++++++++ .../functions/utils/ArrayUtils.java | 28 ++++++ 4 files changed, 132 insertions(+), 2 deletions(-) create mode 100644 src/main/java/cc/shanruifeng/functions/array/UDFArrayEquals.java diff --git a/README.md b/README.md index 63411d1..55ccba2 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ It will generate hive-third-functions-${version}-shaded.jar in target directory. You can also directly download file from [release page](https://github.com/aaronshan/hive-third-functions/releases). -> current latest version is `2.2.0` +> current latest version is `2.2.1` ## Functions @@ -52,6 +52,7 @@ You can also directly download file from [release page](https://github.com/aaron | function| description | |:--|:--| |array_contains(array<E>, E) -> boolean | whether array contains value or not.| +|array_equals(array<E>, array<E>) -> boolean | whether two array equals or not.| |array_intersect(array, array) -> array | returns the two array's intersection, without duplicates.| |array_max(array<E>) -> E | returns the maximum value of input array.| |array_min(array<E>) -> E | returns the minimum value of input array.| @@ -146,6 +147,7 @@ Put these statements into `${HOME}/.hiverc` or exec its on hive cli env. ``` add jar ${jar_location_dir}/hive-third-functions-${version}-shaded.jar create temporary function array_contains as 'cc.shanruifeng.functions.array.UDFArrayContains'; +create temporary function array_equals as 'cc.shanruifeng.functions.array.UDFArrayEquals'; create temporary function array_intersect as 'cc.shanruifeng.functions.array.UDFArrayIntersect'; create temporary function array_max as 'cc.shanruifeng.functions.array.UDFArrayMax'; create temporary function array_min as 'cc.shanruifeng.functions.array.UDFArrayMin'; @@ -236,6 +238,7 @@ select zodiac_en('1989-01-08') => Capricorn ``` select array_contains(array(16,12,18,9), 12) => true +select array_equals(array(16,12,18,9), array(16,12,18,9)) => true select array_intersect(array(16,12,18,9,null), array(14,9,6,18,null)) => [null,9,18] select array_max(array(16,13,12,13,18,16,9,18)) => 18 select array_min(array(16,12,18,9)) => 9 diff --git a/pom.xml b/pom.xml index 5ca98e7..d22e307 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ cc.shanruifeng hive-third-functions - 2.2.0 + 2.2.1 UTF-8 diff --git a/src/main/java/cc/shanruifeng/functions/array/UDFArrayEquals.java b/src/main/java/cc/shanruifeng/functions/array/UDFArrayEquals.java new file mode 100644 index 0000000..f9811d7 --- /dev/null +++ b/src/main/java/cc/shanruifeng/functions/array/UDFArrayEquals.java @@ -0,0 +1,99 @@ +package cc.shanruifeng.functions.array; + +import cc.shanruifeng.functions.utils.ArrayUtils; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.io.BooleanWritable; + +/** + * @author ruifeng.shan + * @date 2016-09-08 + * @time 16:03 + */ +@Description(name = "array_equals" + , value = "_FUNC_(array, array) - whether two arrays equals or not." + , extended = "Example:\n > select _FUNC_(array, array) from src;") +public class UDFArrayEquals extends GenericUDF { + private static final int ARG_COUNT = 2; // Number of arguments to this UDF + private transient ListObjectInspector leftArrayOI; + private transient ListObjectInspector rightArrayOI; + private transient ObjectInspector leftArrayElementOI; + private transient ObjectInspector rightArrayElementOI; + + private BooleanWritable result; + + public UDFArrayEquals() { + } + + @Override + public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { + // Check if two arguments were passed + if (arguments.length != ARG_COUNT) { + throw new UDFArgumentLengthException( + "The function array_equals(array, array) takes exactly " + ARG_COUNT + " arguments."); + } + + // Check if two argument is of category LIST + for (int i = 0; i < 2; i++) { + if (!arguments[i].getCategory().equals(ObjectInspector.Category.LIST)) { + throw new UDFArgumentTypeException(i, + "\"" + org.apache.hadoop.hive.serde.serdeConstants.LIST_TYPE_NAME + "\" " + + "expected at function array_equals, but " + + "\"" + arguments[i].getTypeName() + "\" " + + "is found"); + } + } + + leftArrayOI = (ListObjectInspector) arguments[0]; + rightArrayOI = (ListObjectInspector) arguments[1]; + + leftArrayElementOI = leftArrayOI.getListElementObjectInspector(); + rightArrayElementOI = rightArrayOI.getListElementObjectInspector(); + + // Check if two array are of same type + if (!ObjectInspectorUtils.compareTypes(leftArrayElementOI, rightArrayElementOI)) { + throw new UDFArgumentTypeException(1, + "\"" + leftArrayElementOI.getTypeName() + "\"" + + " expected at function array_equals, but " + + "\"" + rightArrayElementOI.getTypeName() + "\"" + + " is found"); + } + + // Check if the comparison is supported for this type + if (!ObjectInspectorUtils.compareSupported(leftArrayElementOI)) { + throw new UDFArgumentException("The function array_equals" + + " does not support comparison for " + + "\"" + leftArrayElementOI.getTypeName() + "\"" + + " types"); + } + + result = new BooleanWritable(false); + return PrimitiveObjectInspectorFactory.writableBooleanObjectInspector; + } + + @Override + public Object evaluate(DeferredObject[] arguments) throws HiveException { + Object leftArray = arguments[0].get(); + Object rightArray = arguments[1].get(); + + boolean ret = ArrayUtils.arrayEquals(leftArray, rightArray, leftArrayOI); + result.set(ret); + + return result; + } + + @Override + public String getDisplayString(String[] strings) { + assert (strings.length == ARG_COUNT); + return "array_equals(" + strings[0] + ", " + + strings[1] + ")"; + } +} \ No newline at end of file diff --git a/src/main/java/cc/shanruifeng/functions/utils/ArrayUtils.java b/src/main/java/cc/shanruifeng/functions/utils/ArrayUtils.java index 3b0e69b..e04f3b3 100644 --- a/src/main/java/cc/shanruifeng/functions/utils/ArrayUtils.java +++ b/src/main/java/cc/shanruifeng/functions/utils/ArrayUtils.java @@ -34,4 +34,32 @@ public int compare(int left, int right) { } }; } + + public static boolean arrayEquals(Object left, Object right, ListObjectInspector arrayOI) { + if (left == null || right == null) { + if (left == null && right == null) { + return true; + } + return false; + } + + int leftArrayLength = arrayOI.getListLength(left); + int rightArrayLength = arrayOI.getListLength(right); + + if (leftArrayLength != rightArrayLength) { + return false; + } + + ObjectInspector arrayElementOI = arrayOI.getListElementObjectInspector(); + for (int i = 0; i < leftArrayLength; i++) { + Object leftArrayElement = arrayOI.getListElement(left, i); + Object rightArrayElement = arrayOI.getListElement(right, i); + int compareValue = ObjectInspectorUtils.compare(leftArrayElement, arrayElementOI, rightArrayElement, arrayElementOI); + if (compareValue != 0) { + return false; + } + } + + return true; + } }