From 96458e81a89dc8e2d8113ec6eef9872207d20d92 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Wed, 30 Oct 2024 17:42:58 -0400 Subject: [PATCH 1/2] Tweak spark test setup to tags and run tests appropriately --- .github/workflows/test_scala_and_python.yaml | 8 +-- .../ai/chronon/spark/test/FetcherTest.scala | 27 +++++---- .../ai/chronon/spark/test/JoinTest.scala | 58 ++++++++----------- .../ai/chronon/spark/test/MutationsTest.scala | 36 ++++++------ .../spark/test/TaggedFilterSuite.scala | 34 +++++++++++ 5 files changed, 93 insertions(+), 70 deletions(-) create mode 100644 spark/src/test/scala/ai/chronon/spark/test/TaggedFilterSuite.scala diff --git a/.github/workflows/test_scala_and_python.yaml b/.github/workflows/test_scala_and_python.yaml index 62bd22b3a4..99f5087c76 100644 --- a/.github/workflows/test_scala_and_python.yaml +++ b/.github/workflows/test_scala_and_python.yaml @@ -65,7 +65,7 @@ jobs: - name: Run other spark tests run: | export SBT_OPTS="-Xmx8G -Xms2G --add-opens=java.base/sun.nio.ch=ALL-UNNAMED" - sbt "spark/testOnly -- -l ai.chronon.spark.JoinTest -l ai.chronon.spark.test.MutationsTest -l ai.chronon.spark.test.FetcherTest" + sbt "spark/testOnly" join_spark_tests: runs-on: ubuntu-latest @@ -84,7 +84,7 @@ jobs: - name: Run other spark tests run: | export SBT_OPTS="-Xmx8G -Xms2G --add-opens=java.base/sun.nio.ch=ALL-UNNAMED" - sbt "spark/testOnly ai.chronon.spark.JoinTest" + sbt "spark/testOnly -- -n jointest" mutation_spark_tests: runs-on: ubuntu-latest @@ -103,7 +103,7 @@ jobs: - name: Run other spark tests run: | export SBT_OPTS="-Xmx8G -Xms2G --add-opens=java.base/sun.nio.ch=ALL-UNNAMED" - sbt "spark/testOnly ai.chronon.spark.test.MutationsTest" + sbt "spark/testOnly -- -n mutationstest" fetcher_spark_tests: runs-on: ubuntu-latest @@ -122,7 +122,7 @@ jobs: - name: Run other spark tests run: | export SBT_OPTS="-Xmx8G -Xms2G --add-opens=java.base/sun.nio.ch=ALL-UNNAMED" - sbt "spark/testOnly ai.chronon.spark.test.FetcherTest" + sbt "spark/testOnly -- -n fetchertest" scala_compile_fmt_fix : runs-on: ubuntu-latest diff --git a/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala b/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala index a0e9576e9c..64d027edd7 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala @@ -37,7 +37,6 @@ import ai.chronon.spark.Extensions._ import ai.chronon.spark.stats.ConsistencyJob import ai.chronon.spark.{Join => _, _} import com.google.gson.GsonBuilder -import junit.framework.TestCase import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Row import org.apache.spark.sql.SparkSession @@ -48,6 +47,7 @@ import org.apache.spark.sql.functions.lit import org.junit.Assert.assertEquals import org.junit.Assert.assertFalse import org.junit.Assert.assertTrue +import org.scalatest.funsuite.AnyFunSuite import org.slf4j.Logger import org.slf4j.LoggerFactory @@ -64,7 +64,11 @@ import scala.concurrent.duration.SECONDS import scala.io.Source import scala.util.ScalaJavaConversions._ -class FetcherTest extends TestCase { +// Run as follows: sbt "spark/testOnly -- -n fetchertest" +class FetcherTest extends AnyFunSuite with TaggedFilterSuite { + + override def tagName: String = "fetchertest" + @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) val sessionName = "FetcherTest" val spark: SparkSession = SparkSessionBuilder.build(sessionName, local = true) @@ -74,7 +78,7 @@ class FetcherTest extends TestCase { private val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) private val yesterday = tableUtils.partitionSpec.before(today) - def testMetadataStore(): Unit = { + test("test metadata store") { implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1)) implicit val tableUtils: TableUtils = TableUtils(spark) @@ -114,7 +118,8 @@ class FetcherTest extends TestCase { val directoryDataSetDataSet = ChrononMetadataKey + "_directory_test" val directoryMetadataStore = new MetadataStore(inMemoryKvStore, directoryDataSetDataSet, timeoutMillis = 10000) inMemoryKvStore.create(directoryDataSetDataSet) - val directoryDataDirWalker = new MetadataDirWalker(confResource.getPath.replace(s"/$joinPath", ""), acceptedEndPoints) + val directoryDataDirWalker = + new MetadataDirWalker(confResource.getPath.replace(s"/$joinPath", ""), acceptedEndPoints) val directoryDataKvMap = directoryDataDirWalker.run val directoryPut = directoryDataKvMap.toSeq.map { case (_, kvMap) => directoryMetadataStore.put(kvMap, directoryDataSetDataSet) @@ -385,9 +390,8 @@ class FetcherTest extends TestCase { sources = Seq(Builders.Source.entities(query = Builders.Query(), snapshotTable = creditTable)), keyColumns = Seq("vendor_id"), aggregations = Seq( - Builders.Aggregation(operation = Operation.SUM, - inputColumn = "credit", - windows = Seq(new Window(3, TimeUnit.DAYS)))), + Builders + .Aggregation(operation = Operation.SUM, inputColumn = "credit", windows = Seq(new Window(3, TimeUnit.DAYS)))), metaData = Builders.MetaData(name = "unit_test/vendor_credit_derivation", namespace = namespace), derivations = Seq( Builders.Derivation("credit_sum_3d_test_rename", "credit_sum_3d"), @@ -527,7 +531,6 @@ class FetcherTest extends TestCase { println("saved all data hand written for fetcher test") val startPartition = "2021-04-07" - val leftSource = Builders.Source.events( @@ -717,13 +720,13 @@ class FetcherTest extends TestCase { assertEquals(0, diff.count()) } - def testTemporalFetchJoinDeterministic(): Unit = { + test("test temporal fetch join deterministic") { val namespace = "deterministic_fetch" val joinConf = generateMutationData(namespace) compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true) } - def testTemporalFetchJoinGenerated(): Unit = { + test("test temporal fetch join generated") { val namespace = "generated_fetch" val joinConf = generateRandomData(namespace) compareTemporalFetch(joinConf, @@ -733,14 +736,14 @@ class FetcherTest extends TestCase { dropDsOnWrite = false) } - def testTemporalTiledFetchJoinDeterministic(): Unit = { + test("test temporal tiled fetch join deterministic") { val namespace = "deterministic_tiled_fetch" val joinConf = generateEventOnlyData(namespace, groupByCustomJson = Some("{\"enable_tiling\": true}")) compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true) } // test soft-fail on missing keys - def testEmptyRequest(): Unit = { + test("test empty request") { val namespace = "empty_request" val joinConf = generateRandomData(namespace, 5, 5) implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1)) diff --git a/spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala b/spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala index e8e3c3b382..0608ff258b 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala @@ -40,13 +40,13 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.types._ import org.apache.spark.sql.types.{StringType => SparkStringType} import org.junit.Assert._ -import org.junit.Test -import org.scalatest.Assertions.intercept +import org.scalatest.funsuite.AnyFunSuite import scala.collection.JavaConverters._ import scala.util.ScalaJavaConversions.ListOps -class JoinTest { +// Run as follows: sbt "spark/testOnly -- -n jointest" +class JoinTest extends AnyFunSuite with TaggedFilterSuite { val spark: SparkSession = SparkSessionBuilder.build("JoinTest", local = true) private implicit val tableUtils = TableUtils(spark) @@ -59,8 +59,9 @@ class JoinTest { private val namespace = "test_namespace_jointest" tableUtils.createDatabase(namespace) - @Test - def testEventsEntitiesSnapshot(): Unit = { + override def tagName: String = "jointest" + + test("test events entities snapshot") { val dollarTransactions = List( Column("user", StringType, 100), Column("user_name", api.StringType, 100), @@ -263,8 +264,7 @@ class JoinTest { assertEquals(0, diff2.count()) } - @Test - def testEntitiesEntities(): Unit = { + test("test entities entities") { // untimed/unwindowed entities on right // right side val weightSchema = List( @@ -384,8 +384,7 @@ class JoinTest { */ } - @Test - def testEntitiesEntitiesNoHistoricalBackfill(): Unit = { + test("test entities entities no historical backfill") { // Only backfill latest partition if historical_backfill is turned off val weightSchema = List( Column("user", api.StringType, 1000), @@ -438,8 +437,7 @@ class JoinTest { assertEquals(allPartitions.toList(0), end) } - @Test - def testEventsEventsSnapshot(): Unit = { + test("test events events snapshot") { val viewsSchema = List( Column("user", api.StringType, 10000), Column("item", api.StringType, 100), @@ -508,8 +506,7 @@ class JoinTest { assertEquals(diff.count(), 0) } - @Test - def testEventsEventsTemporal(): Unit = { + test("test events events temporal") { val joinConf = getEventsEventsTemporal("temporal") val viewsSchema = List( @@ -586,8 +583,7 @@ class JoinTest { assertEquals(diff.count(), 0) } - @Test - def testEventsEventsCumulative(): Unit = { + test("test events events cumulative") { // Create a cumulative source GroupBy val viewsTable = s"$namespace.view_cumulative" val viewsGroupBy = getViewsGroupBy(suffix = "cumulative", makeCumulative = true) @@ -686,8 +682,7 @@ class JoinTest { } - @Test - def testNoAgg(): Unit = { + test("test no agg") { // Left side entities, right side entities no agg // Also testing specific select statement (rather than select *) val namesSchema = List( @@ -767,8 +762,7 @@ class JoinTest { assertEquals(diff.count(), 0) } - @Test - def testVersioning(): Unit = { + test("test versioning") { val joinConf = getEventsEventsTemporal("versioning") // Run the old join to ensure that tables exist @@ -922,8 +916,7 @@ class JoinTest { } - @Test - def testEndPartitionJoin(): Unit = { + test("test end partition join") { val join = getEventsEventsTemporal("end_partition_test") val start = join.getLeft.query.startPartition val end = tableUtils.partitionSpec.after(start) @@ -940,12 +933,11 @@ class JoinTest { assertTrue(ds.first().getString(0) < today) } - @Test - def testSkipBloomFilterJoinBackfill(): Unit = { - val testSpark: SparkSession = SparkSessionBuilder.build( - "JoinTest", - local = true, - additionalConfig = Some(Map("spark.chronon.backfill.bloomfilter.threshold" -> "100"))) + test("test skip bloom filter join backfill") { + val testSpark: SparkSession = + SparkSessionBuilder.build("JoinTest", + local = true, + additionalConfig = Some(Map("spark.chronon.backfill.bloomfilter.threshold" -> "100"))) val testTableUtils = TableUtils(testSpark) val viewsSchema = List( Column("user", api.StringType, 10000), @@ -990,8 +982,7 @@ class JoinTest { assertEquals(leftSideCount, skipBloomComputed.count()) } - @Test - def testStructJoin(): Unit = { + test("test struct join") { val nameSuffix = "_struct_test" val itemQueries = List(Column("item", api.StringType, 100)) val itemQueriesTable = s"$namespace.item_queries_$nameSuffix" @@ -1049,8 +1040,7 @@ class JoinTest { new SummaryJob(spark, join, today).dailyRun(stepDays = Some(30)) } - @Test - def testMigration(): Unit = { + test("test migration") { // Left val itemQueriesTable = s"$namespace.item_queries" @@ -1099,8 +1089,7 @@ class JoinTest { assertEquals(0, join.tablesToDrop(productionHashV2).length) } - @Test - def testKeyMappingOverlappingFields(): Unit = { + test("testKeyMappingOverlappingFields") { // test the scenario when a key_mapping is a -> b, (right key b is mapped to left key a) and // a happens to be another field in the same group by @@ -1158,8 +1147,7 @@ class JoinTest { * Run computeJoin(). * Check if the selected join part is computed and the other join parts are not computed. */ - @Test - def testSelectedJoinParts(): Unit = { + test("test selected join parts") { // Left val itemQueries = List( Column("item", api.StringType, 100), diff --git a/spark/src/test/scala/ai/chronon/spark/test/MutationsTest.scala b/spark/src/test/scala/ai/chronon/spark/test/MutationsTest.scala index 4a8df9831f..6b7de749d7 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/MutationsTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/MutationsTest.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.types.LongType import org.apache.spark.sql.types.StringType import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types.StructType -import org.junit.Test +import org.scalatest.funsuite.AnyFunSuite import org.slf4j.Logger import org.slf4j.LoggerFactory @@ -46,11 +46,17 @@ import org.slf4j.LoggerFactory * Left is an event source with definite ts. * Right is an entity with snapshots and mutation values through the day. * Join is the events and the entity value at the exact timestamp of the ts. + * To run: sbt "spark/testOnly -- -n mutationstest" */ -class MutationsTest { +class MutationsTest extends AnyFunSuite with TaggedFilterSuite { @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) - val spark: SparkSession = SparkSessionBuilder.build("MutationsTest", local = true) //, additionalConfig = Some(Map("spark.chronon.backfill.validation.enabled" -> "false"))) + override def tagName: String = "mutationstest" + + val spark: SparkSession = + SparkSessionBuilder.build("MutationsTest", + local = true + ) //, additionalConfig = Some(Map("spark.chronon.backfill.validation.enabled" -> "false"))) private implicit val tableUtils: TableUtils = TableUtils(spark) private def namespace(suffix: String) = s"test_mutations_$suffix" @@ -443,8 +449,7 @@ class MutationsTest { * * Compute Join for when mutations are just insert on values. */ - @Test - def testSimplestCase(): Unit = { + test("test simplest case") { val suffix = "simple" val leftData = Seq( // {listing_id, some_col, ts, ds} @@ -502,8 +507,7 @@ class MutationsTest { * * Compute Join when mutations have an update on values. */ - @Test - def testUpdateValueCase(): Unit = { + test("test update value case") { val suffix = "update_value" val leftData = Seq( // {listing_id, ts, event, ds} @@ -554,8 +558,7 @@ class MutationsTest { * * Compute Join when mutations have an update on keys. */ - @Test - def testUpdateKeyCase(): Unit = { + test("test update key case") { val suffix = "update_key" val leftData = Seq( Row(1, 1, millis("2021-04-10 01:00:00"), "2021-04-10"), @@ -612,8 +615,7 @@ class MutationsTest { * For this test we request a value for id 2, w/ mutations happening in the day before and after the time requested. * The consistency constraint here is that snapshot 4/8 + mutations 4/8 = snapshot 4/9 */ - @Test - def testInconsistentTsLeftCase(): Unit = { + test("test inconsistent ts left case") { val suffix = "inconsistent_ts" val leftData = Seq( Row(1, 1, millis("2021-04-10 01:00:00"), "2021-04-10"), @@ -682,8 +684,7 @@ class MutationsTest { * Compute Join, the snapshot aggregation should decay, this is the main reason to have * resolution in snapshot IR */ - @Test - def testDecayedWindowCase(): Unit = { + test("test decayed window case") { val suffix = "decayed" val leftData = Seq( Row(2, 1, millis("2021-04-09 01:30:00"), "2021-04-10"), @@ -754,8 +755,7 @@ class MutationsTest { * Compute Join, the snapshot aggregation should decay. * When there are no mutations returning the collapsed is not enough depending on the time. */ - @Test - def testDecayedWindowCaseNoMutation(): Unit = { + test("test decayed window case no mutation") { val suffix = "decayed_v2" val leftData = Seq( Row(2, 1, millis("2021-04-10 01:00:00"), "2021-04-10"), @@ -803,8 +803,7 @@ class MutationsTest { * Compute Join, the snapshot aggregation should decay. * When there's no snapshot the value would depend only on mutations of the day. */ - @Test - def testNoSnapshotJustMutation(): Unit = { + test("test no snapshot just mutation") { val suffix = "no_mutation" val leftData = Seq( Row(2, 1, millis("2021-04-10 00:07:00"), "2021-04-10"), @@ -844,8 +843,7 @@ class MutationsTest { assert(compareResult(result, expected)) } - @Test - def testWithGeneratedData(): Unit = { + test("test with generated data") { val suffix = "generated" val reviews = List( Column("listing_id", api.StringType, 10), diff --git a/spark/src/test/scala/ai/chronon/spark/test/TaggedFilterSuite.scala b/spark/src/test/scala/ai/chronon/spark/test/TaggedFilterSuite.scala new file mode 100644 index 0000000000..4b8f17e362 --- /dev/null +++ b/spark/src/test/scala/ai/chronon/spark/test/TaggedFilterSuite.scala @@ -0,0 +1,34 @@ +package ai.chronon.spark.test + +import org.scalatest.{Args, Filter, Status, SucceededStatus, Suite, SuiteMixin} + +/** + * SuiteMixin that skips execution of the tests in a suite if the tests are not triggered + * by the specific tagName. As an example: + * sbt test -> Will skip the test suite + * sbt spark/test -> Will skip the test suite + * sbt "spark/testOnly -- -n foo" -> Will include the tests in the suite if tagName = foo + * This allows us to skip some tests selectively by default while still being able to invoke them individually + */ +trait TaggedFilterSuite extends SuiteMixin { this: Suite => + + def tagName: String + + // Override to filter tests based on tags + abstract override def run(testName: Option[String], args: Args): Status = { + // If the tagName is explicitly included, run normally + val include = args.filter.tagsToInclude match { + case Some(tags) => tags.contains(tagName) + case _ => false + } + + val emptyFilter = Filter.apply() + val argsWithTagsCleared = args.copy(filter = emptyFilter) + if (include) { + super.run(testName, argsWithTagsCleared) + } else { + // Otherwise skip this suite + SucceededStatus + } + } +} From 9157deb31f0d992984641f8e08b5d868a427dd61 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Wed, 30 Oct 2024 18:52:12 -0400 Subject: [PATCH 2/2] Fix all --- .../scala/ai/chronon/spark/test/TaggedFilterSuite.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/spark/src/test/scala/ai/chronon/spark/test/TaggedFilterSuite.scala b/spark/src/test/scala/ai/chronon/spark/test/TaggedFilterSuite.scala index 4b8f17e362..46e76d866b 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/TaggedFilterSuite.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/TaggedFilterSuite.scala @@ -1,6 +1,11 @@ package ai.chronon.spark.test -import org.scalatest.{Args, Filter, Status, SucceededStatus, Suite, SuiteMixin} +import org.scalatest.Args +import org.scalatest.Filter +import org.scalatest.Status +import org.scalatest.SucceededStatus +import org.scalatest.Suite +import org.scalatest.SuiteMixin /** * SuiteMixin that skips execution of the tests in a suite if the tests are not triggered