Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bounded unique count aggregation #781

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,44 @@ class ApproxHistogram[T: FrequentItemsFriendly](mapSize: Int, errorType: ErrorTy
}
}

class BoundedUniqueCount[T](inputType: DataType, k: Int = 8) extends SimpleAggregator[T, util.Set[T], Long] {
override def prepare(input: T): util.Set[T] = {
val result = new util.HashSet[T](k)
result.add(input)
result
}

override def update(ir: util.Set[T], input: T): util.Set[T] = {
if (ir.size() >= k) {
return ir
}

ir.add(input)
ir
}

override def outputType: DataType = LongType

override def irType: DataType = ListType(inputType)

override def merge(ir1: util.Set[T], ir2: util.Set[T]): util.Set[T] = {
ir2.asScala.foreach(v =>
if (ir1.size() < k) {
ir1.add(v)
})

ir1
}

override def finalize(ir: util.Set[T]): Long = ir.size()

override def clone(ir: util.Set[T]): util.Set[T] = new util.HashSet[T](ir)

override def normalize(ir: util.Set[T]): Any = new util.ArrayList[T](ir)

override def denormalize(ir: Any): util.Set[T] = new util.HashSet[T](ir.asInstanceOf[util.ArrayList[T]])
}

// Based on CPC sketch (a faster, smaller and more accurate version of HLL)
// See: Back to the future: an even more nearly optimal cardinality estimation algorithm, 2017
// https://arxiv.org/abs/1708.06839
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,19 @@ object ColumnAggregator {
case BinaryType => simple(new ApproxDistinctCount[Array[Byte]](aggregationPart.getInt("k", Some(8))))
case _ => mismatchException
}
case Operation.BOUNDED_UNIQUE_COUNT =>
val k = aggregationPart.getInt("k", Some(8))

inputType match {
case IntType => simple(new BoundedUniqueCount[Int](inputType, k))
case LongType => simple(new BoundedUniqueCount[Long](inputType, k))
case ShortType => simple(new BoundedUniqueCount[Short](inputType, k))
case DoubleType => simple(new BoundedUniqueCount[Double](inputType, k))
case FloatType => simple(new BoundedUniqueCount[Float](inputType, k))
case StringType => simple(new BoundedUniqueCount[String](inputType, k))
case BinaryType => simple(new BoundedUniqueCount[Array[Byte]](inputType, k))
case _ => mismatchException
}
case Operation.APPROX_PERCENTILE =>
val k = aggregationPart.getInt("k", Some(128))
val mapper = new ObjectMapper()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package ai.chronon.aggregator.test

import ai.chronon.aggregator.base.BoundedUniqueCount
import ai.chronon.api.StringType
import junit.framework.TestCase
import org.junit.Assert._

import java.util
import scala.jdk.CollectionConverters._

class BoundedUniqueCountTest extends TestCase {
def testHappyCase(): Unit = {
val boundedDistinctCount = new BoundedUniqueCount[String](StringType, 5)
var ir = boundedDistinctCount.prepare("1")
ir = boundedDistinctCount.update(ir, "1")
ir = boundedDistinctCount.update(ir, "2")

val result = boundedDistinctCount.finalize(ir)
assertEquals(2, result)
}

def testExceedSize(): Unit = {
val boundedDistinctCount = new BoundedUniqueCount[String](StringType, 5)
var ir = boundedDistinctCount.prepare("1")
ir = boundedDistinctCount.update(ir, "2")
ir = boundedDistinctCount.update(ir, "3")
ir = boundedDistinctCount.update(ir, "4")
ir = boundedDistinctCount.update(ir, "5")
ir = boundedDistinctCount.update(ir, "6")
ir = boundedDistinctCount.update(ir, "7")

val result = boundedDistinctCount.finalize(ir)
assertEquals(5, result)
}

def testMerge(): Unit = {
val boundedDistinctCount = new BoundedUniqueCount[String](StringType, 5)
val ir1 = new util.HashSet[String](Seq("1", "2", "3").asJava)
val ir2 = new util.HashSet[String](Seq("4", "5", "6").asJava)

val merged = boundedDistinctCount.merge(ir1, ir2)
assertEquals(merged.size(), 5)
}
}
1 change: 1 addition & 0 deletions api/py/ai/chronon/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class Operation:
# https://github.com/apache/incubator-datasketches-java/blob/master/src/main/java/org/apache/datasketches/cpc/CpcSketch.java#L180
APPROX_UNIQUE_COUNT_LGK = collector(ttypes.Operation.APPROX_UNIQUE_COUNT)
UNIQUE_COUNT = ttypes.Operation.UNIQUE_COUNT
BOUNDED_UNIQUE_COUNT = ttypes.Operation.BOUNDED_UNIQUE_COUNT
COUNT = ttypes.Operation.COUNT
SUM = ttypes.Operation.SUM
AVERAGE = ttypes.Operation.AVERAGE
Expand Down
3 changes: 2 additions & 1 deletion api/thrift/api.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ enum Operation {
BOTTOM_K = 16

HISTOGRAM = 17, // use this only if you know the set of inputs is bounded
APPROX_HISTOGRAM_K = 18
APPROX_HISTOGRAM_K = 18,
BOUNDED_UNIQUE_COUNT = 19
}

// integers map to milliseconds in the timeunit
Expand Down
1 change: 1 addition & 0 deletions docs/source/authoring_features/GroupBy.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ Limitations:
| approx_unique_count | primitive types | list, map | long | no | k=8 | yes |
| approx_percentile | primitive types | list, map | list<input,> | no | k=128, percentiles | yes |
| unique_count | primitive types | list, map | long | no | | no |
| bounded_unique_count | primitive types | list, map | long | no | k=inf | yes |


## Accuracy
Expand Down
12 changes: 10 additions & 2 deletions spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,11 @@ class FetcherTest extends TestCase {
Builders.Aggregation(operation = Operation.LAST_K,
argMap = Map("k" -> "300"),
inputColumn = "user",
windows = Seq(new Window(2, TimeUnit.DAYS), new Window(30, TimeUnit.DAYS)))
),
windows = Seq(new Window(2, TimeUnit.DAYS), new Window(30, TimeUnit.DAYS))),
Builders.Aggregation(operation = Operation.BOUNDED_UNIQUE_COUNT,
argMap = Map("k" -> "5"),
inputColumn = "user",
windows = Seq(new Window(2, TimeUnit.DAYS), new Window(30, TimeUnit.DAYS)))),
metaData = Builders.MetaData(name = "unit_test/vendor_ratings", namespace = namespace),
accuracy = Accuracy.SNAPSHOT
)
Expand Down Expand Up @@ -547,6 +550,11 @@ class FetcherTest extends TestCase {
operation = Operation.APPROX_HISTOGRAM_K,
inputColumn = "rating",
windows = Seq(new Window(1, TimeUnit.DAYS))
),
Builders.Aggregation(
operation = Operation.BOUNDED_UNIQUE_COUNT,
inputColumn = "rating",
windows = Seq(new Window(1, TimeUnit.DAYS))
)
),
accuracy = Accuracy.TEMPORAL,
Expand Down
42 changes: 42 additions & 0 deletions spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -656,4 +656,46 @@ class GroupByTest {
tableUtils = tableUtils,
additionalAgg = aggs)
}

@Test
def testBoundedUniqueCounts(): Unit = {
val (source, endPartition) = createTestSource(suffix = "_bounded_counts")
val tableUtils = TableUtils(spark)
val namespace = "test_bounded_counts"
val aggs = Seq(
Builders.Aggregation(
operation = Operation.BOUNDED_UNIQUE_COUNT,
inputColumn = "item",
windows = Seq(
new Window(15, TimeUnit.DAYS),
new Window(60, TimeUnit.DAYS)
),
argMap = Map("k" -> "5")
),
Builders.Aggregation(
operation = Operation.BOUNDED_UNIQUE_COUNT,
inputColumn = "price",
windows = Seq(
new Window(15, TimeUnit.DAYS),
new Window(60, TimeUnit.DAYS)
),
argMap = Map("k" -> "5")
),
)
backfill(name = "unit_test_group_by_bounded_counts",
source = source,
endPartition = endPartition,
namespace = namespace,
tableUtils = tableUtils,
additionalAgg = aggs)

val result = spark.sql(
"""
|select *
|from test_bounded_counts.unit_test_group_by_bounded_counts
|where item_bounded_unique_count_60d > 5 or price_bounded_unique_count_60d > 5
|""".stripMargin)

assertTrue(result.isEmpty)
}
}