diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala index 3e2d12fc5b17..53c37f0a5491 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala @@ -17,14 +17,11 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.datasketches.frequencies.ItemsSketch -import org.apache.datasketches.memory.Memory - import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} -import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxTopK +import org.apache.spark.sql.catalyst.expressions.aggregate.{ApproxTopK, ApproxTopKAggregateBuffer} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types._ @@ -105,9 +102,10 @@ case class ApproxTopKEstimate(state: Expression, k: Expression) val kVal = kEval.asInstanceOf[Int] ApproxTopK.checkK(kVal) ApproxTopK.checkMaxItemsTracked(maxItemsTrackedVal, kVal) - val itemsSketch = ItemsSketch.getInstance( - Memory.wrap(dataSketchBytes), ApproxTopK.genSketchSerDe(itemDataType)) - ApproxTopK.genEvalResult(itemsSketch, kVal, itemDataType) + val approxTopKAggregateBuffer = ApproxTopKAggregateBuffer.deserialize( + dataSketchBytes, + ApproxTopK.genSketchSerDe(itemDataType)) + approxTopKAggregateBuffer.eval(kVal, itemDataType) } override protected def withNewChildrenInternal(newState: Expression, newK: Expression) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala index 1fca8ad86bc2..7ae542f190d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala @@ -260,54 +260,6 @@ object ApproxTopK { } } - def updateSketchBuffer( - itemExpression: Expression, - buffer: ItemsSketch[Any], - input: InternalRow): ItemsSketch[Any] = { - val v = itemExpression.eval(input) - if (v != null) { - itemExpression.dataType match { - case _: BooleanType => buffer.update(v.asInstanceOf[Boolean]) - case _: ByteType => buffer.update(v.asInstanceOf[Byte]) - case _: ShortType => buffer.update(v.asInstanceOf[Short]) - case _: IntegerType => buffer.update(v.asInstanceOf[Int]) - case _: LongType => buffer.update(v.asInstanceOf[Long]) - case _: FloatType => buffer.update(v.asInstanceOf[Float]) - case _: DoubleType => buffer.update(v.asInstanceOf[Double]) - case _: DateType => buffer.update(v.asInstanceOf[Int]) - case _: TimestampType => buffer.update(v.asInstanceOf[Long]) - case _: TimestampNTZType => buffer.update(v.asInstanceOf[Long]) - case st: StringType => - val cKey = CollationFactory.getCollationKey(v.asInstanceOf[UTF8String], st.collationId) - buffer.update(cKey.toString) - case _: DecimalType => buffer.update(v.asInstanceOf[Decimal]) - } - } - buffer - } - - def genEvalResult( - itemsSketch: ItemsSketch[Any], - k: Int, - itemDataType: DataType): GenericArrayData = { - val items = itemsSketch.getFrequentItems(ErrorType.NO_FALSE_POSITIVES) - val resultLength = math.min(items.length, k) - val result = new Array[AnyRef](resultLength) - for (i <- 0 until resultLength) { - val row = items(i) - itemDataType match { - case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | - _: LongType | _: FloatType | _: DoubleType | _: DecimalType | - _: DateType | _: TimestampType | _: TimestampNTZType => - result(i) = InternalRow.apply(row.getItem, row.getEstimate) - case _: StringType => - val item = UTF8String.fromString(row.getItem.asInstanceOf[String]) - result(i) = InternalRow.apply(item, row.getEstimate) - } - } - new GenericArrayData(result) - } - def genSketchSerDe(dataType: DataType): ArrayOfItemsSerDe[Any] = { dataType match { case _: BooleanType => new ArrayOfBooleansSerDe().asInstanceOf[ArrayOfItemsSerDe[Any]] @@ -333,7 +285,7 @@ object ApproxTopK { def dataTypeToDDL(dataType: DataType): String = dataType match { case _: StringType => - // Hide collation information in DDL format + // Hide collation information in DDL format, otherwise CollationExpressionWalkerSuite fails s"item string not null" case other => StructField("item", other, nullable = false).toDDL @@ -552,7 +504,7 @@ case class ApproxTopKAccumulate( maxItemsTracked: Expression, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends TypedImperativeAggregate[ItemsSketch[Any]] + extends TypedImperativeAggregate[ApproxTopKAggregateBuffer[Any]] with ImplicitCastInputTypes with BinaryLike[Expression] { @@ -592,18 +544,23 @@ case class ApproxTopKAccumulate( override def dataType: DataType = ApproxTopK.getSketchStateDataType(itemDataType) - override def createAggregationBuffer(): ItemsSketch[Any] = { + override def createAggregationBuffer(): ApproxTopKAggregateBuffer[Any] = { val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal) - ApproxTopK.createItemsSketch(expr, maxMapSize) + val sketch = ApproxTopK.createItemsSketch(expr, maxMapSize) + new ApproxTopKAggregateBuffer[Any](sketch, 0L) } - override def update(buffer: ItemsSketch[Any], input: InternalRow): ItemsSketch[Any] = - ApproxTopK.updateSketchBuffer(expr, buffer, input) + override def update(buffer: ApproxTopKAggregateBuffer[Any], input: InternalRow): + ApproxTopKAggregateBuffer[Any] = + buffer.update(expr, input) - override def merge(buffer: ItemsSketch[Any], input: ItemsSketch[Any]): ItemsSketch[Any] = + override def merge( + buffer: ApproxTopKAggregateBuffer[Any], + input: ApproxTopKAggregateBuffer[Any]): + ApproxTopKAggregateBuffer[Any] = buffer.merge(input) - override def eval(buffer: ItemsSketch[Any]): Any = { + override def eval(buffer: ApproxTopKAggregateBuffer[Any]): Any = { val sketchBytes = serialize(buffer) val itemDataTypeDDL = ApproxTopK.dataTypeToDDL(itemDataType) InternalRow.apply( @@ -613,11 +570,11 @@ case class ApproxTopKAccumulate( UTF8String.fromString(itemDataTypeDDL)) } - override def serialize(buffer: ItemsSketch[Any]): Array[Byte] = - buffer.toByteArray(ApproxTopK.genSketchSerDe(itemDataType)) + override def serialize(buffer: ApproxTopKAggregateBuffer[Any]): Array[Byte] = + buffer.serialize(ApproxTopK.genSketchSerDe(itemDataType)) - override def deserialize(storageFormat: Array[Byte]): ItemsSketch[Any] = - ItemsSketch.getInstance(Memory.wrap(storageFormat), ApproxTopK.genSketchSerDe(itemDataType)) + override def deserialize(storageFormat: Array[Byte]): ApproxTopKAggregateBuffer[Any] = + ApproxTopKAggregateBuffer.deserialize(storageFormat, ApproxTopK.genSketchSerDe(itemDataType)) override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) @@ -644,10 +601,10 @@ case class ApproxTopKAccumulate( * @param maxItemsTracked the maximum number of items tracked in the sketch */ class CombineInternal[T]( - sketch: ItemsSketch[T], + sketchWithNullCount: ApproxTopKAggregateBuffer[T], var itemDataType: DataType, var maxItemsTracked: Int) { - def getSketch: ItemsSketch[T] = sketch + def getSketchWithNullCount: ApproxTopKAggregateBuffer[T] = sketchWithNullCount def getItemDataType: DataType = itemDataType @@ -689,6 +646,9 @@ class CombineInternal[T]( } } + def updateSketchWithNullCount(otherSketchWithNullCount: ApproxTopKAggregateBuffer[T]): Unit = + sketchWithNullCount.merge(otherSketchWithNullCount) + /** * Serialize the CombineInternal instance to a byte array. * Serialization format: @@ -698,18 +658,18 @@ class CombineInternal[T]( * sketchBytes */ def serialize(): Array[Byte] = { - val sketchBytes = sketch.toByteArray( + val sketchWithNullCountBytes = sketchWithNullCount.serialize( ApproxTopK.genSketchSerDe(itemDataType).asInstanceOf[ArrayOfItemsSerDe[T]]) val itemDataTypeDDL = ApproxTopK.dataTypeToDDL(itemDataType) val ddlBytes: Array[Byte] = itemDataTypeDDL.getBytes(StandardCharsets.UTF_8) val byteArray = new Array[Byte]( - sketchBytes.length + Integer.BYTES + Integer.BYTES + ddlBytes.length) + sketchWithNullCountBytes.length + Integer.BYTES + Integer.BYTES + ddlBytes.length) val byteBuffer = ByteBuffer.wrap(byteArray) byteBuffer.putInt(maxItemsTracked) byteBuffer.putInt(ddlBytes.length) byteBuffer.put(ddlBytes) - byteBuffer.put(sketchBytes) + byteBuffer.put(sketchWithNullCountBytes) byteArray } } @@ -736,9 +696,9 @@ object CombineInternal { // read sketchBytes val sketchBytes = new Array[Byte](buffer.length - Integer.BYTES - Integer.BYTES - ddlLength) byteBuffer.get(sketchBytes) - val sketch = ItemsSketch.getInstance( - Memory.wrap(sketchBytes), ApproxTopK.genSketchSerDe(itemDataType)) - new CombineInternal[Any](sketch, itemDataType, maxItemsTracked) + val sketchWithNullCount = ApproxTopKAggregateBuffer.deserialize( + sketchBytes, ApproxTopK.genSketchSerDe(itemDataType)) + new CombineInternal[Any](sketchWithNullCount, itemDataType, maxItemsTracked) } } @@ -833,7 +793,7 @@ case class ApproxTopKCombine( if (combineSizeSpecified) { val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal) new CombineInternal[Any]( - new ItemsSketch[Any](maxMapSize), + new ApproxTopKAggregateBuffer[Any](new ItemsSketch[Any](maxMapSize), 0L), null, maxItemsTrackedVal) } else { @@ -842,7 +802,7 @@ case class ApproxTopKCombine( // The actual maxItemsTracked will be checked during the updates. val maxMapSize = ApproxTopK.calMaxMapSize(ApproxTopK.MAX_ITEMS_TRACKED_LIMIT) new CombineInternal[Any]( - new ItemsSketch[Any](maxMapSize), + new ApproxTopKAggregateBuffer[Any](new ItemsSketch[Any](maxMapSize), 0L), null, ApproxTopK.VOID_MAX_ITEMS_TRACKED) } @@ -863,9 +823,9 @@ case class ApproxTopKCombine( // update itemDataType (throw error if not match) buffer.updateItemDataType(inputItemDataType) // update sketch - val inputSketch = ItemsSketch.getInstance( - Memory.wrap(inputSketchBytes), ApproxTopK.genSketchSerDe(buffer.getItemDataType)) - buffer.getSketch.merge(inputSketch) + val inputSketchWithNullCount = ApproxTopKAggregateBuffer.deserialize( + inputSketchBytes, ApproxTopK.genSketchSerDe(inputItemDataType)) + buffer.updateSketchWithNullCount(inputSketchWithNullCount) buffer } @@ -876,14 +836,14 @@ case class ApproxTopKCombine( buffer.updateMaxItemsTracked(combineSizeSpecified, input.getMaxItemsTracked) // update itemDataType (throw error if not match) buffer.updateItemDataType(input.getItemDataType) - // update sketch - buffer.getSketch.merge(input.getSketch) + // update sketchWithNullCount + buffer.getSketchWithNullCount.merge(input.getSketchWithNullCount) buffer } override def eval(buffer: CombineInternal[Any]): Any = { - val sketchBytes = - buffer.getSketch.toByteArray(ApproxTopK.genSketchSerDe(buffer.getItemDataType)) + val sketchBytes = buffer.getSketchWithNullCount + .serialize(ApproxTopK.genSketchSerDe(buffer.getItemDataType)) val maxItemsTracked = buffer.getMaxItemsTracked val itemDataTypeDDL = ApproxTopK.dataTypeToDDL(buffer.getItemDataType) InternalRow.apply( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala index d9d16d1234b7..982c9ff90da7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala @@ -381,6 +381,77 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { ) } + test("SPARK-53960: accumulate and estimate count NULL values") { + val res = sql( + """SELECT approx_top_k_estimate(approx_top_k_accumulate(expr), 2) + |FROM VALUES 'a', 'a', 'b', 'b', 'b', NULL, NULL, NULL, NULL AS tab(expr)""".stripMargin) + checkAnswer(res, Row(Seq(Row(null, 4), Row("b", 3)))) + } + + test("SPARK-53960: accumulate and estimate null is not in top k") { + val res = sql( + """SELECT approx_top_k_estimate(approx_top_k_accumulate(expr), 2) + |FROM VALUES 'a', 'a', 'b', 'b', 'b', NULL AS tab(expr)""".stripMargin) + checkAnswer(res, Row(Seq(Row("b", 3), Row("a", 2)))) + } + + test("SPARK-53960: accumulate and estimate null is the last in top k") { + val res = sql( + """SELECT approx_top_k_estimate(approx_top_k_accumulate(expr), 3) + |FROM VALUES 0, 0, 1, 1, 1, NULL AS tab(expr)""".stripMargin) + checkAnswer(res, Row(Seq(Row(1, 3), Row(0, 2), Row(null, 1)))) + } + + test("SPARK-53960: accumulate and estimate null + frequent items < k") { + val res = sql( + """SELECT approx_top_k_estimate(approx_top_k_accumulate(expr), 5) + |FROM VALUES cast(0.0 AS DECIMAL(4, 1)), cast(0.0 AS DECIMAL(4, 1)), + |cast(0.1 AS DECIMAL(4, 1)), cast(0.1 AS DECIMAL(4, 1)), cast(0.1 AS DECIMAL(4, 1)), + |NULL AS tab(expr)""".stripMargin) + checkAnswer( + res, + Row(Seq(Row(new java.math.BigDecimal("0.1"), 3), + Row(new java.math.BigDecimal("0.0"), 2), + Row(null, 1)))) + } + + test("SPARK-53960: accumulate and estimate work on typed column with only NULL values") { + val res = sql( + """SELECT approx_top_k_estimate(approx_top_k_accumulate(expr)) + |FROM VALUES cast(NULL AS INT), cast(NULL AS INT) AS tab(expr)""".stripMargin) + checkAnswer(res, Row(Seq(Row(null, 2)))) + } + + test("SPARK-53960: accumulate a column of all nulls with type - success") { + withView("accumulation") { + val res = sql( + """SELECT approx_top_k_accumulate(expr) AS acc + |FROM VALUES cast(NULL AS INT), cast(NULL AS INT) AS tab(expr)""".stripMargin) + + assert(res.collect().length == 1) + res.createOrReplaceTempView("accumulation") + val est = sql("SELECT approx_top_k_estimate(acc) FROM accumulation;") + checkAnswer(est, Row(Seq(Row(null, 2)))) + + } + } + + test("SPARK-53960: accumulate a column of all nulls without type - fail") { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql("""SELECT approx_top_k_accumulate(expr) + |FROM VALUES (NULL), (NULL), (NULL), (NULL) AS tab(expr)""".stripMargin) + }, + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + parameters = Map( + "sqlExpr" -> "\"approx_top_k_accumulate(expr, 10000)\"", + "msg" -> "void columns are not supported", + "hint" -> "" + ), + queryContext = Array(ExpectedContext("approx_top_k_accumulate(expr)", 7, 35)) + ) + } + ///////////////////////////////// // approx_top_k_combine ///////////////////////////////// @@ -445,75 +516,87 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { // positive tests for approx_top_k_combine on every types gridTest("SPARK-52798: same type, same size, specified combine size - success")(itemsWithTopK) { case (input, expected) => - sql(s"SELECT approx_top_k_accumulate(expr) AS acc FROM VALUES $input AS tab(expr);") - .createOrReplaceTempView("accumulation1") - sql(s"SELECT approx_top_k_accumulate(expr) AS acc FROM VALUES $input AS tab(expr);") - .createOrReplaceTempView("accumulation2") - sql("SELECT approx_top_k_combine(acc, 30) as com " + - "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") - .createOrReplaceTempView("combined") - val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") - // expected should be doubled because we combine two identical sketches - val expectedDoubled = expected.map { - case Row(value: Any, count: Int) => Row(value, count * 2) + withView("accumulation1", "accumulation2", "combines") { + sql(s"SELECT approx_top_k_accumulate(expr) AS acc FROM VALUES $input AS tab(expr);") + .createOrReplaceTempView("accumulation1") + sql(s"SELECT approx_top_k_accumulate(expr) AS acc FROM VALUES $input AS tab(expr);") + .createOrReplaceTempView("accumulation2") + sql("SELECT approx_top_k_combine(acc, 30) as com " + + "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + .createOrReplaceTempView("combined") + val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") + // expected should be doubled because we combine two identical sketches + val expectedDoubled = expected.map { + case Row(value: Any, count: Int) => Row(value, count * 2) + } + checkAnswer(est, Row(expectedDoubled)) } - checkAnswer(est, Row(expectedDoubled)) } test("SPARK-52798: same type, same size, specified combine size - success") { - setupMixedSizeAccumulations(10, 10) + withView("accumulation1", "accumulation2", "unioned", "combined") { + setupMixedSizeAccumulations(10, 10) - sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned") - .createOrReplaceTempView("combined") + sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned") + .createOrReplaceTempView("combined") - val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") - checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), Row(4, 2)))) + val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") + checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), Row(4, 2)))) + } } test("SPARK-52798: same type, same size, unspecified combine size - success") { - setupMixedSizeAccumulations(10, 10) + withView("accumulation1", "accumulation2", "unioned", "combined") { + setupMixedSizeAccumulations(10, 10) - sql("SELECT approx_top_k_combine(acc) as com FROM unioned") - .createOrReplaceTempView("combined") + sql("SELECT approx_top_k_combine(acc) as com FROM unioned") + .createOrReplaceTempView("combined") - val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") - checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), Row(4, 2)))) + val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") + checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), Row(4, 2)))) + } } test("SPARK-52798: same type, different size, specified combine size - success") { - setupMixedSizeAccumulations(10, 20) + withView("accumulation1", "accumulation2", "unioned", "combined") { + setupMixedSizeAccumulations(10, 20) - sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned") - .createOrReplaceTempView("combination") + sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned") + .createOrReplaceTempView("combined") - val est = sql("SELECT approx_top_k_estimate(com) FROM combination;") - checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), Row(4, 2)))) + val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") + checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), Row(4, 2)))) + } } test("SPARK-52798: same type, different size, unspecified combine size - fail") { - setupMixedSizeAccumulations(10, 20) + withView("accumulation1", "accumulation2", "unioned") { + setupMixedSizeAccumulations(10, 20) - val comb = sql("SELECT approx_top_k_combine(acc) as com FROM unioned") - - checkError( - exception = intercept[SparkRuntimeException] { - comb.collect() - }, - condition = "APPROX_TOP_K_SKETCH_SIZE_NOT_MATCH", - parameters = Map("size1" -> "10", "size2" -> "20") - ) - } + val comb = sql("SELECT approx_top_k_combine(acc) as com FROM unioned") - gridTest("SPARK-52798: invalid combine size - fail")(Seq((10, 10), (10, 20))) { - case (size1, size2) => - setupMixedSizeAccumulations(size1, size2) checkError( exception = intercept[SparkRuntimeException] { - sql("SELECT approx_top_k_combine(acc, 0) as com FROM unioned").collect() + comb.collect() }, - condition = "APPROX_TOP_K_NON_POSITIVE_ARG", - parameters = Map("argName" -> "`maxItemsTracked`", "argValue" -> "0") + condition = "APPROX_TOP_K_SKETCH_SIZE_NOT_MATCH", + parameters = Map("size1" -> "10", "size2" -> "20") ) + } + } + + gridTest("SPARK-52798: invalid combine size - fail")(Seq((10, 10), (10, 20))) { + case (size1, size2) => + withView("accumulation1", "accumulation2", "unioned") { + setupMixedSizeAccumulations(size1, size2) + checkError( + exception = intercept[SparkRuntimeException] { + sql("SELECT approx_top_k_combine(acc, 0) as com FROM unioned").collect() + }, + condition = "APPROX_TOP_K_NON_POSITIVE_ARG", + parameters = Map("argName" -> "`maxItemsTracked`", "argValue" -> "0") + ) + } } test("SPARK-52798: among different number or datetime types - fail at combine") { @@ -523,13 +606,15 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { val (type1, _, seq1) = mixedTypeSeq(i) val (type2, _, seq2) = mixedTypeSeq(j) setupMixedTypeAccumulation(seq1, seq2) - checkError( - exception = intercept[SparkRuntimeException] { - sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() - }, - condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH", - parameters = Map("type1" -> toSQLType(type1), "type2" -> toSQLType(type2)) - ) + withView("accumulation1", "accumulation2", "unioned") { + checkError( + exception = intercept[SparkRuntimeException] { + sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() + }, + condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH", + parameters = Map("type1" -> toSQLType(type1), "type2" -> toSQLType(type2)) + ) + } } } } @@ -547,7 +632,9 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { case ((_, type1, seq1), (_, type2, seq2)) => checkError( exception = intercept[ExtendedAnalysisException] { - setupMixedTypeAccumulation(seq1, seq2) + withView("accumulation1", "accumulation2", "unioned") { + setupMixedTypeAccumulation(seq1, seq2) + } }, condition = "INCOMPATIBLE_COLUMN_TYPE", parameters = Map( @@ -568,14 +655,17 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { gridTest("SPARK-52798: number vs string - fail at combine")(mixedNumberTypes) { case (type1, _, seq1) => - setupMixedTypeAccumulation(seq1, Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'")) - checkError( - exception = intercept[SparkRuntimeException] { - sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() - }, - condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH", - parameters = Map("type1" -> toSQLType(type1), "type2" -> toSQLType(StringType)) - ) + withView("accumulation1", "accumulation2", "unioned") { + setupMixedTypeAccumulation( + seq1, Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'")) + checkError( + exception = intercept[SparkRuntimeException] { + sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() + }, + condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH", + parameters = Map("type1" -> toSQLType(type1), "type2" -> toSQLType(StringType)) + ) + } } gridTest("SPARK-52798: number vs boolean - fail at UNION")(mixedNumberTypes) { @@ -583,7 +673,9 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { val seq2 = Seq("(true)", "(true)", "(false)", "(false)") checkError( exception = intercept[ExtendedAnalysisException] { - setupMixedTypeAccumulation(seq1, seq2) + withView("accumulation1", "accumulation2", "unioned") { + setupMixedTypeAccumulation(seq1, seq2) + } }, condition = "INCOMPATIBLE_COLUMN_TYPE", parameters = Map( @@ -604,14 +696,17 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { gridTest("SPARK-52798: datetime vs string - fail at combine")(mixedDateTimeTypes) { case (type1, _, seq1) => - setupMixedTypeAccumulation(seq1, Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'")) - checkError( - exception = intercept[SparkRuntimeException] { - sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() - }, - condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH", - parameters = Map("type1" -> toSQLType(type1), "type2" -> toSQLType(StringType)) - ) + withView("accumulation1", "accumulation2", "unioned") { + setupMixedTypeAccumulation( + seq1, Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'")) + checkError( + exception = intercept[SparkRuntimeException] { + sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() + }, + condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH", + parameters = Map("type1" -> toSQLType(type1), "type2" -> toSQLType(StringType)) + ) + } } gridTest("SPARK-52798: datetime vs boolean - fail at UNION")(mixedDateTimeTypes) { @@ -619,7 +714,9 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { val seq2 = Seq("(true)", "(true)", "(false)", "(false)") checkError( exception = intercept[ExtendedAnalysisException] { - setupMixedTypeAccumulation(seq1, seq2) + withView("accumulation1", "accumulation2", "unioned") { + setupMixedTypeAccumulation(seq1, seq2) + } }, condition = "INCOMPATIBLE_COLUMN_TYPE", parameters = Map( @@ -641,65 +738,174 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { test("SPARK-52798: string vs boolean - fail at combine") { val seq1 = Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'") val seq2 = Seq("(true)", "(true)", "(false)", "(false)") - setupMixedTypeAccumulation(seq1, seq2) - checkError( - exception = intercept[SparkRuntimeException] { - sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() - }, - condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH", - parameters = Map("type1" -> toSQLType(StringType), "type2" -> toSQLType(BooleanType)) - ) + withView("accumulation1", "accumulation2", "unioned") { + setupMixedTypeAccumulation(seq1, seq2) + checkError( + exception = intercept[SparkRuntimeException] { + sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() + }, + condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH", + parameters = Map("type1" -> toSQLType(StringType), "type2" -> toSQLType(BooleanType)) + ) + } } test("SPARK-52798: combine more than 2 sketches with specified size") { - sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " + - "FROM VALUES (0), (0), (0), (1), (1), (2), (2) AS tab(expr);") - .createOrReplaceTempView("accumulation1") + withView("accumulation1", "accumulation2", "accumulation3", "unioned", "combined") { + sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " + + "FROM VALUES (0), (0), (0), (1), (1), (2), (2) AS tab(expr);") + .createOrReplaceTempView("accumulation1") - sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " + - "FROM VALUES (1), (1), (2), (2), (3), (3), (4) AS tab(expr);") - .createOrReplaceTempView("accumulation2") + sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " + + "FROM VALUES (1), (1), (2), (2), (3), (3), (4) AS tab(expr);") + .createOrReplaceTempView("accumulation2") - sql(s"SELECT approx_top_k_accumulate(expr, 20) as acc " + - "FROM VALUES (2), (2), (3), (3), (3), (4), (5) AS tab(expr);") - .createOrReplaceTempView("accumulation3") + sql(s"SELECT approx_top_k_accumulate(expr, 20) as acc " + + "FROM VALUES (2), (2), (3), (3), (3), (4), (5) AS tab(expr);") + .createOrReplaceTempView("accumulation3") - sql("SELECT acc from accumulation1 UNION ALL " + - "SELECT acc FROM accumulation2 UNION ALL " + - "SELECT acc FROM accumulation3") - .createOrReplaceTempView("unioned") + sql("SELECT acc from accumulation1 UNION ALL " + + "SELECT acc FROM accumulation2 UNION ALL " + + "SELECT acc FROM accumulation3") + .createOrReplaceTempView("unioned") - sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned") - .createOrReplaceTempView("combined") + sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned") + .createOrReplaceTempView("combined") - val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") - checkAnswer(est, Row(Seq(Row(2, 6), Row(3, 5), Row(1, 4), Row(0, 3), Row(4, 2)))) + val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") + checkAnswer(est, Row(Seq(Row(2, 6), Row(3, 5), Row(1, 4), Row(0, 3), Row(4, 2)))) + } } test("SPARK-52798: combine more than 2 sketches without specified size") { - sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " + - "FROM VALUES (0), (0), (0), (1), (1), (2), (2) AS tab(expr);") - .createOrReplaceTempView("accumulation1") + withView("accumulation1", "accumulation2", "accumulation3", "unioned") { + sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " + + "FROM VALUES (0), (0), (0), (1), (1), (2), (2) AS tab(expr);") + .createOrReplaceTempView("accumulation1") - sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " + - "FROM VALUES (1), (1), (2), (2), (3), (3), (4) AS tab(expr);") - .createOrReplaceTempView("accumulation2") + sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " + + "FROM VALUES (1), (1), (2), (2), (3), (3), (4) AS tab(expr);") + .createOrReplaceTempView("accumulation2") - sql(s"SELECT approx_top_k_accumulate(expr, 20) as acc " + - "FROM VALUES (2), (2), (3), (3), (3), (4), (5) AS tab(expr);") - .createOrReplaceTempView("accumulation3") + sql(s"SELECT approx_top_k_accumulate(expr, 20) as acc " + + "FROM VALUES (2), (2), (3), (3), (3), (4), (5) AS tab(expr);") + .createOrReplaceTempView("accumulation3") - sql("SELECT acc from accumulation1 UNION ALL " + - "SELECT acc FROM accumulation2 UNION ALL " + - "SELECT acc FROM accumulation3") - .createOrReplaceTempView("unioned") + sql("SELECT acc from accumulation1 UNION ALL " + + "SELECT acc FROM accumulation2 UNION ALL " + + "SELECT acc FROM accumulation3") + .createOrReplaceTempView("unioned") - checkError( - exception = intercept[SparkRuntimeException] { - sql("SELECT approx_top_k_combine(acc) as com FROM unioned").collect() - }, - condition = "APPROX_TOP_K_SKETCH_SIZE_NOT_MATCH", - parameters = Map("size1" -> "10", "size2" -> "20") - ) + checkError( + exception = intercept[SparkRuntimeException] { + sql("SELECT approx_top_k_combine(acc) as com FROM unioned").collect() + }, + condition = "APPROX_TOP_K_SKETCH_SIZE_NOT_MATCH", + parameters = Map("size1" -> "10", "size2" -> "20") + ) + } + } + + test("SPARK-53960: combine and estimate count NULL values") { + withView("accumulation1", "accumulation2", "unioned", "combined") { + sql( + """SELECT approx_top_k_accumulate(expr, 10) as acc + |FROM VALUES 'a', 'a', 'b', NULL, NULL AS tab(expr)""".stripMargin) + .createOrReplaceTempView("accumulation1") + + sql( + """SELECT approx_top_k_accumulate(expr, 10) as acc + |FROM VALUES 'b', 'b', NULL, NULL AS tab(expr)""".stripMargin) + .createOrReplaceTempView("accumulation2") + + sql("SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2") + .createOrReplaceTempView("unioned") + + sql("SELECT approx_top_k_combine(acc, 20) as com FROM unioned") + .createOrReplaceTempView("combined") + + val est = sql("SELECT approx_top_k_estimate(com, 2) FROM combined") + checkAnswer(est, Row(Seq(Row(null, 4), Row("b", 3)))) + } + } + + test("SPARK-53960: combine with a sketch of all nulls") { + withView("accumulation1", "accumulation2", "unioned", "combined") { + sql( + """SELECT approx_top_k_accumulate(expr, 10) as acc + |FROM VALUES cast(NULL AS INT), cast(NULL AS INT), cast(NULL AS INT) + |AS tab(expr)""".stripMargin) + .createOrReplaceTempView("accumulation1") + + sql( + """SELECT approx_top_k_accumulate(expr, 10) as acc + |FROM VALUES 1, 1, 2, 2 AS tab(expr)""".stripMargin) + .createOrReplaceTempView("accumulation2") + + sql("SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2") + .createOrReplaceTempView("unioned") + + sql("SELECT approx_top_k_combine(acc, 20) as com FROM unioned") + .createOrReplaceTempView("combined") + + val est = sql("SELECT approx_top_k_estimate(com) FROM combined") + checkAnswer(est, Row(Seq(Row(null, 3), Row(2, 2), Row(1, 2)))) + } + } + + test("SPARK-53960: combine sketches with nulls from more than 2 sketches") { + withView("accumulation1", "accumulation2", "accumulation3", "unioned", "combined") { + sql( + """SELECT approx_top_k_accumulate(expr, 10) as acc + |FROM VALUES 0, 0, 0, 1, 1, NULL AS tab(expr)""".stripMargin) + .createOrReplaceTempView("accumulation1") + + sql( + """SELECT approx_top_k_accumulate(expr, 10) as acc + |FROM VALUES NULL, 1, 1, 2, 2, NULL AS tab(expr)""".stripMargin) + .createOrReplaceTempView("accumulation2") + + sql( + """SELECT approx_top_k_accumulate(expr, 10) as acc + |FROM VALUES 2, 3, 3, NULL AS tab(expr)""".stripMargin) + .createOrReplaceTempView("accumulation3") + + sql( + """SELECT acc from accumulation1 UNION ALL + |SELECT acc FROM accumulation2 UNION ALL + |SELECT acc FROM accumulation3""".stripMargin) + .createOrReplaceTempView("unioned") + + sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned") + .createOrReplaceTempView("combined") + + val est = sql("SELECT approx_top_k_estimate(com, 2) FROM combined") + checkAnswer(est, Row(Seq(Row(1, 4), Row(null, 4)))) + } + } + + test("SPARK-53960: combine 2 sketches with all nulls") { + withView("accumulation1", "accumulation2", "unioned", "combined") { + sql( + """SELECT approx_top_k_accumulate(expr, 10) as acc + |FROM VALUES cast(NULL AS INT), cast(NULL AS INT), cast(NULL AS INT) + |AS tab(expr)""".stripMargin) + .createOrReplaceTempView("accumulation1") + + sql( + """SELECT approx_top_k_accumulate(expr, 10) as acc + |FROM VALUES cast(NULL AS INT), cast(NULL AS INT) + |AS tab(expr)""".stripMargin) + .createOrReplaceTempView("accumulation2") + + sql("SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2") + .createOrReplaceTempView("unioned") + + sql("SELECT approx_top_k_combine(acc, 20) as com FROM unioned") + .createOrReplaceTempView("combined") + + val est = sql("SELECT approx_top_k_estimate(com) FROM combined") + checkAnswer(est, Row(Seq(Row(null, 5)))) + } } }