Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions spark/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ scala_test_suite(
srcs = glob([
"src/test/scala/ai/chronon/spark/test/batch/*.scala",
]),
data = ["//spark/src/test/resources:test-resources"],
jvm_flags = _JVM_FLAGS_FOR_ACCESSING_BASE_JAVA_CLASSES,
tags = ["medium"],
visibility = ["//visibility:public"],
Expand Down
4 changes: 4 additions & 0 deletions spark/src/test/resources/local_data_csv/sample_join.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
request_id,listing_id,ts,ds
request_1,listing_1,2025-05-03 10:00:00,2025-05-03
request_2,listing_2,2025-05-09 10:00:00,2025-05-09
request_3,listing_3,2025-05-10 10:00:00,2025-05-10
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
request_id,listing_id,attribution,ts,ds
request_1,listing_1,[view|click|cart|purchase],2025-05-10 00:00:00,2025-05-10
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
request_id,listing_id,attribution,ts,ds
request_1,listing_1,[view|click],2025-05-03 11:00:01,2025-05-03
request_2,listing_2,[view|click],2025-05-09 10:10:10,2025-05-09
request_2,listing_2,[view|click|cart],2025-05-09 15:10:10,2025-05-09
request_3,listing_3,[view],2025-05-10 10:10:10,2025-05-10
22 changes: 21 additions & 1 deletion spark/src/test/scala/ai/chronon/spark/test/TestUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import ai.chronon.spark.catalog.TableUtils
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.{col, unix_timestamp}

object TestUtils {
def createViewsGroupBy(namespace: String,
Expand Down Expand Up @@ -451,4 +451,24 @@ object TestUtils {
accuracy = Accuracy.TEMPORAL
)
}

def createTableWithCsvData(tableUtils: TableUtils,
csvPath: String,
outputTableName: String,
tsColName: String = "ts",
tsFormat: String = "yyyy-MM-dd HH:mm:ss"): Unit = {
val df = createDataframeFromCsv(tableUtils.sparkSession, csvPath, tsColName, tsFormat)
tableUtils.insertPartitions(df = df, tableName = outputTableName)
}

def createDataframeFromCsv(spark: SparkSession,
csvPath: String,
tsColName: String = "ts",
tsFormat: String = "yyyy-MM-dd HH:mm:ss"): DataFrame = {
spark.read
.option("header", "true")
.option("inferSchema", "true")
.csv(csvPath)
.withColumn("ts", (unix_timestamp(col(tsColName), tsFormat) * 1000).cast("long"))
}
}
215 changes: 213 additions & 2 deletions spark/src/test/scala/ai/chronon/spark/test/batch/LabelJoinV2Test.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,22 @@ import ai.chronon.api.Extensions._
import ai.chronon.api._
import ai.chronon.spark.Extensions._
import ai.chronon.spark.batch._
import ai.chronon.spark.test.{DataFrameGen, TableTestUtils}
import ai.chronon.spark.test.{DataFrameGen, TableTestUtils, TestUtils}
import ai.chronon.spark.{GroupBy, Join, _}
import org.apache.spark.sql.SparkSession
import org.junit.Assert.assertEquals
import org.scalatest.flatspec.AnyFlatSpec
import org.slf4j.LoggerFactory

import java.nio.file.Paths

class LabelJoinV2Test extends AnyFlatSpec {

import ai.chronon.spark.submission

@transient private lazy val logger = LoggerFactory.getLogger(getClass)

val spark: SparkSession = submission.SparkSessionBuilder.build("LabelJoinV2Test", local = true)

private val tableUtils = TableTestUtils(spark)
private val today = tableUtils.partitionSpec.at(System.currentTimeMillis())
private val twentyNineDaysAgo = tableUtils.partitionSpec.minus(today, new Window(29, TimeUnit.DAYS))
Expand Down Expand Up @@ -673,4 +674,214 @@ class LabelJoinV2Test extends AnyFlatSpec {
assertEquals(0, diffCount)
}

private def setupJoinForLabelJoinTestingFromFixtures(tableUtils: TableTestUtils, namespace: String) = {
val runfilesDir = System.getenv("RUNFILES_DIR")

val joinName = "label_join_without_round_down"
TestUtils.createTableWithCsvData(
tableUtils,
Paths.get(runfilesDir, "chronon/spark/src/test/resources/local_data_csv/sample_join.csv").toAbsolutePath.toString,
s"$namespace.$joinName"
)

val tightSourceTable = s"$namespace.tight_events"
TestUtils.createTableWithCsvData(
tableUtils,
Paths
.get(runfilesDir, "chronon/spark/src/test/resources/local_data_csv/sample_label_tight_source.csv")
.toAbsolutePath
.toString,
tightSourceTable
)

val looseSourceTable = s"$namespace.loose_events"
TestUtils.createTableWithCsvData(
tableUtils,
Paths
.get(runfilesDir, "chronon/spark/src/test/resources/local_data_csv/sample_label_loose_source.csv")
.toAbsolutePath
.toString,
looseSourceTable
)

val tightLabelGroupBy = Builders.GroupBy(
sources = Seq(
Builders.Source.events(
query = Builders.Query(
selects = Builders.Selects(
"request_id",
"listing_id",
"attribution"
),
timeColumn = "ts"
),
table = tightSourceTable
)),
keyColumns = Seq("request_id", "listing_id"),
aggregations = Seq(
Builders.Aggregation(operation = Operation.LAST,
inputColumn = "attribution",
windows = Seq(new Window(1, TimeUnit.DAYS)))
),
metaData = Builders.MetaData(name = "unit_test.tight_temporal_labels", namespace = namespace),
accuracy = Accuracy.TEMPORAL
)
val looseLabelGroupBy = Builders.GroupBy(
sources = Seq(
Builders.Source.events(
query = Builders.Query(
selects = Builders.Selects(
"request_id",
"listing_id",
"attribution"
),
timeColumn = "ts"
),
table = looseSourceTable
)),
keyColumns = Seq("request_id", "listing_id"),
aggregations = Seq(
Builders.Aggregation(operation = Operation.LAST,
inputColumn = "attribution",
windows = Seq(new Window(7, TimeUnit.DAYS)))
),
metaData = Builders.MetaData(name = "unit_test.loose_temporal_labels", namespace = namespace),
accuracy = Accuracy.TEMPORAL
)

val joinConf = Builders.Join(
// left and join parts not used actually here
left = Builders.Source.events(Builders.Query(), table = ""),
joinParts = Seq(),
labelParts = Builders.LabelPart(
labels = Seq(
Builders.JoinPart(groupBy = tightLabelGroupBy),
Builders.JoinPart(groupBy = looseLabelGroupBy)
)
),
metaData = Builders.MetaData(name = joinName, namespace = namespace, team = "chronon")
)
joinConf
}

it should "test temporal label parts using test fixtures without round_down" in {
val sparkNoRoundDownWithFixture: SparkSession = submission.SparkSessionBuilder.build(
"LabelJoinV2TestWithoutRoundDown",
additionalConfig = Option(Map("spark.chronon.join.label_join.round_down_sub_day_windows" -> "false")),
local = true)
import sparkNoRoundDownWithFixture.implicits._

val namespace = "label_joinv2_temporal_with_fixtures_without_round_down"

val tableUtils = TableTestUtils(sparkNoRoundDownWithFixture)
tableUtils.createDatabase(namespace)
val joinConf = setupJoinForLabelJoinTestingFromFixtures(tableUtils, namespace)

// for request_1 + listing_1 it occurs on May 3, 2025. In order to fill its tight (1d) label value under the
// no round down flag, we need to run May 3 + 1 = May 4.
val may4Ds = "2025-05-04"
val labelJoinMay4 = new LabelJoinV2(joinConf, tableUtils, new api.DateRange(may4Ds, may4Ds))
val labelComputedMay4 = labelJoinMay4.compute()
labelComputedMay4.show(truncate = false)
val expectedMay4Df = Seq(
("request_1", "listing_1", "1746266400000", null, "[view|click]", "2025-05-03")
// we expect only tight for request_1 + listing_1 to be filled. loose will be filled in the future May 10 run
).toDF(
"request_id",
"listing_id",
"ts",
"label__unit_test_loose_temporal_labels_attribution_last_7d",
"label__unit_test_tight_temporal_labels_attribution_last_1d",
"ds"
)
assert(expectedMay4Df.except(labelComputedMay4).isEmpty)
assert(labelComputedMay4.except(expectedMay4Df).isEmpty)

// For no round down if we run May 10, we expect tight(1d) to run and write for May 10 - 1 = 9,
// and loose(7d) to run and write to May 3.
val may10RunDs = "2025-05-10"
val labelJoinMay10 = new LabelJoinV2(joinConf, tableUtils, new api.DateRange(may10RunDs, may10RunDs))
val labelComputedMay10 = labelJoinMay10.compute()
labelComputedMay10.show(truncate = false)

val expectedMay10Df = Seq(
// We expect request_2 + listing_2 to only have its tight(last_1d) label col filled now.
("request_2", "listing_2", "1746784800000", null, "[view|click|cart]", "2025-05-09"),
// We expect request_1 + listing_1 to now have its loose(last_7d) label col filled now.
("request_1", "listing_1", "1746266400000", "[view|click|cart|purchase]", "[view|click]", "2025-05-03")
).toDF(
"request_id",
"listing_id",
"ts",
"label__unit_test_loose_temporal_labels_attribution_last_7d",
"label__unit_test_tight_temporal_labels_attribution_last_1d",
"ds"
)
assert(expectedMay10Df.except(labelComputedMay10).isEmpty)
assert(labelComputedMay10.except(expectedMay10Df).isEmpty)
}

it should "test temporal label parts using test fixtures with round_down" in {
val sparkRoundDownWithFixture: SparkSession = submission.SparkSessionBuilder.build(
"LabelJoinV2TestWithRoundDown",
additionalConfig = Option(Map("spark.chronon.join.label_join.round_down_sub_day_windows" -> "true")),
local = true)
import sparkRoundDownWithFixture.implicits._

val namespace = "label_joinv2_temporal_with_fixtures_round_down"

val tableUtils = TableTestUtils(sparkRoundDownWithFixture)
tableUtils.createDatabase(namespace)
val joinConf = setupJoinForLabelJoinTestingFromFixtures(tableUtils, namespace)

// for request_1 + listing_1 it occurs on May 3, 2025. In order to fill its tight (1d) label value under the
// round down flag, we need to run May 3 + 0 = May 3.
val may3Ds = "2025-05-03"
val labelJoinMay3 = new LabelJoinV2(joinConf, tableUtils, new api.DateRange(may3Ds, may3Ds))
val labelComputedMay3 = labelJoinMay3.compute()
labelComputedMay3.show(truncate = false)
val expectedMay3Df = Seq(
("request_1",
"listing_1",
"1746266400000",
null,
"[view|click]",
// we expect only tight for request_1 + listing_1 to be filled. loose will be filled in the future May 10 run
"2025-05-03")
).toDF(
"request_id",
"listing_id",
"ts",
"label__unit_test_loose_temporal_labels_attribution_last_7d",
"label__unit_test_tight_temporal_labels_attribution_last_1d",
"ds"
)
assert(expectedMay3Df.except(labelComputedMay3).isEmpty)
assert(labelComputedMay3.except(expectedMay3Df).isEmpty)

// For round down if we run May 10, we expect tight(1d) to run for May 10 + 0 join output and write
// (the -1 from previous test gets rounded down to 0),
// and loose(7d) to run and write to May 3.
val may10RunDs = "2025-05-10"
val labelJoinMay10 = new LabelJoinV2(joinConf, tableUtils, new api.DateRange(may10RunDs, may10RunDs))
val labelComputedMay10 = labelJoinMay10.compute()
labelComputedMay10.show(truncate = false)

val expectedMay10Df = Seq(
// We expect request_2 + listing_2 to only have its tight(last_1d) label col filled now.
("request_3", "listing_3", "1746871200000", null, "[view]", "2025-05-10"),
// We expect request_1 + listing_1 to now have its loose(last_7d) label col filled now.
("request_1", "listing_1", "1746266400000", "[view|click|cart|purchase]", "[view|click]", "2025-05-03")
).toDF(
"request_id",
"listing_id",
"ts",
"label__unit_test_loose_temporal_labels_attribution_last_7d",
"label__unit_test_tight_temporal_labels_attribution_last_1d",
"ds"
)
assert(expectedMay10Df.except(labelComputedMay10).isEmpty)
assert(labelComputedMay10.except(expectedMay10Df).isEmpty)
}

}