Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion api/src/main/scala/ai/chronon/api/Constants.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ object Constants {
val GroupByServingInfoKey = "group_by_serving_info"
val UTF8 = "UTF-8"
val TopicInvalidSuffix = "_invalid"
val lineTab = "\n "
val SemanticHashKey = "semantic_hash"
val SemanticHashOptionsKey = "semantic_hash_options"
val SemanticHashExcludeTopic = "exclude_topic"
val SchemaHash = "schema_hash"
val BootstrapHash = "bootstrap_hash"
val MatchedHashes = "matched_hashes"
Expand Down
99 changes: 40 additions & 59 deletions api/src/main/scala/ai/chronon/api/Extensions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import java.util.regex.Pattern
import scala.collection.{Seq, mutable}
import scala.util.ScalaJavaConversions.{IteratorOps, ListOps, MapOps}
import scala.util.{Failure, Success, Try}
import scala.collection.JavaConverters._

object Extensions {

Expand Down Expand Up @@ -845,14 +846,14 @@ object Extensions {
def partOutputTable(jp: JoinPart): String =
(Seq(join.metaData.outputTable) ++ Option(jp.prefix) :+ jp.groupBy.metaData.cleanName).mkString("_")

private val leftSourceKey: String = "left_source"
private val derivedKey: String = "derived"
val leftSourceKey: String = "left_source"
val derivedKey: String = "derived"

/*
* semanticHash contains hashes of left side and each join part, and is used to detect join definition
* changes and determine whether any intermediate/final tables of the join need to be recomputed.
*/
def semanticHash: Map[String, String] = {
private[api] def baseSemanticHash: Map[String, String] = {
val leftHash = ThriftJsonCodec.md5Digest(join.left)
logger.info(s"Join Left Object: ${ThriftJsonCodec.toJsonStr(join.left)}")
val partHashes = join.joinParts.toScala.map { jp => partOutputTable(jp) -> jp.groupBy.semanticHash }.toMap
Expand All @@ -867,6 +868,42 @@ object Extensions {
partHashes ++ Map(leftSourceKey -> leftHash, join.metaData.bootstrapTable -> bootstrapHash) ++ derivedHashMap
}

/*
* Unset topic / mutationTopic in everywhere for this join recursively.
* Input join will be modified in place.
*/
private def cleanTopic(join: Join): Join = {
def cleanTopicInSource(source: Source): Unit = {
if (source.isSetEvents) {
source.getEvents.unsetTopic()
} else if (source.isSetEntities) {
source.getEntities.unsetMutationTopic()
} else if (source.isSetJoinSource) {
cleanTopic(source.getJoinSource.getJoin)
}
}

cleanTopicInSource(join.left)
join.getJoinParts.toScala.foreach(_.groupBy.sources.toScala.foreach(cleanTopicInSource))
join
}

/*
* Compute variants of semantic_hash with different flags. A flag is stored on Hive metadata and used to
* indicate which version of semantic_hash logic to use.
*/
def semanticHash(excludeTopic: Boolean): Map[String, String] = {
if (excludeTopic) {
// WARN: deepCopy doesn't guarantee same semantic_hash will be produced due to reordering of map keys
// but the behavior is deterministic
val joinCopy = join.deepCopy()
cleanTopic(joinCopy)
joinCopy.baseSemanticHash
} else {
baseSemanticHash
}
}

/*
External features computed in online env and logged
This method will get the external feature column names
Expand All @@ -887,62 +924,6 @@ object Extensions {
.getOrElse(Seq.empty)
}

/*
* onlineSemanticHash includes everything in semanticHash as well as hashes of each onlineExternalParts (which only
* affect online serving but not offline table generation).
* It is used to detect join definition change in online serving and to update ttl-cached conf files.
*/
def onlineSemanticHash: Map[String, String] = {
if (join.onlineExternalParts == null) {
return Map.empty[String, String]
}

val externalPartHashes = join.onlineExternalParts.toScala.map { part => part.fullName -> part.semanticHash }.toMap

externalPartHashes ++ semanticHash
}

def leftChanged(oldSemanticHash: Map[String, String]): Boolean = {
// Checks for semantic changes in left or bootstrap, because those are saved together
val bootstrapExistsAndChanged = oldSemanticHash.contains(join.metaData.bootstrapTable) && oldSemanticHash.get(
join.metaData.bootstrapTable) != semanticHash.get(join.metaData.bootstrapTable)
logger.info(s"Bootstrap table changed: $bootstrapExistsAndChanged")
logger.info(s"Old Semantic Hash: $oldSemanticHash")
logger.info(s"New Semantic Hash: $semanticHash")
oldSemanticHash.get(leftSourceKey) != semanticHash.get(leftSourceKey) || bootstrapExistsAndChanged
}

def tablesToDrop(oldSemanticHash: Map[String, String]): Seq[String] = {
val newSemanticHash = semanticHash
// only right join part hashes for convenience
def partHashes(semanticHashMap: Map[String, String]): Map[String, String] = {
semanticHashMap.filter { case (name, _) => name != leftSourceKey && name != derivedKey }
}

// drop everything if left source changes
val partsToDrop = if (leftChanged(oldSemanticHash)) {
partHashes(oldSemanticHash).keys.toSeq
} else {
val changed = partHashes(newSemanticHash).flatMap {
case (key, newVal) =>
oldSemanticHash.get(key).filter(_ != newVal).map(_ => key)
}
val deleted = partHashes(oldSemanticHash).keys.filterNot(newSemanticHash.contains)
(changed ++ deleted).toSeq
}
val added = newSemanticHash.keys.filter(!oldSemanticHash.contains(_)).filter {
// introduce boostrapTable as a semantic_hash but skip dropping to avoid recompute if it is empty
case key if key == join.metaData.bootstrapTable => join.isSetBootstrapParts && !join.bootstrapParts.isEmpty
case _ => true
}
val derivedChanges = oldSemanticHash.get(derivedKey) != newSemanticHash.get(derivedKey)
// TODO: make this incremental, retain the main table and continue joining, dropping etc
val mainTable = if (partsToDrop.nonEmpty || added.nonEmpty || derivedChanges) {
Some(join.metaData.outputTable)
} else None
partsToDrop ++ mainTable
}

def isProduction: Boolean = join.getMetaData.isProduction

def team: String = join.getMetaData.getTeam
Expand Down
25 changes: 20 additions & 5 deletions spark/src/main/scala/ai/chronon/spark/Driver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ object Driver {
def parseConf[T <: TBase[_, _]: Manifest: ClassTag](confPath: String): T =
ThriftJsonCodec.fromJsonFile[T](confPath, check = true)

trait JoinBackfillSubcommand {
this: ScallopConf =>
val unsetSemanticHash: ScallopOption[Boolean] =
opt[Boolean](
required = false,
default = Some(false),
descr =
"When set to true, semantic_hash is unset in join tblprops to allow for config update without recompute."
)
}

trait OfflineSubcommand {
this: ScallopConf =>
val confPath: ScallopOption[String] = opt[String](required = true, descr = "Path to conf")
Expand Down Expand Up @@ -236,6 +247,7 @@ object Driver {
class Args
extends Subcommand("join")
with OfflineSubcommand
with JoinBackfillSubcommand
with LocalExportTableAbility
with ResultValidationAbility {
val selectedJoinParts: ScallopOption[List[String]] =
Expand All @@ -256,7 +268,8 @@ object Driver {
args.endDate(),
args.buildTableUtils(),
!args.runFirstHole(),
selectedJoinParts = args.selectedJoinParts.toOption
selectedJoinParts = args.selectedJoinParts.toOption,
unsetSemanticHash = args.unsetSemanticHash.getOrElse(false)
)

if (args.selectedJoinParts.isDefined) {
Expand Down Expand Up @@ -291,19 +304,20 @@ object Driver {
class Args
extends Subcommand("join-left")
with OfflineSubcommand
with JoinBackfillSubcommand
with LocalExportTableAbility
with ResultValidationAbility {
lazy val joinConf: api.Join = parseConf[api.Join](confPath())
override def subcommandName() = s"join_left_${joinConf.metaData.name}"
}

def run(args: Args): Unit = {
val tableUtils = args.buildTableUtils()
val join = new Join(
args.joinConf,
args.endDate(),
args.buildTableUtils(),
!args.runFirstHole()
!args.runFirstHole(),
unsetSemanticHash = args.unsetSemanticHash.getOrElse(false)
)
join.computeLeft(args.startPartitionOverride.toOption)
}
Expand All @@ -314,19 +328,20 @@ object Driver {
class Args
extends Subcommand("join-final")
with OfflineSubcommand
with JoinBackfillSubcommand
with LocalExportTableAbility
with ResultValidationAbility {
lazy val joinConf: api.Join = parseConf[api.Join](confPath())
override def subcommandName() = s"join_final_${joinConf.metaData.name}"
}

def run(args: Args): Unit = {
val tableUtils = args.buildTableUtils()
val join = new Join(
args.joinConf,
args.endDate(),
args.buildTableUtils(),
!args.runFirstHole()
!args.runFirstHole(),
unsetSemanticHash = args.unsetSemanticHash.getOrElse(false)
)
join.computeFinal(args.startPartitionOverride.toOption)
}
Expand Down
5 changes: 3 additions & 2 deletions spark/src/main/scala/ai/chronon/spark/Join.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ class Join(joinConf: api.Join,
tableUtils: TableUtils,
skipFirstHole: Boolean = true,
showDf: Boolean = false,
selectedJoinParts: Option[List[String]] = None)
extends JoinBase(joinConf, endPartition, tableUtils, skipFirstHole, showDf, selectedJoinParts) {
selectedJoinParts: Option[List[String]] = None,
unsetSemanticHash: Boolean = false)
extends JoinBase(joinConf, endPartition, tableUtils, skipFirstHole, showDf, selectedJoinParts, unsetSemanticHash) {

private def padFields(df: DataFrame, structType: sql.types.StructType): DataFrame = {
structType.foldLeft(df) {
Expand Down
72 changes: 54 additions & 18 deletions spark/src/main/scala/ai/chronon/spark/JoinBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,27 @@ import ai.chronon.api.Extensions._
import ai.chronon.api.{Accuracy, Constants, JoinPart}
import ai.chronon.online.Metrics
import ai.chronon.spark.Extensions._
import ai.chronon.spark.JoinUtils.{coalescedJoin, leftDf, tablesToRecompute, shouldRecomputeLeft}
import ai.chronon.spark.JoinUtils.{coalescedJoin, leftDf}
import ai.chronon.spark.SemanticHashUtils.{shouldRecomputeLeft, tablesToRecompute}
import com.google.gson.Gson
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.util.sketch.BloomFilter
import org.slf4j.LoggerFactory
import java.time.Instant

import java.time.Instant
import java.util
import scala.collection.JavaConverters._
import scala.collection.Seq
import java.util
import scala.util.ScalaJavaConversions.ListOps

abstract class JoinBase(joinConf: api.Join,
endPartition: String,
tableUtils: TableUtils,
skipFirstHole: Boolean,
showDf: Boolean = false,
selectedJoinParts: Option[Seq[String]] = None) {
selectedJoinParts: Option[Seq[String]] = None,
unsetSemanticHash: Boolean) {
@transient lazy val logger = LoggerFactory.getLogger(getClass)
assert(Option(joinConf.metaData.outputNamespace).nonEmpty, s"output namespace could not be empty or null")
val metrics: Metrics.Context = Metrics.Context(Metrics.Environment.JoinOffline, joinConf)
Expand All @@ -57,7 +59,13 @@ abstract class JoinBase(joinConf: api.Join,
private val gson = new Gson()
// Combine tableProperties set on conf with encoded Join
protected val tableProps: Map[String, String] =
confTableProps ++ Map(Constants.SemanticHashKey -> gson.toJson(joinConf.semanticHash.asJava))
confTableProps ++ Map(
Constants.SemanticHashKey -> gson.toJson(joinConf.semanticHash(excludeTopic = true).asJava),
Constants.SemanticHashOptionsKey -> gson.toJson(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: with SemanticHashOptionsKey, we might not need to pass something like excludeTopic to .semanticHash(). All arguments can be packaged into SemanticHashOptionsKey.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@donghanz i think we will still keep SemanticHashKey and SemanticHashOptionsKey as two separate properties in hive tables.

  • SemanticHashKey: will store the semantic hashes (a map from string to hash string)
  • SemanticHashOptionsKey: will store a few flags (a map from string to boolean)

I pass excludeTopic to .semanticHash(), only because excludeTopic will determine the logic I use to compute the hash, the flag itself is not stored in the hash.

Lmk if that makes sense

Map(
Constants.SemanticHashExcludeTopic -> "true"
).asJava)
)

def joinWithLeft(leftDf: DataFrame, rightDf: DataFrame, joinPart: JoinPart): DataFrame = {
val partLeftKeys = joinPart.rightToLeft.values.toArray
Expand Down Expand Up @@ -342,13 +350,37 @@ abstract class JoinBase(joinConf: api.Join,
.getOrElse(Seq.empty))
}

private def performArchive(tables: Seq[String], autoArchive: Boolean, mainTable: String): Unit = {
if (autoArchive) {
val archivedAtTs = Instant.now()
tables.foreach(tableUtils.archiveOrDropTableIfExists(_, Some(archivedAtTs)))
} else {
val errorMsg = s"""Auto archive is disabled due to semantic hash out of date.
|Please verify if your config involves true semantic changes. If so, archive the following tables:
|${tables.map(t => s"- $t").mkString("\n")}
|If not, please retry this job with `--unset-semantic-hash` flag in your run.py args.
|OR run the spark SQL cmd: ALTER TABLE $mainTable UNSET TBLPROPERTIES ('semantic_hash') and then retry this job.
|""".stripMargin
logger.error(errorMsg)
throw SemanticHashException(errorMsg)
}
}

def computeLeft(overrideStartPartition: Option[String] = None): Unit = {
// Runs the left side query for a join and saves the output to a table, for reuse by joinPart
// Computation in parallelized joinPart execution mode.
if (shouldRecomputeLeft(joinConf, bootstrapTable, tableUtils)) {
logger.info(s"Detected semantic change in left side of join, archiving left table for recomputation.")
val archivedAtTs = Instant.now()
tableUtils.archiveOrDropTableIfExists(bootstrapTable, Some(archivedAtTs))
val (shouldRecompute, autoArchive) = shouldRecomputeLeft(joinConf, bootstrapTable, tableUtils, unsetSemanticHash)
if (!shouldRecompute) {
logger.info(s"No semantic change detected, leaving bootstrap table in place.")
// Still update the semantic hash of the bootstrap table.
// It is possible that while bootstrap_table's semantic_hash is unchanged, the rest has changed, so
// we keep everything in sync.
if (tableUtils.tableExists(bootstrapTable)) {
tableUtils.alterTableProperties(bootstrapTable, tableProps)
}
} else {
logger.info(s"Detected semantic change in left side of join, archiving bootstrap table for recomputation.")
performArchive(Seq(bootstrapTable), autoArchive, bootstrapTable)
}

val (rangeToFill, unfilledRanges) = getUnfilledRange(overrideStartPartition, bootstrapTable)
Expand All @@ -375,15 +407,15 @@ abstract class JoinBase(joinConf: api.Join,

def computeFinalJoin(leftDf: DataFrame, leftRange: PartitionRange, bootstrapInfo: BootstrapInfo): Unit

def computeFinal(overrideStartPartition: Option[String] = None) = {
def computeFinal(overrideStartPartition: Option[String] = None): Unit = {

// Utilizes the same tablesToRecompute check as the monolithic spark job, because if any joinPart changes, then so does the output table
if (tablesToRecompute(joinConf, outputTable, tableUtils).isEmpty) {
val (tablesChanged, autoArchive) = tablesToRecompute(joinConf, outputTable, tableUtils, unsetSemanticHash)
if (tablesChanged.isEmpty) {
logger.info(s"No semantic change detected, leaving output table in place.")
} else {
logger.info(s"Semantic changes detected, archiving output table.")
val archivedAtTs = Instant.now()
tableUtils.archiveOrDropTableIfExists(outputTable, Some(archivedAtTs))
performArchive(tablesChanged, autoArchive, outputTable)
}

val (rangeToFill, unfilledRanges) = getUnfilledRange(overrideStartPartition, outputTable)
Expand Down Expand Up @@ -448,11 +480,15 @@ abstract class JoinBase(joinConf: api.Join,
leftDf(joinConf, PartitionRange(endPartition, endPartition)(tableUtils), tableUtils, limit = Some(1)).map(df =>
df.schema)

val archivedAtTs = Instant.now()
// Check semantic hash before overwriting left side
// TODO: We should not archive the output table in the case of selected join parts mode
tablesToRecompute(joinConf, outputTable, tableUtils).foreach(
tableUtils.archiveOrDropTableIfExists(_, Some(archivedAtTs)))
// Run command to archive ALL tables that have changed semantically since the last run
// TODO: We should not archive the output table or other JP's intermediate tables in the case of selected join parts mode
val (tablesChanged, autoArchive) = tablesToRecompute(joinConf, outputTable, tableUtils, unsetSemanticHash)
if (tablesChanged.isEmpty) {
logger.info(s"No semantic change detected, leaving output table in place.")
} else {
logger.info(s"Semantic changes detected, archiving output table.")
performArchive(tablesChanged, autoArchive, outputTable)
}

// Overwriting Join Left with bootstrap table to simplify later computation
val source = joinConf.left
Expand Down
Loading