diff --git a/api/src/main/scala/ai/chronon/api/Constants.scala b/api/src/main/scala/ai/chronon/api/Constants.scala index e72152bb50..4468234070 100644 --- a/api/src/main/scala/ai/chronon/api/Constants.scala +++ b/api/src/main/scala/ai/chronon/api/Constants.scala @@ -31,8 +31,9 @@ object Constants { val GroupByServingInfoKey = "group_by_serving_info" val UTF8 = "UTF-8" val TopicInvalidSuffix = "_invalid" - val lineTab = "\n " val SemanticHashKey = "semantic_hash" + val SemanticHashOptionsKey = "semantic_hash_options" + val SemanticHashExcludeTopic = "exclude_topic" val SchemaHash = "schema_hash" val BootstrapHash = "bootstrap_hash" val MatchedHashes = "matched_hashes" diff --git a/api/src/main/scala/ai/chronon/api/Extensions.scala b/api/src/main/scala/ai/chronon/api/Extensions.scala index 6e86281b70..0ce907145b 100644 --- a/api/src/main/scala/ai/chronon/api/Extensions.scala +++ b/api/src/main/scala/ai/chronon/api/Extensions.scala @@ -30,6 +30,7 @@ import java.util.regex.Pattern import scala.collection.{Seq, mutable} import scala.util.ScalaJavaConversions.{IteratorOps, ListOps, MapOps} import scala.util.{Failure, Success, Try} +import scala.collection.JavaConverters._ object Extensions { @@ -845,14 +846,14 @@ object Extensions { def partOutputTable(jp: JoinPart): String = (Seq(join.metaData.outputTable) ++ Option(jp.prefix) :+ jp.groupBy.metaData.cleanName).mkString("_") - private val leftSourceKey: String = "left_source" - private val derivedKey: String = "derived" + val leftSourceKey: String = "left_source" + val derivedKey: String = "derived" /* * semanticHash contains hashes of left side and each join part, and is used to detect join definition * changes and determine whether any intermediate/final tables of the join need to be recomputed. */ - def semanticHash: Map[String, String] = { + private[api] def baseSemanticHash: Map[String, String] = { val leftHash = ThriftJsonCodec.md5Digest(join.left) logger.info(s"Join Left Object: ${ThriftJsonCodec.toJsonStr(join.left)}") val partHashes = join.joinParts.toScala.map { jp => partOutputTable(jp) -> jp.groupBy.semanticHash }.toMap @@ -867,6 +868,42 @@ object Extensions { partHashes ++ Map(leftSourceKey -> leftHash, join.metaData.bootstrapTable -> bootstrapHash) ++ derivedHashMap } + /* + * Unset topic / mutationTopic in everywhere for this join recursively. + * Input join will be modified in place. + */ + private def cleanTopic(join: Join): Join = { + def cleanTopicInSource(source: Source): Unit = { + if (source.isSetEvents) { + source.getEvents.unsetTopic() + } else if (source.isSetEntities) { + source.getEntities.unsetMutationTopic() + } else if (source.isSetJoinSource) { + cleanTopic(source.getJoinSource.getJoin) + } + } + + cleanTopicInSource(join.left) + join.getJoinParts.toScala.foreach(_.groupBy.sources.toScala.foreach(cleanTopicInSource)) + join + } + + /* + * Compute variants of semantic_hash with different flags. A flag is stored on Hive metadata and used to + * indicate which version of semantic_hash logic to use. + */ + def semanticHash(excludeTopic: Boolean): Map[String, String] = { + if (excludeTopic) { + // WARN: deepCopy doesn't guarantee same semantic_hash will be produced due to reordering of map keys + // but the behavior is deterministic + val joinCopy = join.deepCopy() + cleanTopic(joinCopy) + joinCopy.baseSemanticHash + } else { + baseSemanticHash + } + } + /* External features computed in online env and logged This method will get the external feature column names @@ -887,62 +924,6 @@ object Extensions { .getOrElse(Seq.empty) } - /* - * onlineSemanticHash includes everything in semanticHash as well as hashes of each onlineExternalParts (which only - * affect online serving but not offline table generation). - * It is used to detect join definition change in online serving and to update ttl-cached conf files. - */ - def onlineSemanticHash: Map[String, String] = { - if (join.onlineExternalParts == null) { - return Map.empty[String, String] - } - - val externalPartHashes = join.onlineExternalParts.toScala.map { part => part.fullName -> part.semanticHash }.toMap - - externalPartHashes ++ semanticHash - } - - def leftChanged(oldSemanticHash: Map[String, String]): Boolean = { - // Checks for semantic changes in left or bootstrap, because those are saved together - val bootstrapExistsAndChanged = oldSemanticHash.contains(join.metaData.bootstrapTable) && oldSemanticHash.get( - join.metaData.bootstrapTable) != semanticHash.get(join.metaData.bootstrapTable) - logger.info(s"Bootstrap table changed: $bootstrapExistsAndChanged") - logger.info(s"Old Semantic Hash: $oldSemanticHash") - logger.info(s"New Semantic Hash: $semanticHash") - oldSemanticHash.get(leftSourceKey) != semanticHash.get(leftSourceKey) || bootstrapExistsAndChanged - } - - def tablesToDrop(oldSemanticHash: Map[String, String]): Seq[String] = { - val newSemanticHash = semanticHash - // only right join part hashes for convenience - def partHashes(semanticHashMap: Map[String, String]): Map[String, String] = { - semanticHashMap.filter { case (name, _) => name != leftSourceKey && name != derivedKey } - } - - // drop everything if left source changes - val partsToDrop = if (leftChanged(oldSemanticHash)) { - partHashes(oldSemanticHash).keys.toSeq - } else { - val changed = partHashes(newSemanticHash).flatMap { - case (key, newVal) => - oldSemanticHash.get(key).filter(_ != newVal).map(_ => key) - } - val deleted = partHashes(oldSemanticHash).keys.filterNot(newSemanticHash.contains) - (changed ++ deleted).toSeq - } - val added = newSemanticHash.keys.filter(!oldSemanticHash.contains(_)).filter { - // introduce boostrapTable as a semantic_hash but skip dropping to avoid recompute if it is empty - case key if key == join.metaData.bootstrapTable => join.isSetBootstrapParts && !join.bootstrapParts.isEmpty - case _ => true - } - val derivedChanges = oldSemanticHash.get(derivedKey) != newSemanticHash.get(derivedKey) - // TODO: make this incremental, retain the main table and continue joining, dropping etc - val mainTable = if (partsToDrop.nonEmpty || added.nonEmpty || derivedChanges) { - Some(join.metaData.outputTable) - } else None - partsToDrop ++ mainTable - } - def isProduction: Boolean = join.getMetaData.isProduction def team: String = join.getMetaData.getTeam diff --git a/spark/src/main/scala/ai/chronon/spark/Driver.scala b/spark/src/main/scala/ai/chronon/spark/Driver.scala index 8168007196..a9471e41b0 100644 --- a/spark/src/main/scala/ai/chronon/spark/Driver.scala +++ b/spark/src/main/scala/ai/chronon/spark/Driver.scala @@ -61,6 +61,17 @@ object Driver { def parseConf[T <: TBase[_, _]: Manifest: ClassTag](confPath: String): T = ThriftJsonCodec.fromJsonFile[T](confPath, check = true) + trait JoinBackfillSubcommand { + this: ScallopConf => + val unsetSemanticHash: ScallopOption[Boolean] = + opt[Boolean]( + required = false, + default = Some(false), + descr = + "When set to true, semantic_hash is unset in join tblprops to allow for config update without recompute." + ) + } + trait OfflineSubcommand { this: ScallopConf => val confPath: ScallopOption[String] = opt[String](required = true, descr = "Path to conf") @@ -236,6 +247,7 @@ object Driver { class Args extends Subcommand("join") with OfflineSubcommand + with JoinBackfillSubcommand with LocalExportTableAbility with ResultValidationAbility { val selectedJoinParts: ScallopOption[List[String]] = @@ -256,7 +268,8 @@ object Driver { args.endDate(), args.buildTableUtils(), !args.runFirstHole(), - selectedJoinParts = args.selectedJoinParts.toOption + selectedJoinParts = args.selectedJoinParts.toOption, + unsetSemanticHash = args.unsetSemanticHash.getOrElse(false) ) if (args.selectedJoinParts.isDefined) { @@ -291,6 +304,7 @@ object Driver { class Args extends Subcommand("join-left") with OfflineSubcommand + with JoinBackfillSubcommand with LocalExportTableAbility with ResultValidationAbility { lazy val joinConf: api.Join = parseConf[api.Join](confPath()) @@ -298,12 +312,12 @@ object Driver { } def run(args: Args): Unit = { - val tableUtils = args.buildTableUtils() val join = new Join( args.joinConf, args.endDate(), args.buildTableUtils(), - !args.runFirstHole() + !args.runFirstHole(), + unsetSemanticHash = args.unsetSemanticHash.getOrElse(false) ) join.computeLeft(args.startPartitionOverride.toOption) } @@ -314,6 +328,7 @@ object Driver { class Args extends Subcommand("join-final") with OfflineSubcommand + with JoinBackfillSubcommand with LocalExportTableAbility with ResultValidationAbility { lazy val joinConf: api.Join = parseConf[api.Join](confPath()) @@ -321,12 +336,12 @@ object Driver { } def run(args: Args): Unit = { - val tableUtils = args.buildTableUtils() val join = new Join( args.joinConf, args.endDate(), args.buildTableUtils(), - !args.runFirstHole() + !args.runFirstHole(), + unsetSemanticHash = args.unsetSemanticHash.getOrElse(false) ) join.computeFinal(args.startPartitionOverride.toOption) } diff --git a/spark/src/main/scala/ai/chronon/spark/Join.scala b/spark/src/main/scala/ai/chronon/spark/Join.scala index 3d8f9d363e..345afd2b0c 100644 --- a/spark/src/main/scala/ai/chronon/spark/Join.scala +++ b/spark/src/main/scala/ai/chronon/spark/Join.scala @@ -65,8 +65,9 @@ class Join(joinConf: api.Join, tableUtils: TableUtils, skipFirstHole: Boolean = true, showDf: Boolean = false, - selectedJoinParts: Option[List[String]] = None) - extends JoinBase(joinConf, endPartition, tableUtils, skipFirstHole, showDf, selectedJoinParts) { + selectedJoinParts: Option[List[String]] = None, + unsetSemanticHash: Boolean = false) + extends JoinBase(joinConf, endPartition, tableUtils, skipFirstHole, showDf, selectedJoinParts, unsetSemanticHash) { private def padFields(df: DataFrame, structType: sql.types.StructType): DataFrame = { structType.foldLeft(df) { diff --git a/spark/src/main/scala/ai/chronon/spark/JoinBase.scala b/spark/src/main/scala/ai/chronon/spark/JoinBase.scala index a5f75d6a80..24768291bb 100644 --- a/spark/src/main/scala/ai/chronon/spark/JoinBase.scala +++ b/spark/src/main/scala/ai/chronon/spark/JoinBase.scala @@ -22,17 +22,18 @@ import ai.chronon.api.Extensions._ import ai.chronon.api.{Accuracy, Constants, JoinPart} import ai.chronon.online.Metrics import ai.chronon.spark.Extensions._ -import ai.chronon.spark.JoinUtils.{coalescedJoin, leftDf, tablesToRecompute, shouldRecomputeLeft} +import ai.chronon.spark.JoinUtils.{coalescedJoin, leftDf} +import ai.chronon.spark.SemanticHashUtils.{shouldRecomputeLeft, tablesToRecompute} import com.google.gson.Gson import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.util.sketch.BloomFilter import org.slf4j.LoggerFactory -import java.time.Instant +import java.time.Instant +import java.util import scala.collection.JavaConverters._ import scala.collection.Seq -import java.util import scala.util.ScalaJavaConversions.ListOps abstract class JoinBase(joinConf: api.Join, @@ -40,7 +41,8 @@ abstract class JoinBase(joinConf: api.Join, tableUtils: TableUtils, skipFirstHole: Boolean, showDf: Boolean = false, - selectedJoinParts: Option[Seq[String]] = None) { + selectedJoinParts: Option[Seq[String]] = None, + unsetSemanticHash: Boolean) { @transient lazy val logger = LoggerFactory.getLogger(getClass) assert(Option(joinConf.metaData.outputNamespace).nonEmpty, s"output namespace could not be empty or null") val metrics: Metrics.Context = Metrics.Context(Metrics.Environment.JoinOffline, joinConf) @@ -57,7 +59,13 @@ abstract class JoinBase(joinConf: api.Join, private val gson = new Gson() // Combine tableProperties set on conf with encoded Join protected val tableProps: Map[String, String] = - confTableProps ++ Map(Constants.SemanticHashKey -> gson.toJson(joinConf.semanticHash.asJava)) + confTableProps ++ Map( + Constants.SemanticHashKey -> gson.toJson(joinConf.semanticHash(excludeTopic = true).asJava), + Constants.SemanticHashOptionsKey -> gson.toJson( + Map( + Constants.SemanticHashExcludeTopic -> "true" + ).asJava) + ) def joinWithLeft(leftDf: DataFrame, rightDf: DataFrame, joinPart: JoinPart): DataFrame = { val partLeftKeys = joinPart.rightToLeft.values.toArray @@ -342,13 +350,37 @@ abstract class JoinBase(joinConf: api.Join, .getOrElse(Seq.empty)) } + private def performArchive(tables: Seq[String], autoArchive: Boolean, mainTable: String): Unit = { + if (autoArchive) { + val archivedAtTs = Instant.now() + tables.foreach(tableUtils.archiveOrDropTableIfExists(_, Some(archivedAtTs))) + } else { + val errorMsg = s"""Auto archive is disabled due to semantic hash out of date. + |Please verify if your config involves true semantic changes. If so, archive the following tables: + |${tables.map(t => s"- $t").mkString("\n")} + |If not, please retry this job with `--unset-semantic-hash` flag in your run.py args. + |OR run the spark SQL cmd: ALTER TABLE $mainTable UNSET TBLPROPERTIES ('semantic_hash') and then retry this job. + |""".stripMargin + logger.error(errorMsg) + throw SemanticHashException(errorMsg) + } + } + def computeLeft(overrideStartPartition: Option[String] = None): Unit = { // Runs the left side query for a join and saves the output to a table, for reuse by joinPart // Computation in parallelized joinPart execution mode. - if (shouldRecomputeLeft(joinConf, bootstrapTable, tableUtils)) { - logger.info(s"Detected semantic change in left side of join, archiving left table for recomputation.") - val archivedAtTs = Instant.now() - tableUtils.archiveOrDropTableIfExists(bootstrapTable, Some(archivedAtTs)) + val (shouldRecompute, autoArchive) = shouldRecomputeLeft(joinConf, bootstrapTable, tableUtils, unsetSemanticHash) + if (!shouldRecompute) { + logger.info(s"No semantic change detected, leaving bootstrap table in place.") + // Still update the semantic hash of the bootstrap table. + // It is possible that while bootstrap_table's semantic_hash is unchanged, the rest has changed, so + // we keep everything in sync. + if (tableUtils.tableExists(bootstrapTable)) { + tableUtils.alterTableProperties(bootstrapTable, tableProps) + } + } else { + logger.info(s"Detected semantic change in left side of join, archiving bootstrap table for recomputation.") + performArchive(Seq(bootstrapTable), autoArchive, bootstrapTable) } val (rangeToFill, unfilledRanges) = getUnfilledRange(overrideStartPartition, bootstrapTable) @@ -375,15 +407,15 @@ abstract class JoinBase(joinConf: api.Join, def computeFinalJoin(leftDf: DataFrame, leftRange: PartitionRange, bootstrapInfo: BootstrapInfo): Unit - def computeFinal(overrideStartPartition: Option[String] = None) = { + def computeFinal(overrideStartPartition: Option[String] = None): Unit = { // Utilizes the same tablesToRecompute check as the monolithic spark job, because if any joinPart changes, then so does the output table - if (tablesToRecompute(joinConf, outputTable, tableUtils).isEmpty) { + val (tablesChanged, autoArchive) = tablesToRecompute(joinConf, outputTable, tableUtils, unsetSemanticHash) + if (tablesChanged.isEmpty) { logger.info(s"No semantic change detected, leaving output table in place.") } else { logger.info(s"Semantic changes detected, archiving output table.") - val archivedAtTs = Instant.now() - tableUtils.archiveOrDropTableIfExists(outputTable, Some(archivedAtTs)) + performArchive(tablesChanged, autoArchive, outputTable) } val (rangeToFill, unfilledRanges) = getUnfilledRange(overrideStartPartition, outputTable) @@ -448,11 +480,15 @@ abstract class JoinBase(joinConf: api.Join, leftDf(joinConf, PartitionRange(endPartition, endPartition)(tableUtils), tableUtils, limit = Some(1)).map(df => df.schema) - val archivedAtTs = Instant.now() - // Check semantic hash before overwriting left side - // TODO: We should not archive the output table in the case of selected join parts mode - tablesToRecompute(joinConf, outputTable, tableUtils).foreach( - tableUtils.archiveOrDropTableIfExists(_, Some(archivedAtTs))) + // Run command to archive ALL tables that have changed semantically since the last run + // TODO: We should not archive the output table or other JP's intermediate tables in the case of selected join parts mode + val (tablesChanged, autoArchive) = tablesToRecompute(joinConf, outputTable, tableUtils, unsetSemanticHash) + if (tablesChanged.isEmpty) { + logger.info(s"No semantic change detected, leaving output table in place.") + } else { + logger.info(s"Semantic changes detected, archiving output table.") + performArchive(tablesChanged, autoArchive, outputTable) + } // Overwriting Join Left with bootstrap table to simplify later computation val source = joinConf.left diff --git a/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala b/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala index bb70355e22..46a4364339 100644 --- a/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala @@ -21,7 +21,6 @@ import ai.chronon.api.Constants import ai.chronon.api.DataModel.Events import ai.chronon.api.Extensions.{JoinOps, _} import ai.chronon.spark.Extensions._ -import com.google.gson.Gson import org.apache.spark.sql.DataFrame import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions.{coalesce, col, udf} @@ -31,12 +30,11 @@ import org.slf4j.LoggerFactory import java.util import scala.collection.compat._ import scala.jdk.CollectionConverters._ -import scala.util.ScalaJavaConversions.MapOps object JoinUtils { @transient lazy val logger = LoggerFactory.getLogger(getClass) - /*** + /** * * Util methods for join computation */ @@ -97,7 +95,7 @@ object JoinUtils { } }) - /*** + /** * * Compute partition range to be filled for given join conf */ def getRangesToFill(leftSource: ai.chronon.api.Source, @@ -118,7 +116,7 @@ object JoinUtils { PartitionRange(leftStart, leftEnd)(tableUtils) } - /*** + /** * * join left and right dataframes, merging any shared columns if exists by the coalesce rule. * fails if there is any data type mismatch between shared columns. * @@ -160,7 +158,7 @@ object JoinUtils { finalDf } - /*** + /** * * Method to create or replace a view for feature table joining with labels. * Label columns will be prefixed with "label" or custom prefix for easy identification */ @@ -206,7 +204,7 @@ object JoinUtils { tableUtils.sql(sqlStatement) } - /*** + /** * * Method to create a view with latest available label_ds for a given ds. This view is built * on top of final label view which has all label versions available. * This view will inherit the final label view properties as well. @@ -288,6 +286,7 @@ object JoinUtils { /** * Generate a Bloom filter for 'joinPart' when the row count to be backfilled falls below a specified threshold. * This method anticipates that there will likely be a substantial number of rows on the right side that need to be filtered out. + * * @return bloomfilter map option for right part */ @@ -317,15 +316,15 @@ object JoinUtils { logger.info(s""" Generating bloom filter for joinPart: - | part name : ${joinPart.groupBy.metaData.name}, - | left type : ${joinConf.left.dataModel}, - | right type: ${joinPart.groupBy.dataModel}, - | accuracy : ${joinPart.groupBy.inferredAccuracy}, - | part unfilled range: $unfilledRange, - | left row count: $leftRowCount - | bloom sizes: $bloomSizes - | groupBy: ${joinPart.groupBy.toString} - |""".stripMargin) + | part name : ${joinPart.groupBy.metaData.name}, + | left type : ${joinConf.left.dataModel}, + | right type: ${joinPart.groupBy.dataModel}, + | accuracy : ${joinPart.groupBy.inferredAccuracy}, + | part unfilled range: $unfilledRange, + | left row count: $leftRowCount + | bloom sizes: $bloomSizes + | groupBy: ${joinPart.groupBy.toString} + |""".stripMargin) rightBlooms } @@ -384,31 +383,4 @@ object JoinUtils { df.drop(columnsToDrop: _*) } - def tablesToRecompute(joinConf: ai.chronon.api.Join, - outputTable: String, - tableUtils: TableUtils): collection.Seq[String] = { - // Finds all join output tables (join parts and final table) that need recomputing (in monolithic spark job mode) - val gson = new Gson() - (for ( - props <- tableUtils.getTableProperties(outputTable); - oldSemanticJson <- props.get(Constants.SemanticHashKey); - oldSemanticHash = gson.fromJson(oldSemanticJson, classOf[java.util.HashMap[String, String]]).toScala - ) yield { - logger.info(s"Comparing Hashes:\nNew: ${joinConf.semanticHash},\nOld: $oldSemanticHash") - joinConf.tablesToDrop(oldSemanticHash) - }).getOrElse(collection.Seq.empty) - } - - def shouldRecomputeLeft(joinConf: ai.chronon.api.Join, outputTable: String, tableUtils: TableUtils): Boolean = { - // Determines if the saved left table of the join (includes bootstrap) needs to be recomputed due to semantic changes since last run - if (tableUtils.tableExists(outputTable)) { - val gson = new Gson() - val props = tableUtils.getTableProperties(outputTable); - val oldSemanticJson = props.get(Constants.SemanticHashKey); - val oldSemanticHash = gson.fromJson(oldSemanticJson, classOf[java.util.HashMap[String, String]]).toScala - joinConf.leftChanged(oldSemanticHash) - } else { - false - } - } } diff --git a/spark/src/main/scala/ai/chronon/spark/SemanticHashUtils.scala b/spark/src/main/scala/ai/chronon/spark/SemanticHashUtils.scala new file mode 100644 index 0000000000..1ae8b001a6 --- /dev/null +++ b/spark/src/main/scala/ai/chronon/spark/SemanticHashUtils.scala @@ -0,0 +1,141 @@ +package ai.chronon.spark + +import ai.chronon.api.Constants +import ai.chronon.api.Extensions.{JoinOps, MetadataOps} +import ai.chronon.spark.JoinUtils.logger +import com.google.gson.Gson + +import scala.util.ScalaJavaConversions.MapOps + +/* + * Metadata stored in Hive table properties to track semantic hashes and flags. + * The flags are used to indicate the versioned logic that was used to compute the semantic hash + */ +case class SemanticHashHiveMetadata(semanticHash: Map[String, String], excludeTopic: Boolean) + +case class SemanticHashException(message: String) extends Exception(message) + +/* + * Utilities to handle semantic_hash computation, comparison, migration and table_archiving. + */ +object SemanticHashUtils { + + // Finds all join output tables (join parts and final table) that need recomputing (in monolithic spark job mode) + def tablesToRecompute(joinConf: ai.chronon.api.Join, + outputTable: String, + tableUtils: TableUtils, + unsetSemanticHash: Boolean): (collection.Seq[String], Boolean) = + computeDiff(joinConf, outputTable, tableUtils, unsetSemanticHash, tableHashesChanged, Seq.empty) + + // Determines if the saved left table of the join (includes bootstrap) needs to be recomputed due to semantic changes since last run + def shouldRecomputeLeft(joinConf: ai.chronon.api.Join, + outputTable: String, + tableUtils: TableUtils, + unsetSemanticHash: Boolean): (Boolean, Boolean) = { + computeDiff(joinConf, outputTable, tableUtils, unsetSemanticHash, isLeftHashChanged, false) + } + + /* + * When semantic_hash versions are different, a diff does not automatically mean a semantic change. + * In those scenarios, we print out the diff tables and differ to users about dropping. + */ + private def canAutoArchive(semanticHashHiveMetadata: SemanticHashHiveMetadata): Boolean = { + semanticHashHiveMetadata.excludeTopic + } + + private def computeDiff[T](joinConf: ai.chronon.api.Join, + outputTable: String, + tableUtils: TableUtils, + unsetSemanticHash: Boolean, + computeDiffFunc: (Map[String, String], Map[String, String], ai.chronon.api.Join) => T, + emptyFunc: => T): (T, Boolean) = { + val semanticHashHiveMetadata = if (unsetSemanticHash) { + None + } else { + getSemanticHashFromHive(outputTable, tableUtils) + } + + if (semanticHashHiveMetadata.isDefined) { + val oldSemanticHash = semanticHashHiveMetadata.get.semanticHash + val newSemanticHash = joinConf.semanticHash(excludeTopic = semanticHashHiveMetadata.get.excludeTopic) + def prettyPrintMap(map: Map[String, String]): String = { + map.toSeq.sorted.map { case (key, value) => s"- $key: $value" }.mkString("\n") + } + logger.info( + s"""Comparing Hashes: + |Hive Flag: + |${Constants.SemanticHashExcludeTopic}: ${semanticHashHiveMetadata.get.excludeTopic} + |Old Hashes: + |${prettyPrintMap(oldSemanticHash)} + |New Hashes: + |${prettyPrintMap(newSemanticHash)} + |""".stripMargin + ) + val diff = computeDiffFunc(oldSemanticHash, newSemanticHash, joinConf) + val autoArchive = canAutoArchive(semanticHashHiveMetadata.get) + (diff, autoArchive) + } else { + logger.info("No semantic hash found in Hive. Proceed to computation and table creation.") + (emptyFunc, true) + } + } + + private def getSemanticHashFromHive(outputTable: String, tableUtils: TableUtils): Option[SemanticHashHiveMetadata] = { + val gson = new Gson() + val tablePropsOpt = tableUtils.getTableProperties(outputTable) + val oldSemanticJsonOpt = tablePropsOpt.flatMap(_.get(Constants.SemanticHashKey)) + val oldSemanticHash = + oldSemanticJsonOpt.map(json => gson.fromJson(json, classOf[java.util.HashMap[String, String]]).toScala) + + val oldSemanticHashOptions = tablePropsOpt + .flatMap(_.get(Constants.SemanticHashOptionsKey)) + .map(m => gson.fromJson(m, classOf[java.util.HashMap[String, String]]).toScala) + .getOrElse(Map.empty) + val hasSemanticHashExcludeTopicFlag = + oldSemanticHashOptions.get(Constants.SemanticHashExcludeTopic).contains("true") + + oldSemanticHash.map(hashes => SemanticHashHiveMetadata(hashes, hasSemanticHashExcludeTopicFlag)) + } + + private def isLeftHashChanged(oldSemanticHash: Map[String, String], + newSemanticHash: Map[String, String], + join: ai.chronon.api.Join): Boolean = { + // Checks for semantic changes in left or bootstrap, because those are saved together + val bootstrapExistsAndChanged = oldSemanticHash.contains(join.metaData.bootstrapTable) && oldSemanticHash.get( + join.metaData.bootstrapTable) != newSemanticHash.get(join.metaData.bootstrapTable) + logger.info(s"Bootstrap table changed: $bootstrapExistsAndChanged") + oldSemanticHash.get(join.leftSourceKey) != newSemanticHash.get(join.leftSourceKey) || bootstrapExistsAndChanged + } + + private[spark] def tableHashesChanged(oldSemanticHash: Map[String, String], + newSemanticHash: Map[String, String], + join: ai.chronon.api.Join): Seq[String] = { + // only right join part hashes for convenience + def partHashes(semanticHashMap: Map[String, String]): Map[String, String] = { + semanticHashMap.filter { case (name, _) => name != join.leftSourceKey && name != join.derivedKey } + } + + // drop everything if left source changes + val partsToDrop = if (isLeftHashChanged(oldSemanticHash, newSemanticHash, join)) { + partHashes(oldSemanticHash).keys.toSeq + } else { + val changed = partHashes(newSemanticHash).flatMap { + case (key, newVal) => + oldSemanticHash.get(key).filter(_ != newVal).map(_ => key) + } + val deleted = partHashes(oldSemanticHash).keys.filterNot(newSemanticHash.contains) + (changed ++ deleted).toSeq + } + val added = newSemanticHash.keys.filter(!oldSemanticHash.contains(_)).filter { + // introduce boostrapTable as a semantic_hash but skip dropping to avoid recompute if it is empty + case key if key == join.metaData.bootstrapTable => join.isSetBootstrapParts && !join.bootstrapParts.isEmpty + case _ => true + } + val derivedChanges = oldSemanticHash.get(join.derivedKey) != newSemanticHash.get(join.derivedKey) + // TODO: make this incremental, retain the main table and continue joining, dropping etc + val mainTable = if (partsToDrop.nonEmpty || added.nonEmpty || derivedChanges) { + Some(join.metaData.outputTable) + } else None + partsToDrop ++ mainTable + } +} diff --git a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala index 16cd25e6b5..90aa6d93fe 100644 --- a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala @@ -317,7 +317,7 @@ case class TableUtils(sparkSession: SparkSession) { } } if (tableProperties != null && tableProperties.nonEmpty) { - sql(alterTablePropertiesSql(tableName, tableProperties)) + alterTableProperties(tableName, tableProperties) } if (autoExpand) { @@ -381,7 +381,7 @@ case class TableUtils(sparkSession: SparkSession) { sql(createTableSql(tableName, df.schema, Seq.empty[String], tableProperties, fileFormat)) } else { if (tableProperties != null && tableProperties.nonEmpty) { - sql(alterTablePropertiesSql(tableName, tableProperties)) + alterTableProperties(tableName, tableProperties) } } @@ -551,7 +551,7 @@ case class TableUtils(sparkSession: SparkSession) { Seq(createFragment, partitionFragment, fileFormatString, propertiesFragment).mkString("\n") } - private def alterTablePropertiesSql(tableName: String, properties: Map[String, String]): String = { + def alterTableProperties(tableName: String, properties: Map[String, String]): Unit = { // Only SQL api exists for setting TBLPROPERTIES val propertiesString = properties .map { @@ -559,7 +559,8 @@ case class TableUtils(sparkSession: SparkSession) { s"'$key' = '$value'" } .mkString(", ") - s"ALTER TABLE $tableName SET TBLPROPERTIES ($propertiesString)" + val query = s"ALTER TABLE $tableName SET TBLPROPERTIES ($propertiesString)" + sql(query) } def chunk(partitions: Set[String]): Seq[PartitionRange] = { @@ -804,7 +805,7 @@ case class TableUtils(sparkSession: SparkSession) { sql(expandTableQueryOpt.get) // set a flag in table props to indicate that this is a dynamic table - sql(alterTablePropertiesSql(tableName, Map(Constants.ChrononDynamicTable -> true.toString))) + alterTableProperties(tableName, Map(Constants.ChrononDynamicTable -> true.toString)) } } } diff --git a/spark/src/main/scala/ai/chronon/spark/stats/SummaryJob.scala b/spark/src/main/scala/ai/chronon/spark/stats/SummaryJob.scala index e6f88341a4..000922ee58 100644 --- a/spark/src/main/scala/ai/chronon/spark/stats/SummaryJob.scala +++ b/spark/src/main/scala/ai/chronon/spark/stats/SummaryJob.scala @@ -22,7 +22,7 @@ import ai.chronon.aggregator.row.StatsGenerator import ai.chronon.api.Extensions._ import ai.chronon.api._ import ai.chronon.spark.Extensions._ -import ai.chronon.spark.{JoinUtils, PartitionRange, TableUtils} +import ai.chronon.spark.{PartitionRange, SemanticHashUtils, TableUtils} import org.apache.spark.sql.SparkSession /** @@ -45,9 +45,13 @@ class SummaryJob(session: SparkSession, joinConf: Join, endDate: String) extends sample: Double = 0.1, forceBackfill: Boolean = false): Unit = { val uploadTable = joinConf.metaData.toUploadTable(outputTable) - val backfillRequired = (!JoinUtils.tablesToRecompute(joinConf, outputTable, tableUtils).isEmpty) || forceBackfill + val backfillRequired = + SemanticHashUtils + .tablesToRecompute(joinConf, outputTable, tableUtils, unsetSemanticHash = false) + ._1 + .nonEmpty || forceBackfill if (backfillRequired) - Seq(outputTable, uploadTable).foreach(tableUtils.dropTableIfExists(_)) + Seq(outputTable, uploadTable).foreach(tableUtils.dropTableIfExists) val unfilledRanges = tableUtils .unfilledRanges(outputTable, PartitionRange(null, endDate)(tableUtils), Some(Seq(inputTable))) .getOrElse(Seq.empty) diff --git a/spark/src/test/scala/ai/chronon/spark/test/JoinFlowTest.scala b/spark/src/test/scala/ai/chronon/spark/test/JoinFlowTest.scala index d2a980db48..7b39614dc0 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/JoinFlowTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/JoinFlowTest.scala @@ -105,19 +105,20 @@ class JoinFlowTest { ) val endDs = tableUtils.partitions(queryTable).max - val joinJob = new Join(join, endDs, tableUtils) // compute left - joinJob.computeLeft() + val joinLeftJob = new Join(join.deepCopy(), endDs, tableUtils) + joinLeftJob.computeLeft() assertTrue(tableUtils.tableExists(join.metaData.bootstrapTable)) // compute right - val joinPartJob = new Join(join, endDs, tableUtils, selectedJoinParts = Some(List(joinPart.fullPrefix))) + val joinPartJob = new Join(join.deepCopy(), endDs, tableUtils, selectedJoinParts = Some(List(joinPart.fullPrefix))) joinPartJob.computeJoinOpt(useBootstrapForLeft = true) assertTrue(tableUtils.tableExists(join.partOutputTable(joinPart))) // compute final - joinJob.computeFinal() + val joinFinalJob = new Join(join.deepCopy(), endDs, tableUtils) + joinFinalJob.computeFinal() assertTrue(tableUtils.tableExists(join.metaData.outputTable)) } } diff --git a/spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala b/spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala index 6923384ba3..425646e758 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala @@ -18,12 +18,25 @@ package ai.chronon.spark.test import ai.chronon.aggregator.test.Column import ai.chronon.api -import ai.chronon.api.{Accuracy, Builders, Constants, LongType, Operation, StringType, TimeUnit, Window} +import ai.chronon.api.{ + Accuracy, + Builders, + Constants, + JoinPart, + LongType, + Operation, + PartitionSpec, + StringType, + TimeUnit, + Window +} import ai.chronon.api.Extensions._ import ai.chronon.spark.Extensions._ import ai.chronon.spark.GroupBy.renderDataSourceQuery +import ai.chronon.spark.SemanticHashUtils.{tableHashesChanged, tablesToRecompute} import ai.chronon.spark._ import ai.chronon.spark.stats.SummaryJob +import com.google.gson.Gson import org.apache.spark.rdd.RDD import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StructType, StringType => SparkStringType} @@ -31,12 +44,12 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SparkSession} import org.junit.Assert._ import org.junit.Test import org.scalatest.Assertions.intercept - import org.apache.spark.sql.types._ import org.apache.spark.sql.Row import scala.collection.JavaConverters._ import scala.util.ScalaJavaConversions.ListOps +import scala.util.Try class JoinTest { @@ -799,20 +812,21 @@ class JoinTest { oldJoin.computeJoin(Some(100)) // Make sure that there is no versioning-detected changes at this phase - val joinPartsToRecomputeNoChange = JoinUtils.tablesToRecompute(joinConf, joinConf.metaData.outputTable, tableUtils) - assertEquals(joinPartsToRecomputeNoChange.size, 0) + val joinPartsToRecomputeNoChange = tablesToRecompute(joinConf, joinConf.metaData.outputTable, tableUtils, false) + assertEquals(joinPartsToRecomputeNoChange._1.size, 0) // First test changing the left side table - this should trigger a full recompute val leftChangeJoinConf = joinConf.deepCopy() leftChangeJoinConf.getLeft.getEvents.setTable("some_other_table_name") val leftChangeJoin = new Join(joinConf = leftChangeJoinConf, endPartition = dayAndMonthBefore, tableUtils) val leftChangeRecompute = - JoinUtils.tablesToRecompute(leftChangeJoinConf, leftChangeJoinConf.metaData.outputTable, tableUtils) + tablesToRecompute(leftChangeJoinConf, leftChangeJoinConf.metaData.outputTable, tableUtils, false) println(leftChangeRecompute) - assertEquals(leftChangeRecompute.size, 3) + assertEquals(leftChangeRecompute._1.size, 3) val partTable = s"${leftChangeJoinConf.metaData.outputTable}_user_unit_test_item_views" - assertEquals(leftChangeRecompute, + assertEquals(leftChangeRecompute._1, Seq(partTable, leftChangeJoinConf.metaData.bootstrapTable, leftChangeJoinConf.metaData.outputTable)) + assertTrue(leftChangeRecompute._2) // Test adding a joinPart val addPartJoinConf = joinConf.deepCopy() @@ -820,10 +834,10 @@ class JoinTest { val newJoinPart = Builders.JoinPart(groupBy = getViewsGroupBy(suffix = "versioning"), prefix = "user_2") addPartJoinConf.setJoinParts(Seq(existingJoinPart, newJoinPart).asJava) val addPartJoin = new Join(joinConf = addPartJoinConf, endPartition = dayAndMonthBefore, tableUtils) - val addPartRecompute = - JoinUtils.tablesToRecompute(addPartJoinConf, addPartJoinConf.metaData.outputTable, tableUtils) - assertEquals(addPartRecompute.size, 1) - assertEquals(addPartRecompute, Seq(addPartJoinConf.metaData.outputTable)) + val addPartRecompute = tablesToRecompute(addPartJoinConf, addPartJoinConf.metaData.outputTable, tableUtils, false) + assertEquals(addPartRecompute._1.size, 1) + assertEquals(addPartRecompute._1, Seq(addPartJoinConf.metaData.outputTable)) + assertTrue(addPartRecompute._2) // Compute to ensure that it works and to set the stage for the next assertion addPartJoin.computeJoin(Some(100)) @@ -832,10 +846,11 @@ class JoinTest { rightModJoinConf.getJoinParts.get(1).setPrefix("user_3") val rightModJoin = new Join(joinConf = rightModJoinConf, endPartition = dayAndMonthBefore, tableUtils) val rightModRecompute = - JoinUtils.tablesToRecompute(rightModJoinConf, rightModJoinConf.metaData.outputTable, tableUtils) - assertEquals(rightModRecompute.size, 2) + tablesToRecompute(rightModJoinConf, rightModJoinConf.metaData.outputTable, tableUtils, false) + assertEquals(rightModRecompute._1.size, 2) val rightModPartTable = s"${addPartJoinConf.metaData.outputTable}_user_2_unit_test_item_views" - assertEquals(rightModRecompute, Seq(rightModPartTable, addPartJoinConf.metaData.outputTable)) + assertEquals(rightModRecompute._1, Seq(rightModPartTable, addPartJoinConf.metaData.outputTable)) + assertTrue(rightModRecompute._2) // Modify both rightModJoinConf.getJoinParts.get(0).setPrefix("user_4") val rightModBothJoin = new Join(joinConf = rightModJoinConf, endPartition = dayAndMonthBefore, tableUtils) @@ -1073,7 +1088,7 @@ class JoinTest { } @Test - def testMigration(): Unit = { + def testMigrationForBootstrap(): Unit = { // Left val itemQueriesTable = s"$namespace.item_queries" @@ -1106,6 +1121,7 @@ class JoinTest { joinParts = Seq(Builders.JoinPart(groupBy = groupBy, prefix = "user")), metaData = Builders.MetaData(name = s"test.join_migration", namespace = namespace, team = "chronon") ) + val newSemanticHash = join.semanticHash(excludeTopic = false) // test older versions before migration // older versions do not have the bootstrap hash, but should not trigger recompute if no bootstrap_parts @@ -1113,13 +1129,176 @@ class JoinTest { "left_source" -> "vbQc07vaqm", "test_namespace_jointest.test_join_migration_user_unit_test_item_views" -> "OLFBDTqwMX" ) - assertEquals(0, join.tablesToDrop(productionHashV1).length) + assertEquals(0, tableHashesChanged(productionHashV1, newSemanticHash, join).length) // test newer versions val productionHashV2 = productionHashV1 ++ Map( "test_namespace_jointest.test_join_migration_bootstrap" -> "1B2M2Y8Asg" ) - assertEquals(0, join.tablesToDrop(productionHashV2).length) + assertEquals(0, tableHashesChanged(productionHashV2, newSemanticHash, join).length) + } + + private def prepareTopicTestConfs(prefix: String): (api.Join, String) = { + // left part + val querySchema = Seq(Column("user", api.LongType, 100)) + val queryTable = s"$namespace.${prefix}_left_table" + DataFrameGen + .events(spark, querySchema, 400, partitions = 10) + .where(col("user").isNotNull) + .dropDuplicates("user") + .save(queryTable) + val querySource = Builders.Source.events( + table = queryTable, + query = Builders.Query(Builders.Selects("user"), timeColumn = "ts") + ) + + // right part + val transactionSchema = Seq( + Column("user", LongType, 100), + Column("amount", LongType, 1000) + ) + val transactionsTable = s"$namespace.${prefix}_transactions" + DataFrameGen + .events(spark, transactionSchema, 2000, partitions = 50) + .where(col("user").isNotNull) + .save(transactionsTable) + + val joinPart: JoinPart = Builders.JoinPart(groupBy = Builders.GroupBy( + keyColumns = Seq("user"), + sources = Seq( + Builders.Source.events( + query = Builders.Query( + selects = Builders.Selects("amount"), + timeColumn = "ts" + ), + table = transactionsTable, + topic = "transactions_topic_v1" + )), + aggregations = Seq( + Builders + .Aggregation(operation = Operation.SUM, inputColumn = "amount", windows = Seq(new Window(3, TimeUnit.DAYS)))), + accuracy = Accuracy.SNAPSHOT, + metaData = Builders.MetaData(name = s"join_test.${prefix}_txn", namespace = namespace, team = "chronon") + )) + + // join + val join = Builders.Join( + left = querySource, + joinParts = Seq(joinPart), + metaData = Builders.MetaData(name = s"unit_test.${prefix}_join", namespace = namespace, team = "chronon") + ) + + val endDs = tableUtils.partitions(queryTable).max + (join, endDs) + } + + private def overwriteWithOldSemanticHash(join: api.Join, gson: Gson): Unit = { + // Compute and manually set the semantic_hash computed from using old logic + val oldVersionSemanticHash = join.semanticHash(excludeTopic = false) + val oldTableProperties = Map( + Constants.SemanticHashKey -> gson.toJson(oldVersionSemanticHash.asJava), + Constants.SemanticHashOptionsKey -> gson.toJson( + Map( + Constants.SemanticHashExcludeTopic -> "false" + ).asJava) + ) + tableUtils.alterTableProperties(join.metaData.outputTable, oldTableProperties) + tableUtils.alterTableProperties(join.metaData.bootstrapTable, oldTableProperties) + } + + private def hasExcludeTopicFlag(tableProps: Map[String, String], gson: Gson): Boolean = { + val optionsString = tableProps(Constants.SemanticHashOptionsKey) + val options = gson.fromJson(optionsString, classOf[java.util.HashMap[String, String]]).asScala + options.get(Constants.SemanticHashExcludeTopic).contains("true") + } + + @Test + def testMigrationForTopicSuccess(): Unit = { + val (join, endDs) = prepareTopicTestConfs("test_migration_for_topic_success") + def runJob(join: api.Join, shiftDays: Int): Unit = { + val deepCopy = join.deepCopy() + val joinJob = new Join(deepCopy, tableUtils.partitionSpec.shift(endDs, shiftDays), tableUtils) + joinJob.computeJoin() + } + runJob(join, -2) + + // Compute and manually set the semantic_hash computed from using old logic + val gson = new Gson() + overwriteWithOldSemanticHash(join, gson) + + // Compare semantic hash + val (tablesChanged, autoArchive) = + SemanticHashUtils.tablesToRecompute(join, join.metaData.outputTable, tableUtils, unsetSemanticHash = false) + + assertEquals(0, tablesChanged.length) + assertEquals(false, autoArchive) + + val (shouldRecomputeLeft, autoArchiveLeft) = + SemanticHashUtils.shouldRecomputeLeft(join, join.metaData.bootstrapTable, tableUtils, unsetSemanticHash = false) + assertEquals(false, shouldRecomputeLeft) + assertEquals(false, autoArchiveLeft) + + // Rerun job and update semantic_hash with new logic + runJob(join, -1) + + val newVersionSemanticHash = join.semanticHash(excludeTopic = true) + + val tablePropsV1 = tableUtils.getTableProperties(join.metaData.outputTable).get + assertTrue(hasExcludeTopicFlag(tablePropsV1, gson)) + assertEquals(gson.toJson(newVersionSemanticHash.asJava), tablePropsV1(Constants.SemanticHashKey)) + + // Modify the topic and rerun + val joinPartNew = join.joinParts.get(0).deepCopy() + joinPartNew.groupBy.sources.asScala.head.getEvents.setTopic("transactions_topic_v2") + val joinNew = join.deepCopy() + joinNew.setJoinParts(Seq(joinPartNew).asJava) + runJob(joinNew, 0) + + // Verify that the semantic hash has NOT changed + val tablePropsV2 = tableUtils.getTableProperties(join.metaData.outputTable).get + assertTrue(hasExcludeTopicFlag(tablePropsV2, gson)) + assertEquals(gson.toJson(newVersionSemanticHash.asJava), tablePropsV2(Constants.SemanticHashKey)) + } + + @Test + def testMigrationForTopicManualArchive(): Unit = { + val (join, endDs) = prepareTopicTestConfs("test_migration_for_topic_manual_archive") + def runJob(join: api.Join, shiftDays: Int, unsetSemanticHash: Boolean = false): Unit = { + val deepCopy = join.deepCopy() + val joinJob = new Join(deepCopy, + tableUtils.partitionSpec.shift(endDs, shiftDays), + tableUtils, + unsetSemanticHash = unsetSemanticHash) + joinJob.computeJoin() + } + runJob(join, -2) + + // Compute and manually set the semantic_hash computed from using old logic + val gson = new Gson() + overwriteWithOldSemanticHash(join, gson) + + // Make real semantic hash change to join_part + val joinPartNew = join.getJoinParts.get(0).deepCopy() + joinPartNew.getGroupBy.getSources.asScala.head.getEvents.setTopic("transactions_topic_v2") + joinPartNew.getGroupBy.getAggregations.asScala.head.setWindows(Seq(new Window(7, TimeUnit.DAYS)).asJava) + val joinNew = join.deepCopy() + joinNew.setJoinParts(Seq(joinPartNew).asJava) + + // Rerun job and update semantic_hash with new logic + // Expect that a failure is thrown to ask for manual archive + val runJobTry = Try(runJob(joinNew, -1)) + assertTrue(runJobTry.isFailure) + assertTrue(runJobTry.failed.get.isInstanceOf[SemanticHashException]) + + // Explicitly unsetSemanticHash to rerun the job. Note: technically the correct behavior here + // should be drop table and rerun. But this is to test the unsetSemanticHash flag. + runJob(joinNew, 0, unsetSemanticHash = true) + + // Verify that semantic_hash has been updated + val newVersionSemanticHash = join.semanticHash(excludeTopic = true) + val tableProps = tableUtils.getTableProperties(join.metaData.outputTable).get + assertTrue(hasExcludeTopicFlag(tableProps, gson)) + assertNotEquals(gson.toJson(newVersionSemanticHash.asJava), tableProps(Constants.SemanticHashKey)) } @Test