Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,11 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.datasketches.frequencies.ItemsSketch
import org.apache.datasketches.memory.Memory

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxTopK
import org.apache.spark.sql.catalyst.expressions.aggregate.{ApproxTopK, ApproxTopKAggregateBuffer}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -105,9 +102,10 @@ case class ApproxTopKEstimate(state: Expression, k: Expression)
val kVal = kEval.asInstanceOf[Int]
ApproxTopK.checkK(kVal)
ApproxTopK.checkMaxItemsTracked(maxItemsTrackedVal, kVal)
val itemsSketch = ItemsSketch.getInstance(
Memory.wrap(dataSketchBytes), ApproxTopK.genSketchSerDe(itemDataType))
ApproxTopK.genEvalResult(itemsSketch, kVal, itemDataType)
val approxTopKAggregateBuffer = ApproxTopKAggregateBuffer.deserialize(
dataSketchBytes,
ApproxTopK.genSketchSerDe(itemDataType))
approxTopKAggregateBuffer.eval(kVal, itemDataType)
}

override protected def withNewChildrenInternal(newState: Expression, newK: Expression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,54 +260,6 @@ object ApproxTopK {
}
}

def updateSketchBuffer(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused functions as update and evaluation are handled by ApproxTopKAggregateBuffer, not on ItemsSketch anymore.

itemExpression: Expression,
buffer: ItemsSketch[Any],
input: InternalRow): ItemsSketch[Any] = {
val v = itemExpression.eval(input)
if (v != null) {
itemExpression.dataType match {
case _: BooleanType => buffer.update(v.asInstanceOf[Boolean])
case _: ByteType => buffer.update(v.asInstanceOf[Byte])
case _: ShortType => buffer.update(v.asInstanceOf[Short])
case _: IntegerType => buffer.update(v.asInstanceOf[Int])
case _: LongType => buffer.update(v.asInstanceOf[Long])
case _: FloatType => buffer.update(v.asInstanceOf[Float])
case _: DoubleType => buffer.update(v.asInstanceOf[Double])
case _: DateType => buffer.update(v.asInstanceOf[Int])
case _: TimestampType => buffer.update(v.asInstanceOf[Long])
case _: TimestampNTZType => buffer.update(v.asInstanceOf[Long])
case st: StringType =>
val cKey = CollationFactory.getCollationKey(v.asInstanceOf[UTF8String], st.collationId)
buffer.update(cKey.toString)
case _: DecimalType => buffer.update(v.asInstanceOf[Decimal])
}
}
buffer
}

def genEvalResult(
itemsSketch: ItemsSketch[Any],
k: Int,
itemDataType: DataType): GenericArrayData = {
val items = itemsSketch.getFrequentItems(ErrorType.NO_FALSE_POSITIVES)
val resultLength = math.min(items.length, k)
val result = new Array[AnyRef](resultLength)
for (i <- 0 until resultLength) {
val row = items(i)
itemDataType match {
case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType |
_: LongType | _: FloatType | _: DoubleType | _: DecimalType |
_: DateType | _: TimestampType | _: TimestampNTZType =>
result(i) = InternalRow.apply(row.getItem, row.getEstimate)
case _: StringType =>
val item = UTF8String.fromString(row.getItem.asInstanceOf[String])
result(i) = InternalRow.apply(item, row.getEstimate)
}
}
new GenericArrayData(result)
}

def genSketchSerDe(dataType: DataType): ArrayOfItemsSerDe[Any] = {
dataType match {
case _: BooleanType => new ArrayOfBooleansSerDe().asInstanceOf[ArrayOfItemsSerDe[Any]]
Expand All @@ -333,7 +285,7 @@ object ApproxTopK {

def dataTypeToDDL(dataType: DataType): String = dataType match {
case _: StringType =>
// Hide collation information in DDL format
// Hide collation information in DDL format, otherwise CollationExpressionWalkerSuite fails
s"item string not null"
case other =>
StructField("item", other, nullable = false).toDDL
Expand Down Expand Up @@ -552,7 +504,7 @@ case class ApproxTopKAccumulate(
maxItemsTracked: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends TypedImperativeAggregate[ItemsSketch[Any]]
extends TypedImperativeAggregate[ApproxTopKAggregateBuffer[Any]]
with ImplicitCastInputTypes
with BinaryLike[Expression] {

Expand Down Expand Up @@ -592,18 +544,23 @@ case class ApproxTopKAccumulate(

override def dataType: DataType = ApproxTopK.getSketchStateDataType(itemDataType)

override def createAggregationBuffer(): ItemsSketch[Any] = {
override def createAggregationBuffer(): ApproxTopKAggregateBuffer[Any] = {
val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal)
ApproxTopK.createItemsSketch(expr, maxMapSize)
val sketch = ApproxTopK.createItemsSketch(expr, maxMapSize)
new ApproxTopKAggregateBuffer[Any](sketch, 0L)
}

override def update(buffer: ItemsSketch[Any], input: InternalRow): ItemsSketch[Any] =
ApproxTopK.updateSketchBuffer(expr, buffer, input)
override def update(buffer: ApproxTopKAggregateBuffer[Any], input: InternalRow):
ApproxTopKAggregateBuffer[Any] =
buffer.update(expr, input)

override def merge(buffer: ItemsSketch[Any], input: ItemsSketch[Any]): ItemsSketch[Any] =
override def merge(
buffer: ApproxTopKAggregateBuffer[Any],
input: ApproxTopKAggregateBuffer[Any]):
ApproxTopKAggregateBuffer[Any] =
buffer.merge(input)

override def eval(buffer: ItemsSketch[Any]): Any = {
override def eval(buffer: ApproxTopKAggregateBuffer[Any]): Any = {
val sketchBytes = serialize(buffer)
val itemDataTypeDDL = ApproxTopK.dataTypeToDDL(itemDataType)
InternalRow.apply(
Expand All @@ -613,11 +570,11 @@ case class ApproxTopKAccumulate(
UTF8String.fromString(itemDataTypeDDL))
}

override def serialize(buffer: ItemsSketch[Any]): Array[Byte] =
buffer.toByteArray(ApproxTopK.genSketchSerDe(itemDataType))
override def serialize(buffer: ApproxTopKAggregateBuffer[Any]): Array[Byte] =
buffer.serialize(ApproxTopK.genSketchSerDe(itemDataType))

override def deserialize(storageFormat: Array[Byte]): ItemsSketch[Any] =
ItemsSketch.getInstance(Memory.wrap(storageFormat), ApproxTopK.genSketchSerDe(itemDataType))
override def deserialize(storageFormat: Array[Byte]): ApproxTopKAggregateBuffer[Any] =
ApproxTopKAggregateBuffer.deserialize(storageFormat, ApproxTopK.genSketchSerDe(itemDataType))

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
Expand All @@ -644,10 +601,10 @@ case class ApproxTopKAccumulate(
* @param maxItemsTracked the maximum number of items tracked in the sketch
*/
class CombineInternal[T](
sketch: ItemsSketch[T],
sketchWithNullCount: ApproxTopKAggregateBuffer[T],
var itemDataType: DataType,
var maxItemsTracked: Int) {
def getSketch: ItemsSketch[T] = sketch
def getSketchWithNullCount: ApproxTopKAggregateBuffer[T] = sketchWithNullCount

def getItemDataType: DataType = itemDataType

Expand Down Expand Up @@ -689,6 +646,9 @@ class CombineInternal[T](
}
}

def updateSketchWithNullCount(otherSketchWithNullCount: ApproxTopKAggregateBuffer[T]): Unit =
sketchWithNullCount.merge(otherSketchWithNullCount)

/**
* Serialize the CombineInternal instance to a byte array.
* Serialization format:
Expand All @@ -698,18 +658,18 @@ class CombineInternal[T](
* sketchBytes
*/
def serialize(): Array[Byte] = {
val sketchBytes = sketch.toByteArray(
val sketchWithNullCountBytes = sketchWithNullCount.serialize(
ApproxTopK.genSketchSerDe(itemDataType).asInstanceOf[ArrayOfItemsSerDe[T]])
val itemDataTypeDDL = ApproxTopK.dataTypeToDDL(itemDataType)
val ddlBytes: Array[Byte] = itemDataTypeDDL.getBytes(StandardCharsets.UTF_8)
val byteArray = new Array[Byte](
sketchBytes.length + Integer.BYTES + Integer.BYTES + ddlBytes.length)
sketchWithNullCountBytes.length + Integer.BYTES + Integer.BYTES + ddlBytes.length)

val byteBuffer = ByteBuffer.wrap(byteArray)
byteBuffer.putInt(maxItemsTracked)
byteBuffer.putInt(ddlBytes.length)
byteBuffer.put(ddlBytes)
byteBuffer.put(sketchBytes)
byteBuffer.put(sketchWithNullCountBytes)
byteArray
}
}
Expand All @@ -736,9 +696,9 @@ object CombineInternal {
// read sketchBytes
val sketchBytes = new Array[Byte](buffer.length - Integer.BYTES - Integer.BYTES - ddlLength)
byteBuffer.get(sketchBytes)
val sketch = ItemsSketch.getInstance(
Memory.wrap(sketchBytes), ApproxTopK.genSketchSerDe(itemDataType))
new CombineInternal[Any](sketch, itemDataType, maxItemsTracked)
val sketchWithNullCount = ApproxTopKAggregateBuffer.deserialize(
sketchBytes, ApproxTopK.genSketchSerDe(itemDataType))
new CombineInternal[Any](sketchWithNullCount, itemDataType, maxItemsTracked)
}
}

Expand Down Expand Up @@ -833,7 +793,7 @@ case class ApproxTopKCombine(
if (combineSizeSpecified) {
val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal)
new CombineInternal[Any](
new ItemsSketch[Any](maxMapSize),
new ApproxTopKAggregateBuffer[Any](new ItemsSketch[Any](maxMapSize), 0L),
null,
maxItemsTrackedVal)
} else {
Expand All @@ -842,7 +802,7 @@ case class ApproxTopKCombine(
// The actual maxItemsTracked will be checked during the updates.
val maxMapSize = ApproxTopK.calMaxMapSize(ApproxTopK.MAX_ITEMS_TRACKED_LIMIT)
new CombineInternal[Any](
new ItemsSketch[Any](maxMapSize),
new ApproxTopKAggregateBuffer[Any](new ItemsSketch[Any](maxMapSize), 0L),
null,
ApproxTopK.VOID_MAX_ITEMS_TRACKED)
}
Expand All @@ -863,9 +823,9 @@ case class ApproxTopKCombine(
// update itemDataType (throw error if not match)
buffer.updateItemDataType(inputItemDataType)
// update sketch
val inputSketch = ItemsSketch.getInstance(
Memory.wrap(inputSketchBytes), ApproxTopK.genSketchSerDe(buffer.getItemDataType))
buffer.getSketch.merge(inputSketch)
val inputSketchWithNullCount = ApproxTopKAggregateBuffer.deserialize(
inputSketchBytes, ApproxTopK.genSketchSerDe(inputItemDataType))
buffer.updateSketchWithNullCount(inputSketchWithNullCount)
buffer
}

Expand All @@ -876,14 +836,14 @@ case class ApproxTopKCombine(
buffer.updateMaxItemsTracked(combineSizeSpecified, input.getMaxItemsTracked)
// update itemDataType (throw error if not match)
buffer.updateItemDataType(input.getItemDataType)
// update sketch
buffer.getSketch.merge(input.getSketch)
// update sketchWithNullCount
buffer.getSketchWithNullCount.merge(input.getSketchWithNullCount)
buffer
}

override def eval(buffer: CombineInternal[Any]): Any = {
val sketchBytes =
buffer.getSketch.toByteArray(ApproxTopK.genSketchSerDe(buffer.getItemDataType))
val sketchBytes = buffer.getSketchWithNullCount
.serialize(ApproxTopK.genSketchSerDe(buffer.getItemDataType))
val maxItemsTracked = buffer.getMaxItemsTracked
val itemDataTypeDDL = ApproxTopK.dataTypeToDDL(buffer.getItemDataType)
InternalRow.apply(
Expand Down
Loading