Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/test_scala_and_python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ jobs:
- name: Run other spark tests
run: |
export SBT_OPTS="-Xmx8G -Xms2G --add-opens=java.base/sun.nio.ch=ALL-UNNAMED"
sbt "spark/testOnly -- -l ai.chronon.spark.JoinTest -l ai.chronon.spark.test.MutationsTest -l ai.chronon.spark.test.FetcherTest"
sbt "spark/testOnly"

join_spark_tests:
runs-on: ubuntu-latest
Expand All @@ -84,7 +84,7 @@ jobs:
- name: Run other spark tests
run: |
export SBT_OPTS="-Xmx8G -Xms2G --add-opens=java.base/sun.nio.ch=ALL-UNNAMED"
sbt "spark/testOnly ai.chronon.spark.JoinTest"
sbt "spark/testOnly -- -n jointest"

mutation_spark_tests:
runs-on: ubuntu-latest
Expand All @@ -103,7 +103,7 @@ jobs:
- name: Run other spark tests
run: |
export SBT_OPTS="-Xmx8G -Xms2G --add-opens=java.base/sun.nio.ch=ALL-UNNAMED"
sbt "spark/testOnly ai.chronon.spark.test.MutationsTest"
sbt "spark/testOnly -- -n mutationstest"

fetcher_spark_tests:
runs-on: ubuntu-latest
Expand All @@ -122,7 +122,7 @@ jobs:
- name: Run other spark tests
run: |
export SBT_OPTS="-Xmx8G -Xms2G --add-opens=java.base/sun.nio.ch=ALL-UNNAMED"
sbt "spark/testOnly ai.chronon.spark.test.FetcherTest"
sbt "spark/testOnly -- -n fetchertest"

scala_compile_fmt_fix :
runs-on: ubuntu-latest
Expand Down
27 changes: 15 additions & 12 deletions spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ import ai.chronon.spark.Extensions._
import ai.chronon.spark.stats.ConsistencyJob
import ai.chronon.spark.{Join => _, _}
import com.google.gson.GsonBuilder
import junit.framework.TestCase
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession
Expand All @@ -48,6 +47,7 @@ import org.apache.spark.sql.functions.lit
import org.junit.Assert.assertEquals
import org.junit.Assert.assertFalse
import org.junit.Assert.assertTrue
import org.scalatest.funsuite.AnyFunSuite
import org.slf4j.Logger
import org.slf4j.LoggerFactory

Expand All @@ -64,7 +64,11 @@ import scala.concurrent.duration.SECONDS
import scala.io.Source
import scala.util.ScalaJavaConversions._

class FetcherTest extends TestCase {
// Run as follows: sbt "spark/testOnly -- -n fetchertest"
class FetcherTest extends AnyFunSuite with TaggedFilterSuite {

override def tagName: String = "fetchertest"

@transient lazy val logger: Logger = LoggerFactory.getLogger(getClass)
val sessionName = "FetcherTest"
val spark: SparkSession = SparkSessionBuilder.build(sessionName, local = true)
Expand All @@ -74,7 +78,7 @@ class FetcherTest extends TestCase {
private val today = tableUtils.partitionSpec.at(System.currentTimeMillis())
private val yesterday = tableUtils.partitionSpec.before(today)

def testMetadataStore(): Unit = {
test("test metadata store") {
implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1))
implicit val tableUtils: TableUtils = TableUtils(spark)

Expand Down Expand Up @@ -114,7 +118,8 @@ class FetcherTest extends TestCase {
val directoryDataSetDataSet = ChrononMetadataKey + "_directory_test"
val directoryMetadataStore = new MetadataStore(inMemoryKvStore, directoryDataSetDataSet, timeoutMillis = 10000)
inMemoryKvStore.create(directoryDataSetDataSet)
val directoryDataDirWalker = new MetadataDirWalker(confResource.getPath.replace(s"/$joinPath", ""), acceptedEndPoints)
val directoryDataDirWalker =
new MetadataDirWalker(confResource.getPath.replace(s"/$joinPath", ""), acceptedEndPoints)
val directoryDataKvMap = directoryDataDirWalker.run
val directoryPut = directoryDataKvMap.toSeq.map {
case (_, kvMap) => directoryMetadataStore.put(kvMap, directoryDataSetDataSet)
Expand Down Expand Up @@ -385,9 +390,8 @@ class FetcherTest extends TestCase {
sources = Seq(Builders.Source.entities(query = Builders.Query(), snapshotTable = creditTable)),
keyColumns = Seq("vendor_id"),
aggregations = Seq(
Builders.Aggregation(operation = Operation.SUM,
inputColumn = "credit",
windows = Seq(new Window(3, TimeUnit.DAYS)))),
Builders
.Aggregation(operation = Operation.SUM, inputColumn = "credit", windows = Seq(new Window(3, TimeUnit.DAYS)))),
metaData = Builders.MetaData(name = "unit_test/vendor_credit_derivation", namespace = namespace),
derivations = Seq(
Builders.Derivation("credit_sum_3d_test_rename", "credit_sum_3d"),
Expand Down Expand Up @@ -527,7 +531,6 @@ class FetcherTest extends TestCase {
println("saved all data hand written for fetcher test")

val startPartition = "2021-04-07"


val leftSource =
Builders.Source.events(
Expand Down Expand Up @@ -717,13 +720,13 @@ class FetcherTest extends TestCase {
assertEquals(0, diff.count())
}

def testTemporalFetchJoinDeterministic(): Unit = {
test("test temporal fetch join deterministic") {
val namespace = "deterministic_fetch"
val joinConf = generateMutationData(namespace)
compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true)
}

def testTemporalFetchJoinGenerated(): Unit = {
test("test temporal fetch join generated") {
val namespace = "generated_fetch"
val joinConf = generateRandomData(namespace)
compareTemporalFetch(joinConf,
Expand All @@ -733,14 +736,14 @@ class FetcherTest extends TestCase {
dropDsOnWrite = false)
}

def testTemporalTiledFetchJoinDeterministic(): Unit = {
test("test temporal tiled fetch join deterministic") {
val namespace = "deterministic_tiled_fetch"
val joinConf = generateEventOnlyData(namespace, groupByCustomJson = Some("{\"enable_tiling\": true}"))
compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true)
}

// test soft-fail on missing keys
def testEmptyRequest(): Unit = {
test("test empty request") {
val namespace = "empty_request"
val joinConf = generateRandomData(namespace, 5, 5)
implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1))
Expand Down
58 changes: 23 additions & 35 deletions spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.{StringType => SparkStringType}
import org.junit.Assert._
import org.junit.Test
import org.scalatest.Assertions.intercept
import org.scalatest.funsuite.AnyFunSuite

import scala.collection.JavaConverters._
import scala.util.ScalaJavaConversions.ListOps

class JoinTest {
// Run as follows: sbt "spark/testOnly -- -n jointest"
class JoinTest extends AnyFunSuite with TaggedFilterSuite {

val spark: SparkSession = SparkSessionBuilder.build("JoinTest", local = true)
private implicit val tableUtils = TableUtils(spark)
Expand All @@ -59,8 +59,9 @@ class JoinTest {
private val namespace = "test_namespace_jointest"
tableUtils.createDatabase(namespace)

@Test
def testEventsEntitiesSnapshot(): Unit = {
override def tagName: String = "jointest"

test("test events entities snapshot") {
val dollarTransactions = List(
Column("user", StringType, 100),
Column("user_name", api.StringType, 100),
Expand Down Expand Up @@ -263,8 +264,7 @@ class JoinTest {
assertEquals(0, diff2.count())
}

@Test
def testEntitiesEntities(): Unit = {
test("test entities entities") {
// untimed/unwindowed entities on right
// right side
val weightSchema = List(
Expand Down Expand Up @@ -384,8 +384,7 @@ class JoinTest {
*/
}

@Test
def testEntitiesEntitiesNoHistoricalBackfill(): Unit = {
test("test entities entities no historical backfill") {
// Only backfill latest partition if historical_backfill is turned off
val weightSchema = List(
Column("user", api.StringType, 1000),
Expand Down Expand Up @@ -438,8 +437,7 @@ class JoinTest {
assertEquals(allPartitions.toList(0), end)
}

@Test
def testEventsEventsSnapshot(): Unit = {
test("test events events snapshot") {
val viewsSchema = List(
Column("user", api.StringType, 10000),
Column("item", api.StringType, 100),
Expand Down Expand Up @@ -508,8 +506,7 @@ class JoinTest {
assertEquals(diff.count(), 0)
}

@Test
def testEventsEventsTemporal(): Unit = {
test("test events events temporal") {

val joinConf = getEventsEventsTemporal("temporal")
val viewsSchema = List(
Expand Down Expand Up @@ -586,8 +583,7 @@ class JoinTest {
assertEquals(diff.count(), 0)
}

@Test
def testEventsEventsCumulative(): Unit = {
test("test events events cumulative") {
// Create a cumulative source GroupBy
val viewsTable = s"$namespace.view_cumulative"
val viewsGroupBy = getViewsGroupBy(suffix = "cumulative", makeCumulative = true)
Expand Down Expand Up @@ -686,8 +682,7 @@ class JoinTest {

}

@Test
def testNoAgg(): Unit = {
test("test no agg") {
// Left side entities, right side entities no agg
// Also testing specific select statement (rather than select *)
val namesSchema = List(
Expand Down Expand Up @@ -767,8 +762,7 @@ class JoinTest {
assertEquals(diff.count(), 0)
}

@Test
def testVersioning(): Unit = {
test("test versioning") {
val joinConf = getEventsEventsTemporal("versioning")

// Run the old join to ensure that tables exist
Expand Down Expand Up @@ -922,8 +916,7 @@ class JoinTest {

}

@Test
def testEndPartitionJoin(): Unit = {
test("test end partition join") {
val join = getEventsEventsTemporal("end_partition_test")
val start = join.getLeft.query.startPartition
val end = tableUtils.partitionSpec.after(start)
Expand All @@ -940,12 +933,11 @@ class JoinTest {
assertTrue(ds.first().getString(0) < today)
}

@Test
def testSkipBloomFilterJoinBackfill(): Unit = {
val testSpark: SparkSession = SparkSessionBuilder.build(
"JoinTest",
local = true,
additionalConfig = Some(Map("spark.chronon.backfill.bloomfilter.threshold" -> "100")))
test("test skip bloom filter join backfill") {
val testSpark: SparkSession =
SparkSessionBuilder.build("JoinTest",
local = true,
additionalConfig = Some(Map("spark.chronon.backfill.bloomfilter.threshold" -> "100")))
Comment on lines +937 to +940
Copy link
Contributor

@coderabbitai coderabbitai bot Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider adding error handling for SparkSession creation.

The SparkSession creation could benefit from explicit error handling to gracefully handle initialization failures.

Consider wrapping the SparkSession creation in a try-catch block:

-    val testSpark: SparkSession =
-      SparkSessionBuilder.build("JoinTest",
-                                local = true,
-                                additionalConfig = Some(Map("spark.chronon.backfill.bloomfilter.threshold" -> "100")))
+    val testSpark: SparkSession = try {
+      SparkSessionBuilder.build("JoinTest",
+                                local = true,
+                                additionalConfig = Some(Map("spark.chronon.backfill.bloomfilter.threshold" -> "100")))
+    } catch {
+      case e: Exception =>
+        throw new RuntimeException("Failed to initialize SparkSession for testing", e)
+    }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
val testSpark: SparkSession =
SparkSessionBuilder.build("JoinTest",
local = true,
additionalConfig = Some(Map("spark.chronon.backfill.bloomfilter.threshold" -> "100")))
val testSpark: SparkSession = try {
SparkSessionBuilder.build("JoinTest",
local = true,
additionalConfig = Some(Map("spark.chronon.backfill.bloomfilter.threshold" -> "100")))
} catch {
case e: Exception =>
throw new RuntimeException("Failed to initialize SparkSession for testing", e)
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

existing code, leaving as-is

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@piyush-zlai, understood, leaving the existing code as-is.

(_/)
(._.)
/ \

val testTableUtils = TableUtils(testSpark)
val viewsSchema = List(
Column("user", api.StringType, 10000),
Expand Down Expand Up @@ -990,8 +982,7 @@ class JoinTest {
assertEquals(leftSideCount, skipBloomComputed.count())
}

@Test
def testStructJoin(): Unit = {
test("test struct join") {
val nameSuffix = "_struct_test"
val itemQueries = List(Column("item", api.StringType, 100))
val itemQueriesTable = s"$namespace.item_queries_$nameSuffix"
Expand Down Expand Up @@ -1049,8 +1040,7 @@ class JoinTest {
new SummaryJob(spark, join, today).dailyRun(stepDays = Some(30))
}

@Test
def testMigration(): Unit = {
test("test migration") {

// Left
val itemQueriesTable = s"$namespace.item_queries"
Expand Down Expand Up @@ -1099,8 +1089,7 @@ class JoinTest {
assertEquals(0, join.tablesToDrop(productionHashV2).length)
}

@Test
def testKeyMappingOverlappingFields(): Unit = {
test("testKeyMappingOverlappingFields") {
// test the scenario when a key_mapping is a -> b, (right key b is mapped to left key a) and
// a happens to be another field in the same group by

Expand Down Expand Up @@ -1158,8 +1147,7 @@ class JoinTest {
* Run computeJoin().
* Check if the selected join part is computed and the other join parts are not computed.
*/
@Test
def testSelectedJoinParts(): Unit = {
test("test selected join parts") {
// Left
val itemQueries = List(
Column("item", api.StringType, 100),
Expand Down
Loading
Loading