Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ case class ApproxTopK(
maxItemsTracked: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends TypedImperativeAggregate[ItemsSketch[Any]]
extends TypedImperativeAggregate[ApproxTopKAggregateBuffer[Any]]
with ImplicitCastInputTypes
with TernaryLike[Expression] {

Expand Down Expand Up @@ -137,25 +137,30 @@ case class ApproxTopK(

override def dataType: DataType = ApproxTopK.getResultDataType(itemDataType)

override def createAggregationBuffer(): ItemsSketch[Any] = {
override def createAggregationBuffer(): ApproxTopKAggregateBuffer[Any] = {
val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal)
ApproxTopK.createAggregationBuffer(expr, maxMapSize)
val sketch = ApproxTopK.createItemsSketch(expr, maxMapSize)
new ApproxTopKAggregateBuffer[Any](sketch, 0L)
}

override def update(buffer: ItemsSketch[Any], input: InternalRow): ItemsSketch[Any] =
ApproxTopK.updateSketchBuffer(expr, buffer, input)
override def update(buffer: ApproxTopKAggregateBuffer[Any], input: InternalRow):
ApproxTopKAggregateBuffer[Any] =
buffer.update(expr, input)

override def merge(buffer: ItemsSketch[Any], input: ItemsSketch[Any]): ItemsSketch[Any] =
override def merge(
buffer: ApproxTopKAggregateBuffer[Any],
input: ApproxTopKAggregateBuffer[Any]):
ApproxTopKAggregateBuffer[Any] =
buffer.merge(input)

override def eval(buffer: ItemsSketch[Any]): GenericArrayData =
ApproxTopK.genEvalResult(buffer, kVal, itemDataType)
override def eval(buffer: ApproxTopKAggregateBuffer[Any]): GenericArrayData =
buffer.eval(kVal, itemDataType)

override def serialize(buffer: ItemsSketch[Any]): Array[Byte] =
buffer.toByteArray(ApproxTopK.genSketchSerDe(itemDataType))
override def serialize(buffer: ApproxTopKAggregateBuffer[Any]): Array[Byte] =
buffer.serialize(ApproxTopK.genSketchSerDe(itemDataType))

override def deserialize(storageFormat: Array[Byte]): ItemsSketch[Any] =
ItemsSketch.getInstance(Memory.wrap(storageFormat), ApproxTopK.genSketchSerDe(itemDataType))
override def deserialize(storageFormat: Array[Byte]): ApproxTopKAggregateBuffer[Any] =
ApproxTopKAggregateBuffer.deserialize(storageFormat, ApproxTopK.genSketchSerDe(itemDataType))

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
Expand Down Expand Up @@ -214,7 +219,7 @@ object ApproxTopK {

def getResultDataType(itemDataType: DataType): DataType = {
val resultEntryType = StructType(
StructField("item", itemDataType, nullable = false) ::
StructField("item", itemDataType, nullable = true) ::
StructField("count", LongType, nullable = false) :: Nil)
ArrayType(resultEntryType, containsNull = false)
}
Expand All @@ -238,7 +243,7 @@ object ApproxTopK {
math.pow(2, math.ceil(math.log(ceilMaxMapSize) / math.log(2))).toInt
}

def createAggregationBuffer(itemExpression: Expression, maxMapSize: Int): ItemsSketch[Any] = {
def createItemsSketch(itemExpression: Expression, maxMapSize: Int): ItemsSketch[Any] = {
itemExpression.dataType match {
case _: BooleanType =>
new ItemsSketch[Boolean](maxMapSize).asInstanceOf[ItemsSketch[Any]]
Expand Down Expand Up @@ -369,6 +374,145 @@ object ApproxTopK {
}
}

/**
* In internal class used as the aggregation buffer for ApproxTopK.
*
* @param sketch the ItemsSketch instance for counting not-null items
* @param nullCount the count of null items
*/
class ApproxTopKAggregateBuffer[T](val sketch: ItemsSketch[T], private var nullCount: Long) {
def update(itemExpression: Expression, input: InternalRow): ApproxTopKAggregateBuffer[T] = {
val v = itemExpression.eval(input)
if (v != null) {
itemExpression.dataType match {
case _: BooleanType =>
sketch.asInstanceOf[ItemsSketch[Boolean]].update(v.asInstanceOf[Boolean])
case _: ByteType =>
sketch.asInstanceOf[ItemsSketch[Byte]].update(v.asInstanceOf[Byte])
case _: ShortType =>
sketch.asInstanceOf[ItemsSketch[Short]].update(v.asInstanceOf[Short])
case _: IntegerType =>
sketch.asInstanceOf[ItemsSketch[Int]].update(v.asInstanceOf[Int])
case _: LongType =>
sketch.asInstanceOf[ItemsSketch[Long]].update(v.asInstanceOf[Long])
case _: FloatType =>
sketch.asInstanceOf[ItemsSketch[Float]].update(v.asInstanceOf[Float])
case _: DoubleType =>
sketch.asInstanceOf[ItemsSketch[Double]].update(v.asInstanceOf[Double])
case _: DateType =>
sketch.asInstanceOf[ItemsSketch[Int]].update(v.asInstanceOf[Int])
case _: TimestampType =>
sketch.asInstanceOf[ItemsSketch[Long]].update(v.asInstanceOf[Long])
case _: TimestampNTZType =>
sketch.asInstanceOf[ItemsSketch[Long]].update(v.asInstanceOf[Long])
case st: StringType =>
val cKey = CollationFactory.getCollationKey(v.asInstanceOf[UTF8String], st.collationId)
sketch.asInstanceOf[ItemsSketch[String]].update(cKey.toString)
case _: DecimalType =>
sketch.asInstanceOf[ItemsSketch[Decimal]].update(v.asInstanceOf[Decimal])
}
} else {
nullCount += 1
}
this
}

def merge(other: ApproxTopKAggregateBuffer[T]): ApproxTopKAggregateBuffer[T] = {
sketch.merge(other.sketch)
nullCount += other.nullCount
this
}

/**
* Serialize the buffer into bytes.
* The format is:
* [sketch bytes][null count (8 bytes Long)]
*/
def serialize(serDe: ArrayOfItemsSerDe[T]): Array[Byte] = {
val sketchBytes = sketch.toByteArray(serDe)
val result = new Array[Byte](sketchBytes.length + java.lang.Long.BYTES)
val byteBuffer = java.nio.ByteBuffer.wrap(result)
byteBuffer.put(sketchBytes)
byteBuffer.putLong(nullCount)
result
}

/**
* Evaluate the buffer and return top K items (including null) with their estimated frequency.
* The result is sorted by frequency in descending order.
*/
def eval(k: Int, itemDataType: DataType): GenericArrayData = {
// frequent items from sketch
val frequentItems = sketch.getFrequentItems(ErrorType.NO_FALSE_POSITIVES)
// total number of frequent items (including null, if any)
val itemsLength = frequentItems.length + (if (nullCount > 0) 1 else 0)
// actual number of items to return
val resultLength = math.min(itemsLength, k)
val result = new Array[AnyRef](resultLength)

// variable pointers for merging frequent items and nullCount into result
var fiIndex = 0 // pointer for frequentItems
var resultIndex = 0 // pointer for result
var isNullAdded = false // whether nullCount has been added to result

// helper function to get nullCount estimate: if nullCount has been added, return Long.MinValue
// so that it won't be added again; otherwise return nullCount
@inline def getNullEstimate: Long = if (!isNullAdded) nullCount else Long.MinValue

// looping until result is full or run out of frequent items
while (resultIndex < resultLength && fiIndex < frequentItems.length) {
val curFrequentItem = frequentItems(fiIndex)
val itemEstimate = curFrequentItem.getEstimate
val nullEstimate = getNullEstimate

val (item, estimate) = if (nullEstimate > itemEstimate) {
// insert (null, nullCount) into result
isNullAdded = true
(null, nullCount.toLong)
} else {
// insert frequent item into result
val item: Any = itemDataType match {
case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType |
_: LongType | _: FloatType | _: DoubleType | _: DecimalType |
_: DateType | _: TimestampType | _: TimestampNTZType =>
curFrequentItem.getItem
case _: StringType =>
UTF8String.fromString(curFrequentItem.getItem.asInstanceOf[String])
}
fiIndex += 1 // move to next frequent item
(item, itemEstimate)
}
result(resultIndex) = InternalRow(item, estimate)
resultIndex += 1 // move to next result position
}

// in case there is still space in result and nullCount > 0 has not been added
if (resultIndex < resultLength && nullCount > 0 && !isNullAdded) {
result(resultIndex) = InternalRow(null, nullCount.toLong)
}

new GenericArrayData(result)
}
}

object ApproxTopKAggregateBuffer {
/**
* Deserialize the buffer from bytes.
* The format is:
* [sketch bytes][null count (8 bytes)]
*/
def deserialize(bytes: Array[Byte], serDe: ArrayOfItemsSerDe[Any]):
ApproxTopKAggregateBuffer[Any] = {
val byteBuffer = java.nio.ByteBuffer.wrap(bytes)
val sketchBytesLength = bytes.length - 8
val sketchBytes = new Array[Byte](sketchBytesLength)
byteBuffer.get(sketchBytes, 0, sketchBytesLength)
val nullCount = byteBuffer.getLong(sketchBytesLength)
val deserializedSketch = ItemsSketch.getInstance(Memory.wrap(sketchBytes), serDe)
new ApproxTopKAggregateBuffer[Any](deserializedSketch, nullCount)
}
}

/**
* An aggregate function that accumulates items into a sketch, which can then be used
* to combine with other sketches, via ApproxTopKCombine,
Expand Down Expand Up @@ -450,7 +594,7 @@ case class ApproxTopKAccumulate(

override def createAggregationBuffer(): ItemsSketch[Any] = {
val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal)
ApproxTopK.createAggregationBuffer(expr, maxMapSize)
ApproxTopK.createItemsSketch(expr, maxMapSize)
}

override def update(buffer: ItemsSketch[Any], input: InternalRow): ItemsSketch[Any] =
Expand Down
53 changes: 51 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,62 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession {
)
}

test("SPARK-52515: does not count NULL values") {
test("SPARK-53947: count NULL values") {
val res = sql(
"SELECT approx_top_k(expr, 2)" +
"FROM VALUES 'a', 'a', 'b', 'b', 'b', NULL, NULL, NULL AS tab(expr);")
"FROM VALUES 'a', 'a', 'b', 'b', 'b', NULL, NULL, NULL, NULL AS tab(expr);")
checkAnswer(res, Row(Seq(Row(null, 4), Row("b", 3))))
}

test("SPARK-53947: null is not in top k") {
val res = sql(
"SELECT approx_top_k(expr, 2) FROM VALUES 'a', 'a', 'b', 'b', 'b', NULL AS tab(expr)"
)
checkAnswer(res, Row(Seq(Row("b", 3), Row("a", 2))))
}

test("SPARK-53947: null is the last in top k") {
val res = sql(
"SELECT approx_top_k(expr, 3) FROM VALUES 0, 0, 1, 1, 1, NULL AS tab(expr)"
)
checkAnswer(res, Row(Seq(Row(1, 3), Row(0, 2), Row(null, 1))))
}

test("SPARK-53947: null + frequent items < k") {
val res = sql(
"""SELECT approx_top_k(expr, 5)
|FROM VALUES cast(0.0 AS DECIMAL(4, 1)), cast(0.0 AS DECIMAL(4, 1)),
|cast(0.1 AS DECIMAL(4, 1)), cast(0.1 AS DECIMAL(4, 1)), cast(0.1 AS DECIMAL(4, 1)),
|NULL AS tab(expr)""".stripMargin)
checkAnswer(
res,
Row(Seq(Row(new java.math.BigDecimal("0.1"), 3),
Row(new java.math.BigDecimal("0.0"), 2),
Row(null, 1))))
}

test("SPARK-53947: work on typed column with only NULL values") {
val res = sql(
"SELECT approx_top_k(expr) FROM VALUES cast(NULL AS INT), cast(NULL AS INT) AS tab(expr)"
)
checkAnswer(res, Row(Seq(Row(null, 2))))
}

test("SPARK-53947: invalid item void columns") {
checkError(
exception = intercept[ExtendedAnalysisException] {
sql("SELECT approx_top_k(expr) FROM VALUES (NULL), (NULL), (NULL) AS tab(expr)")
},
condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
parameters = Map(
"sqlExpr" -> "\"approx_top_k(expr, 5, 10000)\"",
"msg" -> "void columns are not supported",
"hint" -> ""
),
queryContext = Array(ExpectedContext("approx_top_k(expr)", 7, 24))
)
}

/////////////////////////////////
// approx_top_k_accumulate and
// approx_top_k_estimate tests
Expand Down