From 4890fac7bd1eb54caf2cc9b2a795d580ec84f145 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 11 Apr 2018 21:52:25 -0700 Subject: [PATCH 1/9] [SPARK-23912][SQL]add array_distinct --- python/pyspark/sql/functions.py | 14 ++++++++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 26 +++++++++++++++++++ .../CollectionExpressionsSuite.scala | 14 ++++++++++ .../org/apache/spark/sql/functions.scala | 7 +++++ .../spark/sql/DataFrameFunctionsSuite.scala | 22 ++++++++++++++++ 6 files changed, 84 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1cdbb8a4c3e8..192ca9e7df4f 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1999,6 +1999,20 @@ def array_remove(col, element): return Column(sc._jvm.functions.array_remove(_to_java_column(col), element)) +@since(2.4) +def array_distinct(col): + """ + Collection function: removes duplicate values from the array. + :param col: name of column or expression + + >>> df = spark.createDataFrame([([1, 2, 3, 2],), ([4, 5, 5, 4],)], ['data']) + >>> df.select(array_distinct(df.data)).collect() + [Row(array_distinct(data)=[1, 2, 3]), Row(array_distinct(data)=[4, 5])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_distinct(_to_java_column(col))) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 3700c63d817e..4b09b9a7e75d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -433,6 +433,7 @@ object FunctionRegistry { expression[Flatten]("flatten"), expression[ArrayRepeat]("array_repeat"), expression[ArrayRemove]("array_remove"), + expression[ArrayDistinct]("array_distinct"), CreateStruct.registryEntry, // mask functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index d76f3013f0c4..3d52c19e4549 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2355,3 +2355,29 @@ case class ArrayRemove(left: Expression, right: Expression) override def prettyName: String = "array_remove" } + +/** + * Removes duplicate values from the array. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Removes duplicate values from the array.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3, null, 3)); + [1,2,3,null] + """, since = "2.4.0") +case class ArrayDistinct(child: Expression) + extends UnaryExpression with ExpectsInputTypes with CodegenFallback { + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + override def dataType: DataType = child.dataType + + override def nullSafeEval(array: Any): Any = { + val elementType = child.dataType.asInstanceOf[ArrayType].elementType + val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType).distinct + new GenericArrayData(data.asInstanceOf[Array[Any]]) + } + + override def prettyName: String = "array_distinct" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 85e692bdc4ef..b6da579fac4f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -766,4 +766,18 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayRemove(c1, dataToRemove2), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1))) checkEvaluation(ArrayRemove(c2, dataToRemove2), Seq[Seq[Int]](null, Seq[Int](2, 1))) } + + test("Array Unique") { + val a0 = Literal.create(Seq(2, 1, 2, 3, 4, 4, 5), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) + val a2 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType)) + val a3 = Literal.create(Seq("b", null, "a", "a"), ArrayType(StringType)) + val a4 = Literal.create(Seq(null, null), ArrayType(NullType)) + + checkEvaluation(new ArrayDistinct(a0), Seq(2, 1, 3, 4, 5)) + checkEvaluation(new ArrayDistinct(a1), Seq[Integer]()) + checkEvaluation(new ArrayDistinct(a2), Seq("b", "a", "c")) + checkEvaluation(new ArrayDistinct(a3), Seq("b", null, "a")) + checkEvaluation(new ArrayDistinct(a4), Seq(null)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 87bd7b3b0f9c..9b0255c12c63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3189,6 +3189,13 @@ object functions { ArrayRemove(column.expr, Literal(element)) } + /** + * Removes duplicate values from the array. + * @group collection_funcs + * @since 2.4.0 + */ + def array_distinct(e: Column): Column = withExpr { ArrayDistinct(e.expr) } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 4e5c1c56e267..3dc696bd01ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1216,6 +1216,28 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(e.message.contains("argument 1 requires array type, however, '`_1`' is of string type")) } + test("array_distinct functions") { + val df = Seq( + (Array[Int](2, 1, 3, 4, 3, 5), Array("b", "c", "a", "c", "b", "", "")), + (Array.empty[Int], Array.empty[String]), + (null, null) + ).toDF("a", "b") + checkAnswer( + df.select(array_distinct($"a"), array_distinct($"b")), + Seq( + Row(Seq(2, 1, 3, 4, 5), Seq("b", "c", "a", "")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) + ) + checkAnswer( + df.selectExpr("array_distinct(a)", "array_distinct(b)"), + Seq( + Row(Seq(2, 1, 3, 4, 5), Seq("b", "c", "a", "")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) + ) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 1cefd72be4f2e28de5c706a8970610458f39133d Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 25 Apr 2018 14:41:23 -0700 Subject: [PATCH 2/9] resolve conflicts --- .../expressions/collectionOperations.scala | 53 ++++++++++++++++++- .../CollectionExpressionsSuite.scala | 6 +-- 2 files changed, 54 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 3d52c19e4549..826da17985d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2365,9 +2365,9 @@ case class ArrayRemove(left: Expression, right: Expression) Examples: > SELECT _FUNC_(array(1, 2, 3, null, 3)); [1,2,3,null] - """, since = "2.4.0") + """, since = "2.4.0") case class ArrayDistinct(child: Expression) - extends UnaryExpression with ExpectsInputTypes with CodegenFallback { + extends UnaryExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) @@ -2379,5 +2379,54 @@ case class ArrayDistinct(child: Expression) new GenericArrayData(data.asInstanceOf[Array[Any]]) } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val elementType = dataType.asInstanceOf[ArrayType].elementType + nullSafeCodeGen(ctx, ev, (array) => { + val arrayClass = classOf[GenericArrayData].getName + val tempArray = ctx.freshName("tempArray") + val distinctArray = ctx.freshName("distinctArray") + val i = ctx.freshName("i") + val j = ctx.freshName("j") + val pos = ctx.freshName("arrayPosition") + val getValue1 = CodeGenerator.getValue(array, elementType, i) + val getValue2 = CodeGenerator.getValue(array, elementType, j) + s""" + |int $pos = 0; + |Object[] $tempArray = new Object[$array.numElements()]; + |for (int $i = 0; $i < $array.numElements(); $i ++) { + | if ($array.isNullAt($i)) { + | int $j; + | for ($j = 0; $j < $i; $j ++) { + | if ($array.isNullAt($j)) + | break; + | } + | if ($i == $j) { + | $tempArray[$pos] = null; + | $pos = $pos + 1; + | } + | } + | else { + | int $j; + | for ($j = 0; $j < $i; $j ++) { + | if (${ctx.genEqual(elementType, getValue1, getValue2)}) + | break; + | } + | if ($i == $j) { + | $tempArray[$pos] = ${CodeGenerator.getValue(array, elementType, s"$i")}; + | $pos = $pos + 1; + | } + | } + |} + | + |Object[] $distinctArray = new Object[$pos]; + |for (int $i = 0; $i < $pos; $i ++) { + | $distinctArray[$i] = $tempArray[$i]; + |} + | + |${ev.value} = new $arrayClass($distinctArray); + """.stripMargin + }) + } + override def prettyName: String = "array_distinct" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index b6da579fac4f..690b431204f8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -767,15 +767,15 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayRemove(c2, dataToRemove2), Seq[Seq[Int]](null, Seq[Int](2, 1))) } - test("Array Unique") { + test("Array Distinct") { val a0 = Literal.create(Seq(2, 1, 2, 3, 4, 4, 5), ArrayType(IntegerType)) - val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) + val a1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) val a2 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType)) val a3 = Literal.create(Seq("b", null, "a", "a"), ArrayType(StringType)) val a4 = Literal.create(Seq(null, null), ArrayType(NullType)) checkEvaluation(new ArrayDistinct(a0), Seq(2, 1, 3, 4, 5)) - checkEvaluation(new ArrayDistinct(a1), Seq[Integer]()) + checkEvaluation(new ArrayDistinct(a1), Seq.empty[Integer]) checkEvaluation(new ArrayDistinct(a2), Seq("b", "a", "c")) checkEvaluation(new ArrayDistinct(a3), Seq("b", null, "a")) checkEvaluation(new ArrayDistinct(a4), Seq(null)) From 3ce3ddcb5d020162dc84888679f642a7718dfc24 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 1 May 2018 13:59:06 -0700 Subject: [PATCH 3/9] address comments(2) --- .../expressions/collectionOperations.scala | 34 ++++++++++++++----- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 826da17985d0..da4c2ebcfd9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2383,7 +2383,6 @@ case class ArrayDistinct(child: Expression) val elementType = dataType.asInstanceOf[ArrayType].elementType nullSafeCodeGen(ctx, ev, (array) => { val arrayClass = classOf[GenericArrayData].getName - val tempArray = ctx.freshName("tempArray") val distinctArray = ctx.freshName("distinctArray") val i = ctx.freshName("i") val j = ctx.freshName("j") @@ -2392,7 +2391,6 @@ case class ArrayDistinct(child: Expression) val getValue2 = CodeGenerator.getValue(array, elementType, j) s""" |int $pos = 0; - |Object[] $tempArray = new Object[$array.numElements()]; |for (int $i = 0; $i < $array.numElements(); $i ++) { | if ($array.isNullAt($i)) { | int $j; @@ -2401,26 +2399,46 @@ case class ArrayDistinct(child: Expression) | break; | } | if ($i == $j) { - | $tempArray[$pos] = null; | $pos = $pos + 1; | } | } | else { | int $j; | for ($j = 0; $j < $i; $j ++) { - | if (${ctx.genEqual(elementType, getValue1, getValue2)}) - | break; + | if (!$array.isNullAt($j) && ${ctx.genEqual(elementType, getValue1, getValue2)}) + | break; | } | if ($i == $j) { - | $tempArray[$pos] = ${CodeGenerator.getValue(array, elementType, s"$i")}; | $pos = $pos + 1; | } | } |} | |Object[] $distinctArray = new Object[$pos]; - |for (int $i = 0; $i < $pos; $i ++) { - | $distinctArray[$i] = $tempArray[$i]; + |$pos = 0; + |for (int $i = 0; $i < $array.numElements(); $i ++) { + | if ($array.isNullAt($i)) { + | int $j; + | for ($j = 0; $j < $i; $j ++) { + | if ($array.isNullAt($j)) + | break; + | } + | if ($i == $j) { + | $distinctArray[$pos] = null; + | $pos = $pos + 1; + | } + | } + | else { + | int $j; + | for ($j = 0; $j < $i; $j ++) { + | if (!$array.isNullAt($j) && ${ctx.genEqual(elementType, getValue1, getValue2)}) + | break; + | } + | if ($i == $j) { + | $distinctArray[$pos] = ${CodeGenerator.getValue(array, elementType, s"$i")}; + | $pos = $pos + 1; + | } + | } |} | |${ev.value} = new $arrayClass($distinctArray); From 12196265f06d1c31da29ed2fb27fba9cf1f456f7 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 17 May 2018 19:36:50 -0700 Subject: [PATCH 4/9] use OpenHashSet to check duplication in the array --- .../expressions/collectionOperations.scala | 129 ++++++++++++------ 1 file changed, 87 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index da4c2ebcfd9f..fcb969e9f18e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Comparator import scala.collection.mutable - import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder @@ -31,6 +30,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.{ByteArray, UTF8String} +import org.apache.spark.util.collection.OpenHashSet /** * Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit @@ -2373,6 +2373,8 @@ case class ArrayDistinct(child: Expression) override def dataType: DataType = child.dataType + lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + override def nullSafeEval(array: Any): Any = { val elementType = child.dataType.asInstanceOf[ArrayType].elementType val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType).distinct @@ -2380,70 +2382,113 @@ case class ArrayDistinct(child: Expression) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val elementType = dataType.asInstanceOf[ArrayType].elementType nullSafeCodeGen(ctx, ev, (array) => { - val arrayClass = classOf[GenericArrayData].getName - val distinctArray = ctx.freshName("distinctArray") val i = ctx.freshName("i") val j = ctx.freshName("j") - val pos = ctx.freshName("arrayPosition") - val getValue1 = CodeGenerator.getValue(array, elementType, i) - val getValue2 = CodeGenerator.getValue(array, elementType, j) + val hs = ctx.freshName("hs") + val distinctArrayLen = ctx.freshName("distinctArrayLen") + val getValue = CodeGenerator.getValue(array, elementType, i) + val openHashSet = classOf[OpenHashSet[_]].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()" s""" - |int $pos = 0; - |for (int $i = 0; $i < $array.numElements(); $i ++) { + |int $distinctArrayLen = 0; + |$openHashSet $hs = new $openHashSet($classTag); + |for (int $i = 0; $i < $array.numElements(); $i++) { | if ($array.isNullAt($i)) { - | int $j; - | for ($j = 0; $j < $i; $j ++) { - | if ($array.isNullAt($j)) - | break; - | } - | if ($i == $j) { - | $pos = $pos + 1; - | } - | } - | else { | int $j; | for ($j = 0; $j < $i; $j ++) { - | if (!$array.isNullAt($j) && ${ctx.genEqual(elementType, getValue1, getValue2)}) + | if ($array.isNullAt($j)) | break; | } | if ($i == $j) { - | $pos = $pos + 1; + | $distinctArrayLen = $distinctArrayLen + 1; + | } + | } + | else { + | if (!($hs.contains($getValue))) { + | $hs.add($getValue); + | $distinctArrayLen = $distinctArrayLen + 1; | } | } |} - | - |Object[] $distinctArray = new Object[$pos]; - |$pos = 0; - |for (int $i = 0; $i < $array.numElements(); $i ++) { - | if ($array.isNullAt($i)) { - | int $j; - | for ($j = 0; $j < $i; $j ++) { - | if ($array.isNullAt($j)) - | break; - | } - | if ($i == $j) { - | $distinctArray[$pos] = null; - | $pos = $pos + 1; - | } + |${genCodeForResult(ctx, ev, array, distinctArrayLen)} + """.stripMargin + }) + } + + def genCodeForResult( + ctx: CodegenContext, + ev: ExprCode, + inputArray: String, + newArraySize: String): String = { + val distinctArr = ctx.freshName("distinctArray") + val hs = ctx.freshName("hs") + val openHashSet = classOf[OpenHashSet[_]].getName + val i = ctx.freshName("i") + val j = ctx.freshName("j") + val pos = ctx.freshName("pos") + val genericArrayData = classOf[GenericArrayData].getName + val getValue = CodeGenerator.getValue(inputArray, elementType, i) + + if (!CodeGenerator.isPrimitiveType(elementType)) { + val arrayClass = classOf[GenericArrayData].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()" + s""" + |Object[] $distinctArr = new Object[$newArraySize]; + |int $pos = 0; + |$openHashSet $hs = new $openHashSet($classTag); + |for (int $i = 0; $i < $inputArray.numElements(); $i++) { + | if ($inputArray.isNullAt($i)) { + | int $j; + | for ($j = 0; $j < $i; $j ++) { + | if ($inputArray.isNullAt($j)) + | break; + | } + | if ($i == $j) { + | $distinctArr[$pos] = null; + | $pos = $pos + 1; + | } | } | else { + | if (!($hs.contains($getValue))) { + | $hs.add($getValue); + | $distinctArr[$pos] = $getValue; + | $pos = $pos + 1; + | } + | } + |} + |${ev.value} = new $arrayClass($distinctArr); + """.stripMargin + } else { + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$primitiveValueTypeName()" + s""" + |${ctx.createUnsafeArray(distinctArr, newArraySize, elementType, s" $prettyName failed.")} + |int $pos = 0; + |$openHashSet $hs = new $openHashSet($classTag); + |for (int $i = 0; $i < $inputArray.numElements(); $i++) { + | if ($inputArray.isNullAt($i)) { | int $j; | for ($j = 0; $j < $i; $j ++) { - | if (!$array.isNullAt($j) && ${ctx.genEqual(elementType, getValue1, getValue2)}) + | if ($inputArray.isNullAt($j)) | break; | } | if ($i == $j) { - | $distinctArray[$pos] = ${CodeGenerator.getValue(array, elementType, s"$i")}; - | $pos = $pos + 1; + | $distinctArr.setNullAt($pos); + | $pos = $pos + 1; + | } + | } + | else { + | if (!($hs.contains($getValue))) { + | $hs.add($getValue); + | $distinctArr.set$primitiveValueTypeName($pos, $getValue); + | $pos = $pos + 1; | } | } |} - | - |${ev.value} = new $arrayClass($distinctArray); - """.stripMargin - }) + |${ev.value} = $distinctArr; + """.stripMargin + } } override def prettyName: String = "array_distinct" From d127ac48ed7254c00bb517e8aaf0c84e66bcba1a Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 18 May 2018 11:30:55 -0700 Subject: [PATCH 5/9] address comments(3) --- .../expressions/collectionOperations.scala | 30 ++++++++----------- .../CollectionExpressionsSuite.scala | 12 ++++++-- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index fcb969e9f18e..cfa4a3930c17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Comparator import scala.collection.mutable + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder @@ -2386,22 +2387,20 @@ case class ArrayDistinct(child: Expression) val i = ctx.freshName("i") val j = ctx.freshName("j") val hs = ctx.freshName("hs") + val foundNullElement = ctx.freshName("foundNullElement") val distinctArrayLen = ctx.freshName("distinctArrayLen") val getValue = CodeGenerator.getValue(array, elementType, i) val openHashSet = classOf[OpenHashSet[_]].getName val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()" s""" |int $distinctArrayLen = 0; + |boolean $foundNullElement = false; |$openHashSet $hs = new $openHashSet($classTag); |for (int $i = 0; $i < $array.numElements(); $i++) { | if ($array.isNullAt($i)) { - | int $j; - | for ($j = 0; $j < $i; $j ++) { - | if ($array.isNullAt($j)) - | break; - | } - | if ($i == $j) { + | if (!($foundNullElement)) { | $distinctArrayLen = $distinctArrayLen + 1; + | $foundNullElement = true; | } | } | else { @@ -2427,6 +2426,7 @@ case class ArrayDistinct(child: Expression) val i = ctx.freshName("i") val j = ctx.freshName("j") val pos = ctx.freshName("pos") + val foundNullElement = ctx.freshName("foundNullElement") val genericArrayData = classOf[GenericArrayData].getName val getValue = CodeGenerator.getValue(inputArray, elementType, i) @@ -2436,17 +2436,14 @@ case class ArrayDistinct(child: Expression) s""" |Object[] $distinctArr = new Object[$newArraySize]; |int $pos = 0; + |boolean $foundNullElement = false; |$openHashSet $hs = new $openHashSet($classTag); |for (int $i = 0; $i < $inputArray.numElements(); $i++) { | if ($inputArray.isNullAt($i)) { - | int $j; - | for ($j = 0; $j < $i; $j ++) { - | if ($inputArray.isNullAt($j)) - | break; - | } - | if ($i == $j) { + | if (!($foundNullElement)) { | $distinctArr[$pos] = null; | $pos = $pos + 1; + | $foundNullElement = true; | } | } | else { @@ -2465,17 +2462,14 @@ case class ArrayDistinct(child: Expression) s""" |${ctx.createUnsafeArray(distinctArr, newArraySize, elementType, s" $prettyName failed.")} |int $pos = 0; + |boolean $foundNullElement = false; |$openHashSet $hs = new $openHashSet($classTag); |for (int $i = 0; $i < $inputArray.numElements(); $i++) { | if ($inputArray.isNullAt($i)) { - | int $j; - | for ($j = 0; $j < $i; $j ++) { - | if ($inputArray.isNullAt($j)) - | break; - | } - | if ($i == $j) { + | if (!($foundNullElement)) { | $distinctArr.setNullAt($pos); | $pos = $pos + 1; + | $foundNullElement = true; | } | } | else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 690b431204f8..6369c2c5d963 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -771,13 +771,21 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a0 = Literal.create(Seq(2, 1, 2, 3, 4, 4, 5), ArrayType(IntegerType)) val a1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) val a2 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType)) - val a3 = Literal.create(Seq("b", null, "a", "a"), ArrayType(StringType)) - val a4 = Literal.create(Seq(null, null), ArrayType(NullType)) + val a3 = Literal.create(Seq("b", null, "a", null, "a", null), ArrayType(StringType)) + val a4 = Literal.create(Seq(null, null, null), ArrayType(NullType)) + val a5 = Literal.create(Seq(true, false, false, true), ArrayType(BooleanType)) + val a6 = Literal.create(Seq(1.123, 0.1234, 1.121, 1.123, 1.1230, 1.121, 0.1234), + ArrayType(DoubleType)) + val a7 = Literal.create(Seq(1.123f, 0.1234f, 1.121f, 1.123f, 1.1230f, 1.121f, 0.1234f), + ArrayType(FloatType)) checkEvaluation(new ArrayDistinct(a0), Seq(2, 1, 3, 4, 5)) checkEvaluation(new ArrayDistinct(a1), Seq.empty[Integer]) checkEvaluation(new ArrayDistinct(a2), Seq("b", "a", "c")) checkEvaluation(new ArrayDistinct(a3), Seq("b", null, "a")) checkEvaluation(new ArrayDistinct(a4), Seq(null)) + checkEvaluation(new ArrayDistinct(a5), Seq(true, false)) + checkEvaluation(new ArrayDistinct(a6), Seq(1.123, 0.1234, 1.121)) + checkEvaluation(new ArrayDistinct(a7), Seq(1.123f, 0.1234f, 1.121f)) } } From ba0d60f2e7fd7de4e571d78af362b7f06935b3e8 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 12 Jun 2018 14:23:15 -0700 Subject: [PATCH 6/9] add complex data type support --- .../expressions/collectionOperations.scala | 339 ++++++++++++++---- .../CollectionExpressionsSuite.scala | 23 ++ 2 files changed, 285 insertions(+), 77 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index cfa4a3930c17..d35bed1539b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2376,112 +2376,297 @@ case class ArrayDistinct(child: Expression) lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(elementType) + + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(elementType, s"function $prettyName") + } + } + + @transient private lazy val elementTypeSupportEquals = elementType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } + override def nullSafeEval(array: Any): Any = { - val elementType = child.dataType.asInstanceOf[ArrayType].elementType - val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType).distinct - new GenericArrayData(data.asInstanceOf[Array[Any]]) + val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) + if (elementTypeSupportEquals) { + new GenericArrayData(data.distinct.asInstanceOf[Array[Any]]) + } else { + var foundNullElement = false + var pos = 0 + for(i <- 0 until data.length) { + if (data(i) == null) { + if (!foundNullElement) { + foundNullElement = true + pos = pos + 1 + } + } else { + var j = 0 + var done = false + while (j <= i && !done) { + if (data(j) != null && ordering.equiv(data(j), data(i))) { + done = true + } + j = j + 1 + } + if (i == j-1) { + pos = pos + 1 + } + } + } + new GenericArrayData(data.slice(0, pos)) + } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (array) => { val i = ctx.freshName("i") val j = ctx.freshName("j") - val hs = ctx.freshName("hs") + val sizeOfDistinctArray = ctx.freshName("sizeOfDistinctArray") + val getValue1 = CodeGenerator.getValue(array, elementType, i) + val getValue2 = CodeGenerator.getValue(array, elementType, j) val foundNullElement = ctx.freshName("foundNullElement") - val distinctArrayLen = ctx.freshName("distinctArrayLen") - val getValue = CodeGenerator.getValue(array, elementType, i) val openHashSet = classOf[OpenHashSet[_]].getName + val hs = ctx.freshName("hs") val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()" + if(elementTypeSupportEquals) { + s""" + |int $sizeOfDistinctArray = 0; + |boolean $foundNullElement = false; + |$openHashSet $hs = new $openHashSet($classTag); + |for (int $i = 0; $i < $array.numElements(); $i++) { + | if ($array.isNullAt($i)) { + | if (!($foundNullElement)) { + | $foundNullElement = true; + | } + | } + | else { + | if (!($hs.contains($getValue1))) { + | $hs.add($getValue1); + | } + | } + |} + |$sizeOfDistinctArray = $hs.size() + ($foundNullElement ? 1 : 0); + |${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)} + """.stripMargin + } + else { + s""" + |int $sizeOfDistinctArray = 0; + |boolean $foundNullElement = false; + |for (int $i = 0; $i < $array.numElements(); $i ++) { + | if ($array.isNullAt($i)) { + | if (!($foundNullElement)) { + | $sizeOfDistinctArray = $sizeOfDistinctArray + 1; + | $foundNullElement = true; + | } + | } + | else { + | int $j; + | for ($j = 0; $j < $i; $j++) { + | if (!$array.isNullAt($j) && ${ctx.genEqual(elementType, getValue1, getValue2)}) + | break; + | } + | if ($i == $j) { + | $sizeOfDistinctArray = $sizeOfDistinctArray + 1; + | } + | } + |} + | + |${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)} + """.stripMargin + } + }) + } + + private def setNull( + isPrimitive: Boolean, + foundNullElement: String, + distinctArray: String, + pos: String): String = { + val setNullValue = + if (!isPrimitive) { + s""" + |$distinctArray[$pos] = null; + """. + stripMargin + } else { + s""" + |$distinctArray.setNullAt($pos); + """. + stripMargin + } + + s""" + |if (!($foundNullElement)) { + | $setNullValue; + | $pos = $pos + 1; + | $foundNullElement = true; + |} + """.stripMargin + } + + private def setNotNullValue(isPrimitive: Boolean, + distinctArray: String, + pos: String, + getValue1: String, + primitiveValueTypeName: String): String = { + if (!isPrimitive) { s""" - |int $distinctArrayLen = 0; - |boolean $foundNullElement = false; - |$openHashSet $hs = new $openHashSet($classTag); - |for (int $i = 0; $i < $array.numElements(); $i++) { - | if ($array.isNullAt($i)) { - | if (!($foundNullElement)) { - | $distinctArrayLen = $distinctArrayLen + 1; - | $foundNullElement = true; - | } - | } - | else { - | if (!($hs.contains($getValue))) { - | $hs.add($getValue); - | $distinctArrayLen = $distinctArrayLen + 1; - | } - | } - |} - |${genCodeForResult(ctx, ev, array, distinctArrayLen)} + |$distinctArray[$pos] = $getValue1; """.stripMargin - }) + } else { + s""" + |$distinctArray.set$primitiveValueTypeName($pos, $getValue1); + """.stripMargin + } + } + + private def setValueForFastEval( + isPrimitive: Boolean, + hs: String, + distinctArray: String, + pos: String, + getValue1: String, + primitiveValueTypeName: String): String = { + val setValue = setNotNullValue(isPrimitive, + distinctArray, pos, getValue1, primitiveValueTypeName) + s""" + |if (!($hs.contains($getValue1))) { + | $hs.add($getValue1); + | $setValue; + | $pos = $pos + 1; + |} + """.stripMargin + } + + private def setValueForbruteForceEval( + isPrimitive: Boolean, + i: String, + j: String, + inputArray: String, + distinctArray: String, + pos: String, + getValue1: String, + isEqual: String, + primitiveValueTypeName: String): String = { + val setValue = setNotNullValue(isPrimitive, + distinctArray, pos, getValue1, primitiveValueTypeName) + s""" + |int $j; + |for ($j = 0; $j < $i; $j ++) { + | if (!$inputArray.isNullAt($j) && $isEqual) + | break; + | } + | if ($i == $j) { + | $setValue; + | $pos = $pos + 1; + | } + """.stripMargin } def genCodeForResult( ctx: CodegenContext, ev: ExprCode, inputArray: String, - newArraySize: String): String = { - val distinctArr = ctx.freshName("distinctArray") - val hs = ctx.freshName("hs") - val openHashSet = classOf[OpenHashSet[_]].getName + size: String): String = { + val distinctArray = ctx.freshName("distinctArray") val i = ctx.freshName("i") val j = ctx.freshName("j") val pos = ctx.freshName("pos") + val getValue1 = CodeGenerator.getValue(inputArray, elementType, i) + val getValue2 = CodeGenerator.getValue(inputArray, elementType, j) + val isEqual = ctx.genEqual(elementType, getValue1, getValue2) val foundNullElement = ctx.freshName("foundNullElement") - val genericArrayData = classOf[GenericArrayData].getName - val getValue = CodeGenerator.getValue(inputArray, elementType, i) - + val hs = ctx.freshName("hs") + val openHashSet = classOf[OpenHashSet[_]].getName if (!CodeGenerator.isPrimitiveType(elementType)) { val arrayClass = classOf[GenericArrayData].getName val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()" - s""" - |Object[] $distinctArr = new Object[$newArraySize]; - |int $pos = 0; - |boolean $foundNullElement = false; - |$openHashSet $hs = new $openHashSet($classTag); - |for (int $i = 0; $i < $inputArray.numElements(); $i++) { - | if ($inputArray.isNullAt($i)) { - | if (!($foundNullElement)) { - | $distinctArr[$pos] = null; - | $pos = $pos + 1; - | $foundNullElement = true; - | } - | } - | else { - | if (!($hs.contains($getValue))) { - | $hs.add($getValue); - | $distinctArr[$pos] = $getValue; - | $pos = $pos + 1; - | } - | } - |} - |${ev.value} = new $arrayClass($distinctArr); + val setNullForNonPrimitive = + setNull(false, foundNullElement, distinctArray, pos) + if (elementTypeSupportEquals) { + val setValueForFast = setValueForFastEval(false, hs, distinctArray, pos, getValue1, "") + s""" + |int $pos = 0; + |Object[] $distinctArray = new Object[$size]; + |boolean $foundNullElement = false; + |$openHashSet $hs = new $openHashSet($classTag); + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $setNullForNonPrimitive; + | } + | else { + | $setValueForFast; + | } + |} + |${ev.value} = new $arrayClass($distinctArray); + """.stripMargin + } + else { + val setValueForbruteForce = setValueForbruteForceEval(false, i, j, + inputArray, distinctArray, pos, getValue1: String, isEqual, "") + s""" + |int $pos = 0; + |Object[] $distinctArray = new Object[$size]; + |boolean $foundNullElement = false; + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $setNullForNonPrimitive; + | } + | else { + | $setValueForbruteForce; + | } + |} + |${ev.value} = new $arrayClass($distinctArray); """.stripMargin + } } else { val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + val setNullForPrimitive = setNull(true, foundNullElement, distinctArray, pos) val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$primitiveValueTypeName()" - s""" - |${ctx.createUnsafeArray(distinctArr, newArraySize, elementType, s" $prettyName failed.")} - |int $pos = 0; - |boolean $foundNullElement = false; - |$openHashSet $hs = new $openHashSet($classTag); - |for (int $i = 0; $i < $inputArray.numElements(); $i++) { - | if ($inputArray.isNullAt($i)) { - | if (!($foundNullElement)) { - | $distinctArr.setNullAt($pos); - | $pos = $pos + 1; - | $foundNullElement = true; - | } - | } - | else { - | if (!($hs.contains($getValue))) { - | $hs.add($getValue); - | $distinctArr.set$primitiveValueTypeName($pos, $getValue); - | $pos = $pos + 1; - | } - | } - |} - |${ev.value} = $distinctArr; - """.stripMargin + val setValueForFast = + setValueForFastEval(true, hs, distinctArray, pos, getValue1, primitiveValueTypeName) + if (elementTypeSupportEquals) { + s""" + |${ctx.createUnsafeArray(distinctArray, size, elementType, s" $prettyName failed.")} + |int $pos = 0; + |boolean $foundNullElement = false; + |$openHashSet $hs = new $openHashSet($classTag); + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $setNullForPrimitive; + | } + | else { + | $setValueForFast; + | } + |} + |${ev.value} = $distinctArray; + """.stripMargin + } else { + val setValueForbruteForce = setValueForbruteForceEval(true, i, j, + inputArray, distinctArray, pos, getValue1: String, isEqual, primitiveValueTypeName) + s""" + |${ctx.createUnsafeArray(distinctArray, size, elementType, s" $prettyName failed.")} + |int $pos = 0; + |boolean $foundNullElement = false; + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $setNullForPrimitive; + | } + | else { + | $setValueForbruteForce; + | } + |} + |${ev.value} = $distinctArray; + """.stripMargin + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 6369c2c5d963..f377f9c8cd53 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -787,5 +787,28 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(new ArrayDistinct(a5), Seq(true, false)) checkEvaluation(new ArrayDistinct(a6), Seq(1.123, 0.1234, 1.121)) checkEvaluation(new ArrayDistinct(a7), Seq(1.123f, 0.1234f, 1.121f)) + + // complex data types + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2), + Array[Byte](1, 2), Array[Byte](5, 6)), ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), null, Array[Byte](1, 2), + null, Array[Byte](5, 6), null), ArrayType(BinaryType)) + + checkEvaluation(ArrayDistinct(b0), Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2))) + checkEvaluation(ArrayDistinct(b1), Seq[Array[Byte]](Array[Byte](2, 1), null)) + checkEvaluation(ArrayDistinct(b2), Seq[Array[Byte]](Array[Byte](5, 6), null, + Array[Byte](1, 2))) + + val c0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](1, 2), + Seq[Int](3, 4), Seq[Int](1, 2)), ArrayType(ArrayType(IntegerType))) + val c1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + val c2 = Literal.create(Seq[Seq[Int]](null, Seq[Int](2, 1), null, null, Seq[Int](2, 1), null), + ArrayType(ArrayType(IntegerType))) + checkEvaluation(ArrayDistinct(c0), Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4))) + checkEvaluation(ArrayDistinct(c1), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1))) + checkEvaluation(ArrayDistinct(c2), Seq[Seq[Int]](null, Seq[Int](2, 1))) } } From 2df2cffbf15becc86a7d4f3f731a9b9345a8e206 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 17 Jun 2018 09:20:27 +0200 Subject: [PATCH 7/9] address comments --- .../expressions/collectionOperations.scala | 116 ++++++------------ 1 file changed, 39 insertions(+), 77 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index d35bed1539b5..940d1c5d6061 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2374,7 +2374,7 @@ case class ArrayDistinct(child: Expression) override def dataType: DataType = child.dataType - lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + @transient lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(elementType) @@ -2442,21 +2442,15 @@ case class ArrayDistinct(child: Expression) |$openHashSet $hs = new $openHashSet($classTag); |for (int $i = 0; $i < $array.numElements(); $i++) { | if ($array.isNullAt($i)) { - | if (!($foundNullElement)) { - | $foundNullElement = true; - | } - | } - | else { - | if (!($hs.contains($getValue1))) { - | $hs.add($getValue1); - | } + | $foundNullElement = true; + | } else { + | $hs.add($getValue1); | } |} |$sizeOfDistinctArray = $hs.size() + ($foundNullElement ? 1 : 0); |${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)} """.stripMargin - } - else { + } else { s""" |int $sizeOfDistinctArray = 0; |boolean $foundNullElement = false; @@ -2466,12 +2460,12 @@ case class ArrayDistinct(child: Expression) | $sizeOfDistinctArray = $sizeOfDistinctArray + 1; | $foundNullElement = true; | } - | } - | else { + | } else { | int $j; | for ($j = 0; $j < $i; $j++) { - | if (!$array.isNullAt($j) && ${ctx.genEqual(elementType, getValue1, getValue2)}) + | if (!$array.isNullAt($j) && ${ctx.genEqual(elementType, getValue1, getValue2)}) { | break; + | } | } | if ($i == $j) { | $sizeOfDistinctArray = $sizeOfDistinctArray + 1; @@ -2492,15 +2486,9 @@ case class ArrayDistinct(child: Expression) pos: String): String = { val setNullValue = if (!isPrimitive) { - s""" - |$distinctArray[$pos] = null; - """. - stripMargin + s"$distinctArray[$pos] = null"; } else { - s""" - |$distinctArray.setNullAt($pos); - """. - stripMargin + s"$distinctArray.setNullAt($pos)"; } s""" @@ -2518,13 +2506,9 @@ case class ArrayDistinct(child: Expression) getValue1: String, primitiveValueTypeName: String): String = { if (!isPrimitive) { - s""" - |$distinctArray[$pos] = $getValue1; - """.stripMargin + s"$distinctArray[$pos] = $getValue1"; } else { - s""" - |$distinctArray.set$primitiveValueTypeName($pos, $getValue1); - """.stripMargin + s"$distinctArray.set$primitiveValueTypeName($pos, $getValue1)"; } } @@ -2546,7 +2530,7 @@ case class ArrayDistinct(child: Expression) """.stripMargin } - private def setValueForbruteForceEval( + private def setValueForBruteForceEval( isPrimitive: Boolean, i: String, j: String, @@ -2561,13 +2545,14 @@ case class ArrayDistinct(child: Expression) s""" |int $j; |for ($j = 0; $j < $i; $j ++) { - | if (!$inputArray.isNullAt($j) && $isEqual) + | if (!$inputArray.isNullAt($j) && $isEqual) { | break; | } - | if ($i == $j) { - | $setValue; - | $pos = $pos + 1; - | } + |} + |if ($i == $j) { + | $setValue; + | $pos = $pos + 1; + |} """.stripMargin } @@ -2601,17 +2586,15 @@ case class ArrayDistinct(child: Expression) |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { | if ($inputArray.isNullAt($i)) { | $setNullForNonPrimitive; - | } - | else { + | } else { | $setValueForFast; | } |} |${ev.value} = new $arrayClass($distinctArray); """.stripMargin - } - else { - val setValueForbruteForce = setValueForbruteForceEval(false, i, j, - inputArray, distinctArray, pos, getValue1: String, isEqual, "") + } else { + val setValueForbruteForce = setValueForBruteForceEval(false, i, j, + inputArray, distinctArray, pos, getValue1, isEqual, "") s""" |int $pos = 0; |Object[] $distinctArray = new Object[$size]; @@ -2619,8 +2602,7 @@ case class ArrayDistinct(child: Expression) |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { | if ($inputArray.isNullAt($i)) { | $setNullForNonPrimitive; - | } - | else { + | } else { | $setValueForbruteForce; | } |} @@ -2632,41 +2614,21 @@ case class ArrayDistinct(child: Expression) val setNullForPrimitive = setNull(true, foundNullElement, distinctArray, pos) val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$primitiveValueTypeName()" val setValueForFast = - setValueForFastEval(true, hs, distinctArray, pos, getValue1, primitiveValueTypeName) - if (elementTypeSupportEquals) { - s""" - |${ctx.createUnsafeArray(distinctArray, size, elementType, s" $prettyName failed.")} - |int $pos = 0; - |boolean $foundNullElement = false; - |$openHashSet $hs = new $openHashSet($classTag); - |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { - | if ($inputArray.isNullAt($i)) { - | $setNullForPrimitive; - | } - | else { - | $setValueForFast; - | } - |} - |${ev.value} = $distinctArray; - """.stripMargin - } else { - val setValueForbruteForce = setValueForbruteForceEval(true, i, j, - inputArray, distinctArray, pos, getValue1: String, isEqual, primitiveValueTypeName) - s""" - |${ctx.createUnsafeArray(distinctArray, size, elementType, s" $prettyName failed.")} - |int $pos = 0; - |boolean $foundNullElement = false; - |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { - | if ($inputArray.isNullAt($i)) { - | $setNullForPrimitive; - | } - | else { - | $setValueForbruteForce; - | } - |} - |${ev.value} = $distinctArray; - """.stripMargin - } + setValueForFastEval(true, hs, distinctArray, pos, getValue1, primitiveValueTypeName) + s""" + |${ctx.createUnsafeArray(distinctArray, size, elementType, s" $prettyName failed.")} + |int $pos = 0; + |boolean $foundNullElement = false; + |$openHashSet $hs = new $openHashSet($classTag); + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $setNullForPrimitive; + | } else { + | $setValueForFast; + | } + |} + |${ev.value} = $distinctArray; + """.stripMargin } } From 2d778803d316006b8d333ff57f999ce4b275e378 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 17 Jun 2018 09:34:32 +0200 Subject: [PATCH 8/9] change b to B in setValueForbruteForce --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 940d1c5d6061..f4e881ec3b97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2593,7 +2593,7 @@ case class ArrayDistinct(child: Expression) |${ev.value} = new $arrayClass($distinctArray); """.stripMargin } else { - val setValueForbruteForce = setValueForBruteForceEval(false, i, j, + val setValueForBruteForce = setValueForBruteForceEval(false, i, j, inputArray, distinctArray, pos, getValue1, isEqual, "") s""" |int $pos = 0; @@ -2603,7 +2603,7 @@ case class ArrayDistinct(child: Expression) | if ($inputArray.isNullAt($i)) { | $setNullForNonPrimitive; | } else { - | $setValueForbruteForce; + | $setValueForBruteForce; | } |} |${ev.value} = new $arrayClass($distinctArray); From 3f5d03b617ceff2e6f02735f73372976674855ef Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 20 Jun 2018 23:57:48 +0100 Subject: [PATCH 9/9] fix comments --- .../expressions/collectionOperations.scala | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index f4e881ec3b97..7c064a130ff3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2400,7 +2400,7 @@ case class ArrayDistinct(child: Expression) } else { var foundNullElement = false var pos = 0 - for(i <- 0 until data.length) { + for (i <- 0 until data.length) { if (data(i) == null) { if (!foundNullElement) { foundNullElement = true @@ -2415,7 +2415,7 @@ case class ArrayDistinct(child: Expression) } j = j + 1 } - if (i == j-1) { + if (i == j - 1) { pos = pos + 1 } } @@ -2435,12 +2435,12 @@ case class ArrayDistinct(child: Expression) val openHashSet = classOf[OpenHashSet[_]].getName val hs = ctx.freshName("hs") val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()" - if(elementTypeSupportEquals) { + if (elementTypeSupportEquals) { s""" |int $sizeOfDistinctArray = 0; |boolean $foundNullElement = false; |$openHashSet $hs = new $openHashSet($classTag); - |for (int $i = 0; $i < $array.numElements(); $i++) { + |for (int $i = 0; $i < $array.numElements(); $i ++) { | if ($array.isNullAt($i)) { | $foundNullElement = true; | } else { @@ -2462,7 +2462,7 @@ case class ArrayDistinct(child: Expression) | } | } else { | int $j; - | for ($j = 0; $j < $i; $j++) { + | for ($j = 0; $j < $i; $j ++) { | if (!$array.isNullAt($j) && ${ctx.genEqual(elementType, getValue1, getValue2)}) { | break; | } @@ -2593,8 +2593,8 @@ case class ArrayDistinct(child: Expression) |${ev.value} = new $arrayClass($distinctArray); """.stripMargin } else { - val setValueForBruteForce = setValueForBruteForceEval(false, i, j, - inputArray, distinctArray, pos, getValue1, isEqual, "") + val setValueForBruteForce = setValueForBruteForceEval( + false, i, j, inputArray, distinctArray, pos, getValue1, isEqual, "") s""" |int $pos = 0; |Object[] $distinctArray = new Object[$size]; @@ -2614,20 +2614,20 @@ case class ArrayDistinct(child: Expression) val setNullForPrimitive = setNull(true, foundNullElement, distinctArray, pos) val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$primitiveValueTypeName()" val setValueForFast = - setValueForFastEval(true, hs, distinctArray, pos, getValue1, primitiveValueTypeName) + setValueForFastEval(true, hs, distinctArray, pos, getValue1, primitiveValueTypeName) s""" - |${ctx.createUnsafeArray(distinctArray, size, elementType, s" $prettyName failed.")} - |int $pos = 0; - |boolean $foundNullElement = false; - |$openHashSet $hs = new $openHashSet($classTag); - |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { - | if ($inputArray.isNullAt($i)) { - | $setNullForPrimitive; - | } else { - | $setValueForFast; - | } - |} - |${ev.value} = $distinctArray; + |${ctx.createUnsafeArray(distinctArray, size, elementType, s" $prettyName failed.")} + |int $pos = 0; + |boolean $foundNullElement = false; + |$openHashSet $hs = new $openHashSet($classTag); + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $setNullForPrimitive; + | } else { + | $setValueForFast; + | } + |} + |${ev.value} = $distinctArray; """.stripMargin } }