Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions api/src/main/scala/ai/chronon/api/Constants.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,7 @@ object Constants {
val ContextualSourceKeys: String = "contextual_keys"
val ContextualSourceValues: String = "contextual_values"
val TeamOverride: String = "team_override"
val LabelColumnPrefix: String = "label"
val LabelViewPropertyFeatureTable: String = "feature_table"
val LabelViewPropertyKeyLabelTable: String = "label_table"
}
12 changes: 12 additions & 0 deletions api/src/main/scala/ai/chronon/api/Extensions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import com.fasterxml.jackson.databind.ObjectMapper
import java.io.{PrintWriter, StringWriter}
import java.util
import scala.collection.mutable
import scala.collection.JavaConverters._
import scala.util.{Failure, ScalaVersionSpecificCollectionsConverter, Success, Try}

object Extensions {
Expand Down Expand Up @@ -75,6 +76,9 @@ object Extensions {
def cleanName: String = metaData.name.sanitize

def outputTable = s"${metaData.outputNamespace}.${metaData.cleanName}"
def outputLabelTable = s"${metaData.outputNamespace}.${metaData.cleanName}_labels"
def outputFinalView = s"${metaData.outputNamespace}.${metaData.cleanName}_labeled"
def outputLatestLabelView = s"${metaData.outputNamespace}.${metaData.cleanName}_labeled_latest"
def loggedTable = s"${outputTable}_logged"
def bootstrapTable = s"${outputTable}_bootstrap"
private def comparisonPrefix = "comparison"
Expand Down Expand Up @@ -563,6 +567,14 @@ object Extensions {
.flatMap(_.groupBy.setups)
.distinct
}

// a list of columns which can identify a row on left, use user specified columns by default
def rowIdentifier(userRowId: util.List[String] = null): Array[String] = {
if (userRowId != null && !userRowId.isEmpty)
userRowId.asScala.toArray
else
leftKeyCols ++ Array(Constants.PartitionColumn)
}
}

implicit class BootstrapPartOps(val bootstrapPart: BootstrapPart) extends Serializable {
Expand Down
4 changes: 4 additions & 0 deletions spark/src/main/scala/ai/chronon/spark/DataRange.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ case class PartitionRange(start: String, end: String) extends DataRange with Ord
(startClause ++ endClause).toSeq
}

def betweenClauses: String = {
s"${Constants.PartitionColumn} BETWEEN '" + start + "' AND '" + end + "'"
}

def substituteMacros(template: String): String = {
val substitutions = Seq(Constants.StartPartitionMacro -> Option(start), Constants.EndPartitionMacro -> Option(end))
substitutions.foldLeft(template) {
Expand Down
128 changes: 128 additions & 0 deletions spark/src/main/scala/ai/chronon/spark/JoinUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,132 @@ object JoinUtils {
val finalDf = joinedDf.select(selects: _*)
finalDf
}

/***
* Method to create or replace a view for feature table joining with labels.
* Label columns will be prefixed with "label" or custom prefix for easy identification
*/
def createOrReplaceView(viewName: String,
leftTable: String,
rightTable: String,
joinKeys: Array[String],
tableUtils: TableUtils,
viewProperties: Map[String, String] = null,
labelColumnPrefix: String = Constants.LabelColumnPrefix): Unit = {
val fieldDefinitions = joinKeys.map(field => s"l.${field}") ++
tableUtils.getSchemaFromTable(leftTable)
.filterNot(field => joinKeys.contains(field.name))
.map(field => s"l.${field.name}") ++
tableUtils.getSchemaFromTable(rightTable)
.filterNot(field => joinKeys.contains(field.name))
.map(field => {
if(field.name.startsWith(labelColumnPrefix)) {
s"r.${field.name}"
} else {
s"r.${field.name} AS ${labelColumnPrefix}_${field.name}"
}
})
val joinKeyDefinitions = joinKeys.map(key => s"l.${key} = r.${key}")
val createFragment = s"""CREATE OR REPLACE VIEW $viewName"""
val queryFragment =
s"""
| AS SELECT
| ${fieldDefinitions.mkString(",\n ")}
| FROM ${leftTable} AS l LEFT OUTER JOIN ${rightTable} AS r
| ON ${joinKeyDefinitions.mkString(" AND ")}""".stripMargin

val propertiesFragment = if (viewProperties != null && viewProperties.nonEmpty) {
s""" TBLPROPERTIES (
| ${viewProperties.transform((k, v) => s"'$k'='$v'").values.mkString(",\n ")}
| )""".stripMargin
} else {
""
}
val sqlStatement = Seq(createFragment, propertiesFragment, queryFragment).mkString("\n")
tableUtils.sql(sqlStatement)
}

/***
* Method to create a view with latest available label_ds for a given ds. This view is built
* on top of final label view which has all label versions available.
* This view will inherit the final label view properties as well.
*/
def createLatestLabelView(viewName: String,
baseView: String,
tableUtils: TableUtils,
propertiesOverride: Map[String, String] = null): Unit = {
val baseViewProperties = tableUtils.getTableProperties(baseView).getOrElse(Map.empty)
val labelTableName = baseViewProperties.get(Constants.LabelViewPropertyKeyLabelTable).getOrElse("")
assert(!labelTableName.isEmpty, s"Not able to locate underlying label table for partitions")

val labelMapping = getLatestLabelMapping(labelTableName, tableUtils)
val caseDefinitions = labelMapping.map( entry => {
entry._2.map(v =>
s"WHEN " + v.betweenClauses + s" THEN ${Constants.LabelPartitionColumn} = '${entry._1}'"
).toList
}).flatten

val createFragment = s"""CREATE OR REPLACE VIEW $viewName"""
val queryFragment =
s"""
| AS SELECT *
| FROM ${baseView}
| WHERE (
| CASE
| ${caseDefinitions.mkString("\n ")}
| ELSE true
| END
| )
| """.stripMargin

val mergedProperties = if (propertiesOverride != null) baseViewProperties ++ propertiesOverride
else baseViewProperties
val propertiesFragment = if (mergedProperties.nonEmpty) {
s"""TBLPROPERTIES (
| ${mergedProperties.transform((k, v) => s"'$k'='$v'").values.mkString(",\n ")}
|)""".stripMargin
} else {
""
}
val sqlStatement = Seq(createFragment, propertiesFragment, queryFragment).mkString("\n")
tableUtils.sql(sqlStatement)
}

/**
* compute the mapping label_ds -> PartitionRange of ds which has this label_ds as latest version
* - Get all partitions from table
* - For each ds, find the latest available label_ds
* - Reverse the mapping and get the ds partition range for each label version(label_ds)
*
* @return Mapping of the label ds -> partition ranges of ds which has this label available as latest
*/
def getLatestLabelMapping(tableName: String, tableUtils: TableUtils): Map[String, Seq[PartitionRange]] = {
val partitions = tableUtils.allPartitions(tableName)
assert(
partitions(0).keys.equals(Set(Constants.PartitionColumn, Constants.LabelPartitionColumn)),
s""" Table must have label partition columns for latest label computation: `${Constants.PartitionColumn}`
| & `${Constants.LabelPartitionColumn}`
|inputView: ${tableName}
|""".stripMargin
)

val labelMap = collection.mutable.Map[String, String]()
partitions.foreach(par => {
val ds_value = par.get(Constants.PartitionColumn).get
val label_value: String = par.get(Constants.LabelPartitionColumn).get
if(!labelMap.contains(ds_value)) {
labelMap.put(ds_value, label_value)
} else {
labelMap.put(ds_value, Seq(labelMap.get(ds_value).get, label_value).max)
}
})

labelMap.groupBy(_._2).map { case (v, kvs) => (v, tableUtils.chunk(kvs.map(_._1).toSet)) }
}

def filterColumns(df: DataFrame, filter: Seq[String]): DataFrame = {
val columnsToDrop = df.columns
.filterNot(col => filter.contains(col))
df.drop(columnsToDrop:_*)
}
}
50 changes: 38 additions & 12 deletions spark/src/main/scala/ai/chronon/spark/LabelJoin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class LabelJoin(joinConf: api.Join, tableUtils: TableUtils, labelDS: String) {
)

val metrics = Metrics.Context(Metrics.Environment.LabelJoin, joinConf)
private val outputTable = joinConf.metaData.outputTable
private val outputLabelTable = joinConf.metaData.outputLabelTable
private val labelJoinConf = joinConf.labelPart
private val confTableProps = Option(joinConf.metaData.tableProperties)
.map(_.asScala.toMap)
Expand All @@ -32,7 +32,7 @@ class LabelJoin(joinConf: api.Join, tableUtils: TableUtils, labelDS: String) {
val leftStart = Constants.Partition.minus(labelDS, new Window(labelJoinConf.leftStartOffset, TimeUnit.DAYS))
val leftEnd = Constants.Partition.minus(labelDS, new Window(labelJoinConf.leftEndOffset, TimeUnit.DAYS))

def computeLabelJoin(stepDays: Option[Int] = None): DataFrame = {
def computeLabelJoin(stepDays: Option[Int] = None, skipFinalJoin: Boolean = false): DataFrame = {
// validations
assert(Option(joinConf.left.dataModel).equals(Option(Events)),
s"join.left.dataMode needs to be Events for label join ${joinConf.metaData.name}")
Expand All @@ -52,19 +52,40 @@ class LabelJoin(joinConf: api.Join, tableUtils: TableUtils, labelDS: String) {
}

labelJoinConf.setups.foreach(tableUtils.sql)
compute(joinConf.left, stepDays, Option(labelDS))
val labelTable = compute(joinConf.left, stepDays, Option(labelDS))

if (skipFinalJoin) {
labelTable
} else {
// creating final join view with feature join output table
println(s"Joining label table : ${outputLabelTable} with joined output table : ${joinConf.metaData.outputTable}")
JoinUtils.createOrReplaceView(joinConf.metaData.outputFinalView,
leftTable = joinConf.metaData.outputTable,
rightTable = outputLabelTable,
joinKeys = labelJoinConf.rowIdentifier(joinConf.rowIds),
tableUtils = tableUtils,
viewProperties = Map(Constants.LabelViewPropertyKeyLabelTable -> outputLabelTable,
Constants.LabelViewPropertyFeatureTable -> joinConf.metaData.outputTable))
println(s"Final labeled view created: ${joinConf.metaData.outputFinalView}")
JoinUtils.createLatestLabelView(joinConf.metaData.outputLatestLabelView,
baseView = joinConf.metaData.outputFinalView,
tableUtils)
println(s"Final view with latest label created: ${joinConf.metaData.outputLatestLabelView}")
labelTable
}
}


def compute(left: Source, stepDays: Option[Int] = None, labelDS: Option[String] = None): DataFrame = {
val rangeToFill = PartitionRange(leftStart, leftEnd)
val today = Constants.Partition.at(System.currentTimeMillis())
val sanitizedLabelDs = labelDS.getOrElse(today)
println(s"Label join range to fill $rangeToFill")
def finalResult = tableUtils.sql(rangeToFill.genScanQuery(null, outputTable))
def finalResult = tableUtils.sql(rangeToFill.genScanQuery(null, outputLabelTable))
//TODO: use unfilledRanges instead of dropPartitionsAfterHole
val earliestHoleOpt =
tableUtils.dropPartitionsAfterHole(left.table,
outputTable,
outputLabelTable,
rangeToFill,
Map(Constants.LabelPartitionColumn -> sanitizedLabelDs))
if (earliestHoleOpt.forall(_ > rangeToFill.end)) {
Expand Down Expand Up @@ -94,18 +115,18 @@ class LabelJoin(joinConf: api.Join, tableUtils: TableUtils, labelDS: String) {
println(s"Computing join for range: $range ${labelDS.getOrElse(today)} $progress")
JoinUtils.leftDf(joinConf, range, tableUtils).map { leftDfInRange =>
computeRange(leftDfInRange, range, sanitizedLabelDs)
.save(outputTable, confTableProps, Seq(Constants.LabelPartitionColumn, Constants.PartitionColumn), true)
.save(outputLabelTable, confTableProps, Seq(Constants.LabelPartitionColumn, Constants.PartitionColumn), true)
val elapsedMins = (System.currentTimeMillis() - startMillis) / (60 * 1000)
metrics.gauge(Metrics.Name.LatencyMinutes, elapsedMins)
metrics.gauge(Metrics.Name.PartitionCount, range.partitions.length)
println(s"Wrote to table $outputTable, into partitions: $range $progress in $elapsedMins mins")
println(s"Wrote to table $outputLabelTable, into partitions: $range $progress in $elapsedMins mins")
}
}
println(s"Wrote to table $outputTable, into partitions: $leftUnfilledRange")
println(s"Wrote to table $outputLabelTable, into partitions: $leftUnfilledRange")
finalResult
}

def computeRange(leftDf: DataFrame, leftRange: PartitionRange, labelDs: String): DataFrame = {
def computeRange(leftDf: DataFrame, leftRange: PartitionRange, sanitizedLabelDs: String): DataFrame = {
val leftDfCount = leftDf.count()
val leftBlooms = labelJoinConf.leftKeyCols.par.map { key =>
key -> leftDf.generateBloomFilter(key, leftDfCount, joinConf.left.table, leftRange)
Expand All @@ -122,12 +143,17 @@ class LabelJoin(joinConf: api.Join, tableUtils: TableUtils, labelDS: String) {
}
}

val joined = rightDfs.zip(labelJoinConf.labels.asScala).foldLeft(leftDf) {
val rowIdentifier = labelJoinConf.rowIdentifier(joinConf.rowIds)
println("Label Join filtering left df with only row identifier:", rowIdentifier.mkString(", "))
val leftFiltered = JoinUtils.filterColumns(leftDf, rowIdentifier)

val joined = rightDfs.zip(labelJoinConf.labels.asScala).foldLeft(leftFiltered) {
case (partialDf, (rightDf, joinPart)) => joinWithLeft(partialDf, rightDf, joinPart)
}

// assign label ds value to avoid null cases
val updatedJoin = joined.withColumn(Constants.LabelPartitionColumn, lit(labelDS))
// assign label ds value and drop duplicates
val updatedJoin = joined.withColumn(Constants.LabelPartitionColumn, lit(sanitizedLabelDs))
.dropDuplicates(rowIdentifier)
updatedJoin.explain()
updatedJoin.drop(Constants.TimePartitionColumn)
}
Expand Down
25 changes: 24 additions & 1 deletion spark/src/main/scala/ai/chronon/spark/TableUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,30 @@ case class TableUtils(sparkSession: SparkSession) {
schema.fieldNames.contains(Constants.PartitionColumn)
}

def partitions(tableName: String, subPartitionsFilter: Map[String, String] = Map.empty): Seq[String] = {
// return all specified partition columns in a table in format of Map[partitionName, PartitionValue]
def allPartitions(tableName: String,
partitionColumnsFilter: Seq[String] = Seq.empty): Seq[Map[String, String]] = {
if (!tableExists(tableName)) return Seq.empty[Map[String, String]]
if (isIcebergTable(tableName)) {
throw new NotImplementedError("Multi-partitions retrieval is not supported on Iceberg tables yet." +
"For single partition retrieval, please use 'partition' method.")
}
sparkSession.sqlContext
.sql(s"SHOW PARTITIONS $tableName")
.collect()
.map { row => {
val partitionMap = parsePartition(row.getString(0))
if(partitionColumnsFilter.isEmpty) {
partitionMap
} else {
partitionMap.filterKeys(key => partitionColumnsFilter.contains(key))
}
}
}
}

def partitions(tableName: String,
subPartitionsFilter: Map[String, String] = Map.empty): Seq[String] = {
if (!tableExists(tableName)) return Seq.empty[String]
if (isIcebergTable(tableName)) {
if (subPartitionsFilter.nonEmpty) {
Expand Down
Loading