Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions api/src/main/scala/ai/chronon/api/DataRange.scala
Original file line number Diff line number Diff line change
Expand Up @@ -165,4 +165,10 @@ object PartitionRange {
if (ranges == null) return ""
rangesToString(ranges)
}

def toTimeRange(partitionRange: PartitionRange): TimeRange = {
val spec = partitionRange.partitionSpec
val shiftedEnd = spec.after(partitionRange.end)
TimeRange(spec.epochMillis(partitionRange.start), spec.epochMillis(shiftedEnd) - 1)(spec)
}
}
44 changes: 26 additions & 18 deletions spark/src/main/scala/ai/chronon/spark/JoinPartJob.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ package ai.chronon.spark

import ai.chronon.api.DataModel.{Entities, Events}
import ai.chronon.api.Extensions.{DateRangeOps, DerivationOps, GroupByOps, JoinPartOps, MetadataOps}
import ai.chronon.api.{Accuracy, Constants, DateRange, JoinPart, PartitionRange}
import ai.chronon.api.PartitionRange.toTimeRange
import ai.chronon.api.{Accuracy, Builders, Constants, DateRange, JoinPart, PartitionRange}
import ai.chronon.online.Metrics
import ai.chronon.orchestration.JoinPartNode
import ai.chronon.spark.Extensions.{DfWithStats, _}
Expand Down Expand Up @@ -38,9 +39,18 @@ class JoinPartJob(node: JoinPartNode, range: DateRange, showDf: Boolean = false)

def run(context: Option[JoinPartJobContext] = None): Option[DataFrame] = {

logger.info(s"Running join part job for ${joinPart.groupBy.metaData.name} on range $dateRange")

val jobContext = context.getOrElse {
// LeftTable is already computed by SourceJob, no need to apply query/filters/etc
val cachedLeftDf = tableUtils.scanDf(query = null, leftTable, range = Some(dateRange))
val relevantLeftCols =
joinPart.rightToLeft.keys.toArray ++ Seq(tableUtils.partitionColumn) ++ (leftDataModel match {
case Entities => None
case Events => Some(Constants.TimeColumn)
})

val query = Builders.Query(selects = relevantLeftCols.map(t => t -> t).toMap)
val cachedLeftDf = tableUtils.scanDf(query = query, leftTable, range = Some(dateRange))

val leftTimeRangeOpt: Option[PartitionRange] =
if (cachedLeftDf.schema.fieldNames.contains(Constants.TimePartitionColumn)) {
Expand All @@ -55,7 +65,7 @@ class JoinPartJob(node: JoinPartNode, range: DateRange, showDf: Boolean = false)
val leftWithStats = cachedLeftDf.withStats

val joinLevelBloomMapOpt =
JoinUtils.genBloomFilterIfNeeded(joinPart, leftDataModel, cachedLeftDf.count, dateRange, None)
JoinUtils.genBloomFilterIfNeeded(joinPart, leftDataModel, dateRange, None)

JoinPartJobContext(Option(leftWithStats),
joinLevelBloomMapOpt,
Expand Down Expand Up @@ -107,9 +117,7 @@ class JoinPartJob(node: JoinPartNode, range: DateRange, showDf: Boolean = false)
}
val elapsedMins = (System.currentTimeMillis() - start) / 60000
partMetrics.gauge(Metrics.Name.LatencyMinutes, elapsedMins)
val partitionCount = rightRange.partitions.length
partMetrics.gauge(Metrics.Name.PartitionCount, partitionCount)
logger.info(s"Wrote $partitionCount partitions to join part table: $partTable in $elapsedMins minutes")
logger.info(s"Wrote to join part table: $partTable in $elapsedMins minutes")
} catch {
case e: Exception =>
logger.error(s"Error while processing groupBy: ${joinPart.groupBy.getMetaData.getName}")
Expand Down Expand Up @@ -137,15 +145,12 @@ class JoinPartJob(node: JoinPartNode, range: DateRange, showDf: Boolean = false)
}

val statsDf = leftDfWithStats.get
val rowCount = statsDf.count
val unfilledRange = statsDf.partitionRange

logger.info(
s"\nBackfill is required for ${joinPart.groupBy.metaData.name} for $rowCount rows on range $unfilledRange")
logger.info(s"\nBackfill is required for ${joinPart.groupBy.metaData.name}")
val rightBloomMap = if (skipBloom) {
None
} else {
JoinUtils.genBloomFilterIfNeeded(joinPart, leftDataModel, rowCount, unfilledRange, joinLevelBloomMapOpt)
JoinUtils.genBloomFilterIfNeeded(joinPart, leftDataModel, dateRange, joinLevelBloomMapOpt)
}

val rightSkewFilter = JoinUtils.partSkewFilter(joinPart, skewKeys)
Expand All @@ -160,12 +165,15 @@ class JoinPartJob(node: JoinPartNode, range: DateRange, showDf: Boolean = false)
showDf = showDf)

// all lazy vals - so evaluated only when needed by each case.
lazy val partitionRangeGroupBy = genGroupBy(unfilledRange)
lazy val partitionRangeGroupBy = genGroupBy(dateRange)

lazy val unfilledTimeRange = {
lazy val unfilledPartitionRange = if (tableUtils.checkLeftTimeRange) {
val timeRange = statsDf.timeRange
logger.info(s"left unfilled time range: $timeRange")
timeRange
logger.info(s"left unfilled time range checked to be: $timeRange")
timeRange.toPartitionRange
} else {
logger.info(s"Not checking time range, but inferring it from partition range: $dateRange")
dateRange
}

val leftSkewFilter =
Expand Down Expand Up @@ -202,20 +210,20 @@ class JoinPartJob(node: JoinPartNode, range: DateRange, showDf: Boolean = false)
skewFilteredLeft.select(columns: _*)
}

lazy val shiftedPartitionRange = unfilledTimeRange.toPartitionRange.shift(-1)
lazy val shiftedPartitionRange = unfilledPartitionRange.shift(-1)

val renamedLeftDf = renamedLeftRawDf.select(renamedLeftRawDf.columns.map {
case c if c == tableUtils.partitionColumn =>
date_format(renamedLeftRawDf.col(c), tableUtils.partitionFormat).as(c)
case c => renamedLeftRawDf.col(c)
}.toList: _*)
val rightDf = (leftDataModel, joinPart.groupBy.dataModel, joinPart.groupBy.inferredAccuracy) match {
case (Entities, Events, _) => partitionRangeGroupBy.snapshotEvents(unfilledRange)
case (Entities, Events, _) => partitionRangeGroupBy.snapshotEvents(dateRange)
case (Entities, Entities, _) => partitionRangeGroupBy.snapshotEntities
case (Events, Events, Accuracy.SNAPSHOT) =>
genGroupBy(shiftedPartitionRange).snapshotEvents(shiftedPartitionRange)
case (Events, Events, Accuracy.TEMPORAL) =>
genGroupBy(unfilledTimeRange.toPartitionRange).temporalEvents(renamedLeftDf, Some(unfilledTimeRange))
genGroupBy(unfilledPartitionRange).temporalEvents(renamedLeftDf, Some(toTimeRange(unfilledPartitionRange)))

case (Events, Entities, Accuracy.SNAPSHOT) => genGroupBy(shiftedPartitionRange).snapshotEntities

Expand Down
2 changes: 0 additions & 2 deletions spark/src/main/scala/ai/chronon/spark/JoinUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,6 @@ object JoinUtils {
def genBloomFilterIfNeeded(
joinPart: ai.chronon.api.JoinPart,
leftDataModel: DataModel,
leftRowCount: Long,
unfilledRange: PartitionRange,
joinLevelBloomMapOpt: Option[util.Map[String, BloomFilter]]): Option[util.Map[String, BloomFilter]] = {

Expand Down Expand Up @@ -362,7 +361,6 @@ object JoinUtils {
| right type: ${joinPart.groupBy.dataModel},
| accuracy : ${joinPart.groupBy.inferredAccuracy},
| part unfilled range: $unfilledRange,
| left row count: $leftRowCount
| bloom sizes: $bloomSizes
| groupBy: ${joinPart.groupBy.toString}
|""".stripMargin)
Expand Down
2 changes: 2 additions & 0 deletions spark/src/main/scala/ai/chronon/spark/TableUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
// default threshold is 100K rows
val bloomFilterThreshold: Long =
sparkSession.conf.get("spark.chronon.backfill.bloomfilter.threshold", "1000000").toLong
val checkLeftTimeRange: Boolean =
sparkSession.conf.get("spark.chronon.join.backfill.check.left_time_range", "false").toBoolean

private val minWriteShuffleParallelism = 200

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ class MutationsTest extends AnyFlatSpec {

val spark: SparkSession =
SparkSessionBuilder.build("MutationsTest",
local = true
) //, additionalConfig = Some(Map("spark.chronon.backfill.validation.enabled" -> "false")))
local = true,
additionalConfig =
Some(Map("spark.chronon.join.backfill.check.left_time_range" -> "true")))
private implicit val tableUtils: TableUtils = TableUtils(spark)

private def namespace(suffix: String) = s"test_mutations_$suffix"
Expand Down Expand Up @@ -823,6 +824,7 @@ class MutationsTest extends AnyFlatSpec {
}

it should "test with generated data" in {

val suffix = "generated"
val reviews = List(
Column("listing_id", api.StringType, 10),
Expand Down