From eaacfa6aa2609495321a90f5e45b5dbd35cd8d89 Mon Sep 17 00:00:00 2001 From: chaitu Date: Fri, 30 May 2025 13:11:41 -0700 Subject: [PATCH 01/28] code to write daily irs --- .../main/scala/ai/chronon/spark/GroupBy.scala | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala index 79ef1095d3..8fbd990083 100644 --- a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala +++ b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala @@ -546,12 +546,23 @@ object GroupBy { df } + //make it parameterized + val incrementalAgg = true + if (incrementalAgg) { + new GroupBy(Option(groupByConf.getAggregations).map(_.toScala).orNull, + keyColumns, + nullFiltered, + mutationDfFn, + finalize = false) + } else { + new GroupBy(Option(groupByConf.getAggregations).map(_.toScala).orNull, + keyColumns, + nullFiltered, + mutationDfFn, + finalize = finalize) + + } - new GroupBy(Option(groupByConf.getAggregations).map(_.toScala).orNull, - keyColumns, - nullFiltered, - mutationDfFn, - finalize = finalize) } def getIntersectedRange(source: api.Source, From 40b6cb2645bbe18d2d7b6b19289b83aea1d262c8 Mon Sep 17 00:00:00 2001 From: chaitu Date: Tue, 3 Jun 2025 10:47:48 -0700 Subject: [PATCH 02/28] store incremental agg and compute final IRs --- .../scala/ai/chronon/api/Extensions.scala | 1 + .../main/scala/ai/chronon/spark/GroupBy.scala | 83 ++++++++++++++----- 2 files changed, 61 insertions(+), 23 deletions(-) diff --git a/api/src/main/scala/ai/chronon/api/Extensions.scala b/api/src/main/scala/ai/chronon/api/Extensions.scala index 0ce907145b..d813b945a6 100644 --- a/api/src/main/scala/ai/chronon/api/Extensions.scala +++ b/api/src/main/scala/ai/chronon/api/Extensions.scala @@ -98,6 +98,7 @@ object Extensions { def cleanName: String = metaData.name.sanitize def outputTable = s"${metaData.outputNamespace}.${metaData.cleanName}" + def incOutputTable = s"${metaData.outputNamespace}.${metaData.cleanName}_inc" def outputLabelTable = s"${metaData.outputNamespace}.${metaData.cleanName}_labels" def outputFinalView = s"${metaData.outputNamespace}.${metaData.cleanName}_labeled" def outputLatestLabelView = s"${metaData.outputNamespace}.${metaData.cleanName}_labeled_latest" diff --git a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala index 8fbd990083..b7dd1f4f67 100644 --- a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala +++ b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala @@ -461,18 +461,19 @@ object GroupBy { bloomMapOpt: Option[util.Map[String, BloomFilter]] = None, skewFilter: Option[String] = None, finalize: Boolean = true, - showDf: Boolean = false): GroupBy = { + showDf: Boolean = false, + incrementalAgg: Boolean = false): GroupBy = { logger.info(s"\n----[Processing GroupBy: ${groupByConfOld.metaData.name}]----") val groupByConf = replaceJoinSource(groupByConfOld, queryRange, tableUtils, computeDependency, showDf) val inputDf = groupByConf.sources.toScala .map { source => renderDataSourceQuery(groupByConf, - source, - groupByConf.getKeyColumns.toScala, - queryRange, - tableUtils, - groupByConf.maxWindow, - groupByConf.inferredAccuracy) + source, + groupByConf.getKeyColumns.toScala, + queryRange, + tableUtils, + groupByConf.maxWindow, + groupByConf.inferredAccuracy) } .map { @@ -543,26 +544,18 @@ object GroupBy { logger.info(s"printing mutation data for groupBy: ${groupByConf.metaData.name}") df.prettyPrint() } - df } - //make it parameterized - val incrementalAgg = true - if (incrementalAgg) { - new GroupBy(Option(groupByConf.getAggregations).map(_.toScala).orNull, - keyColumns, - nullFiltered, - mutationDfFn, - finalize = false) + val finalizeValue = if (incrementalAgg) { + !incrementalAgg } else { - new GroupBy(Option(groupByConf.getAggregations).map(_.toScala).orNull, + finalize + } + new GroupBy(Option(groupByConf.getAggregations).map(_.toScala).orNull, keyColumns, nullFiltered, mutationDfFn, - finalize = finalize) - - } - + finalize = finalizeValue) } def getIntersectedRange(source: api.Source, @@ -681,12 +674,51 @@ object GroupBy { query } + def saveAndGetIncDf( + groupByConf: api.GroupBy, + range: PartitionRange, + tableUtils: TableUtils, + ): GroupBy = { + val incOutputTable = groupByConf.metaData.incOutputTable + val tableProps = Option(groupByConf.metaData.tableProperties) + .map(_.toScala) + .orNull + //range should be modified to incremental range + val incGroupByBackfill = from(groupByConf, range, tableUtils, computeDependency = true, incrementalAgg = true) + val incOutputDf = incGroupByBackfill.snapshotEvents(range) + incOutputDf.save(incOutputTable, tableProps) + + val maxWindow = groupByConf.maxWindow.get + val sourceQueryableRange = PartitionRange( + range.start, + tableUtils.partitionSpec.minus(range.end, maxWindow) + )(tableUtils) + + val incTableFirstPartition: Option[String] = tableUtils.firstAvailablePartition(incOutputTable) + val incTableLastPartition: Option[String] = tableUtils.lastAvailablePartition(incOutputTable) + + val incTableRange = PartitionRange( + incTableFirstPartition.get, + incTableLastPartition.get + )(tableUtils) + + val incDfQuery = incTableRange.intersect(sourceQueryableRange).genScanQuery(null, incOutputTable) + val incDf: DataFrame = tableUtils.sql(incDfQuery) + + new GroupBy( + incGroupByBackfill.aggregations, + incGroupByBackfill.keyColumns, + incDf + ) + } + def computeBackfill(groupByConf: api.GroupBy, endPartition: String, tableUtils: TableUtils, stepDays: Option[Int] = None, overrideStartPartition: Option[String] = None, - skipFirstHole: Boolean = true): Unit = { + skipFirstHole: Boolean = true, + incrementalAgg: Boolean = true): Unit = { assert( groupByConf.backfillStartDate != null, s"GroupBy:${groupByConf.metaData.name} has null backfillStartDate. This needs to be set for offline backfilling.") @@ -725,7 +757,12 @@ object GroupBy { stepRanges.zipWithIndex.foreach { case (range, index) => logger.info(s"Computing group by for range: $range [${index + 1}/${stepRanges.size}]") - val groupByBackfill = from(groupByConf, range, tableUtils, computeDependency = true) + val groupByBackfill = if (incrementalAgg) { + saveAndGetIncDf(groupByConf, range, tableUtils) + //from(groupByConf, range, tableUtils, computeDependency = true) + } else { + from(groupByConf, range, tableUtils, computeDependency = true) + } val outputDf = groupByConf.dataModel match { // group by backfills have to be snapshot only case Entities => groupByBackfill.snapshotEntities From a014b6ef4706ce7200af1bb8a0681d04fc25208d Mon Sep 17 00:00:00 2001 From: chaitu Date: Fri, 6 Jun 2025 19:04:46 -0700 Subject: [PATCH 03/28] Store hops to inc tables --- .../aggregator/row/RowAggregator.scala | 5 ++ .../scala/ai/chronon/api/Extensions.scala | 5 ++ .../main/scala/ai/chronon/spark/GroupBy.scala | 68 +++++++++++++++---- 3 files changed, 66 insertions(+), 12 deletions(-) diff --git a/aggregator/src/main/scala/ai/chronon/aggregator/row/RowAggregator.scala b/aggregator/src/main/scala/ai/chronon/aggregator/row/RowAggregator.scala index c8bc1da08c..6bda47bf19 100644 --- a/aggregator/src/main/scala/ai/chronon/aggregator/row/RowAggregator.scala +++ b/aggregator/src/main/scala/ai/chronon/aggregator/row/RowAggregator.scala @@ -70,6 +70,11 @@ class RowAggregator(val inputSchema: Seq[(String, DataType)], val aggregationPar .toArray .zip(columnAggregators.map(_.irType)) + val incSchema = aggregationParts + .map(_.incOutputColumnName) + .toArray + .zip(columnAggregators.map(_.irType)) + val outputSchema: Array[(String, DataType)] = aggregationParts .map(_.outputColumnName) .toArray diff --git a/api/src/main/scala/ai/chronon/api/Extensions.scala b/api/src/main/scala/ai/chronon/api/Extensions.scala index d813b945a6..c6c8074757 100644 --- a/api/src/main/scala/ai/chronon/api/Extensions.scala +++ b/api/src/main/scala/ai/chronon/api/Extensions.scala @@ -177,8 +177,13 @@ object Extensions { def outputColumnName = s"${aggregationPart.inputColumn}_$opSuffix${aggregationPart.window.suffix}${bucketSuffix}" + + def incOutputColumnName = + s"${aggregationPart.inputColumn}_$opSuffix${bucketSuffix}" + } + implicit class AggregationOps(aggregation: Aggregation) { // one agg part per bucket per window diff --git a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala index b7dd1f4f67..a3f76f1483 100644 --- a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala +++ b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala @@ -18,6 +18,7 @@ package ai.chronon.spark import ai.chronon.aggregator.base.TimeTuple import ai.chronon.aggregator.row.RowAggregator +import ai.chronon.aggregator.windowing.HopsAggregator.HopIr import ai.chronon.aggregator.windowing._ import ai.chronon.api import ai.chronon.api.DataModel.{Entities, Events} @@ -41,7 +42,9 @@ class GroupBy(val aggregations: Seq[api.Aggregation], val inputDf: DataFrame, val mutationDfFn: () => DataFrame = null, skewFilter: Option[String] = None, - finalize: Boolean = true) + finalize: Boolean = true, + incAgg: Boolean = false + ) extends Serializable { @transient lazy val logger = LoggerFactory.getLogger(getClass) @@ -88,7 +91,11 @@ class GroupBy(val aggregations: Seq[api.Aggregation], lazy val aggPartWithSchema = aggregationParts.zip(columnAggregators.map(_.outputType)) lazy val postAggSchema: StructType = { - val valueChrononSchema = if (finalize) windowAggregator.outputSchema else windowAggregator.irSchema + val valueChrononSchema = if (finalize) { + windowAggregator.outputSchema + } else { + windowAggregator.irSchema + } SparkConversions.fromChrononSchema(valueChrononSchema) } @@ -141,12 +148,13 @@ class GroupBy(val aggregations: Seq[api.Aggregation], } def snapshotEventsBase(partitionRange: PartitionRange, - resolution: Resolution = DailyResolution): RDD[(Array[Any], Array[Any])] = { + resolution: Resolution = DailyResolution, + incAgg: Boolean = true): RDD[(Array[Any], Array[Any])] = { val endTimes: Array[Long] = partitionRange.toTimePoints // add 1 day to the end times to include data [ds 00:00:00.000, ds + 1 00:00:00.000) val shiftedEndTimes = endTimes.map(_ + tableUtils.partitionSpec.spanMillis) val sawtoothAggregator = new SawtoothAggregator(aggregations, selectedSchema, resolution) - val hops = hopsAggregate(endTimes.min, resolution) + val hops = hopsAggregate(endTimes.min, resolution, incAgg) hops .flatMap { @@ -356,12 +364,43 @@ class GroupBy(val aggregations: Seq[api.Aggregation], toDf(outputRdd, Seq(Constants.TimeColumn -> LongType, tableUtils.partitionColumn -> StringType)) } + //def dfToOutputArrayType(df: DataFrame): RDD[(KeyWithHash, HopsAggregator.OutputArrayType)] = { + // val keyBuilder: Row => KeyWithHash = + // FastHashing.generateKeyBuilder(keyColumns.toArray, df.schema) + + // df.rdd + // .keyBy(keyBuilder) + // .mapValues(SparkConversions.toChrononRow(_, tsIndex)) + // .mapValues(windowAggregator.toTimeSortedArray) + //} + + def flattenOutputArrayType(hopsArrays: RDD[(KeyWithHash, HopsAggregator.OutputArrayType)]): RDD[(Array[Any], Array[Any])] = { + hopsArrays.flatMap { case (keyWithHash: KeyWithHash, hopsArray: HopsAggregator.OutputArrayType) => + val hopsArrayHead: Array[HopIr] = hopsArray.headOption.get + hopsArrayHead.map { array: HopIr => + // the last element is a timestamp, we need to drop it + // and add it to the key + val timestamp = array.last.asInstanceOf[Long] + val withoutTimestamp = array.dropRight(1) + ((keyWithHash.data :+ tableUtils.partitionSpec.at(timestamp)), withoutTimestamp) + } + } + } + + def convertHopsToDf(range: PartitionRange, + schema: Array[(String, ai.chronon.api.DataType)] + ): DataFrame = { + val hops = hopsAggregate(range.toTimePoints.min, DailyResolution) + val hopsDf = flattenOutputArrayType(hops) + toDf(hopsDf, Seq((tableUtils.partitionColumn, StringType)), Some(SparkConversions.fromChrononSchema(schema))) + } + // convert raw data into IRs, collected by hopSizes // TODO cache this into a table: interface below // Class HopsCacher(keySchema, irSchema, resolution) extends RddCacher[(KeyWithHash, HopsOutput)] // buildTableRow((keyWithHash, hopsOutput)) -> GenericRowWithSchema // buildRddRow(GenericRowWithSchema) -> (keyWithHash, hopsOutput) - def hopsAggregate(minQueryTs: Long, resolution: Resolution): RDD[(KeyWithHash, HopsAggregator.OutputArrayType)] = { + def hopsAggregate(minQueryTs: Long, resolution: Resolution, incAgg: Boolean = false): RDD[(KeyWithHash, HopsAggregator.OutputArrayType)] = { val hopsAggregator = new HopsAggregator(minQueryTs, aggregations, selectedSchema, resolution) val keyBuilder: Row => KeyWithHash = @@ -378,9 +417,9 @@ class GroupBy(val aggregations: Seq[api.Aggregation], } protected[spark] def toDf(aggregateRdd: RDD[(Array[Any], Array[Any])], - additionalFields: Seq[(String, DataType)]): DataFrame = { + additionalFields: Seq[(String, DataType)], schema: Option[StructType] = None): DataFrame = { val finalKeySchema = StructType(keySchema ++ additionalFields.map { case (name, typ) => StructField(name, typ) }) - KvRdd(aggregateRdd, finalKeySchema, postAggSchema).toFlatDf + KvRdd(aggregateRdd, finalKeySchema, schema.getOrElse(postAggSchema)).toFlatDf } private def normalizeOrFinalize(ir: Array[Any]): Array[Any] = @@ -555,7 +594,9 @@ object GroupBy { keyColumns, nullFiltered, mutationDfFn, - finalize = finalizeValue) + finalize = finalizeValue, + incAgg = incrementalAgg, + ) } def getIntersectedRange(source: api.Source, @@ -685,13 +726,16 @@ object GroupBy { .orNull //range should be modified to incremental range val incGroupByBackfill = from(groupByConf, range, tableUtils, computeDependency = true, incrementalAgg = true) - val incOutputDf = incGroupByBackfill.snapshotEvents(range) - incOutputDf.save(incOutputTable, tableProps) + val selectedSchema = incGroupByBackfill.selectedSchema + //TODO is there any other way to get incSchema? + val incSchema = new RowAggregator(selectedSchema, incGroupByBackfill.aggregations.flatMap(_.unWindowed)).incSchema + val hopsDf = incGroupByBackfill.convertHopsToDf(range, incSchema) + hopsDf.save(incOutputTable, tableProps) val maxWindow = groupByConf.maxWindow.get val sourceQueryableRange = PartitionRange( - range.start, - tableUtils.partitionSpec.minus(range.end, maxWindow) + tableUtils.partitionSpec.minus(range.start, maxWindow), + range.end )(tableUtils) val incTableFirstPartition: Option[String] = tableUtils.firstAvailablePartition(incOutputTable) From 32d559eac7683dfc36a7704d27071f6578bf0baf Mon Sep 17 00:00:00 2001 From: chaitu Date: Fri, 13 Jun 2025 18:10:03 -0700 Subject: [PATCH 04/28] add code changes to generate final output from IR for AVG --- .../aggregator/base/SimpleAggregators.scala | 51 ++++++++++++++++ .../aggregator/row/ColumnAggregator.scala | 8 +++ .../main/scala/ai/chronon/spark/GroupBy.scala | 59 +++++++++++++------ 3 files changed, 100 insertions(+), 18 deletions(-) diff --git a/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala b/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala index b120d29e7f..31bf93cc6b 100644 --- a/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala +++ b/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala @@ -116,6 +116,57 @@ class UniqueCount[T](inputType: DataType) extends SimpleAggregator[T, util.HashS } } +class AverageIR extends SimpleAggregator[Array[Any], Array[Any], Double] { + override def outputType: DataType = DoubleType + + override def irType: DataType = + StructType( + "AvgIr", + Array(StructField("sum", DoubleType), StructField("count", IntType)) + ) + + override def prepare(input: Array[Any]): Array[Any] = { + Array(input(0).asInstanceOf[Double], input(1).asInstanceOf[Int]) + } + + // mutating + override def update(ir: Array[Any], input: Array[Any]): Array[Any] = { + val inputSum = input(0).asInstanceOf[Double] + val inputCount = input(1).asInstanceOf[Int] + ir.update(0, ir(0).asInstanceOf[Double] + inputSum) + ir.update(1, ir(1).asInstanceOf[Int] + inputCount) + ir + } + + // mutating + override def merge(ir1: Array[Any], ir2: Array[Any]): Array[Any] = { + ir1.update(0, ir1(0).asInstanceOf[Double] + ir2(0).asInstanceOf[Double]) + ir1.update(1, ir1(1).asInstanceOf[Int] + ir2(1).asInstanceOf[Int]) + ir1 + } + + override def finalize(ir: Array[Any]): Double = + ir(0).asInstanceOf[Double] / ir(1).asInstanceOf[Int].toDouble + + override def delete(ir: Array[Any], input: Array[Any]): Array[Any] = { + val inputSum = input(0).asInstanceOf[Double] + val inputCount = input(1).asInstanceOf[Int] + ir.update(0, ir(0).asInstanceOf[Double] - inputSum) + ir.update(1, ir(1).asInstanceOf[Int] - inputCount) + ir + } + + override def clone(ir: Array[Any]): Array[Any] = { + val arr = new Array[Any](ir.length) + ir.copyToArray(arr) + arr + } + + override def isDeletable: Boolean = true +} + + + class Average extends SimpleAggregator[Double, Array[Any], Double] { override def outputType: DataType = DoubleType diff --git a/aggregator/src/main/scala/ai/chronon/aggregator/row/ColumnAggregator.scala b/aggregator/src/main/scala/ai/chronon/aggregator/row/ColumnAggregator.scala index d5f21b3072..5c8a9bcf56 100644 --- a/aggregator/src/main/scala/ai/chronon/aggregator/row/ColumnAggregator.scala +++ b/aggregator/src/main/scala/ai/chronon/aggregator/row/ColumnAggregator.scala @@ -217,6 +217,13 @@ object ColumnAggregator { private def toJavaDouble[A: Numeric](inp: Any) = implicitly[Numeric[A]].toDouble(inp.asInstanceOf[A]).asInstanceOf[java.lang.Double] + + private def toStructArray(inp: Any): Array[Any] = inp match { + case r: org.apache.spark.sql.Row => r.toSeq.toArray + case null => null + case other => throw new IllegalArgumentException(s"Expected Row, got: $other") + } + def construct(baseInputType: DataType, aggregationPart: AggregationPart, columnIndices: ColumnIndices, @@ -330,6 +337,7 @@ object ColumnAggregator { case ShortType => simple(new Average, toDouble[Short]) case DoubleType => simple(new Average) case FloatType => simple(new Average, toDouble[Float]) + case StructType(name, fields) => simple(new AverageIR, toStructArray) case _ => mismatchException } diff --git a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala index a3f76f1483..2765e7b99f 100644 --- a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala +++ b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala @@ -147,15 +147,16 @@ class GroupBy(val aggregations: Seq[api.Aggregation], toDf(snapshotEntitiesBase, Seq(tableUtils.partitionColumn -> StringType)) } - def snapshotEventsBase(partitionRange: PartitionRange, - resolution: Resolution = DailyResolution, - incAgg: Boolean = true): RDD[(Array[Any], Array[Any])] = { - val endTimes: Array[Long] = partitionRange.toTimePoints + def computeHopsAggregate(endTimes: Array[Long], resolution: Resolution): RDD[(KeyWithHash, HopsAggregator.OutputArrayType)] = { + hopsAggregate(endTimes.min, resolution) + } + + def computeSawtoothAggregate(hops: RDD[(KeyWithHash, HopsAggregator.OutputArrayType)], + endTimes: Array[Long], + resolution: Resolution): RDD[(Array[Any], Array[Any])] = { // add 1 day to the end times to include data [ds 00:00:00.000, ds + 1 00:00:00.000) val shiftedEndTimes = endTimes.map(_ + tableUtils.partitionSpec.spanMillis) val sawtoothAggregator = new SawtoothAggregator(aggregations, selectedSchema, resolution) - val hops = hopsAggregate(endTimes.min, resolution, incAgg) - hops .flatMap { case (keys, hopsArrays) => @@ -169,6 +170,15 @@ class GroupBy(val aggregations: Seq[api.Aggregation], } } + def snapshotEventsBase(partitionRange: PartitionRange, + resolution: Resolution = DailyResolution, + incAgg: Boolean = true): RDD[(Array[Any], Array[Any])] = { + val endTimes: Array[Long] = partitionRange.toTimePoints + + val hops = computeHopsAggregate(endTimes, resolution) + computeSawtoothAggregate(hops, endTimes, resolution) + } + // Calculate snapshot accurate windows for ALL keys at pre-defined "endTimes" // At this time, we hardcode the resolution to Daily, but it is straight forward to support // hourly resolution. @@ -364,14 +374,13 @@ class GroupBy(val aggregations: Seq[api.Aggregation], toDf(outputRdd, Seq(Constants.TimeColumn -> LongType, tableUtils.partitionColumn -> StringType)) } - //def dfToOutputArrayType(df: DataFrame): RDD[(KeyWithHash, HopsAggregator.OutputArrayType)] = { + //def convertDfToOutputArrayType(df: DataFrame): RDD[(KeyWithHash, HopsAggregator.OutputArrayType)] = { // val keyBuilder: Row => KeyWithHash = // FastHashing.generateKeyBuilder(keyColumns.toArray, df.schema) // df.rdd // .keyBy(keyBuilder) // .mapValues(SparkConversions.toChrononRow(_, tsIndex)) - // .mapValues(windowAggregator.toTimeSortedArray) //} def flattenOutputArrayType(hopsArrays: RDD[(KeyWithHash, HopsAggregator.OutputArrayType)]): RDD[(Array[Any], Array[Any])] = { @@ -382,17 +391,16 @@ class GroupBy(val aggregations: Seq[api.Aggregation], // and add it to the key val timestamp = array.last.asInstanceOf[Long] val withoutTimestamp = array.dropRight(1) - ((keyWithHash.data :+ tableUtils.partitionSpec.at(timestamp)), withoutTimestamp) + ((keyWithHash.data :+ tableUtils.partitionSpec.at(timestamp) :+ timestamp), withoutTimestamp) } } } - def convertHopsToDf(range: PartitionRange, + def convertHopsToDf(hops: RDD[(KeyWithHash, HopsAggregator.OutputArrayType)], schema: Array[(String, ai.chronon.api.DataType)] ): DataFrame = { - val hops = hopsAggregate(range.toTimePoints.min, DailyResolution) val hopsDf = flattenOutputArrayType(hops) - toDf(hopsDf, Seq((tableUtils.partitionColumn, StringType)), Some(SparkConversions.fromChrononSchema(schema))) + toDf(hopsDf, Seq(tableUtils.partitionColumn -> StringType, Constants.TimeColumn -> LongType), Some(SparkConversions.fromChrononSchema(schema))) } // convert raw data into IRs, collected by hopSizes @@ -728,8 +736,10 @@ object GroupBy { val incGroupByBackfill = from(groupByConf, range, tableUtils, computeDependency = true, incrementalAgg = true) val selectedSchema = incGroupByBackfill.selectedSchema //TODO is there any other way to get incSchema? - val incSchema = new RowAggregator(selectedSchema, incGroupByBackfill.aggregations.flatMap(_.unWindowed)).incSchema - val hopsDf = incGroupByBackfill.convertHopsToDf(range, incSchema) + val incFlattendAgg = new RowAggregator(selectedSchema, incGroupByBackfill.aggregations.flatMap(_.unWindowed)) + val incSchema = incFlattendAgg.incSchema + val hops = incGroupByBackfill.computeHopsAggregate(range.toTimePoints, DailyResolution) + val hopsDf = incGroupByBackfill.convertHopsToDf(hops, incSchema) hopsDf.save(incOutputTable, tableProps) val maxWindow = groupByConf.maxWindow.get @@ -746,14 +756,27 @@ object GroupBy { incTableLastPartition.get )(tableUtils) + //val dfQuery = groupByConf. val incDfQuery = incTableRange.intersect(sourceQueryableRange).genScanQuery(null, incOutputTable) val incDf: DataFrame = tableUtils.sql(incDfQuery) + //incGroupByBackfill.computeSawtoothAggregate(incDf, range.toTimePoints, DailyResolution) + + val a = incFlattendAgg.aggregationParts.zip(groupByConf.getAggregations.toScala).map{ case (part, agg) => + val newAgg = agg.deepCopy() + newAgg.setInputColumn(part.incOutputColumnName) + newAgg + } + new GroupBy( - incGroupByBackfill.aggregations, - incGroupByBackfill.keyColumns, - incDf + a, + groupByConf.getKeyColumns.toScala, + incDf, + () => null, + finalize = true, + incAgg = false, ) + } def computeBackfill(groupByConf: api.GroupBy, @@ -801,7 +824,7 @@ object GroupBy { stepRanges.zipWithIndex.foreach { case (range, index) => logger.info(s"Computing group by for range: $range [${index + 1}/${stepRanges.size}]") - val groupByBackfill = if (incrementalAgg) { + val groupByBackfill: GroupBy = if (incrementalAgg) { saveAndGetIncDf(groupByConf, range, tableUtils) //from(groupByConf, range, tableUtils, computeDependency = true) } else { From 37293df2c7c7b491993016d96145bc5f7463301e Mon Sep 17 00:00:00 2001 From: chaitu Date: Thu, 19 Jun 2025 13:39:06 -0700 Subject: [PATCH 05/28] change function structure and variable names --- .../aggregator/row/RowAggregator.scala | 4 +- .../scala/ai/chronon/api/Extensions.scala | 4 +- .../main/scala/ai/chronon/spark/GroupBy.scala | 104 +++++++++++------- 3 files changed, 69 insertions(+), 43 deletions(-) diff --git a/aggregator/src/main/scala/ai/chronon/aggregator/row/RowAggregator.scala b/aggregator/src/main/scala/ai/chronon/aggregator/row/RowAggregator.scala index 6bda47bf19..e9d0608d25 100644 --- a/aggregator/src/main/scala/ai/chronon/aggregator/row/RowAggregator.scala +++ b/aggregator/src/main/scala/ai/chronon/aggregator/row/RowAggregator.scala @@ -70,8 +70,8 @@ class RowAggregator(val inputSchema: Seq[(String, DataType)], val aggregationPar .toArray .zip(columnAggregators.map(_.irType)) - val incSchema = aggregationParts - .map(_.incOutputColumnName) + val incrementalOutputSchema = aggregationParts + .map(_.incrementalOutputColumnName) .toArray .zip(columnAggregators.map(_.irType)) diff --git a/api/src/main/scala/ai/chronon/api/Extensions.scala b/api/src/main/scala/ai/chronon/api/Extensions.scala index c6c8074757..b39bb2f016 100644 --- a/api/src/main/scala/ai/chronon/api/Extensions.scala +++ b/api/src/main/scala/ai/chronon/api/Extensions.scala @@ -98,7 +98,7 @@ object Extensions { def cleanName: String = metaData.name.sanitize def outputTable = s"${metaData.outputNamespace}.${metaData.cleanName}" - def incOutputTable = s"${metaData.outputNamespace}.${metaData.cleanName}_inc" + def incrementalOutputTable = s"${metaData.outputNamespace}.${metaData.cleanName}_inc" def outputLabelTable = s"${metaData.outputNamespace}.${metaData.cleanName}_labels" def outputFinalView = s"${metaData.outputNamespace}.${metaData.cleanName}_labeled" def outputLatestLabelView = s"${metaData.outputNamespace}.${metaData.cleanName}_labeled_latest" @@ -178,7 +178,7 @@ object Extensions { def outputColumnName = s"${aggregationPart.inputColumn}_$opSuffix${aggregationPart.window.suffix}${bucketSuffix}" - def incOutputColumnName = + def incrementalOutputColumnName = s"${aggregationPart.inputColumn}_$opSuffix${bucketSuffix}" } diff --git a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala index 2765e7b99f..36c8362618 100644 --- a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala +++ b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala @@ -23,7 +23,7 @@ import ai.chronon.aggregator.windowing._ import ai.chronon.api import ai.chronon.api.DataModel.{Entities, Events} import ai.chronon.api.Extensions._ -import ai.chronon.api.{Accuracy, Constants, DataModel, ParametricMacro} +import ai.chronon.api.{Accuracy, Constants, DataModel, ParametricMacro, Source} import ai.chronon.online.{RowWrapper, SparkConversions} import ai.chronon.spark.Extensions._ import org.apache.spark.rdd.RDD @@ -43,7 +43,7 @@ class GroupBy(val aggregations: Seq[api.Aggregation], val mutationDfFn: () => DataFrame = null, skewFilter: Option[String] = None, finalize: Boolean = true, - incAgg: Boolean = false + incrementalMode: Boolean = false ) extends Serializable { @transient lazy val logger = LoggerFactory.getLogger(getClass) @@ -99,6 +99,10 @@ class GroupBy(val aggregations: Seq[api.Aggregation], SparkConversions.fromChrononSchema(valueChrononSchema) } + lazy val flattenedAgg: RowAggregator = new RowAggregator(selectedSchema, aggregations.flatMap(_.unWindowed)) + lazy val incrementalSchema: Array[(String, api.DataType)] = flattenedAgg.incrementalOutputSchema + + @transient protected[spark] lazy val windowAggregator: RowAggregator = new RowAggregator(selectedSchema, aggregations.flatMap(_.unpack)) @@ -509,7 +513,7 @@ object GroupBy { skewFilter: Option[String] = None, finalize: Boolean = true, showDf: Boolean = false, - incrementalAgg: Boolean = false): GroupBy = { + incrementalMode: Boolean = false): GroupBy = { logger.info(s"\n----[Processing GroupBy: ${groupByConfOld.metaData.name}]----") val groupByConf = replaceJoinSource(groupByConfOld, queryRange, tableUtils, computeDependency, showDf) val inputDf = groupByConf.sources.toScala @@ -593,17 +597,21 @@ object GroupBy { } df } - val finalizeValue = if (incrementalAgg) { - !incrementalAgg + + //if incrementalMode is enabled, we do not compute finalize values + //IR values are stored in the table + val finalizeValue = if (incrementalMode) { + false } else { finalize } + new GroupBy(Option(groupByConf.getAggregations).map(_.toScala).orNull, keyColumns, nullFiltered, mutationDfFn, finalize = finalizeValue, - incAgg = incrementalAgg, + incrementalMode = incrementalMode, ) } @@ -723,58 +731,77 @@ object GroupBy { query } - def saveAndGetIncDf( + /** + * Computes and saves the output of hopsAggregation. + * HopsAggregate computes event level data to daily aggregates and saves the output in IR format + * + * @param groupByConf + * @param range + * @param tableUtils + */ + def computeIncrementalDf( + groupByConf: api.GroupBy, + range: PartitionRange, + tableUtils: TableUtils, + ): GroupBy = { + + val incrementalOutputTable = groupByConf.metaData.incrementalOutputTable + val tableProps = Option(groupByConf.metaData.tableProperties) + .map(_.toScala) + .orNull + + val incrementalGroupByBackfill = + from(groupByConf, range, tableUtils, computeDependency = true, incrementalMode = true) + + val incrementalSchema = incrementalGroupByBackfill.incrementalSchema + + val hops = incrementalGroupByBackfill.computeHopsAggregate(range.toTimePoints, DailyResolution) + val hopsDf = incrementalGroupByBackfill.convertHopsToDf(hops, incrementalSchema) + hopsDf.save(incrementalOutputTable, tableProps) + + incrementalGroupByBackfill + } + + def fromIncrementalDf( groupByConf: api.GroupBy, range: PartitionRange, tableUtils: TableUtils, ): GroupBy = { - val incOutputTable = groupByConf.metaData.incOutputTable - val tableProps = Option(groupByConf.metaData.tableProperties) - .map(_.toScala) - .orNull - //range should be modified to incremental range - val incGroupByBackfill = from(groupByConf, range, tableUtils, computeDependency = true, incrementalAgg = true) - val selectedSchema = incGroupByBackfill.selectedSchema - //TODO is there any other way to get incSchema? - val incFlattendAgg = new RowAggregator(selectedSchema, incGroupByBackfill.aggregations.flatMap(_.unWindowed)) - val incSchema = incFlattendAgg.incSchema - val hops = incGroupByBackfill.computeHopsAggregate(range.toTimePoints, DailyResolution) - val hopsDf = incGroupByBackfill.convertHopsToDf(hops, incSchema) - hopsDf.save(incOutputTable, tableProps) - - val maxWindow = groupByConf.maxWindow.get + + + val incrementalGroupByBackfill: GroupBy = computeIncrementalDf(groupByConf, range, tableUtils) + + val incrementalOutputTable = groupByConf.metaData.incrementalOutputTable val sourceQueryableRange = PartitionRange( - tableUtils.partitionSpec.minus(range.start, maxWindow), + tableUtils.partitionSpec.minus(range.start, groupByConf.maxWindow.get), range.end )(tableUtils) - val incTableFirstPartition: Option[String] = tableUtils.firstAvailablePartition(incOutputTable) - val incTableLastPartition: Option[String] = tableUtils.lastAvailablePartition(incOutputTable) + val incTableFirstPartition: Option[String] = tableUtils.firstAvailablePartition(incrementalOutputTable) + val incTableLastPartition: Option[String] = tableUtils.lastAvailablePartition(incrementalOutputTable) val incTableRange = PartitionRange( incTableFirstPartition.get, incTableLastPartition.get )(tableUtils) - //val dfQuery = groupByConf. - val incDfQuery = incTableRange.intersect(sourceQueryableRange).genScanQuery(null, incOutputTable) - val incDf: DataFrame = tableUtils.sql(incDfQuery) - //incGroupByBackfill.computeSawtoothAggregate(incDf, range.toTimePoints, DailyResolution) + val incrementalDf: DataFrame = tableUtils.sql( + incTableRange.intersect(sourceQueryableRange).genScanQuery(null, incrementalOutputTable) + ) - val a = incFlattendAgg.aggregationParts.zip(groupByConf.getAggregations.toScala).map{ case (part, agg) => + val incrementalAggregations = incrementalGroupByBackfill.flattenedAgg.aggregationParts.zip(groupByConf.getAggregations.toScala).map{ case (part, agg) => val newAgg = agg.deepCopy() - newAgg.setInputColumn(part.incOutputColumnName) + newAgg.setInputColumn(part.incrementalOutputColumnName) newAgg } - new GroupBy( - a, + incrementalAggregations, groupByConf.getKeyColumns.toScala, - incDf, + incrementalDf, () => null, finalize = true, - incAgg = false, + incrementalMode = false, ) } @@ -785,7 +812,7 @@ object GroupBy { stepDays: Option[Int] = None, overrideStartPartition: Option[String] = None, skipFirstHole: Boolean = true, - incrementalAgg: Boolean = true): Unit = { + incrementalMode: Boolean = true): Unit = { assert( groupByConf.backfillStartDate != null, s"GroupBy:${groupByConf.metaData.name} has null backfillStartDate. This needs to be set for offline backfilling.") @@ -824,9 +851,8 @@ object GroupBy { stepRanges.zipWithIndex.foreach { case (range, index) => logger.info(s"Computing group by for range: $range [${index + 1}/${stepRanges.size}]") - val groupByBackfill: GroupBy = if (incrementalAgg) { - saveAndGetIncDf(groupByConf, range, tableUtils) - //from(groupByConf, range, tableUtils, computeDependency = true) + val groupByBackfill: GroupBy = if (incrementalMode) { + fromIncrementalDf(groupByConf, range, tableUtils) } else { from(groupByConf, range, tableUtils, computeDependency = true) } From 6263706d0c2420af3cf1e5be6ecfb118ce04e69e Mon Sep 17 00:00:00 2001 From: chaitu Date: Thu, 19 Jun 2025 13:45:18 -0700 Subject: [PATCH 06/28] remove unused functions --- .../main/scala/ai/chronon/spark/GroupBy.scala | 23 ++++--------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala index 36c8362618..53347dc970 100644 --- a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala +++ b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala @@ -150,11 +150,7 @@ class GroupBy(val aggregations: Seq[api.Aggregation], } else { toDf(snapshotEntitiesBase, Seq(tableUtils.partitionColumn -> StringType)) } - - def computeHopsAggregate(endTimes: Array[Long], resolution: Resolution): RDD[(KeyWithHash, HopsAggregator.OutputArrayType)] = { - hopsAggregate(endTimes.min, resolution) - } - + def computeSawtoothAggregate(hops: RDD[(KeyWithHash, HopsAggregator.OutputArrayType)], endTimes: Array[Long], resolution: Resolution): RDD[(Array[Any], Array[Any])] = { @@ -179,7 +175,7 @@ class GroupBy(val aggregations: Seq[api.Aggregation], incAgg: Boolean = true): RDD[(Array[Any], Array[Any])] = { val endTimes: Array[Long] = partitionRange.toTimePoints - val hops = computeHopsAggregate(endTimes, resolution) + val hops = hopsAggregate(endTimes.min, resolution) computeSawtoothAggregate(hops, endTimes, resolution) } @@ -378,21 +374,10 @@ class GroupBy(val aggregations: Seq[api.Aggregation], toDf(outputRdd, Seq(Constants.TimeColumn -> LongType, tableUtils.partitionColumn -> StringType)) } - //def convertDfToOutputArrayType(df: DataFrame): RDD[(KeyWithHash, HopsAggregator.OutputArrayType)] = { - // val keyBuilder: Row => KeyWithHash = - // FastHashing.generateKeyBuilder(keyColumns.toArray, df.schema) - - // df.rdd - // .keyBy(keyBuilder) - // .mapValues(SparkConversions.toChrononRow(_, tsIndex)) - //} - def flattenOutputArrayType(hopsArrays: RDD[(KeyWithHash, HopsAggregator.OutputArrayType)]): RDD[(Array[Any], Array[Any])] = { hopsArrays.flatMap { case (keyWithHash: KeyWithHash, hopsArray: HopsAggregator.OutputArrayType) => val hopsArrayHead: Array[HopIr] = hopsArray.headOption.get hopsArrayHead.map { array: HopIr => - // the last element is a timestamp, we need to drop it - // and add it to the key val timestamp = array.last.asInstanceOf[Long] val withoutTimestamp = array.dropRight(1) ((keyWithHash.data :+ tableUtils.partitionSpec.at(timestamp) :+ timestamp), withoutTimestamp) @@ -412,7 +397,7 @@ class GroupBy(val aggregations: Seq[api.Aggregation], // Class HopsCacher(keySchema, irSchema, resolution) extends RddCacher[(KeyWithHash, HopsOutput)] // buildTableRow((keyWithHash, hopsOutput)) -> GenericRowWithSchema // buildRddRow(GenericRowWithSchema) -> (keyWithHash, hopsOutput) - def hopsAggregate(minQueryTs: Long, resolution: Resolution, incAgg: Boolean = false): RDD[(KeyWithHash, HopsAggregator.OutputArrayType)] = { + def hopsAggregate(minQueryTs: Long, resolution: Resolution): RDD[(KeyWithHash, HopsAggregator.OutputArrayType)] = { val hopsAggregator = new HopsAggregator(minQueryTs, aggregations, selectedSchema, resolution) val keyBuilder: Row => KeyWithHash = @@ -755,7 +740,7 @@ object GroupBy { val incrementalSchema = incrementalGroupByBackfill.incrementalSchema - val hops = incrementalGroupByBackfill.computeHopsAggregate(range.toTimePoints, DailyResolution) + val hops = incrementalGroupByBackfill.hopsAggregate(range.toTimePoints.min, DailyResolution) val hopsDf = incrementalGroupByBackfill.convertHopsToDf(hops, incrementalSchema) hopsDf.save(incrementalOutputTable, tableProps) From cb4325ba90456e3d93a0d4ec365bcdd329318b0e Mon Sep 17 00:00:00 2001 From: chaitu Date: Thu, 19 Jun 2025 14:00:33 -0700 Subject: [PATCH 07/28] change function defs --- .../main/scala/ai/chronon/spark/GroupBy.scala | 20 +++++++------------ 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala index 53347dc970..d56877af69 100644 --- a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala +++ b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala @@ -150,10 +150,13 @@ class GroupBy(val aggregations: Seq[api.Aggregation], } else { toDf(snapshotEntitiesBase, Seq(tableUtils.partitionColumn -> StringType)) } - - def computeSawtoothAggregate(hops: RDD[(KeyWithHash, HopsAggregator.OutputArrayType)], - endTimes: Array[Long], - resolution: Resolution): RDD[(Array[Any], Array[Any])] = { + + def snapshotEventsBase(partitionRange: PartitionRange, + resolution: Resolution = DailyResolution, + incAgg: Boolean = true): RDD[(Array[Any], Array[Any])] = { + val endTimes: Array[Long] = partitionRange.toTimePoints + + val hops = hopsAggregate(endTimes.min, resolution) // add 1 day to the end times to include data [ds 00:00:00.000, ds + 1 00:00:00.000) val shiftedEndTimes = endTimes.map(_ + tableUtils.partitionSpec.spanMillis) val sawtoothAggregator = new SawtoothAggregator(aggregations, selectedSchema, resolution) @@ -170,15 +173,6 @@ class GroupBy(val aggregations: Seq[api.Aggregation], } } - def snapshotEventsBase(partitionRange: PartitionRange, - resolution: Resolution = DailyResolution, - incAgg: Boolean = true): RDD[(Array[Any], Array[Any])] = { - val endTimes: Array[Long] = partitionRange.toTimePoints - - val hops = hopsAggregate(endTimes.min, resolution) - computeSawtoothAggregate(hops, endTimes, resolution) - } - // Calculate snapshot accurate windows for ALL keys at pre-defined "endTimes" // At this time, we hardcode the resolution to Daily, but it is straight forward to support // hourly resolution. From 796ef9660bcbee5a8263e990b282787ab59f53aa Mon Sep 17 00:00:00 2001 From: chaitu Date: Thu, 19 Jun 2025 14:01:43 -0700 Subject: [PATCH 08/28] make changes --- spark/src/main/scala/ai/chronon/spark/GroupBy.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala index d56877af69..dfbfd43090 100644 --- a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala +++ b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala @@ -155,11 +155,11 @@ class GroupBy(val aggregations: Seq[api.Aggregation], resolution: Resolution = DailyResolution, incAgg: Boolean = true): RDD[(Array[Any], Array[Any])] = { val endTimes: Array[Long] = partitionRange.toTimePoints - - val hops = hopsAggregate(endTimes.min, resolution) // add 1 day to the end times to include data [ds 00:00:00.000, ds + 1 00:00:00.000) val shiftedEndTimes = endTimes.map(_ + tableUtils.partitionSpec.spanMillis) val sawtoothAggregator = new SawtoothAggregator(aggregations, selectedSchema, resolution) + val hops = hopsAggregate(endTimes.min, resolution) + hops .flatMap { case (keys, hopsArrays) => From f218b231bd3e430850f3304c1f9be40dfdcdace3 Mon Sep 17 00:00:00 2001 From: chaitu Date: Thu, 19 Jun 2025 14:03:46 -0700 Subject: [PATCH 09/28] change function order --- spark/src/main/scala/ai/chronon/spark/GroupBy.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala index dfbfd43090..cb557a50f3 100644 --- a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala +++ b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala @@ -152,8 +152,7 @@ class GroupBy(val aggregations: Seq[api.Aggregation], } def snapshotEventsBase(partitionRange: PartitionRange, - resolution: Resolution = DailyResolution, - incAgg: Boolean = true): RDD[(Array[Any], Array[Any])] = { + resolution: Resolution = DailyResolution): RDD[(Array[Any], Array[Any])] = { val endTimes: Array[Long] = partitionRange.toTimePoints // add 1 day to the end times to include data [ds 00:00:00.000, ds + 1 00:00:00.000) val shiftedEndTimes = endTimes.map(_ + tableUtils.partitionSpec.spanMillis) From b1d4ee99b96a9671cec04e9ced50a73760de44ed Mon Sep 17 00:00:00 2001 From: chaitu Date: Fri, 20 Jun 2025 10:52:14 -0700 Subject: [PATCH 10/28] add new field is_incremental to python api --- api/py/ai/chronon/group_by.py | 2 ++ api/thrift/api.thrift | 1 + 2 files changed, 3 insertions(+) diff --git a/api/py/ai/chronon/group_by.py b/api/py/ai/chronon/group_by.py index b2290e34e8..e919b74060 100644 --- a/api/py/ai/chronon/group_by.py +++ b/api/py/ai/chronon/group_by.py @@ -349,6 +349,7 @@ def GroupBy( tags: Dict[str, str] = None, derivations: List[ttypes.Derivation] = None, deprecation_date: str = None, + is_incremental: bool = False, **kwargs, ) -> ttypes.GroupBy: """ @@ -556,6 +557,7 @@ def _normalize_source(source): backfillStartDate=backfill_start_date, accuracy=accuracy, derivations=derivations, + isIncremental=is_incremental, ) validate_group_by(group_by) return group_by diff --git a/api/thrift/api.thrift b/api/thrift/api.thrift index 3fd8f5428a..16f19d1681 100644 --- a/api/thrift/api.thrift +++ b/api/thrift/api.thrift @@ -278,6 +278,7 @@ struct GroupBy { 6: optional string backfillStartDate // Optional derivation list 7: optional list derivations + 8: optional bool isIncremental } struct JoinPart { From 2ab7659c18db9c764a3d3e0cd078f649d9ac4d26 Mon Sep 17 00:00:00 2001 From: chaitu Date: Fri, 20 Jun 2025 10:54:01 -0700 Subject: [PATCH 11/28] get argument for isIncremental in scala spark backend --- spark/src/main/scala/ai/chronon/spark/Driver.scala | 3 ++- .../src/main/scala/ai/chronon/spark/GroupBy.scala | 14 +++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/Driver.scala b/spark/src/main/scala/ai/chronon/spark/Driver.scala index 560d6d261a..748c4765bf 100644 --- a/spark/src/main/scala/ai/chronon/spark/Driver.scala +++ b/spark/src/main/scala/ai/chronon/spark/Driver.scala @@ -386,7 +386,8 @@ object Driver { tableUtils, args.stepDays.toOption, args.startPartitionOverride.toOption, - !args.runFirstHole() + !args.runFirstHole(), + args.groupByConf.isIncremental ) if (args.shouldExport()) { diff --git a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala index cb557a50f3..0f607680c5 100644 --- a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala +++ b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala @@ -497,12 +497,12 @@ object GroupBy { val inputDf = groupByConf.sources.toScala .map { source => renderDataSourceQuery(groupByConf, - source, - groupByConf.getKeyColumns.toScala, - queryRange, - tableUtils, - groupByConf.maxWindow, - groupByConf.inferredAccuracy) + source, + groupByConf.getKeyColumns.toScala, + queryRange, + tableUtils, + groupByConf.maxWindow, + groupByConf.inferredAccuracy) } .map { @@ -790,7 +790,7 @@ object GroupBy { stepDays: Option[Int] = None, overrideStartPartition: Option[String] = None, skipFirstHole: Boolean = true, - incrementalMode: Boolean = true): Unit = { + incrementalMode: Boolean = false): Unit = { assert( groupByConf.backfillStartDate != null, s"GroupBy:${groupByConf.metaData.name} has null backfillStartDate. This needs to be set for offline backfilling.") From 238c781f5b8adb13928ed3fa6634843746bc0516 Mon Sep 17 00:00:00 2001 From: chaitu Date: Fri, 20 Jun 2025 15:36:56 -0700 Subject: [PATCH 12/28] add unit test for incremental groupby --- .../ai/chronon/spark/test/GroupByTest.scala | 94 +++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala b/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala index dd979e4422..baa7a1cbf4 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala @@ -23,6 +23,7 @@ import ai.chronon.api.{Aggregation, Builders, Constants, Derivation, DoubleType, import ai.chronon.online.{RowWrapper, SparkConversions} import ai.chronon.spark.Extensions._ import ai.chronon.spark._ +import ai.chronon.spark.test.TestUtils.makeDf import com.google.gson.Gson import org.apache.spark.rdd.RDD import org.apache.spark.sql.types.{StructField, StructType, LongType => SparkLongType, StringType => SparkStringType} @@ -423,6 +424,7 @@ class GroupByTest { additionalAgg = aggs) } + private def createTestSource(windowSize: Int = 365, suffix: String = ""): (Source, String) = { lazy val spark: SparkSession = SparkSessionBuilder.build("GroupByTest" + "_" + Random.alphanumeric.take(6).mkString, local = true) implicit val tableUtils = TableUtils(spark) @@ -694,4 +696,96 @@ class GroupByTest { tableUtils = tableUtils, additionalAgg = aggs) } + + @Test + def testIncrementalMode(): Unit = { + lazy val spark: SparkSession = SparkSessionBuilder.build("GroupByTest" + "_" + Random.alphanumeric.take(6).mkString, local = true) + implicit val tableUtils = TableUtils(spark) + + val namespace = "test_incremental_group_by" + + val schema = + ai.chronon.api.StructType( + "test_incremental_group_by", + Array( + ai.chronon.api.StructField("user", StringType), + ai.chronon.api.StructField("purchase_price", IntType), + ai.chronon.api.StructField("ds", StringType), + ai.chronon.api.StructField("ts", LongType) + ) + ) + + + val sourceTable = "test_incremental_group_by_" + Random.alphanumeric.take(6).mkString + val data = List( + Row("user1", 100, "2025-06-01", 1748772000000L), + Row("user2", 200, "2025-06-01", 1748772000000L), + Row("user3", 300, "2025-06-01", 1748772000000L), + ) + val df = makeDf(spark, schema, data).save(s"${namespace}.${sourceTable}") + + + val source = Builders.Source.events( + query = Builders.Query(selects = Builders.Selects("ts", "user", "purchase_price", "ds"), + startPartition = "2025-06-01"), + table = sourceTable + ) + + // Define aggregations + val aggregations: Seq[Aggregation] = Seq( + Builders.Aggregation(Operation.SUM, "purchase_price", Seq(new Window(3, TimeUnit.DAYS))), + Builders.Aggregation(Operation.COUNT, "purchase_price", Seq(new Window(3, TimeUnit.DAYS))) + ) + + val groupByConf = Builders.GroupBy( + sources = Seq(source), + keyColumns = Seq("user"), + aggregations = aggregations, + metaData = Builders.MetaData(name = "intermediate_output_gb", namespace = "test_incremental_group_by", team = "chronon"), + backfillStartDate = tableUtils.partitionSpec.minus(tableUtils.partitionSpec.at(System.currentTimeMillis()), + new Window(60, TimeUnit.DAYS)), + ) + + val outputTableName = groupByConf.metaData.incrementalOutputTable + + GroupBy.computeIncrementalDf( + groupByConf, + PartitionRange("2025-06-01", "2025-06-01"), + tableUtils + ) + + //check if the table exists + assertTrue(s"Output table $outputTableName should exist", spark.catalog.tableExists(outputTableName)) + + // Create GroupBy with incrementalMode = true + /* + val incrementalGroupBy = new GroupBy( + aggregations = aggregations, + keyColumns = Seq("user"), + inputDf = df, + incrementalMode = true + ) + */ + + // Test that incremental schema is available + //val incrementalSchema = incrementalGroupBy.incrementalSchema + //assertNotNull("Incremental schema should not be null", incrementalSchema) + //assertEquals("Should have correct number of incremental schema columns", + // aggregations.length, incrementalSchema.length) + + + // Test that we can compute snapshot events + //val result = incrementalGroupBy.snapshotEvents(PartitionRange("2025-06-01", "2025-06-01")) + //println("================================") + //println(df.show()) + //println(result.show()) + //println("================================") + //assertNotNull("Should be able to compute snapshot events", result) + //assertTrue("Result should have data", result.count() > 100) + + // Verify that result contains expected columns + //val resultColumns = result.columns.toSet + //assertTrue("Result should contain user column", resultColumns.contains("user")) + //assertTrue("Result should contain partition column", resultColumns.contains(tableUtils.partitionColumn)) + } } From 8edfd2785aab51adf858d43b7e526a5124ad7065 Mon Sep 17 00:00:00 2001 From: chaitu Date: Thu, 17 Jul 2025 21:00:54 -0700 Subject: [PATCH 13/28] reuse table ccreation --- .../ai/chronon/spark/test/GroupByTest.scala | 45 +++++++------------ 1 file changed, 15 insertions(+), 30 deletions(-) diff --git a/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala b/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala index baa7a1cbf4..de934561e6 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala @@ -699,49 +699,31 @@ class GroupByTest { @Test def testIncrementalMode(): Unit = { - lazy val spark: SparkSession = SparkSessionBuilder.build("GroupByTest" + "_" + Random.alphanumeric.take(6).mkString, local = true) + lazy val spark: SparkSession = SparkSessionBuilder.build("GroupByTestIncremental" + "_" + Random.alphanumeric.take(6).mkString, local = true) implicit val tableUtils = TableUtils(spark) - val namespace = "test_incremental_group_by" - - val schema = - ai.chronon.api.StructType( - "test_incremental_group_by", - Array( - ai.chronon.api.StructField("user", StringType), - ai.chronon.api.StructField("purchase_price", IntType), - ai.chronon.api.StructField("ds", StringType), - ai.chronon.api.StructField("ts", LongType) - ) - ) - - - val sourceTable = "test_incremental_group_by_" + Random.alphanumeric.take(6).mkString - val data = List( - Row("user1", 100, "2025-06-01", 1748772000000L), - Row("user2", 200, "2025-06-01", 1748772000000L), - Row("user3", 300, "2025-06-01", 1748772000000L), + val schema = List( + Column("user", StringType, 10), // ts = last 10 days + Column("session_length", IntType, 2), + Column("rating", DoubleType, 2000) ) - val df = makeDf(spark, schema, data).save(s"${namespace}.${sourceTable}") + val df = DataFrameGen.events(spark, schema, count = 100000, partitions = 100) - val source = Builders.Source.events( - query = Builders.Query(selects = Builders.Selects("ts", "user", "purchase_price", "ds"), - startPartition = "2025-06-01"), - table = sourceTable - ) + println(s"Input DataFrame: ${df.count()}") - // Define aggregations val aggregations: Seq[Aggregation] = Seq( - Builders.Aggregation(Operation.SUM, "purchase_price", Seq(new Window(3, TimeUnit.DAYS))), - Builders.Aggregation(Operation.COUNT, "purchase_price", Seq(new Window(3, TimeUnit.DAYS))) + Builders.Aggregation(Operation.APPROX_UNIQUE_COUNT, "session_length", Seq(new Window(10, TimeUnit.DAYS), WindowUtils.Unbounded)), + Builders.Aggregation(Operation.UNIQUE_COUNT, "session_length", Seq(new Window(10, TimeUnit.DAYS), WindowUtils.Unbounded)), + Builders.Aggregation(Operation.SUM, "session_length", Seq(new Window(10, TimeUnit.DAYS))) ) + val groupByConf = Builders.GroupBy( sources = Seq(source), keyColumns = Seq("user"), aggregations = aggregations, - metaData = Builders.MetaData(name = "intermediate_output_gb", namespace = "test_incremental_group_by", team = "chronon"), + metaData = Builders.MetaData(name = "intermediate_output_gb", namespace = namespace, team = "chronon"), backfillStartDate = tableUtils.partitionSpec.minus(tableUtils.partitionSpec.at(System.currentTimeMillis()), new Window(60, TimeUnit.DAYS)), ) @@ -756,6 +738,9 @@ class GroupByTest { //check if the table exists assertTrue(s"Output table $outputTableName should exist", spark.catalog.tableExists(outputTableName)) + + val df1 = spark.sql(s"SELECT * FROM $outputTableName").toDF() + println(s"Table output ${df1.show()}") // Create GroupBy with incrementalMode = true /* From e903683dc6d7d3a28a943d0bfd95e7e16a20ef5b Mon Sep 17 00:00:00 2001 From: chaitu Date: Fri, 18 Jul 2025 09:39:14 -0700 Subject: [PATCH 14/28] Update GroupByTest --- .../main/scala/ai/chronon/spark/GroupBy.scala | 20 +++-- .../ai/chronon/spark/test/GroupByTest.scala | 83 +++++++------------ 2 files changed, 43 insertions(+), 60 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala index 0f607680c5..939843be1f 100644 --- a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala +++ b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala @@ -419,6 +419,16 @@ class GroupBy(val aggregations: Seq[api.Aggregation], windowAggregator.normalize(ir) } + + def computeIncrementalDf(incrementalOutputTable: String, + range: PartitionRange, + tableProps: Map[String, String]) = { + + val hops = hopsAggregate(range.toTimePoints.min, DailyResolution) + println(s"Saving incremental hops to ${hops.map(x => x._1.data.mkString(",")).take(20)}.") + val hopsDf: DataFrame = convertHopsToDf(hops, incrementalSchema) + hopsDf.save(incrementalOutputTable, tableProps) + } } // TODO: truncate queryRange for caching @@ -723,19 +733,15 @@ object GroupBy { tableUtils: TableUtils, ): GroupBy = { - val incrementalOutputTable = groupByConf.metaData.incrementalOutputTable - val tableProps = Option(groupByConf.metaData.tableProperties) + val incrementalOutputTable: String = groupByConf.metaData.incrementalOutputTable + val tableProps: Map[String, String] = Option(groupByConf.metaData.tableProperties) .map(_.toScala) .orNull val incrementalGroupByBackfill = from(groupByConf, range, tableUtils, computeDependency = true, incrementalMode = true) - val incrementalSchema = incrementalGroupByBackfill.incrementalSchema - - val hops = incrementalGroupByBackfill.hopsAggregate(range.toTimePoints.min, DailyResolution) - val hopsDf = incrementalGroupByBackfill.convertHopsToDf(hops, incrementalSchema) - hopsDf.save(incrementalOutputTable, tableProps) + incrementalGroupByBackfill.computeIncrementalDf(incrementalOutputTable, range, tableProps) incrementalGroupByBackfill } diff --git a/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala b/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala index de934561e6..8cb886c498 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala @@ -701,7 +701,8 @@ class GroupByTest { def testIncrementalMode(): Unit = { lazy val spark: SparkSession = SparkSessionBuilder.build("GroupByTestIncremental" + "_" + Random.alphanumeric.take(6).mkString, local = true) implicit val tableUtils = TableUtils(spark) - + val namespace = "incremental" + tableUtils.createDatabase(namespace) val schema = List( Column("user", StringType, 10), // ts = last 10 days Column("session_length", IntType, 2), @@ -713,64 +714,40 @@ class GroupByTest { println(s"Input DataFrame: ${df.count()}") val aggregations: Seq[Aggregation] = Seq( - Builders.Aggregation(Operation.APPROX_UNIQUE_COUNT, "session_length", Seq(new Window(10, TimeUnit.DAYS), WindowUtils.Unbounded)), - Builders.Aggregation(Operation.UNIQUE_COUNT, "session_length", Seq(new Window(10, TimeUnit.DAYS), WindowUtils.Unbounded)), + //Builders.Aggregation(Operation.APPROX_UNIQUE_COUNT, "session_length", Seq(new Window(10, TimeUnit.DAYS), WindowUtils.Unbounded)), + //Builders.Aggregation(Operation.UNIQUE_COUNT, "session_length", Seq(new Window(10, TimeUnit.DAYS), WindowUtils.Unbounded)), Builders.Aggregation(Operation.SUM, "session_length", Seq(new Window(10, TimeUnit.DAYS))) ) - - val groupByConf = Builders.GroupBy( - sources = Seq(source), - keyColumns = Seq("user"), - aggregations = aggregations, - metaData = Builders.MetaData(name = "intermediate_output_gb", namespace = namespace, team = "chronon"), - backfillStartDate = tableUtils.partitionSpec.minus(tableUtils.partitionSpec.at(System.currentTimeMillis()), - new Window(60, TimeUnit.DAYS)), + val tableProps: Map[String, String] = Map( + "source" -> "chronon" ) - val outputTableName = groupByConf.metaData.incrementalOutputTable + val groupBy = new GroupBy(aggregations, Seq("user"), df) + groupBy.computeIncrementalDf("incremental.testIncrementalOutput", PartitionRange("2025-05-01", "2025-06-01"), tableProps) - GroupBy.computeIncrementalDf( - groupByConf, - PartitionRange("2025-06-01", "2025-06-01"), - tableUtils - ) + val actualIncrementalDf = spark.sql(s"select * from incremental.testIncrementalOutput where ds='2025-05-11'") + df.createOrReplaceTempView("test_incremental_input") + //spark.sql(s"select * from test_incremental_input where user='user7' and ds='2025-05-11'").show(numRows=100) + + spark.sql(s"select * from incremental.testIncrementalOutput where ds='2025-05-11'").show() + + val query = + s""" + |select user, ds, UNIX_TIMESTAMP(ds, 'yyyy-MM-dd')*1000 as ts, sum(session_length) as session_length_sum + |from test_incremental_input + |where ds='2025-05-11' + |group by user, ds, UNIX_TIMESTAMP(ds, 'yyyy-MM-dd')*1000 + |""".stripMargin + + val expectedDf = spark.sql(query) + + val diff = Comparison.sideBySide(actualIncrementalDf, expectedDf, List("user", tableUtils.partitionColumn)) + if (diff.count() > 0) { + diff.show() + println("diff result rows") + } + assertEquals(0, diff.count()) - //check if the table exists - assertTrue(s"Output table $outputTableName should exist", spark.catalog.tableExists(outputTableName)) - - val df1 = spark.sql(s"SELECT * FROM $outputTableName").toDF() - println(s"Table output ${df1.show()}") - - // Create GroupBy with incrementalMode = true - /* - val incrementalGroupBy = new GroupBy( - aggregations = aggregations, - keyColumns = Seq("user"), - inputDf = df, - incrementalMode = true - ) - */ - - // Test that incremental schema is available - //val incrementalSchema = incrementalGroupBy.incrementalSchema - //assertNotNull("Incremental schema should not be null", incrementalSchema) - //assertEquals("Should have correct number of incremental schema columns", - // aggregations.length, incrementalSchema.length) - - - // Test that we can compute snapshot events - //val result = incrementalGroupBy.snapshotEvents(PartitionRange("2025-06-01", "2025-06-01")) - //println("================================") - //println(df.show()) - //println(result.show()) - //println("================================") - //assertNotNull("Should be able to compute snapshot events", result) - //assertTrue("Result should have data", result.count() > 100) - - // Verify that result contains expected columns - //val resultColumns = result.columns.toSet - //assertTrue("Result should contain user column", resultColumns.contains("user")) - //assertTrue("Result should contain partition column", resultColumns.contains(tableUtils.partitionColumn)) } } From 0bdc4fc5670ecebc1d5cfe3de4c9677485c47307 Mon Sep 17 00:00:00 2001 From: chaitu Date: Fri, 18 Jul 2025 15:43:23 -0700 Subject: [PATCH 15/28] Add GroupByTest for events --- .../ai/chronon/spark/test/GroupByTest.scala | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala b/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala index 8cb886c498..5a7d2f7e58 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala @@ -101,6 +101,7 @@ class GroupByTest { val groupBy = new GroupBy(aggregations, Seq("user"), df) val actualDf = groupBy.snapshotEvents(PartitionRange(outputDates.min, outputDates.max)) + val outputDatesRdd: RDD[Row] = spark.sparkContext.parallelize(outputDates.map(Row(_))) val outputDatesDf = spark.createDataFrame(outputDatesRdd, StructType(Seq(StructField("ds", SparkStringType)))) val datesViewName = "test_group_by_snapshot_events_output_range" @@ -748,6 +749,61 @@ class GroupByTest { println("diff result rows") } assertEquals(0, diff.count()) + } + + @Test + def testSnapshotIncrementalEvents(): Unit = { + lazy val spark: SparkSession = SparkSessionBuilder.build("GroupByTest" + "_" + Random.alphanumeric.take(6).mkString, local = true) + implicit val tableUtils = TableUtils(spark) + val schema = List( + Column("user", StringType, 10), // ts = last 10 days + Column("session_length", IntType, 2), + Column("rating", DoubleType, 2000) + ) + + val outputDates = CStream.genPartitions(10, tableUtils.partitionSpec) + + val df = DataFrameGen.events(spark, schema, count = 100000, partitions = 100) + df.drop("ts") // snapshots don't need ts. + val viewName = "test_group_by_snapshot_events" + df.createOrReplaceTempView(viewName) + val aggregations: Seq[Aggregation] = Seq( + Builders.Aggregation(Operation.SUM, "session_length", Seq(new Window(10, TimeUnit.DAYS))), + Builders.Aggregation(Operation.SUM, "rating", Seq(new Window(10, TimeUnit.DAYS))) + ) + + val groupBy = new GroupBy(aggregations, Seq("user"), df) + val actualDf = groupBy.snapshotEvents(PartitionRange(outputDates.min, outputDates.max)) + + val groupByIncremental = new GroupBy(aggregations, Seq("user"), df, incrementalMode = true) + val actualDfIncremental = groupByIncremental.snapshotEvents(PartitionRange(outputDates.min, outputDates.max)) + + val outputDatesRdd: RDD[Row] = spark.sparkContext.parallelize(outputDates.map(Row(_))) + val outputDatesDf = spark.createDataFrame(outputDatesRdd, StructType(Seq(StructField("ds", SparkStringType)))) + val datesViewName = "test_group_by_snapshot_events_output_range" + outputDatesDf.createOrReplaceTempView(datesViewName) + val expectedDf = df.sqlContext.sql(s""" + |select user, + | $datesViewName.ds, + | SUM(IF(ts >= (unix_timestamp($datesViewName.ds, 'yyyy-MM-dd') - 86400*(10-1)) * 1000, session_length, null)) AS session_length_sum_10d, + | SUM(IF(ts >= (unix_timestamp($datesViewName.ds, 'yyyy-MM-dd') - 86400*(10-1)) * 1000, rating, null)) AS rating_sum_10d + |FROM $viewName CROSS JOIN $datesViewName + |WHERE ts < unix_timestamp($datesViewName.ds, 'yyyy-MM-dd') * 1000 + ${tableUtils.partitionSpec.spanMillis} + |group by user, $datesViewName.ds + |""".stripMargin) + + val diff = Comparison.sideBySide(actualDf, expectedDf, List("user", tableUtils.partitionColumn)) + if (diff.count() > 0) { + diff.show() + println("diff result rows") + } + assertEquals(0, diff.count()) + val diffIncremental = Comparison.sideBySide(actualDfIncremental, expectedDf, List("user", tableUtils.partitionColumn)) + if (diffIncremental.count() > 0) { + diffIncremental.show() + println("diff result rows incremental") + } + assertEquals(0, diffIncremental.count()) } } From 7987931b1fda6229a689fc0779ce137d846269c7 Mon Sep 17 00:00:00 2001 From: chaitu Date: Tue, 2 Sep 2025 21:45:19 -0700 Subject: [PATCH 16/28] changes for incrementalg --- api/py/test/sample/scripts/spark_submit.sh | 9 ++++++++- api/py/test/sample/teams.json | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/api/py/test/sample/scripts/spark_submit.sh b/api/py/test/sample/scripts/spark_submit.sh index 45102e8843..ef048a5532 100644 --- a/api/py/test/sample/scripts/spark_submit.sh +++ b/api/py/test/sample/scripts/spark_submit.sh @@ -28,13 +28,14 @@ set -euxo pipefail CHRONON_WORKING_DIR=${CHRONON_TMPDIR:-/tmp}/${USER} +echo $CHRONON_WORKING_DIR mkdir -p ${CHRONON_WORKING_DIR} export TEST_NAME="${APP_NAME}_${USER}_test" unset PYSPARK_DRIVER_PYTHON unset PYSPARK_PYTHON unset SPARK_HOME unset SPARK_CONF_DIR -export LOG4J_FILE="${CHRONON_WORKING_DIR}/log4j_file" +export LOG4J_FILE="${CHRONON_WORKING_DIR}/log4j.properties" cat > ${LOG4J_FILE} << EOF log4j.rootLogger=INFO, stdout log4j.appender.stdout=org.apache.log4j.ConsoleAppender @@ -47,6 +48,9 @@ EOF $SPARK_SUBMIT_PATH \ --driver-java-options " -Dlog4j.configuration=file:${LOG4J_FILE}" \ --conf "spark.executor.extraJavaOptions= -XX:ParallelGCThreads=4 -XX:+UseParallelGC -XX:+UseCompressedOops" \ +--conf "spark.driver.extraJavaOptions=-agentlib:jdwp=transport=dt_socket,server=y,suspend=y,address=5005 -Dlog4j.configuration=file:${LOG4J_FILE}" \ +--conf "spark.sql.warehouse.dir=/home/chaitu/projects/chronon/spark-warehouse" \ +--conf "javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=/home/chaitu/projects/chronon/hive-metastore/metastore_db;create=true" \ --conf spark.sql.shuffle.partitions=${PARALLELISM:-4000} \ --conf spark.dynamicAllocation.maxExecutors=${MAX_EXECUTORS:-1000} \ --conf spark.default.parallelism=${PARALLELISM:-4000} \ @@ -77,3 +81,6 @@ tee ${CHRONON_WORKING_DIR}/${APP_NAME}_spark.log + +#--conf "spark.driver.extraJavaOptions=-agentlib:jdwp=transport=dt_socket,server=y,suspend=y,address=5005 -Dlog4j.rootLogger=INFO,console" \ + diff --git a/api/py/test/sample/teams.json b/api/py/test/sample/teams.json index 39f7a25559..a60502b65d 100644 --- a/api/py/test/sample/teams.json +++ b/api/py/test/sample/teams.json @@ -5,7 +5,7 @@ }, "common_env": { "VERSION": "latest", - "SPARK_SUBMIT_PATH": "[TODO]/path/to/spark-submit", + "SPARK_SUBMIT_PATH": "spark-submit", "JOB_MODE": "local[*]", "HADOOP_DIR": "[STREAMING-TODO]/path/to/folder/containing", "CHRONON_ONLINE_CLASS": "[ONLINE-TODO]your.online.class", From 7b62a43415a45abffe837b4a89d3fdd4370d13db Mon Sep 17 00:00:00 2001 From: chaitu Date: Fri, 5 Sep 2025 02:14:13 -0700 Subject: [PATCH 17/28] add last hole logic for incrementnal bacckfill --- .../main/scala/ai/chronon/spark/GroupBy.scala | 43 ++++++++++++++++--- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala index d536ed2876..e5f265c5b4 100644 --- a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala +++ b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala @@ -36,6 +36,7 @@ import org.slf4j.LoggerFactory import java.util import scala.collection.{Seq, mutable} import scala.util.ScalaJavaConversions.{JListOps, ListOps, MapOps} +import _root_.com.google.common.collect.Table class GroupBy(val aggregations: Seq[api.Aggregation], val keyColumns: Seq[String], @@ -426,7 +427,6 @@ class GroupBy(val aggregations: Seq[api.Aggregation], tableProps: Map[String, String]) = { val hops = hopsAggregate(range.toTimePoints.min, DailyResolution) - println(s"Saving incremental hops to ${hops.map(x => x._1.data.mkString(",")).take(20)}.") val hopsDf: DataFrame = convertHopsToDf(hops, incrementalSchema) hopsDf.save(incrementalOutputTable, tableProps) } @@ -748,14 +748,49 @@ object GroupBy { .map(_.toScala) .orNull + logger.info(s"Writing incremental df to $incrementalOutputTable") val incrementalGroupByBackfill = from(groupByConf, range, tableUtils, computeDependency = true, incrementalMode = true) - incrementalGroupByBackfill.computeIncrementalDf(incrementalOutputTable, range, tableProps) + val incTableRange = PartitionRange( + tableUtils.firstAvailablePartition(incrementalOutputTable).get, + tableUtils.lastAvailablePartition(incrementalOutputTable).get + )(tableUtils) + + val allPartitionRangeHoles: Option[Seq[PartitionRange]] = computePartitionRangeHoles(incTableRange, range, tableUtils) + + allPartitionRangeHoles.foreach { holes => + holes.foreach { hole => + logger.info(s"Filling hole in incremental table: $hole") + incrementalGroupByBackfill.computeIncrementalDf(incrementalOutputTable, range, tableProps) + } + } incrementalGroupByBackfill } +/** + * Compute the holes in the incremental output table + * + * holes are partitions that are not in the incremetnal output table but are in the source queryable range + * + * @param incTableRange the range of the incremental output table + * @param sourceQueryableRange the range of the source queryable range + * @return the holes in the incremental output table + */ + private def computePartitionRangeHoles( + incTableRange: PartitionRange, + queryRange: PartitionRange, + tableUtils: TableUtils): Option[Seq[PartitionRange]] = { + + + if (queryRange.end <= incTableRange.end) { + None + } else { + Some(Seq(PartitionRange(tableUtils.partitionSpec.shift(incTableRange.end, 1), queryRange.end))) + } + } + def fromIncrementalDf( groupByConf: api.GroupBy, range: PartitionRange, @@ -779,9 +814,7 @@ object GroupBy { incTableLastPartition.get )(tableUtils) - val incrementalDf: DataFrame = tableUtils.sql( - incTableRange.intersect(sourceQueryableRange).genScanQuery(null, incrementalOutputTable) - ) + val (_, incrementalDf: DataFrame) = incTableRange.intersect(sourceQueryableRange).scanQueryStringAndDf(null, incrementalOutputTable) val incrementalAggregations = incrementalGroupByBackfill.flattenedAgg.aggregationParts.zip(groupByConf.getAggregations.toScala).map{ case (part, agg) => val newAgg = agg.deepCopy() From aeeb5ecb71f773a262ebd2333bd42e3363bcc1a5 Mon Sep 17 00:00:00 2001 From: chaitu Date: Fri, 5 Sep 2025 02:20:14 -0700 Subject: [PATCH 18/28] fix syntax --- spark/src/main/scala/ai/chronon/spark/GroupBy.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala index e5f265c5b4..b1e4390984 100644 --- a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala +++ b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala @@ -787,7 +787,7 @@ object GroupBy { if (queryRange.end <= incTableRange.end) { None } else { - Some(Seq(PartitionRange(tableUtils.partitionSpec.shift(incTableRange.end, 1), queryRange.end))) + Some(Seq(PartitionRange(tableUtils.partitionSpec.shift(incTableRange.end, 1), queryRange.end)(tableUtils))) } } From 9180d233cdd74976cdb18528b08552f198f5841e Mon Sep 17 00:00:00 2001 From: chaitu Date: Sat, 6 Sep 2025 01:09:01 -0700 Subject: [PATCH 19/28] fix bug : backfill only for missing holes --- spark/src/main/scala/ai/chronon/spark/GroupBy.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala index b1e4390984..a8f5b4e45b 100644 --- a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala +++ b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala @@ -762,7 +762,7 @@ object GroupBy { allPartitionRangeHoles.foreach { holes => holes.foreach { hole => logger.info(s"Filling hole in incremental table: $hole") - incrementalGroupByBackfill.computeIncrementalDf(incrementalOutputTable, range, tableProps) + incrementalGroupByBackfill.computeIncrementalDf(incrementalOutputTable, hole, tableProps) } } From ee81672109f80a022d3b53b268e21486b66698c9 Mon Sep 17 00:00:00 2001 From: chaitu Date: Sun, 7 Sep 2025 05:38:37 -0700 Subject: [PATCH 20/28] fix none error for inc Table --- .../main/scala/ai/chronon/spark/GroupBy.scala | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala index a8f5b4e45b..fe142949d5 100644 --- a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala +++ b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala @@ -752,14 +752,20 @@ object GroupBy { val incrementalGroupByBackfill = from(groupByConf, range, tableUtils, computeDependency = true, incrementalMode = true) - val incTableRange = PartitionRange( - tableUtils.firstAvailablePartition(incrementalOutputTable).get, - tableUtils.lastAvailablePartition(incrementalOutputTable).get - )(tableUtils) + val incTableRange: Option[PartitionRange] = for { + first <- tableUtils.firstAvailablePartition(incrementalOutputTable) + last <- tableUtils.lastAvailablePartition(incrementalOutputTable) + } yield + PartitionRange(first, last)(tableUtils) - val allPartitionRangeHoles: Option[Seq[PartitionRange]] = computePartitionRangeHoles(incTableRange, range, tableUtils) - allPartitionRangeHoles.foreach { holes => + val partitionRangeHoles: Option[Seq[PartitionRange]] = incTableRange match { + case None => Some(Seq(range)) + case Some(incrementalTableRange) => + computePartitionRangeHoles(incrementalTableRange, range, tableUtils) + } + + partitionRangeHoles.foreach { holes => holes.foreach { hole => logger.info(s"Filling hole in incremental table: $hole") incrementalGroupByBackfill.computeIncrementalDf(incrementalOutputTable, hole, tableProps) @@ -771,8 +777,7 @@ object GroupBy { /** * Compute the holes in the incremental output table - * - * holes are partitions that are not in the incremetnal output table but are in the source queryable range + * * * @param incTableRange the range of the incremental output table * @param sourceQueryableRange the range of the source queryable range From 29a3f2814ff8c720f0f7528fc7c9c53e9e1749ae Mon Sep 17 00:00:00 2001 From: chaitu Date: Fri, 19 Sep 2025 13:41:11 -0700 Subject: [PATCH 21/28] add incremental table queryable range --- .../main/scala/ai/chronon/spark/GroupBy.scala | 68 ++++++------------- 1 file changed, 19 insertions(+), 49 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala index fe142949d5..debf47f1b9 100644 --- a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala +++ b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala @@ -741,29 +741,27 @@ object GroupBy { groupByConf: api.GroupBy, range: PartitionRange, tableUtils: TableUtils, - ): GroupBy = { + incrementalOutputTable: String, + incrementalGroupByBackfill: GroupBy, + ): PartitionRange = { - val incrementalOutputTable: String = groupByConf.metaData.incrementalOutputTable val tableProps: Map[String, String] = Option(groupByConf.metaData.tableProperties) .map(_.toScala) .orNull - logger.info(s"Writing incremental df to $incrementalOutputTable") - val incrementalGroupByBackfill = - from(groupByConf, range, tableUtils, computeDependency = true, incrementalMode = true) + val incrementalQueryableRange = PartitionRange( + tableUtils.partitionSpec.minus(range.start, groupByConf.maxWindow.get), + range.end + )(tableUtils) - val incTableRange: Option[PartitionRange] = for { - first <- tableUtils.firstAvailablePartition(incrementalOutputTable) - last <- tableUtils.lastAvailablePartition(incrementalOutputTable) - } yield - PartitionRange(first, last)(tableUtils) + logger.info(s"Writing incremental df to $incrementalOutputTable") - val partitionRangeHoles: Option[Seq[PartitionRange]] = incTableRange match { - case None => Some(Seq(range)) - case Some(incrementalTableRange) => - computePartitionRangeHoles(incrementalTableRange, range, tableUtils) - } + + val partitionRangeHoles: Option[Seq[PartitionRange]] = tableUtils.unfilledRanges( + incrementalOutputTable, + incrementalQueryableRange, + ) partitionRangeHoles.foreach { holes => holes.foreach { hole => @@ -772,29 +770,9 @@ object GroupBy { } } - incrementalGroupByBackfill + incrementalQueryableRange } -/** - * Compute the holes in the incremental output table - * - * - * @param incTableRange the range of the incremental output table - * @param sourceQueryableRange the range of the source queryable range - * @return the holes in the incremental output table - */ - private def computePartitionRangeHoles( - incTableRange: PartitionRange, - queryRange: PartitionRange, - tableUtils: TableUtils): Option[Seq[PartitionRange]] = { - - - if (queryRange.end <= incTableRange.end) { - None - } else { - Some(Seq(PartitionRange(tableUtils.partitionSpec.shift(incTableRange.end, 1), queryRange.end)(tableUtils))) - } - } def fromIncrementalDf( groupByConf: api.GroupBy, @@ -802,24 +780,16 @@ object GroupBy { tableUtils: TableUtils, ): GroupBy = { + val incrementalOutputTable = groupByConf.metaData.incrementalOutputTable - val incrementalGroupByBackfill: GroupBy = computeIncrementalDf(groupByConf, range, tableUtils) + val incrementalGroupByBackfill = + from(groupByConf, range, tableUtils, computeDependency = true, incrementalMode = true) - val incrementalOutputTable = groupByConf.metaData.incrementalOutputTable - val sourceQueryableRange = PartitionRange( - tableUtils.partitionSpec.minus(range.start, groupByConf.maxWindow.get), - range.end - )(tableUtils) - val incTableFirstPartition: Option[String] = tableUtils.firstAvailablePartition(incrementalOutputTable) - val incTableLastPartition: Option[String] = tableUtils.lastAvailablePartition(incrementalOutputTable) + val incrementalQueryableRange = computeIncrementalDf(groupByConf, range, tableUtils, incrementalOutputTable, incrementalGroupByBackfill) - val incTableRange = PartitionRange( - incTableFirstPartition.get, - incTableLastPartition.get - )(tableUtils) + val (_, incrementalDf: DataFrame) = incrementalQueryableRange.scanQueryStringAndDf(null, incrementalOutputTable) - val (_, incrementalDf: DataFrame) = incTableRange.intersect(sourceQueryableRange).scanQueryStringAndDf(null, incrementalOutputTable) val incrementalAggregations = incrementalGroupByBackfill.flattenedAgg.aggregationParts.zip(groupByConf.getAggregations.toScala).map{ case (part, agg) => val newAgg = agg.deepCopy() From aa1601084f0676f040a1b9fd40aaf0c34f599078 Mon Sep 17 00:00:00 2001 From: chaitu Date: Fri, 19 Sep 2025 14:10:40 -0700 Subject: [PATCH 22/28] add logging for tableUtils --- .../src/main/scala/ai/chronon/spark/TableUtils.scala | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala index a83943407d..44d7b19758 100644 --- a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala @@ -855,6 +855,8 @@ case class TableUtils(sparkSession: SparkSession) { inputToOutputShift: Int = 0, skipFirstHole: Boolean = true): Option[Seq[PartitionRange]] = { + logger.info(s"-----------UnfilledRanges---------------------") + logger.info(s"unfilled range called for output table: $outputTable") val validPartitionRange = if (outputPartitionRange.start == null) { // determine partition range automatically val inputStart = inputTables.flatMap( _.map(table => @@ -872,6 +874,8 @@ case class TableUtils(sparkSession: SparkSession) { } else { outputPartitionRange } + + logger.info(s"Determined valid partition range: $validPartitionRange") val outputExisting = partitions(outputTable) // To avoid recomputing partitions removed by retention mechanisms we will not fill holes in the very beginning of the range // If a user fills a new partition in the newer end of the range, then we will never fill any partitions before that range. @@ -881,13 +885,19 @@ case class TableUtils(sparkSession: SparkSession) { } else { validPartitionRange.start } + + logger.info(s"Cutoff partition for skipping holes is set to $cutoffPartition") val fillablePartitions = if (skipFirstHole) { validPartitionRange.partitions.toSet.filter(_ >= cutoffPartition) } else { validPartitionRange.partitions.toSet } + + logger.info(s"Fillable partitions : ${fillablePartitions}") val outputMissing = fillablePartitions -- outputExisting + + logger.info(s"outputMissing : ${outputMissing}") val allInputExisting = inputTables .map { tables => tables @@ -900,6 +910,8 @@ case class TableUtils(sparkSession: SparkSession) { } .getOrElse(fillablePartitions) + logger.info(s"allInputExisting : ${allInputExisting}") + val inputMissing = fillablePartitions -- allInputExisting val missingPartitions = outputMissing -- inputMissing val missingChunks = chunk(missingPartitions) From ff41cc9f317db023e5323d776e68daabac97e729 Mon Sep 17 00:00:00 2001 From: chaitu Date: Fri, 19 Sep 2025 15:34:15 -0700 Subject: [PATCH 23/28] add log --- spark/src/main/scala/ai/chronon/spark/TableUtils.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala index 44d7b19758..2710ccb873 100644 --- a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala @@ -877,6 +877,7 @@ case class TableUtils(sparkSession: SparkSession) { logger.info(s"Determined valid partition range: $validPartitionRange") val outputExisting = partitions(outputTable) + logger.info(s"outputExisting : ${outputExisting}") // To avoid recomputing partitions removed by retention mechanisms we will not fill holes in the very beginning of the range // If a user fills a new partition in the newer end of the range, then we will never fill any partitions before that range. // We instead log a message saying why we won't fill the earliest hole. From aa25f9fb499c8a0ad4565b21fa1754fe18ecb6a3 Mon Sep 17 00:00:00 2001 From: chaitu Date: Sun, 21 Sep 2025 23:19:47 -0700 Subject: [PATCH 24/28] fill incremental holes --- .../scala/ai/chronon/spark/DataRange.scala | 5 +++ .../main/scala/ai/chronon/spark/GroupBy.scala | 35 ++++++++++--------- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/DataRange.scala b/spark/src/main/scala/ai/chronon/spark/DataRange.scala index b0af96e05d..e6f13e0ea0 100644 --- a/spark/src/main/scala/ai/chronon/spark/DataRange.scala +++ b/spark/src/main/scala/ai/chronon/spark/DataRange.scala @@ -53,6 +53,11 @@ case class PartitionRange(start: String, end: String)(implicit tableUtils: Table } } + def daysBetween: Int = { + if (start == null || end == null) 0 + else Stream.iterate(start)(tableUtils.partitionSpec.after).takeWhile(_ <= end).size + } + def isSingleDay: Boolean = { start == end } diff --git a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala index debf47f1b9..ac45afb825 100644 --- a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala +++ b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala @@ -23,7 +23,7 @@ import ai.chronon.aggregator.windowing._ import ai.chronon.api import ai.chronon.api.DataModel.{Entities, Events} import ai.chronon.api.Extensions._ -import ai.chronon.api.{Accuracy, Constants, DataModel, ParametricMacro, Source} +import ai.chronon.api.{Accuracy, Constants, DataModel, ParametricMacro, Source, TimeUnit, Window} import ai.chronon.online.{RowWrapper, SparkConversions} import ai.chronon.spark.Extensions._ import org.apache.spark.rdd.RDD @@ -505,6 +505,8 @@ object GroupBy { incrementalMode: Boolean = false): GroupBy = { logger.info(s"\n----[Processing GroupBy: ${groupByConfOld.metaData.name}]----") val groupByConf = replaceJoinSource(groupByConfOld, queryRange, tableUtils, computeDependency, showDf) + val sourceQueryWindow: Option[Window] = if (incrementalMode) Some(new Window(queryRange.daysBetween, TimeUnit.DAYS)) else groupByConf.maxWindow + val backfillQueryRange: PartitionRange = if (incrementalMode) PartitionRange(queryRange.end, queryRange.end)(tableUtils) else queryRange val inputDf = groupByConf.sources.toScala .map { source => val partitionColumn = tableUtils.getPartitionColumn(source.query) @@ -513,9 +515,9 @@ object GroupBy { groupByConf, source, groupByConf.getKeyColumns.toScala, - queryRange, + backfillQueryRange, tableUtils, - groupByConf.maxWindow, + sourceQueryWindow, groupByConf.inferredAccuracy, partitionColumn = partitionColumn ), @@ -742,8 +744,7 @@ object GroupBy { range: PartitionRange, tableUtils: TableUtils, incrementalOutputTable: String, - incrementalGroupByBackfill: GroupBy, - ): PartitionRange = { + ): (PartitionRange, Seq[api.AggregationPart]) = { val tableProps: Map[String, String] = Option(groupByConf.metaData.tableProperties) .map(_.toScala) @@ -754,23 +755,28 @@ object GroupBy { range.end )(tableUtils) - logger.info(s"Writing incremental df to $incrementalOutputTable") - val partitionRangeHoles: Option[Seq[PartitionRange]] = tableUtils.unfilledRanges( incrementalOutputTable, incrementalQueryableRange, ) - partitionRangeHoles.foreach { holes => + val incrementalGroupByAggParts = partitionRangeHoles.map { holes => holes.foreach { hole => logger.info(s"Filling hole in incremental table: $hole") + val incrementalGroupByBackfill = + from(groupByConf, hole, tableUtils, computeDependency = true, incrementalMode = true) incrementalGroupByBackfill.computeIncrementalDf(incrementalOutputTable, hole, tableProps) } - } - incrementalQueryableRange + holes.headOption.map { firstHole => + from(groupByConf, firstHole, tableUtils, computeDependency = true, incrementalMode = true) + .flattenedAgg.aggregationParts + }.getOrElse(Seq.empty) + }.getOrElse(Seq.empty) + + (incrementalQueryableRange, incrementalGroupByAggParts) } @@ -782,16 +788,11 @@ object GroupBy { val incrementalOutputTable = groupByConf.metaData.incrementalOutputTable - val incrementalGroupByBackfill = - from(groupByConf, range, tableUtils, computeDependency = true, incrementalMode = true) - - - val incrementalQueryableRange = computeIncrementalDf(groupByConf, range, tableUtils, incrementalOutputTable, incrementalGroupByBackfill) + val (incrementalQueryableRange, aggregationParts) = computeIncrementalDf(groupByConf, range, tableUtils, incrementalOutputTable) val (_, incrementalDf: DataFrame) = incrementalQueryableRange.scanQueryStringAndDf(null, incrementalOutputTable) - - val incrementalAggregations = incrementalGroupByBackfill.flattenedAgg.aggregationParts.zip(groupByConf.getAggregations.toScala).map{ case (part, agg) => + val incrementalAggregations = aggregationParts.zip(groupByConf.getAggregations.toScala).map{ case (part, agg) => val newAgg = agg.deepCopy() newAgg.setInputColumn(part.incrementalOutputColumnName) newAgg From 3efe8cd588edba06cb558fbc3621148a149e9244 Mon Sep 17 00:00:00 2001 From: chaitu Date: Wed, 1 Oct 2025 23:07:06 -0700 Subject: [PATCH 25/28] modify incremental aggregation parts --- spark/src/main/scala/ai/chronon/spark/GroupBy.scala | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala index ac45afb825..5131a18121 100644 --- a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala +++ b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala @@ -763,18 +763,16 @@ object GroupBy { ) val incrementalGroupByAggParts = partitionRangeHoles.map { holes => - holes.foreach { hole => + val incrementalAggregationParts = holes.map{ hole => logger.info(s"Filling hole in incremental table: $hole") val incrementalGroupByBackfill = from(groupByConf, hole, tableUtils, computeDependency = true, incrementalMode = true) incrementalGroupByBackfill.computeIncrementalDf(incrementalOutputTable, hole, tableProps) + incrementalGroupByBackfill.flattenedAgg.aggregationParts } - holes.headOption.map { firstHole => - from(groupByConf, firstHole, tableUtils, computeDependency = true, incrementalMode = true) - .flattenedAgg.aggregationParts - }.getOrElse(Seq.empty) - }.getOrElse(Seq.empty) + incrementalAggregationParts.headOption.getOrElse(Seq.empty) + }.getOrElse(Seq.empty) (incrementalQueryableRange, incrementalGroupByAggParts) } From a3bece636429584fd324940ef6a241dc9506d902 Mon Sep 17 00:00:00 2001 From: chaitu Date: Mon, 6 Oct 2025 22:10:14 -0700 Subject: [PATCH 26/28] remove logs for debugging --- spark/src/main/scala/ai/chronon/spark/TableUtils.scala | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala index 2710ccb873..3879bf1636 100644 --- a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala @@ -855,8 +855,6 @@ case class TableUtils(sparkSession: SparkSession) { inputToOutputShift: Int = 0, skipFirstHole: Boolean = true): Option[Seq[PartitionRange]] = { - logger.info(s"-----------UnfilledRanges---------------------") - logger.info(s"unfilled range called for output table: $outputTable") val validPartitionRange = if (outputPartitionRange.start == null) { // determine partition range automatically val inputStart = inputTables.flatMap( _.map(table => @@ -874,8 +872,6 @@ case class TableUtils(sparkSession: SparkSession) { } else { outputPartitionRange } - - logger.info(s"Determined valid partition range: $validPartitionRange") val outputExisting = partitions(outputTable) logger.info(s"outputExisting : ${outputExisting}") // To avoid recomputing partitions removed by retention mechanisms we will not fill holes in the very beginning of the range @@ -895,10 +891,8 @@ case class TableUtils(sparkSession: SparkSession) { validPartitionRange.partitions.toSet } - logger.info(s"Fillable partitions : ${fillablePartitions}") val outputMissing = fillablePartitions -- outputExisting - logger.info(s"outputMissing : ${outputMissing}") val allInputExisting = inputTables .map { tables => tables @@ -910,9 +904,7 @@ case class TableUtils(sparkSession: SparkSession) { .map(partitionSpec.shift(_, inputToOutputShift)) } .getOrElse(fillablePartitions) - - logger.info(s"allInputExisting : ${allInputExisting}") - + val inputMissing = fillablePartitions -- allInputExisting val missingPartitions = outputMissing -- inputMissing val missingChunks = chunk(missingPartitions) From 32f6ac9adef39909efe73769d5172201ca52101f Mon Sep 17 00:00:00 2001 From: Abby Whittier Date: Fri, 10 Oct 2025 03:30:05 +0000 Subject: [PATCH 27/28] support serializeIR --- .../aggregator/base/SimpleAggregators.scala | 51 ------ .../aggregator/row/ColumnAggregator.scala | 8 - .../test/SawtoothOnlineAggregatorTest.scala | 14 +- .../scala/ai/chronon/api/Extensions.scala | 1 - .../ai/chronon/online/SparkConversions.scala | 74 ++++++++ .../main/scala/ai/chronon/spark/GroupBy.scala | 158 +++++++++++------- .../scala/ai/chronon/spark/TableUtils.scala | 2 +- .../ai/chronon/spark/test/GroupByTest.scala | 38 +++-- .../chronon/spark/test/StagingQueryTest.scala | 57 ++++--- 9 files changed, 234 insertions(+), 169 deletions(-) diff --git a/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala b/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala index 370386cb45..520ae135b2 100644 --- a/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala +++ b/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala @@ -117,57 +117,6 @@ class UniqueCount[T](inputType: DataType) extends SimpleAggregator[T, util.HashS } } -class AverageIR extends SimpleAggregator[Array[Any], Array[Any], Double] { - override def outputType: DataType = DoubleType - - override def irType: DataType = - StructType( - "AvgIr", - Array(StructField("sum", DoubleType), StructField("count", IntType)) - ) - - override def prepare(input: Array[Any]): Array[Any] = { - Array(input(0).asInstanceOf[Double], input(1).asInstanceOf[Int]) - } - - // mutating - override def update(ir: Array[Any], input: Array[Any]): Array[Any] = { - val inputSum = input(0).asInstanceOf[Double] - val inputCount = input(1).asInstanceOf[Int] - ir.update(0, ir(0).asInstanceOf[Double] + inputSum) - ir.update(1, ir(1).asInstanceOf[Int] + inputCount) - ir - } - - // mutating - override def merge(ir1: Array[Any], ir2: Array[Any]): Array[Any] = { - ir1.update(0, ir1(0).asInstanceOf[Double] + ir2(0).asInstanceOf[Double]) - ir1.update(1, ir1(1).asInstanceOf[Int] + ir2(1).asInstanceOf[Int]) - ir1 - } - - override def finalize(ir: Array[Any]): Double = - ir(0).asInstanceOf[Double] / ir(1).asInstanceOf[Int].toDouble - - override def delete(ir: Array[Any], input: Array[Any]): Array[Any] = { - val inputSum = input(0).asInstanceOf[Double] - val inputCount = input(1).asInstanceOf[Int] - ir.update(0, ir(0).asInstanceOf[Double] - inputSum) - ir.update(1, ir(1).asInstanceOf[Int] - inputCount) - ir - } - - override def clone(ir: Array[Any]): Array[Any] = { - val arr = new Array[Any](ir.length) - ir.copyToArray(arr) - arr - } - - override def isDeletable: Boolean = true -} - - - class Average extends SimpleAggregator[Double, Array[Any], Double] { override def outputType: DataType = DoubleType diff --git a/aggregator/src/main/scala/ai/chronon/aggregator/row/ColumnAggregator.scala b/aggregator/src/main/scala/ai/chronon/aggregator/row/ColumnAggregator.scala index 72bfb4f405..d0be4ead20 100644 --- a/aggregator/src/main/scala/ai/chronon/aggregator/row/ColumnAggregator.scala +++ b/aggregator/src/main/scala/ai/chronon/aggregator/row/ColumnAggregator.scala @@ -217,13 +217,6 @@ object ColumnAggregator { private def toJavaDouble[A: Numeric](inp: Any) = implicitly[Numeric[A]].toDouble(inp.asInstanceOf[A]).asInstanceOf[java.lang.Double] - - private def toStructArray(inp: Any): Array[Any] = inp match { - case r: org.apache.spark.sql.Row => r.toSeq.toArray - case null => null - case other => throw new IllegalArgumentException(s"Expected Row, got: $other") - } - def construct(baseInputType: DataType, aggregationPart: AggregationPart, columnIndices: ColumnIndices, @@ -349,7 +342,6 @@ object ColumnAggregator { case ShortType => simple(new Average, toDouble[Short]) case DoubleType => simple(new Average) case FloatType => simple(new Average, toDouble[Float]) - case StructType(name, fields) => simple(new AverageIR, toStructArray) case _ => mismatchException } diff --git a/aggregator/src/test/scala/ai/chronon/aggregator/test/SawtoothOnlineAggregatorTest.scala b/aggregator/src/test/scala/ai/chronon/aggregator/test/SawtoothOnlineAggregatorTest.scala index 86546e3669..4b2a6e31b8 100644 --- a/aggregator/src/test/scala/ai/chronon/aggregator/test/SawtoothOnlineAggregatorTest.scala +++ b/aggregator/src/test/scala/ai/chronon/aggregator/test/SawtoothOnlineAggregatorTest.scala @@ -143,7 +143,7 @@ class SawtoothOnlineAggregatorTest extends TestCase { operation = Operation.HISTOGRAM, inputColumn = "action", windows = Seq( - new Window(3, TimeUnit.DAYS), + new Window(3, TimeUnit.DAYS) ) ) ) @@ -162,15 +162,15 @@ class SawtoothOnlineAggregatorTest extends TestCase { val finalBatchIr = FinalBatchIr( Array[Any]( - null, // collapsed (T-1 -> T) + null // collapsed (T-1 -> T) ), Array( - Array.empty, // 1‑day hops (not used) - Array( // 1-hour hops - hop(1, 1746745200000L), // 2025-05-08 23:00:00 UTC - hop(1, 1746766800000L), // 2025-05-09 05:00:00 UTC + Array.empty, // 1‑day hops (not used) + Array( // 1-hour hops + hop(1, 1746745200000L), // 2025-05-08 23:00:00 UTC + hop(1, 1746766800000L) // 2025-05-09 05:00:00 UTC ), - Array.empty // 5‑minute hops (not used) + Array.empty // 5‑minute hops (not used) ) ) val queryTs = batchEndTs + 100 diff --git a/api/src/main/scala/ai/chronon/api/Extensions.scala b/api/src/main/scala/ai/chronon/api/Extensions.scala index d65262d420..6a4cb6de6f 100644 --- a/api/src/main/scala/ai/chronon/api/Extensions.scala +++ b/api/src/main/scala/ai/chronon/api/Extensions.scala @@ -185,7 +185,6 @@ object Extensions { } - implicit class AggregationOps(aggregation: Aggregation) { // one agg part per bucket per window diff --git a/online/src/main/scala/ai/chronon/online/SparkConversions.scala b/online/src/main/scala/ai/chronon/online/SparkConversions.scala index c669f29a1d..22d2605a04 100644 --- a/online/src/main/scala/ai/chronon/online/SparkConversions.scala +++ b/online/src/main/scala/ai/chronon/online/SparkConversions.scala @@ -163,4 +163,78 @@ object SparkConversions { extraneousRecord ) } + + /** + * Converts a single Spark column value to Chronon normalized IR format. + * + * This is the inverse of toSparkRow() - used when reading pre-computed IR values + * from Spark DataFrames. Each IR column in the DataFrame is converted based on its + * Chronon IR type. + * + * Examples: + * - Count IR: Long → Long (pass-through, primitives stay primitives) + * - Sum IR: Double → Double (pass-through) + * - Average IR: Spark Row(sum, count) → Array[Any](sum, count) + * - UniqueCount IR: Spark Array[T] → java.util.ArrayList[T] + * - Histogram IR: Spark Map[K,V] → java.util.HashMap[K,V] + * - ApproxPercentile IR: Array[Byte] → Array[Byte] (pass-through for binary) + * + * @param sparkValue The value from a Spark DataFrame column + * @param irType The Chronon IR type for this column (from RowAggregator.incrementalOutputSchema) + * @return Normalized IR value ready for denormalize() + */ + def fromSparkValue(sparkValue: Any, irType: api.DataType): Any = { + if (sparkValue == null) return null + + (sparkValue, irType) match { + // Primitives - pass through (Count, Sum, Min, Max, Binary sketches) + case (v, + api.IntType | api.LongType | api.ShortType | api.ByteType | api.FloatType | api.DoubleType | + api.StringType | api.BooleanType | api.BinaryType) => + v + + // Spark Row → Array[Any] (Average, Variance, Skew, Kurtosis, FirstK/LastK) + case (row: Row, api.StructType(_, fields)) => + val arr = new Array[Any](fields.length) + fields.zipWithIndex.foreach { + case (field, idx) => + arr(idx) = fromSparkValue(row.get(idx), field.fieldType) + } + arr + + // Spark mutable.WrappedArray → util.ArrayList (UniqueCount, TopK, BottomK) + case (arr: mutable.WrappedArray[_], api.ListType(elementType)) => + val result = new util.ArrayList[Any](arr.length) + arr.foreach { elem => + result.add(fromSparkValue(elem, elementType)) + } + result + + // Spark native Array → util.ArrayList (alternative array representation) + case (arr: Array[_], api.ListType(elementType)) => + val result = new util.ArrayList[Any](arr.length) + arr.foreach { elem => + result.add(fromSparkValue(elem, elementType)) + } + result + + // Spark scala.collection.Map → util.HashMap (Histogram) + case (map: scala.collection.Map[_, _], api.MapType(keyType, valueType)) => + val result = new util.HashMap[Any, Any]() + map.foreach { + case (k, v) => + result.put( + fromSparkValue(k, keyType), + fromSparkValue(v, valueType) + ) + } + result + + case (value, tpe) => + throw new IllegalArgumentException( + s"Cannot convert Spark value $value (${value.getClass.getSimpleName}) " + + s"to Chronon IR type $tpe" + ) + } + } } diff --git a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala index 5131a18121..8b6a5fe69f 100644 --- a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala +++ b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala @@ -44,8 +44,7 @@ class GroupBy(val aggregations: Seq[api.Aggregation], val mutationDfFn: () => DataFrame = null, skewFilter: Option[String] = None, finalize: Boolean = true, - incrementalMode: Boolean = false - ) + incrementalMode: Boolean = false) extends Serializable { @transient lazy val logger = LoggerFactory.getLogger(getClass) @@ -103,7 +102,6 @@ class GroupBy(val aggregations: Seq[api.Aggregation], lazy val flattenedAgg: RowAggregator = new RowAggregator(selectedSchema, aggregations.flatMap(_.unWindowed)) lazy val incrementalSchema: Array[(String, api.DataType)] = flattenedAgg.incrementalOutputSchema - @transient protected[spark] lazy val windowAggregator: RowAggregator = new RowAggregator(selectedSchema, aggregations.flatMap(_.unpack)) @@ -368,22 +366,27 @@ class GroupBy(val aggregations: Seq[api.Aggregation], toDf(outputRdd, Seq(Constants.TimeColumn -> LongType, tableUtils.partitionColumn -> StringType)) } - def flattenOutputArrayType(hopsArrays: RDD[(KeyWithHash, HopsAggregator.OutputArrayType)]): RDD[(Array[Any], Array[Any])] = { - hopsArrays.flatMap { case (keyWithHash: KeyWithHash, hopsArray: HopsAggregator.OutputArrayType) => - val hopsArrayHead: Array[HopIr] = hopsArray.headOption.get - hopsArrayHead.map { array: HopIr => - val timestamp = array.last.asInstanceOf[Long] - val withoutTimestamp = array.dropRight(1) - ((keyWithHash.data :+ tableUtils.partitionSpec.at(timestamp) :+ timestamp), withoutTimestamp) - } + def flattenOutputArrayType( + hopsArrays: RDD[(KeyWithHash, HopsAggregator.OutputArrayType)]): RDD[(Array[Any], Array[Any])] = { + hopsArrays.flatMap { + case (keyWithHash: KeyWithHash, hopsArray: HopsAggregator.OutputArrayType) => + val hopsArrayHead: Array[HopIr] = hopsArray.headOption.get + hopsArrayHead.map { array: HopIr => + val timestamp = array.last.asInstanceOf[Long] + val withoutTimestamp = array.dropRight(1) + // Convert IR to Spark-native format for DataFrame storage + val sparkIr = SerializeIR.toSparkRow(withoutTimestamp, flattenedAgg) + ((keyWithHash.data :+ tableUtils.partitionSpec.at(timestamp) :+ timestamp), sparkIr) + } } } def convertHopsToDf(hops: RDD[(KeyWithHash, HopsAggregator.OutputArrayType)], - schema: Array[(String, ai.chronon.api.DataType)] - ): DataFrame = { + schema: Array[(String, ai.chronon.api.DataType)]): DataFrame = { val hopsDf = flattenOutputArrayType(hops) - toDf(hopsDf, Seq(tableUtils.partitionColumn -> StringType, Constants.TimeColumn -> LongType), Some(SparkConversions.fromChrononSchema(schema))) + toDf(hopsDf, + Seq(tableUtils.partitionColumn -> StringType, Constants.TimeColumn -> LongType), + Some(SparkConversions.fromChrononSchema(schema))) } // convert raw data into IRs, collected by hopSizes @@ -409,7 +412,8 @@ class GroupBy(val aggregations: Seq[api.Aggregation], } protected[spark] def toDf(aggregateRdd: RDD[(Array[Any], Array[Any])], - additionalFields: Seq[(String, DataType)], schema: Option[StructType] = None): DataFrame = { + additionalFields: Seq[(String, DataType)], + schema: Option[StructType] = None): DataFrame = { val finalKeySchema = StructType(keySchema ++ additionalFields.map { case (name, typ) => StructField(name, typ) }) KvRdd(aggregateRdd, finalKeySchema, schema.getOrElse(postAggSchema)).toFlatDf } @@ -421,15 +425,12 @@ class GroupBy(val aggregations: Seq[api.Aggregation], windowAggregator.normalize(ir) } + def computeIncrementalDf(incrementalOutputTable: String, range: PartitionRange, tableProps: Map[String, String]) = { - def computeIncrementalDf(incrementalOutputTable: String, - range: PartitionRange, - tableProps: Map[String, String]) = { - - val hops = hopsAggregate(range.toTimePoints.min, DailyResolution) - val hopsDf: DataFrame = convertHopsToDf(hops, incrementalSchema) - hopsDf.save(incrementalOutputTable, tableProps) - } + val hops = hopsAggregate(range.toTimePoints.min, DailyResolution) + val hopsDf: DataFrame = convertHopsToDf(hops, incrementalSchema) + hopsDf.save(incrementalOutputTable, tableProps) + } } // TODO: truncate queryRange for caching @@ -505,8 +506,10 @@ object GroupBy { incrementalMode: Boolean = false): GroupBy = { logger.info(s"\n----[Processing GroupBy: ${groupByConfOld.metaData.name}]----") val groupByConf = replaceJoinSource(groupByConfOld, queryRange, tableUtils, computeDependency, showDf) - val sourceQueryWindow: Option[Window] = if (incrementalMode) Some(new Window(queryRange.daysBetween, TimeUnit.DAYS)) else groupByConf.maxWindow - val backfillQueryRange: PartitionRange = if (incrementalMode) PartitionRange(queryRange.end, queryRange.end)(tableUtils) else queryRange + val sourceQueryWindow: Option[Window] = + if (incrementalMode) Some(new Window(queryRange.daysBetween, TimeUnit.DAYS)) else groupByConf.maxWindow + val backfillQueryRange: PartitionRange = + if (incrementalMode) PartitionRange(queryRange.end, queryRange.end)(tableUtils) else queryRange val inputDf = groupByConf.sources.toScala .map { source => val partitionColumn = tableUtils.getPartitionColumn(source.query) @@ -606,12 +609,11 @@ object GroupBy { } new GroupBy(Option(groupByConf.getAggregations).map(_.toScala).orNull, - keyColumns, - nullFiltered, - mutationDfFn, - finalize = finalizeValue, - incrementalMode = incrementalMode, - ) + keyColumns, + nullFiltered, + mutationDfFn, + finalize = finalizeValue, + incrementalMode = incrementalMode) } def getIntersectedRange(source: api.Source, @@ -740,11 +742,11 @@ object GroupBy { * @param tableUtils */ def computeIncrementalDf( - groupByConf: api.GroupBy, - range: PartitionRange, - tableUtils: TableUtils, - incrementalOutputTable: String, - ): (PartitionRange, Seq[api.AggregationPart]) = { + groupByConf: api.GroupBy, + range: PartitionRange, + tableUtils: TableUtils, + incrementalOutputTable: String + ): (PartitionRange, Seq[api.AggregationPart]) = { val tableProps: Map[String, String] = Option(groupByConf.metaData.tableProperties) .map(_.toScala) @@ -759,50 +761,82 @@ object GroupBy { val partitionRangeHoles: Option[Seq[PartitionRange]] = tableUtils.unfilledRanges( incrementalOutputTable, - incrementalQueryableRange, + incrementalQueryableRange ) - val incrementalGroupByAggParts = partitionRangeHoles.map { holes => - val incrementalAggregationParts = holes.map{ hole => - logger.info(s"Filling hole in incremental table: $hole") - val incrementalGroupByBackfill = - from(groupByConf, hole, tableUtils, computeDependency = true, incrementalMode = true) - incrementalGroupByBackfill.computeIncrementalDf(incrementalOutputTable, hole, tableProps) - incrementalGroupByBackfill.flattenedAgg.aggregationParts - } + val incrementalGroupByAggParts = partitionRangeHoles + .map { holes => + val incrementalAggregationParts = holes.map { hole => + logger.info(s"Filling hole in incremental table: $hole") + val incrementalGroupByBackfill = + from(groupByConf, hole, tableUtils, computeDependency = true, incrementalMode = true) + incrementalGroupByBackfill.computeIncrementalDf(incrementalOutputTable, hole, tableProps) + incrementalGroupByBackfill.flattenedAgg.aggregationParts + } - incrementalAggregationParts.headOption.getOrElse(Seq.empty) - }.getOrElse(Seq.empty) + incrementalAggregationParts.headOption.getOrElse(Seq.empty) + } + .getOrElse(Seq.empty) (incrementalQueryableRange, incrementalGroupByAggParts) } - def fromIncrementalDf( - groupByConf: api.GroupBy, - range: PartitionRange, - tableUtils: TableUtils, - ): GroupBy = { + groupByConf: api.GroupBy, + range: PartitionRange, + tableUtils: TableUtils + ): GroupBy = { val incrementalOutputTable = groupByConf.metaData.incrementalOutputTable - val (incrementalQueryableRange, aggregationParts) = computeIncrementalDf(groupByConf, range, tableUtils, incrementalOutputTable) + val (incrementalQueryableRange, aggregationParts) = + computeIncrementalDf(groupByConf, range, tableUtils, incrementalOutputTable) + + val (_, incrementalDf: DataFrame) = incrementalQueryableRange.scanQueryStringAndDf(null, incrementalOutputTable) + + // Create RowAggregator for deserializing and merging IRs + val selectedSchema = SparkConversions.toChrononSchema(incrementalDf.schema) + val flattenedAggregations = groupByConf.getAggregations.toScala.flatMap(_.unWindowed) + val flattenedAgg = new RowAggregator(selectedSchema, flattenedAggregations) + + // Convert Spark DataFrame to RDD, deserialize IRs, and merge by key + val keyColumns = groupByConf.getKeyColumns.toScala + val keySchema = StructType(keyColumns.map(incrementalDf.schema.apply).toArray) + + val irRdd = incrementalDf.rdd.map { sparkRow => + // Extract keys + val keys = keyColumns.map(sparkRow.getAs[Any]).toArray - val (_, incrementalDf: DataFrame) = incrementalQueryableRange.scanQueryStringAndDf(null, incrementalOutputTable) + // Deserialize IR columns from Spark Row + val ir = SerializeIR.fromSparkRow(sparkRow, flattenedAgg) - val incrementalAggregations = aggregationParts.zip(groupByConf.getAggregations.toScala).map{ case (part, agg) => - val newAgg = agg.deepCopy() - newAgg.setInputColumn(part.incrementalOutputColumnName) - newAgg + (keys.toSeq, ir) } + // Merge IRs by key + val mergedRdd = irRdd.reduceByKey { (ir1, ir2) => + flattenedAgg.merge(ir1, ir2) + } + + // Finalize IRs to get final feature values + val finalRdd = mergedRdd.map { + case (keys, ir) => + (keys.toArray, flattenedAgg.finalize(ir)) + } + + // Convert back to DataFrame + val outputChrononSchema = flattenedAgg.outputSchema + val outputSparkSchema = SparkConversions.fromChrononSchema(outputChrononSchema) + implicit val session: SparkSession = incrementalDf.sparkSession + val finalDf = KvRdd(finalRdd, keySchema, outputSparkSchema).toFlatDf + new GroupBy( - incrementalAggregations, - groupByConf.getKeyColumns.toScala, - incrementalDf, + groupByConf.getAggregations.toScala, + keyColumns, + finalDf, () => null, finalize = true, - incrementalMode = false, + incrementalMode = false ) } diff --git a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala index 3879bf1636..57e6c80435 100644 --- a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala @@ -904,7 +904,7 @@ case class TableUtils(sparkSession: SparkSession) { .map(partitionSpec.shift(_, inputToOutputShift)) } .getOrElse(fillablePartitions) - + val inputMissing = fillablePartitions -- allInputExisting val missingPartitions = outputMissing -- inputMissing val missingChunks = chunk(missingPartitions) diff --git a/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala b/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala index b7fa2ef651..8739f1bbb3 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala @@ -452,7 +452,6 @@ class GroupByTest { additionalAgg = aggs) } - private def createTestSource(windowSize: Int = 365, suffix: String = "", partitionColOpt: Option[String] = None): (Source, String) = { @@ -866,10 +865,10 @@ class GroupByTest { val testDatabase = s"staging_query_view_test_${Random.alphanumeric.take(6).mkString}" tableUtils.createDatabase(testDatabase) - // Create source data table with partitions + // Create source data table with partitions val sourceSchema = List( Column("user", StringType, 20), - Column("item", StringType, 50), + Column("item", StringType, 50), Column("time_spent_ms", LongType, 5000), Column("price", DoubleType, 100) ) @@ -892,7 +891,7 @@ class GroupByTest { stagingQueryJob.createStagingQueryView() val viewTable = s"$testDatabase.test_staging_view" - + // Now create a GroupBy that uses the staging query view as its source val source = Builders.Source.events( table = viewTable, @@ -901,12 +900,12 @@ class GroupByTest { val aggregations = Seq( Builders.Aggregation(operation = Operation.COUNT, inputColumn = "time_spent_ms"), - Builders.Aggregation(operation = Operation.AVERAGE, - inputColumn = "price", - windows = Seq(new Window(7, TimeUnit.DAYS))), + Builders.Aggregation(operation = Operation.AVERAGE, + inputColumn = "price", + windows = Seq(new Window(7, TimeUnit.DAYS))), Builders.Aggregation(operation = Operation.MAX, - inputColumn = "time_spent_ms", - windows = Seq(new Window(30, TimeUnit.DAYS))) + inputColumn = "time_spent_ms", + windows = Seq(new Window(30, TimeUnit.DAYS))) ) val groupByConf = Builders.GroupBy( @@ -963,7 +962,8 @@ class GroupByTest { @Test def testIncrementalMode(): Unit = { - lazy val spark: SparkSession = SparkSessionBuilder.build("GroupByTestIncremental" + "_" + Random.alphanumeric.take(6).mkString, local = true) + lazy val spark: SparkSession = + SparkSessionBuilder.build("GroupByTestIncremental" + "_" + Random.alphanumeric.take(6).mkString, local = true) implicit val tableUtils = TableUtils(spark) val namespace = "incremental" tableUtils.createDatabase(namespace) @@ -978,9 +978,9 @@ class GroupByTest { println(s"Input DataFrame: ${df.count()}") val aggregations: Seq[Aggregation] = Seq( - //Builders.Aggregation(Operation.APPROX_UNIQUE_COUNT, "session_length", Seq(new Window(10, TimeUnit.DAYS), WindowUtils.Unbounded)), - //Builders.Aggregation(Operation.UNIQUE_COUNT, "session_length", Seq(new Window(10, TimeUnit.DAYS), WindowUtils.Unbounded)), - Builders.Aggregation(Operation.SUM, "session_length", Seq(new Window(10, TimeUnit.DAYS))) + //Builders.Aggregation(Operation.APPROX_UNIQUE_COUNT, "session_length", Seq(new Window(10, TimeUnit.DAYS), WindowUtils.Unbounded)), + //Builders.Aggregation(Operation.UNIQUE_COUNT, "session_length", Seq(new Window(10, TimeUnit.DAYS), WindowUtils.Unbounded)), + Builders.Aggregation(Operation.SUM, "session_length", Seq(new Window(10, TimeUnit.DAYS))) ) val tableProps: Map[String, String] = Map( @@ -988,7 +988,9 @@ class GroupByTest { ) val groupBy = new GroupBy(aggregations, Seq("user"), df) - groupBy.computeIncrementalDf("incremental.testIncrementalOutput", PartitionRange("2025-05-01", "2025-06-01"), tableProps) + groupBy.computeIncrementalDf("incremental.testIncrementalOutput", + PartitionRange("2025-05-01", "2025-06-01"), + tableProps) val actualIncrementalDf = spark.sql(s"select * from incremental.testIncrementalOutput where ds='2025-05-11'") df.createOrReplaceTempView("test_incremental_input") @@ -1016,7 +1018,8 @@ class GroupByTest { @Test def testSnapshotIncrementalEvents(): Unit = { - lazy val spark: SparkSession = SparkSessionBuilder.build("GroupByTest" + "_" + Random.alphanumeric.take(6).mkString, local = true) + lazy val spark: SparkSession = + SparkSessionBuilder.build("GroupByTest" + "_" + Random.alphanumeric.take(6).mkString, local = true) implicit val tableUtils = TableUtils(spark) val schema = List( Column("user", StringType, 10), // ts = last 10 days @@ -1032,7 +1035,7 @@ class GroupByTest { df.createOrReplaceTempView(viewName) val aggregations: Seq[Aggregation] = Seq( Builders.Aggregation(Operation.SUM, "session_length", Seq(new Window(10, TimeUnit.DAYS))), - Builders.Aggregation(Operation.SUM, "rating", Seq(new Window(10, TimeUnit.DAYS))) + Builders.Aggregation(Operation.SUM, "rating", Seq(new Window(10, TimeUnit.DAYS))) ) val groupBy = new GroupBy(aggregations, Seq("user"), df) @@ -1062,7 +1065,8 @@ class GroupByTest { } assertEquals(0, diff.count()) - val diffIncremental = Comparison.sideBySide(actualDfIncremental, expectedDf, List("user", tableUtils.partitionColumn)) + val diffIncremental = + Comparison.sideBySide(actualDfIncremental, expectedDf, List("user", tableUtils.partitionColumn)) if (diffIncremental.count() > 0) { diffIncremental.show() println("diff result rows incremental") diff --git a/spark/src/test/scala/ai/chronon/spark/test/StagingQueryTest.scala b/spark/src/test/scala/ai/chronon/spark/test/StagingQueryTest.scala index cbf84be04e..193aa575e9 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/StagingQueryTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/StagingQueryTest.scala @@ -323,28 +323,35 @@ class StagingQueryTest { val outputView = stagingQueryConfView.metaData.outputTable val isView = tableUtils.tableReadFormat(outputView) match { case View => true - case _ => false + case _ => false } - + assert(isView, s"Expected $outputView to be a view when createView=true") // Verify virtual partition metadata was written for the view - val virtualPartitionExists = try { - val metadataCount = tableUtils.sql(s"SELECT COUNT(*) as count FROM ${stagingQueryView.signalPartitionsTable} WHERE table_name = '$outputView'").collect()(0).getAs[Long]("count") - metadataCount > 0 - } catch { - case _: Exception => false - } + val virtualPartitionExists = + try { + val metadataCount = tableUtils + .sql( + s"SELECT COUNT(*) as count FROM ${stagingQueryView.signalPartitionsTable} WHERE table_name = '$outputView'") + .collect()(0) + .getAs[Long]("count") + metadataCount > 0 + } catch { + case _: Exception => false + } assert(virtualPartitionExists, s"Expected virtual partition metadata to exist for view $outputView") // Verify the structure of virtual partition metadata if (virtualPartitionExists) { - val metadataRows = tableUtils.sql(s"SELECT * FROM ${stagingQueryView.signalPartitionsTable} WHERE table_name = '$outputView'").collect() + val metadataRows = tableUtils + .sql(s"SELECT * FROM ${stagingQueryView.signalPartitionsTable} WHERE table_name = '$outputView'") + .collect() assert(metadataRows.length > 0, "Should have at least one partition metadata entry") - + val firstRow = metadataRows(0) val tableName = firstRow.getAs[String]("table_name") - + assertEquals(s"Virtual partition metadata should have correct table name", outputView, tableName) } @@ -362,19 +369,25 @@ class StagingQueryTest { val outputTable = stagingQueryConfTable.metaData.outputTable val isTable = tableUtils.tableReadFormat(outputTable) match { case View => false - case _ => true + case _ => true } - + assert(isTable, s"Expected $outputTable to be a table when createView=false") // Verify virtual partition metadata was NOT written for the table - val virtualPartitionExistsForTable = try { - val metadataCountForTable = tableUtils.sql(s"SELECT COUNT(*) as count FROM ${stagingQueryTable.signalPartitionsTable} WHERE table_name = '$outputTable'").collect()(0).getAs[Long]("count") - metadataCountForTable > 0 - } catch { - case _: Exception => false - } - assert(!virtualPartitionExistsForTable, s"Expected NO virtual partition metadata for table $outputTable when createView=false") + val virtualPartitionExistsForTable = + try { + val metadataCountForTable = tableUtils + .sql( + s"SELECT COUNT(*) as count FROM ${stagingQueryTable.signalPartitionsTable} WHERE table_name = '$outputTable'") + .collect()(0) + .getAs[Long]("count") + metadataCountForTable > 0 + } catch { + case _: Exception => false + } + assert(!virtualPartitionExistsForTable, + s"Expected NO virtual partition metadata for table $outputTable when createView=false") // Test Case 3: createView unset (should default to false and create table) val stagingQueryConfUnset = Builders.StagingQuery( @@ -389,9 +402,9 @@ class StagingQueryTest { val outputUnset = stagingQueryConfUnset.metaData.outputTable val isTableUnset = tableUtils.tableReadFormat(outputUnset) match { case View => false - case _ => true + case _ => true } - + assert(isTableUnset, s"Expected $outputUnset to be a table when createView is unset") } } From 6c9e74d29bb13cecce1368690693ac1b2a7e4624 Mon Sep 17 00:00:00 2001 From: Abby Whittier Date: Fri, 10 Oct 2025 04:03:57 +0000 Subject: [PATCH 28/28] add the actually important file --- .../scala/ai/chronon/spark/SerializeIR.scala | 117 ++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 spark/src/main/scala/ai/chronon/spark/SerializeIR.scala diff --git a/spark/src/main/scala/ai/chronon/spark/SerializeIR.scala b/spark/src/main/scala/ai/chronon/spark/SerializeIR.scala new file mode 100644 index 0000000000..e887529d58 --- /dev/null +++ b/spark/src/main/scala/ai/chronon/spark/SerializeIR.scala @@ -0,0 +1,117 @@ +/* + * Copyright (C) 2023 The Chronon Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.chronon.spark + +import ai.chronon.aggregator.row.RowAggregator +import ai.chronon.online.SparkConversions +import org.apache.spark.sql.Row + +/** + * Utilities for serializing/deserializing Chronon IR values to/from Spark DataFrames. + * + * Key concepts: + * - IR tables have MULTIPLE COLUMNS, each with a different IR type + * - Example: [count:Long, sum:Double, avg:Struct(sum, count)] + * - Each column is converted independently based on its IR type + * + * This bridges between: + * - Java IR types (internal aggregator state: HashSet, CpcSketch, etc.) + * - Normalized IR types (serializable format: Array, ArrayList, bytes) + * - Spark native types (DataFrame columns: primitives, Row, WrappedArray, Map) + * + * The conversion pipeline: + * Writing: Java IR → normalize() → toSparkRow() → Spark Row → DataFrame + * Reading: DataFrame → Spark Row → fromSparkRow() → denormalize() → Java IR + */ +object SerializeIR { + + /** + * Converts a Chronon IR array to Spark-native format for DataFrame writing. + * + * Processes each column independently based on its IR type from + * RowAggregator.incrementalOutputSchema. + * + * @param ir The IR array from RowAggregator (Java types) + * @param rowAgg The RowAggregator with column schemas + * @return Array where each element is in Spark-native format + */ + def toSparkRow(ir: Array[Any], rowAgg: RowAggregator): Array[Any] = { + // Step 1: Normalize (Java types → serializable types) + val normalized = rowAgg.normalize(ir) + + // Step 2: Convert each column to Spark-native type + val sparkColumns = new Array[Any](normalized.length) + rowAgg.incrementalOutputSchema.zipWithIndex.foreach { + case ((_, irType), idx) => + sparkColumns(idx) = SparkConversions.toSparkRow(normalized(idx), irType) + } + sparkColumns + } + + /** + * Converts Spark DataFrame Row to Chronon IR format for aggregation. + * + * Reads each IR column from the Spark Row by name and converts based on IR type. + * Uses RowAggregator.incrementalOutputSchema to get both column names and types. + * + * @param sparkRow The Spark Row from DataFrame.read() + * @param rowAgg The RowAggregator with IR schemas + * @return Denormalized IR array ready for merge() (Java types) + */ + def fromSparkRow(sparkRow: Row, rowAgg: RowAggregator): Array[Any] = { + val normalized = new Array[Any](rowAgg.incrementalOutputSchema.length) + + // Step 1: Extract each IR column from Spark Row by name + rowAgg.incrementalOutputSchema.zipWithIndex.foreach { + case ((colName, irType), idx) => + // Get column from Spark Row by NAME + val sparkValue = sparkRow.getAs[Any](colName) + // Convert using IR type + normalized(idx) = SparkConversions.fromSparkValue(sparkValue, irType) + } + + // Step 2: Denormalize (serializable types → Java types) + rowAgg.denormalize(normalized) + } + + /** + * Alternative: Extract IR columns by position instead of name. + * Faster but requires column order to match exactly. + * + * @param sparkRow The Spark Row from DataFrame + * @param rowAgg The RowAggregator with IR schemas + * @param startIndex The starting index of IR columns in the Row + * @return Denormalized IR array (Java types) + */ + def fromSparkRowByPosition( + sparkRow: Row, + rowAgg: RowAggregator, + startIndex: Int + ): Array[Any] = { + val normalized = new Array[Any](rowAgg.incrementalOutputSchema.length) + + // Step 1: Extract each IR column by position + rowAgg.incrementalOutputSchema.zipWithIndex.foreach { + case ((_, irType), idx) => + val sparkValue = sparkRow.get(startIndex + idx) + normalized(idx) = SparkConversions.fromSparkValue(sparkValue, irType) + } + + // Step 2: Denormalize + rowAgg.denormalize(normalized) + } +}