Skip to content
Closed
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,55 +17,21 @@

package org.apache.spark.sql.execution.stat

import scala.collection.mutable.{Map => MutableMap}
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}

import scala.collection.mutable

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.{functions, Column, DataFrame}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._

object FrequentItems extends Logging {

/** A helper class wrapping `MutableMap[Any, Long]` for simplicity. */
private class FreqItemCounter(size: Int) extends Serializable {
val baseMap: MutableMap[Any, Long] = MutableMap.empty[Any, Long]
/**
* Add a new example to the counts if it exists, otherwise deduct the count
* from existing items.
*/
def add(key: Any, count: Long): this.type = {
if (baseMap.contains(key)) {
baseMap(key) += count
} else {
if (baseMap.size < size) {
baseMap += key -> count
} else {
val minCount = if (baseMap.values.isEmpty) 0 else baseMap.values.min
val remainder = count - minCount
if (remainder >= 0) {
baseMap += key -> count // something will get kicked out, so we can add this
baseMap.retain((k, v) => v > minCount)
baseMap.transform((k, v) => v - minCount)
} else {
baseMap.transform((k, v) => v - count)
}
}
}
this
}

/**
* Merge two maps of counts.
* @param other The map containing the counts for that partition
*/
def merge(other: FreqItemCounter): this.type = {
other.baseMap.foreach { case (k, v) =>
add(k, v)
}
this
}
}

/**
* Finding frequent items for columns, possibly with false positives. Using the
* frequent element count algorithm described in
Expand All @@ -85,42 +51,144 @@ object FrequentItems extends Logging {
cols: Seq[String],
support: Double): DataFrame = {
require(support >= 1e-4 && support <= 1.0, s"Support must be in [1e-4, 1], but got $support.")
val numCols = cols.length

// number of max items to keep counts for
val sizeOfMap = (1 / support).toInt
val countMaps = Seq.tabulate(numCols)(i => new FreqItemCounter(sizeOfMap))

val freqItems = df.select(cols.map(Column(_)) : _*).rdd.treeAggregate(countMaps)(
seqOp = (counts, row) => {
var i = 0
while (i < numCols) {
val thisMap = counts(i)
val key = row.get(i)
thisMap.add(key, 1L)
i += 1
}
counts
},
combOp = (baseCounts, counts) => {
var i = 0
while (i < numCols) {
baseCounts(i).merge(counts(i))
i += 1

val frequentItemCols = cols.map { col =>
val aggExpr = new CollectFrequentItems(functions.col(col).expr, sizeOfMap)
Column(aggExpr.toAggregateExpression(isDistinct = false)).as(s"${col}_freqItems")
}

df.select(frequentItemCols: _*)
}
}

case class CollectFrequentItems(
child: Expression,
size: Int,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[mutable.Map[Any, Long]]
with ImplicitCastInputTypes with UnaryLike[Expression] {
require(size > 0)

def this(child: Expression, size: Int) = this(child, size, 0, 0)

// Returns empty array for empty inputs
override def nullable: Boolean = false

override def dataType: DataType = ArrayType(child.dataType, containsNull = child.nullable)

override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems we don't have any input type requirement, we don't need to extend ImplicitCastInputTypes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, will update


override def prettyName: String = "collect_frequent_items"

override def createAggregationBuffer(): mutable.Map[Any, Long] =
mutable.Map.empty[Any, Long]

private def add(map: mutable.Map[Any, Long], key: Any, count: Long): mutable.Map[Any, Long] = {
if (map.contains(key)) {
map(key) += count
} else {
if (map.size < size) {
map += key -> count
} else {
val minCount = if (map.values.isEmpty) 0 else map.values.min
val remainder = count - minCount
if (remainder >= 0) {
map += key -> count // something will get kicked out, so we can add this
map.retain((k, v) => v > minCount)
map.transform((k, v) => v - minCount)
} else {
map.transform((k, v) => v - count)
}
baseCounts
}
)
val justItems = freqItems.map(m => m.baseMap.keys.toArray)
val resultRow = Row(justItems : _*)
}
map
}

override def update(
buffer: mutable.Map[Any, Long],
input: InternalRow): mutable.Map[Any, Long] = {
val key = child.eval(input)
if (key != null) {
this.add(buffer, InternalRow.copyValue(key), 1L)
} else {
this.add(buffer, key, 1L)
}
}

override def merge(
buffer: mutable.Map[Any, Long],
input: mutable.Map[Any, Long]): mutable.Map[Any, Long] = {
val otherIter = input.iterator
while (otherIter.hasNext) {
val (key, count) = otherIter.next
add(buffer, key, count)
}
buffer
}

val outputCols = cols.map { name =>
val originalField = df.resolve(name)
override def eval(buffer: mutable.Map[Any, Long]): Any =
new GenericArrayData(buffer.keys.toArray)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we reuse one GenericArrayData with cleaning up?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it is only invoked once, so no need to reuse?


// append frequent Items to the column name for easy debugging
StructField(name + "_freqItems", ArrayType(originalField.dataType, originalField.nullable))
}.toArray
private lazy val projection =
UnsafeProjection.create(Array[DataType](child.dataType, LongType))

val schema = StructType(outputCols).toAttributes
Dataset.ofRows(df.sparkSession, LocalRelation.fromExternalRows(schema, Seq(resultRow)))
override def serialize(map: mutable.Map[Any, Long]): Array[Byte] = {
val buffer = new Array[Byte](4 << 10) // 4K
val bos = new ByteArrayOutputStream()
val out = new DataOutputStream(bos)
try {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can leverage Utils.tryWithResource

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, will update

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seem we can not leverage Utils.tryWithResource here, since Utils.tryWithResource only support single Closeable but there are two ones bos and out.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could use Utils.tryWithSafeFinally but that's fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool, let me update it.

// Write pairs in counts map to byte buffer.
map.foreach { case (key, count) =>
val row = InternalRow.apply(key, count)
val unsafeRow = projection.apply(row)
out.writeInt(unsafeRow.getSizeInBytes)
unsafeRow.writeToStream(out, buffer)
}
out.writeInt(-1)
out.flush()

bos.toByteArray
} finally {
out.close()
bos.close()
}
}

override def deserialize(bytes: Array[Byte]): mutable.Map[Any, Long] = {
val bis = new ByteArrayInputStream(bytes)
val ins = new DataInputStream(bis)
try {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

val map = mutable.Map.empty[Any, Long]
// Read unsafeRow size and content in bytes.
var sizeOfNextRow = ins.readInt()
while (sizeOfNextRow >= 0) {
val bs = new Array[Byte](sizeOfNextRow)
ins.readFully(bs)
val row = new UnsafeRow(2)
row.pointTo(bs, sizeOfNextRow)
// Insert the pairs into counts map.
val key = row.get(0, child.dataType)
val count = row.get(1, LongType).asInstanceOf[Long]
map.update(key, count)
sizeOfNextRow = ins.readInt()
}

map
} finally {
ins.close()
bis.close()
}
}

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)

override protected def withNewChildInternal(newChild: Expression): Expression =
copy(child = newChild)
}