Skip to content

Commit f0da575

Browse files
author
Sophie Wang
committed
comments
1 parent cc4139c commit f0da575

File tree

6 files changed

+26
-24
lines changed

6 files changed

+26
-24
lines changed

api/src/main/scala/ai/chronon/api/Constants.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,6 @@ object Constants {
3939
val ContextualSourceValues: String = "contextual_values"
4040
val TeamOverride: String = "team_override"
4141
val LabelColumnPrefix: String = "label"
42-
val LabelViewPropertyFeatureTable: String = "featureTable"
43-
val LabelViewPropertyKeyLabelTable: String = "labelTable"
42+
val LabelViewPropertyFeatureTable: String = "feature_table"
43+
val LabelViewPropertyKeyLabelTable: String = "label_table"
4444
}

api/src/main/scala/ai/chronon/api/Extensions.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import com.fasterxml.jackson.databind.ObjectMapper
88
import java.io.{PrintWriter, StringWriter}
99
import java.util
1010
import scala.collection.mutable
11+
import scala.collection.JavaConverters._
1112
import scala.util.{Failure, ScalaVersionSpecificCollectionsConverter, Success, Try}
1213

1314
object Extensions {
@@ -567,9 +568,12 @@ object Extensions {
567568
.distinct
568569
}
569570

570-
// a list of columns which can identify a row on left
571-
def rowIdentifier: Array[String] = {
572-
leftKeyCols ++ Array(Constants.PartitionColumn)
571+
// a list of columns which can identify a row on left, use user specified columns by default
572+
def rowIdentifier(userRowId: util.List[String] = null): Array[String] = {
573+
if (userRowId != null && !userRowId.isEmpty)
574+
userRowId.asScala.toArray
575+
else
576+
leftKeyCols ++ Array(Constants.PartitionColumn)
573577
}
574578
}
575579

spark/src/main/scala/ai/chronon/spark/JoinUtils.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ object JoinUtils {
125125
tableUtils.getSchemaFromTable(rightTable)
126126
.filterNot(field => joinKeys.contains(field.name))
127127
.map(field => {
128-
if(field.name.startsWith("label")) {
128+
if(field.name.startsWith(labelColumnPrefix)) {
129129
s"r.${field.name}"
130130
} else {
131131
s"r.${field.name} AS ${labelColumnPrefix}_${field.name}"
@@ -199,7 +199,7 @@ object JoinUtils {
199199

200200
/**
201201
* compute the mapping label_ds -> PartitionRange of ds which has this label_ds as latest version
202-
* - Get all partitions from table and
202+
* - Get all partitions from table
203203
* - For each ds, find the latest available label_ds
204204
* - Reverse the mapping and get the ds partition range for each label version(label_ds)
205205
*

spark/src/main/scala/ai/chronon/spark/LabelJoin.scala

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,10 @@ class LabelJoin(joinConf: api.Join, tableUtils: TableUtils, labelDS: String) {
5959
} else {
6060
// creating final join view with feature join output table
6161
println(s"Joining label table : ${outputLabelTable} with joined output table : ${joinConf.metaData.outputTable}")
62-
val joinKeys: Array[String] = if (joinConf.rowIds != null && !joinConf.rowIds.isEmpty)
63-
joinConf.rowIds.asScala.toArray else labelJoinConf.rowIdentifier
6462
JoinUtils.createOrReplaceView(joinConf.metaData.outputFinalView,
6563
leftTable = joinConf.metaData.outputTable,
6664
rightTable = outputLabelTable,
67-
joinKeys = joinKeys,
65+
joinKeys = labelJoinConf.rowIdentifier(joinConf.rowIds),
6866
tableUtils = tableUtils,
6967
viewProperties = Map(Constants.LabelViewPropertyKeyLabelTable -> outputLabelTable,
7068
Constants.LabelViewPropertyFeatureTable -> joinConf.metaData.outputTable))
@@ -128,7 +126,7 @@ class LabelJoin(joinConf: api.Join, tableUtils: TableUtils, labelDS: String) {
128126
finalResult
129127
}
130128

131-
def computeRange(leftDf: DataFrame, leftRange: PartitionRange, labelDs: String): DataFrame = {
129+
def computeRange(leftDf: DataFrame, leftRange: PartitionRange, sanitizedLabelDs: String): DataFrame = {
132130
val leftDfCount = leftDf.count()
133131
val leftBlooms = labelJoinConf.leftKeyCols.par.map { key =>
134132
key -> leftDf.generateBloomFilter(key, leftDfCount, joinConf.left.table, leftRange)
@@ -145,7 +143,7 @@ class LabelJoin(joinConf: api.Join, tableUtils: TableUtils, labelDS: String) {
145143
}
146144
}
147145

148-
val rowIdentifier = labelJoinConf.rowIdentifier
146+
val rowIdentifier = labelJoinConf.rowIdentifier(joinConf.rowIds)
149147
println("Label Join filtering left df with only row identifier:", rowIdentifier.mkString(", "))
150148
val leftFiltered = JoinUtils.filterColumns(leftDf, rowIdentifier)
151149

@@ -154,7 +152,7 @@ class LabelJoin(joinConf: api.Join, tableUtils: TableUtils, labelDS: String) {
154152
}
155153

156154
// assign label ds value and drop duplicates
157-
val updatedJoin = joined.withColumn(Constants.LabelPartitionColumn, lit(labelDS))
155+
val updatedJoin = joined.withColumn(Constants.LabelPartitionColumn, lit(sanitizedLabelDs))
158156
.dropDuplicates(rowIdentifier)
159157
updatedJoin.explain()
160158
updatedJoin.drop(Constants.TimePartitionColumn)

spark/src/test/scala/ai/chronon/spark/test/FeatureWithLabelJoinTest.scala

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,26 +36,22 @@ class FeatureWithLabelJoinTest {
3636
val runner = new LabelJoin(joinConf, tableUtils, labelDS)
3737
val labelDf = runner.computeLabelJoin()
3838
println(" == First Run Label version 2022-10-30 == ")
39-
prefixColumnName(labelDf, exceptions = labelJoinConf.rowIdentifier).show()
39+
prefixColumnName(labelDf, exceptions = labelJoinConf.rowIdentifier()).show()
4040
val featureDf = tableUtils.sparkSession.table(joinConf.metaData.outputTable)
4141
println(" == Features == ")
4242
featureDf.show()
4343
val computed = tableUtils.sql(s"select * from ${joinConf.metaData.outputFinalView}")
4444
computed.show()
45-
val expectedFinal = featureDf.join(prefixColumnName(labelDf, exceptions = labelJoinConf.rowIdentifier),
46-
labelJoinConf.rowIdentifier,
45+
val expectedFinal = featureDf.join(prefixColumnName(labelDf, exceptions = labelJoinConf.rowIdentifier()),
46+
labelJoinConf.rowIdentifier(),
4747
"left_outer")
4848
println(" == Expected == ")
4949
expectedFinal.show()
5050
val diff = Comparison.sideBySide(computed,
5151
expectedFinal,
5252
List("listing",
5353
"ds",
54-
"label_ds",
55-
"feature_review",
56-
"feature_locale",
57-
"label_listing_labels_dim_bedrooms",
58-
"label_listing_labels_dim_room_type"))
54+
"label_ds"))
5955
if (diff.count() > 0) {
6056
println(s"Actual count: ${computed.count()}")
6157
println(s"Expected count: ${expectedFinal.count()}")
@@ -98,7 +94,7 @@ class FeatureWithLabelJoinTest {
9894
println(exceptions.mkString(", "))
9995
val renamedColumns = df.columns
10096
.map(col => {
101-
if(exceptions.contains(col) || col.startsWith("label")) {
97+
if(exceptions.contains(col) || col.startsWith(prefix)) {
10298
df(col)
10399
} else {
104100
df(col).as(s"$prefix$col")

spark/src/test/scala/ai/chronon/spark/test/JoinUtilsTest.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,8 @@ class JoinUtilsTest {
254254
val keys = Array("listing_id", Constants.PartitionColumn)
255255

256256
JoinUtils.createOrReplaceView(finalViewName, leftTableName, rightTableName, keys, tableUtils,
257-
viewProperties = Map("featureTable" -> leftTableName, "labelTable" -> rightTableName))
257+
viewProperties = Map(Constants.LabelViewPropertyFeatureTable -> leftTableName,
258+
Constants.LabelViewPropertyKeyLabelTable -> rightTableName))
258259
val view = tableUtils.sql(s"select * from $finalViewName")
259260
view.show()
260261
assertEquals(6, view.count())
@@ -269,10 +270,13 @@ class JoinUtilsTest {
269270
assertEquals(0, latest.filter(latest("listing_id") === "3").count())
270271
assertEquals("2022-11-22", latest.where(latest("ds") === "2022-10-07").
271272
select("label_ds").first().get(0))
273+
// label_ds should be unique per ds + listing
274+
val removeDup = latest.dropDuplicates(Seq("label_ds", "ds"))
275+
assertEquals(removeDup.count(), latest.count())
272276

273277
val properties = tableUtils.getTableProperties(latestLabelView)
274278
assertTrue(properties.isDefined)
275-
assertEquals(properties.get.get("featureTable"), Some(leftTableName))
279+
assertEquals(properties.get.get(Constants.LabelViewPropertyFeatureTable), Some(leftTableName))
276280
assertEquals(properties.get.get("newProperties"), Some("value"))
277281
}
278282

0 commit comments

Comments
 (0)