Skip to content

Commit f2f5786

Browse files
author
Sophie Wang
committed
Add final view sql & refactor
1 parent 8e91a8c commit f2f5786

File tree

10 files changed

+340
-50
lines changed

10 files changed

+340
-50
lines changed

api/src/main/scala/ai/chronon/api/Constants.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,5 @@ object Constants {
3737
val ContextualSourceName: String = "contextual"
3838
val ContextualSourceKeys: String = "contextual_keys"
3939
val ContextualSourceValues: String = "contextual_values"
40+
val LabelColumnPrefix: String = "label"
4041
}

api/src/main/scala/ai/chronon/api/Extensions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ object Extensions {
7676

7777
def outputTable = s"${metaData.outputNamespace}.${metaData.cleanName}"
7878
def outputLabelTable = s"${metaData.outputNamespace}.${metaData.cleanName}_labels"
79-
def outputFinalView = s"${metaData.outputNamespace}.${metaData.cleanName}_final_table"
79+
def outputFinalView = s"${metaData.outputNamespace}.${metaData.cleanName}_labeled"
8080
def loggedTable = s"${outputTable}_logged"
8181
def bootstrapTable = s"${outputTable}_bootstrap"
8282
private def comparisonPrefix = "comparison"

spark/src/main/scala/ai/chronon/spark/JoinUtils.scala

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ object JoinUtils {
8181
val leftDataType = leftDf.schema(leftDf.schema.fieldIndex(column)).dataType
8282
val rightDataType = rightDf.schema(rightDf.schema.fieldIndex(column)).dataType
8383
assert(leftDataType == rightDataType,
84-
s"Column '$column' has mismatched data types - left type: $leftDataType vs. right type $rightDataType")
84+
s"Column '$column' has mismatched data types - left type: $leftDataType vs. right type $rightDataType")
8585
}
8686

8787
val joinedDf = leftDf.join(rightDf, keys, joinType)
@@ -106,4 +106,53 @@ object JoinUtils {
106106
val finalDf = joinedDf.select(selects: _*)
107107
finalDf
108108
}
109+
110+
/***
111+
* Method to create or replace a view for feature table joining with labels.
112+
* Label columns will be prefixed with "label" or custom prefix for easy identification
113+
*/
114+
def createOrReplaceView(viewName: String,
115+
leftTable: String,
116+
rightTable: String,
117+
joinKeys: Array[String],
118+
tableUtils: TableUtils,
119+
viewProperties: Map[String, String] = null,
120+
labelColumnPrefix: String = Constants.LabelColumnPrefix): Unit = {
121+
val fieldDefinitions = joinKeys.map(field => s"l.${field}") ++
122+
tableUtils.getSchemaFromTable(leftTable)
123+
.filterNot(field => joinKeys.contains(field.name))
124+
.map(field => s"l.${field.name}") ++
125+
tableUtils.getSchemaFromTable(rightTable)
126+
.filterNot(field => joinKeys.contains(field.name))
127+
.map(field => {
128+
if(field.name.startsWith("label")) {
129+
s"r.${field.name}"
130+
} else {
131+
s"r.${field.name} AS ${labelColumnPrefix}_${field.name}"
132+
}
133+
})
134+
val joinKeyDefinitions = joinKeys.map(key => s"l.${key} = r.${key}")
135+
val createFragment =
136+
s"""CREATE OR REPLACE VIEW $viewName
137+
| AS SELECT
138+
| ${fieldDefinitions.mkString(",\n ")}
139+
| FROM ${leftTable} AS l LEFT OUTER JOIN ${rightTable} AS r
140+
| ON ${joinKeyDefinitions.mkString(" AND ")}""".stripMargin
141+
142+
val propertiesFragment = if (viewProperties != null && viewProperties.nonEmpty) {
143+
s"""TBLPROPERTIES (
144+
| ${viewProperties.transform((k, v) => s"'$k'='$v'").values.mkString(",\n ")}
145+
|)""".stripMargin
146+
} else {
147+
""
148+
}
149+
val sqlStatement = Seq(createFragment, propertiesFragment).mkString("\n")
150+
tableUtils.sql(sqlStatement)
151+
}
152+
153+
def filterColumns(df: DataFrame, filter: Seq[String]): DataFrame = {
154+
val columnsToDrop = df.columns
155+
.filterNot(col => filter.contains(col))
156+
df.drop(columnsToDrop:_*)
157+
}
109158
}

spark/src/main/scala/ai/chronon/spark/LabelJoin.scala

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class LabelJoin(joinConf: api.Join, tableUtils: TableUtils, labelDS: String) {
2323
)
2424

2525
val metrics = Metrics.Context(Metrics.Environment.LabelJoin, joinConf)
26-
private val outputTable = joinConf.metaData.outputLabelTable
26+
private val outputLabelTable = joinConf.metaData.outputLabelTable
2727
private val labelJoinConf = joinConf.labelPart
2828
private val confTableProps = Option(joinConf.metaData.tableProperties)
2929
.map(_.asScala.toMap)
@@ -32,7 +32,7 @@ class LabelJoin(joinConf: api.Join, tableUtils: TableUtils, labelDS: String) {
3232
val leftStart = Constants.Partition.minus(labelDS, new Window(labelJoinConf.leftStartOffset, TimeUnit.DAYS))
3333
val leftEnd = Constants.Partition.minus(labelDS, new Window(labelJoinConf.leftEndOffset, TimeUnit.DAYS))
3434

35-
def computeLabelJoin(stepDays: Option[Int] = None): DataFrame = {
35+
def computeLabelJoin(stepDays: Option[Int] = None, skipFinalJoin: Boolean = false): DataFrame = {
3636
// validations
3737
assert(Option(joinConf.left.dataModel).equals(Option(Events)),
3838
s"join.left.dataMode needs to be Events for label join ${joinConf.metaData.name}")
@@ -52,26 +52,33 @@ class LabelJoin(joinConf: api.Join, tableUtils: TableUtils, labelDS: String) {
5252
}
5353

5454
labelJoinConf.setups.foreach(tableUtils.sql)
55-
val labelDf = compute(joinConf.left, stepDays, Option(labelDS))
56-
// val joinOutputDf = tableUtils.sparkSession.table(joinConf.metaData.outputTable)
57-
// val finalDf = createFinalJoinView(labelDf, joinOutputDf)
58-
labelDf
55+
val labelTable = compute(joinConf.left, stepDays, Option(labelDS))
56+
57+
if (skipFinalJoin) {
58+
labelTable
59+
} else {
60+
// creating final join view with feature join output table
61+
println(s"Joining label table : ${outputLabelTable} with joined output table : ${joinConf.metaData.outputTable}")
62+
JoinUtils.createOrReplaceView(joinConf.metaData.outputFinalView,
63+
joinConf.metaData.outputTable,
64+
outputLabelTable,
65+
labelJoinConf.rowIdentifier,
66+
tableUtils)
67+
labelTable
68+
}
5969
}
60-
//
61-
// def createFinalJoinView(labelDf: DataFrame, joinOutput: DataFrame): DataFrame = {
62-
//
63-
// }
70+
6471

6572
def compute(left: Source, stepDays: Option[Int] = None, labelDS: Option[String] = None): DataFrame = {
6673
val rangeToFill = PartitionRange(leftStart, leftEnd)
6774
val today = Constants.Partition.at(System.currentTimeMillis())
6875
val sanitizedLabelDs = labelDS.getOrElse(today)
6976
println(s"Label join range to fill $rangeToFill")
70-
def finalResult = tableUtils.sql(rangeToFill.genScanQuery(null, outputTable))
77+
def finalResult = tableUtils.sql(rangeToFill.genScanQuery(null, outputLabelTable))
7178
//TODO: use unfilledRanges instead of dropPartitionsAfterHole
7279
val earliestHoleOpt =
7380
tableUtils.dropPartitionsAfterHole(left.table,
74-
outputTable,
81+
outputLabelTable,
7582
rangeToFill,
7683
Map(Constants.LabelPartitionColumn -> sanitizedLabelDs))
7784
if (earliestHoleOpt.forall(_ > rangeToFill.end)) {
@@ -101,14 +108,14 @@ class LabelJoin(joinConf: api.Join, tableUtils: TableUtils, labelDS: String) {
101108
println(s"Computing join for range: $range ${labelDS.getOrElse(today)} $progress")
102109
JoinUtils.leftDf(joinConf, range, tableUtils).map { leftDfInRange =>
103110
computeRange(leftDfInRange, range, sanitizedLabelDs)
104-
.save(outputTable, confTableProps, Seq(Constants.LabelPartitionColumn, Constants.PartitionColumn), true)
111+
.save(outputLabelTable, confTableProps, Seq(Constants.LabelPartitionColumn, Constants.PartitionColumn), true)
105112
val elapsedMins = (System.currentTimeMillis() - startMillis) / (60 * 1000)
106113
metrics.gauge(Metrics.Name.LatencyMinutes, elapsedMins)
107114
metrics.gauge(Metrics.Name.PartitionCount, range.partitions.length)
108-
println(s"Wrote to table $outputTable, into partitions: $range $progress in $elapsedMins mins")
115+
println(s"Wrote to table $outputLabelTable, into partitions: $range $progress in $elapsedMins mins")
109116
}
110117
}
111-
println(s"Wrote to table $outputTable, into partitions: $leftUnfilledRange")
118+
println(s"Wrote to table $outputLabelTable, into partitions: $leftUnfilledRange")
112119
finalResult
113120
}
114121

@@ -129,13 +136,17 @@ class LabelJoin(joinConf: api.Join, tableUtils: TableUtils, labelDS: String) {
129136
}
130137
}
131138

132-
val joined = rightDfs.zip(labelJoinConf.labels.asScala).foldLeft(leftDf) {
139+
val rowIdentifier = labelJoinConf.rowIdentifier
140+
println("Label Join filtering left df with only row identifier:", rowIdentifier.mkString(", "))
141+
val leftFiltered = JoinUtils.filterColumns(leftDf, rowIdentifier)
142+
143+
val joined = rightDfs.zip(labelJoinConf.labels.asScala).foldLeft(leftFiltered) {
133144
case (partialDf, (rightDf, joinPart)) => joinWithLeft(partialDf, rightDf, joinPart)
134145
}
135146

136147
// assign label ds value and drop duplicates
137148
val updatedJoin = joined.withColumn(Constants.LabelPartitionColumn, lit(labelDS))
138-
.dropDuplicates(labelJoinConf.rowIdentifier)
149+
.dropDuplicates(rowIdentifier)
139150
updatedJoin.explain()
140151
updatedJoin.drop(Constants.TimePartitionColumn)
141152
}

spark/src/main/scala/ai/chronon/spark/TableUtils.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import java.time.{Instant, ZoneId}
1111
import scala.collection.mutable
1212
import scala.util.{Success, Try}
1313

14+
1415
case class TableUtils(sparkSession: SparkSession) {
1516

1617
private val ARCHIVE_TIMESTAMP_FORMAT = "yyyyMMddHHmmss"
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
package ai.chronon.spark.test
2+
3+
import ai.chronon.api.Extensions.{LabelPartOps, MetadataOps}
4+
import ai.chronon.api.{Builders, LongType, StringType, StructField, StructType}
5+
import ai.chronon.spark.{Comparison, LabelJoin, SparkSessionBuilder, TableUtils}
6+
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
7+
import org.junit.Assert.{assertEquals}
8+
import org.junit.Test
9+
10+
class FeatureWithLabelJoinTest {
11+
val spark: SparkSession = SparkSessionBuilder.build("FeatureWithLabelJoinTest", local = true)
12+
13+
private val namespace = "final_join"
14+
private val tableName = "test_feature_label_join"
15+
spark.sql(s"CREATE DATABASE IF NOT EXISTS $namespace")
16+
private val tableUtils = TableUtils(spark)
17+
18+
private val labelDS = "2022-10-30"
19+
private val viewsGroupBy = TestUtils.createViewsGroupBy(namespace, spark)
20+
private val left = viewsGroupBy.groupByConf.sources.get(0)
21+
22+
@Test
23+
def testFinalView(): Unit = {
24+
// create test feature join table
25+
val featureTable = s"${namespace}.${tableName}"
26+
createTestFeatureTable().write.saveAsTable(featureTable)
27+
28+
val labelJoinConf = createTestLabelJoin(50, 20)
29+
val joinConf = Builders.Join(
30+
Builders.MetaData(name = tableName, namespace = namespace, team = "chronon"),
31+
left,
32+
labelPart = labelJoinConf
33+
)
34+
35+
val runner = new LabelJoin(joinConf, tableUtils, labelDS)
36+
val labelDf = runner.computeLabelJoin()
37+
println(" == First Run Label version 2022-10-30 == ")
38+
prefixColumnName(labelDf, exceptions = labelJoinConf.rowIdentifier).show()
39+
val featureDf = tableUtils.sparkSession.table(joinConf.metaData.outputTable)
40+
println(" == Features == ")
41+
featureDf.show()
42+
val computed = tableUtils.sql(s"select * from ${joinConf.metaData.outputFinalView}")
43+
computed.show()
44+
val expectedFinal = featureDf.join(prefixColumnName(labelDf, exceptions = labelJoinConf.rowIdentifier),
45+
labelJoinConf.rowIdentifier,
46+
"left_outer")
47+
println(" == Expected == ")
48+
expectedFinal.show()
49+
val diff = Comparison.sideBySide(computed,
50+
expectedFinal,
51+
List("listing",
52+
"ds",
53+
"label_ds",
54+
"feature_review",
55+
"feature_locale",
56+
"label_listing_labels_dim_bedrooms",
57+
"label_listing_labels_dim_room_type"))
58+
if (diff.count() > 0) {
59+
println(s"Actual count: ${computed.count()}")
60+
println(s"Expected count: ${expectedFinal.count()}")
61+
println(s"Diff count: ${diff.count()}")
62+
println(s"diff result rows")
63+
diff.show()
64+
}
65+
assertEquals(0, diff.count())
66+
67+
// add another label version
68+
val secondRun = new LabelJoin(joinConf, tableUtils, "2022-11-11")
69+
val secondLabel = secondRun.computeLabelJoin()
70+
println(" == Second Run Label version 2022-11-11 == ")
71+
secondLabel.show()
72+
val view = tableUtils.sql(s"select * from ${joinConf.metaData.outputFinalView} order by label_ds")
73+
view.show()
74+
// listing 4 should not have any 2022-11-11 version labels
75+
assertEquals(null, view.where(view("label_ds") === "2022-11-11" && view("listing") === "4")
76+
.select("label_listing_labels_dim_bedrooms").first().get(0))
77+
// 11-11 label record number should be same as 10-30 label version record number
78+
assertEquals(view.where(view("label_ds") === "2022-10-30").count(),
79+
view.where(view("label_ds") === "2022-11-11").count())
80+
// listing 5 should not not have any label
81+
assertEquals(null, view.where(view("listing") === "5")
82+
.select("label_ds").first().get(0))
83+
}
84+
85+
private def prefixColumnName(df: DataFrame,
86+
prefix: String = "label_",
87+
exceptions: Array[String] = null): DataFrame = {
88+
println("exceptions")
89+
println(exceptions.mkString(", "))
90+
val renamedColumns = df.columns
91+
.map(col => {
92+
if(exceptions.contains(col) || col.startsWith("label")) {
93+
df(col)
94+
} else {
95+
df(col).as(s"$prefix$col")
96+
}
97+
})
98+
df.select(renamedColumns: _*)
99+
}
100+
101+
def createTestLabelJoin(startOffset: Int,
102+
endOffset: Int,
103+
groupByTableName: String = "listing_labels"): ai.chronon.api.LabelPart = {
104+
val labelGroupBy = TestUtils.createAttributesGroupBy(namespace, spark, groupByTableName)
105+
Builders.LabelPart(
106+
labels = Seq(
107+
Builders.JoinPart(groupBy = labelGroupBy.groupByConf)
108+
),
109+
leftStartOffset = startOffset,
110+
leftEndOffset = endOffset
111+
)
112+
}
113+
114+
def createTestFeatureTable(): DataFrame = {
115+
val schema = StructType(
116+
tableName,
117+
Array(
118+
StructField("listing", LongType),
119+
StructField("feature_review", LongType),
120+
StructField("feature_locale", StringType),
121+
StructField("ds", StringType),
122+
StructField("ts", StringType)
123+
)
124+
)
125+
val rows = List(
126+
Row(1L, 20L, "US", "2022-10-01", "2022-10-01 10:00:00"),
127+
Row(2L, 38L, "US", "2022-10-02", "2022-10-02 11:00:00"),
128+
Row(3L, 19L, "CA", "2022-10-01", "2022-10-01 08:00:00"),
129+
Row(4L, 2L, "MX", "2022-10-02", "2022-10-02 18:00:00"),
130+
Row(5L, 139L, "EU", "2022-10-01", "2022-10-01 22:00:00"),
131+
Row(1L, 24L, "US", "2022-10-02", "2022-10-02 16:00:00")
132+
)
133+
TestUtils.makeDf(spark, schema, rows)
134+
}
135+
}

spark/src/test/scala/ai/chronon/spark/test/JoinUtilsTest.scala

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package ai.chronon.spark.test
22

33
import ai.chronon.spark.JoinUtils.{contains_any, set_add}
4-
import ai.chronon.spark.{JoinUtils, SparkSessionBuilder}
4+
import ai.chronon.spark.{JoinUtils, SparkSessionBuilder, TableUtils}
55
import org.apache.spark.rdd.RDD
66
import org.apache.spark.sql.functions._
77
import org.apache.spark.sql.types._
@@ -11,9 +11,11 @@ import org.junit.Test
1111

1212
import scala.collection.mutable
1313
import scala.util.Try
14+
1415
class JoinUtilsTest {
1516

16-
val spark: SparkSession = SparkSessionBuilder.build("JoinUtilsTest", local = true)
17+
lazy val spark: SparkSession = SparkSessionBuilder.build("JoinUtilsTest", local = true)
18+
private val tableUtils = TableUtils(spark)
1719

1820
@Test
1921
def testUDFSetAdd(): Unit = {
@@ -210,4 +212,51 @@ class JoinUtilsTest {
210212
)
211213
assertEquals(3, df.get.columns.length)
212214
}
215+
216+
@Test
217+
def testCreateJoinView(): Unit = {
218+
val finalViewName = "testCreateView"
219+
val leftTableName = "joinUtil.testFeatureTable"
220+
val rightTableName = "joinUtil.testLabelTable"
221+
spark.sql("CREATE DATABASE IF NOT EXISTS joinUtil")
222+
TestUtils.createSampleFeatureTableDf(spark).write.saveAsTable(leftTableName)
223+
TestUtils.createSampleLabelTableDf(spark).write.saveAsTable(rightTableName)
224+
val keys = Array("listing_id", Constants.PartitionColumn)
225+
226+
JoinUtils.createOrReplaceView(finalViewName, leftTableName, rightTableName, keys, tableUtils)
227+
val view = tableUtils.sql(s"select * from $finalViewName")
228+
view.show()
229+
assertEquals(6, view.count())
230+
assertEquals(null, view.where(view("ds") === "2022-10-01" && view("listing_id") === "5")
231+
.select("label_room_type").first().get(0))
232+
assertEquals("SUPER_HOST", view.where(view("ds") === "2022-10-07" && view("listing_id") === "1")
233+
.select("label_host_type").first().get(0))
234+
}
235+
236+
@Test
237+
def testFilterColumns(): Unit ={
238+
val testDf = createSampleTable()
239+
val filter = Array("listing", "ds", "feature_review")
240+
val filteredDf = JoinUtils.filterColumns(testDf, filter)
241+
assertTrue(filteredDf.schema.fieldNames.sorted sameElements filter.sorted)
242+
}
243+
244+
def createSampleTable(tableName:String = "testSampleTable"): DataFrame = {
245+
val schema = ai.chronon.api.StructType(
246+
tableName,
247+
Array(
248+
StructField("listing", LongType),
249+
StructField("feature_review", LongType),
250+
StructField("feature_locale", StringType),
251+
StructField("ds", StringType),
252+
StructField("ts", StringType)
253+
)
254+
)
255+
val rows = List(
256+
Row(1L, 20L, "US", "2022-10-01", "2022-10-01 10:00:00"),
257+
Row(2L, 38L, "US", "2022-10-02", "2022-10-02 11:00:00"),
258+
Row(3L, 19L, "CA", "2022-10-01", "2022-10-01 08:00:00")
259+
)
260+
TestUtils.makeDf(spark, schema, rows)
261+
}
213262
}

0 commit comments

Comments
 (0)