Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -391,13 +391,19 @@ object FrequentItemsFriendly {
}
}

class FrequentItems[T: FrequentItemsFriendly](val mapSize: Int, val errorType: ErrorType = ErrorType.NO_FALSE_POSITIVES)
class FrequentItems[T: FrequentItemsFriendly](val mapSize: Int, val errorType: ErrorType = ErrorType.NO_FALSE_NEGATIVES)
extends SimpleAggregator[T, ItemsSketchIR[T], util.Map[String, Long]] {
private type Sketch = ItemsSketchIR[T]

// The ItemsSketch implementation requires a size with a positive power of 2
// Initialize the sketch with the next closest power of 2
val sketchSize: Int = if (mapSize > 1) Integer.highestOneBit(mapSize - 1) << 1 else 2
val sketchSize: Int = {
// during the purge of internal map this removes more half the elements
// and internal map is 0.75x of k - so to keep k at all times - we need to set mapSize = k / (0.5*0.75)
val effectiveMapSize = math.ceil(mapSize.toDouble / (0.75 * 0.5)).toInt

// The ItemsSketch implementation requires a size with a positive power of 2
// Initialize the sketch with the next closest power of 2
if (effectiveMapSize > 1) Integer.highestOneBit(effectiveMapSize - 1) << 1 else 2
}

override def outputType: DataType = MapType(StringType, LongType)

Expand Down Expand Up @@ -432,6 +438,11 @@ class FrequentItems[T: FrequentItemsFriendly](val mapSize: Int, val errorType: E
return new util.HashMap[String, Long]()
}

// useful with debugger on - keep around
// val outputSketchSize = ir.sketch.getNumActiveItems
// val serializer = implicitly[FrequentItemsFriendly[T]].serializer
// val outputSketchBytes = ir.sketch.toByteArray(serializer).length

val items = ir.sketch.getFrequentItems(errorType).map(sk => sk.getItem -> sk.getEstimate)
val heap = mutable.PriorityQueue[(T, Long)]()(Ordering.by(_._2))

Expand Down Expand Up @@ -473,150 +484,6 @@ class FrequentItems[T: FrequentItemsFriendly](val mapSize: Int, val errorType: E
}
}

case class ApproxHistogramIr[T: FrequentItemsFriendly](
isApprox: Boolean,
sketch: Option[ItemsSketchIR[T]],
histogram: Option[util.Map[T, Long]]
)

case class ApproxHistogramIrSerializable[T: FrequentItemsFriendly](
isApprox: Boolean,
// The ItemsSketch isn't directly serializable
sketch: Option[Array[Byte]],
histogram: Option[util.Map[T, Long]]
)

// The ItemsSketch uses approximations and estimates for both values below and above k.
// This keeps an exact aggregation for entries where the number of keys is < k, and switches over to the sketch
// when the underlying map exceeds k keys.
class ApproxHistogram[T: FrequentItemsFriendly](mapSize: Int, errorType: ErrorType = ErrorType.NO_FALSE_POSITIVES)
extends SimpleAggregator[T, ApproxHistogramIr[T], util.Map[String, Long]] {
private val frequentItemsAggregator = new FrequentItems[T](mapSize, errorType)
override def prepare(input: T): ApproxHistogramIr[T] = {
val histogram = new util.HashMap[T, Long]()
histogram.put(input, 1L)
ApproxHistogramIr(isApprox = false, sketch = None, histogram = Some(histogram))
}

override def update(ir: ApproxHistogramIr[T], input: T): ApproxHistogramIr[T] = {
(ir.histogram, ir.sketch) match {
case (Some(hist), _) =>
increment(input, 1L, hist)
toIr(hist)
case (_, Some(sketch)) =>
sketch.sketch.update(input)
ApproxHistogramIr(isApprox = true, sketch = Some(sketch), histogram = None)
case _ => throw new IllegalStateException("Histogram state is missing")
}
}

override def outputType: DataType = MapType(StringType, LongType)
override def irType: DataType = BinaryType

override def merge(ir1: ApproxHistogramIr[T], ir2: ApproxHistogramIr[T]): ApproxHistogramIr[T] = {
(ir1.histogram, ir1.sketch, ir2.histogram, ir2.sketch) match {
case (Some(hist1), None, Some(hist2), None) => combine(hist1, hist2)
case (None, Some(sketch1), None, Some(sketch2)) => combine(sketch1, sketch2)
case (Some(hist1), None, None, Some(sketch2)) => combine(hist1, sketch2)
case (None, Some(sketch1), Some(hist2), None) => combine(hist2, sketch1)
case _ => throw new IllegalStateException("Histogram state is missing")
}
}

override def finalize(ir: ApproxHistogramIr[T]): util.Map[String, Long] = {
(ir.sketch, ir.histogram) match {
case (Some(sketch), None) => frequentItemsAggregator.finalize(sketch)
case (None, Some(hist)) => toOutputMap(hist)
case _ => throw new IllegalStateException("Histogram state is missing")
}
}

override def clone(ir: ApproxHistogramIr[T]): ApproxHistogramIr[T] = {
(ir.sketch, ir.histogram) match {
case (Some(sketch), None) =>
val clone = frequentItemsAggregator.clone(sketch)
ApproxHistogramIr(isApprox = true, sketch = Some(clone), histogram = None)
case (None, Some(hist)) =>
val clone = new util.HashMap[T, Long](hist)
ApproxHistogramIr(isApprox = false, sketch = None, histogram = Some(clone))
case _ => throw new IllegalStateException("Histogram state is missing")
}
}

override def normalize(ir: ApproxHistogramIr[T]): Any = {
val serializable = ApproxHistogramIrSerializable(
isApprox = ir.isApprox,
sketch = ir.sketch.map(frequentItemsAggregator.normalize),
histogram = ir.histogram
)

val byteStream = new ByteArrayOutputStream()
val outputStream = new ObjectOutputStream(byteStream)

try {
outputStream.writeObject(serializable)
} finally {
outputStream.close()
byteStream.close()
}

byteStream.toByteArray
}

override def denormalize(ir: Any): ApproxHistogramIr[T] = {
val bytes = ir.asInstanceOf[Array[Byte]]

val byteStream = new ByteArrayInputStream(bytes)
val objectStream = new ObjectInputStream(byteStream)

try {
val serializable = objectStream.readObject().asInstanceOf[ApproxHistogramIrSerializable[T]]
ApproxHistogramIr(
isApprox = serializable.isApprox,
sketch = serializable.sketch.map(frequentItemsAggregator.denormalize),
histogram = serializable.histogram
)
} finally {
objectStream.close()
byteStream.close()
}
}

private def combine(hist1: util.Map[T, Long], hist2: util.Map[T, Long]): ApproxHistogramIr[T] = {
val hist = new util.HashMap[T, Long]()

hist1.asScala.foreach({ case (k, v) => increment(k, v, hist) })
hist2.asScala.foreach({ case (k, v) => increment(k, v, hist) })

toIr(hist)
}
private def combine(sketch1: ItemsSketchIR[T], sketch2: ItemsSketchIR[T]): ApproxHistogramIr[T] = {
val sketch = frequentItemsAggregator.merge(sketch1, sketch2)
ApproxHistogramIr(isApprox = true, sketch = Some(sketch), histogram = None)
}
private def combine(hist: util.Map[T, Long], sketch: ItemsSketchIR[T]): ApproxHistogramIr[T] = {
hist.asScala.foreach({ case (k, v) => sketch.sketch.update(k, v) })
ApproxHistogramIr(isApprox = true, sketch = Some(sketch), histogram = None)
}

private def toIr(hist: util.Map[T, Long]): ApproxHistogramIr[T] = {
if (hist.size > mapSize)
ApproxHistogramIr(isApprox = true, sketch = Some(frequentItemsAggregator.toSketch(hist)), histogram = None)
else
ApproxHistogramIr(isApprox = false, sketch = None, histogram = Some(hist))
}

private def increment(value: T, times: Long, values: util.Map[T, Long]): Unit = {
values.put(value, values.getOrDefault(value, 0) + times)
}

private def toOutputMap(map: util.Map[T, Long]): util.Map[String, Long] = {
val result = new util.HashMap[String, Long](map.size())
map.asScala.foreach({ case (k, v) => result.put(String.valueOf(k), v) })
result
}
}

// Based on CPC sketch (a faster, smaller and more accurate version of HLL)
// See: Back to the future: an even more nearly optimal cardinality estimation algorithm, 2017
// https://arxiv.org/abs/1708.06839
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import ai.chronon.api.Extensions.AggregationPartOps
import ai.chronon.api.Extensions.OperationOps
import ai.chronon.api._
import com.fasterxml.jackson.databind.ObjectMapper
import org.apache.datasketches.frequencies.ErrorType

import java.util
import scala.collection.JavaConverters.asScalaIteratorConverter
Expand Down Expand Up @@ -133,25 +134,28 @@ case class ColumnIndices(input: Int, output: Int)

object ColumnAggregator {

private def toJLong(l: Long): java.lang.Long = java.lang.Long.valueOf(l)
private def toJDouble(d: Double): java.lang.Double = java.lang.Double.valueOf(d)

def castToLong(value: AnyRef): AnyRef =
value match {
case i: java.lang.Integer => new java.lang.Long(i.longValue())
case i: java.lang.Short => new java.lang.Long(i.longValue())
case i: java.lang.Byte => new java.lang.Long(i.longValue())
case i: java.lang.Double => new java.lang.Long(i.longValue())
case i: java.lang.Float => new java.lang.Long(i.longValue())
case i: java.lang.String => new java.lang.Long(java.lang.Long.parseLong(i))
case i: java.lang.Integer => toJLong(i.longValue())
case i: java.lang.Short => toJLong(i.longValue())
case i: java.lang.Byte => toJLong(i.longValue())
case i: java.lang.Double => toJLong(i.longValue())
case i: java.lang.Float => toJLong(i.longValue())
case i: java.lang.String => toJLong(java.lang.Long.parseLong(i))
case _ => value
}

def castToDouble(value: AnyRef): AnyRef =
value match {
case i: java.lang.Integer => new java.lang.Double(i.doubleValue())
case i: java.lang.Short => new java.lang.Double(i.doubleValue())
case i: java.lang.Byte => new java.lang.Double(i.doubleValue())
case i: java.lang.Float => new java.lang.Double(i.doubleValue())
case i: java.lang.Long => new java.lang.Double(i.doubleValue())
case i: java.lang.String => new java.lang.Double(java.lang.Double.parseDouble(i))
case i: java.lang.Integer => toJDouble(i.doubleValue())
case i: java.lang.Short => toJDouble(i.doubleValue())
case i: java.lang.Byte => toJDouble(i.doubleValue())
case i: java.lang.Float => toJDouble(i.doubleValue())
case i: java.lang.Long => toJDouble(i.doubleValue())
case i: java.lang.String => toJDouble(java.lang.Double.parseDouble(i))
case _ => value
}

Expand Down Expand Up @@ -260,15 +264,28 @@ object ColumnAggregator {
aggregationPart.operation match {
case Operation.COUNT => simple(new Count)
case Operation.HISTOGRAM => simple(new Histogram(aggregationPart.getInt("k", Some(0))))
case Operation.APPROX_HISTOGRAM_K =>
case Operation.APPROX_FREQUENT_K =>
val k = aggregationPart.getInt("k", Some(8))
inputType match {
case IntType => simple(new FrequentItems[java.lang.Long](k), toJavaLong[Int])
case LongType => simple(new FrequentItems[java.lang.Long](k))
case ShortType => simple(new FrequentItems[java.lang.Long](k), toJavaLong[Short])
case DoubleType => simple(new FrequentItems[java.lang.Double](k))
case FloatType => simple(new FrequentItems[java.lang.Double](k), toJavaDouble[Float])
case StringType => simple(new FrequentItems[String](k))
case _ => mismatchException
}
case Operation.APPROX_HEAVY_HITTERS_K =>
val k = aggregationPart.getInt("k", Some(8))
inputType match {
case IntType => simple(new ApproxHistogram[java.lang.Long](k), toJavaLong[Int])
case LongType => simple(new ApproxHistogram[java.lang.Long](k))
case ShortType => simple(new ApproxHistogram[java.lang.Long](k), toJavaLong[Short])
case DoubleType => simple(new ApproxHistogram[java.lang.Double](k))
case FloatType => simple(new ApproxHistogram[java.lang.Double](k), toJavaDouble[Float])
case StringType => simple(new ApproxHistogram[String](k))
case IntType => simple(new FrequentItems[java.lang.Long](k, ErrorType.NO_FALSE_POSITIVES), toJavaLong[Int])
case LongType => simple(new FrequentItems[java.lang.Long](k, ErrorType.NO_FALSE_POSITIVES))
case ShortType =>
simple(new FrequentItems[java.lang.Long](k, ErrorType.NO_FALSE_POSITIVES), toJavaLong[Short])
case DoubleType => simple(new FrequentItems[java.lang.Double](k, ErrorType.NO_FALSE_POSITIVES))
case FloatType =>
simple(new FrequentItems[java.lang.Double](k, ErrorType.NO_FALSE_POSITIVES), toJavaDouble[Float])
case StringType => simple(new FrequentItems[String](k, ErrorType.NO_FALSE_POSITIVES))
case _ => mismatchException
}
case Operation.SUM =>
Expand Down
Loading
Loading