diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala index 6c6b3b805048..1fca8ad86bc2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala @@ -79,7 +79,7 @@ case class ApproxTopK( maxItemsTracked: Expression, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends TypedImperativeAggregate[ItemsSketch[Any]] + extends TypedImperativeAggregate[ApproxTopKAggregateBuffer[Any]] with ImplicitCastInputTypes with TernaryLike[Expression] { @@ -137,25 +137,30 @@ case class ApproxTopK( override def dataType: DataType = ApproxTopK.getResultDataType(itemDataType) - override def createAggregationBuffer(): ItemsSketch[Any] = { + override def createAggregationBuffer(): ApproxTopKAggregateBuffer[Any] = { val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal) - ApproxTopK.createAggregationBuffer(expr, maxMapSize) + val sketch = ApproxTopK.createItemsSketch(expr, maxMapSize) + new ApproxTopKAggregateBuffer[Any](sketch, 0L) } - override def update(buffer: ItemsSketch[Any], input: InternalRow): ItemsSketch[Any] = - ApproxTopK.updateSketchBuffer(expr, buffer, input) + override def update(buffer: ApproxTopKAggregateBuffer[Any], input: InternalRow): + ApproxTopKAggregateBuffer[Any] = + buffer.update(expr, input) - override def merge(buffer: ItemsSketch[Any], input: ItemsSketch[Any]): ItemsSketch[Any] = + override def merge( + buffer: ApproxTopKAggregateBuffer[Any], + input: ApproxTopKAggregateBuffer[Any]): + ApproxTopKAggregateBuffer[Any] = buffer.merge(input) - override def eval(buffer: ItemsSketch[Any]): GenericArrayData = - ApproxTopK.genEvalResult(buffer, kVal, itemDataType) + override def eval(buffer: ApproxTopKAggregateBuffer[Any]): GenericArrayData = + buffer.eval(kVal, itemDataType) - override def serialize(buffer: ItemsSketch[Any]): Array[Byte] = - buffer.toByteArray(ApproxTopK.genSketchSerDe(itemDataType)) + override def serialize(buffer: ApproxTopKAggregateBuffer[Any]): Array[Byte] = + buffer.serialize(ApproxTopK.genSketchSerDe(itemDataType)) - override def deserialize(storageFormat: Array[Byte]): ItemsSketch[Any] = - ItemsSketch.getInstance(Memory.wrap(storageFormat), ApproxTopK.genSketchSerDe(itemDataType)) + override def deserialize(storageFormat: Array[Byte]): ApproxTopKAggregateBuffer[Any] = + ApproxTopKAggregateBuffer.deserialize(storageFormat, ApproxTopK.genSketchSerDe(itemDataType)) override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) @@ -214,7 +219,7 @@ object ApproxTopK { def getResultDataType(itemDataType: DataType): DataType = { val resultEntryType = StructType( - StructField("item", itemDataType, nullable = false) :: + StructField("item", itemDataType, nullable = true) :: StructField("count", LongType, nullable = false) :: Nil) ArrayType(resultEntryType, containsNull = false) } @@ -238,7 +243,7 @@ object ApproxTopK { math.pow(2, math.ceil(math.log(ceilMaxMapSize) / math.log(2))).toInt } - def createAggregationBuffer(itemExpression: Expression, maxMapSize: Int): ItemsSketch[Any] = { + def createItemsSketch(itemExpression: Expression, maxMapSize: Int): ItemsSketch[Any] = { itemExpression.dataType match { case _: BooleanType => new ItemsSketch[Boolean](maxMapSize).asInstanceOf[ItemsSketch[Any]] @@ -369,6 +374,145 @@ object ApproxTopK { } } +/** + * In internal class used as the aggregation buffer for ApproxTopK. + * + * @param sketch the ItemsSketch instance for counting not-null items + * @param nullCount the count of null items + */ +class ApproxTopKAggregateBuffer[T](val sketch: ItemsSketch[T], private var nullCount: Long) { + def update(itemExpression: Expression, input: InternalRow): ApproxTopKAggregateBuffer[T] = { + val v = itemExpression.eval(input) + if (v != null) { + itemExpression.dataType match { + case _: BooleanType => + sketch.asInstanceOf[ItemsSketch[Boolean]].update(v.asInstanceOf[Boolean]) + case _: ByteType => + sketch.asInstanceOf[ItemsSketch[Byte]].update(v.asInstanceOf[Byte]) + case _: ShortType => + sketch.asInstanceOf[ItemsSketch[Short]].update(v.asInstanceOf[Short]) + case _: IntegerType => + sketch.asInstanceOf[ItemsSketch[Int]].update(v.asInstanceOf[Int]) + case _: LongType => + sketch.asInstanceOf[ItemsSketch[Long]].update(v.asInstanceOf[Long]) + case _: FloatType => + sketch.asInstanceOf[ItemsSketch[Float]].update(v.asInstanceOf[Float]) + case _: DoubleType => + sketch.asInstanceOf[ItemsSketch[Double]].update(v.asInstanceOf[Double]) + case _: DateType => + sketch.asInstanceOf[ItemsSketch[Int]].update(v.asInstanceOf[Int]) + case _: TimestampType => + sketch.asInstanceOf[ItemsSketch[Long]].update(v.asInstanceOf[Long]) + case _: TimestampNTZType => + sketch.asInstanceOf[ItemsSketch[Long]].update(v.asInstanceOf[Long]) + case st: StringType => + val cKey = CollationFactory.getCollationKey(v.asInstanceOf[UTF8String], st.collationId) + sketch.asInstanceOf[ItemsSketch[String]].update(cKey.toString) + case _: DecimalType => + sketch.asInstanceOf[ItemsSketch[Decimal]].update(v.asInstanceOf[Decimal]) + } + } else { + nullCount += 1 + } + this + } + + def merge(other: ApproxTopKAggregateBuffer[T]): ApproxTopKAggregateBuffer[T] = { + sketch.merge(other.sketch) + nullCount += other.nullCount + this + } + + /** + * Serialize the buffer into bytes. + * The format is: + * [sketch bytes][null count (8 bytes Long)] + */ + def serialize(serDe: ArrayOfItemsSerDe[T]): Array[Byte] = { + val sketchBytes = sketch.toByteArray(serDe) + val result = new Array[Byte](sketchBytes.length + java.lang.Long.BYTES) + val byteBuffer = java.nio.ByteBuffer.wrap(result) + byteBuffer.put(sketchBytes) + byteBuffer.putLong(nullCount) + result + } + + /** + * Evaluate the buffer and return top K items (including null) with their estimated frequency. + * The result is sorted by frequency in descending order. + */ + def eval(k: Int, itemDataType: DataType): GenericArrayData = { + // frequent items from sketch + val frequentItems = sketch.getFrequentItems(ErrorType.NO_FALSE_POSITIVES) + // total number of frequent items (including null, if any) + val itemsLength = frequentItems.length + (if (nullCount > 0) 1 else 0) + // actual number of items to return + val resultLength = math.min(itemsLength, k) + val result = new Array[AnyRef](resultLength) + + // variable pointers for merging frequent items and nullCount into result + var fiIndex = 0 // pointer for frequentItems + var resultIndex = 0 // pointer for result + var isNullAdded = false // whether nullCount has been added to result + + // helper function to get nullCount estimate: if nullCount has been added, return Long.MinValue + // so that it won't be added again; otherwise return nullCount + @inline def getNullEstimate: Long = if (!isNullAdded) nullCount else Long.MinValue + + // looping until result is full or run out of frequent items + while (resultIndex < resultLength && fiIndex < frequentItems.length) { + val curFrequentItem = frequentItems(fiIndex) + val itemEstimate = curFrequentItem.getEstimate + val nullEstimate = getNullEstimate + + val (item, estimate) = if (nullEstimate > itemEstimate) { + // insert (null, nullCount) into result + isNullAdded = true + (null, nullCount.toLong) + } else { + // insert frequent item into result + val item: Any = itemDataType match { + case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | + _: LongType | _: FloatType | _: DoubleType | _: DecimalType | + _: DateType | _: TimestampType | _: TimestampNTZType => + curFrequentItem.getItem + case _: StringType => + UTF8String.fromString(curFrequentItem.getItem.asInstanceOf[String]) + } + fiIndex += 1 // move to next frequent item + (item, itemEstimate) + } + result(resultIndex) = InternalRow(item, estimate) + resultIndex += 1 // move to next result position + } + + // in case there is still space in result and nullCount > 0 has not been added + if (resultIndex < resultLength && nullCount > 0 && !isNullAdded) { + result(resultIndex) = InternalRow(null, nullCount.toLong) + } + + new GenericArrayData(result) + } +} + +object ApproxTopKAggregateBuffer { + /** + * Deserialize the buffer from bytes. + * The format is: + * [sketch bytes][null count (8 bytes)] + */ + def deserialize(bytes: Array[Byte], serDe: ArrayOfItemsSerDe[Any]): + ApproxTopKAggregateBuffer[Any] = { + val byteBuffer = java.nio.ByteBuffer.wrap(bytes) + val sketchBytesLength = bytes.length - 8 + val sketchBytes = new Array[Byte](sketchBytesLength) + byteBuffer.get(sketchBytes, 0, sketchBytesLength) + val nullCount = byteBuffer.getLong(sketchBytesLength) + val deserializedSketch = ItemsSketch.getInstance(Memory.wrap(sketchBytes), serDe) + new ApproxTopKAggregateBuffer[Any](deserializedSketch, nullCount) + } +} + /** * An aggregate function that accumulates items into a sketch, which can then be used * to combine with other sketches, via ApproxTopKCombine, @@ -450,7 +594,7 @@ case class ApproxTopKAccumulate( override def createAggregationBuffer(): ItemsSketch[Any] = { val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal) - ApproxTopK.createAggregationBuffer(expr, maxMapSize) + ApproxTopK.createItemsSketch(expr, maxMapSize) } override def update(buffer: ItemsSketch[Any], input: InternalRow): ItemsSketch[Any] = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala index 702f361ace28..d9d16d1234b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala @@ -196,13 +196,62 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { ) } - test("SPARK-52515: does not count NULL values") { + test("SPARK-53947: count NULL values") { val res = sql( "SELECT approx_top_k(expr, 2)" + - "FROM VALUES 'a', 'a', 'b', 'b', 'b', NULL, NULL, NULL AS tab(expr);") + "FROM VALUES 'a', 'a', 'b', 'b', 'b', NULL, NULL, NULL, NULL AS tab(expr);") + checkAnswer(res, Row(Seq(Row(null, 4), Row("b", 3)))) + } + + test("SPARK-53947: null is not in top k") { + val res = sql( + "SELECT approx_top_k(expr, 2) FROM VALUES 'a', 'a', 'b', 'b', 'b', NULL AS tab(expr)" + ) checkAnswer(res, Row(Seq(Row("b", 3), Row("a", 2)))) } + test("SPARK-53947: null is the last in top k") { + val res = sql( + "SELECT approx_top_k(expr, 3) FROM VALUES 0, 0, 1, 1, 1, NULL AS tab(expr)" + ) + checkAnswer(res, Row(Seq(Row(1, 3), Row(0, 2), Row(null, 1)))) + } + + test("SPARK-53947: null + frequent items < k") { + val res = sql( + """SELECT approx_top_k(expr, 5) + |FROM VALUES cast(0.0 AS DECIMAL(4, 1)), cast(0.0 AS DECIMAL(4, 1)), + |cast(0.1 AS DECIMAL(4, 1)), cast(0.1 AS DECIMAL(4, 1)), cast(0.1 AS DECIMAL(4, 1)), + |NULL AS tab(expr)""".stripMargin) + checkAnswer( + res, + Row(Seq(Row(new java.math.BigDecimal("0.1"), 3), + Row(new java.math.BigDecimal("0.0"), 2), + Row(null, 1)))) + } + + test("SPARK-53947: work on typed column with only NULL values") { + val res = sql( + "SELECT approx_top_k(expr) FROM VALUES cast(NULL AS INT), cast(NULL AS INT) AS tab(expr)" + ) + checkAnswer(res, Row(Seq(Row(null, 2)))) + } + + test("SPARK-53947: invalid item void columns") { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql("SELECT approx_top_k(expr) FROM VALUES (NULL), (NULL), (NULL) AS tab(expr)") + }, + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + parameters = Map( + "sqlExpr" -> "\"approx_top_k(expr, 5, 10000)\"", + "msg" -> "void columns are not supported", + "hint" -> "" + ), + queryContext = Array(ExpectedContext("approx_top_k(expr)", 7, 24)) + ) + } + ///////////////////////////////// // approx_top_k_accumulate and // approx_top_k_estimate tests