diff --git a/api/src/main/scala/ai/chronon/api/Constants.scala b/api/src/main/scala/ai/chronon/api/Constants.scala index bafc209eee..217cab9df3 100644 --- a/api/src/main/scala/ai/chronon/api/Constants.scala +++ b/api/src/main/scala/ai/chronon/api/Constants.scala @@ -38,4 +38,7 @@ object Constants { val ContextualSourceKeys: String = "contextual_keys" val ContextualSourceValues: String = "contextual_values" val TeamOverride: String = "team_override" + val LabelColumnPrefix: String = "label" + val LabelViewPropertyFeatureTable: String = "feature_table" + val LabelViewPropertyKeyLabelTable: String = "label_table" } diff --git a/api/src/main/scala/ai/chronon/api/Extensions.scala b/api/src/main/scala/ai/chronon/api/Extensions.scala index f0b9276afc..6dbced546f 100644 --- a/api/src/main/scala/ai/chronon/api/Extensions.scala +++ b/api/src/main/scala/ai/chronon/api/Extensions.scala @@ -8,6 +8,7 @@ import com.fasterxml.jackson.databind.ObjectMapper import java.io.{PrintWriter, StringWriter} import java.util import scala.collection.mutable +import scala.collection.JavaConverters._ import scala.util.{Failure, ScalaVersionSpecificCollectionsConverter, Success, Try} object Extensions { @@ -75,6 +76,9 @@ object Extensions { def cleanName: String = metaData.name.sanitize def outputTable = s"${metaData.outputNamespace}.${metaData.cleanName}" + def outputLabelTable = s"${metaData.outputNamespace}.${metaData.cleanName}_labels" + def outputFinalView = s"${metaData.outputNamespace}.${metaData.cleanName}_labeled" + def outputLatestLabelView = s"${metaData.outputNamespace}.${metaData.cleanName}_labeled_latest" def loggedTable = s"${outputTable}_logged" def bootstrapTable = s"${outputTable}_bootstrap" private def comparisonPrefix = "comparison" @@ -563,6 +567,14 @@ object Extensions { .flatMap(_.groupBy.setups) .distinct } + + // a list of columns which can identify a row on left, use user specified columns by default + def rowIdentifier(userRowId: util.List[String] = null): Array[String] = { + if (userRowId != null && !userRowId.isEmpty) + userRowId.asScala.toArray + else + leftKeyCols ++ Array(Constants.PartitionColumn) + } } implicit class BootstrapPartOps(val bootstrapPart: BootstrapPart) extends Serializable { diff --git a/spark/src/main/scala/ai/chronon/spark/DataRange.scala b/spark/src/main/scala/ai/chronon/spark/DataRange.scala index e6b50a792b..a941a16e2a 100644 --- a/spark/src/main/scala/ai/chronon/spark/DataRange.scala +++ b/spark/src/main/scala/ai/chronon/spark/DataRange.scala @@ -58,6 +58,10 @@ case class PartitionRange(start: String, end: String) extends DataRange with Ord (startClause ++ endClause).toSeq } + def betweenClauses: String = { + s"${Constants.PartitionColumn} BETWEEN '" + start + "' AND '" + end + "'" + } + def substituteMacros(template: String): String = { val substitutions = Seq(Constants.StartPartitionMacro -> Option(start), Constants.EndPartitionMacro -> Option(end)) substitutions.foldLeft(template) { diff --git a/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala b/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala index 3f78e21c03..18cfb487ed 100644 --- a/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala @@ -106,4 +106,132 @@ object JoinUtils { val finalDf = joinedDf.select(selects: _*) finalDf } + + /*** + * Method to create or replace a view for feature table joining with labels. + * Label columns will be prefixed with "label" or custom prefix for easy identification + */ + def createOrReplaceView(viewName: String, + leftTable: String, + rightTable: String, + joinKeys: Array[String], + tableUtils: TableUtils, + viewProperties: Map[String, String] = null, + labelColumnPrefix: String = Constants.LabelColumnPrefix): Unit = { + val fieldDefinitions = joinKeys.map(field => s"l.${field}") ++ + tableUtils.getSchemaFromTable(leftTable) + .filterNot(field => joinKeys.contains(field.name)) + .map(field => s"l.${field.name}") ++ + tableUtils.getSchemaFromTable(rightTable) + .filterNot(field => joinKeys.contains(field.name)) + .map(field => { + if(field.name.startsWith(labelColumnPrefix)) { + s"r.${field.name}" + } else { + s"r.${field.name} AS ${labelColumnPrefix}_${field.name}" + } + }) + val joinKeyDefinitions = joinKeys.map(key => s"l.${key} = r.${key}") + val createFragment = s"""CREATE OR REPLACE VIEW $viewName""" + val queryFragment = + s""" + | AS SELECT + | ${fieldDefinitions.mkString(",\n ")} + | FROM ${leftTable} AS l LEFT OUTER JOIN ${rightTable} AS r + | ON ${joinKeyDefinitions.mkString(" AND ")}""".stripMargin + + val propertiesFragment = if (viewProperties != null && viewProperties.nonEmpty) { + s""" TBLPROPERTIES ( + | ${viewProperties.transform((k, v) => s"'$k'='$v'").values.mkString(",\n ")} + | )""".stripMargin + } else { + "" + } + val sqlStatement = Seq(createFragment, propertiesFragment, queryFragment).mkString("\n") + tableUtils.sql(sqlStatement) + } + + /*** + * Method to create a view with latest available label_ds for a given ds. This view is built + * on top of final label view which has all label versions available. + * This view will inherit the final label view properties as well. + */ + def createLatestLabelView(viewName: String, + baseView: String, + tableUtils: TableUtils, + propertiesOverride: Map[String, String] = null): Unit = { + val baseViewProperties = tableUtils.getTableProperties(baseView).getOrElse(Map.empty) + val labelTableName = baseViewProperties.get(Constants.LabelViewPropertyKeyLabelTable).getOrElse("") + assert(!labelTableName.isEmpty, s"Not able to locate underlying label table for partitions") + + val labelMapping = getLatestLabelMapping(labelTableName, tableUtils) + val caseDefinitions = labelMapping.map( entry => { + entry._2.map(v => + s"WHEN " + v.betweenClauses + s" THEN ${Constants.LabelPartitionColumn} = '${entry._1}'" + ).toList + }).flatten + + val createFragment = s"""CREATE OR REPLACE VIEW $viewName""" + val queryFragment = + s""" + | AS SELECT * + | FROM ${baseView} + | WHERE ( + | CASE + | ${caseDefinitions.mkString("\n ")} + | ELSE true + | END + | ) + | """.stripMargin + + val mergedProperties = if (propertiesOverride != null) baseViewProperties ++ propertiesOverride + else baseViewProperties + val propertiesFragment = if (mergedProperties.nonEmpty) { + s"""TBLPROPERTIES ( + | ${mergedProperties.transform((k, v) => s"'$k'='$v'").values.mkString(",\n ")} + |)""".stripMargin + } else { + "" + } + val sqlStatement = Seq(createFragment, propertiesFragment, queryFragment).mkString("\n") + tableUtils.sql(sqlStatement) + } + + /** + * compute the mapping label_ds -> PartitionRange of ds which has this label_ds as latest version + * - Get all partitions from table + * - For each ds, find the latest available label_ds + * - Reverse the mapping and get the ds partition range for each label version(label_ds) + * + * @return Mapping of the label ds -> partition ranges of ds which has this label available as latest + */ + def getLatestLabelMapping(tableName: String, tableUtils: TableUtils): Map[String, Seq[PartitionRange]] = { + val partitions = tableUtils.allPartitions(tableName) + assert( + partitions(0).keys.equals(Set(Constants.PartitionColumn, Constants.LabelPartitionColumn)), + s""" Table must have label partition columns for latest label computation: `${Constants.PartitionColumn}` + | & `${Constants.LabelPartitionColumn}` + |inputView: ${tableName} + |""".stripMargin + ) + + val labelMap = collection.mutable.Map[String, String]() + partitions.foreach(par => { + val ds_value = par.get(Constants.PartitionColumn).get + val label_value: String = par.get(Constants.LabelPartitionColumn).get + if(!labelMap.contains(ds_value)) { + labelMap.put(ds_value, label_value) + } else { + labelMap.put(ds_value, Seq(labelMap.get(ds_value).get, label_value).max) + } + }) + + labelMap.groupBy(_._2).map { case (v, kvs) => (v, tableUtils.chunk(kvs.map(_._1).toSet)) } + } + + def filterColumns(df: DataFrame, filter: Seq[String]): DataFrame = { + val columnsToDrop = df.columns + .filterNot(col => filter.contains(col)) + df.drop(columnsToDrop:_*) + } } diff --git a/spark/src/main/scala/ai/chronon/spark/LabelJoin.scala b/spark/src/main/scala/ai/chronon/spark/LabelJoin.scala index ce6310fab0..eb0216683c 100644 --- a/spark/src/main/scala/ai/chronon/spark/LabelJoin.scala +++ b/spark/src/main/scala/ai/chronon/spark/LabelJoin.scala @@ -23,7 +23,7 @@ class LabelJoin(joinConf: api.Join, tableUtils: TableUtils, labelDS: String) { ) val metrics = Metrics.Context(Metrics.Environment.LabelJoin, joinConf) - private val outputTable = joinConf.metaData.outputTable + private val outputLabelTable = joinConf.metaData.outputLabelTable private val labelJoinConf = joinConf.labelPart private val confTableProps = Option(joinConf.metaData.tableProperties) .map(_.asScala.toMap) @@ -32,7 +32,7 @@ class LabelJoin(joinConf: api.Join, tableUtils: TableUtils, labelDS: String) { val leftStart = Constants.Partition.minus(labelDS, new Window(labelJoinConf.leftStartOffset, TimeUnit.DAYS)) val leftEnd = Constants.Partition.minus(labelDS, new Window(labelJoinConf.leftEndOffset, TimeUnit.DAYS)) - def computeLabelJoin(stepDays: Option[Int] = None): DataFrame = { + def computeLabelJoin(stepDays: Option[Int] = None, skipFinalJoin: Boolean = false): DataFrame = { // validations assert(Option(joinConf.left.dataModel).equals(Option(Events)), s"join.left.dataMode needs to be Events for label join ${joinConf.metaData.name}") @@ -52,19 +52,40 @@ class LabelJoin(joinConf: api.Join, tableUtils: TableUtils, labelDS: String) { } labelJoinConf.setups.foreach(tableUtils.sql) - compute(joinConf.left, stepDays, Option(labelDS)) + val labelTable = compute(joinConf.left, stepDays, Option(labelDS)) + + if (skipFinalJoin) { + labelTable + } else { + // creating final join view with feature join output table + println(s"Joining label table : ${outputLabelTable} with joined output table : ${joinConf.metaData.outputTable}") + JoinUtils.createOrReplaceView(joinConf.metaData.outputFinalView, + leftTable = joinConf.metaData.outputTable, + rightTable = outputLabelTable, + joinKeys = labelJoinConf.rowIdentifier(joinConf.rowIds), + tableUtils = tableUtils, + viewProperties = Map(Constants.LabelViewPropertyKeyLabelTable -> outputLabelTable, + Constants.LabelViewPropertyFeatureTable -> joinConf.metaData.outputTable)) + println(s"Final labeled view created: ${joinConf.metaData.outputFinalView}") + JoinUtils.createLatestLabelView(joinConf.metaData.outputLatestLabelView, + baseView = joinConf.metaData.outputFinalView, + tableUtils) + println(s"Final view with latest label created: ${joinConf.metaData.outputLatestLabelView}") + labelTable + } } + def compute(left: Source, stepDays: Option[Int] = None, labelDS: Option[String] = None): DataFrame = { val rangeToFill = PartitionRange(leftStart, leftEnd) val today = Constants.Partition.at(System.currentTimeMillis()) val sanitizedLabelDs = labelDS.getOrElse(today) println(s"Label join range to fill $rangeToFill") - def finalResult = tableUtils.sql(rangeToFill.genScanQuery(null, outputTable)) + def finalResult = tableUtils.sql(rangeToFill.genScanQuery(null, outputLabelTable)) //TODO: use unfilledRanges instead of dropPartitionsAfterHole val earliestHoleOpt = tableUtils.dropPartitionsAfterHole(left.table, - outputTable, + outputLabelTable, rangeToFill, Map(Constants.LabelPartitionColumn -> sanitizedLabelDs)) if (earliestHoleOpt.forall(_ > rangeToFill.end)) { @@ -94,18 +115,18 @@ class LabelJoin(joinConf: api.Join, tableUtils: TableUtils, labelDS: String) { println(s"Computing join for range: $range ${labelDS.getOrElse(today)} $progress") JoinUtils.leftDf(joinConf, range, tableUtils).map { leftDfInRange => computeRange(leftDfInRange, range, sanitizedLabelDs) - .save(outputTable, confTableProps, Seq(Constants.LabelPartitionColumn, Constants.PartitionColumn), true) + .save(outputLabelTable, confTableProps, Seq(Constants.LabelPartitionColumn, Constants.PartitionColumn), true) val elapsedMins = (System.currentTimeMillis() - startMillis) / (60 * 1000) metrics.gauge(Metrics.Name.LatencyMinutes, elapsedMins) metrics.gauge(Metrics.Name.PartitionCount, range.partitions.length) - println(s"Wrote to table $outputTable, into partitions: $range $progress in $elapsedMins mins") + println(s"Wrote to table $outputLabelTable, into partitions: $range $progress in $elapsedMins mins") } } - println(s"Wrote to table $outputTable, into partitions: $leftUnfilledRange") + println(s"Wrote to table $outputLabelTable, into partitions: $leftUnfilledRange") finalResult } - def computeRange(leftDf: DataFrame, leftRange: PartitionRange, labelDs: String): DataFrame = { + def computeRange(leftDf: DataFrame, leftRange: PartitionRange, sanitizedLabelDs: String): DataFrame = { val leftDfCount = leftDf.count() val leftBlooms = labelJoinConf.leftKeyCols.par.map { key => key -> leftDf.generateBloomFilter(key, leftDfCount, joinConf.left.table, leftRange) @@ -122,12 +143,17 @@ class LabelJoin(joinConf: api.Join, tableUtils: TableUtils, labelDS: String) { } } - val joined = rightDfs.zip(labelJoinConf.labels.asScala).foldLeft(leftDf) { + val rowIdentifier = labelJoinConf.rowIdentifier(joinConf.rowIds) + println("Label Join filtering left df with only row identifier:", rowIdentifier.mkString(", ")) + val leftFiltered = JoinUtils.filterColumns(leftDf, rowIdentifier) + + val joined = rightDfs.zip(labelJoinConf.labels.asScala).foldLeft(leftFiltered) { case (partialDf, (rightDf, joinPart)) => joinWithLeft(partialDf, rightDf, joinPart) } - // assign label ds value to avoid null cases - val updatedJoin = joined.withColumn(Constants.LabelPartitionColumn, lit(labelDS)) + // assign label ds value and drop duplicates + val updatedJoin = joined.withColumn(Constants.LabelPartitionColumn, lit(sanitizedLabelDs)) + .dropDuplicates(rowIdentifier) updatedJoin.explain() updatedJoin.drop(Constants.TimePartitionColumn) } diff --git a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala index b6b3e2eb0e..e381ee9b1e 100644 --- a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala @@ -39,7 +39,30 @@ case class TableUtils(sparkSession: SparkSession) { schema.fieldNames.contains(Constants.PartitionColumn) } - def partitions(tableName: String, subPartitionsFilter: Map[String, String] = Map.empty): Seq[String] = { + // return all specified partition columns in a table in format of Map[partitionName, PartitionValue] + def allPartitions(tableName: String, + partitionColumnsFilter: Seq[String] = Seq.empty): Seq[Map[String, String]] = { + if (!tableExists(tableName)) return Seq.empty[Map[String, String]] + if (isIcebergTable(tableName)) { + throw new NotImplementedError("Multi-partitions retrieval is not supported on Iceberg tables yet." + + "For single partition retrieval, please use 'partition' method.") + } + sparkSession.sqlContext + .sql(s"SHOW PARTITIONS $tableName") + .collect() + .map { row => { + val partitionMap = parsePartition(row.getString(0)) + if(partitionColumnsFilter.isEmpty) { + partitionMap + } else { + partitionMap.filterKeys(key => partitionColumnsFilter.contains(key)) + } + } + } + } + + def partitions(tableName: String, + subPartitionsFilter: Map[String, String] = Map.empty): Seq[String] = { if (!tableExists(tableName)) return Seq.empty[String] if (isIcebergTable(tableName)) { if (subPartitionsFilter.nonEmpty) { diff --git a/spark/src/test/scala/ai/chronon/spark/test/FeatureWithLabelJoinTest.scala b/spark/src/test/scala/ai/chronon/spark/test/FeatureWithLabelJoinTest.scala new file mode 100644 index 0000000000..b4e3ed649e --- /dev/null +++ b/spark/src/test/scala/ai/chronon/spark/test/FeatureWithLabelJoinTest.scala @@ -0,0 +1,140 @@ +package ai.chronon.spark.test + +import ai.chronon.api.Extensions.{LabelPartOps, MetadataOps} +import ai.chronon.api.{Builders, LongType, StringType, StructField, StructType} +import ai.chronon.spark.{Comparison, LabelJoin, SparkSessionBuilder, TableUtils} +import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.sql.functions.{max, min} +import org.junit.Assert.assertEquals +import org.junit.Test + +class FeatureWithLabelJoinTest { + val spark: SparkSession = SparkSessionBuilder.build("FeatureWithLabelJoinTest", local = true) + + private val namespace = "final_join" + private val tableName = "test_feature_label_join" + spark.sql(s"CREATE DATABASE IF NOT EXISTS $namespace") + private val tableUtils = TableUtils(spark) + + private val labelDS = "2022-10-30" + private val viewsGroupBy = TestUtils.createViewsGroupBy(namespace, spark) + private val left = viewsGroupBy.groupByConf.sources.get(0) + + @Test + def testFinalViews(): Unit = { + // create test feature join table + val featureTable = s"${namespace}.${tableName}" + createTestFeatureTable().write.saveAsTable(featureTable) + + val labelJoinConf = createTestLabelJoin(50, 20) + val joinConf = Builders.Join( + Builders.MetaData(name = tableName, namespace = namespace, team = "chronon"), + left, + labelPart = labelJoinConf + ) + + val runner = new LabelJoin(joinConf, tableUtils, labelDS) + val labelDf = runner.computeLabelJoin() + println(" == First Run Label version 2022-10-30 == ") + prefixColumnName(labelDf, exceptions = labelJoinConf.rowIdentifier()).show() + val featureDf = tableUtils.sparkSession.table(joinConf.metaData.outputTable) + println(" == Features == ") + featureDf.show() + val computed = tableUtils.sql(s"select * from ${joinConf.metaData.outputFinalView}") + computed.show() + val expectedFinal = featureDf.join(prefixColumnName(labelDf, exceptions = labelJoinConf.rowIdentifier()), + labelJoinConf.rowIdentifier(), + "left_outer") + println(" == Expected == ") + expectedFinal.show() + val diff = Comparison.sideBySide(computed, + expectedFinal, + List("listing", + "ds", + "label_ds")) + if (diff.count() > 0) { + println(s"Actual count: ${computed.count()}") + println(s"Expected count: ${expectedFinal.count()}") + println(s"Diff count: ${diff.count()}") + println(s"diff result rows") + diff.show() + } + assertEquals(0, diff.count()) + + // add another label version + val secondRun = new LabelJoin(joinConf, tableUtils, "2022-11-11") + val secondLabel = secondRun.computeLabelJoin() + println(" == Second Run Label version 2022-11-11 == ") + secondLabel.show() + val view = tableUtils.sql(s"select * from ${joinConf.metaData.outputFinalView} order by label_ds") + view.show() + // listing 4 should not have any 2022-11-11 version labels + assertEquals(null, view.where(view("label_ds") === "2022-11-11" && view("listing") === "4") + .select("label_listing_labels_dim_bedrooms").first().get(0)) + // 11-11 label record number should be same as 10-30 label version record number + assertEquals(view.where(view("label_ds") === "2022-10-30").count(), + view.where(view("label_ds") === "2022-11-11").count()) + // listing 5 should not not have any label + assertEquals(null, view.where(view("listing") === "5") + .select("label_ds").first().get(0)) + + //validate the latest label view + val latest = tableUtils.sql(s"select * from ${joinConf.metaData.outputLatestLabelView} order by label_ds") + latest.show() + // latest label should be all same "2022-11-11" + assertEquals(latest.agg(max("label_ds")).first().getString(0), + latest.agg(min("label_ds")).first().getString(0)) + assertEquals("2022-11-11", latest.agg(max("label_ds")).first().getString(0)) + } + + private def prefixColumnName(df: DataFrame, + prefix: String = "label_", + exceptions: Array[String] = null): DataFrame = { + println("exceptions") + println(exceptions.mkString(", ")) + val renamedColumns = df.columns + .map(col => { + if(exceptions.contains(col) || col.startsWith(prefix)) { + df(col) + } else { + df(col).as(s"$prefix$col") + } + }) + df.select(renamedColumns: _*) + } + + def createTestLabelJoin(startOffset: Int, + endOffset: Int, + groupByTableName: String = "listing_labels"): ai.chronon.api.LabelPart = { + val labelGroupBy = TestUtils.createAttributesGroupBy(namespace, spark, groupByTableName) + Builders.LabelPart( + labels = Seq( + Builders.JoinPart(groupBy = labelGroupBy.groupByConf) + ), + leftStartOffset = startOffset, + leftEndOffset = endOffset + ) + } + + def createTestFeatureTable(): DataFrame = { + val schema = StructType( + tableName, + Array( + StructField("listing", LongType), + StructField("feature_review", LongType), + StructField("feature_locale", StringType), + StructField("ds", StringType), + StructField("ts", StringType) + ) + ) + val rows = List( + Row(1L, 20L, "US", "2022-10-01", "2022-10-01 10:00:00"), + Row(2L, 38L, "US", "2022-10-02", "2022-10-02 11:00:00"), + Row(3L, 19L, "CA", "2022-10-01", "2022-10-01 08:00:00"), + Row(4L, 2L, "MX", "2022-10-02", "2022-10-02 18:00:00"), + Row(5L, 139L, "EU", "2022-10-01", "2022-10-01 22:00:00"), + Row(1L, 24L, "US", "2022-10-02", "2022-10-02 16:00:00") + ) + TestUtils.makeDf(spark, schema, rows) + } +} diff --git a/spark/src/test/scala/ai/chronon/spark/test/JoinUtilsTest.scala b/spark/src/test/scala/ai/chronon/spark/test/JoinUtilsTest.scala index f23ad4bf03..060921808e 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/JoinUtilsTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/JoinUtilsTest.scala @@ -1,7 +1,8 @@ package ai.chronon.spark.test +import ai.chronon.api.Constants import ai.chronon.spark.JoinUtils.{contains_any, set_add} -import ai.chronon.spark.{JoinUtils, SparkSessionBuilder} +import ai.chronon.spark.{JoinUtils, SparkSessionBuilder, TableUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -11,9 +12,11 @@ import org.junit.Test import scala.collection.mutable import scala.util.Try + class JoinUtilsTest { - val spark: SparkSession = SparkSessionBuilder.build("JoinUtilsTest", local = true) + lazy val spark: SparkSession = SparkSessionBuilder.build("JoinUtilsTest", local = true) + private val tableUtils = TableUtils(spark) @Test def testUDFSetAdd(): Unit = { @@ -210,4 +213,99 @@ class JoinUtilsTest { ) assertEquals(3, df.get.columns.length) } + + @Test + def testCreateJoinView(): Unit = { + val finalViewName = "testCreateView" + val leftTableName = "joinUtil.testFeatureTable" + val rightTableName = "joinUtil.testLabelTable" + spark.sql("CREATE DATABASE IF NOT EXISTS joinUtil") + TestUtils.createSampleFeatureTableDf(spark).write.saveAsTable(leftTableName) + TestUtils.createSampleLabelTableDf(spark).write.saveAsTable(rightTableName) + val keys = Array("listing_id", Constants.PartitionColumn) + + JoinUtils.createOrReplaceView(finalViewName, leftTableName, rightTableName, keys, tableUtils, + viewProperties = Map("featureTable" -> leftTableName, "labelTable" -> rightTableName)) + + val view = tableUtils.sql(s"select * from $finalViewName") + view.show() + assertEquals(6, view.count()) + assertEquals(null, view.where(view("ds") === "2022-10-01" && view("listing_id") === "5") + .select("label_room_type").first().get(0)) + assertEquals("SUPER_HOST", view.where(view("ds") === "2022-10-07" && view("listing_id") === "1") + .select("label_host_type").first().get(0)) + + val properties = tableUtils.getTableProperties(finalViewName) + assertTrue(properties.isDefined) + assertEquals(properties.get.get("featureTable"), Some(leftTableName)) + assertEquals(properties.get.get("labelTable"), Some(rightTableName)) + } + + @Test + def testCreateLatestLabelView(): Unit = { + val finalViewName = "joinUtil.testFinalView" + val leftTableName = "joinUtil.testFeatureTable2" + val rightTableName = "joinUtil.testLabelTable2" + spark.sql("CREATE DATABASE IF NOT EXISTS joinUtil") + TestUtils.createSampleFeatureTableDf(spark).write.saveAsTable(leftTableName) + tableUtils.insertPartitions(TestUtils.createSampleLabelTableDf(spark), + rightTableName, + partitionColumns = Seq(Constants.PartitionColumn, Constants.LabelPartitionColumn)) + val keys = Array("listing_id", Constants.PartitionColumn) + + JoinUtils.createOrReplaceView(finalViewName, leftTableName, rightTableName, keys, tableUtils, + viewProperties = Map(Constants.LabelViewPropertyFeatureTable -> leftTableName, + Constants.LabelViewPropertyKeyLabelTable -> rightTableName)) + val view = tableUtils.sql(s"select * from $finalViewName") + view.show() + assertEquals(6, view.count()) + + //verity latest label view + val latestLabelView = "testLatestLabel" + JoinUtils.createLatestLabelView(latestLabelView, finalViewName, tableUtils, + propertiesOverride = Map("newProperties" -> "value")) + val latest = tableUtils.sql(s"select * from $latestLabelView") + latest.show() + assertEquals(2, latest.count()) + assertEquals(0, latest.filter(latest("listing_id") === "3").count()) + assertEquals("2022-11-22", latest.where(latest("ds") === "2022-10-07"). + select("label_ds").first().get(0)) + // label_ds should be unique per ds + listing + val removeDup = latest.dropDuplicates(Seq("label_ds", "ds")) + assertEquals(removeDup.count(), latest.count()) + + val properties = tableUtils.getTableProperties(latestLabelView) + assertTrue(properties.isDefined) + assertEquals(properties.get.get(Constants.LabelViewPropertyFeatureTable), Some(leftTableName)) + assertEquals(properties.get.get("newProperties"), Some("value")) + } + + @Test + def testFilterColumns(): Unit ={ + val testDf = createSampleTable() + val filter = Array("listing", "ds", "feature_review") + val filteredDf = JoinUtils.filterColumns(testDf, filter) + assertTrue(filteredDf.schema.fieldNames.sorted sameElements filter.sorted) + } + + import ai.chronon.api.{LongType, StringType, StructField, StructType} + + def createSampleTable(tableName:String = "testSampleTable"): DataFrame = { + val schema = StructType( + tableName, + Array( + StructField("listing", LongType), + StructField("feature_review", LongType), + StructField("feature_locale", StringType), + StructField("ds", StringType), + StructField("ts", StringType) + ) + ) + val rows = List( + Row(1L, 20L, "US", "2022-10-01", "2022-10-01 10:00:00"), + Row(2L, 38L, "US", "2022-10-02", "2022-10-02 11:00:00"), + Row(3L, 19L, "CA", "2022-10-01", "2022-10-01 08:00:00") + ) + TestUtils.makeDf(spark, schema, rows) + } } diff --git a/spark/src/test/scala/ai/chronon/spark/test/LabelJoinTest.scala b/spark/src/test/scala/ai/chronon/spark/test/LabelJoinTest.scala index 103f3d15c8..72dda682c4 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/LabelJoinTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/LabelJoinTest.scala @@ -29,18 +29,15 @@ class LabelJoinTest { labelPart = labelJoinConf ) val runner = new LabelJoin(joinConf, tableUtils, labelDS) - val computed = runner.computeLabelJoin() + val computed = runner.computeLabelJoin(skipFinalJoin = true) println(" == Computed == ") computed.show() val expected = tableUtils.sql(s""" SELECT v.listing_id as listing, - ts, - m_guests, - m_views, dim_bedrooms as listing_attributes_dim_bedrooms, dim_room_type as listing_attributes_dim_room_type, - v.ds, - a.ds as label_ds + a.ds as label_ds, + v.ds FROM label_join.listing_views as v LEFT OUTER JOIN label_join.listing_attributes as a ON v.listing_id = a.listing_id @@ -52,18 +49,11 @@ class LabelJoinTest { val diff = Comparison.sideBySide(computed, expected, - List("listing", - "ds", - "label_ds", - "m_guests", - "m_views", - "listing_attributes_dim_room_type", - "listing_attributes_dim_bedrooms")) + List("listing", "ds")) if (diff.count() > 0) { println(s"Actual count: ${computed.count()}") println(s"Expected count: ${expected.count()}") println(s"Diff count: ${diff.count()}") - println(s"diff result rows") diff.show() } assertEquals(0, diff.count()) @@ -79,7 +69,7 @@ class LabelJoinTest { ) // label ds does not exist in label table, labels should be null val runner = new LabelJoin(joinConf, tableUtils, "2022-11-01") - val computed = runner.computeLabelJoin() + val computed = runner.computeLabelJoin(skipFinalJoin = true) println(" == Computed == ") computed.show() assertEquals(computed.select("label_ds").first().get(0), "2022-11-01") @@ -100,7 +90,7 @@ class LabelJoinTest { ) val runner = new LabelJoin(joinConf, tableUtils, labelDS) - val computed = runner.computeLabelJoin() + val computed = runner.computeLabelJoin(skipFinalJoin = true) println(" == Computed == ") computed.show() assertEquals(computed.count(), 6) @@ -111,7 +101,7 @@ class LabelJoinTest { subPartitionFilters = Map(Constants.LabelPartitionColumn -> labelDS)) val runner2 = new LabelJoin(joinConf, tableUtils, labelDS) - val refreshed = runner2.computeLabelJoin() + val refreshed = runner2.computeLabelJoin(skipFinalJoin = true) println(" == Refreshed == ") refreshed.show() assertEquals(refreshed.count(), 6) @@ -128,7 +118,7 @@ class LabelJoinTest { labelPart = labelJoinConf ) val runner = new LabelJoin(joinConf, tableUtils, labelDS) - val computed = runner.computeLabelJoin() + val computed = runner.computeLabelJoin(skipFinalJoin = true) println(" == First Run == ") computed.show() assertEquals(computed.count(), 6) @@ -148,7 +138,7 @@ class LabelJoinTest { labelPart = updatedLabelJoin ) val runner2 = new LabelJoin(updatedJoinConf, tableUtils, "2022-11-01") - val updated = runner2.computeLabelJoin() + val updated = runner2.computeLabelJoin(skipFinalJoin = true) println(" == Updated Run == ") updated.show() assertEquals(updated.count(), 12) diff --git a/spark/src/test/scala/ai/chronon/spark/test/TableUtilsTest.scala b/spark/src/test/scala/ai/chronon/spark/test/TableUtilsTest.scala index be4c1c3420..45726a705b 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/TableUtilsTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/TableUtilsTest.scala @@ -2,12 +2,13 @@ package ai.chronon.spark.test import ai.chronon.api._ import ai.chronon.spark._ +import ai.chronon.spark.test.TestUtils.makeDf import org.apache.spark.sql.functions.col import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SparkSession} import org.junit.Assert.{assertEquals, assertTrue} import org.junit.Test -import scala.util.{ScalaVersionSpecificCollectionsConverter, Try} +import scala.util.{Try} class TableUtilsTest { lazy val spark: SparkSession = SparkSessionBuilder.build("TableUtilsTest", local = true) @@ -87,13 +88,6 @@ class TableUtilsTest { }) } - private def makeDf(schema: StructType, rows: List[Row]): DataFrame = { - spark.createDataFrame( - ScalaVersionSpecificCollectionsConverter.convertScalaListToJava(rows), - Conversions.fromChrononSchema(schema) - ) - } - @Test def testInsertPartitionsAddColumns(): Unit = { val tableName = "db.test_table_1" @@ -104,6 +98,7 @@ class TableUtilsTest { StructField("string_field", StringType) ) val df1 = makeDf( + spark, StructType( tableName, columns1 :+ StructField("ds", StringType) @@ -114,6 +109,7 @@ class TableUtilsTest { ) val df2 = makeDf( + spark, StructType( tableName, columns1 @@ -138,6 +134,7 @@ class TableUtilsTest { StructField("string_field", StringType) ) val df1 = makeDf( + spark, StructType( tableName, columns1 @@ -150,6 +147,7 @@ class TableUtilsTest { ) val df2 = makeDf( + spark, StructType( tableName, columns1 :+ StructField("ds", StringType) @@ -170,6 +168,7 @@ class TableUtilsTest { StructField("int_field", IntType) ) val df1 = makeDf( + spark, StructType( tableName, columns1 @@ -182,6 +181,7 @@ class TableUtilsTest { ) val df2 = makeDf( + spark, StructType( tableName, columns1 @@ -216,6 +216,7 @@ class TableUtilsTest { StructField("label_ds", StringType) ) val df1 = makeDf( + spark, StructType( tableName, columns1 @@ -246,6 +247,50 @@ class TableUtilsTest { ))) } + @Test + def testAllPartitionsAndGetLatestLabelMapping(): Unit = { + val tableName = "db.test_show_partitions" + spark.sql("CREATE DATABASE IF NOT EXISTS db") + + val columns1 = Array( + StructField("long_field", LongType), + StructField("int_field", IntType), + StructField("ds", StringType), + StructField("label_ds", StringType) + ) + val df1 = makeDf( + spark, + StructType( + tableName, + columns1 + ), + List( + Row(1L, 2, "2022-10-01", "2022-11-01"), + Row(2L, 2, "2022-10-02", "2022-11-02"), + Row(3L, 8, "2022-10-05", "2022-11-05"), + Row(1L, 2, "2022-10-01", "2022-11-09"), + Row(2L, 2, "2022-10-02", "2022-11-09"), + Row(3L, 8, "2022-10-05", "2022-11-09") + ) + ) + tableUtils.insertPartitions(df1, + tableName, + partitionColumns = Seq(Constants.PartitionColumn, Constants.LabelPartitionColumn)) + val par = tableUtils.allPartitions(tableName) + assertTrue(par.size == 6) + assertEquals(par(0).keys, Set(Constants.PartitionColumn, Constants.LabelPartitionColumn)) + + // filter subset of partitions + val filtered = tableUtils.allPartitions(tableName, Seq(Constants.LabelPartitionColumn)) + assertTrue(filtered.size == 6) + assertEquals(filtered(0).keys, Set(Constants.LabelPartitionColumn)) + + // verify the latest label version + val labels = JoinUtils.getLatestLabelMapping(tableName, tableUtils) + assertEquals(labels.get("2022-11-09").get, List(PartitionRange("2022-10-01","2022-10-02"), + PartitionRange("2022-10-05","2022-10-05"))) + } + private def prepareTestDataWithSubPartitions(tableName: String): Unit = { spark.sql("CREATE DATABASE IF NOT EXISTS db") val columns1 = Array( @@ -254,6 +299,7 @@ class TableUtilsTest { StructField("label_ds", StringType) ) val df1 = makeDf( + spark, StructType( tableName, columns1 diff --git a/spark/src/test/scala/ai/chronon/spark/test/TestUtils.scala b/spark/src/test/scala/ai/chronon/spark/test/TestUtils.scala index 347a261712..4c45065157 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/TestUtils.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/TestUtils.scala @@ -3,7 +3,7 @@ package ai.chronon.spark.test import ai.chronon.api.{Accuracy, Builders, IntType, LongType, Operation, StringType, StructField, StructType} import ai.chronon.spark.Conversions import ai.chronon.spark.Extensions._ -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.{DataFrame, Row, SparkSession} import scala.util.ScalaVersionSpecificCollectionsConverter @@ -90,7 +90,9 @@ object TestUtils { Row(3L, 1, "PRIVATE_ROOM", "2022-10-30"), Row(4L, 1, "PRIVATE_ROOM", "2022-10-30"), Row(5L, 1, "PRIVATE_ROOM", "2022-10-30"), - Row(1L, 4, "ENTIRE_HOME_2", "2022-11-11") + Row(1L, 4, "ENTIRE_HOME_2", "2022-11-11"), + Row(2L, 3, "ENTIRE_HOME_2", "2022-11-11"), + Row(3L, 1, "PRIVATE_ROOM_2", "2022-11-11") ) val source = Builders.Source.entities( query = Builders.Query( @@ -172,4 +174,54 @@ object TestUtils { df ) } + + def createSampleLabelTableDf(spark: SparkSession, tableName: String = "listing_labels"): DataFrame = { + val schema = StructType( + tableName, + Array( + StructField("listing_id", LongType), + StructField("room_type", StringType), + StructField("host_type", StringType), + StructField("ds", StringType), + StructField("label_ds", StringType) + ) + ) + val rows = List( + Row(1L, "PRIVATE_ROOM", "SUPER_HOST", "2022-10-01", "2022-11-01"), + Row(2L, "PRIVATE_ROOM", "NEW_HOST", "2022-10-02", "2022-11-01"), + Row(3L, "ENTIRE_HOME", "SUPER_HOST", "2022-10-03", "2022-11-01"), + Row(4L, "PRIVATE_ROOM", "SUPER_HOST", "2022-10-04", "2022-11-01"), + Row(5L, "ENTIRE_HOME", "NEW_HOST", "2022-10-07", "2022-11-01"), + Row(1L, "PRIVATE ROOM", "SUPER_HOST", "2022-10-07", "2022-11-22") + ) + makeDf(spark, schema, rows) + } + + def createSampleFeatureTableDf(spark: SparkSession, tableName: String = "listing_features"): DataFrame = { + val schema = StructType( + tableName, + Array( + StructField("listing_id", LongType), + StructField("m_guests", LongType), + StructField("m_views", LongType), + StructField("ds", StringType) + ) + ) + val rows = List( + Row(1L, 2L, 20L, "2022-10-01"), + Row(2L, 3L, 30L, "2022-10-01"), + Row(3L, 1L, 10L, "2022-10-01"), + Row(4L, 2L, 20L, "2022-10-01"), + Row(5L, 3L, 35L, "2022-10-01"), + Row(1L, 5L, 15L, "2022-10-07") + ) + makeDf(spark, schema, rows) + } + + def makeDf(spark: SparkSession, schema: StructType, rows: List[Row]): DataFrame = { + spark.createDataFrame( + ScalaVersionSpecificCollectionsConverter.convertScalaListToJava(rows), + Conversions.fromChrononSchema(schema) + ) + } }