From 4291b99751d76d1991efc8af33152e40185b1155 Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Tue, 14 Jan 2025 09:35:22 -0500 Subject: [PATCH 01/14] unifying all tests into flatspec except vertx ones --- .../aggregator/test/ApproxDistinctTest.scala | 8 +- .../aggregator/test/ApproxHistogramTest.scala | 20 ++--- .../test/ApproxPercentilesTest.scala | 8 +- .../aggregator/test/FrequentItemsTest.scala | 16 ++-- .../chronon/aggregator/test/MinHeapTest.scala | 6 +- .../chronon/aggregator/test/MomentTest.scala | 14 ++-- .../aggregator/test/RowAggregatorTest.scala | 6 +- .../test/SawtoothAggregatorTest.scala | 8 +- .../test/SawtoothOnlineAggregatorTest.scala | 6 +- .../test/TwoStackLiteAggregatorTest.scala | 8 +- .../aggregator/test/VarianceTest.scala | 6 +- .../api/test/DataTypeConversionTest.scala | 7 +- .../ai/chronon/api/test/ExtensionsTest.scala | 37 ++++----- .../api/test/ParametricMacroTest.scala | 7 +- .../aws/DynamoDBKVStoreTest.scala | 20 ++--- .../cloud_gcp/BigQueryCatalogTest.scala | 16 ++-- .../cloud_gcp/BigTableKVStoreTest.scala | 34 +++----- .../cloud_gcp/DataprocSubmitterTest.scala | 12 +-- .../cloud_gcp/GCSFormatTest.scala | 10 +-- .../flink/test/AsyncKVStoreWriterTest.scala | 10 +-- .../flink/test/FlinkJobIntegrationTest.scala | 10 +-- .../test/SparkExpressionEvalFnTest.scala | 7 +- .../FlinkRowAggregationFunctionTest.scala | 14 ++-- .../flink/test/window/KeySelectorTest.scala | 10 +-- .../online/test/CatalystUtilHiveUDFTest.scala | 8 +- .../online/test/CatalystUtilTest.scala | 77 +++++++------------ .../online/test/DataStreamBuilderTest.scala | 10 +-- .../chronon/online/test/FetcherBaseTest.scala | 27 +++---- .../online/test/FetcherCacheTest.scala | 44 ++++------- .../chronon/online/test/JoinCodecTest.scala | 7 +- .../ai/chronon/online/test/LRUCacheTest.scala | 13 ++-- .../ai/chronon/online/test/TagsTest.scala | 7 +- .../online/test/ThriftDecodingTest.scala | 7 +- .../chronon/online/test/TileCodecTest.scala | 10 +-- .../online/test/stats/DriftMetricsTest.scala | 12 +-- .../ai/chronon/spark/test/AnalyzerTest.scala | 34 +++----- .../ai/chronon/spark/test/AvroTest.scala | 7 +- .../spark/test/ChainingFetcherTest.scala | 8 +- .../ai/chronon/spark/test/CompareTest.scala | 22 ++---- .../ai/chronon/spark/test/DataRangeTest.scala | 7 +- .../chronon/spark/test/EditDistanceTest.scala | 7 +- .../spark/test/ExternalSourcesTest.scala | 7 +- .../spark/test/FeatureWithLabelJoinTest.scala | 10 +-- .../ai/chronon/spark/test/FetcherTest.scala | 14 ++-- .../ai/chronon/spark/test/GroupByTest.scala | 43 ++++------- .../spark/test/GroupByUploadTest.scala | 16 ++-- .../ai/chronon/spark/test/JoinTest.scala | 34 ++++---- .../ai/chronon/spark/test/JoinUtilsTest.scala | 46 ++++------- .../spark/test/KafkaStreamBuilderTest.scala | 7 +- .../ai/chronon/spark/test/LabelJoinTest.scala | 34 +++----- .../spark/test/LocalDataLoaderTest.scala | 10 +-- .../test/LocalExportTableAbilityTest.scala | 13 ++-- .../spark/test/LocalTableExporterTest.scala | 10 +-- .../spark/test/MetadataExporterTest.scala | 6 +- .../spark/test/MigrationCompareTest.scala | 16 ++-- .../ai/chronon/spark/test/MutationsTest.scala | 20 ++--- .../spark/test/OfflineSubcommandTest.scala | 13 ++-- .../test/ResultValidationAbilityTest.scala | 16 ++-- .../spark/test/SchemaEvolutionTest.scala | 14 ++-- .../chronon/spark/test/StagingQueryTest.scala | 16 ++-- .../chronon/spark/test/StatsComputeTest.scala | 19 ++--- .../ai/chronon/spark/test/StreamingTest.scala | 6 +- .../spark/test/TableUtilsFormatTest.scala | 12 +-- .../chronon/spark/test/TableUtilsTest.scala | 55 +++++-------- .../spark/test/bootstrap/DerivationTest.scala | 22 ++---- .../test/bootstrap/LogBootstrapTest.scala | 7 +- .../test/bootstrap/TableBootstrapTest.scala | 10 +-- 67 files changed, 435 insertions(+), 648 deletions(-) diff --git a/aggregator/src/test/scala/ai/chronon/aggregator/test/ApproxDistinctTest.scala b/aggregator/src/test/scala/ai/chronon/aggregator/test/ApproxDistinctTest.scala index 2416a894f5..cec97db0f8 100644 --- a/aggregator/src/test/scala/ai/chronon/aggregator/test/ApproxDistinctTest.scala +++ b/aggregator/src/test/scala/ai/chronon/aggregator/test/ApproxDistinctTest.scala @@ -17,10 +17,10 @@ package ai.chronon.aggregator.test import ai.chronon.aggregator.base.ApproxDistinctCount -import junit.framework.TestCase import org.junit.Assert._ +import org.scalatest.flatspec.AnyFlatSpec -class ApproxDistinctTest extends TestCase { +class ApproxDistinctTest extends AnyFlatSpec { def testErrorBound(uniques: Int, errorBound: Int, lgK: Int): Unit = { val uniqueElems = 1 to uniques val duplicates = uniqueElems ++ uniqueElems ++ uniqueElems @@ -50,13 +50,13 @@ class ApproxDistinctTest extends TestCase { assertTrue(Math.abs(estimated - uniques) < errorBound) } - def testErrorBounds(): Unit = { + it should "error bounds" in { testErrorBound(uniques = 100, errorBound = 1, lgK = 10) testErrorBound(uniques = 1000, errorBound = 20, lgK = 10) testErrorBound(uniques = 10000, errorBound = 300, lgK = 10) } - def testMergingErrorBounds(): Unit = { + it should "merging error bounds" in { testMergingErrorBound(uniques = 100, errorBound = 1, lgK = 10, merges = 10) testMergingErrorBound(uniques = 1000, errorBound = 20, lgK = 10, merges = 4) testMergingErrorBound(uniques = 10000, errorBound = 400, lgK = 10, merges = 100) diff --git a/aggregator/src/test/scala/ai/chronon/aggregator/test/ApproxHistogramTest.scala b/aggregator/src/test/scala/ai/chronon/aggregator/test/ApproxHistogramTest.scala index f1b2cb039a..91733fad43 100644 --- a/aggregator/src/test/scala/ai/chronon/aggregator/test/ApproxHistogramTest.scala +++ b/aggregator/src/test/scala/ai/chronon/aggregator/test/ApproxHistogramTest.scala @@ -2,14 +2,14 @@ package ai.chronon.aggregator.test import ai.chronon.aggregator.base.ApproxHistogram import ai.chronon.aggregator.base.ApproxHistogramIr -import junit.framework.TestCase import org.junit.Assert._ +import org.scalatest.flatspec.AnyFlatSpec import java.util import scala.jdk.CollectionConverters._ -class ApproxHistogramTest extends TestCase { - def testHistogram(): Unit = { +class ApproxHistogramTest extends AnyFlatSpec { + it should "histogram" in { val approxHistogram = new ApproxHistogram[String](3) val counts = (1L to 3).map(i => i.toString -> i).toMap val ir = makeIr(approxHistogram, counts) @@ -19,7 +19,7 @@ class ApproxHistogramTest extends TestCase { assertEquals(toHashMap(counts), approxHistogram.finalize(ir)) } - def testSketch(): Unit = { + it should "sketch" in { val approxHistogram = new ApproxHistogram[String](3) val counts = (1L to 4).map(i => i.toString -> i).toMap val expected = counts.toSeq.sortBy(_._2).reverse.take(3).toMap @@ -30,7 +30,7 @@ class ApproxHistogramTest extends TestCase { assertEquals(toHashMap(expected), approxHistogram.finalize(ir)) } - def testMergeSketches(): Unit = { + it should "merge sketches" in { val approxHistogram = new ApproxHistogram[String](3) val counts1: Map[String, Long] = Map("5" -> 5L, "4" -> 4, "2" -> 2, "1" -> 1) val counts2: Map[String, Long] = Map("6" -> 6L, "4" -> 4, "2" -> 2, "1" -> 1) @@ -52,7 +52,7 @@ class ApproxHistogramTest extends TestCase { assertTrue(ir.histogram.isEmpty) } - def testMergeHistograms(): Unit = { + it should "merge histograms" in { val approxHistogram = new ApproxHistogram[String](3) val counts1: Map[String, Long] = Map("4" -> 4L, "2" -> 2) val counts2: Map[String, Long] = Map("3" -> 3L, "2" -> 2) @@ -74,7 +74,7 @@ class ApproxHistogramTest extends TestCase { assertTrue(ir.sketch.isEmpty) } - def testMergeHistogramsToSketch(): Unit = { + it should "merge histograms to sketch" in { val approxHistogram = new ApproxHistogram[String](3) val counts1: Map[String, Long] = Map("4" -> 4L, "3" -> 3) val counts2: Map[String, Long] = Map("2" -> 2L, "1" -> 1) @@ -97,7 +97,7 @@ class ApproxHistogramTest extends TestCase { assertTrue(ir.histogram.isEmpty) } - def testMergeSketchAndHistogram(): Unit = { + it should "merge sketch and histogram" in { val approxHistogram = new ApproxHistogram[String](3) val counts1: Map[String, Long] = Map("5" -> 5L, "3" -> 3, "2" -> 2, "1" -> 1) val counts2: Map[String, Long] = Map("2" -> 2L) @@ -119,7 +119,7 @@ class ApproxHistogramTest extends TestCase { assert(ir.histogram.isEmpty) } - def testNormalizeHistogram(): Unit = { + it should "normalize histogram" in { val approxHistogram = new ApproxHistogram[String](3) val counts = (1L to 3).map(i => i.toString -> i).toMap val ir = makeIr(approxHistogram, counts) @@ -129,7 +129,7 @@ class ApproxHistogramTest extends TestCase { assertEquals(ir, normalized) } - def testNormalizeSketch(): Unit = { + it should "normalize sketch" in { val approxHistogram = new ApproxHistogram[String](3) val counts = (1L to 4).map(i => i.toString -> i).toMap val expected = counts.toSeq.sortBy(_._2).reverse.take(3).toMap diff --git a/aggregator/src/test/scala/ai/chronon/aggregator/test/ApproxPercentilesTest.scala b/aggregator/src/test/scala/ai/chronon/aggregator/test/ApproxPercentilesTest.scala index 8cb92e4dad..ae83db6bfd 100644 --- a/aggregator/src/test/scala/ai/chronon/aggregator/test/ApproxPercentilesTest.scala +++ b/aggregator/src/test/scala/ai/chronon/aggregator/test/ApproxPercentilesTest.scala @@ -18,15 +18,15 @@ package ai.chronon.aggregator.test import ai.chronon.aggregator.base.ApproxPercentiles import ai.chronon.aggregator.row.StatsGenerator -import junit.framework.TestCase import org.apache.datasketches.kll.KllFloatsSketch import org.junit.Assert._ +import org.scalatest.flatspec.AnyFlatSpec import org.slf4j.Logger import org.slf4j.LoggerFactory import scala.util.Random -class ApproxPercentilesTest extends TestCase { +class ApproxPercentilesTest extends AnyFlatSpec { @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) def basicImplTestHelper(nums: Int, slide: Int, k: Int, percentiles: Array[Double], errorPercent: Float): Unit = { @@ -56,7 +56,7 @@ class ApproxPercentilesTest extends TestCase { diffs.foreach(diff => assertTrue(diff < errorMargin)) } - def testBasicPercentiles: Unit = { + it should "basic percentiles: unit = {" in { val percentiles_tested: Int = 31 val percentiles: Array[Double] = (0 to percentiles_tested).toArray.map(i => i * 1.0 / percentiles_tested) basicImplTestHelper(3000, 5, 100, percentiles, errorPercent = 4) @@ -74,7 +74,7 @@ class ApproxPercentilesTest extends TestCase { drift } - def testPSIDrifts(): Unit = { + it should "psi drifts" in { assertTrue( getPSIDrift( Array(1, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7).map(_.toFloat), diff --git a/aggregator/src/test/scala/ai/chronon/aggregator/test/FrequentItemsTest.scala b/aggregator/src/test/scala/ai/chronon/aggregator/test/FrequentItemsTest.scala index e117f2e49f..c3c979307c 100644 --- a/aggregator/src/test/scala/ai/chronon/aggregator/test/FrequentItemsTest.scala +++ b/aggregator/src/test/scala/ai/chronon/aggregator/test/FrequentItemsTest.scala @@ -4,14 +4,14 @@ import ai.chronon.aggregator.base.FrequentItemType import ai.chronon.aggregator.base.FrequentItems import ai.chronon.aggregator.base.FrequentItemsFriendly import ai.chronon.aggregator.base.ItemsSketchIR -import junit.framework.TestCase import org.junit.Assert._ +import org.scalatest.flatspec.AnyFlatSpec import java.util import scala.jdk.CollectionConverters._ -class FrequentItemsTest extends TestCase { - def testNonPowerOfTwoAndTruncate(): Unit = { +class FrequentItemsTest extends AnyFlatSpec { + it should "non power of two and truncate" in { val size = 3 val items = new FrequentItems[String](size) val ir = items.prepare("4") @@ -32,7 +32,7 @@ class FrequentItemsTest extends TestCase { )), result) } - def testLessItemsThanSize(): Unit = { + it should "less items than size" in { val size = 10 val items = new FrequentItems[java.lang.Long](size) val ir = items.prepare(3) @@ -52,7 +52,7 @@ class FrequentItemsTest extends TestCase { )), result) } - def testZeroSize(): Unit = { + it should "zero size" in { val size = 0 val items = new FrequentItems[java.lang.Double](size) val ir = items.prepare(3.0) @@ -68,7 +68,7 @@ class FrequentItemsTest extends TestCase { assertEquals(new util.HashMap[String, Double](), result) } - def testSketchSizes(): Unit = { + it should "sketch sizes" in { val expectedSketchSizes = Map( -1 -> 2, @@ -87,7 +87,7 @@ class FrequentItemsTest extends TestCase { assertEquals(expectedSketchSizes, actualSketchSizes) } - def testNormalization(): Unit = { + it should "normalization" in { val testValues = (1 to 4) .map(i => i -> i) .toMap @@ -118,7 +118,7 @@ class FrequentItemsTest extends TestCase { assertEquals(expectedStringValues, actualStringValues) } - def testBulkMerge(): Unit = { + it should "bulk merge" in { val sketch = new FrequentItems[String](3) val irs = Seq( diff --git a/aggregator/src/test/scala/ai/chronon/aggregator/test/MinHeapTest.scala b/aggregator/src/test/scala/ai/chronon/aggregator/test/MinHeapTest.scala index 5cf5dda1a5..7d3db30b95 100644 --- a/aggregator/src/test/scala/ai/chronon/aggregator/test/MinHeapTest.scala +++ b/aggregator/src/test/scala/ai/chronon/aggregator/test/MinHeapTest.scala @@ -17,14 +17,14 @@ package ai.chronon.aggregator.test import ai.chronon.aggregator.base.MinHeap -import junit.framework.TestCase import org.junit.Assert._ +import org.scalatest.flatspec.AnyFlatSpec import java.util import scala.collection.JavaConverters._ -class MinHeapTest extends TestCase { - def testInserts(): Unit = { +class MinHeapTest extends AnyFlatSpec { + it should "inserts" in { val mh = new MinHeap[Int](maxSize = 4, Ordering.Int) def make_container = new util.ArrayList[Int](4) diff --git a/aggregator/src/test/scala/ai/chronon/aggregator/test/MomentTest.scala b/aggregator/src/test/scala/ai/chronon/aggregator/test/MomentTest.scala index a81045984e..b6de29ce11 100644 --- a/aggregator/src/test/scala/ai/chronon/aggregator/test/MomentTest.scala +++ b/aggregator/src/test/scala/ai/chronon/aggregator/test/MomentTest.scala @@ -1,12 +1,12 @@ package ai.chronon.aggregator.test import ai.chronon.aggregator.base._ -import junit.framework.TestCase import org.apache.commons.math3.stat.descriptive.moment.{Kurtosis => ApacheKurtosis} import org.apache.commons.math3.stat.descriptive.moment.{Skewness => ApacheSkew} import org.junit.Assert._ +import org.scalatest.flatspec.AnyFlatSpec -class MomentTest extends TestCase { +class MomentTest extends AnyFlatSpec { def makeAgg(aggregator: MomentAggregator, values: Seq[Double]): (MomentAggregator, MomentsIR) = { var ir = aggregator.prepare(values.head) @@ -36,32 +36,32 @@ class MomentTest extends TestCase { assertEquals(expected(v1 ++ v2), agg.finalize(ir), 0.1) } - def testUpdate(): Unit = { + it should "update" in { val values = Seq(1.1, 2.2, 3.3, 4.4, 5.5) assertUpdate(new Skew(), values, expectedSkew) assertUpdate(new Kurtosis(), values, expectedKurtosis) } - def testInsufficientSizes(): Unit = { + it should "insufficient sizes" in { val values = Seq(1.1, 2.2, 3.3, 4.4) assertUpdate(new Skew(), values.take(2), _ => Double.NaN) assertUpdate(new Kurtosis(), values.take(3), _ => Double.NaN) } - def testNoVariance(): Unit = { + it should "no variance" in { val values = Seq(1.0, 1.0, 1.0, 1.0) assertUpdate(new Skew(), values, _ => Double.NaN) assertUpdate(new Kurtosis(), values, _ => Double.NaN) } - def testMerge(): Unit = { + it should "merge" in { val values1 = Seq(1.1, 2.2, 3.3) val values2 = Seq(4.4, 5.5) assertMerge(new Kurtosis(), values1, values2, expectedKurtosis) assertMerge(new Skew(), values1, values2, expectedSkew) } - def testNormalize(): Unit = { + it should "normalize" in { val values = Seq(1.0, 2.0, 3.0, 4.0, 5.0) val (agg, ir) = makeAgg(new Kurtosis, values) diff --git a/aggregator/src/test/scala/ai/chronon/aggregator/test/RowAggregatorTest.scala b/aggregator/src/test/scala/ai/chronon/aggregator/test/RowAggregatorTest.scala index c96045d7d7..dcb7a0a9c3 100644 --- a/aggregator/src/test/scala/ai/chronon/aggregator/test/RowAggregatorTest.scala +++ b/aggregator/src/test/scala/ai/chronon/aggregator/test/RowAggregatorTest.scala @@ -18,8 +18,8 @@ package ai.chronon.aggregator.test import ai.chronon.aggregator.row.RowAggregator import ai.chronon.api._ -import junit.framework.TestCase import org.junit.Assert._ +import org.scalatest.flatspec.AnyFlatSpec import java.util import scala.collection.JavaConverters._ @@ -48,8 +48,8 @@ object TestRow { def apply(inputsArray: Any*): TestRow = new TestRow(inputsArray: _*)() } -class RowAggregatorTest extends TestCase { - def testUpdate(): Unit = { +class RowAggregatorTest extends AnyFlatSpec { + it should "update" in { val rows = List( TestRow(1L, 4, 5.0f, "A", Seq(5, 3, 4), Seq("D", "A", "B", "A"), Map("A" -> 1, "B" -> 2)), TestRow(2L, 3, 4.0f, "B", Seq(6, null), Seq(), null), diff --git a/aggregator/src/test/scala/ai/chronon/aggregator/test/SawtoothAggregatorTest.scala b/aggregator/src/test/scala/ai/chronon/aggregator/test/SawtoothAggregatorTest.scala index 72f97d6712..d58ef03b32 100644 --- a/aggregator/src/test/scala/ai/chronon/aggregator/test/SawtoothAggregatorTest.scala +++ b/aggregator/src/test/scala/ai/chronon/aggregator/test/SawtoothAggregatorTest.scala @@ -22,8 +22,8 @@ import ai.chronon.aggregator.windowing._ import ai.chronon.api.Extensions.AggregationOps import ai.chronon.api._ import com.google.gson.Gson -import junit.framework.TestCase import org.junit.Assert._ +import org.scalatest.flatspec.AnyFlatSpec import org.slf4j.Logger import org.slf4j.LoggerFactory @@ -46,9 +46,9 @@ class Timer { } } -class SawtoothAggregatorTest extends TestCase { +class SawtoothAggregatorTest extends AnyFlatSpec { - def testTailAccuracy(): Unit = { + it should "tail accuracy" in { val timer = new Timer val queries = CStream.genTimestamps(new Window(30, TimeUnit.DAYS), 10000, 5 * 60 * 1000) @@ -119,7 +119,7 @@ class SawtoothAggregatorTest extends TestCase { } } - def testRealTimeAccuracy(): Unit = { + it should "real time accuracy" in { val timer = new Timer val queries = CStream.genTimestamps(new Window(1, TimeUnit.DAYS), 1000) val columns = Seq(Column("ts", LongType, 180), diff --git a/aggregator/src/test/scala/ai/chronon/aggregator/test/SawtoothOnlineAggregatorTest.scala b/aggregator/src/test/scala/ai/chronon/aggregator/test/SawtoothOnlineAggregatorTest.scala index 7341bcf542..37f559beae 100644 --- a/aggregator/src/test/scala/ai/chronon/aggregator/test/SawtoothOnlineAggregatorTest.scala +++ b/aggregator/src/test/scala/ai/chronon/aggregator/test/SawtoothOnlineAggregatorTest.scala @@ -24,17 +24,17 @@ import ai.chronon.api.Extensions.WindowOps import ai.chronon.api.Extensions.WindowUtils import ai.chronon.api._ import com.google.gson.Gson -import junit.framework.TestCase import org.junit.Assert.assertEquals +import org.scalatest.flatspec.AnyFlatSpec import java.time.Instant import java.time.ZoneOffset import java.time.format.DateTimeFormatter import java.util.Locale -class SawtoothOnlineAggregatorTest extends TestCase { +class SawtoothOnlineAggregatorTest extends AnyFlatSpec { - def testConsistency(): Unit = { + it should "consistency" in { val queryEndTs = TsUtils.round(System.currentTimeMillis(), WindowUtils.Day.millis) val batchEndTs = queryEndTs - WindowUtils.Day.millis val queries = CStream.genTimestamps(new Window(1, TimeUnit.DAYS), 1000) diff --git a/aggregator/src/test/scala/ai/chronon/aggregator/test/TwoStackLiteAggregatorTest.scala b/aggregator/src/test/scala/ai/chronon/aggregator/test/TwoStackLiteAggregatorTest.scala index f529223c59..8db84d48bf 100644 --- a/aggregator/src/test/scala/ai/chronon/aggregator/test/TwoStackLiteAggregatorTest.scala +++ b/aggregator/src/test/scala/ai/chronon/aggregator/test/TwoStackLiteAggregatorTest.scala @@ -32,13 +32,13 @@ import ai.chronon.api.StructType import ai.chronon.api.TimeUnit import ai.chronon.api.Window import com.google.gson.Gson -import junit.framework.TestCase import org.junit.Assert._ +import org.scalatest.flatspec.AnyFlatSpec import scala.collection.Seq -class TwoStackLiteAggregatorTest extends TestCase{ - def testBufferWithTopK(): Unit = { +class TwoStackLiteAggregatorTest extends AnyFlatSpec { + it should "buffer with top k" in { val topK = new TopK[Integer](IntType, 2) val bankersBuffer = new TwoStackLiteAggregationBuffer(topK, 5) assertEquals(null, bankersBuffer.query) // null @@ -63,7 +63,7 @@ class TwoStackLiteAggregatorTest extends TestCase{ assertBufferEquals(Seq(10), bankersBuffer.query) } - def testAgainstSawtooth(): Unit = { + it should "against sawtooth" in { val timer = new Timer val queries = CStream.genTimestamps(new Window(30, TimeUnit.DAYS), 100000, 5 * 60 * 1000) diff --git a/aggregator/src/test/scala/ai/chronon/aggregator/test/VarianceTest.scala b/aggregator/src/test/scala/ai/chronon/aggregator/test/VarianceTest.scala index b7922189f6..4ec3a97b42 100644 --- a/aggregator/src/test/scala/ai/chronon/aggregator/test/VarianceTest.scala +++ b/aggregator/src/test/scala/ai/chronon/aggregator/test/VarianceTest.scala @@ -17,12 +17,12 @@ package ai.chronon.aggregator.test import ai.chronon.aggregator.base.Variance -import junit.framework.TestCase import org.junit.Assert._ +import org.scalatest.flatspec.AnyFlatSpec import org.slf4j.Logger import org.slf4j.LoggerFactory -class VarianceTest extends TestCase { +class VarianceTest extends AnyFlatSpec { @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) def mean(elems: Seq[Double]): Double = elems.sum / elems.length @@ -60,7 +60,7 @@ class VarianceTest extends TestCase { assertTrue((naiveResult - welfordResult) / naiveResult < 0.0000001) } - def testVariance: Unit = { + it should "variance: unit = {" in { compare(1000000) compare(1000000, min = 100000, max = 100001) } diff --git a/api/src/test/scala/ai/chronon/api/test/DataTypeConversionTest.scala b/api/src/test/scala/ai/chronon/api/test/DataTypeConversionTest.scala index 1ae2819bb0..703afd913b 100644 --- a/api/src/test/scala/ai/chronon/api/test/DataTypeConversionTest.scala +++ b/api/src/test/scala/ai/chronon/api/test/DataTypeConversionTest.scala @@ -20,14 +20,13 @@ import ai.chronon.api._ import ai.chronon.api.thrift.TSerializer import ai.chronon.api.thrift.protocol.TSimpleJSONProtocol import org.junit.Assert._ -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec import org.slf4j.Logger import org.slf4j.LoggerFactory -class DataTypeConversionTest { +class DataTypeConversionTest extends AnyFlatSpec { @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) - @Test - def testDataTypeToThriftAndBack(): Unit = { + it should "data type to thrift and back" in { // build some complex type val dType = StructType( "root", diff --git a/api/src/test/scala/ai/chronon/api/test/ExtensionsTest.scala b/api/src/test/scala/ai/chronon/api/test/ExtensionsTest.scala index f128de98b5..66f1916095 100644 --- a/api/src/test/scala/ai/chronon/api/test/ExtensionsTest.scala +++ b/api/src/test/scala/ai/chronon/api/test/ExtensionsTest.scala @@ -25,16 +25,15 @@ import ai.chronon.api.ScalaJavaConversions._ import org.junit.Assert.assertEquals import org.junit.Assert.assertFalse import org.junit.Assert.assertTrue -import org.junit.Test import org.mockito.Mockito.spy import org.mockito.Mockito.when +import org.scalatest.flatspec.AnyFlatSpec import java.util.Arrays -class ExtensionsTest { +class ExtensionsTest extends AnyFlatSpec { - @Test - def testSubPartitionFilters(): Unit = { + it should "sub partition filters" in { val source = Builders.Source.events(query = null, table = "db.table/system=mobile/currency=USD") assertEquals( Map("system" -> "mobile", "currency" -> "USD"), @@ -42,8 +41,7 @@ class ExtensionsTest { ) } - @Test - def testOwningTeam(): Unit = { + it should "owning team" in { val metadata = Builders.MetaData( customJson = "{\"check_consistency\": true, \"lag\": 0, \"team_override\": \"ml_infra\"}", @@ -61,22 +59,19 @@ class ExtensionsTest { ) } - @Test - def testRowIdentifier(): Unit = { + it should "row identifier" in { val labelPart = Builders.LabelPart(); val res = labelPart.rowIdentifier(Arrays.asList("yoyo", "yujia"), "ds") assertTrue(res.contains("ds")) } - @Test - def partSkewFilterShouldReturnNoneWhenNoSkewKey(): Unit = { + it should "part skew filter should return none when no skew key" in { val joinPart = Builders.JoinPart() val join = Builders.Join(joinParts = Seq(joinPart)) assertTrue(join.partSkewFilter(joinPart).isEmpty) } - @Test - def partSkewFilterShouldReturnCorrectlyWithSkewKeys(): Unit = { + it should "part skew filter should return correctly with skew keys" in { val groupByMetadata = Builders.MetaData(name = "test") val groupBy = Builders.GroupBy(keyColumns = Seq("a", "c"), metaData = groupByMetadata) val joinPart = Builders.JoinPart(groupBy = groupBy) @@ -85,8 +80,7 @@ class ExtensionsTest { assertEquals("a NOT IN (b) OR c NOT IN (d)", join.partSkewFilter(joinPart).get) } - @Test - def partSkewFilterShouldReturnCorrectlyWithPartialSkewKeys(): Unit = { + it should "part skew filter should return correctly with partial skew keys" in { val groupByMetadata = Builders.MetaData(name = "test") val groupBy = Builders.GroupBy(keyColumns = Seq("c"), metaData = groupByMetadata) @@ -97,8 +91,7 @@ class ExtensionsTest { assertEquals("c NOT IN (d)", join.partSkewFilter(joinPart).get) } - @Test - def partSkewFilterShouldReturnCorrectlyWithSkewKeysWithMapping(): Unit = { + it should "part skew filter should return correctly with skew keys with mapping" in { val groupByMetadata = Builders.MetaData(name = "test") val groupBy = Builders.GroupBy(keyColumns = Seq("x", "c"), metaData = groupByMetadata) @@ -109,8 +102,7 @@ class ExtensionsTest { assertEquals("x NOT IN (b) OR c NOT IN (d)", join.partSkewFilter(joinPart).get) } - @Test - def partSkewFilterShouldReturnNoneIfJoinPartHasNoRelatedKeys(): Unit = { + it should "part skew filter should return none if join part has no related keys" in { val groupByMetadata = Builders.MetaData(name = "test") val groupBy = Builders.GroupBy(keyColumns = Seq("non_existent"), metaData = groupByMetadata) @@ -120,8 +112,7 @@ class ExtensionsTest { assertTrue(join.partSkewFilter(joinPart).isEmpty) } - @Test - def groupByKeysShouldContainPartitionColumn(): Unit = { + it should "group by keys should contain partition column" in { val groupBy = spy(new GroupBy()) val baseKeys = List("a", "b") val partitionColumn = "ds" @@ -135,8 +126,7 @@ class ExtensionsTest { assertEquals(3, keys.size) } - @Test - def groupByKeysShouldContainTimeColumnForTemporalAccuracy(): Unit = { + it should "group by keys should contain time column for temporal accuracy" in { val groupBy = spy(new GroupBy()) val baseKeys = List("a", "b") val partitionColumn = "ds" @@ -151,8 +141,7 @@ class ExtensionsTest { assertEquals(4, keys.size) } - @Test - def testIsTilingEnabled(): Unit = { + it should "is tiling enabled" in { def buildGroupByWithCustomJson(customJson: String = null): GroupBy = Builders.GroupBy( metaData = Builders.MetaData(name = "featureGroupName", customJson = customJson) diff --git a/api/src/test/scala/ai/chronon/api/test/ParametricMacroTest.scala b/api/src/test/scala/ai/chronon/api/test/ParametricMacroTest.scala index b5df6993a6..8328d83f1d 100644 --- a/api/src/test/scala/ai/chronon/api/test/ParametricMacroTest.scala +++ b/api/src/test/scala/ai/chronon/api/test/ParametricMacroTest.scala @@ -18,11 +18,10 @@ package ai.chronon.api.test import ai.chronon.api.ParametricMacro import org.junit.Assert.assertEquals -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec -class ParametricMacroTest { - @Test - def testSubstitution(): Unit = { +class ParametricMacroTest extends AnyFlatSpec { + it should "substitution" in { val mc = ParametricMacro("something", { x => "st:" + x.keys.mkString("/") + "|" + x.values.mkString("/") }) val str = "something nothing-{{ something( a_1=b, 3.1, c=d) }}-something after-{{ thing:a1=b1 }}{{ something }}" val replaced = mc.replace(str) diff --git a/cloud_aws/src/test/scala/ai/chronon/integrations/aws/DynamoDBKVStoreTest.scala b/cloud_aws/src/test/scala/ai/chronon/integrations/aws/DynamoDBKVStoreTest.scala index eefe22540d..3ceab3d83b 100644 --- a/cloud_aws/src/test/scala/ai/chronon/integrations/aws/DynamoDBKVStoreTest.scala +++ b/cloud_aws/src/test/scala/ai/chronon/integrations/aws/DynamoDBKVStoreTest.scala @@ -11,9 +11,8 @@ import io.circe.generic.auto._ import io.circe.parser._ import io.circe.syntax._ import org.junit.After -import org.junit.Assert.fail import org.junit.Before -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.must.Matchers.be import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import software.amazon.awssdk.auth.credentials.AwsBasicCredentials @@ -33,7 +32,7 @@ import scala.util.Try case class Model(modelId: String, modelName: String, online: Boolean) case class TimeSeries(joinName: String, featureName: String, tileTs: Long, metric: String, summary: Array[Double]) -class DynamoDBKVStoreTest { +class DynamoDBKVStoreTest extends AnyFlatSpec { import DynamoDBKVStoreConstants._ @@ -82,8 +81,7 @@ class DynamoDBKVStoreTest { } // Test creation of a table with primary keys only (e.g. model) - @Test - def testCreatePKeyOnlyTable(): Unit = { + it should "create p key only table" in { val dataset = "models" val props = Map(isTimedSorted -> "false") val kvStore = new DynamoDBKVStoreImpl(client) @@ -96,8 +94,7 @@ class DynamoDBKVStoreTest { } // Test creation of a table with primary + sort keys (e.g. time series) - @Test - def testCreatePKeyAndSortKeyTable(): Unit = { + it should "create p key and sort key table" in { val dataset = "timeseries" val props = Map(isTimedSorted -> "true") val kvStore = new DynamoDBKVStoreImpl(client) @@ -110,8 +107,7 @@ class DynamoDBKVStoreTest { } // Test table scan with pagination - @Test - def testTableScanWithPagination(): Unit = { + it should "table scan with pagination" in { val dataset = "models" val props = Map(isTimedSorted -> "false") val kvStore = new DynamoDBKVStoreImpl(client) @@ -141,8 +137,7 @@ class DynamoDBKVStoreTest { } // Test write & read of a simple blob dataset - @Test - def testBlobDataRoundTrip(): Unit = { + it should "blob data round trip" in { val dataset = "models" val props = Map(isTimedSorted -> "false") val kvStore = new DynamoDBKVStoreImpl(client) @@ -174,8 +169,7 @@ class DynamoDBKVStoreTest { } // Test write and query of a time series dataset - @Test - def testTimeSeriesQuery(): Unit = { + it should "time series query" in { val dataset = "timeseries" val props = Map(isTimedSorted -> "true") val kvStore = new DynamoDBKVStoreImpl(client) diff --git a/cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/BigQueryCatalogTest.scala b/cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/BigQueryCatalogTest.scala index 156443234f..2beac27584 100644 --- a/cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/BigQueryCatalogTest.scala +++ b/cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/BigQueryCatalogTest.scala @@ -10,10 +10,10 @@ import com.google.cloud.hadoop.gcsio.GoogleCloudStorageFileSystem import org.apache.spark.sql.SparkSession import org.junit.Assert.assertEquals import org.junit.Assert.assertTrue -import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.flatspec.AnyFlatSpec import org.scalatestplus.mockito.MockitoSugar -class BigQueryCatalogTest extends AnyFunSuite with MockitoSugar { +class BigQueryCatalogTest extends AnyFlatSpec with MockitoSugar { lazy val spark: SparkSession = SparkSessionBuilder.build( "BigQuerySparkTest", @@ -30,11 +30,11 @@ class BigQueryCatalogTest extends AnyFunSuite with MockitoSugar { ) lazy val tableUtils: TableUtils = TableUtils(spark) - test("hive uris are set") { + it should "hive uris are set" in { assertEquals("thrift://localhost:9083", spark.sqlContext.getConf("hive.metastore.uris")) } - test("google runtime classes are available") { + it should "google runtime classes are available" in { assertTrue(GoogleHadoopFileSystemConfiguration.BLOCK_SIZE.isInstanceOf[HadoopConfigurationProperty[Long]]) assertCompiles("classOf[GoogleHadoopFileSystem]") assertCompiles("classOf[GoogleHadoopFS]") @@ -42,14 +42,14 @@ class BigQueryCatalogTest extends AnyFunSuite with MockitoSugar { } - test("verify dynamic classloading of GCP providers") { + it should "verify dynamic classloading of GCP providers" in { assertTrue(tableUtils.tableReadFormat("data.sample_native") match { case BQuery(_, _) => true case _ => false }) } - ignore("integration testing bigquery native table") { + it should "integration testing bigquery native table" ignore { val nativeTable = "data.sample_native" val table = tableUtils.loadTable(nativeTable) table.show @@ -60,7 +60,7 @@ class BigQueryCatalogTest extends AnyFunSuite with MockitoSugar { println(allParts) } - ignore("integration testing bigquery external table") { + it should "integration testing bigquery external table" ignore { val externalTable = "data.checkouts_parquet" val bs = GoogleHadoopFileSystemConfiguration.BLOCK_SIZE @@ -74,7 +74,7 @@ class BigQueryCatalogTest extends AnyFunSuite with MockitoSugar { println(allParts) } - ignore("integration testing bigquery partitions") { + it should "integration testing bigquery partitions" ignore { // TODO(tchow): This test is ignored because it requires a running instance of the bigquery. Need to figure out stubbing locally. // to run this: // 1. Set up a tunnel to dataproc federation proxy: diff --git a/cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/BigTableKVStoreTest.scala b/cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/BigTableKVStoreTest.scala index a45de17d27..2aa43bce8d 100644 --- a/cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/BigTableKVStoreTest.scala +++ b/cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/BigTableKVStoreTest.scala @@ -18,12 +18,12 @@ import com.google.cloud.bigtable.data.v2.models.RowMutation import com.google.cloud.bigtable.emulator.v2.BigtableEmulatorRule import org.junit.Before import org.junit.Rule -import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.when import org.mockito.Mockito.withSettings +import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import org.scalatestplus.mockito.MockitoSugar.mock @@ -34,7 +34,7 @@ import scala.concurrent.duration.DurationInt import scala.jdk.CollectionConverters._ @RunWith(classOf[JUnit4]) -class BigTableKVStoreTest { +class BigTableKVStoreTest extends AnyFlatSpec { import BigTableKVStore._ @@ -71,8 +71,7 @@ class BigTableKVStoreTest { adminClient = BigtableTableAdminClient.create(adminSettings) } - @Test - def testBigTableCreation(): Unit = { + it should "big table creation" in { val kvStore = new BigTableKVStoreImpl(dataClient, adminClient) val dataset = "test-table" kvStore.create(dataset) @@ -82,8 +81,7 @@ class BigTableKVStoreTest { } // Test write & read of a simple blob dataset - @Test - def testBlobDataRoundTrip(): Unit = { + it should "blob data round trip" in { val dataset = "models" val kvStore = new BigTableKVStoreImpl(dataClient, adminClient) kvStore.create(dataset) @@ -113,8 +111,7 @@ class BigTableKVStoreTest { validateBlobValueExpectedPayload(getResult2.head, value2) } - @Test - def testBlobDataUpdates(): Unit = { + it should "blob data updates" in { val dataset = "models" val kvStore = new BigTableKVStoreImpl(dataClient, adminClient) kvStore.create(dataset) @@ -149,8 +146,7 @@ class BigTableKVStoreTest { validateBlobValueExpectedPayload(getResultUpdated.head, valueUpdated) } - @Test - def testListWithPagination(): Unit = { + it should "list with pagination" in { val dataset = "models" val kvStore = new BigTableKVStoreImpl(dataClient, adminClient) kvStore.create(dataset) @@ -191,8 +187,7 @@ class BigTableKVStoreTest { .toSet } - @Test - def testMultiputFailures(): Unit = { + it should "multiput failures" in { val mockDataClient: BigtableDataClient = mock[BigtableDataClient](withSettings().mockMaker("mock-maker-inline")) val mockAdminClient = mock[BigtableTableAdminClient] val kvStoreWithMocks = new BigTableKVStoreImpl(mockDataClient, mockAdminClient) @@ -214,8 +209,7 @@ class BigTableKVStoreTest { putResults shouldBe Seq(false, false) } - @Test - def testMultigetFailures(): Unit = { + it should "multiget failures" in { val mockDataClient: BigtableDataClient = mock[BigtableDataClient](withSettings().mockMaker("mock-maker-inline")) val mockAdminClient = mock[BigtableTableAdminClient] val kvStoreWithMocks = new BigTableKVStoreImpl(mockDataClient, mockAdminClient) @@ -244,8 +238,7 @@ class BigTableKVStoreTest { } // Test write and query of a simple time series dataset - @Test - def testTimeSeriesQuery_MultipleDays(): Unit = { + it should "time series query_multiple days" in { val dataset = "TILE_SUMMARIES" val kvStore = new BigTableKVStoreImpl(dataClient, adminClient) kvStore.create(dataset) @@ -265,8 +258,7 @@ class BigTableKVStoreTest { validateTimeSeriesValueExpectedPayload(getResult1.head, expectedTimeSeriesPoints, fakePayload) } - @Test - def testMultipleDatasetTimeSeriesQuery_OneDay(): Unit = { + it should "multiple dataset time series query_one day" in { val dataset = "TILE_SUMMARIES" val kvStore = new BigTableKVStoreImpl(dataClient, adminClient) kvStore.create(dataset) @@ -286,8 +278,7 @@ class BigTableKVStoreTest { validateTimeSeriesValueExpectedPayload(getResult1.head, expectedTimeSeriesPoints, fakePayload) } - @Test - def testMultipleDatasetTimeSeriesQuery_SameDay(): Unit = { + it should "multiple dataset time series query_same day" in { val dataset = "TILE_SUMMARIES" val kvStore = new BigTableKVStoreImpl(dataClient, adminClient) kvStore.create(dataset) @@ -307,8 +298,7 @@ class BigTableKVStoreTest { validateTimeSeriesValueExpectedPayload(getResult1.head, expectedTimeSeriesPoints, fakePayload) } - @Test - def testMultipleDatasetTimeSeriesQuery_DaysWithoutData(): Unit = { + it should "multiple dataset time series query_days without data" in { val dataset = "TILE_SUMMARIES" val kvStore = new BigTableKVStoreImpl(dataClient, adminClient) kvStore.create(dataset) diff --git a/cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/DataprocSubmitterTest.scala b/cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/DataprocSubmitterTest.scala index 94afd02767..a33130abf5 100644 --- a/cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/DataprocSubmitterTest.scala +++ b/cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/DataprocSubmitterTest.scala @@ -7,12 +7,12 @@ import com.google.cloud.spark.bigquery.BigQueryUtilScala import org.junit.Assert.assertEquals import org.mockito.ArgumentMatchers._ import org.mockito.Mockito._ -import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.flatspec.AnyFlatSpec import org.scalatestplus.mockito.MockitoSugar -class DataprocSubmitterTest extends AnyFunSuite with MockitoSugar { +class DataprocSubmitterTest extends AnyFlatSpec with MockitoSugar { - test("DataprocClient should return job id when a job is submitted") { +"DataprocClient" should "return job id when a job is submitted" in { // Mock dataproc job client. val jobId = "mock-job-id" @@ -43,11 +43,11 @@ class DataprocSubmitterTest extends AnyFunSuite with MockitoSugar { assertEquals(submittedJobId, jobId) } - test("Verify classpath with spark-bigquery-connector") { + it should "Verify classpath with spark-bigquery-connector" in { BigQueryUtilScala.validateScalaVersionCompatibility() } - ignore("Used to iterate locally. Do not enable this in CI/CD!") { + it should "Used to iterate locally. Do not enable this in CI/CD!" ignore { val submitter = DataprocSubmitter() val submittedJobId = @@ -63,7 +63,7 @@ class DataprocSubmitterTest extends AnyFunSuite with MockitoSugar { println(submittedJobId) } - ignore("Used to test GBU bulk load locally. Do not enable this in CI/CD!") { + it should "Used to test GBU bulk load locally. Do not enable this in CI/CD!" ignore { val submitter = DataprocSubmitter() val submittedJobId = diff --git a/cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/GCSFormatTest.scala b/cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/GCSFormatTest.scala index 53b21cd88e..45251bf256 100644 --- a/cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/GCSFormatTest.scala +++ b/cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/GCSFormatTest.scala @@ -9,18 +9,18 @@ import org.apache.spark.sql.types.StringType import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types.StructType import org.junit.Assert.assertEquals -import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.flatspec.AnyFlatSpec import java.nio.file.Files -class GCSFormatTest extends AnyFunSuite { +class GCSFormatTest extends AnyFlatSpec { lazy val spark: SparkSession = SparkSessionBuilder.build( "BigQuerySparkTest", local = true ) - test("partitions method should return correctly parsed partitions as maps") { + it should "partitions method should return correctly parsed partitions as maps" in { val testData = List( ("20241223", "b", "c"), @@ -40,7 +40,7 @@ class GCSFormatTest extends AnyFunSuite { } - test("partitions method should handle empty partitions gracefully") { + it should "partitions method should handle empty partitions gracefully" in { val testData = List( ("20241223", "b", "c"), @@ -60,7 +60,7 @@ class GCSFormatTest extends AnyFunSuite { } - test("partitions method should handle date types") { + it should "partitions method should handle date types" in { val testData = List( Row("2024-12-23", "b", "c"), Row("2024-12-24", "e", "f"), diff --git a/flink/src/test/scala/ai/chronon/flink/test/AsyncKVStoreWriterTest.scala b/flink/src/test/scala/ai/chronon/flink/test/AsyncKVStoreWriterTest.scala index f3374c62ce..1cd13cc858 100644 --- a/flink/src/test/scala/ai/chronon/flink/test/AsyncKVStoreWriterTest.scala +++ b/flink/src/test/scala/ai/chronon/flink/test/AsyncKVStoreWriterTest.scala @@ -8,18 +8,17 @@ import org.apache.flink.api.scala._ import org.apache.flink.streaming.api.scala.DataStream import org.apache.flink.streaming.api.scala.DataStreamUtils import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec import org.scalatestplus.mockito.MockitoSugar.mock -class AsyncKVStoreWriterTest { +class AsyncKVStoreWriterTest extends AnyFlatSpec { val eventTs = 1519862400075L def createKVRequest(key: String, value: String, dataset: String, ts: Long): PutRequest = PutRequest(key.getBytes, value.getBytes, dataset, Some(ts)) - @Test - def testAsyncWriterSuccessWrites(): Unit = { + it should "async writer success writes" in { val env = StreamExecutionEnvironment.getExecutionEnvironment val source: DataStream[PutRequest] = env .fromCollection( @@ -41,8 +40,7 @@ class AsyncKVStoreWriterTest { // ensure that if we get an event that would cause the operator to throw an exception, // we don't crash the app - @Test - def testAsyncWriterHandlesPoisonPillWrites(): Unit = { + it should "async writer handles poison pill writes" in { val env = StreamExecutionEnvironment.getExecutionEnvironment val source: DataStream[KVStore.PutRequest] = env .fromCollection( diff --git a/flink/src/test/scala/ai/chronon/flink/test/FlinkJobIntegrationTest.scala b/flink/src/test/scala/ai/chronon/flink/test/FlinkJobIntegrationTest.scala index 9e10356c03..04cc8b03a4 100644 --- a/flink/src/test/scala/ai/chronon/flink/test/FlinkJobIntegrationTest.scala +++ b/flink/src/test/scala/ai/chronon/flink/test/FlinkJobIntegrationTest.scala @@ -14,13 +14,13 @@ import org.apache.spark.sql.Encoders import org.junit.After import org.junit.Assert.assertEquals import org.junit.Before -import org.junit.Test import org.mockito.Mockito.withSettings +import org.scalatest.flatspec.AnyFlatSpec import org.scalatestplus.mockito.MockitoSugar.mock import scala.jdk.CollectionConverters.asScalaBufferConverter -class FlinkJobIntegrationTest { +class FlinkJobIntegrationTest extends AnyFlatSpec { val flinkCluster = new MiniClusterWithClientResource( new MiniClusterResourceConfiguration.Builder() @@ -64,8 +64,7 @@ class FlinkJobIntegrationTest { CollectSink.values.clear() } - @Test - def testFlinkJobEndToEnd(): Unit = { + it should "flink job end to end" in { implicit val env: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment val elements = Seq( @@ -101,8 +100,7 @@ class FlinkJobIntegrationTest { assertEquals(writeEventCreatedDS.map(_.status), Seq(true, true, true)) } - @Test - def testTiledFlinkJobEndToEnd(): Unit = { + it should "tiled flink job end to end" in { implicit val env: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment // Create some test events with multiple different ids so we can check if tiling/pre-aggregation works correctly diff --git a/flink/src/test/scala/ai/chronon/flink/test/SparkExpressionEvalFnTest.scala b/flink/src/test/scala/ai/chronon/flink/test/SparkExpressionEvalFnTest.scala index 4af6541931..920bafc15f 100644 --- a/flink/src/test/scala/ai/chronon/flink/test/SparkExpressionEvalFnTest.scala +++ b/flink/src/test/scala/ai/chronon/flink/test/SparkExpressionEvalFnTest.scala @@ -6,12 +6,11 @@ import org.apache.flink.streaming.api.scala.DataStream import org.apache.flink.streaming.api.scala.DataStreamUtils import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment import org.apache.spark.sql.Encoders -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec -class SparkExpressionEvalFnTest { +class SparkExpressionEvalFnTest extends AnyFlatSpec { - @Test - def testBasicSparkExprEvalSanity(): Unit = { + it should "basic spark expr eval sanity" in { val elements = Seq( E2ETestEvent("test1", 12, 1.5, 1699366993123L), E2ETestEvent("test2", 13, 1.6, 1699366993124L), diff --git a/flink/src/test/scala/ai/chronon/flink/test/window/FlinkRowAggregationFunctionTest.scala b/flink/src/test/scala/ai/chronon/flink/test/window/FlinkRowAggregationFunctionTest.scala index 7472a42395..804cb903f9 100644 --- a/flink/src/test/scala/ai/chronon/flink/test/window/FlinkRowAggregationFunctionTest.scala +++ b/flink/src/test/scala/ai/chronon/flink/test/window/FlinkRowAggregationFunctionTest.scala @@ -3,13 +3,12 @@ package ai.chronon.flink.test.window import ai.chronon.api._ import ai.chronon.flink.window.FlinkRowAggregationFunction import ai.chronon.online.TileCodec -import org.junit.Assert.fail -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec import scala.util.Failure import scala.util.Try -class FlinkRowAggregationFunctionTest { +class FlinkRowAggregationFunctionTest extends AnyFlatSpec { private val aggregations: Seq[Aggregation] = Seq( Builders.Aggregation( Operation.AVERAGE, @@ -51,8 +50,7 @@ class FlinkRowAggregationFunctionTest { "title" -> StringType ) - @Test - def testFlinkAggregatorProducesCorrectResults(): Unit = { + it should "flink aggregator produces correct results" in { val groupByMetadata = Builders.MetaData(name = "my_group_by") val groupBy = Builders.GroupBy(metaData = groupByMetadata, aggregations = aggregations) val aggregateFunc = new FlinkRowAggregationFunction(groupBy, schema) @@ -94,8 +92,7 @@ class FlinkRowAggregationFunctionTest { assert(finalResult sameElements expectedResult) } - @Test - def testFlinkAggregatorResultsCanBeMergedWithOtherPreAggregates(): Unit = { + it should "flink aggregator results can be merged with other pre aggregates" in { val groupByMetadata = Builders.MetaData(name = "my_group_by") val groupBy = Builders.GroupBy(metaData = groupByMetadata, aggregations = aggregations) val aggregateFunc = new FlinkRowAggregationFunction(groupBy, schema) @@ -159,8 +156,7 @@ class FlinkRowAggregationFunctionTest { assert(finalResult sameElements expectedResult) } - @Test - def testFlinkAggregatorProducesCorrectResultsIfInputIsInIncorrectOrder(): Unit = { + it should "flink aggregator produces correct results if input is in incorrect order" in { val groupByMetadata = Builders.MetaData(name = "my_group_by") val groupBy = Builders.GroupBy(metaData = groupByMetadata, aggregations = aggregations) val aggregateFunc = new FlinkRowAggregationFunction(groupBy, schema) diff --git a/flink/src/test/scala/ai/chronon/flink/test/window/KeySelectorTest.scala b/flink/src/test/scala/ai/chronon/flink/test/window/KeySelectorTest.scala index b81c39aabc..9958fb17c9 100644 --- a/flink/src/test/scala/ai/chronon/flink/test/window/KeySelectorTest.scala +++ b/flink/src/test/scala/ai/chronon/flink/test/window/KeySelectorTest.scala @@ -2,11 +2,10 @@ package ai.chronon.flink.test.window import ai.chronon.api.Builders import ai.chronon.flink.window.KeySelector -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec -class KeySelectorTest { - @Test - def TestChrononFlinkJobCorrectlyKeysByAGroupbysEntityKeys(): Unit = { +class KeySelectorTest extends AnyFlatSpec { + it should "chronon flink job correctly keys by a groupbys entity keys" in { // We expect something like this to come out of the SparkExprEval operator val sampleSparkExprEvalOutput: Map[String, Any] = Map("number" -> 4242, "ip" -> "192.168.0.1", "user" -> "abc") @@ -24,8 +23,7 @@ class KeySelectorTest { ) } - @Test - def testKeySelectorFunctionReturnsSameHashesForListsWithTheSameContent(): Unit = { + it should "key selector function returns same hashes for lists with the same content" in { // This is more of a sanity check. It's not comprehensive. // SINGLE ENTITY KEY val map1: Map[String, Any] = diff --git a/online/src/test/scala/ai/chronon/online/test/CatalystUtilHiveUDFTest.scala b/online/src/test/scala/ai/chronon/online/test/CatalystUtilHiveUDFTest.scala index 5f537d60f6..5f4233f545 100644 --- a/online/src/test/scala/ai/chronon/online/test/CatalystUtilHiveUDFTest.scala +++ b/online/src/test/scala/ai/chronon/online/test/CatalystUtilHiveUDFTest.scala @@ -1,14 +1,12 @@ package ai.chronon.online.test import ai.chronon.online.CatalystUtil -import junit.framework.TestCase import org.junit.Assert.assertEquals -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec -class CatalystUtilHiveUDFTest extends TestCase with CatalystUtilTestSparkSQLStructs { +class CatalystUtilHiveUDFTest extends AnyFlatSpec with CatalystUtilTestSparkSQLStructs { - @Test - def testHiveUDFsViaSetupsShouldWork(): Unit = { + it should "hive ud fs via setups should work" in { val setups = Seq( "CREATE FUNCTION MINUS_ONE AS 'ai.chronon.online.test.Minus_One'", "CREATE FUNCTION CAT_STR AS 'ai.chronon.online.test.Cat_Str'", diff --git a/online/src/test/scala/ai/chronon/online/test/CatalystUtilTest.scala b/online/src/test/scala/ai/chronon/online/test/CatalystUtilTest.scala index bfe85beeba..e291f83ac4 100644 --- a/online/src/test/scala/ai/chronon/online/test/CatalystUtilTest.scala +++ b/online/src/test/scala/ai/chronon/online/test/CatalystUtilTest.scala @@ -18,11 +18,10 @@ package ai.chronon.online.test import ai.chronon.api._ import ai.chronon.online.CatalystUtil -import junit.framework.TestCase import org.junit.Assert.assertArrayEquals import org.junit.Assert.assertEquals import org.junit.Assert.assertTrue -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec import java.util @@ -167,10 +166,9 @@ trait CatalystUtilTestSparkSQLStructs { } -class CatalystUtilTest extends TestCase with CatalystUtilTestSparkSQLStructs { +class CatalystUtilTest extends AnyFlatSpec with CatalystUtilTestSparkSQLStructs { - @Test - def testSelectStarWithCommonScalarsShouldReturnAsIs(): Unit = { + it should "select star with common scalars should return as is" in { val selects = Seq( "bool_x" -> "bool_x", "int32_x" -> "int32_x", @@ -190,8 +188,7 @@ class CatalystUtilTest extends TestCase with CatalystUtilTestSparkSQLStructs { assertArrayEquals(res.get("bytes_x").asInstanceOf[Array[Byte]], "world".getBytes()) } - @Test - def testMathWithCommonScalarsShouldFollowOrderOfOperations(): Unit = { + it should "math with common scalars should follow order of operations" in { val selects = Seq( "a" -> "4 + 5 * 32 - 2", "b" -> "(int32_x - 1) / 6 * 3 + 7 % 3", @@ -209,8 +206,7 @@ class CatalystUtilTest extends TestCase with CatalystUtilTestSparkSQLStructs { assertEquals(res.get("e"), 1.5) } - @Test - def testCommonFunctionsWithCommonScalarsShouldWork(): Unit = { + it should "common functions with common scalars should work" in { val selects = Seq( "a" -> "ABS(CAST(-1.0 * `int32_x` + 1.5 AS LONG))", "b" -> "BASE64('Spark SQL')", @@ -241,8 +237,7 @@ class CatalystUtilTest extends TestCase with CatalystUtilTestSparkSQLStructs { assertEquals(res.get("k"), Int.MaxValue) } - @Test - def testDatetimeWithCommonScalarsShouldWork(): Unit = { + it should "datetime with common scalars should work" in { val selects = Seq( "a" -> "FROM_UNIXTIME(int32_x)", "b" -> "CURRENT_TIMESTAMP()", @@ -260,8 +255,7 @@ class CatalystUtilTest extends TestCase with CatalystUtilTestSparkSQLStructs { assertEquals(res.get("e"), 5) } - @Test - def testSimpleUdfsWithCommonScalarsShouldWork(): Unit = { + it should "simple udfs with common scalars should work" in { CatalystUtil.session.udf.register("bool_udf", (x: Boolean) => x ^ x) CatalystUtil.session.udf.register("INT32_UDF", (x: Int) => x - 1) CatalystUtil.session.udf.register("int64_UDF", (x: Long) => x - 1) @@ -287,8 +281,7 @@ class CatalystUtilTest extends TestCase with CatalystUtilTestSparkSQLStructs { assertArrayEquals(res.get("bytes_x").asInstanceOf[Array[Byte]], "worldworld".getBytes()) } - @Test - def testComplexUdfsWithCommonScalarsShouldWork(): Unit = { + it should "complex udfs with common scalars should work" in { CatalystUtil.session.udf.register("two_param_udf", (x: Int, y: Long) => y - x) val add_one = (x: Int) => x + 1 CatalystUtil.session.udf.register("add_two_udf", (x: Int) => add_one(add_one(x))) @@ -308,8 +301,7 @@ class CatalystUtilTest extends TestCase with CatalystUtilTestSparkSQLStructs { assertEquals(res.get("recursive_udf"), 21) } - @Test - def testDefinitelyFalseFilterWithCommonScalarsShouldReturnNone(): Unit = { + it should "definitely false filter with common scalars should return none" in { // aka. optimized False, LocalTableScanExec case val selects = Seq("a" -> "int32_x") val wheres = Seq("FALSE AND int64_x > `int32_x`") @@ -318,8 +310,7 @@ class CatalystUtilTest extends TestCase with CatalystUtilTestSparkSQLStructs { assertTrue(res.isEmpty) } - @Test - def testTrueFilterWithCommonScalarsShouldReturnData(): Unit = { + it should "true filter with common scalars should return data" in { val selects = Seq("a" -> "int32_x") val wheres = Seq("FALSE OR int64_x > `int32_x`") val cu = new CatalystUtil(CommonScalarsStruct, selects, wheres) @@ -328,8 +319,7 @@ class CatalystUtilTest extends TestCase with CatalystUtilTestSparkSQLStructs { assertEquals(res.get("a"), Int.MaxValue) } - @Test - def testFalseFilterWithCommonScalarsShouldReturnNone(): Unit = { + it should "false filter with common scalars should return none" in { val selects = Seq("a" -> "int32_x") val wheres = Seq("FALSE OR int64_x < `int32_x`") val cu = new CatalystUtil(CommonScalarsStruct, selects, wheres) @@ -337,8 +327,7 @@ class CatalystUtilTest extends TestCase with CatalystUtilTestSparkSQLStructs { assertTrue(res.isEmpty) } - @Test - def testTrueFiltersWithCommonScalarsShouldReturnData(): Unit = { + it should "true filters with common scalars should return data" in { val selects = Seq("a" -> "int32_x") val wheres = Seq("int64_x > `int32_x`", "FALSE OR int64_x > `int32_x`") val cu = new CatalystUtil(CommonScalarsStruct, selects, wheres) @@ -347,8 +336,7 @@ class CatalystUtilTest extends TestCase with CatalystUtilTestSparkSQLStructs { assertEquals(res.get("a"), Int.MaxValue) } - @Test - def testFalseFiltersWithCommonScalarsShouldReturnNone(): Unit = { + it should "false filters with common scalars should return none" in { val selects = Seq("a" -> "int32_x") val wheres = Seq("int64_x > `int32_x`", "FALSE OR int64_x < `int32_x`") val cu = new CatalystUtil(CommonScalarsStruct, selects, wheres) @@ -356,8 +344,7 @@ class CatalystUtilTest extends TestCase with CatalystUtilTestSparkSQLStructs { assertTrue(res.isEmpty) } - @Test - def testEmptySeqFiltersWithCommonScalarsShouldReturnData(): Unit = { + it should "empty seq filters with common scalars should return data" in { val selects = Seq("a" -> "int32_x") val wheres = Seq() val cu = new CatalystUtil(CommonScalarsStruct, selects, wheres) @@ -366,8 +353,7 @@ class CatalystUtilTest extends TestCase with CatalystUtilTestSparkSQLStructs { assertEquals(res.get("a"), Int.MaxValue) } - @Test - def testFunctionInFilterWithCommonScalarsShouldWork(): Unit = { + it should "function in filter with common scalars should work" in { CatalystUtil.session.udf.register("sub_one", (x: Int) => x - 1) val selects = Seq("a" -> "int32_x") val wheres = Seq("COALESCE(NULL, NULL, int32_x, int64_x, NULL) = `int32_x`") @@ -377,8 +363,7 @@ class CatalystUtilTest extends TestCase with CatalystUtilTestSparkSQLStructs { assertEquals(res.get("a"), Int.MaxValue) } - @Test - def testUdfInFilterWithCommonScalarsShouldWork(): Unit = { + it should "udf in filter with common scalars should work" in { CatalystUtil.session.udf.register("sub_one", (x: Int) => x - 1) val selects = Seq("a" -> "int32_x") val wheres = Seq("int32_x - 1 = SUB_ONE(int32_x)") @@ -388,8 +373,7 @@ class CatalystUtilTest extends TestCase with CatalystUtilTestSparkSQLStructs { assertEquals(res.get("a"), Int.MaxValue) } - @Test - def testSelectStarWithCommonScalarsNullShouldReturnNulls(): Unit = { + it should "select star with common scalars null should return nulls" in { val selects = Seq( "bool_x" -> "bool_x", "int32_x" -> "int32_x", @@ -409,8 +393,7 @@ class CatalystUtilTest extends TestCase with CatalystUtilTestSparkSQLStructs { assertEquals(res.get("bytes_x"), null) } - @Test - def testSelectWithNestedShouldWork(): Unit = { + it should "select with nested should work" in { val selects = Seq( "inner_req" -> "inner_req", "inner_opt" -> "inner_opt", @@ -430,8 +413,7 @@ class CatalystUtilTest extends TestCase with CatalystUtilTestSparkSQLStructs { assertEquals(res.get("inner_opt_int32_opt"), 78) } - @Test - def testSelectWithNestedNullsShouldWork(): Unit = { + it should "select with nested nulls should work" in { val selects = Seq( "inner_req" -> "inner_req", "inner_opt" -> "inner_opt", @@ -447,8 +429,7 @@ class CatalystUtilTest extends TestCase with CatalystUtilTestSparkSQLStructs { assertEquals(res.get("inner_req_int32_opt"), null) } - @Test - def testSelectStarWithListContainersShouldReturnAsIs(): Unit = { + it should "select star with list containers should return as is" in { val selects = Seq( "bools" -> "bools", "int32s" -> "int32s", @@ -475,8 +456,7 @@ class CatalystUtilTest extends TestCase with CatalystUtilTestSparkSQLStructs { // Array inputs passed to the performSql method. This takes place when // we're dealing with Derivations in GroupBys that contain aggregations such // as ApproxPercentiles. - @Test - def testSelectStarWithListArrayContainersShouldReturnAsIs(): Unit = { + it should "select star with list array containers should return as is" in { val selects = Seq( "bools" -> "bools", "int32s" -> "int32s", @@ -499,8 +479,7 @@ class CatalystUtilTest extends TestCase with CatalystUtilTestSparkSQLStructs { assertArrayEquals(res_bytess.get(1).asInstanceOf[Array[Byte]], "world".getBytes()) } - @Test - def testIndexingWithListContainersShouldWork(): Unit = { + it should "indexing with list containers should work" in { val selects = Seq( "a" -> "int64s[1] + int32s[2]" ) @@ -510,8 +489,7 @@ class CatalystUtilTest extends TestCase with CatalystUtilTestSparkSQLStructs { assertEquals(res.get("a"), 8L) } - @Test - def testFunctionsWithListContainersShouldWork(): Unit = { + it should "functions with list containers should work" in { val selects = Seq( "a" -> "ARRAY(2, 4, 6)", "b" -> "ARRAY_REPEAT('123', 2)", @@ -529,8 +507,7 @@ class CatalystUtilTest extends TestCase with CatalystUtilTestSparkSQLStructs { assertEquals(res.get("e"), 3) } - @Test - def testSelectStarWithMapContainersShouldReturnAsIs(): Unit = { + it should "select star with map containers should return as is" in { val selects = Seq( "bools" -> "bools", "int32s" -> "int32s", @@ -553,8 +530,7 @@ class CatalystUtilTest extends TestCase with CatalystUtilTestSparkSQLStructs { assertArrayEquals(res_bytess.get("b").asInstanceOf[Array[Byte]], "world".getBytes()) } - @Test - def testIndexingWithMapContainersShouldWork(): Unit = { + it should "indexing with map containers should work" in { val selects = Seq( "a" -> "int32s[2]", "b" -> "strings['a']" @@ -566,8 +542,7 @@ class CatalystUtilTest extends TestCase with CatalystUtilTestSparkSQLStructs { assertEquals(res.get("b"), "hello") } - @Test - def testFunctionsWithMapContainersShouldWork(): Unit = { + it should "functions with map containers should work" in { val selects = Seq( "a" -> "MAP(1, '2', 3, '4')", "b" -> "map_keys(int32s)", diff --git a/online/src/test/scala/ai/chronon/online/test/DataStreamBuilderTest.scala b/online/src/test/scala/ai/chronon/online/test/DataStreamBuilderTest.scala index c549ebac73..56a4d8a4cc 100644 --- a/online/src/test/scala/ai/chronon/online/test/DataStreamBuilderTest.scala +++ b/online/src/test/scala/ai/chronon/online/test/DataStreamBuilderTest.scala @@ -31,11 +31,11 @@ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Row import org.apache.spark.sql.SparkSession import org.junit.Assert.assertTrue -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec import org.slf4j.Logger import org.slf4j.LoggerFactory -class DataStreamBuilderTest { +class DataStreamBuilderTest extends AnyFlatSpec { @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) lazy val spark: SparkSession = { System.setSecurityManager(null) @@ -47,8 +47,7 @@ class DataStreamBuilderTest { spark } - @Test - def testDataStreamQueryEvent(): Unit = { + it should "data stream query event" in { val topicInfo = TopicInfo.parse("kafka://topic_name/schema=my_schema/host=X/port=Y") val df = testDataFrame() // todo: test start/ end partition in where clause @@ -64,8 +63,7 @@ class DataStreamBuilderTest { assertTrue(dataStream.df.count() == 6) } - @Test - def testTopicInfoParsing(): Unit = { + it should "topic info parsing" in { checkTopicInfo(parse("kafka://topic_name/schema=test_schema/host=X/port=Y"), TopicInfo("topic_name", "kafka", Map("schema" -> "test_schema", "host" -> "X", "port" -> "Y"))) checkTopicInfo(parse("topic_name/host=X/port=Y"), diff --git a/online/src/test/scala/ai/chronon/online/test/FetcherBaseTest.scala b/online/src/test/scala/ai/chronon/online/test/FetcherBaseTest.scala index 70ffcfecc9..c8deb4ba98 100644 --- a/online/src/test/scala/ai/chronon/online/test/FetcherBaseTest.scala +++ b/online/src/test/scala/ai/chronon/online/test/FetcherBaseTest.scala @@ -29,14 +29,14 @@ import ai.chronon.online.KVStore.TimedValue import ai.chronon.online._ import org.junit.Assert.assertFalse import org.junit.Assert.assertTrue -import org.junit.Before -import org.junit.Test import org.mockito.Answers import org.mockito.ArgumentCaptor import org.mockito.ArgumentMatchers.any import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfterAll +import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers import org.scalatestplus.mockito.MockitoSugar @@ -48,7 +48,7 @@ import scala.util.Failure import scala.util.Success import scala.util.Try -class FetcherBaseTest extends MockitoSugar with Matchers with MockitoHelper { +class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with MockitoHelper with BeforeAndAfterAll { val GroupBy = "relevance.short_term_user_features" val Column = "pdp_view_count_14d" val GuestKey = "guest" @@ -58,8 +58,7 @@ class FetcherBaseTest extends MockitoSugar with Matchers with MockitoHelper { var fetcherBase: FetcherBase = _ var kvStore: KVStore = _ - @Before - def setup(): Unit = { + override def beforeAll(): Unit = { kvStore = mock[KVStore](Answers.RETURNS_DEEP_STUBS) // The KVStore execution context is implicitly used for // Future compositions in the Fetcher so provision it in @@ -68,8 +67,7 @@ class FetcherBaseTest extends MockitoSugar with Matchers with MockitoHelper { fetcherBase = spy(new FetcherBase(kvStore)) } - @Test - def testFetchColumnsSingleQuery(): Unit = { + it should "fetch columns single query" in { // Fetch a single query val keyMap = Map(GuestKey -> GuestId) val query = ColumnSpec(GroupBy, Column, None, Some(keyMap)) @@ -97,8 +95,7 @@ class FetcherBaseTest extends MockitoSugar with Matchers with MockitoHelper { actualRequest.get.keys shouldBe query.keyMapping.get } - @Test - def testFetchColumnsBatch(): Unit = { + it should "fetch columns batch" in { // Fetch a batch of queries val guestKeyMap = Map(GuestKey -> GuestId) val guestQuery = ColumnSpec(GroupBy, Column, Some(GuestKey), Some(guestKeyMap)) @@ -131,8 +128,7 @@ class FetcherBaseTest extends MockitoSugar with Matchers with MockitoHelper { actualRequests(1).keys shouldBe hostQuery.keyMapping.get } - @Test - def testFetchColumnsMissingResponse(): Unit = { + it should "fetch columns missing response" in { // Fetch a single query val keyMap = Map(GuestKey -> GuestId) val query = ColumnSpec(GroupBy, Column, None, Some(keyMap)) @@ -161,8 +157,7 @@ class FetcherBaseTest extends MockitoSugar with Matchers with MockitoHelper { } // updateServingInfo() is called when the batch response is from the KV store. - @Test - def testGetServingInfoShouldCallUpdateServingInfoIfBatchResponseIsFromKvStore(): Unit = { + it should "get serving info should call update serving info if batch response is from kv store" in { val oldServingInfo = mock[GroupByServingInfoParsed] val updatedServingInfo = mock[GroupByServingInfoParsed] doReturn(updatedServingInfo).when(fetcherBase).updateServingInfo(any(), any()) @@ -179,8 +174,7 @@ class FetcherBaseTest extends MockitoSugar with Matchers with MockitoHelper { // If a batch response is cached, the serving info should be refreshed. This is needed to prevent // the serving info from becoming stale if all the requests are cached. - @Test - def testGetServingInfoShouldRefreshServingInfoIfBatchResponseIsCached(): Unit = { + it should "get serving info should refresh serving info if batch response is cached" in { val ttlCache = mock[TTLCache[String, Try[GroupByServingInfoParsed]]] doReturn(ttlCache).when(fetcherBase).getGroupByServingInfo @@ -202,8 +196,7 @@ class FetcherBaseTest extends MockitoSugar with Matchers with MockitoHelper { verify(fetcherBase, never()).updateServingInfo(any(), any()) } - @Test - def testIsCachingEnabledCorrectlyDetermineIfCacheIsEnabled(): Unit = { + it should "is caching enabled correctly determine if cache is enabled" in { val flagStore: FlagStore = (flagName: String, attributes: java.util.Map[String, String]) => { flagName match { case "enable_fetcher_batch_ir_cache" => diff --git a/online/src/test/scala/ai/chronon/online/test/FetcherCacheTest.scala b/online/src/test/scala/ai/chronon/online/test/FetcherCacheTest.scala index d24503f078..84dc46bcba 100644 --- a/online/src/test/scala/ai/chronon/online/test/FetcherCacheTest.scala +++ b/online/src/test/scala/ai/chronon/online/test/FetcherCacheTest.scala @@ -13,12 +13,11 @@ import ai.chronon.online.Metrics.Context import org.junit.Assert.assertArrayEquals import org.junit.Assert.assertEquals import org.junit.Assert.assertNull -import org.junit.Assert.fail -import org.junit.Test import org.mockito.ArgumentMatchers.any import org.mockito.Mockito import org.mockito.Mockito._ import org.mockito.stubbing.Stubber +import org.scalatest.flatspec.AnyFlatSpec import org.scalatestplus.mockito.MockitoSugar import scala.collection.JavaConverters._ @@ -34,14 +33,13 @@ trait MockitoHelper extends MockitoSugar { } } -class FetcherCacheTest extends MockitoHelper { +class FetcherCacheTest extends AnyFlatSpec with MockitoHelper { class TestableFetcherCache(cache: Option[BatchIrCache]) extends FetcherCache { override val maybeBatchIrCache: Option[BatchIrCache] = cache } val batchIrCacheMaximumSize = 50 - @Test - def testBatchIrCacheCorrectlyCachesBatchIrs(): Unit = { + it should "batch ir cache correctly caches batch irs" in { val cacheName = "test" val batchIrCache = new BatchIrCache(cacheName, batchIrCacheMaximumSize) val dataset = "TEST_GROUPBY_BATCH" @@ -63,8 +61,7 @@ class FetcherCacheTest extends MockitoHelper { }) } - @Test - def testBatchIrCacheCorrectlyCachesMapResponse(): Unit = { + it should "batch ir cache correctly caches map response" in { val cacheName = "test" val batchIrCache = new BatchIrCache(cacheName, batchIrCacheMaximumSize) val dataset = "TEST_GROUPBY_BATCH" @@ -88,8 +85,7 @@ class FetcherCacheTest extends MockitoHelper { // Test that the cache keys are compared by equality, not by reference. In practice, this means that if two keys // have the same (dataset, keys, batchEndTsMillis), they will only be stored once in the cache. - @Test - def testBatchIrCacheKeysAreComparedByEquality(): Unit = { + it should "batch ir cache keys are compared by equality" in { val cacheName = "test" val batchIrCache = new BatchIrCache(cacheName, batchIrCacheMaximumSize) @@ -108,8 +104,7 @@ class FetcherCacheTest extends MockitoHelper { assert(batchIrCache.cache.estimatedSize() == 1) } - @Test - def testGetCachedRequestsReturnsCorrectCachedDataWhenCacheIsEnabled(): Unit = { + it should "get cached requests returns correct cached data when cache is enabled" in { val cacheName = "test" val testCache = Some(new BatchIrCache(cacheName, batchIrCacheMaximumSize)) val fetcherCache = new TestableFetcherCache(testCache) { @@ -144,8 +139,7 @@ class FetcherCacheTest extends MockitoHelper { assert(cachedRequestsAfterAddingItem.head._2 == finalBatchIr) } - @Test - def testGetCachedRequestsDoesNotCacheWhenCacheIsDisabledForGroupBy(): Unit = { + it should "get cached requests does not cache when cache is disabled for group by" in { val testCache = new BatchIrCache("test", batchIrCacheMaximumSize) val spiedTestCache = spy(testCache) val fetcherCache = new TestableFetcherCache(Some(testCache)) { @@ -171,8 +165,7 @@ class FetcherCacheTest extends MockitoHelper { verify(spiedTestCache, never()).cache } - @Test - def testGetBatchBytesReturnsLatestTimedValueBytesIfGreaterThanBatchEnd(): Unit = { + it should "get batch bytes returns latest timed value bytes if greater than batch end" in { val kvStoreResponse = Success( Seq(TimedValue(Array(1.toByte), 1000L), TimedValue(Array(2.toByte), 2000L)) ) @@ -181,8 +174,7 @@ class FetcherCacheTest extends MockitoHelper { assertArrayEquals(Array(2.toByte), batchBytes) } - @Test - def testGetBatchBytesReturnsNullIfLatestTimedValueTimestampIsLessThanBatchEnd(): Unit = { + it should "get batch bytes returns null if latest timed value timestamp is less than batch end" in { val kvStoreResponse = Success( Seq(TimedValue(Array(1.toByte), 1000L), TimedValue(Array(2.toByte), 1500L)) ) @@ -191,24 +183,21 @@ class FetcherCacheTest extends MockitoHelper { assertNull(batchBytes) } - @Test - def testGetBatchBytesReturnsNullWhenCachedBatchResponse(): Unit = { + it should "get batch bytes returns null when cached batch response" in { val finalBatchIr = mock[FinalBatchIr] val batchResponses = BatchResponses(finalBatchIr) val batchBytes = batchResponses.getBatchBytes(1000L) assertNull(batchBytes) } - @Test - def testGetBatchBytesReturnsNullWhenKvStoreBatchResponseFails(): Unit = { + it should "get batch bytes returns null when kv store batch response fails" in { val kvStoreResponse = Failure(new RuntimeException("KV Store error")) val batchResponses = BatchResponses(kvStoreResponse) val batchBytes = batchResponses.getBatchBytes(1000L) assertNull(batchBytes) } - @Test - def testGetBatchIrFromBatchResponseReturnsCorrectIRsWithCacheEnabled(): Unit = { + it should "get batch ir from batch response returns correct i rs with cache enabled" in { // Use a real cache val batchIrCache = new BatchIrCache("test_cache", batchIrCacheMaximumSize) @@ -249,8 +238,7 @@ class FetcherCacheTest extends MockitoHelper { verify(toBatchIr, times(1))(any(), any()) // decoding did happen } - @Test - def testGetBatchIrFromBatchResponseDecodesBatchBytesIfCacheDisabled(): Unit = { + it should "get batch ir from batch response decodes batch bytes if cache disabled" in { // Set up mocks and dummy data val servingInfo = mock[GroupByServingInfoParsed] val batchBytes = Array[Byte](1, 2, 3) @@ -269,8 +257,7 @@ class FetcherCacheTest extends MockitoHelper { assertEquals(finalBatchIr, ir) } - @Test - def testGetBatchIrFromBatchResponseReturnsCorrectMapResponseWithCacheEnabled(): Unit = { + it should "get batch ir from batch response returns correct map response with cache enabled" in { // Use a real cache val batchIrCache = new BatchIrCache("test_cache", batchIrCacheMaximumSize) // Set up mocks and dummy data @@ -315,8 +302,7 @@ class FetcherCacheTest extends MockitoHelper { assertEquals(batchIrCache.cache.getIfPresent(cacheKey), CachedMapBatchResponse(mapResponse2)) // key was added } - @Test - def testGetMapResponseFromBatchResponseDecodesBatchBytesIfCacheDisabled(): Unit = { + it should "get map response from batch response decodes batch bytes if cache disabled" in { // Set up mocks and dummy data val servingInfo = mock[GroupByServingInfoParsed] val batchBytes = Array[Byte](1, 2, 3) diff --git a/online/src/test/scala/ai/chronon/online/test/JoinCodecTest.scala b/online/src/test/scala/ai/chronon/online/test/JoinCodecTest.scala index aa4d8692e4..c185261f05 100644 --- a/online/src/test/scala/ai/chronon/online/test/JoinCodecTest.scala +++ b/online/src/test/scala/ai/chronon/online/test/JoinCodecTest.scala @@ -18,11 +18,10 @@ package ai.chronon.online.test import ai.chronon.online.OnlineDerivationUtil.reintroduceExceptions import org.junit.Assert.assertEquals -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec -class JoinCodecTest { - @Test - def testReintroduceException(): Unit = { +class JoinCodecTest extends AnyFlatSpec { + it should "reintroduce exception" in { val preDerived = Map("group_by_2_exception" -> "ex", "group_by_1_exception" -> "ex", "group_by_4_exception" -> "ex") val derived = Map( diff --git a/online/src/test/scala/ai/chronon/online/test/LRUCacheTest.scala b/online/src/test/scala/ai/chronon/online/test/LRUCacheTest.scala index 179586bcdf..ea9afe3459 100644 --- a/online/src/test/scala/ai/chronon/online/test/LRUCacheTest.scala +++ b/online/src/test/scala/ai/chronon/online/test/LRUCacheTest.scala @@ -2,28 +2,25 @@ package ai.chronon.online.test import ai.chronon.online.LRUCache import com.github.benmanes.caffeine.cache.{Cache => CaffeineCache} -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec -class LRUCacheTest { +class LRUCacheTest extends AnyFlatSpec { val testCache: CaffeineCache[String, String] = LRUCache[String, String]("testCache") - @Test - def testGetsNothingWhenThereIsNothing(): Unit = { + it should "gets nothing when there is nothing" in { assert(testCache.getIfPresent("key") == null) assert(testCache.estimatedSize() == 0) } - @Test - def testGetsSomethingWhenThereIsSomething(): Unit = { + it should "gets something when there is something" in { assert(testCache.getIfPresent("key") == null) testCache.put("key", "value") assert(testCache.getIfPresent("key") == "value") assert(testCache.estimatedSize() == 1) } - @Test - def testEvictsWhenSomethingIsSet(): Unit = { + it should "evicts when something is set" in { assert(testCache.estimatedSize() == 0) assert(testCache.getIfPresent("key") == null) testCache.put("key", "value") diff --git a/online/src/test/scala/ai/chronon/online/test/TagsTest.scala b/online/src/test/scala/ai/chronon/online/test/TagsTest.scala index c65545b4f5..25261cb2f9 100644 --- a/online/src/test/scala/ai/chronon/online/test/TagsTest.scala +++ b/online/src/test/scala/ai/chronon/online/test/TagsTest.scala @@ -22,9 +22,9 @@ import ai.chronon.online.Metrics.Context import ai.chronon.online.Metrics.Environment import ai.chronon.online.TTLCache import org.junit.Assert.assertEquals -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec -class TagsTest { +class TagsTest extends AnyFlatSpec { // test that ttlCache of context is creates non duplicated entries // copied from the private NonBlockingStatsDClient.tagString @@ -45,8 +45,7 @@ class TagsTest { sb.toString } - @Test - def testCachedTagsAreComputedTags(): Unit = { + it should "cached tags are computed tags" in { val cache = new TTLCache[Metrics.Context, String]( { ctx => ctx.toTags.mkString(",") }, { ctx => ctx }, diff --git a/online/src/test/scala/ai/chronon/online/test/ThriftDecodingTest.scala b/online/src/test/scala/ai/chronon/online/test/ThriftDecodingTest.scala index 2956a142a8..2b70c31b38 100644 --- a/online/src/test/scala/ai/chronon/online/test/ThriftDecodingTest.scala +++ b/online/src/test/scala/ai/chronon/online/test/ThriftDecodingTest.scala @@ -24,14 +24,13 @@ import ai.chronon.online.SerializableFunction import ai.chronon.online.TBaseDecoderFactory import com.google.gson.Gson import org.junit.Assert.assertEquals -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec import java.util -class ThriftDecodingTest { +class ThriftDecodingTest extends AnyFlatSpec { - @Test - def testDecoding(): Unit = { + it should "decoding" in { val tokens = new util.HashSet[String]() Seq("left", "source", "events", "derivations", "name", "expression") .foreach(tokens.add) diff --git a/online/src/test/scala/ai/chronon/online/test/TileCodecTest.scala b/online/src/test/scala/ai/chronon/online/test/TileCodecTest.scala index 1824dbdb8d..6cd15195e3 100644 --- a/online/src/test/scala/ai/chronon/online/test/TileCodecTest.scala +++ b/online/src/test/scala/ai/chronon/online/test/TileCodecTest.scala @@ -20,13 +20,13 @@ import ai.chronon.api._ import ai.chronon.online.ArrayRow import ai.chronon.online.TileCodec import org.junit.Assert.assertEquals -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec import org.slf4j.Logger import org.slf4j.LoggerFactory import scala.collection.JavaConverters._ -class TileCodecTest { +class TileCodecTest extends AnyFlatSpec { @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) private val histogram = Map[String, Int]("A" -> 3, "B" -> 2).asJava @@ -92,8 +92,7 @@ class TileCodecTest { new ArrayRow(values.map(_._2), ts) } - @Test - def testTileCodecIrSerRoundTrip(): Unit = { + it should "tile codec ir ser round trip" in { val groupByMetadata = Builders.MetaData(name = "my_group_by") val (aggregations, expectedVals) = aggregationsAndExpected.unzip val expectedFlattenedVals = expectedVals.flatten @@ -127,8 +126,7 @@ class TileCodecTest { } } - @Test - def testTileCodecIrSerRoundTrip_WithBuckets(): Unit = { + it should "tile codec ir ser round trip_with buckets" in { val groupByMetadata = Builders.MetaData(name = "my_group_by") val groupBy = Builders.GroupBy(metaData = groupByMetadata, aggregations = bucketedAggregations) val tileCodec = new TileCodec(groupBy, schema) diff --git a/online/src/test/scala/ai/chronon/online/test/stats/DriftMetricsTest.scala b/online/src/test/scala/ai/chronon/online/test/stats/DriftMetricsTest.scala index bfe1a2a073..e86643fbc9 100644 --- a/online/src/test/scala/ai/chronon/online/test/stats/DriftMetricsTest.scala +++ b/online/src/test/scala/ai/chronon/online/test/stats/DriftMetricsTest.scala @@ -4,10 +4,10 @@ import ai.chronon.api.ScalaJavaConversions._ import ai.chronon.observability.DriftMetric import ai.chronon.online.stats.DriftMetrics.histogramDistance import ai.chronon.online.stats.DriftMetrics.percentileDistance -import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers -class DriftMetricsTest extends AnyFunSuite with Matchers { +class DriftMetricsTest extends AnyFlatSpec with Matchers { def buildPercentiles(mean: Double, variance: Double, breaks: Int = 20): Array[Double] = { val stdDev = math.sqrt(variance) @@ -75,7 +75,7 @@ class DriftMetricsTest extends AnyFunSuite with Matchers { ) } - test("Low drift - similar distributions") { + it should "Low drift - similar distributions" in { val drifts = compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 101.0, varianceB = 225.0) // JSD assertions @@ -89,7 +89,7 @@ class DriftMetricsTest extends AnyFunSuite with Matchers { hellingerHisto should be < 0.05 } - test("Moderate drift - slightly different distributions") { + it should "Moderate drift - slightly different distributions" in { val drifts = compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 105.0, varianceB = 256.0) // JSD assertions @@ -101,7 +101,7 @@ class DriftMetricsTest extends AnyFunSuite with Matchers { hellingerPercentile should (be >= 0.05 and be <= 0.15) } - test("Severe drift - different means") { + it should "Severe drift - different means" in { val drifts = compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 110.0, varianceB = 225.0) // JSD assertions @@ -113,7 +113,7 @@ class DriftMetricsTest extends AnyFunSuite with Matchers { hellingerPercentile should be > 0.15 } - test("Severe drift - different variances") { + it should "Severe drift - different variances" in { val drifts = compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 105.0, varianceB = 100.0) // JSD assertions diff --git a/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala b/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala index 9819a5f244..9b3ac66be6 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala @@ -28,11 +28,11 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions.col import org.apache.spark.sql.functions.lit import org.junit.Assert.assertTrue -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec import org.slf4j.Logger import org.slf4j.LoggerFactory -class AnalyzerTest { +class AnalyzerTest extends AnyFlatSpec { @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) val spark: SparkSession = SparkSessionBuilder.build("AnalyzerTest", local = true) private val tableUtils = TableUtils(spark) @@ -46,8 +46,7 @@ class AnalyzerTest { private val viewsSource = getTestEventSource() - @Test - def testJoinAnalyzerSchemaWithValidation(): Unit = { + it should "join analyzer schema with validation" in { val viewsGroupBy = getViewsGroupBy("join_analyzer_test.item_gb", Operation.AVERAGE) val anotherViewsGroupBy = getViewsGroupBy("join_analyzer_test.another_item_gb", Operation.SUM) @@ -82,8 +81,7 @@ class AnalyzerTest { assertTrue(expectedSchema sameElements analyzerSchema) } - @Test(expected = classOf[java.lang.AssertionError]) - def testJoinAnalyzerValidationFailure(): Unit = { + it should "join analyzer validation failure" in { val viewsGroupBy = getViewsGroupBy("join_analyzer_test.item_gb", Operation.AVERAGE, source = getTestGBSource()) val usersGroupBy = getUsersGroupBy("join_analyzer_test.user_gb", Operation.AVERAGE, source = getTestGBSource()) @@ -111,8 +109,7 @@ class AnalyzerTest { analyzer.analyzeJoin(joinConf, validationAssert = true) } - @Test(expected = classOf[java.lang.AssertionError]) - def testJoinAnalyzerValidationDataAvailability(): Unit = { + it should "join analyzer validation data availability" in { // left side val itemQueries = List(Column("item", api.StringType, 100), Column("guest", api.StringType, 100)) val itemQueriesTable = s"$namespace.item_queries_with_user_table" @@ -147,8 +144,7 @@ class AnalyzerTest { analyzer.analyzeJoin(joinConf, validationAssert = true) } - @Test - def testJoinAnalyzerValidationDataAvailabilityMultipleSources(): Unit = { + it should "join analyzer validation data availability multiple sources" in { val leftSchema = List(Column("item", api.StringType, 100)) val leftTable = s"$namespace.multiple_sources_left_table" val leftData = DataFrameGen.events(spark, leftSchema, 10, partitions = 1) @@ -214,8 +210,7 @@ class AnalyzerTest { analyzer.analyzeJoin(joinConf, validationAssert = true) } - @Test - def testJoinAnalyzerCheckTimestampHasValues(): Unit = { + it should "join analyzer check timestamp has values" in { // left side // create the event source with values @@ -246,8 +241,7 @@ class AnalyzerTest { } - @Test(expected = classOf[java.lang.AssertionError]) - def testJoinAnalyzerCheckTimestampOutOfRange(): Unit = { + it should "join analyzer check timestamp out of range" in { // left side // create the event source with values out of range @@ -278,8 +272,7 @@ class AnalyzerTest { } - @Test(expected = classOf[java.lang.AssertionError]) - def testJoinAnalyzerCheckTimestampAllNulls(): Unit = { + it should "join analyzer check timestamp all nulls" in { // left side // create the event source with nulls @@ -310,8 +303,7 @@ class AnalyzerTest { } - @Test - def testGroupByAnalyzerCheckTimestampHasValues(): Unit = { + it should "group by analyzer check timestamp has values" in { val tableGroupBy = Builders.GroupBy( sources = Seq(getTestGBSourceWithTs()), @@ -329,8 +321,7 @@ class AnalyzerTest { } - @Test(expected = classOf[java.lang.AssertionError]) - def testGroupByAnalyzerCheckTimestampAllNulls(): Unit = { + it should "group by analyzer check timestamp all nulls" in { val tableGroupBy = Builders.GroupBy( sources = Seq(getTestGBSourceWithTs("nulls")), @@ -347,8 +338,7 @@ class AnalyzerTest { analyzer.analyzeGroupBy(tableGroupBy) } - @Test(expected = classOf[java.lang.AssertionError]) - def testGroupByAnalyzerCheckTimestampOutOfRange(): Unit = { + it should "group by analyzer check timestamp out of range" in { val tableGroupBy = Builders.GroupBy( sources = Seq(getTestGBSourceWithTs("out_of_range")), diff --git a/spark/src/test/scala/ai/chronon/spark/test/AvroTest.scala b/spark/src/test/scala/ai/chronon/spark/test/AvroTest.scala index d542c4095a..44af2b3aa7 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/AvroTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/AvroTest.scala @@ -25,17 +25,16 @@ import ai.chronon.spark.TableUtils import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions.col import org.apache.spark.sql.types.DecimalType -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec -class AvroTest { +class AvroTest extends AnyFlatSpec { val spark: SparkSession = SparkSessionBuilder.build("AvroTest", local = true) private val tableUtils = TableUtils(spark) private val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) private val monthAgo = tableUtils.partitionSpec.minus(today, new Window(30, TimeUnit.DAYS)) private val twoMonthsAgo = tableUtils.partitionSpec.minus(today, new Window(60, TimeUnit.DAYS)) - @Test - def testDecimal(): Unit = { + it should "decimal" in { val namespace = "test_decimal" tableUtils.createDatabase(namespace) diff --git a/spark/src/test/scala/ai/chronon/spark/test/ChainingFetcherTest.scala b/spark/src/test/scala/ai/chronon/spark/test/ChainingFetcherTest.scala index 56cce9b4c3..964a546549 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/ChainingFetcherTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/ChainingFetcherTest.scala @@ -29,7 +29,6 @@ import ai.chronon.online.SparkConversions import ai.chronon.spark.Extensions._ import ai.chronon.spark.utils.MockApi import ai.chronon.spark.{Join => _, _} -import junit.framework.TestCase import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Row import org.apache.spark.sql.SparkSession @@ -37,6 +36,7 @@ import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.functions.lit import org.junit.Assert.assertEquals import org.junit.Assert.assertTrue +import org.scalatest.flatspec.AnyFlatSpec import org.slf4j.Logger import org.slf4j.LoggerFactory @@ -46,7 +46,7 @@ import java.util.concurrent.Executors import scala.collection.Seq import scala.concurrent.ExecutionContext -class ChainingFetcherTest extends TestCase { +class ChainingFetcherTest extends AnyFlatSpec { @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) val sessionName = "ChainingFetcherTest" val spark: SparkSession = SparkSessionBuilder.build(sessionName, local = true) @@ -318,14 +318,14 @@ class ChainingFetcherTest extends TestCase { assertEquals(0, diff.count()) } - def testFetchParentJoin(): Unit = { + it should "fetch parent join" in { val namespace = "parent_join_fetch" val joinConf = generateMutationData(namespace, Accuracy.TEMPORAL) val (expected, fetcherResponse) = executeFetch(joinConf, "2021-04-15", namespace) compareTemporalFetch(joinConf, "2021-04-15", expected, fetcherResponse, "user") } - def testFetchChainingDeterministic(): Unit = { + it should "fetch chaining deterministic" in { val namespace = "chaining_fetch" val chainingJoinConf = generateChainingJoinData(namespace, Accuracy.TEMPORAL) assertTrue(chainingJoinConf.joinParts.get(0).groupBy.sources.get(0).isSetJoinSource) diff --git a/spark/src/test/scala/ai/chronon/spark/test/CompareTest.scala b/spark/src/test/scala/ai/chronon/spark/test/CompareTest.scala index 897987fe82..a17cce59d5 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/CompareTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/CompareTest.scala @@ -24,11 +24,11 @@ import ai.chronon.spark.TimedKvRdd import ai.chronon.spark.stats.CompareBaseJob import org.apache.spark.sql.DataFrame import org.apache.spark.sql.SparkSession -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec import org.slf4j.Logger import org.slf4j.LoggerFactory -class CompareTest { +class CompareTest extends AnyFlatSpec { @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) lazy val spark: SparkSession = SparkSessionBuilder.build("CompareTest", local = true) @@ -53,8 +53,7 @@ class CompareTest { val leftColumns: Seq[String] = Seq("serial", "value", "rating", "keyId", "ts", "ds") val rightColumns: Seq[String] = Seq("rev_serial", "rev_value", "rev_rating", "keyId", "ts", "ds") - @Test - def basicTest(): Unit = { + it should "basic" in { val leftRdd = spark.sparkContext.parallelize(leftData) val leftDf = spark.createDataFrame(leftRdd).toDF(leftColumns: _*) val rightRdd = spark.sparkContext.parallelize(rightData) @@ -85,8 +84,7 @@ class CompareTest { } } - @Test - def mappingTest(): Unit = { + it should "mapping" in { val leftRdd = spark.sparkContext.parallelize(leftData) val leftDf = spark.createDataFrame(leftRdd).toDF(leftColumns: _*) val rightRdd = spark.sparkContext.parallelize(rightData) @@ -122,8 +120,7 @@ class CompareTest { } } - @Test - def checkKeysTest(): Unit = { + it should "check keys" in { val leftRdd = spark.sparkContext.parallelize(leftData) val leftDf = spark.createDataFrame(leftRdd).toDF(leftColumns: _*) val rightRdd = spark.sparkContext.parallelize(rightData) @@ -138,8 +135,7 @@ class CompareTest { runFailureScenario(leftDf, rightDf, keys2, mapping2) } - @Test - def checkDataTypeTest(): Unit = { + it should "check data type" in { val leftData = Seq( (1, Some(1), 1.0, "a", toTs("2021-04-10 09:00:00"), "2021-04-10") ) @@ -161,8 +157,7 @@ class CompareTest { runFailureScenario(leftDf, rightDf, keys, mapping) } - @Test - def checkForWrongColumnCount(): Unit = { + it should "check for wrong column count" in { val leftData = Seq( (1, Some(1), 1.0, "a", "2021-04-10") ) @@ -184,8 +179,7 @@ class CompareTest { runFailureScenario(leftDf, rightDf, keys, mapping) } - @Test - def checkForMappingConsistency(): Unit = { + it should "check for mapping consistency" in { val leftData = Seq( (1, Some(1), 1.0, "a", toTs("2021-04-10 09:00:00"), "2021-04-10") ) diff --git a/spark/src/test/scala/ai/chronon/spark/test/DataRangeTest.scala b/spark/src/test/scala/ai/chronon/spark/test/DataRangeTest.scala index b80ed5b391..b92b8b5f93 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/DataRangeTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/DataRangeTest.scala @@ -22,14 +22,13 @@ import ai.chronon.spark.SparkSessionBuilder import ai.chronon.spark.TableUtils import org.apache.spark.sql.SparkSession import org.junit.Assert.assertEquals -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec -class DataRangeTest { +class DataRangeTest extends AnyFlatSpec { val spark: SparkSession = SparkSessionBuilder.build("DataRangeTest", local = true) private implicit val partitionSpec: PartitionSpec = TableUtils(spark).partitionSpec - @Test - def testIntersect(): Unit = { + it should "intersect" in { val range1 = PartitionRange(null, null) val range2 = PartitionRange("2023-01-01", "2023-01-02") assertEquals(range2, range1.intersect(range2)) diff --git a/spark/src/test/scala/ai/chronon/spark/test/EditDistanceTest.scala b/spark/src/test/scala/ai/chronon/spark/test/EditDistanceTest.scala index 7b4a3fbdd0..ed17d393b5 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/EditDistanceTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/EditDistanceTest.scala @@ -18,12 +18,11 @@ package ai.chronon.spark.test import ai.chronon.spark.stats.EditDistance import org.junit.Assert.assertEquals -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec -class EditDistanceTest { +class EditDistanceTest extends AnyFlatSpec { - @Test - def basic(): Unit = { + it should "basic" in { def of(a: Any, b: Any) = EditDistance.between(a, b) def ofString(a: String, b: String) = EditDistance.betweenStrings(a, b) diff --git a/spark/src/test/scala/ai/chronon/spark/test/ExternalSourcesTest.scala b/spark/src/test/scala/ai/chronon/spark/test/ExternalSourcesTest.scala index de0282ba2b..62ad4babd0 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/ExternalSourcesTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/ExternalSourcesTest.scala @@ -21,7 +21,7 @@ import ai.chronon.online.Fetcher.Request import ai.chronon.spark.LoggingSchema import ai.chronon.spark.utils.MockApi import org.junit.Assert._ -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec import java.util.Base64 import scala.collection.mutable @@ -29,9 +29,8 @@ import scala.concurrent.Await import scala.concurrent.duration.Duration import scala.concurrent.duration.SECONDS -class ExternalSourcesTest { - @Test - def testFetch(): Unit = { +class ExternalSourcesTest extends AnyFlatSpec { + it should "fetch" in { val plusOneSource = Builders.ExternalSource( metadata = Builders.MetaData( name = "plus_one" diff --git a/spark/src/test/scala/ai/chronon/spark/test/FeatureWithLabelJoinTest.scala b/spark/src/test/scala/ai/chronon/spark/test/FeatureWithLabelJoinTest.scala index 2b05c224b2..81edd74c61 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/FeatureWithLabelJoinTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/FeatureWithLabelJoinTest.scala @@ -33,11 +33,11 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions.max import org.apache.spark.sql.functions.min import org.junit.Assert.assertEquals -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec import org.slf4j.Logger import org.slf4j.LoggerFactory -class FeatureWithLabelJoinTest { +class FeatureWithLabelJoinTest extends AnyFlatSpec { @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) val spark: SparkSession = SparkSessionBuilder.build("FeatureWithLabelJoinTest", local = true) @@ -50,8 +50,7 @@ class FeatureWithLabelJoinTest { private val viewsGroupBy = TestUtils.createViewsGroupBy(namespace, spark) private val left = viewsGroupBy.groupByConf.sources.get(0) - @Test - def testFinalViews(): Unit = { + it should "final views" in { // create test feature join table val featureTable = s"${namespace}.${tableName}" createTestFeatureTable().write.saveAsTable(featureTable) @@ -113,8 +112,7 @@ class FeatureWithLabelJoinTest { assertEquals("2022-11-11", latest.agg(max("label_ds")).first().getString(0)) } - @Test - def testFinalViewsWithAggLabel(): Unit = { + it should "final views with agg label" in { // create test feature join table val tableName = "label_agg_table" val featureTable = s"${namespace}.${tableName}" diff --git a/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala b/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala index 3d07ee5b5c..b7b87b9753 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala @@ -49,7 +49,7 @@ import org.apache.spark.sql.functions.lit import org.junit.Assert.assertEquals import org.junit.Assert.assertFalse import org.junit.Assert.assertTrue -import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.flatspec.AnyFlatSpec import org.slf4j.Logger import org.slf4j.LoggerFactory @@ -66,7 +66,7 @@ import scala.concurrent.duration.SECONDS import scala.io.Source // Run as follows: sbt "spark/testOnly -- -n fetchertest" -class FetcherTest extends AnyFunSuite with TaggedFilterSuite { +class FetcherTest extends AnyFlatSpec with TaggedFilterSuite { override def tagName: String = "fetchertest" @@ -79,7 +79,7 @@ class FetcherTest extends AnyFunSuite with TaggedFilterSuite { private val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) private val yesterday = tableUtils.partitionSpec.before(today) - test("test metadata store") { + it should "test metadata store" in { implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1)) implicit val tableUtils: TableUtils = TableUtils(spark) @@ -721,13 +721,13 @@ class FetcherTest extends AnyFunSuite with TaggedFilterSuite { assertEquals(0, diff.count()) } - test("test temporal fetch join deterministic") { + it should "test temporal fetch join deterministic" in { val namespace = "deterministic_fetch" val joinConf = generateMutationData(namespace) compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true) } - test("test temporal fetch join generated") { + it should "test temporal fetch join generated" in { val namespace = "generated_fetch" val joinConf = generateRandomData(namespace) compareTemporalFetch(joinConf, @@ -737,14 +737,14 @@ class FetcherTest extends AnyFunSuite with TaggedFilterSuite { dropDsOnWrite = false) } - test("test temporal tiled fetch join deterministic") { + it should "test temporal tiled fetch join deterministic" in { val namespace = "deterministic_tiled_fetch" val joinConf = generateEventOnlyData(namespace, groupByCustomJson = Some("{\"enable_tiling\": true}")) compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true) } // test soft-fail on missing keys - test("test empty request") { + it should "test empty request" in { val namespace = "empty_request" val joinConf = generateRandomData(namespace, 5, 5) implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1)) diff --git a/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala b/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala index fd1b6d9422..d1da1b6b42 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala @@ -48,18 +48,17 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.types.{LongType => SparkLongType} import org.apache.spark.sql.types.{StringType => SparkStringType} import org.junit.Assert._ -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec import scala.collection.mutable -class GroupByTest { +class GroupByTest extends AnyFlatSpec { lazy val spark: SparkSession = SparkSessionBuilder.build("GroupByTest", local = true) val tableUtils: TableUtils = TableUtils(spark) implicit val partitionSpec: PartitionSpec = tableUtils.partitionSpec - @Test - def testSnapshotEntities(): Unit = { + it should "snapshot entities" in { val schema = List( Column("user", StringType, 10), Column(Constants.TimeColumn, LongType, 100), // ts = last 100 days @@ -92,8 +91,7 @@ class GroupByTest { assertEquals(0, diff.count()) } - @Test - def testSnapshotEvents(): Unit = { + it should "snapshot events" in { val schema = List( Column("user", StringType, 10), // ts = last 10 days Column("session_length", IntType, 2), @@ -144,8 +142,7 @@ class GroupByTest { assertEquals(0, diff.count()) } - @Test - def eventsLastKTest(): Unit = { + it should "events last k" in { val eventSchema = List( Column("user", StringType, 10), Column("listing_view", StringType, 100) @@ -214,8 +211,7 @@ class GroupByTest { } } } - @Test - def testTemporalEvents(): Unit = { + it should "temporal events" in { val eventSchema = List( Column("user", StringType, 10), Column("session_length", IntType, 10000) @@ -279,8 +275,7 @@ class GroupByTest { } // Test that the output of Group by with Step Days is the same as the output without Steps (full data range) - @Test - def testStepDaysConsistency(): Unit = { + it should "step days consistency" in { val (source, endPartition) = createTestSource() val tableUtils = TableUtils(spark) @@ -299,8 +294,7 @@ class GroupByTest { assertEquals(0, diff.count()) } - @Test - def testGroupByAnalyzer(): Unit = { + it should "group by analyzer" in { val (source, endPartition) = createTestSource(30) val tableUtils = TableUtils(spark) @@ -328,8 +322,7 @@ class GroupByTest { }) } - @Test - def testGroupByNoAggregationAnalyzer(): Unit = { + it should "group by no aggregation analyzer" in { val (source, endPartition) = createTestSource(30) val testName = "unit_analyze_test_item_no_agg" @@ -359,8 +352,7 @@ class GroupByTest { } // test that OrderByLimit and OrderByLimitTimed serialization works well with Spark's data type - @Test - def testFirstKLastKTopKBottomKApproxUniqueCount(): Unit = { + it should "first k last k top k bottom k approx unique count" in { val (source, endPartition) = createTestSource() val tableUtils = TableUtils(spark) @@ -476,8 +468,7 @@ class GroupByTest { } // Test percentile Impl on Spark. - @Test - def testPercentiles(): Unit = { + it should "percentiles" in { val (source, endPartition) = createTestSource(suffix = "_percentile") val tableUtils = TableUtils(spark) val namespace = "test_percentiles" @@ -500,8 +491,7 @@ class GroupByTest { additionalAgg = aggs) } - @Test - def testApproxHistograms(): Unit = { + it should "approx histograms" in { val (source, endPartition) = createTestSource(suffix = "_approx_histogram") val tableUtils = TableUtils(spark) val namespace = "test_approx_histograms" @@ -554,8 +544,7 @@ class GroupByTest { assert(!histogramValues.contains(0)) } - @Test - def testReplaceJoinSource(): Unit = { + it should "replace join source" in { val namespace = "replace_join_source_ns" val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) @@ -570,8 +559,7 @@ class GroupByTest { assertEquals(query, newGroupBy.sources.get(0).query) } - @Test - def testGroupByFromChainingGB(): Unit = { + it should "group by from chaining gb" in { val namespace = "test_chaining_gb" val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) val joinName = "parent_join_table" @@ -630,8 +618,7 @@ class GroupByTest { assertEquals(0, diff.count()) } - @Test - def testDescriptiveStats(): Unit = { + it should "descriptive stats" in { val (source, endPartition) = createTestSource(suffix = "_descriptive_stats") val tableUtils = TableUtils(spark) val namespace = "test_descriptive_stats" diff --git a/spark/src/test/scala/ai/chronon/spark/test/GroupByUploadTest.scala b/spark/src/test/scala/ai/chronon/spark/test/GroupByUploadTest.scala index 9b1eafa843..c14cbdcd98 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/GroupByUploadTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/GroupByUploadTest.scala @@ -30,22 +30,21 @@ import ai.chronon.spark.utils.MockApi import com.google.gson.Gson import org.apache.spark.sql.SparkSession import org.junit.Assert.assertEquals -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec import org.slf4j.Logger import org.slf4j.LoggerFactory import scala.concurrent.Await import scala.concurrent.duration.DurationInt -class GroupByUploadTest { +class GroupByUploadTest extends AnyFlatSpec { @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) lazy val spark: SparkSession = SparkSessionBuilder.build("GroupByUploadTest", local = true) private val namespace = "group_by_upload_test" private val tableUtils = TableUtils(spark) - @Test - def temporalEventsLastKTest(): Unit = { + it should "temporal events last k" in { val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) val yesterday = tableUtils.partitionSpec.before(today) tableUtils.createDatabase(namespace) @@ -73,8 +72,7 @@ class GroupByUploadTest { GroupByUpload.run(groupByConf, endDs = yesterday) } - @Test - def structSupportTest(): Unit = { + it should "struct support" in { val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) val yesterday = tableUtils.partitionSpec.before(today) tableUtils.createDatabase(namespace) @@ -115,8 +113,7 @@ class GroupByUploadTest { GroupByUpload.run(groupByConf, endDs = yesterday) } - @Test - def multipleAvgCountersTest(): Unit = { + it should "multiple avg counters" in { val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) val yesterday = tableUtils.partitionSpec.before(today) tableUtils.createDatabase(namespace) @@ -150,8 +147,7 @@ class GroupByUploadTest { // joinLeft = (review, category, rating) [ratings] // joinPart = (review, user, listing) [reviews] // groupBy = keys:[listing, category], aggs:[avg(rating)] - @Test - def listingRatingCategoryJoinSourceTest(): Unit = { + it should "listing rating category join source" in { tableUtils.createDatabase(namespace) tableUtils.sql(s"USE $namespace") diff --git a/spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala b/spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala index 8b595f9136..4462ed1ea4 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.types.{StringType => SparkStringType} import org.junit.Assert._ -import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.flatspec.AnyFlatSpec import scala.collection.JavaConverters._ @@ -54,7 +54,7 @@ object TestRow { } } // Run as follows: sbt "spark/testOnly -- -n jointest" -class JoinTest extends AnyFunSuite with TaggedFilterSuite { +class JoinTest extends AnyFlatSpec with TaggedFilterSuite { val spark: SparkSession = SparkSessionBuilder.build("JoinTest", local = true) private implicit val tableUtils = TableTestUtils(spark) @@ -69,7 +69,7 @@ class JoinTest extends AnyFunSuite with TaggedFilterSuite { override def tagName: String = "jointest" - test("testing basic spark dynamic partition overwrite") { + it should "testing basic spark dynamic partition overwrite" in { import org.apache.spark.sql.SaveMode import spark.implicits._ @@ -108,7 +108,7 @@ class JoinTest extends AnyFunSuite with TaggedFilterSuite { } - test("test events entities snapshot") { + it should "test events entities snapshot" in { val dollarTransactions = List( Column("user", StringType, 100), Column("user_name", api.StringType, 100), @@ -311,7 +311,7 @@ class JoinTest extends AnyFunSuite with TaggedFilterSuite { assertEquals(0, diff2.count()) } - test("test entities entities") { + it should "test entities entities" in { // untimed/unwindowed entities on right // right side val weightSchema = List( @@ -431,7 +431,7 @@ class JoinTest extends AnyFunSuite with TaggedFilterSuite { */ } - test("test entities entities no historical backfill") { + it should "test entities entities no historical backfill" in { // Only backfill latest partition if historical_backfill is turned off val weightSchema = List( Column("user", api.StringType, 1000), @@ -484,7 +484,7 @@ class JoinTest extends AnyFunSuite with TaggedFilterSuite { assertEquals(allPartitions.toList(0), end) } - test("test events events snapshot") { + it should "test events events snapshot" in { val viewsSchema = List( Column("user", api.StringType, 10000), Column("item", api.StringType, 100), @@ -553,7 +553,7 @@ class JoinTest extends AnyFunSuite with TaggedFilterSuite { assertEquals(diff.count(), 0) } - test("test events events temporal") { + it should "test events events temporal" in { val joinConf = getEventsEventsTemporal("temporal") val viewsSchema = List( @@ -630,7 +630,7 @@ class JoinTest extends AnyFunSuite with TaggedFilterSuite { assertEquals(diff.count(), 0) } - test("test events events cumulative") { + it should "test events events cumulative" in { // Create a cumulative source GroupBy val viewsTable = s"$namespace.view_cumulative" val viewsGroupBy = getViewsGroupBy(suffix = "cumulative", makeCumulative = true) @@ -729,7 +729,7 @@ class JoinTest extends AnyFunSuite with TaggedFilterSuite { } - test("test no agg") { + it should "test no agg" in { // Left side entities, right side entities no agg // Also testing specific select statement (rather than select *) val namesSchema = List( @@ -809,7 +809,7 @@ class JoinTest extends AnyFunSuite with TaggedFilterSuite { assertEquals(diff.count(), 0) } - test("test versioning") { + it should "test versioning" in { val joinConf = getEventsEventsTemporal("versioning") // Run the old join to ensure that tables exist @@ -963,7 +963,7 @@ class JoinTest extends AnyFunSuite with TaggedFilterSuite { } - test("test end partition join") { + it should "test end partition join" in { val join = getEventsEventsTemporal("end_partition_test") val start = join.getLeft.query.startPartition val end = tableUtils.partitionSpec.after(start) @@ -980,7 +980,7 @@ class JoinTest extends AnyFunSuite with TaggedFilterSuite { assertTrue(ds.first().getString(0) < today) } - test("test skip bloom filter join backfill") { + it should "test skip bloom filter join backfill" in { val testSpark: SparkSession = SparkSessionBuilder.build("JoinTest", local = true, @@ -1029,7 +1029,7 @@ class JoinTest extends AnyFunSuite with TaggedFilterSuite { assertEquals(leftSideCount, skipBloomComputed.count()) } - test("test struct join") { + it should "test struct join" in { val nameSuffix = "_struct_test" val itemQueries = List(Column("item", api.StringType, 100)) val itemQueriesTable = s"$namespace.item_queries_$nameSuffix" @@ -1085,7 +1085,7 @@ class JoinTest extends AnyFunSuite with TaggedFilterSuite { toCompute.computeJoin() } - test("test migration") { + it should "test migration" in { // Left val itemQueriesTable = s"$namespace.item_queries" @@ -1134,7 +1134,7 @@ class JoinTest extends AnyFunSuite with TaggedFilterSuite { assertEquals(0, join.tablesToDrop(productionHashV2).length) } - test("testKeyMappingOverlappingFields") { + it should "testKeyMappingOverlappingFields" in { // test the scenario when a key_mapping is a -> b, (right key b is mapped to left key a) and // a happens to be another field in the same group by @@ -1192,7 +1192,7 @@ class JoinTest extends AnyFunSuite with TaggedFilterSuite { * Run computeJoin(). * Check if the selected join part is computed and the other join parts are not computed. */ - test("test selected join parts") { + it should "test selected join parts" in { // Left val itemQueries = List( Column("item", api.StringType, 100), diff --git a/spark/src/test/scala/ai/chronon/spark/test/JoinUtilsTest.scala b/spark/src/test/scala/ai/chronon/spark/test/JoinUtilsTest.scala index aa1e47b8e2..8da0796a50 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/JoinUtilsTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/JoinUtilsTest.scala @@ -35,19 +35,18 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.junit.Assert._ -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec import scala.collection.mutable import scala.util.Try -class JoinUtilsTest { +class JoinUtilsTest extends AnyFlatSpec { lazy val spark: SparkSession = SparkSessionBuilder.build("JoinUtilsTest", local = true) private val tableUtils = TableUtils(spark) private implicit val partitionSpec: PartitionSpec = tableUtils.partitionSpec private val namespace = "joinUtil" - @Test - def testUDFSetAdd(): Unit = { + it should "udf set add" in { val data = Seq( Row(Seq("a", "b", "c"), "a"), Row(Seq("a", "b", "c"), "d"), @@ -80,8 +79,7 @@ class JoinUtilsTest { } } - @Test - def testUDFContainsAny(): Unit = { + it should "udf contains any" in { val data = Seq( Row(Seq("a", "b", "c"), Seq("a")), Row(Seq("a", "b", "c"), Seq("a", "b")), @@ -130,8 +128,7 @@ class JoinUtilsTest { df } - @Test - def testCoalescedJoinMismatchedKeyColumns(): Unit = { + it should "coalesced join mismatched key columns" in { // mismatch data type on join keys testJoinScenario( new StructType() @@ -145,8 +142,7 @@ class JoinUtilsTest { ) } - @Test - def testCoalescedJoinMismatchedSharedColumns(): Unit = { + it should "coalesced join mismatched shared columns" in { // mismatch data type on shared columns testJoinScenario( new StructType() @@ -160,8 +156,7 @@ class JoinUtilsTest { ) } - @Test - def testCoalescedJoinMissingKeys(): Unit = { + it should "coalesced join missing keys" in { // missing some keys testJoinScenario( new StructType() @@ -176,8 +171,7 @@ class JoinUtilsTest { ) } - @Test - def testCoalescedJoinNoSharedColumns(): Unit = { + it should "coalesced join no shared columns" in { // test no shared columns val df = testJoinScenario( new StructType() @@ -192,8 +186,7 @@ class JoinUtilsTest { assertEquals(3, df.get.columns.length) } - @Test - def testCoalescedJoinSharedColumns(): Unit = { + it should "coalesced join shared columns" in { // test shared columns val df = testJoinScenario( new StructType() @@ -210,8 +203,7 @@ class JoinUtilsTest { assertEquals(4, df.get.columns.length) } - @Test - def testCoalescedJoinOneSidedLeft(): Unit = { + it should "coalesced join one sided left" in { // test when left side only has keys val df = testJoinScenario( new StructType() @@ -226,8 +218,7 @@ class JoinUtilsTest { assertEquals(3, df.get.columns.length) } - @Test - def testCoalescedJoinOneSidedRight(): Unit = { + it should "coalesced join one sided right" in { // test when right side only has keys val df = testJoinScenario( new StructType() @@ -242,8 +233,7 @@ class JoinUtilsTest { assertEquals(3, df.get.columns.length) } - @Test - def testCreateJoinView(): Unit = { + it should "create join view" in { val finalViewName = "testCreateView" val leftTableName = "joinUtil.testFeatureTable" val rightTableName = "joinUtil.testLabelTable" @@ -281,8 +271,7 @@ class JoinUtilsTest { assertEquals(properties.get.get("labelTable"), Some(rightTableName)) } - @Test - def testCreateLatestLabelView(): Unit = { + it should "create latest label view" in { val finalViewName = "joinUtil.testFinalView" val leftTableName = "joinUtil.testFeatureTable2" val rightTableName = "joinUtil.testLabelTable2" @@ -327,16 +316,14 @@ class JoinUtilsTest { assertEquals(properties.get.get("newProperties"), Some("value")) } - @Test - def testFilterColumns(): Unit = { + it should "filter columns" in { val testDf = createSampleTable() val filter = Array("listing", "ds", "feature_review") val filteredDf = JoinUtils.filterColumns(testDf, filter) assertTrue(filteredDf.schema.fieldNames.sorted sameElements filter.sorted) } - @Test - def testGetRangesToFill(): Unit = { + it should "get ranges to fill" in { tableUtils.createDatabase(namespace) // left table val itemQueries = List(Column("item", api.StringType, 100)) @@ -352,8 +339,7 @@ class JoinUtilsTest { assertEquals(range, PartitionRange(startPartition, endPartition)) } - @Test - def testGetRangesToFillWithOverride(): Unit = { + it should "get ranges to fill with override" in { tableUtils.createDatabase(namespace) // left table val itemQueries = List(Column("item", api.StringType, 100)) diff --git a/spark/src/test/scala/ai/chronon/spark/test/KafkaStreamBuilderTest.scala b/spark/src/test/scala/ai/chronon/spark/test/KafkaStreamBuilderTest.scala index 31bda24fb8..e7889c7067 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/KafkaStreamBuilderTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/KafkaStreamBuilderTest.scala @@ -20,14 +20,13 @@ import ai.chronon.online.TopicInfo import ai.chronon.spark.SparkSessionBuilder import ai.chronon.spark.streaming.KafkaStreamBuilder import org.apache.spark.sql.SparkSession -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec -class KafkaStreamBuilderTest { +class KafkaStreamBuilderTest extends AnyFlatSpec { private val spark: SparkSession = SparkSessionBuilder.build("KafkaStreamBuilderTest", local = true) - @Test(expected = classOf[RuntimeException]) - def testKafkaStreamDoesNotExist(): Unit = { + it should "kafka stream does not exist" in { val topicInfo = TopicInfo.parse("kafka://test_topic/schema=my_schema/host=X/port=Y") KafkaStreamBuilder.from(topicInfo)(spark, Map.empty) } diff --git a/spark/src/test/scala/ai/chronon/spark/test/LabelJoinTest.scala b/spark/src/test/scala/ai/chronon/spark/test/LabelJoinTest.scala index c009ac7f54..cdb8f83302 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/LabelJoinTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/LabelJoinTest.scala @@ -26,10 +26,10 @@ import ai.chronon.spark._ import org.apache.spark.sql.Row import org.apache.spark.sql.SparkSession import org.junit.Assert.assertEquals -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec import org.slf4j.LoggerFactory -class LabelJoinTest { +class LabelJoinTest extends AnyFlatSpec { @transient private lazy val logger = LoggerFactory.getLogger(getClass) val spark: SparkSession = SparkSessionBuilder.build("LabelJoinTest", local = true) @@ -44,8 +44,7 @@ class LabelJoinTest { private val labelGroupBy = TestUtils.createRoomTypeGroupBy(namespace, spark) private val left = viewsGroupBy.groupByConf.sources.get(0) - @Test - def testLabelJoin(): Unit = { + it should "label join" in { val labelGroupBy = TestUtils.createRoomTypeGroupBy(namespace, spark, "listing_attributes").groupByConf val labelJoinConf = createTestLabelJoin(30, 20, Seq(labelGroupBy)) val joinConf = Builders.Join( @@ -81,8 +80,7 @@ class LabelJoinTest { assertEquals(0, diff.count()) } - @Test - def testLabelJoinMultiLabels(): Unit = { + it should "label join multi labels" in { val labelGroupBy1 = TestUtils.createRoomTypeGroupBy(namespace, spark).groupByConf val labelGroupBy2 = TestUtils.createReservationGroupBy(namespace, spark).groupByConf val labelJoinConf = createTestLabelJoin(30, 20, Seq(labelGroupBy1, labelGroupBy2)) @@ -134,8 +132,7 @@ class LabelJoinTest { assertEquals(0, diff.count()) } - @Test - def testLabelDsDoesNotExist(): Unit = { + it should "label ds does not exist" in { val labelGroupBy = TestUtils.createRoomTypeGroupBy(namespace, spark, "listing_label_not_exist").groupByConf val labelJoinConf = createTestLabelJoin(30, 20, Seq(labelGroupBy)) val joinConf = Builders.Join( @@ -157,8 +154,7 @@ class LabelJoinTest { null) } - @Test - def testLabelRefresh(): Unit = { + it should "label refresh" in { val labelGroupBy = TestUtils.createRoomTypeGroupBy(namespace, spark, "listing_attributes_refresh").groupByConf val labelJoinConf = createTestLabelJoin(60, 20, Seq(labelGroupBy)) val joinConf = Builders.Join( @@ -188,8 +184,7 @@ class LabelJoinTest { assertEquals(computedRows.toSet, refreshedRows.toSet) } - @Test - def testLabelEvolution(): Unit = { + it should "label evolution" in { val labelGroupBy = TestUtils.createRoomTypeGroupBy(namespace, spark, "listing_labels").groupByConf val labelJoinConf = createTestLabelJoin(30, 20, Seq(labelGroupBy)) val tableName = "label_evolution" @@ -236,8 +231,7 @@ class LabelJoinTest { "NEW_HOST") } - @Test(expected = classOf[AssertionError]) - def testLabelJoinInvalidSource(): Unit = { + it should "label join invalid source" in { // Invalid left data model entities val labelJoin = Builders.LabelPart( labels = Seq( @@ -256,8 +250,7 @@ class LabelJoinTest { new LabelJoin(invalidJoinConf, tableUtils, labelDS).computeLabelJoin() } - @Test(expected = classOf[AssertionError]) - def testLabelJoinInvalidLabelGroupByDataModal(): Unit = { + it should "label join invalid label group by data modal" in { // Invalid data model entities with aggregations, expected Events val agg_label_conf = Builders.GroupBy( sources = Seq(labelGroupBy.groupByConf.sources.get(0)), @@ -289,8 +282,7 @@ class LabelJoinTest { new LabelJoin(invalidJoinConf, tableUtils, labelDS).computeLabelJoin() } - @Test(expected = classOf[AssertionError]) - def testLabelJoinInvalidAggregations(): Unit = { + it should "label join invalid aggregations" in { // multi window aggregations val agg_label_conf = Builders.GroupBy( sources = Seq(labelGroupBy.groupByConf.sources.get(0)), @@ -322,8 +314,7 @@ class LabelJoinTest { new LabelJoin(invalidJoinConf, tableUtils, labelDS).computeLabelJoin() } - @Test - def testLabelAggregations(): Unit = { + it should "label aggregations" in { // left : listing_id, _, _, ts, ds val rows = List( Row(1L, 20L, "2022-10-02 11:00:00", "2022-10-02"), @@ -380,8 +371,7 @@ class LabelJoinTest { assertEquals(0, diff.count()) } - @Test - def testLabelAggregationsWithLargerDataset(): Unit = { + it should "label aggregations with larger dataset" in { val labelTableName = s"$namespace.listing_status" val listingTableName = s"$namespace.listing_views_agg_left" val listingTable = TestUtils.buildListingTable(spark, listingTableName) diff --git a/spark/src/test/scala/ai/chronon/spark/test/LocalDataLoaderTest.scala b/spark/src/test/scala/ai/chronon/spark/test/LocalDataLoaderTest.scala index 45193b2181..2c41989a46 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/LocalDataLoaderTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/LocalDataLoaderTest.scala @@ -24,7 +24,7 @@ import org.apache.commons.io.FileUtils import org.apache.spark.sql.SparkSession import org.junit.AfterClass import org.junit.Assert.assertEquals -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec import java.io.File @@ -41,10 +41,9 @@ object LocalDataLoaderTest { } } -class LocalDataLoaderTest { +class LocalDataLoaderTest extends AnyFlatSpec { - @Test - def loadDataFileAsTableShouldBeCorrect(): Unit = { + it should "load data file as table should be correct" in { val resourceURL = Option(getClass.getResource("/local_data_csv/test_table_1_data.csv")) .getOrElse(throw new IllegalStateException("Required test resource not found")) val file = new File(resourceURL.getFile) @@ -58,8 +57,7 @@ class LocalDataLoaderTest { assertEquals(3, loadedDataDf.count()) } - @Test - def loadDataRecursivelyShouldBeCorrect(): Unit = { + it should "load data recursively should be correct" in { val resourceURI = getClass.getResource("/local_data_csv") val path = new File(resourceURI.getFile) LocalDataLoader.loadDataRecursively(path, spark) diff --git a/spark/src/test/scala/ai/chronon/spark/test/LocalExportTableAbilityTest.scala b/spark/src/test/scala/ai/chronon/spark/test/LocalExportTableAbilityTest.scala index 947f1221d8..e78600c726 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/LocalExportTableAbilityTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/LocalExportTableAbilityTest.scala @@ -25,15 +25,15 @@ import org.apache.spark.sql.SparkSession import org.junit.Assert.assertEquals import org.junit.Assert.assertFalse import org.junit.Assert.assertTrue -import org.junit.Test import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.doNothing import org.mockito.Mockito.mock import org.mockito.Mockito.times import org.mockito.Mockito.verify import org.rogach.scallop.ScallopConf +import org.scalatest.flatspec.AnyFlatSpec -class LocalExportTableAbilityTest { +class LocalExportTableAbilityTest extends AnyFlatSpec { class TestArgs(args: Array[String], localTableExporter: LocalTableExporter) extends ScallopConf(args) with OfflineSubcommand @@ -47,23 +47,20 @@ class LocalExportTableAbilityTest { protected override def buildLocalTableExporter(tableUtils: TableUtils): LocalTableExporter = localTableExporter } - @Test - def localTableExporterIsNotUsedWhenNotInLocalMode(): Unit = { + it should "local table exporter is not used when not in local mode" in { val argList = Seq("--conf-path", "joins/team/example_join.v1", "--end-date", "2023-03-03") val args = new TestArgs(argList.toArray, mock(classOf[LocalTableExporter])) assertFalse(args.shouldExport()) } - @Test - def localTableExporterIsNotUsedWhenNotExportPathIsNotSpecified(): Unit = { + it should "local table exporter is not used when not export path is not specified" in { val argList = Seq("--conf-path", "joins/team/example_join.v1", "--end-date", "2023-03-03", "--local-data-path", "somewhere") val args = new TestArgs(argList.toArray, mock(classOf[LocalTableExporter])) assertFalse(args.shouldExport()) } - @Test - def localTableExporterIsUsedWhenNecessary(): Unit = { + it should "local table exporter is used when necessary" in { val targetOutputPath = "path/to/somewhere" val targetFormat = "parquet" val prefix = "test_prefix" diff --git a/spark/src/test/scala/ai/chronon/spark/test/LocalTableExporterTest.scala b/spark/src/test/scala/ai/chronon/spark/test/LocalTableExporterTest.scala index 23149d6c7a..005e0e3934 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/LocalTableExporterTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/LocalTableExporterTest.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.SparkSession import org.junit.AfterClass import org.junit.Assert.assertEquals import org.junit.Assert.assertTrue -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec import java.io.File @@ -50,10 +50,9 @@ object LocalTableExporterTest { } } -class LocalTableExporterTest { +class LocalTableExporterTest extends AnyFlatSpec { - @Test - def exporterExportsTablesCorrectly(): Unit = { + it should "exporter exports tables correctly" in { val schema = List( Column("user", StringType, 10), Column(Constants.TimeColumn, LongType, 10000), // ts = last 10000 days to avoid conflict @@ -83,8 +82,7 @@ class LocalTableExporterTest { generatedData.zip(loadedData).foreach { case (g, l) => assertEquals(g, l) } } - @Test - def exporterExportsMultipleTablesWithFilesInCorrectPlace(): Unit = { + it should "exporter exports multiple tables with files in correct place" in { val schema = List( Column("user", StringType, 100000), Column(Constants.TimeColumn, LongType, 10000), diff --git a/spark/src/test/scala/ai/chronon/spark/test/MetadataExporterTest.scala b/spark/src/test/scala/ai/chronon/spark/test/MetadataExporterTest.scala index b83f14bfda..f37d756bc6 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/MetadataExporterTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/MetadataExporterTest.scala @@ -25,16 +25,16 @@ import ai.chronon.spark.TableUtils import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.module.scala.DefaultScalaModule import com.google.common.io.Files -import junit.framework.TestCase import org.apache.spark.sql.SparkSession import org.junit.Assert.assertEquals +import org.scalatest.flatspec.AnyFlatSpec import org.slf4j.Logger import org.slf4j.LoggerFactory import java.io.File import scala.io.Source -class MetadataExporterTest extends TestCase { +class MetadataExporterTest extends AnyFlatSpec { @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) val sessionName = "MetadataExporter" @@ -64,7 +64,7 @@ class MetadataExporterTest extends TestCase { } } - def testMetadataExport(): Unit = { + it should "metadata export" in { // Create the tables. val namespace = "example_namespace" val tablename = "table" diff --git a/spark/src/test/scala/ai/chronon/spark/test/MigrationCompareTest.scala b/spark/src/test/scala/ai/chronon/spark/test/MigrationCompareTest.scala index 172980c922..b1b2471274 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/MigrationCompareTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/MigrationCompareTest.scala @@ -28,9 +28,9 @@ import ai.chronon.spark.SparkSessionBuilder import ai.chronon.spark.TableUtils import ai.chronon.spark.stats.CompareJob import org.apache.spark.sql.SparkSession -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec -class MigrationCompareTest { +class MigrationCompareTest extends AnyFlatSpec { lazy val spark: SparkSession = SparkSessionBuilder.build("MigrationCompareTest", local = true) private val tableUtils = TableUtils(spark) private val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) @@ -95,8 +95,7 @@ class MigrationCompareTest { (joinConf, stagingQueryConf) } - @Test - def testMigrateCompare(): Unit = { + it should "migrate compare" in { val (joinConf, stagingQueryConf) = setupTestData() val (compareDf, metricsDf, metrics: DataMetrics) = @@ -105,8 +104,7 @@ class MigrationCompareTest { assert(result.size == 0) } - @Test - def testMigrateCompareWithLessColumns(): Unit = { + it should "migrate compare with less columns" in { val (joinConf, _) = setupTestData() // Run the staging query to generate the corresponding table for comparison @@ -124,8 +122,7 @@ class MigrationCompareTest { assert(result.size == 0) } - @Test - def testMigrateCompareWithWindows(): Unit = { + it should "migrate compare with windows" in { val (joinConf, stagingQueryConf) = setupTestData() val (compareDf, metricsDf, metrics: DataMetrics) = @@ -134,8 +131,7 @@ class MigrationCompareTest { assert(result.size == 0) } - @Test - def testMigrateCompareWithLessData(): Unit = { + it should "migrate compare with less data" in { val (joinConf, _) = setupTestData() val stagingQueryConf = Builders.StagingQuery( diff --git a/spark/src/test/scala/ai/chronon/spark/test/MutationsTest.scala b/spark/src/test/scala/ai/chronon/spark/test/MutationsTest.scala index 6b7de749d7..f85c4f8c5c 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/MutationsTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/MutationsTest.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.types.LongType import org.apache.spark.sql.types.StringType import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types.StructType -import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.flatspec.AnyFlatSpec import org.slf4j.Logger import org.slf4j.LoggerFactory @@ -48,7 +48,7 @@ import org.slf4j.LoggerFactory * Join is the events and the entity value at the exact timestamp of the ts. * To run: sbt "spark/testOnly -- -n mutationstest" */ -class MutationsTest extends AnyFunSuite with TaggedFilterSuite { +class MutationsTest extends AnyFlatSpec with TaggedFilterSuite { @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) override def tagName: String = "mutationstest" @@ -449,7 +449,7 @@ class MutationsTest extends AnyFunSuite with TaggedFilterSuite { * * Compute Join for when mutations are just insert on values. */ - test("test simplest case") { + it should "test simplest case" in { val suffix = "simple" val leftData = Seq( // {listing_id, some_col, ts, ds} @@ -507,7 +507,7 @@ class MutationsTest extends AnyFunSuite with TaggedFilterSuite { * * Compute Join when mutations have an update on values. */ - test("test update value case") { + it should "test update value case" in { val suffix = "update_value" val leftData = Seq( // {listing_id, ts, event, ds} @@ -558,7 +558,7 @@ class MutationsTest extends AnyFunSuite with TaggedFilterSuite { * * Compute Join when mutations have an update on keys. */ - test("test update key case") { + it should "test update key case" in { val suffix = "update_key" val leftData = Seq( Row(1, 1, millis("2021-04-10 01:00:00"), "2021-04-10"), @@ -615,7 +615,7 @@ class MutationsTest extends AnyFunSuite with TaggedFilterSuite { * For this test we request a value for id 2, w/ mutations happening in the day before and after the time requested. * The consistency constraint here is that snapshot 4/8 + mutations 4/8 = snapshot 4/9 */ - test("test inconsistent ts left case") { + it should "test inconsistent ts left case" in { val suffix = "inconsistent_ts" val leftData = Seq( Row(1, 1, millis("2021-04-10 01:00:00"), "2021-04-10"), @@ -684,7 +684,7 @@ class MutationsTest extends AnyFunSuite with TaggedFilterSuite { * Compute Join, the snapshot aggregation should decay, this is the main reason to have * resolution in snapshot IR */ - test("test decayed window case") { + it should "test decayed window case" in { val suffix = "decayed" val leftData = Seq( Row(2, 1, millis("2021-04-09 01:30:00"), "2021-04-10"), @@ -755,7 +755,7 @@ class MutationsTest extends AnyFunSuite with TaggedFilterSuite { * Compute Join, the snapshot aggregation should decay. * When there are no mutations returning the collapsed is not enough depending on the time. */ - test("test decayed window case no mutation") { + it should "test decayed window case no mutation" in { val suffix = "decayed_v2" val leftData = Seq( Row(2, 1, millis("2021-04-10 01:00:00"), "2021-04-10"), @@ -803,7 +803,7 @@ class MutationsTest extends AnyFunSuite with TaggedFilterSuite { * Compute Join, the snapshot aggregation should decay. * When there's no snapshot the value would depend only on mutations of the day. */ - test("test no snapshot just mutation") { + it should "test no snapshot just mutation" in { val suffix = "no_mutation" val leftData = Seq( Row(2, 1, millis("2021-04-10 00:07:00"), "2021-04-10"), @@ -843,7 +843,7 @@ class MutationsTest extends AnyFunSuite with TaggedFilterSuite { assert(compareResult(result, expected)) } - test("test with generated data") { + it should "test with generated data" in { val suffix = "generated" val reviews = List( Column("listing_id", api.StringType, 10), diff --git a/spark/src/test/scala/ai/chronon/spark/test/OfflineSubcommandTest.scala b/spark/src/test/scala/ai/chronon/spark/test/OfflineSubcommandTest.scala index 33c69329c5..285b3b7709 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/OfflineSubcommandTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/OfflineSubcommandTest.scala @@ -22,15 +22,15 @@ import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.junit.Assert.assertEquals import org.junit.Assert.assertTrue -import org.junit.Test import org.rogach.scallop.ScallopConf +import org.scalatest.flatspec.AnyFlatSpec import org.yaml.snakeyaml.Yaml import scala.io.Source import collection.JavaConverters._ -class OfflineSubcommandTest { +class OfflineSubcommandTest extends AnyFlatSpec { class TestArgs(args: Array[String]) extends ScallopConf(args) with OfflineSubcommand { verify() @@ -42,16 +42,14 @@ class OfflineSubcommandTest { override def isLocal: Boolean = true } - @Test - def basicIsParsedCorrectly(): Unit = { + it should "basic is parsed correctly" in { val confPath = "joins/team/example_join.v1" val args = new TestArgs(Seq("--conf-path", confPath).toArray) assertEquals(confPath, args.confPath()) assertTrue(args.localTableMapping.isEmpty) } - @Test - def localTableMappingIsParsedCorrectly(): Unit = { + it should "local table mapping is parsed correctly" in { val confPath = "joins/team/example_join.v1" val endData = "2023-03-03" val argList = Seq("--local-table-mapping", "a=b", "c=d", "--conf-path", confPath, "--end-date", endData) @@ -63,8 +61,7 @@ class OfflineSubcommandTest { assertEquals(endData, args.endDate()) } - @Test - def additionalConfsParsedCorrectly(): Unit = { + it should "additional confs parsed correctly" in { implicit val formats: Formats = DefaultFormats val url = getClass.getClassLoader.getResource("test-driver-additional-confs.yaml") diff --git a/spark/src/test/scala/ai/chronon/spark/test/ResultValidationAbilityTest.scala b/spark/src/test/scala/ai/chronon/spark/test/ResultValidationAbilityTest.scala index f7f1dfb027..96d245b9e6 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/ResultValidationAbilityTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/ResultValidationAbilityTest.scala @@ -27,13 +27,13 @@ import org.apache.spark.sql.SparkSession import org.junit.Assert.assertFalse import org.junit.Assert.assertTrue import org.junit.Before -import org.junit.Test import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.mock import org.mockito.Mockito.when import org.rogach.scallop.ScallopConf +import org.scalatest.flatspec.AnyFlatSpec -class ResultValidationAbilityTest { +class ResultValidationAbilityTest extends AnyFlatSpec { val confPath = "joins/team/example_join.v1" val spark: SparkSession = SparkSessionBuilder.build("test", local = true) val mockTableUtils: TableUtils = mock(classOf[TableUtils]) @@ -51,20 +51,17 @@ class ResultValidationAbilityTest { override def buildSparkSession(): SparkSession = spark } - @Test - def shouldNotValidateWhenComparisonTableIsNotSpecified(): Unit = { + it should "should not validate when comparison table is not specified" in { val args = new TestArgs(Seq("--conf-path", confPath).toArray) assertFalse(args.shouldPerformValidate()) } - @Test - def shouldValidateWhenComparisonTableIsSpecified(): Unit = { + it should "should validate when comparison table is specified" in { val args = new TestArgs(Seq("--conf-path", confPath, "--expected-result-table", "a_table").toArray) assertTrue(args.shouldPerformValidate()) } - @Test - def testSuccessfulValidation(): Unit = { + it should "successful validation" in { val args = new TestArgs(Seq("--conf-path", confPath, "--expected-result-table", "a_table").toArray) // simple testing, more comprehensive testing are already done in CompareTest.scala @@ -78,8 +75,7 @@ class ResultValidationAbilityTest { assertTrue(args.validateResult(df, Seq("keyId", "ds"), mockTableUtils)) } - @Test - def testFailedValidation(): Unit = { + it should "failed validation" in { val args = new TestArgs(Seq("--conf-path", confPath, "--expected-result-table", "a_table").toArray) val columns = Seq("serial", "value", "rating", "keyId", "ds") diff --git a/spark/src/test/scala/ai/chronon/spark/test/SchemaEvolutionTest.scala b/spark/src/test/scala/ai/chronon/spark/test/SchemaEvolutionTest.scala index a5e98bfe0a..fdc10eca86 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/SchemaEvolutionTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/SchemaEvolutionTest.scala @@ -29,7 +29,6 @@ import ai.chronon.spark.SparkSessionBuilder import ai.chronon.spark.TableUtils import ai.chronon.spark.utils.InMemoryKvStore import ai.chronon.spark.utils.MockApi -import junit.framework.TestCase import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Row import org.apache.spark.sql.SparkSession @@ -39,6 +38,7 @@ import org.junit.Assert.assertEquals import org.junit.Assert.assertFalse import org.junit.Assert.assertNotEquals import org.junit.Assert.assertTrue +import org.scalatest.flatspec.AnyFlatSpec import java.nio.charset.StandardCharsets import java.util.Base64 @@ -73,7 +73,7 @@ object JoinTestSuite { } } -class SchemaEvolutionTest extends TestCase { +class SchemaEvolutionTest extends AnyFlatSpec { val spark: SparkSession = SparkSessionBuilder.build("SchemaEvolutionTest", local = true) TimeZone.setDefault(TimeZone.getTimeZone("UTC")) @@ -315,8 +315,10 @@ class SchemaEvolutionTest extends TestCase { } def testSchemaEvolution(namespace: String, joinSuiteV1: JoinTestSuite, joinSuiteV2: JoinTestSuite): Unit = { - assert(joinSuiteV1.joinConf.metaData.name == joinSuiteV2.joinConf.metaData.name, - message = "Schema evolution can only be tested on changes of the SAME join") + + require(joinSuiteV1.joinConf.metaData.name == joinSuiteV2.joinConf.metaData.name, + "Schema evolution can only be tested on changes of the SAME join") + val tableUtils: TableUtils = TableUtils(spark) val inMemoryKvStore = OnlineUtils.buildInMemoryKVStore(namespace) val mockApi = new MockApi(() => inMemoryKvStore, namespace) @@ -441,12 +443,12 @@ class SchemaEvolutionTest extends TestCase { assertTrue(removedFeatures.forall(flattenedDf34.schema.fieldNames.contains(_))) } - def testAddFeatures(): Unit = { + it should "add features" in { val namespace = "add_features" testSchemaEvolution(namespace, createV1Join(namespace), createV2Join(namespace)) } - def testRemoveFeatures(): Unit = { + it should "remove features" in { val namespace = "remove_features" testSchemaEvolution(namespace, createV2Join(namespace), createV1Join(namespace)) } diff --git a/spark/src/test/scala/ai/chronon/spark/test/StagingQueryTest.scala b/spark/src/test/scala/ai/chronon/spark/test/StagingQueryTest.scala index 616f98f837..92db1ba293 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/StagingQueryTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/StagingQueryTest.scala @@ -26,11 +26,11 @@ import ai.chronon.spark.StagingQuery import ai.chronon.spark.TableUtils import org.apache.spark.sql.SparkSession import org.junit.Assert.assertEquals -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec import org.slf4j.Logger import org.slf4j.LoggerFactory -class StagingQueryTest { +class StagingQueryTest extends AnyFlatSpec { @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) lazy val spark: SparkSession = SparkSessionBuilder.build("StagingQueryTest", local = true) implicit private val tableUtils: TableUtils = TableUtils(spark) @@ -40,8 +40,7 @@ class StagingQueryTest { private val namespace = "staging_query_chronon_test" tableUtils.createDatabase(namespace) - @Test - def testStagingQuery(): Unit = { + it should "staging query" in { val schema = List( Column("user", StringType, 10), Column("session_length", IntType, 1000) @@ -110,8 +109,7 @@ class StagingQueryTest { /** Test Staging Query update with new feature/column added to the query. */ - @Test - def testStagingQueryAutoExpand(): Unit = { + it should "staging query auto expand" in { val schema = List( Column("user", StringType, 10), Column("session_length", IntType, 50), @@ -186,8 +184,7 @@ class StagingQueryTest { * Compute in several step ranges a trivial query and for the first step range (first partition) the latest_date * value should be that of the latest partition (today). */ - @Test - def testStagingQueryLatestDate(): Unit = { + it should "staging query latest date" in { val schema = List( Column("user", StringType, 10), Column("session_length", IntType, 1000) @@ -238,8 +235,7 @@ class StagingQueryTest { assertEquals(0, diff.count()) } - @Test - def testStagingQueryMaxDate(): Unit = { + it should "staging query max date" in { val schema = List( Column("user", StringType, 10), Column("session_length", IntType, 1000) diff --git a/spark/src/test/scala/ai/chronon/spark/test/StatsComputeTest.scala b/spark/src/test/scala/ai/chronon/spark/test/StatsComputeTest.scala index a77f35b18d..9e326066b8 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/StatsComputeTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/StatsComputeTest.scala @@ -25,18 +25,17 @@ import ai.chronon.spark.TableUtils import ai.chronon.spark.stats.StatsCompute import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions.lit -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec import org.slf4j.Logger import org.slf4j.LoggerFactory -class StatsComputeTest { +class StatsComputeTest extends AnyFlatSpec { @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) lazy val spark: SparkSession = SparkSessionBuilder.build("StatsComputeTest", local = true) implicit val tableUtils: TableUtils = TableUtils(spark) val namespace: String = "stats_compute_test" - @Test - def summaryTest(): Unit = { + it should "summary" in { val data = Seq( ("1", Some(1L), Some(1.0), Some("a")), ("1", Some(1L), None, Some("b")), @@ -53,8 +52,7 @@ class StatsComputeTest { stats.addDerivedMetrics(result, aggregator).show() } - @Test - def snapshotSummaryTest(): Unit = { + it should "snapshot summary" in { tableUtils.createDatabase(namespace) val data = Seq( ("1", Some(1L), Some(1.0), Some("a")), @@ -76,8 +74,7 @@ class StatsComputeTest { stats.addDerivedMetrics(result, aggregator).save(s"$namespace.testTablenameSnapshot") } - @Test - def generatedSummaryTest(): Unit = { + it should "generated summary" in { val schema = List( Column("user", StringType, 10), Column("session_length", IntType, 10000) @@ -104,8 +101,7 @@ class StatsComputeTest { denormalized.show(truncate = false) } - @Test - def generatedSummaryNoTsTest(): Unit = { + it should "generated summary no ts" in { val schema = List( Column("user", StringType, 10), Column("session_length", IntType, 10000) @@ -135,8 +131,7 @@ class StatsComputeTest { * Test to make sure aggregations are generated when it makes sense. * Example, percentiles are not currently supported for byte. */ - @Test - def generatedSummaryByteTest(): Unit = { + it should "generated summary byte" in { val schema = List( Column("user", StringType, 10), Column("session_length", IntType, 10000) diff --git a/spark/src/test/scala/ai/chronon/spark/test/StreamingTest.scala b/spark/src/test/scala/ai/chronon/spark/test/StreamingTest.scala index 963ca4dc80..cd6116e646 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/StreamingTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/StreamingTest.scala @@ -29,8 +29,8 @@ import ai.chronon.spark.Extensions._ import ai.chronon.spark.test.StreamingTest.buildInMemoryKvStore import ai.chronon.spark.utils.InMemoryKvStore import ai.chronon.spark.{Join => _, _} -import junit.framework.TestCase import org.apache.spark.sql.SparkSession +import org.scalatest.flatspec.AnyFlatSpec import java.util.TimeZone import scala.collection.JavaConverters.asScalaBufferConverter @@ -42,7 +42,7 @@ object StreamingTest { } } -class StreamingTest extends TestCase { +class StreamingTest extends AnyFlatSpec { val spark: SparkSession = SparkSessionBuilder.build("StreamingTest", local = true) val tableUtils: TableUtils = TableUtils(spark) @@ -52,7 +52,7 @@ class StreamingTest extends TestCase { tableUtils.partitionSpec.before(today) private val yearAgo = tableUtils.partitionSpec.minus(today, new Window(365, TimeUnit.DAYS)) - def testStructInStreaming(): Unit = { + it should "struct in streaming" in { tableUtils.createDatabase(namespace) val topicName = "fake_topic" val inMemoryKvStore = buildInMemoryKvStore() diff --git a/spark/src/test/scala/ai/chronon/spark/test/TableUtilsFormatTest.scala b/spark/src/test/scala/ai/chronon/spark/test/TableUtilsFormatTest.scala index 2157d1ec0b..e7ee1ad4bf 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/TableUtilsFormatTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/TableUtilsFormatTest.scala @@ -19,11 +19,11 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions.col import org.junit.Assert.assertEquals import org.junit.Assert.assertTrue -import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.flatspec.AnyFlatSpec import scala.util.Try -class TableUtilsFormatTest extends AnyFunSuite { +class TableUtilsFormatTest extends AnyFlatSpec { import TableUtilsFormatTest._ @@ -32,11 +32,11 @@ class TableUtilsFormatTest extends AnyFunSuite { val spark: SparkSession = SparkSessionBuilder.build("TableUtilsFormatTest", local = true) val tableUtils: TableUtils = TableUtils(spark) - test("testing dynamic classloading") { + it should "testing dynamic classloading" in { assertTrue(tableUtils.tableFormatProvider.isInstanceOf[DefaultFormatProvider]) } - test("test insertion of partitioned data and adding of columns") { + it should "test insertion of partitioned data and adding of columns" in { val dbName = s"db_${System.currentTimeMillis()}" val tableName = s"$dbName.test_table_1_$format" spark.sql(s"CREATE DATABASE IF NOT EXISTS $dbName") @@ -71,7 +71,7 @@ class TableUtilsFormatTest extends AnyFunSuite { testInsertPartitions(spark, tableUtils, tableName, format, df1, df2, ds1 = "2022-10-01", ds2 = "2022-10-02") } - test("test insertion of partitioned data and removal of columns") { + it should "test insertion of partitioned data and removal of columns" in { val dbName = s"db_${System.currentTimeMillis()}" val tableName = s"$dbName.test_table_2_$format" spark.sql(s"CREATE DATABASE IF NOT EXISTS $dbName") @@ -106,7 +106,7 @@ class TableUtilsFormatTest extends AnyFunSuite { testInsertPartitions(spark, tableUtils, tableName, format, df1, df2, ds1 = "2022-10-01", ds2 = "2022-10-02") } - test("test insertion of partitioned data and modification of columns") { + it should "test insertion of partitioned data and modification of columns" in { val dbName = s"db_${System.currentTimeMillis()}" val tableName = s"$dbName.test_table_3_$format" spark.sql(s"CREATE DATABASE IF NOT EXISTS $dbName") diff --git a/spark/src/test/scala/ai/chronon/spark/test/TableUtilsTest.scala b/spark/src/test/scala/ai/chronon/spark/test/TableUtilsTest.scala index ce6280fbc9..74b85c4569 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/TableUtilsTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/TableUtilsTest.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.types import org.junit.Assert.assertEquals import org.junit.Assert.assertFalse import org.junit.Assert.assertTrue -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec import scala.util.Try @@ -44,13 +44,12 @@ class SimpleAddUDF extends UDF { } } -class TableUtilsTest { +class TableUtilsTest extends AnyFlatSpec { lazy val spark: SparkSession = SparkSessionBuilder.build("TableUtilsTest", local = true) private val tableUtils = TableTestUtils(spark) private implicit val partitionSpec: PartitionSpec = tableUtils.partitionSpec - @Test - def ColumnFromSqlTest(): Unit = { + it should "column from sql" in { val sampleSql = """ |SELECT @@ -75,8 +74,7 @@ class TableUtilsTest { assertEquals(expected, columns.sorted) } - @Test - def GetFieldNamesTest(): Unit = { + it should "get field names" in { val schema = types.StructType( Seq( types.StructField("name", types.StringType, nullable = true), @@ -141,8 +139,7 @@ class TableUtilsTest { }) } - @Test - def testInsertPartitionsAddColumns(): Unit = { + it should "insert partitions add columns" in { val tableName = "db.test_table_1" spark.sql("CREATE DATABASE IF NOT EXISTS db") val columns1 = Array( @@ -177,8 +174,7 @@ class TableUtilsTest { testInsertPartitions(tableName, df1, df2, ds1 = "2022-10-01", ds2 = "2022-10-02") } - @Test - def testInsertPartitionsRemoveColumns(): Unit = { + it should "insert partitions remove columns" in { val tableName = "db.test_table_2" spark.sql("CREATE DATABASE IF NOT EXISTS db") val columns1 = Array( @@ -212,8 +208,7 @@ class TableUtilsTest { testInsertPartitions(tableName, df1, df2, ds1 = "2022-10-01", ds2 = "2022-10-02") } - @Test - def testInsertPartitionsModifiedColumns(): Unit = { + it should "insert partitions modified columns" in { val tableName = "db.test_table_3" spark.sql("CREATE DATABASE IF NOT EXISTS db") val columns1 = Array( @@ -249,8 +244,7 @@ class TableUtilsTest { testInsertPartitions(tableName, df1, df2, ds1 = "2022-10-01", ds2 = "2022-10-02") } - @Test - def ChunkTest(): Unit = { + it should "chunk" in { val actual = tableUtils.chunk(Set("2021-01-01", "2021-01-02", "2021-01-05", "2021-01-07")) val expected = Seq( PartitionRange("2021-01-01", "2021-01-02"), @@ -260,8 +254,7 @@ class TableUtilsTest { assertEquals(expected, actual) } - @Test - def testDropPartitions(): Unit = { + it should "drop partitions" in { val tableName = "db.test_drop_partitions_table" spark.sql("CREATE DATABASE IF NOT EXISTS db") val columns1 = Array( @@ -303,8 +296,7 @@ class TableUtilsTest { ))) } - @Test - def testAllPartitionsAndGetLatestLabelMapping(): Unit = { + it should "all partitions and get latest label mapping" in { val tableName = "db.test_show_partitions" spark.sql("CREATE DATABASE IF NOT EXISTS db") @@ -375,8 +367,7 @@ class TableUtilsTest { } - @Test - def testLastAvailablePartition(): Unit = { + it should "last available partition" in { val tableName = "db.test_last_available_partition" prepareTestDataWithSubPartitions(tableName) Seq("2022-11-01", "2022-11-02", "2022-11-03").foreach { ds => @@ -385,8 +376,7 @@ class TableUtilsTest { } } - @Test - def testFirstAvailablePartition(): Unit = { + it should "first available partition" in { val tableName = "db.test_first_available_partition" prepareTestDataWithSubPartitions(tableName) Seq("2022-11-01", "2022-11-02", "2022-11-03").foreach { ds => @@ -395,8 +385,7 @@ class TableUtilsTest { } } - @Test - def testColumnSizeEstimator(): Unit = { + it should "column size estimator" in { val chrononType = StructType( "table_schema", Array( @@ -419,21 +408,18 @@ class TableUtilsTest { ) } - @Test - def testCheckTablePermission(): Unit = { + it should "check table permission" in { val tableName = "db.test_check_table_permission" prepareTestDataWithSubPartitions(tableName) assertTrue(tableUtils.checkTablePermission(tableName)) } - @Test - def testDoubleUDFRegistration(): Unit = { + it should "double udf registration" in { tableUtils.sql("CREATE TEMPORARY FUNCTION test AS 'ai.chronon.spark.test.SimpleAddUDF'") tableUtils.sql("CREATE TEMPORARY FUNCTION test AS 'ai.chronon.spark.test.SimpleAddUDF'") } - @Test - def testInsertPartitionsTableReachableAlready(): Unit = { + it should "insert partitions table reachable already" in { val tableName = "db.test_table_exists_already" spark.sql("CREATE DATABASE IF NOT EXISTS db") @@ -472,8 +458,7 @@ class TableUtilsTest { testInsertPartitions(tableName, df1, df2, ds1 = "2022-10-01", ds2 = "2022-10-02") } - @Test - def testCreateTableAlreadyExists(): Unit = { + it should "create table already exists" in { val tableName = "db.test_create_table_already_exists" spark.sql("CREATE DATABASE IF NOT EXISTS db") @@ -504,8 +489,7 @@ class TableUtilsTest { } } - @Test - def testCreateTable(): Unit = { + it should "create table" in { val tableName = "db.test_create_table" spark.sql("CREATE DATABASE IF NOT EXISTS db") try { @@ -531,8 +515,7 @@ class TableUtilsTest { } } - @Test - def testCreateTableBigQuery(): Unit = { + it should "create table big query" in { val tableName = "db.test_create_table_bigquery" spark.sql("CREATE DATABASE IF NOT EXISTS db") try { diff --git a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/DerivationTest.scala b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/DerivationTest.scala index 5b1c1f68b6..545f2fc89a 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/DerivationTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/DerivationTest.scala @@ -32,22 +32,21 @@ import org.apache.spark.sql.functions._ import org.junit.Assert.assertEquals import org.junit.Assert.assertFalse import org.junit.Assert.assertTrue -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec import org.slf4j.Logger import org.slf4j.LoggerFactory import scala.concurrent.Await import scala.concurrent.duration.Duration -class DerivationTest { +class DerivationTest extends AnyFlatSpec { @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) val spark: SparkSession = SparkSessionBuilder.build("DerivationTest", local = true) private val tableUtils = TableUtils(spark) private val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) - @Test - def testBootstrapToDerivations(): Unit = { + it should "bootstrap to derivations" in { val namespace = "test_derivations" tableUtils.createDatabase(namespace) val groupBy = BootstrapUtils.buildGroupBy(namespace, spark) @@ -293,8 +292,7 @@ class DerivationTest { assertEquals(0, diff.count()) } - @Test - def testBootstrapToDerivationsNoStar(): Unit = { + it should "bootstrap to derivations no star" in { val namespace = "test_derivations_no_star" tableUtils.createDatabase(namespace) @@ -367,13 +365,11 @@ class DerivationTest { assertEquals(0, diff.count()) } - @Test - def testLoggingNonStar(): Unit = { + it should "logging non star" in { runLoggingTest("test_derivations_logging_non_star", wildcardSelection = false) } - @Test - def testLogging(): Unit = { + it should "logging" in { runLoggingTest("test_derivations_logging", wildcardSelection = true) } @@ -501,8 +497,7 @@ class DerivationTest { assertEquals(0, diff.count()) } - @Test - def testContextual(): Unit = { + it should "contextual" in { val namespace = "test_contextual" tableUtils.createDatabase(namespace) val queryTable = BootstrapUtils.buildQuery(namespace, spark) @@ -628,8 +623,7 @@ class DerivationTest { assertFalse(schema4.contains("ext_contextual_context_2")) } - @Test - def testGroupByDerivations(): Unit = { + it should "group by derivations" in { val namespace = "test_group_by_derivations" tableUtils.createDatabase(namespace) val groupBy = BootstrapUtils.buildGroupBy(namespace, spark) diff --git a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/LogBootstrapTest.scala b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/LogBootstrapTest.scala index 30ddde1de7..1141a8a96a 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/LogBootstrapTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/LogBootstrapTest.scala @@ -32,14 +32,14 @@ import ai.chronon.spark.utils.MockApi import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions._ import org.junit.Assert.assertEquals -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec import org.slf4j.Logger import org.slf4j.LoggerFactory import scala.concurrent.Await import scala.concurrent.duration.Duration -class LogBootstrapTest { +class LogBootstrapTest extends AnyFlatSpec { @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) val spark: SparkSession = SparkSessionBuilder.build("BootstrapTest", local = true) @@ -48,8 +48,7 @@ class LogBootstrapTest { tableUtils.createDatabase(namespace) private val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) - @Test - def testBootstrap(): Unit = { + it should "bootstrap" in { // group by val groupBy = BootstrapUtils.buildGroupBy(namespace, spark) diff --git a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/TableBootstrapTest.scala b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/TableBootstrapTest.scala index 46e3fb8d0d..ce6188c534 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/TableBootstrapTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/TableBootstrapTest.scala @@ -28,11 +28,11 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions._ import org.junit.Assert.assertEquals import org.junit.Assert.assertFalse -import org.junit.Test +import org.scalatest.flatspec.AnyFlatSpec import org.slf4j.Logger import org.slf4j.LoggerFactory -class TableBootstrapTest { +class TableBootstrapTest extends AnyFlatSpec { @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) val spark: SparkSession = SparkSessionBuilder.build("BootstrapTest", local = true) @@ -77,8 +77,7 @@ class TableBootstrapTest { (bootstrapPart, bootstrapDf) } - @Test - def testBootstrap(): Unit = { + it should "bootstrap" in { val namespace = "test_table_bootstrap" tableUtils.createDatabase(namespace) @@ -161,8 +160,7 @@ class TableBootstrapTest { assertEquals(0, diff.count()) } - @Test - def testBootstrapSameJoinPartMultipleSources(): Unit = { + it should "bootstrap same join part multiple sources" in { val namespace = "test_bootstrap_multi_source" tableUtils.createDatabase(namespace) From ed1d402d3376372ad1007e1af801bf722557befe Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Tue, 14 Jan 2025 09:41:01 -0500 Subject: [PATCH 02/14] script used for translating - still requires last mile changes by hand. --- scripts/codemod/test_replace.py | 257 ++++++++++++++++++++++++++++++++ 1 file changed, 257 insertions(+) create mode 100644 scripts/codemod/test_replace.py diff --git a/scripts/codemod/test_replace.py b/scripts/codemod/test_replace.py new file mode 100644 index 0000000000..1466616506 --- /dev/null +++ b/scripts/codemod/test_replace.py @@ -0,0 +1,257 @@ +#!/usr/bin/env python3 + + +import glob + + +def get_test_class_name(path): + # Get the file name from the path + filename = path.split("/")[-1] + # Remove 'Test.scala' and return + return filename.replace("Test.scala", "") + + +def convert_fun_suite_to_flatspec(lines, test_name): + modified_lines = [] + + for line in lines: + # Replace import statement + if "import org.scalatest.funsuite.AnyFunSuite" in line: + line = line.replace("funsuite.AnyFunSuite", "flatspec.AnyFlatSpec") + modified_lines.append(line) + continue + + # Replace AnyFunSuite with AnyFlatSpec + if "extends AnyFunSuite" in line: + line = line.replace("AnyFunSuite", "AnyFlatSpec") + modified_lines.append(line) + continue + + # Handle ignore tests and regular tests + if ("ignore(" in line or "test(" in line) and "{" in line: + start = line.find('"') + end = line.find('"', start + 1) + if start != -1 and end != -1: + test_desc = line[start + 1 : end] # Get description without quotes + words = test_desc.split() + + # Check if second word is "should" + if len(words) > 1 and words[1].lower() == "should": + subject = words[0] # Use first word as subject + remaining_desc = " ".join( + words[2:] + ) # Rest of description including "should" + new_desc = f'"{subject}" should "{remaining_desc}"' + else: + new_desc = f' it should "{test_desc}"' + + # Add appropriate suffix based on whether it's ignore or test + if "ignore(" in line: + new_line = f"{new_desc} ignore {{" + else: + new_line = f"{new_desc} in {{" + + modified_lines.append(new_line + "\n") + continue + + # Keep other lines unchanged + modified_lines.append(line) + + return "".join(modified_lines) + + +def split_camel_case(word): + if not word: + return [] + + result = [] + current_word = word[0].lower() + + for i in range(1, len(word)): + current_char = word[i] + prev_char = word[i - 1] + + # Split on transition from lowercase to uppercase + if current_char.isupper() and prev_char.islower(): + result.append(current_word) + current_word = current_char.lower() + # Split on transition from uppercase to lowercase, but only if it's not + # part of an acronym (i.e., if the previous char was also uppercase and + # not at the start of a word) + elif ( + current_char.islower() + and prev_char.isupper() + and i > 1 + and word[i - 2].isupper() + ): + result.append(current_word[:-1]) + current_word = prev_char.lower() + current_char + else: + current_word += current_char.lower() + + result.append(current_word) + return [token for token in result if token != "test"] + + +def convert_junit_to_flatspec(lines, test_name): + modified_lines = [] + is_test_method = False + class_modified = False + + for line in lines: + # Replace JUnit import with FlatSpec import + if "import org.junit.Test" in line: + modified_lines.append("import org.scalatest.flatspec.AnyFlatSpec\n") + continue + + # Handle class definition + if "class" in line and "Test" in line and (not class_modified): + class_modified = True + class_name = line.split("class")[1].split("{")[0].strip() + modified_lines.append(f"class {class_name} extends AnyFlatSpec {{\n") + continue + + # Mark start of a test method + if "@Test" in line: + is_test_method = True + continue + + # Convert only test methods marked with @Test and not private + if ( + is_test_method + and "def " in line + and "private" not in line + and (("(): Unit" in line) or ("): Unit" not in line)) + ): + is_test_method = False + + method_name = line.split("def ")[1].split("(")[0] + + test_description = " ".join(split_camel_case(method_name)) + + modified_lines.append(f' it should "{test_description}" in {{\n') + continue + + is_test_method = False + modified_lines.append(line) + + return "".join(modified_lines) + + +def convert_testcase_to_flatspec(lines, test_name): + modified_lines = [] + + for line in lines: + # Replace TestCase import with FlatSpec import + if "junit.framework.TestCase" in line: + modified_lines.append("import org.scalatest.flatspec.AnyFlatSpec\n") + continue + + # Handle imports that we want to keep + if line.startswith("import") and "TestCase" not in line: + modified_lines.append(line) + continue + + # Handle class definition + if "class" in line and "extends TestCase" in line: + class_name = line.split("class")[1].split("extends")[0].strip() + modified_lines.append(f"class {class_name} extends AnyFlatSpec {{\n") + continue + + # Convert test methods (they start with "def test") + if ( + "def test" in line + and "private" not in line + and ("(): Unit" in line or "): Unit" not in line) + ): + method_name = line.split("def test")[1].split("(")[0].strip() + # If there are parameters, capture them + + test_description = " ".join(split_camel_case(method_name)) + + modified_lines.append(f' it should "{test_description}" in {{\n') + continue + + modified_lines.append(line) + + return "".join(modified_lines) + + +def convert(handler, file_path): + test_name = get_test_class_name(file_path) + with open(file_path, "r") as file: + lines = file.readlines() + converted = handler(lines, test_name) + + with open(file_path, "w") as file: + file.write(converted) + + print(f"Converted {file_path}") + + +# Few challenging test cases below + +# convert( +# convert_junit_to_flatspec, +# "spark/src/test/scala/ai/chronon/spark/test/JoinUtilsTest.scala", +# ) + +# convert( +# convert_junit_to_flatspec, +# "spark/src/test/scala/ai/chronon/spark/test/LocalExportTableAbilityTest.scala", +# ) + +# convert( +# convert_testcase_to_flatspec, +# "aggregator/src/test/scala/ai/chronon/aggregator/test/FrequentItemsTest.scala", +# ) + +# convert( +# convert_fun_suite_to_flatspec, +# "spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala", +# ) + + +if __name__ == "__main__": + test_files = glob.glob("**/*Test.scala", recursive=True) + + fun_suite_files = [] + junit_files = [] + others = [] + junit_test_case_files = [] + flat_spec_files = [] + + for file_path in test_files: + try: + with open(file_path, "r") as file: + content = file.read() + if "AnyFunSuite" in content: + fun_suite_files.append(file_path) + elif "import org.junit.Test" in content: + junit_files.append(file_path) + elif "extends TestCase" in content: + junit_test_case_files.append(file_path) + elif "extends AnyFlatSpec" in content: + flat_spec_files.append(file_path) + else: + others.append(file_path) + except Exception as e: + print(f"Error reading {file_path}: {e}") + + print(f"funsuite files:\n {"\n ".join(fun_suite_files)}") + + for file in fun_suite_files: + convert(convert_fun_suite_to_flatspec, file) + + print(f"junit files:\n {"\n ".join(junit_files)}") + + for file in junit_files: + convert(convert_junit_to_flatspec, file) + + print(f"test case files:\n {"\n ".join(junit_test_case_files)}") + + for file in junit_test_case_files: + convert(convert_testcase_to_flatspec, file) + + print(f"flat spec files:\n {"\n ".join(flat_spec_files)}") + print(f"Other files:\n {"\n ".join(others)}") From 60ab733c00cadf4363329e994cb7e2c9b43d3b1c Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Tue, 14 Jan 2025 09:42:41 -0500 Subject: [PATCH 03/14] codemod script --- scripts/codemod/test_replace.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/scripts/codemod/test_replace.py b/scripts/codemod/test_replace.py index 1466616506..a141ff708f 100644 --- a/scripts/codemod/test_replace.py +++ b/scripts/codemod/test_replace.py @@ -3,6 +3,19 @@ import glob +""" +we have tests written in multiple flavors + +- extending junit TestCase class +- using @test annotation +- using AnyFunSuite +- using AnyFlatSpec +- using vertx junit runner + +bazel silently fails to run the tests when they are not uniform! + +This script translates almost all of the tests to AnyFlatSpec except for vertx tests. +""" def get_test_class_name(path): # Get the file name from the path From a1d41a0ba79ad49f39798e77a53558db69c5de44 Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Tue, 14 Jan 2025 09:43:21 -0500 Subject: [PATCH 04/14] comments --- scripts/codemod/test_replace.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/scripts/codemod/test_replace.py b/scripts/codemod/test_replace.py index a141ff708f..b5e49d8a82 100644 --- a/scripts/codemod/test_replace.py +++ b/scripts/codemod/test_replace.py @@ -15,6 +15,10 @@ bazel silently fails to run the tests when they are not uniform! This script translates almost all of the tests to AnyFlatSpec except for vertx tests. + +NOTE: CWD needs to be the root of the repo. + +USAGE: python3 scripts/codemod/test_replace.py """ def get_test_class_name(path): From 6840dc9e6e0829aa0f07158e9a4438bae5aee101 Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Tue, 14 Jan 2025 09:53:26 -0500 Subject: [PATCH 05/14] before & after methods --- .../integrations/aws/DynamoDBKVStoreTest.scala | 11 ++++------- .../integrations/cloud_gcp/BigTableKVStoreTest.scala | 7 +++---- .../chronon/flink/test/FlinkJobIntegrationTest.scala | 11 ++++------- scripts/codemod/test_replace.py | 2 +- .../spark/test/ResultValidationAbilityTest.scala | 7 +++---- 5 files changed, 15 insertions(+), 23 deletions(-) diff --git a/cloud_aws/src/test/scala/ai/chronon/integrations/aws/DynamoDBKVStoreTest.scala b/cloud_aws/src/test/scala/ai/chronon/integrations/aws/DynamoDBKVStoreTest.scala index 3ceab3d83b..aaba986d2f 100644 --- a/cloud_aws/src/test/scala/ai/chronon/integrations/aws/DynamoDBKVStoreTest.scala +++ b/cloud_aws/src/test/scala/ai/chronon/integrations/aws/DynamoDBKVStoreTest.scala @@ -10,8 +10,7 @@ import com.amazonaws.services.dynamodbv2.local.server.DynamoDBProxyServer import io.circe.generic.auto._ import io.circe.parser._ import io.circe.syntax._ -import org.junit.After -import org.junit.Before +import org.scalatest.BeforeAndAfterAll import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.must.Matchers.be import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper @@ -32,7 +31,7 @@ import scala.util.Try case class Model(modelId: String, modelName: String, online: Boolean) case class TimeSeries(joinName: String, featureName: String, tileTs: Long, metric: String, summary: Array[Double]) -class DynamoDBKVStoreTest extends AnyFlatSpec { +class DynamoDBKVStoreTest extends AnyFlatSpec with BeforeAndAfterAll{ import DynamoDBKVStoreConstants._ @@ -56,8 +55,7 @@ class DynamoDBKVStoreTest extends AnyFlatSpec { series.asJson.noSpaces.getBytes(StandardCharsets.UTF_8) } - @Before - def setup(): Unit = { + override def beforeAll(): Unit = { // Start the local DynamoDB instance server = ServerRunner.createServerFromCommandLineArgs(Array("-inMemory", "-port", "8000")) server.start() @@ -74,8 +72,7 @@ class DynamoDBKVStoreTest extends AnyFlatSpec { .build() } - @After - def tearDown(): Unit = { + override def afterAll(): Unit = { client.close() server.stop() } diff --git a/cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/BigTableKVStoreTest.scala b/cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/BigTableKVStoreTest.scala index 2aa43bce8d..7f940a8dd4 100644 --- a/cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/BigTableKVStoreTest.scala +++ b/cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/BigTableKVStoreTest.scala @@ -16,13 +16,13 @@ import com.google.cloud.bigtable.data.v2.models.Query import com.google.cloud.bigtable.data.v2.models.Row import com.google.cloud.bigtable.data.v2.models.RowMutation import com.google.cloud.bigtable.emulator.v2.BigtableEmulatorRule -import org.junit.Before import org.junit.Rule import org.junit.runner.RunWith import org.junit.runners.JUnit4 import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.when import org.mockito.Mockito.withSettings +import org.scalatest.BeforeAndAfterAll import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import org.scalatestplus.mockito.MockitoSugar.mock @@ -34,7 +34,7 @@ import scala.concurrent.duration.DurationInt import scala.jdk.CollectionConverters._ @RunWith(classOf[JUnit4]) -class BigTableKVStoreTest extends AnyFlatSpec { +class BigTableKVStoreTest extends AnyFlatSpec with BeforeAndAfterAll{ import BigTableKVStore._ @@ -49,8 +49,7 @@ class BigTableKVStoreTest extends AnyFlatSpec { private val projectId = "test-project" private val instanceId = "test-instance" - @Before - def setup(): Unit = { + override def beforeAll(): Unit = { // Configure settings to use emulator val dataSettings = BigtableDataSettings .newBuilderForEmulator(bigtableEmulator.getPort) diff --git a/flink/src/test/scala/ai/chronon/flink/test/FlinkJobIntegrationTest.scala b/flink/src/test/scala/ai/chronon/flink/test/FlinkJobIntegrationTest.scala index 04cc8b03a4..8a3382a6cd 100644 --- a/flink/src/test/scala/ai/chronon/flink/test/FlinkJobIntegrationTest.scala +++ b/flink/src/test/scala/ai/chronon/flink/test/FlinkJobIntegrationTest.scala @@ -11,16 +11,15 @@ import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment import org.apache.flink.test.util.MiniClusterWithClientResource import org.apache.spark.sql.Encoders -import org.junit.After import org.junit.Assert.assertEquals -import org.junit.Before import org.mockito.Mockito.withSettings +import org.scalatest.BeforeAndAfterAll import org.scalatest.flatspec.AnyFlatSpec import org.scalatestplus.mockito.MockitoSugar.mock import scala.jdk.CollectionConverters.asScalaBufferConverter -class FlinkJobIntegrationTest extends AnyFlatSpec { +class FlinkJobIntegrationTest extends AnyFlatSpec with BeforeAndAfterAll{ val flinkCluster = new MiniClusterWithClientResource( new MiniClusterResourceConfiguration.Builder() @@ -52,14 +51,12 @@ class FlinkJobIntegrationTest extends AnyFlatSpec { TimestampedIR(tileIR._1, Some(timestampedTile.latestTsMillis)) } - @Before - def setup(): Unit = { + override def beforeAll(): Unit = { flinkCluster.before() CollectSink.values.clear() } - @After - def teardown(): Unit = { + override def afterAll: Unit = { flinkCluster.after() CollectSink.values.clear() } diff --git a/scripts/codemod/test_replace.py b/scripts/codemod/test_replace.py index b5e49d8a82..328e0ca65f 100644 --- a/scripts/codemod/test_replace.py +++ b/scripts/codemod/test_replace.py @@ -20,7 +20,7 @@ USAGE: python3 scripts/codemod/test_replace.py """ - +zz def get_test_class_name(path): # Get the file name from the path filename = path.split("/")[-1] diff --git a/spark/src/test/scala/ai/chronon/spark/test/ResultValidationAbilityTest.scala b/spark/src/test/scala/ai/chronon/spark/test/ResultValidationAbilityTest.scala index 96d245b9e6..ae979a9629 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/ResultValidationAbilityTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/ResultValidationAbilityTest.scala @@ -26,20 +26,19 @@ import ai.chronon.spark.TableUtils import org.apache.spark.sql.SparkSession import org.junit.Assert.assertFalse import org.junit.Assert.assertTrue -import org.junit.Before import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.mock import org.mockito.Mockito.when import org.rogach.scallop.ScallopConf +import org.scalatest.BeforeAndAfterAll import org.scalatest.flatspec.AnyFlatSpec -class ResultValidationAbilityTest extends AnyFlatSpec { +class ResultValidationAbilityTest extends AnyFlatSpec with BeforeAndAfterAll{ val confPath = "joins/team/example_join.v1" val spark: SparkSession = SparkSessionBuilder.build("test", local = true) val mockTableUtils: TableUtils = mock(classOf[TableUtils]) - @Before - def setup(): Unit = { + override def beforeAll(): Unit = { when(mockTableUtils.partitionColumn).thenReturn("ds") when(mockTableUtils.partitionSpec).thenReturn(PartitionSpec("yyyy-MM-dd", WindowUtils.Day.millis)) } From 346b38d805320abceecba872ad9ee0a72394e9b5 Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Tue, 14 Jan 2025 10:21:21 -0500 Subject: [PATCH 06/14] before after --- .../ai/chronon/integrations/aws/DynamoDBKVStoreTest.scala | 8 ++++---- .../integrations/cloud_gcp/BigTableKVStoreTest.scala | 6 +++--- .../ai/chronon/flink/test/FlinkJobIntegrationTest.scala | 8 ++++---- .../test/scala/ai/chronon/flink/test/FlinkTestUtils.scala | 1 + .../scala/ai/chronon/online/test/FetcherBaseTest.scala | 6 +++--- .../chronon/spark/test/ResultValidationAbilityTest.scala | 6 +++--- 6 files changed, 18 insertions(+), 17 deletions(-) diff --git a/cloud_aws/src/test/scala/ai/chronon/integrations/aws/DynamoDBKVStoreTest.scala b/cloud_aws/src/test/scala/ai/chronon/integrations/aws/DynamoDBKVStoreTest.scala index aaba986d2f..d759ec6ab1 100644 --- a/cloud_aws/src/test/scala/ai/chronon/integrations/aws/DynamoDBKVStoreTest.scala +++ b/cloud_aws/src/test/scala/ai/chronon/integrations/aws/DynamoDBKVStoreTest.scala @@ -10,7 +10,7 @@ import com.amazonaws.services.dynamodbv2.local.server.DynamoDBProxyServer import io.circe.generic.auto._ import io.circe.parser._ import io.circe.syntax._ -import org.scalatest.BeforeAndAfterAll +import org.scalatest.BeforeAndAfter import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.must.Matchers.be import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper @@ -31,7 +31,7 @@ import scala.util.Try case class Model(modelId: String, modelName: String, online: Boolean) case class TimeSeries(joinName: String, featureName: String, tileTs: Long, metric: String, summary: Array[Double]) -class DynamoDBKVStoreTest extends AnyFlatSpec with BeforeAndAfterAll{ +class DynamoDBKVStoreTest extends AnyFlatSpec with BeforeAndAfter{ import DynamoDBKVStoreConstants._ @@ -55,7 +55,7 @@ class DynamoDBKVStoreTest extends AnyFlatSpec with BeforeAndAfterAll{ series.asJson.noSpaces.getBytes(StandardCharsets.UTF_8) } - override def beforeAll(): Unit = { + before { // Start the local DynamoDB instance server = ServerRunner.createServerFromCommandLineArgs(Array("-inMemory", "-port", "8000")) server.start() @@ -72,7 +72,7 @@ class DynamoDBKVStoreTest extends AnyFlatSpec with BeforeAndAfterAll{ .build() } - override def afterAll(): Unit = { + after { client.close() server.stop() } diff --git a/cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/BigTableKVStoreTest.scala b/cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/BigTableKVStoreTest.scala index 7f940a8dd4..233b6adaee 100644 --- a/cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/BigTableKVStoreTest.scala +++ b/cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/BigTableKVStoreTest.scala @@ -22,7 +22,7 @@ import org.junit.runners.JUnit4 import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.when import org.mockito.Mockito.withSettings -import org.scalatest.BeforeAndAfterAll +import org.scalatest.BeforeAndAfter import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import org.scalatestplus.mockito.MockitoSugar.mock @@ -34,7 +34,7 @@ import scala.concurrent.duration.DurationInt import scala.jdk.CollectionConverters._ @RunWith(classOf[JUnit4]) -class BigTableKVStoreTest extends AnyFlatSpec with BeforeAndAfterAll{ +class BigTableKVStoreTest extends AnyFlatSpec with BeforeAndAfter{ import BigTableKVStore._ @@ -49,7 +49,7 @@ class BigTableKVStoreTest extends AnyFlatSpec with BeforeAndAfterAll{ private val projectId = "test-project" private val instanceId = "test-instance" - override def beforeAll(): Unit = { + before { // Configure settings to use emulator val dataSettings = BigtableDataSettings .newBuilderForEmulator(bigtableEmulator.getPort) diff --git a/flink/src/test/scala/ai/chronon/flink/test/FlinkJobIntegrationTest.scala b/flink/src/test/scala/ai/chronon/flink/test/FlinkJobIntegrationTest.scala index 8a3382a6cd..5b2493eb1c 100644 --- a/flink/src/test/scala/ai/chronon/flink/test/FlinkJobIntegrationTest.scala +++ b/flink/src/test/scala/ai/chronon/flink/test/FlinkJobIntegrationTest.scala @@ -13,13 +13,13 @@ import org.apache.flink.test.util.MiniClusterWithClientResource import org.apache.spark.sql.Encoders import org.junit.Assert.assertEquals import org.mockito.Mockito.withSettings -import org.scalatest.BeforeAndAfterAll +import org.scalatest.BeforeAndAfter import org.scalatest.flatspec.AnyFlatSpec import org.scalatestplus.mockito.MockitoSugar.mock import scala.jdk.CollectionConverters.asScalaBufferConverter -class FlinkJobIntegrationTest extends AnyFlatSpec with BeforeAndAfterAll{ +class FlinkJobIntegrationTest extends AnyFlatSpec with BeforeAndAfter{ val flinkCluster = new MiniClusterWithClientResource( new MiniClusterResourceConfiguration.Builder() @@ -51,12 +51,12 @@ class FlinkJobIntegrationTest extends AnyFlatSpec with BeforeAndAfterAll{ TimestampedIR(tileIR._1, Some(timestampedTile.latestTsMillis)) } - override def beforeAll(): Unit = { + before { flinkCluster.before() CollectSink.values.clear() } - override def afterAll: Unit = { + after { flinkCluster.after() CollectSink.values.clear() } diff --git a/flink/src/test/scala/ai/chronon/flink/test/FlinkTestUtils.scala b/flink/src/test/scala/ai/chronon/flink/test/FlinkTestUtils.scala index ec580e6794..518af8d808 100644 --- a/flink/src/test/scala/ai/chronon/flink/test/FlinkTestUtils.scala +++ b/flink/src/test/scala/ai/chronon/flink/test/FlinkTestUtils.scala @@ -122,6 +122,7 @@ object FlinkTestUtils { PartitionSpec(format = "yyyy-MM-dd", spanMillis = WindowUtils.Day.millis) ) } + def makeGroupBy(keyColumns: Seq[String], filters: Seq[String] = Seq.empty): GroupBy = Builders.GroupBy( sources = Seq( diff --git a/online/src/test/scala/ai/chronon/online/test/FetcherBaseTest.scala b/online/src/test/scala/ai/chronon/online/test/FetcherBaseTest.scala index c8deb4ba98..09c30e8f73 100644 --- a/online/src/test/scala/ai/chronon/online/test/FetcherBaseTest.scala +++ b/online/src/test/scala/ai/chronon/online/test/FetcherBaseTest.scala @@ -35,7 +35,7 @@ import org.mockito.ArgumentMatchers.any import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer -import org.scalatest.BeforeAndAfterAll +import org.scalatest.BeforeAndAfter import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers import org.scalatestplus.mockito.MockitoSugar @@ -48,7 +48,7 @@ import scala.util.Failure import scala.util.Success import scala.util.Try -class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with MockitoHelper with BeforeAndAfterAll { +class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with MockitoHelper with BeforeAndAfter { val GroupBy = "relevance.short_term_user_features" val Column = "pdp_view_count_14d" val GuestKey = "guest" @@ -58,7 +58,7 @@ class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with M var fetcherBase: FetcherBase = _ var kvStore: KVStore = _ - override def beforeAll(): Unit = { + before { kvStore = mock[KVStore](Answers.RETURNS_DEEP_STUBS) // The KVStore execution context is implicitly used for // Future compositions in the Fetcher so provision it in diff --git a/spark/src/test/scala/ai/chronon/spark/test/ResultValidationAbilityTest.scala b/spark/src/test/scala/ai/chronon/spark/test/ResultValidationAbilityTest.scala index ae979a9629..75dcb84b87 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/ResultValidationAbilityTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/ResultValidationAbilityTest.scala @@ -30,15 +30,15 @@ import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.mock import org.mockito.Mockito.when import org.rogach.scallop.ScallopConf -import org.scalatest.BeforeAndAfterAll +import org.scalatest.BeforeAndAfter import org.scalatest.flatspec.AnyFlatSpec -class ResultValidationAbilityTest extends AnyFlatSpec with BeforeAndAfterAll{ +class ResultValidationAbilityTest extends AnyFlatSpec with BeforeAndAfter{ val confPath = "joins/team/example_join.v1" val spark: SparkSession = SparkSessionBuilder.build("test", local = true) val mockTableUtils: TableUtils = mock(classOf[TableUtils]) - override def beforeAll(): Unit = { + before { when(mockTableUtils.partitionColumn).thenReturn("ds") when(mockTableUtils.partitionSpec).thenReturn(PartitionSpec("yyyy-MM-dd", WindowUtils.Day.millis)) } From 0ea78426fd26cf740ed69adc3a38f7b81c7bc39a Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Tue, 14 Jan 2025 12:22:25 -0500 Subject: [PATCH 07/14] asyncwriter jdk 17 opens --- .github/workflows/test_scala_no_spark.yaml | 1 + .../test/scala/ai/chronon/aggregator/test/VarianceTest.scala | 2 +- .../scala/ai/chronon/flink/test/AsyncKVStoreWriterTest.scala | 4 ++-- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test_scala_no_spark.yaml b/.github/workflows/test_scala_no_spark.yaml index b22fcf1814..d892e75ac6 100644 --- a/.github/workflows/test_scala_no_spark.yaml +++ b/.github/workflows/test_scala_no_spark.yaml @@ -48,6 +48,7 @@ jobs: - name: Run Flink tests run: | + export SBT_OPTS="-Xmx24G -Xms4G --add-opens=java.base/sun.nio.ch=ALL-UNNAMED" sbt "++ 2.12.18 flink/test" - name: Run Aggregator tests diff --git a/aggregator/src/test/scala/ai/chronon/aggregator/test/VarianceTest.scala b/aggregator/src/test/scala/ai/chronon/aggregator/test/VarianceTest.scala index 4ec3a97b42..fde24a59ff 100644 --- a/aggregator/src/test/scala/ai/chronon/aggregator/test/VarianceTest.scala +++ b/aggregator/src/test/scala/ai/chronon/aggregator/test/VarianceTest.scala @@ -60,7 +60,7 @@ class VarianceTest extends AnyFlatSpec { assertTrue((naiveResult - welfordResult) / naiveResult < 0.0000001) } - it should "variance: unit = {" in { + it should "match with naive approach" in { compare(1000000) compare(1000000, min = 100000, max = 100001) } diff --git a/flink/src/test/scala/ai/chronon/flink/test/AsyncKVStoreWriterTest.scala b/flink/src/test/scala/ai/chronon/flink/test/AsyncKVStoreWriterTest.scala index 1cd13cc858..4825c66075 100644 --- a/flink/src/test/scala/ai/chronon/flink/test/AsyncKVStoreWriterTest.scala +++ b/flink/src/test/scala/ai/chronon/flink/test/AsyncKVStoreWriterTest.scala @@ -18,7 +18,7 @@ class AsyncKVStoreWriterTest extends AnyFlatSpec { def createKVRequest(key: String, value: String, dataset: String, ts: Long): PutRequest = PutRequest(key.getBytes, value.getBytes, dataset, Some(ts)) - it should "async writer success writes" in { + it should "write successfully" in { val env = StreamExecutionEnvironment.getExecutionEnvironment val source: DataStream[PutRequest] = env .fromCollection( @@ -40,7 +40,7 @@ class AsyncKVStoreWriterTest extends AnyFlatSpec { // ensure that if we get an event that would cause the operator to throw an exception, // we don't crash the app - it should "async writer handles poison pill writes" in { + it should "handle poison pill writes" in { val env = StreamExecutionEnvironment.getExecutionEnvironment val source: DataStream[KVStore.PutRequest] = env .fromCollection( From 0fef86d316881c1cd92fa720da72a5896f6fb013 Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Tue, 14 Jan 2025 23:16:29 -0500 Subject: [PATCH 08/14] error intercepts --- .../scala/ai/chronon/spark/Analyzer.scala | 16 ++--- .../ai/chronon/spark/stats/CompareJob.scala | 2 +- .../ai/chronon/spark/test/AnalyzerTest.scala | 66 +++++++++++-------- .../spark/test/KafkaStreamBuilderTest.scala | 6 +- .../ai/chronon/spark/test/LabelJoinTest.scala | 21 ++++-- 5 files changed, 66 insertions(+), 45 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/Analyzer.scala b/spark/src/main/scala/ai/chronon/spark/Analyzer.scala index dc6d028a52..249160cd05 100644 --- a/spark/src/main/scala/ai/chronon/spark/Analyzer.scala +++ b/spark/src/main/scala/ai/chronon/spark/Analyzer.scala @@ -87,7 +87,7 @@ class Analyzer(tableUtils: TableUtils, endDate: String, count: Int = 64, sample: Double = 0.1, - enableHitter: Boolean = false, + skewDetection: Boolean = false, silenceMode: Boolean = false) { @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) @@ -528,7 +528,7 @@ class Analyzer(tableUtils: TableUtils, // set max sample to 100 rows if larger input is provided val sampleN = if (sampleNumber > 100) { 100 } else { sampleNumber } - dataFrameToMap( + dataframeToMap( df.limit(sampleN) .agg( // will return 0 if all values are null @@ -591,7 +591,7 @@ class Analyzer(tableUtils: TableUtils, } - def dataFrameToMap(inputDf: DataFrame): Map[String, String] = { + private def dataframeToMap(inputDf: DataFrame): Map[String, String] = { val row: Row = inputDf.head() val schema = inputDf.schema val columns = schema.fieldNames @@ -600,7 +600,7 @@ class Analyzer(tableUtils: TableUtils, .zip(values) .map { case (column, value) => - (column, value.toString) + (column, Option(value).getOrElse("null").toString) } .toMap } @@ -610,12 +610,12 @@ class Analyzer(tableUtils: TableUtils, case confPath: String => if (confPath.contains("/joins/")) { val joinConf = parseConf[api.Join](confPath) - analyzeJoin(joinConf, enableHitter = enableHitter) + analyzeJoin(joinConf, enableHitter = skewDetection) } else if (confPath.contains("/group_bys/")) { val groupByConf = parseConf[api.GroupBy](confPath) - analyzeGroupBy(groupByConf, enableHitter = enableHitter) + analyzeGroupBy(groupByConf, enableHitter = skewDetection) } - case groupByConf: api.GroupBy => analyzeGroupBy(groupByConf, enableHitter = enableHitter) - case joinConf: api.Join => analyzeJoin(joinConf, enableHitter = enableHitter) + case groupByConf: api.GroupBy => analyzeGroupBy(groupByConf, enableHitter = skewDetection) + case joinConf: api.Join => analyzeJoin(joinConf, enableHitter = skewDetection) } } diff --git a/spark/src/main/scala/ai/chronon/spark/stats/CompareJob.scala b/spark/src/main/scala/ai/chronon/spark/stats/CompareJob.scala index e53fbec159..0a7f2e4820 100644 --- a/spark/src/main/scala/ai/chronon/spark/stats/CompareJob.scala +++ b/spark/src/main/scala/ai/chronon/spark/stats/CompareJob.scala @@ -98,7 +98,7 @@ class CompareJob( def validate(): Unit = { // Extract the schema of the Join, StagingQuery and the keys before calling this. - val analyzer = new Analyzer(tableUtils, joinConf, startDate, endDate, enableHitter = false) + val analyzer = new Analyzer(tableUtils, joinConf, startDate, endDate, skewDetection = false) val joinChrononSchema = analyzer.analyzeJoin(joinConf)._1 val joinSchema = joinChrononSchema.map { case (k, v) => (k, SparkConversions.fromChrononType(v)) } val finalStagingQuery = StagingQuery.substitute(tableUtils, stagingQueryConf.query, startDate, endDate, endDate) diff --git a/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala b/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala index 9b3ac66be6..da0202dc74 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala @@ -46,7 +46,7 @@ class AnalyzerTest extends AnyFlatSpec { private val viewsSource = getTestEventSource() - it should "join analyzer schema with validation" in { + it should "produce correct analyzer schema" in { val viewsGroupBy = getViewsGroupBy("join_analyzer_test.item_gb", Operation.AVERAGE) val anotherViewsGroupBy = getViewsGroupBy("join_analyzer_test.another_item_gb", Operation.SUM) @@ -70,7 +70,7 @@ class AnalyzerTest extends AnyFlatSpec { ) //run analyzer and validate output schema - val analyzer = new Analyzer(tableUtils, joinConf, oneMonthAgo, today, enableHitter = true) + val analyzer = new Analyzer(tableUtils, joinConf, oneMonthAgo, today, skewDetection = true) val analyzerSchema = analyzer.analyzeJoin(joinConf)._1.map { case (k, v) => s"${k} => ${v}" }.toList.sorted val join = new Join(joinConf = joinConf, endPartition = oneMonthAgo, tableUtils) val computed = join.computeJoin() @@ -81,7 +81,7 @@ class AnalyzerTest extends AnyFlatSpec { assertTrue(expectedSchema sameElements analyzerSchema) } - it should "join analyzer validation failure" in { + it should "throw on validation failure" in { val viewsGroupBy = getViewsGroupBy("join_analyzer_test.item_gb", Operation.AVERAGE, source = getTestGBSource()) val usersGroupBy = getUsersGroupBy("join_analyzer_test.user_gb", Operation.AVERAGE, source = getTestGBSource()) @@ -104,15 +104,19 @@ class AnalyzerTest extends AnyFlatSpec { Builders.MetaData(name = "test_join_analyzer.item_type_mismatch", namespace = namespace, team = "chronon") ) - //run analyzer and validate output schema - val analyzer = new Analyzer(tableUtils, joinConf, oneMonthAgo, today, enableHitter = true) - analyzer.analyzeJoin(joinConf, validationAssert = true) + intercept[AssertionError] { + //run analyzer and validate output schema + val analyzer = new Analyzer(tableUtils, joinConf, oneMonthAgo, today, skewDetection = true) + analyzer.analyzeJoin(joinConf, validationAssert = true) + } } - it should "join analyzer validation data availability" in { + it should "throw on data unavailability" in { + // left side val itemQueries = List(Column("item", api.StringType, 100), Column("guest", api.StringType, 100)) val itemQueriesTable = s"$namespace.item_queries_with_user_table" + DataFrameGen .events(spark, itemQueries, 500, partitions = 100) .save(itemQueriesTable) @@ -139,9 +143,10 @@ class AnalyzerTest extends AnyFlatSpec { metaData = Builders.MetaData(name = "test_join_analyzer.item_validation", namespace = namespace, team = "chronon") ) - //run analyzer and validate data availability - val analyzer = new Analyzer(tableUtils, joinConf, oneMonthAgo, today, enableHitter = true) - analyzer.analyzeJoin(joinConf, validationAssert = true) + intercept[AssertionError] { + val analyzer = new Analyzer(tableUtils, joinConf, oneMonthAgo, today, skewDetection = true) + analyzer.analyzeJoin(joinConf, validationAssert = true) + } } it should "join analyzer validation data availability multiple sources" in { @@ -236,7 +241,7 @@ class AnalyzerTest extends AnyFlatSpec { ) //run analyzer an ensure ts timestamp values result in analyzer passing - val analyzer = new Analyzer(tableUtils, joinConf, oneMonthAgo, today, enableHitter = true) + val analyzer = new Analyzer(tableUtils, joinConf, oneMonthAgo, today, skewDetection = true) analyzer.analyzeJoin(joinConf, validationAssert = true) } @@ -266,13 +271,14 @@ class AnalyzerTest extends AnyFlatSpec { metaData = Builders.MetaData(name = "test_join_analyzer.key_validation", namespace = namespace, team = "chronon") ) - //run analyzer an ensure ts timestamp values result in analyzer passing - val analyzer = new Analyzer(tableUtils, joinConf, oneMonthAgo, today, enableHitter = true) - analyzer.analyzeJoin(joinConf, validationAssert = true) - + intercept[AssertionError] { + //run analyzer and trigger assertion error when timestamps are out of range + val analyzer = new Analyzer(tableUtils, joinConf, oneMonthAgo, today, skewDetection = true) + analyzer.analyzeJoin(joinConf, validationAssert = true) + } } - it should "join analyzer check timestamp all nulls" in { + it should "throw when join timestamps are all nulls" in { // left side // create the event source with nulls @@ -297,10 +303,11 @@ class AnalyzerTest extends AnyFlatSpec { metaData = Builders.MetaData(name = "test_join_analyzer.key_validation", namespace = namespace, team = "chronon") ) - //run analyzer an ensure ts timestamp values result in analyzer passing - val analyzer = new Analyzer(tableUtils, joinConf, oneMonthAgo, today, enableHitter = true) - analyzer.analyzeJoin(joinConf, validationAssert = true) - + intercept[AssertionError] { + //run analyzer and trigger assertion error when timestamps are all NULL + val analyzer = new Analyzer(tableUtils, joinConf, oneMonthAgo, today, skewDetection = true) + analyzer.analyzeJoin(joinConf, validationAssert = true) + } } it should "group by analyzer check timestamp has values" in { @@ -321,7 +328,7 @@ class AnalyzerTest extends AnyFlatSpec { } - it should "group by analyzer check timestamp all nulls" in { + it should "throw when groupBy timestamps are all nulls" in { val tableGroupBy = Builders.GroupBy( sources = Seq(getTestGBSourceWithTs("nulls")), @@ -333,9 +340,11 @@ class AnalyzerTest extends AnyFlatSpec { accuracy = Accuracy.TEMPORAL ) - //run analyzer and trigger assertion error when timestamps are all NULL - val analyzer = new Analyzer(tableUtils, tableGroupBy, oneMonthAgo, today) - analyzer.analyzeGroupBy(tableGroupBy) + intercept[AssertionError] { + //run analyzer and trigger assertion error when timestamps are all NULL + val analyzer = new Analyzer(tableUtils, tableGroupBy, oneMonthAgo, today) + analyzer.analyzeGroupBy(tableGroupBy) + } } it should "group by analyzer check timestamp out of range" in { @@ -350,10 +359,11 @@ class AnalyzerTest extends AnyFlatSpec { accuracy = Accuracy.TEMPORAL ) - //run analyzer and trigger assertion error when timestamps are all NULL - val analyzer = new Analyzer(tableUtils, tableGroupBy, oneMonthAgo, today) - analyzer.analyzeGroupBy(tableGroupBy) - + intercept[AssertionError] { + //run analyzer and trigger assertion error when timestamps are out of range + val analyzer = new Analyzer(tableUtils, tableGroupBy, oneMonthAgo, today) + analyzer.analyzeGroupBy(tableGroupBy) + } } def getTestGBSourceWithTs(option: String = "default"): api.Source = { diff --git a/spark/src/test/scala/ai/chronon/spark/test/KafkaStreamBuilderTest.scala b/spark/src/test/scala/ai/chronon/spark/test/KafkaStreamBuilderTest.scala index e7889c7067..96426cf3f9 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/KafkaStreamBuilderTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/KafkaStreamBuilderTest.scala @@ -26,8 +26,10 @@ class KafkaStreamBuilderTest extends AnyFlatSpec { private val spark: SparkSession = SparkSessionBuilder.build("KafkaStreamBuilderTest", local = true) - it should "kafka stream does not exist" in { + it should "throw when kafka stream does not exist" in { val topicInfo = TopicInfo.parse("kafka://test_topic/schema=my_schema/host=X/port=Y") - KafkaStreamBuilder.from(topicInfo)(spark, Map.empty) + intercept[RuntimeException] { + KafkaStreamBuilder.from(topicInfo)(spark, Map.empty) + } } } diff --git a/spark/src/test/scala/ai/chronon/spark/test/LabelJoinTest.scala b/spark/src/test/scala/ai/chronon/spark/test/LabelJoinTest.scala index cdb8f83302..45d74fe5c9 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/LabelJoinTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/LabelJoinTest.scala @@ -231,7 +231,7 @@ class LabelJoinTest extends AnyFlatSpec { "NEW_HOST") } - it should "label join invalid source" in { + it should "throw on invalid source" in { // Invalid left data model entities val labelJoin = Builders.LabelPart( labels = Seq( @@ -247,10 +247,13 @@ class LabelJoinTest extends AnyFlatSpec { joinParts = Seq.empty, labelParts = labelJoin ) - new LabelJoin(invalidJoinConf, tableUtils, labelDS).computeLabelJoin() + + intercept[AssertionError] { + new LabelJoin(invalidJoinConf, tableUtils, labelDS).computeLabelJoin() + } } - it should "label join invalid label group by data modal" in { + it should "throw on invalid label group-by data-model" in { // Invalid data model entities with aggregations, expected Events val agg_label_conf = Builders.GroupBy( sources = Seq(labelGroupBy.groupByConf.sources.get(0)), @@ -279,10 +282,13 @@ class LabelJoinTest extends AnyFlatSpec { joinParts = Seq.empty, labelParts = labelJoin ) - new LabelJoin(invalidJoinConf, tableUtils, labelDS).computeLabelJoin() + + intercept[AssertionError] { + new LabelJoin(invalidJoinConf, tableUtils, labelDS).computeLabelJoin() + } } - it should "label join invalid aggregations" in { + it should "throw on invalid aggregations" in { // multi window aggregations val agg_label_conf = Builders.GroupBy( sources = Seq(labelGroupBy.groupByConf.sources.get(0)), @@ -311,7 +317,10 @@ class LabelJoinTest extends AnyFlatSpec { joinParts = Seq.empty, labelParts = labelJoin ) - new LabelJoin(invalidJoinConf, tableUtils, labelDS).computeLabelJoin() + + intercept[AssertionError] { + new LabelJoin(invalidJoinConf, tableUtils, labelDS).computeLabelJoin() + } } it should "label aggregations" in { From 54d4eeca1af2a39770d10bc0bbcfde9b7f98f65c Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Tue, 14 Jan 2025 23:25:26 -0500 Subject: [PATCH 09/14] lru test cache --- .../src/test/scala/ai/chronon/online/test/LRUCacheTest.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/online/src/test/scala/ai/chronon/online/test/LRUCacheTest.scala b/online/src/test/scala/ai/chronon/online/test/LRUCacheTest.scala index ea9afe3459..ef327b3b92 100644 --- a/online/src/test/scala/ai/chronon/online/test/LRUCacheTest.scala +++ b/online/src/test/scala/ai/chronon/online/test/LRUCacheTest.scala @@ -6,14 +6,16 @@ import org.scalatest.flatspec.AnyFlatSpec class LRUCacheTest extends AnyFlatSpec { - val testCache: CaffeineCache[String, String] = LRUCache[String, String]("testCache") + it should "gets nothing when there is nothing" in { + val testCache: CaffeineCache[String, String] = LRUCache[String, String]("testCache") assert(testCache.getIfPresent("key") == null) assert(testCache.estimatedSize() == 0) } it should "gets something when there is something" in { + val testCache: CaffeineCache[String, String] = LRUCache[String, String]("testCache") assert(testCache.getIfPresent("key") == null) testCache.put("key", "value") assert(testCache.getIfPresent("key") == "value") @@ -21,6 +23,7 @@ class LRUCacheTest extends AnyFlatSpec { } it should "evicts when something is set" in { + val testCache: CaffeineCache[String, String] = LRUCache[String, String]("testCache") assert(testCache.estimatedSize() == 0) assert(testCache.getIfPresent("key") == null) testCache.put("key", "value") From 3ba567c4f5be484d05057b2f0741f7ac007093e9 Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Tue, 14 Jan 2025 23:45:22 -0500 Subject: [PATCH 10/14] flink test fix --- .../chronon/flink/test/SparkExpressionEvalFnTest.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/flink/src/test/scala/ai/chronon/flink/test/SparkExpressionEvalFnTest.scala b/flink/src/test/scala/ai/chronon/flink/test/SparkExpressionEvalFnTest.scala index 920bafc15f..8797a53259 100644 --- a/flink/src/test/scala/ai/chronon/flink/test/SparkExpressionEvalFnTest.scala +++ b/flink/src/test/scala/ai/chronon/flink/test/SparkExpressionEvalFnTest.scala @@ -2,6 +2,7 @@ package ai.chronon.flink.test import ai.chronon.flink.SparkExpressionEvalFn import org.apache.flink.api.scala._ +import org.apache.flink.configuration.Configuration import org.apache.flink.streaming.api.scala.DataStream import org.apache.flink.streaming.api.scala.DataStreamUtils import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment @@ -25,7 +26,10 @@ class SparkExpressionEvalFnTest extends AnyFlatSpec { groupBy ) - val env = StreamExecutionEnvironment.getExecutionEnvironment + val config = new Configuration() + config.setString("classloader.check-leaked-classloader", "false") + val env = StreamExecutionEnvironment.createLocalEnvironment(1, config) + val source: DataStream[E2ETestEvent] = env.fromCollection(elements) val sparkExprEvalDS = source.flatMap(sparkExprEval) @@ -34,5 +38,8 @@ class SparkExpressionEvalFnTest extends AnyFlatSpec { assert(result.size == elements.size, "Expect result sets to include all 3 rows") // let's check the id field assert(result.map(_.apply("id")).toSet == Set("test1", "test2", "test3")) + + sparkExprEval.close() + env.close() } } From 4707538550b012d6a70f334e393cffa547695d16 Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Wed, 15 Jan 2025 00:16:14 -0500 Subject: [PATCH 11/14] undo flink test fix --- .../ai/chronon/flink/test/SparkExpressionEvalFnTest.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flink/src/test/scala/ai/chronon/flink/test/SparkExpressionEvalFnTest.scala b/flink/src/test/scala/ai/chronon/flink/test/SparkExpressionEvalFnTest.scala index 8797a53259..941bd20b69 100644 --- a/flink/src/test/scala/ai/chronon/flink/test/SparkExpressionEvalFnTest.scala +++ b/flink/src/test/scala/ai/chronon/flink/test/SparkExpressionEvalFnTest.scala @@ -39,7 +39,7 @@ class SparkExpressionEvalFnTest extends AnyFlatSpec { // let's check the id field assert(result.map(_.apply("id")).toSet == Set("test1", "test2", "test3")) - sparkExprEval.close() - env.close() + // sparkExprEval.close() + // env.close() } } From f7428725b1f596d6e356a33961bcbe1b2d8b5005 Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Wed, 15 Jan 2025 11:09:39 -0500 Subject: [PATCH 12/14] concurrent join modification fix --- .../test/SparkExpressionEvalFnTest.scala | 8 +- .../scala/ai/chronon/spark/Analyzer.scala | 91 ++++++++--------- .../main/scala/ai/chronon/spark/Driver.scala | 14 +-- .../main/scala/ai/chronon/spark/Join.scala | 37 +++---- .../scala/ai/chronon/spark/JoinBase.scala | 98 ++++++++++--------- .../scala/ai/chronon/spark/JoinUtils.scala | 6 ++ .../scala/ai/chronon/spark/TableUtils.scala | 23 +++-- .../ai/chronon/spark/test/AnalyzerTest.scala | 19 +++- 8 files changed, 163 insertions(+), 133 deletions(-) diff --git a/flink/src/test/scala/ai/chronon/flink/test/SparkExpressionEvalFnTest.scala b/flink/src/test/scala/ai/chronon/flink/test/SparkExpressionEvalFnTest.scala index 941bd20b69..8bea1dce65 100644 --- a/flink/src/test/scala/ai/chronon/flink/test/SparkExpressionEvalFnTest.scala +++ b/flink/src/test/scala/ai/chronon/flink/test/SparkExpressionEvalFnTest.scala @@ -2,7 +2,6 @@ package ai.chronon.flink.test import ai.chronon.flink.SparkExpressionEvalFn import org.apache.flink.api.scala._ -import org.apache.flink.configuration.Configuration import org.apache.flink.streaming.api.scala.DataStream import org.apache.flink.streaming.api.scala.DataStreamUtils import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment @@ -26,9 +25,7 @@ class SparkExpressionEvalFnTest extends AnyFlatSpec { groupBy ) - val config = new Configuration() - config.setString("classloader.check-leaked-classloader", "false") - val env = StreamExecutionEnvironment.createLocalEnvironment(1, config) + val env = StreamExecutionEnvironment.getExecutionEnvironment val source: DataStream[E2ETestEvent] = env.fromCollection(elements) val sparkExprEvalDS = source.flatMap(sparkExprEval) @@ -38,8 +35,5 @@ class SparkExpressionEvalFnTest extends AnyFlatSpec { assert(result.size == elements.size, "Expect result sets to include all 3 rows") // let's check the id field assert(result.map(_.apply("id")).toSet == Set("test1", "test2", "test3")) - - // sparkExprEval.close() - // env.close() } } diff --git a/spark/src/main/scala/ai/chronon/spark/Analyzer.scala b/spark/src/main/scala/ai/chronon/spark/Analyzer.scala index 249160cd05..ab37932fd9 100644 --- a/spark/src/main/scala/ai/chronon/spark/Analyzer.scala +++ b/spark/src/main/scala/ai/chronon/spark/Analyzer.scala @@ -207,17 +207,17 @@ class Analyzer(tableUtils: TableUtils, def analyzeGroupBy(groupByConf: api.GroupBy, prefix: String = "", includeOutputTableName: Boolean = false, - enableHitter: Boolean = false): (Array[AggregationMetadata], Map[String, DataType]) = { + skewDetection: Boolean = false): (Array[AggregationMetadata], Map[String, DataType]) = { groupByConf.setups.foreach(tableUtils.sql) - val groupBy = GroupBy.from(groupByConf, range, tableUtils, computeDependency = enableHitter, finalize = true) + val groupBy = GroupBy.from(groupByConf, range, tableUtils, computeDependency = skewDetection, finalize = true) val name = "group_by/" + prefix + groupByConf.metaData.name - println(s"""|Running GroupBy analysis for $name ...""".stripMargin) + logger.info(s"""Running GroupBy analysis for $name ...""".stripMargin) val timestampChecks = runTimestampChecks(groupBy.inputDf) validateTimestampChecks(timestampChecks, "GroupBy", name) val analysis = - if (enableHitter) + if (skewDetection) analyze(groupBy.inputDf, groupByConf.keyColumns.toScala.toArray, groupByConf.sources.toScala.map(_.table).mkString(",")) @@ -241,20 +241,20 @@ class Analyzer(tableUtils: TableUtils, groupBy.outputSchema } if (silenceMode) { - println(s"""ANALYSIS completed for group_by/${name}.""".stripMargin) + logger.info(s"""ANALYSIS completed for group_by/${name}.""".stripMargin) } else { - println(s""" + logger.info(s""" |ANALYSIS for $name: |$analysis """.stripMargin) if (includeOutputTableName) - println(s""" + logger.info(s""" |----- OUTPUT TABLE NAME ----- |${groupByConf.metaData.outputTable} """.stripMargin) val keySchema = groupBy.keySchema.fields.map { field => s" ${field.name} => ${field.dataType}" } schema.fields.map { field => s" ${field.name} => ${field.fieldType}" } - println(s""" + logger.info(s""" |----- KEY SCHEMA ----- |${keySchema.mkString("\n")} |----- OUTPUT SCHEMA ----- @@ -275,16 +275,15 @@ class Analyzer(tableUtils: TableUtils, } def analyzeJoin(joinConf: api.Join, - enableHitter: Boolean = false, + skewDetection: Boolean = false, validateTablePermission: Boolean = true, validationAssert: Boolean = false): (Map[String, DataType], ListBuffer[AggregationMetadata]) = { val name = "joins/" + joinConf.metaData.name - println(s"""|Running join analysis for $name ...""".stripMargin) + logger.info(s"""|Running join analysis for $name ...\n""".stripMargin) // run SQL environment setups such as UDFs and JARs joinConf.setups.foreach(tableUtils.sql) - val (analysis, leftDf) = if (enableHitter) { - println() + val (analysis, leftDf) = if (skewDetection) { val leftDf = JoinUtils.leftDf(joinConf, range, tableUtils, allowEmpty = true).get val analysis = analyze(leftDf, joinConf.leftKeyCols, joinConf.left.table) (analysis, leftDf) @@ -313,14 +312,17 @@ class Analyzer(tableUtils: TableUtils, val rangeToFill = JoinUtils.getRangesToFill(joinConf.left, tableUtils, endDate, historicalBackfill = joinConf.historicalBackfill) - println(s"Join range to fill $rangeToFill") + logger.info(s"Join range to fill $rangeToFill") val unfilledRanges = tableUtils .unfilledRanges(joinConf.metaData.outputTable, rangeToFill, Some(Seq(joinConf.left.table))) .getOrElse(Seq.empty) joinConf.joinParts.toScala.foreach { part => val (aggMetadata, gbKeySchema) = - analyzeGroupBy(part.groupBy, part.fullPrefix, includeOutputTableName = true, enableHitter = enableHitter) + analyzeGroupBy(part.groupBy, + Option(part.prefix).map(_ + "_").getOrElse(""), + includeOutputTableName = true, + skewDetection = skewDetection) aggregationsMetadata ++= aggMetadata.map { aggMeta => AggregationMetadata(part.fullPrefix + "_" + aggMeta.name, aggMeta.columnType, @@ -330,7 +332,7 @@ class Analyzer(tableUtils: TableUtils, part.getGroupBy.getMetaData.getName) } // Run validation checks. - println(s""" + logger.info(s""" |left columns: ${leftDf.columns.mkString(", ")} |gb columns: ${gbKeySchema.keys.mkString(", ")} |""".stripMargin) @@ -351,9 +353,9 @@ class Analyzer(tableUtils: TableUtils, val rightSchema: Map[String, DataType] = aggregationsMetadata.map(aggregation => (aggregation.name, aggregation.columnType)).toMap if (silenceMode) { - println(s"""-- ANALYSIS completed for join/${joinConf.metaData.cleanName}. --""".stripMargin.blue) + logger.info(s"""-- ANALYSIS completed for join/${joinConf.metaData.cleanName}. --""".stripMargin.blue) } else { - println(s""" + logger.info(s""" |ANALYSIS for join/${joinConf.metaData.cleanName}: |$analysis |-- OUTPUT TABLE NAME -- @@ -363,38 +365,39 @@ class Analyzer(tableUtils: TableUtils, |-- RIGHT SIDE SCHEMA -- |${rightSchema.mkString("\n")} |-- END -- - |""".stripMargin) + |""".stripMargin.green) } - println(s"-- Validations for join/${joinConf.metaData.cleanName} --") + logger.info(s"-- Validations for join/${joinConf.metaData.cleanName} --") if (gbStartPartitions.nonEmpty) { - println( - "-- Following Group_Bys contains a startPartition. Please check if any startPartition will conflict with your backfill. --") + logger.info( + "-- Following GroupBy-s contains a startPartition. Please check if any startPartition will conflict with your backfill. --") gbStartPartitions.foreach { case (gbName, startPartitions) => - println(s" $gbName : ${startPartitions.mkString(",")}".yellow) + logger.info(s" $gbName : ${startPartitions.mkString(",")}".yellow) } } if (keysWithError.nonEmpty) { - println(s"-- Schema validation completed. Found ${keysWithError.size} errors".red) + logger.info(s"-- Schema validation completed. Found ${keysWithError.size} errors".red) val keyErrorSet: Set[(String, String)] = keysWithError.toSet - println(keyErrorSet.map { case (key, errorMsg) => s"$key => $errorMsg" }.mkString("\n ").yellow) + logger.info(keyErrorSet.map { case (key, errorMsg) => s"$key => $errorMsg" }.mkString("\n ").yellow) } if (noAccessTables.nonEmpty) { - println(s"-- Table permission check completed. Found permission errors in ${noAccessTables.size} tables --".red) - println(noAccessTables.mkString("\n ").yellow) + logger.info( + s"-- Table permission check completed. Found permission errors in ${noAccessTables.size} tables --".red) + logger.info(noAccessTables.mkString("\n ").yellow) } if (dataAvailabilityErrors.nonEmpty) { - println(s"-- Data availability check completed. Found issue in ${dataAvailabilityErrors.size} tables --".red) + logger.info(s"-- Data availability check completed. Found issue in ${dataAvailabilityErrors.size} tables --".red) dataAvailabilityErrors.foreach(error => - println(s" Group_By ${error._2} : Source Tables ${error._1} : Expected start ${error._3}".yellow)) + logger.info(s" Group_By ${error._2} : Source Tables ${error._1} : Expected start ${error._3}".yellow)) } if (keysWithError.isEmpty && noAccessTables.isEmpty && dataAvailabilityErrors.isEmpty) { - println("-- Backfill validation completed. No errors found. --".green) + logger.info("-- Backfill validation completed. No errors found. --".green) } if (validationAssert) { @@ -418,9 +421,9 @@ class Analyzer(tableUtils: TableUtils, // validate the schema of the left and right side of the join and make sure the types match // return a map of keys and corresponding error message that failed validation - def runSchemaValidation(left: Map[String, DataType], - right: Map[String, DataType], - keyMapping: Map[String, String]): Map[String, String] = { + private def runSchemaValidation(left: Map[String, DataType], + right: Map[String, DataType], + keyMapping: Map[String, String]): Map[String, String] = { keyMapping.flatMap { case (_, leftKey) if !left.contains(leftKey) => Some(leftKey -> @@ -441,8 +444,8 @@ class Analyzer(tableUtils: TableUtils, // validate the table permissions for given list of tables // return a list of tables that the user doesn't have access to - def runTablePermissionValidation(sources: Set[String]): Set[String] = { - println(s"Validating ${sources.size} tables permissions ...") + private def runTablePermissionValidation(sources: Set[String]): Set[String] = { + logger.info(s"Validating ${sources.size} tables permissions ...") val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) //todo: handle offset-by-1 depending on temporal vs snapshot accuracy val partitionFilter = tableUtils.partitionSpec.minus(today, new Window(2, TimeUnit.DAYS)) @@ -459,7 +462,7 @@ class Analyzer(tableUtils: TableUtils, groupBy: api.GroupBy, unfilledRanges: Seq[PartitionRange]): List[(String, String, String)] = { if (unfilledRanges.isEmpty) { - println("No unfilled ranges found.") + logger.info("No unfilled ranges found.") List.empty } else { val firstUnfilledPartition = unfilledRanges.min.start @@ -479,14 +482,14 @@ class Analyzer(tableUtils: TableUtils, case (Events, Events, Accuracy.TEMPORAL) => tableUtils.partitionSpec.minus(firstUnfilledPartition, window) } - println( + logger.info( s"Checking data availability for group_by ${groupBy.metaData.name} ... Expected start partition: $expectedStart") if (groupBy.sources.toScala.exists(s => s.isCumulative)) { List.empty } else { val tableToPartitions = groupBy.sources.toScala.map { source => val table = source.table - println(s"Checking table $table for data availability ...") + logger.info(s"Checking table $table for data availability ...") val partitions = tableUtils.partitions(table) val startOpt = if (partitions.isEmpty) None else Some(partitions.min) val endOpt = if (partitions.isEmpty) None else Some(partitions.max) @@ -496,7 +499,7 @@ class Analyzer(tableUtils: TableUtils, val minPartition = if (allPartitions.isEmpty) None else Some(allPartitions.min) if (minPartition.isEmpty || minPartition.get > expectedStart) { - println(s""" + logger.info(s""" |Join needs data older than what is available for GroupBy: ${groupBy.metaData.name} |left-${leftDataModel.toString.low.yellow}, |right-${groupBy.dataModel.toString.low.yellow}, @@ -504,7 +507,7 @@ class Analyzer(tableUtils: TableUtils, |expected earliest available data partition: $expectedStart\n""".stripMargin.red) tableToPartitions.foreach { case (table, _, startOpt, endOpt) => - println( + logger.info( s"Table $table startPartition ${startOpt.getOrElse("empty")} endPartition ${endOpt.getOrElse("empty")}") } val tables = tableToPartitions.map(_._1) @@ -600,7 +603,7 @@ class Analyzer(tableUtils: TableUtils, .zip(values) .map { case (column, value) => - (column, Option(value).getOrElse("null").toString) + (column, value.toString) } .toMap } @@ -610,12 +613,12 @@ class Analyzer(tableUtils: TableUtils, case confPath: String => if (confPath.contains("/joins/")) { val joinConf = parseConf[api.Join](confPath) - analyzeJoin(joinConf, enableHitter = skewDetection) + analyzeJoin(joinConf, skewDetection = skewDetection) } else if (confPath.contains("/group_bys/")) { val groupByConf = parseConf[api.GroupBy](confPath) - analyzeGroupBy(groupByConf, enableHitter = skewDetection) + analyzeGroupBy(groupByConf, skewDetection = skewDetection) } - case groupByConf: api.GroupBy => analyzeGroupBy(groupByConf, enableHitter = skewDetection) - case joinConf: api.Join => analyzeJoin(joinConf, enableHitter = skewDetection) + case groupByConf: api.GroupBy => analyzeGroupBy(groupByConf, skewDetection = skewDetection) + case joinConf: api.Join => analyzeJoin(joinConf, skewDetection = skewDetection) } } diff --git a/spark/src/main/scala/ai/chronon/spark/Driver.scala b/spark/src/main/scala/ai/chronon/spark/Driver.scala index 7dbf7bd0b9..9a722ce160 100644 --- a/spark/src/main/scala/ai/chronon/spark/Driver.scala +++ b/spark/src/main/scala/ai/chronon/spark/Driver.scala @@ -431,24 +431,24 @@ object Driver { class Args extends Subcommand("analyze") with OfflineSubcommand { val startDate: ScallopOption[String] = opt[String](required = false, - descr = "Finds heavy hitters & time-distributions until a specified start date", + descr = "Finds skewed keys & time-distributions until a specified start date", default = None) - val count: ScallopOption[Int] = + val skewKeyCount: ScallopOption[Int] = opt[Int]( required = false, descr = - "Finds the specified number of heavy hitters approximately. The larger this number is the more accurate the analysis will be.", + "Finds the specified number of skewed keys. The larger this number is the more accurate the analysis will be.", default = Option(128) ) val sample: ScallopOption[Double] = opt[Double](required = false, descr = "Sampling ratio - what fraction of rows into incorporate into the heavy hitter estimate", default = Option(0.1)) - val enableHitter: ScallopOption[Boolean] = + val skewDetection: ScallopOption[Boolean] = opt[Boolean]( required = false, descr = - "enable skewed data analysis - whether to include the heavy hitter analysis, will only output schema if disabled", + "finds skewed keys if true else will only output schema and exit. Skew detection will take longer time.", default = Some(false) ) @@ -461,9 +461,9 @@ object Driver { args.confPath(), args.startDate.getOrElse(tableUtils.partitionSpec.shiftBackFromNow(3)), args.endDate(), - args.count(), + args.skewKeyCount(), args.sample(), - args.enableHitter()).run + args.skewDetection()).run } } diff --git a/spark/src/main/scala/ai/chronon/spark/Join.scala b/spark/src/main/scala/ai/chronon/spark/Join.scala index 84beeaa607..0360b2c1b0 100644 --- a/spark/src/main/scala/ai/chronon/spark/Join.scala +++ b/spark/src/main/scala/ai/chronon/spark/Join.scala @@ -72,7 +72,8 @@ class Join(joinConf: api.Join, skipFirstHole: Boolean = true, showDf: Boolean = false, selectedJoinParts: Option[List[String]] = None) - extends JoinBase(joinConf, endPartition, tableUtils, skipFirstHole, showDf, selectedJoinParts) { +// we copy the joinConfCloned to prevent modification of shared joinConf's in unit tests + extends JoinBase(joinConf.deepCopy(), endPartition, tableUtils, skipFirstHole, showDf, selectedJoinParts) { private implicit val partitionSpec: PartitionSpec = tableUtils.partitionSpec private def padFields(df: DataFrame, structType: sql.types.StructType): DataFrame = { @@ -190,7 +191,7 @@ class Join(joinConf: api.Join, } logger.info( - s"\n======= CoveringSet for Join ${joinConf.metaData.name} for PartitionRange(${leftRange.start}, ${leftRange.end}) =======\n") + s"\n======= CoveringSet for Join ${joinConfCloned.metaData.name} for PartitionRange(${leftRange.start}, ${leftRange.end}) =======\n") coveringSetsPerJoinPart.foreach { case (joinPartMetadata, coveringSets) => logger.info(s"Bootstrap sets for join part ${joinPartMetadata.joinPart.groupBy.metaData.name}") @@ -204,10 +205,10 @@ class Join(joinConf: api.Join, } private def getRightPartsData(leftRange: PartitionRange): Seq[(JoinPart, DataFrame)] = { - joinConf.joinParts.asScala.map { joinPart => - val partTable = joinConf.partOutputTable(joinPart) + joinConfCloned.joinParts.asScala.map { joinPart => + val partTable = joinConfCloned.partOutputTable(joinPart) val effectiveRange = - if (joinConf.left.dataModel != Entities && joinPart.groupBy.inferredAccuracy == Accuracy.SNAPSHOT) { + if (joinConfCloned.left.dataModel != Entities && joinPart.groupBy.inferredAccuracy == Accuracy.SNAPSHOT) { leftRange.shift(-1) } else { leftRange @@ -273,9 +274,9 @@ class Join(joinConf: api.Join, if (skipBloomFilter) { None } else { - val leftBlooms = joinConf.leftKeyCols.iterator + val leftBlooms = joinConfCloned.leftKeyCols.iterator .map { key => - key -> bootstrapDf.generateBloomFilter(key, leftRowCount, joinConf.left.table, leftRange) + key -> bootstrapDf.generateBloomFilter(key, leftRowCount, joinConfCloned.left.table, leftRange) } .toMap .asJava @@ -383,11 +384,11 @@ class Join(joinConf: api.Join, } private def applyDerivation(baseDf: DataFrame, bootstrapInfo: BootstrapInfo, leftColumns: Seq[String]): DataFrame = { - if (!joinConf.isSetDerivations || joinConf.derivations.isEmpty) { + if (!joinConfCloned.isSetDerivations || joinConfCloned.derivations.isEmpty) { return baseDf } - val projections = joinConf.derivations.toScala.derivationProjection(bootstrapInfo.baseValueNames) + val projections = joinConfCloned.derivations.toScala.derivationProjection(bootstrapInfo.baseValueNames) val projectionsMap = projections.toMap val baseOutputColumns = baseDf.columns.toSet @@ -440,7 +441,7 @@ class Join(joinConf: api.Join, val result = baseDf.select(finalOutputColumns: _*) if (showDf) { - logger.info(s"printing results for join: ${joinConf.metaData.name}") + logger.info(s"printing results for join: ${joinConfCloned.metaData.name}") result.prettyPrint() } result @@ -455,8 +456,8 @@ class Join(joinConf: api.Join, val contextualNames = bootstrapInfo.externalParts.filter(_.externalPart.isContextual).flatMap(_.keySchema).map(_.name) - val projections = if (joinConf.isSetDerivations) { - joinConf.derivations.toScala.derivationProjection(bootstrapInfo.baseValueNames).map(_._1) + val projections = if (joinConfCloned.isSetDerivations) { + joinConfCloned.derivations.toScala.derivationProjection(bootstrapInfo.baseValueNames).map(_._1) } else { Seq() } @@ -492,13 +493,13 @@ class Join(joinConf: api.Join, val startMillis = System.currentTimeMillis() // verify left table does not have reserved columns - validateReservedColumns(leftDf, joinConf.left.table, Seq(Constants.BootstrapHash, Constants.MatchedHashes)) + validateReservedColumns(leftDf, joinConfCloned.left.table, Seq(Constants.BootstrapHash, Constants.MatchedHashes)) tableUtils .unfilledRanges(bootstrapTable, range, skipFirstHole = skipFirstHole) .getOrElse(Seq()) .foreach(unfilledRange => { - val parts = Option(joinConf.bootstrapParts) + val parts = Option(joinConfCloned.bootstrapParts) .map(_.toScala) .getOrElse(Seq()) @@ -537,16 +538,16 @@ class Join(joinConf: api.Join, // include only necessary columns. in particular, // this excludes columns that are NOT part of Join's output (either from GB or external source) val includedColumns = bootstrapDf.columns - .filter(bootstrapInfo.fieldNames ++ part.keys(joinConf, tableUtils.partitionColumn) + .filter(bootstrapInfo.fieldNames ++ part.keys(joinConfCloned, tableUtils.partitionColumn) ++ Seq(Constants.BootstrapHash, tableUtils.partitionColumn)) .sorted bootstrapDf = bootstrapDf .select(includedColumns.map(col): _*) // TODO: allow customization of deduplication logic - .dropDuplicates(part.keys(joinConf, tableUtils.partitionColumn).toArray) + .dropDuplicates(part.keys(joinConfCloned, tableUtils.partitionColumn).toArray) - coalescedJoin(partialDf, bootstrapDf, part.keys(joinConf, tableUtils.partitionColumn)) + coalescedJoin(partialDf, bootstrapDf, part.keys(joinConfCloned, tableUtils.partitionColumn)) // as part of the left outer join process, we update and maintain matched_hashes for each record // that summarizes whether there is a join-match for each bootstrap source. // later on we use this information to decide whether we still need to re-run the backfill logic @@ -564,7 +565,7 @@ class Join(joinConf: api.Join, }) val elapsedMins = (System.currentTimeMillis() - startMillis) / (60 * 1000) - logger.info(s"Finished computing bootstrap table ${joinConf.metaData.bootstrapTable} in $elapsedMins minutes") + logger.info(s"Finished computing bootstrap table ${joinConfCloned.metaData.bootstrapTable} in $elapsedMins minutes") tableUtils.scanDf(query = null, table = bootstrapTable, range = Some(range)) } diff --git a/spark/src/main/scala/ai/chronon/spark/JoinBase.scala b/spark/src/main/scala/ai/chronon/spark/JoinBase.scala index 4d683974d0..3c4ee92cee 100644 --- a/spark/src/main/scala/ai/chronon/spark/JoinBase.scala +++ b/spark/src/main/scala/ai/chronon/spark/JoinBase.scala @@ -44,7 +44,7 @@ import java.util import scala.collection.JavaConverters._ import scala.collection.Seq -abstract class JoinBase(joinConf: api.Join, +abstract class JoinBase(val joinConfCloned: api.Join, endPartition: String, tableUtils: TableUtils, skipFirstHole: Boolean, @@ -52,29 +52,29 @@ abstract class JoinBase(joinConf: api.Join, selectedJoinParts: Option[Seq[String]] = None) { @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) private implicit val partitionSpec: PartitionSpec = tableUtils.partitionSpec - assert(Option(joinConf.metaData.outputNamespace).nonEmpty, "output namespace could not be empty or null") - val metrics: Metrics.Context = Metrics.Context(Metrics.Environment.JoinOffline, joinConf) - val outputTable: String = joinConf.metaData.outputTable + assert(Option(joinConfCloned.metaData.outputNamespace).nonEmpty, "output namespace could not be empty or null") + val metrics: Metrics.Context = Metrics.Context(Metrics.Environment.JoinOffline, joinConfCloned) + val outputTable: String = joinConfCloned.metaData.outputTable // Used for parallelized JoinPart execution - val bootstrapTable: String = joinConf.metaData.bootstrapTable + val bootstrapTable: String = joinConfCloned.metaData.bootstrapTable // Get table properties from config - protected val confTableProps: Map[String, String] = Option(joinConf.metaData.tableProperties) + protected val confTableProps: Map[String, String] = Option(joinConfCloned.metaData.tableProperties) .map(_.asScala.toMap) .getOrElse(Map.empty[String, String]) private val gson = new Gson() // Combine tableProperties set on conf with encoded Join protected val tableProps: Map[String, String] = - confTableProps ++ Map(Constants.SemanticHashKey -> gson.toJson(joinConf.semanticHash.asJava)) + confTableProps ++ Map(Constants.SemanticHashKey -> gson.toJson(joinConfCloned.semanticHash.asJava)) def joinWithLeft(leftDf: DataFrame, rightDf: DataFrame, joinPart: JoinPart): DataFrame = { val partLeftKeys = joinPart.rightToLeft.values.toArray // compute join keys, besides the groupBy keys - like ds, ts etc., val additionalKeys: Seq[String] = { - if (joinConf.left.dataModel == Entities) { + if (joinConfCloned.left.dataModel == Entities) { Seq(tableUtils.partitionColumn) } else if (joinPart.groupBy.inferredAccuracy == Accuracy.TEMPORAL) { Seq(Constants.TimeColumn, tableUtils.partitionColumn) @@ -137,11 +137,11 @@ abstract class JoinBase(joinConf: api.Join, joinLevelBloomMapOpt: Option[util.Map[String, BloomFilter]], smallMode: Boolean = false): Option[DataFrame] = { - val partTable = joinConf.partOutputTable(joinPart) + val partTable = joinConfCloned.partOutputTable(joinPart) val partMetrics = Metrics.Context(metrics, joinPart) // in Events <> batch GB case, the partition dates are offset by 1 val shiftDays = - if (joinConf.left.dataModel == Events && joinPart.groupBy.inferredAccuracy == Accuracy.SNAPSHOT) { + if (joinConfCloned.left.dataModel == Events && joinPart.groupBy.inferredAccuracy == Accuracy.SNAPSHOT) { -1 } else { 0 @@ -153,18 +153,19 @@ abstract class JoinBase(joinConf: api.Join, // events | entities | snapshot => right part tables are not aligned - so scan by leftTimeRange // events | entities | temporal => right part tables are aligned - so scan by leftRange // entities | entities | snapshot => right part tables are aligned - so scan by leftRange - val rightRange = if (joinConf.left.dataModel == Events && joinPart.groupBy.inferredAccuracy == Accuracy.SNAPSHOT) { - leftTimeRangeOpt.get.shift(shiftDays) - } else { - leftRange - } + val rightRange = + if (joinConfCloned.left.dataModel == Events && joinPart.groupBy.inferredAccuracy == Accuracy.SNAPSHOT) { + leftTimeRangeOpt.get.shift(shiftDays) + } else { + leftRange + } try { val unfilledRanges = tableUtils .unfilledRanges( partTable, rightRange, - Some(Seq(joinConf.left.table)), + Some(Seq(joinConfCloned.left.table)), inputToOutputShift = shiftDays, // never skip hole during partTable's range determination logic because we don't want partTable // and joinTable to be out of sync. skipping behavior is already handled in the outer loop. @@ -210,7 +211,7 @@ abstract class JoinBase(joinConf: api.Join, } catch { case e: Exception => logger.error( - s"Error while processing groupBy: ${joinConf.metaData.name}/${joinPart.groupBy.getMetaData.getName}") + s"Error while processing groupBy: ${joinConfCloned.metaData.name}/${joinPart.groupBy.getMetaData.getName}") throw e } if (tableUtils.tableReachable(partTable)) { @@ -241,9 +242,9 @@ abstract class JoinBase(joinConf: api.Join, val rightBloomMap = if (skipBloom) { None } else { - JoinUtils.genBloomFilterIfNeeded(joinPart, joinConf, rowCount, unfilledRange, joinLevelBloomMapOpt) + JoinUtils.genBloomFilterIfNeeded(joinPart, joinConfCloned, rowCount, unfilledRange, joinLevelBloomMapOpt) } - val rightSkewFilter = joinConf.partSkewFilter(joinPart) + val rightSkewFilter = joinConfCloned.partSkewFilter(joinPart) def genGroupBy(partitionRange: PartitionRange) = GroupBy.from(joinPart.groupBy, partitionRange, @@ -262,7 +263,7 @@ abstract class JoinBase(joinConf: api.Join, timeRange } - val leftSkewFilter = joinConf.skewFilter(Some(joinPart.rightToLeft.values.toSeq)) + val leftSkewFilter = joinConfCloned.skewFilter(Some(joinPart.rightToLeft.values.toSeq)) // this is the second time we apply skew filter - but this filters only on the keys // relevant for this join part. lazy val skewFilteredLeft = leftSkewFilter @@ -301,7 +302,7 @@ abstract class JoinBase(joinConf: api.Join, date_format(renamedLeftRawDf.col(c), tableUtils.partitionFormat).as(c) case c => renamedLeftRawDf.col(c) }.toList: _*) - val rightDf = (joinConf.left.dataModel, joinPart.groupBy.dataModel, joinPart.groupBy.inferredAccuracy) match { + val rightDf = (joinConfCloned.left.dataModel, joinPart.groupBy.dataModel, joinPart.groupBy.inferredAccuracy) match { case (Entities, Events, _) => partitionRangeGroupBy.snapshotEvents(unfilledRange) case (Entities, Entities, _) => partitionRangeGroupBy.snapshotEntities case (Events, Events, Accuracy.SNAPSHOT) => @@ -323,7 +324,7 @@ abstract class JoinBase(joinConf: api.Join, rightDf } if (showDf) { - logger.info(s"printing results for joinPart: ${joinConf.metaData.name}::${joinPart.groupBy.metaData.name}") + logger.info(s"printing results for joinPart: ${joinConfCloned.metaData.name}::${joinPart.groupBy.metaData.name}") rightDfWithDerivations.prettyPrint() } Some(rightDfWithDerivations) @@ -340,23 +341,23 @@ abstract class JoinBase(joinConf: api.Join, private def getUnfilledRange(overrideStartPartition: Option[String], outputTable: String): (PartitionRange, Seq[PartitionRange]) = { - val rangeToFill = JoinUtils.getRangesToFill(joinConf.left, + val rangeToFill = JoinUtils.getRangesToFill(joinConfCloned.left, tableUtils, endPartition, overrideStartPartition, - joinConf.historicalBackfill) + joinConfCloned.historicalBackfill) logger.info(s"Left side range to fill $rangeToFill") (rangeToFill, tableUtils - .unfilledRanges(outputTable, rangeToFill, Some(Seq(joinConf.left.table)), skipFirstHole = skipFirstHole) + .unfilledRanges(outputTable, rangeToFill, Some(Seq(joinConfCloned.left.table)), skipFirstHole = skipFirstHole) .getOrElse(Seq.empty)) } def computeLeft(overrideStartPartition: Option[String] = None): Unit = { // Runs the left side query for a join and saves the output to a table, for reuse by joinPart // Computation in parallelized joinPart execution mode. - if (shouldRecomputeLeft(joinConf, bootstrapTable, tableUtils)) { + if (shouldRecomputeLeft(joinConfCloned, bootstrapTable, tableUtils)) { logger.info("Detected semantic change in left side of join, archiving left table for recomputation.") val archivedAtTs = Instant.now() tableUtils.archiveOrDropTableIfExists(bootstrapTable, Some(archivedAtTs)) @@ -368,12 +369,12 @@ abstract class JoinBase(joinConf: api.Join, logger.info("Range to fill already computed. Skipping query execution...") } else { // Register UDFs for the left part computation - joinConf.setups.foreach(tableUtils.sql) - val leftSchema = leftDf(joinConf, unfilledRanges.head, tableUtils, limit = Some(1)).map(df => df.schema) - val bootstrapInfo = BootstrapInfo.from(joinConf, rangeToFill, tableUtils, leftSchema) + joinConfCloned.setups.foreach(tableUtils.sql) + val leftSchema = leftDf(joinConfCloned, unfilledRanges.head, tableUtils, limit = Some(1)).map(df => df.schema) + val bootstrapInfo = BootstrapInfo.from(joinConfCloned, rangeToFill, tableUtils, leftSchema) logger.info(s"Running ranges: $unfilledRanges") unfilledRanges.foreach { unfilledRange => - val leftDf = JoinUtils.leftDf(joinConf, unfilledRange, tableUtils) + val leftDf = JoinUtils.leftDf(joinConfCloned, unfilledRange, tableUtils) if (leftDf.isDefined) { val leftTaggedDf = leftDf.get.addTimebasedColIfExists() computeBootstrapTable(leftTaggedDf, unfilledRange, bootstrapInfo) @@ -389,7 +390,7 @@ abstract class JoinBase(joinConf: api.Join, def computeFinal(overrideStartPartition: Option[String] = None): Unit = { // Utilizes the same tablesToRecompute check as the monolithic spark job, because if any joinPart changes, then so does the output table - if (tablesToRecompute(joinConf, outputTable, tableUtils).isEmpty) { + if (tablesToRecompute(joinConfCloned, outputTable, tableUtils).isEmpty) { logger.info("No semantic change detected, leaving output table in place.") } else { logger.info("Semantic changes detected, archiving output table.") @@ -402,11 +403,11 @@ abstract class JoinBase(joinConf: api.Join, if (unfilledRanges.isEmpty) { logger.info("Range to fill already computed. Skipping query execution...") } else { - val leftSchema = leftDf(joinConf, unfilledRanges.head, tableUtils, limit = Some(1)).map(df => df.schema) - val bootstrapInfo = BootstrapInfo.from(joinConf, rangeToFill, tableUtils, leftSchema) + val leftSchema = leftDf(joinConfCloned, unfilledRanges.head, tableUtils, limit = Some(1)).map(df => df.schema) + val bootstrapInfo = BootstrapInfo.from(joinConfCloned, rangeToFill, tableUtils, leftSchema) logger.info(s"Running ranges: $unfilledRanges") unfilledRanges.foreach { unfilledRange => - val leftDf = JoinUtils.leftDf(joinConf, unfilledRange, tableUtils) + val leftDf = JoinUtils.leftDf(joinConfCloned, unfilledRange, tableUtils) if (leftDf.isDefined) { computeFinalJoin(leftDf.get, unfilledRange, bootstrapInfo) } else { @@ -425,19 +426,19 @@ abstract class JoinBase(joinConf: api.Join, overrideStartPartition: Option[String] = None, useBootstrapForLeft: Boolean = false): Option[DataFrame] = { - assert(Option(joinConf.metaData.team).nonEmpty, - s"join.metaData.team needs to be set for join ${joinConf.metaData.name}") + assert(Option(joinConfCloned.metaData.team).nonEmpty, + s"join.metaData.team needs to be set for join ${joinConfCloned.metaData.name}") - joinConf.joinParts.asScala.foreach { jp => + joinConfCloned.joinParts.asScala.foreach { jp => assert(Option(jp.groupBy.metaData.team).nonEmpty, s"groupBy.metaData.team needs to be set for joinPart ${jp.groupBy.metaData.name}") } // Run validations before starting the job val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) - val analyzer = new Analyzer(tableUtils, joinConf, today, today, silenceMode = true) + val analyzer = new Analyzer(tableUtils, joinConfCloned, today, today, silenceMode = true) try { - analyzer.analyzeJoin(joinConf, validationAssert = true) + analyzer.analyzeJoin(joinConfCloned, validationAssert = true) metrics.gauge(Metrics.Name.validationSuccess, 1) logger.info("Join conf validation succeeded. No error found.") } catch { @@ -453,11 +454,11 @@ abstract class JoinBase(joinConf: api.Join, // First run command to archive tables that have changed semantically since the last run val archivedAtTs = Instant.now() // TODO: We should not archive the output table in the case of selected join parts mode - tablesToRecompute(joinConf, outputTable, tableUtils).foreach( + tablesToRecompute(joinConfCloned, outputTable, tableUtils).foreach( tableUtils.archiveOrDropTableIfExists(_, Some(archivedAtTs))) // Check semantic hash before overwriting left side - val source = joinConf.left + val source = joinConfCloned.left if (useBootstrapForLeft) { logger.info("Overwriting left side to use saved Bootstrap table...") source.overwriteTable(bootstrapTable) @@ -471,14 +472,14 @@ abstract class JoinBase(joinConf: api.Join, // OverrideStartPartition is used to replace the start partition of the join config. This is useful when // 1 - User would like to test run with different start partition // 2 - User has entity table which is cumulative and only want to run backfill for the latest partition - val rangeToFill = JoinUtils.getRangesToFill(joinConf.left, + val rangeToFill = JoinUtils.getRangesToFill(joinConfCloned.left, tableUtils, endPartition, overrideStartPartition, - joinConf.historicalBackfill) + joinConfCloned.historicalBackfill) logger.info(s"Join range to fill $rangeToFill") val unfilledRanges = tableUtils - .unfilledRanges(outputTable, rangeToFill, Some(Seq(joinConf.left.table)), skipFirstHole = skipFirstHole) + .unfilledRanges(outputTable, rangeToFill, Some(Seq(joinConfCloned.left.table)), skipFirstHole = skipFirstHole) .getOrElse(Seq.empty) def finalResult: DataFrame = tableUtils.scanDf(null, outputTable, range = Some(rangeToFill)) @@ -492,15 +493,16 @@ abstract class JoinBase(joinConf: api.Join, stepDays.map(unfilledRange.steps).getOrElse(Seq(unfilledRange)) } - val leftSchema = leftDf(joinConf, unfilledRanges.head, tableUtils, limit = Some(1)).map(df => df.schema) + val leftSchema = leftDf(joinConfCloned, unfilledRanges.head, tableUtils, limit = Some(1)).map(df => df.schema) // build bootstrap info once for the entire job - val bootstrapInfo = BootstrapInfo.from(joinConf, rangeToFill, tableUtils, leftSchema) + val bootstrapInfo = BootstrapInfo.from(joinConfCloned, rangeToFill, tableUtils, leftSchema) val wholeRange = PartitionRange(unfilledRanges.minBy(_.start).start, unfilledRanges.maxBy(_.end).end) val runSmallMode = { if (tableUtils.smallModelEnabled) { val thresholdCount = - leftDf(joinConf, wholeRange, tableUtils, limit = Some(tableUtils.smallModeNumRowsCutoff + 1)).get.count() + leftDf(joinConfCloned, wholeRange, tableUtils, limit = Some(tableUtils.smallModeNumRowsCutoff + 1)).get + .count() val result = thresholdCount <= tableUtils.smallModeNumRowsCutoff if (result) { logger.info(s"Counted $thresholdCount rows, running join in small mode.") @@ -526,7 +528,7 @@ abstract class JoinBase(joinConf: api.Join, val startMillis = System.currentTimeMillis() val progress = s"| [${index + 1}/${effectiveRanges.size}]" logger.info(s"Computing join for range: ${range.toString()} $progress") - leftDf(joinConf, range, tableUtils).map { leftDfInRange => + leftDf(joinConfCloned, range, tableUtils).map { leftDfInRange => if (showDf) leftDfInRange.prettyPrint() // set autoExpand = true to ensure backward compatibility due to column ordering changes val finalDf = computeRange(leftDfInRange, range, bootstrapInfo, runSmallMode, useBootstrapForLeft) diff --git a/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala b/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala index a4c949d2ad..23d795e5f3 100644 --- a/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala @@ -356,6 +356,10 @@ object JoinUtils { val collectedLeft = leftDf.collect() + // clone groupBy before modifying it to prevent concurrent modification + val groupByClone = joinPart.groupBy.deepCopy() + joinPart.setGroupBy(groupByClone) + joinPart.groupBy.sources.asScala.foreach { source => val selectMap = Option(source.rootQuery.getQuerySelects).getOrElse(Map.empty[String, String]) val groupByKeyExpressions = groupByKeyNames.map { key => @@ -389,6 +393,8 @@ object JoinUtils { s"$groupByKeyExpression in (${valueSet.mkString(sep = ",")})" } .foreach { whereClause => + logger.info(s"Injecting where clause: $whereClause into groupBy: ${joinPart.groupBy.metaData.name}") + val currentWheres = Option(source.rootQuery.getWheres).getOrElse(new util.ArrayList[String]()) currentWheres.add(whereClause) source.rootQuery.setWheres(currentWheres) diff --git a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala index bf8b8a3b68..805b3c36f2 100644 --- a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala @@ -132,7 +132,10 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable true } catch { case ex: Exception => - logger.error(s"Couldn't load $tableName", ex) + logger.info(s"""Couldn't reach $tableName. Error: ${ex.getMessage.red} + |Call path: + |${cleanStackTrace(ex).yellow} + |""".stripMargin) false } } @@ -370,17 +373,23 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable repartitionAndWrite(finalizedDf, tableName, saveMode, stats, sortByCols) } - def sql(query: String): DataFrame = { - val partitionCount = sparkSession.sparkContext.getConf.getInt("spark.default.parallelism", 1000) + // retains only the invocations from chronon code. + private def cleanStackTrace(throwable: Throwable): String = { val sw = new StringWriter() val pw = new PrintWriter(sw) - new Throwable().printStackTrace(pw) + throwable.printStackTrace(pw) val stackTraceString = sw.toString - val stackTraceStringPretty = " " + stackTraceString + " " + stackTraceString .split("\n") .filter(_.contains("chronon")) .map(_.replace("at ai.chronon.spark.test.", "").replace("at ai.chronon.spark.", "").stripLeading()) .mkString("\n ") + } + + def sql(query: String): DataFrame = { + val partitionCount = sparkSession.sparkContext.getConf.getInt("spark.default.parallelism", 1000) + + val stackTraceString = cleanStackTrace(new Throwable()) logger.info(s""" | ${"---- running query ----".highlight} @@ -389,7 +398,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable | | ---- call path ---- | - |$stackTraceStringPretty + |$stackTraceString | | ---- end ---- |""".stripMargin) @@ -807,7 +816,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable | ${wheres.mkString(",\n ").green} | partition filters: | ${rangeWheres.mkString(",\n ").green} - |""".stripMargin.yellow) + |""".stripMargin) if (selects.nonEmpty) df = df.selectExpr(selects: _*) if (wheres.nonEmpty) df = df.where(wheres.map(w => s"($w)").mkString(" AND ")) if (rangeWheres.nonEmpty) df = df.where(rangeWheres.map(w => s"($w)").mkString(" AND ")) diff --git a/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala b/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala index da0202dc74..342ac97055 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala @@ -28,12 +28,15 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions.col import org.apache.spark.sql.functions.lit import org.junit.Assert.assertTrue +import org.scalatest.BeforeAndAfter import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import org.slf4j.Logger import org.slf4j.LoggerFactory -class AnalyzerTest extends AnyFlatSpec { +class AnalyzerTest extends AnyFlatSpec with BeforeAndAfter { @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) + val spark: SparkSession = SparkSessionBuilder.build("AnalyzerTest", local = true) private val tableUtils = TableUtils(spark) @@ -44,6 +47,7 @@ class AnalyzerTest extends AnyFlatSpec { private val namespace = "analyzer_test_ns" tableUtils.createDatabase(namespace) + private val viewsTable = s"$namespace.view_events_gb_table" private val viewsSource = getTestEventSource() it should "produce correct analyzer schema" in { @@ -72,8 +76,14 @@ class AnalyzerTest extends AnyFlatSpec { //run analyzer and validate output schema val analyzer = new Analyzer(tableUtils, joinConf, oneMonthAgo, today, skewDetection = true) val analyzerSchema = analyzer.analyzeJoin(joinConf)._1.map { case (k, v) => s"${k} => ${v}" }.toList.sorted + + val originalJoinConf = joinConf.deepCopy() + val join = new Join(joinConf = joinConf, endPartition = oneMonthAgo, tableUtils) val computed = join.computeJoin() + + originalJoinConf shouldBe joinConf // running a join should not modify the passed in conf + val expectedSchema = computed.schema.fields.map(field => s"${field.name} => ${field.dataType}").sorted logger.info("=== expected schema =====") logger.info(expectedSchema.mkString("\n")) @@ -104,6 +114,9 @@ class AnalyzerTest extends AnyFlatSpec { Builders.MetaData(name = "test_join_analyzer.item_type_mismatch", namespace = namespace, team = "chronon") ) + logger.info("=== views table ===") + tableUtils.sql(s"SELECT * FROM $viewsTable LIMIT 10").show() + intercept[AssertionError] { //run analyzer and validate output schema val analyzer = new Analyzer(tableUtils, joinConf, oneMonthAgo, today, skewDetection = true) @@ -143,6 +156,9 @@ class AnalyzerTest extends AnyFlatSpec { metaData = Builders.MetaData(name = "test_join_analyzer.item_validation", namespace = namespace, team = "chronon") ) + logger.info("=== views table ===") + tableUtils.sql(s"SELECT * FROM $viewsTable LIMIT 10").show() + intercept[AssertionError] { val analyzer = new Analyzer(tableUtils, joinConf, oneMonthAgo, today, skewDetection = true) analyzer.analyzeJoin(joinConf, validationAssert = true) @@ -427,7 +443,6 @@ class AnalyzerTest extends AnyFlatSpec { Column("time_spent_ms", api.LongType, 5000) ) - val viewsTable = s"$namespace.view_events_gb_table" DataFrameGen.events(spark, viewsSchema, count = 1000, partitions = 200).drop("ts").save(viewsTable) Builders.Source.events( From f038b893a970012ec425228452e1a430ac76d0e5 Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Wed, 15 Jan 2025 11:16:57 -0500 Subject: [PATCH 13/14] rename heavy hitters to skew keys --- docs/source/test_deploy_serve/Test.md | 4 +- .../scala/ai/chronon/spark/Analyzer.scala | 40 +++++++++---------- .../main/scala/ai/chronon/spark/Driver.scala | 2 +- .../spark/stats/drift/Expressions.scala | 2 +- 4 files changed, 24 insertions(+), 24 deletions(-) diff --git a/docs/source/test_deploy_serve/Test.md b/docs/source/test_deploy_serve/Test.md index 89556ecd90..36650ee428 100644 --- a/docs/source/test_deploy_serve/Test.md +++ b/docs/source/test_deploy_serve/Test.md @@ -47,12 +47,12 @@ Please note that these validations will also be executed as a prerequisite check ``` # run the analyzer -run.py --mode=analyze --conf=production/joins/ --enable-hitter +run.py --mode=analyze --conf=production/joins/ --skew-detection ``` Optional parameters: -`--endable-hitter`: enable skewed data analysis - include the heavy hitter analysis in output, only output schema if not specified +`--skew-detection`: enable skewed data analysis - include the frequent key analysis in output, only output schema if not specified `--start-date` : Finds heavy hitters & time-distributions for a specified start date. Default 3 days prior to "today" diff --git a/spark/src/main/scala/ai/chronon/spark/Analyzer.scala b/spark/src/main/scala/ai/chronon/spark/Analyzer.scala index ab37932fd9..e94fdb571c 100644 --- a/spark/src/main/scala/ai/chronon/spark/Analyzer.scala +++ b/spark/src/main/scala/ai/chronon/spark/Analyzer.scala @@ -93,29 +93,29 @@ class Analyzer(tableUtils: TableUtils, @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) // include ts into heavy hitter analysis - useful to surface timestamps that have wrong units // include total approx row count - so it is easy to understand the percentage of skewed data - def heavyHittersWithTsAndCount(df: DataFrame, - keys: Array[String], - frequentItemMapSize: Int = 1024, - sampleFraction: Double = 0.1): Array[(String, Array[(String, Long)])] = { + def skewKeysWithTsAndCount(df: DataFrame, + keys: Array[String], + frequentItemMapSize: Int = 1024, + sampleFraction: Double = 0.1): Array[(String, Array[(String, Long)])] = { val baseDf = df.withColumn("total_count", lit("rows")) val baseKeys = keys :+ "total_count" if (df.schema.fieldNames.contains(Constants.TimeColumn)) { - heavyHitters(baseDf.withColumn("ts_year", from_unixtime(col("ts") / 1000, "yyyy")), - baseKeys :+ "ts_year", - frequentItemMapSize, - sampleFraction) + skewKeys(baseDf.withColumn("ts_year", from_unixtime(col("ts") / 1000, "yyyy")), + baseKeys :+ "ts_year", + frequentItemMapSize, + sampleFraction) } else { - heavyHitters(baseDf, baseKeys, frequentItemMapSize, sampleFraction) + skewKeys(baseDf, baseKeys, frequentItemMapSize, sampleFraction) } } - // Uses a variant Misra-Gries heavy hitter algorithm from Data Sketches to find topK most frequent items in data - // frame. The result is a Array of tuples of (column names, array of tuples of (heavy hitter keys, counts)) + // Uses a variant Misra-Gries frequent items algorithm from Data Sketches to find topK most frequent items in data + // frame. The result is a Array of tuples of (column names, array of tuples of (frequent keys, counts)) // [(keyCol1, [(key1: count1) ...]), (keyCol2, [...]), ....] - def heavyHitters(df: DataFrame, - frequentItemKeys: Array[String], - frequentItemMapSize: Int = 1024, - sampleFraction: Double = 0.1): Array[(String, Array[(String, Long)])] = { + def skewKeys(df: DataFrame, + frequentItemKeys: Array[String], + frequentItemMapSize: Int = 1024, + sampleFraction: Double = 0.1): Array[(String, Array[(String, Long)])] = { assert(frequentItemKeys.nonEmpty, "No column arrays specified for frequent items summary") // convert all keys into string val stringifiedCols = frequentItemKeys.map { col => @@ -160,13 +160,13 @@ class Analyzer(tableUtils: TableUtils, } private val range = PartitionRange(startDate, endDate)(tableUtils.partitionSpec) - // returns with heavy hitter analysis for the specified keys + // returns with frequent key analysis for the specified keys def analyze(df: DataFrame, keys: Array[String], sourceTable: String): String = { - val result = heavyHittersWithTsAndCount(df, keys, count, sample) - val header = s"Analyzing heavy-hitters from table $sourceTable over columns: [${keys.mkString(", ")}]" + val result = skewKeysWithTsAndCount(df, keys, count, sample) + val header = s"Analyzing frequent keys from table $sourceTable over columns: [${keys.mkString(", ")}]" val colPrints = result.flatMap { - case (col, heavyHitters) => - Seq(s" $col") ++ heavyHitters.map { case (name, count) => s" $name: $count" } + case (col, skewKeys) => + Seq(s" $col") ++ skewKeys.map { case (name, count) => s" $name: $count" } } (header +: colPrints).mkString("\n") } diff --git a/spark/src/main/scala/ai/chronon/spark/Driver.scala b/spark/src/main/scala/ai/chronon/spark/Driver.scala index 9a722ce160..9c35d9bc44 100644 --- a/spark/src/main/scala/ai/chronon/spark/Driver.scala +++ b/spark/src/main/scala/ai/chronon/spark/Driver.scala @@ -442,7 +442,7 @@ object Driver { ) val sample: ScallopOption[Double] = opt[Double](required = false, - descr = "Sampling ratio - what fraction of rows into incorporate into the heavy hitter estimate", + descr = "Sampling ratio - what fraction of rows into incorporate into the skew key detection", default = Option(0.1)) val skewDetection: ScallopOption[Boolean] = opt[Boolean]( diff --git a/spark/src/main/scala/ai/chronon/spark/stats/drift/Expressions.scala b/spark/src/main/scala/ai/chronon/spark/stats/drift/Expressions.scala index b67b5e13cc..ec61aee3f7 100644 --- a/spark/src/main/scala/ai/chronon/spark/stats/drift/Expressions.scala +++ b/spark/src/main/scala/ai/chronon/spark/stats/drift/Expressions.scala @@ -216,7 +216,7 @@ object Expressions { }) // TODO: deal with map keys - as histogram - high cardinality keys vs low cardinality? - // TODO: heavy hitters - top_k via approx_histogram + // TODO: frequent key - top_k via approx_histogram case types.MapType(_, vType, _) => se(Inp.cLen, Agg.ptile, MetricName.lengthPercentiles) ++ // length drift se(Inp.mapVals, Agg.arrNulls, MetricName.innerNullCount) ++ From 919730cfb24d7cd237308022608850e3f0d6fb9c Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Wed, 15 Jan 2025 12:25:23 -0500 Subject: [PATCH 14/14] isolate catalyst util test --- .github/workflows/test_scala_no_spark.yaml | 4 ++++ build.sbt | 2 +- .../ai/chronon/online/test/CatalystUtilHiveUDFTest.scala | 4 +++- .../ai/chronon/online}/test/TaggedFilterSuite.scala | 9 ++------- .../test/scala/ai/chronon/spark/test/FetcherTest.scala | 1 + .../src/test/scala/ai/chronon/spark/test/JoinTest.scala | 4 +++- .../test/scala/ai/chronon/spark/test/MutationsTest.scala | 1 + 7 files changed, 15 insertions(+), 10 deletions(-) rename {spark/src/test/scala/ai/chronon/spark => online/src/test/scala/ai/chronon/online}/test/TaggedFilterSuite.scala (83%) diff --git a/.github/workflows/test_scala_no_spark.yaml b/.github/workflows/test_scala_no_spark.yaml index d892e75ac6..2a908baf71 100644 --- a/.github/workflows/test_scala_no_spark.yaml +++ b/.github/workflows/test_scala_no_spark.yaml @@ -59,6 +59,10 @@ jobs: run: | sbt "++ 2.12.18 online/test" + - name: Run Catalyst Util tests + run: | + sbt "++ 2.12.18 online/testOnly -- -n catalystUtilHiveUdfTest" + - name: Run api tests run: | sbt "++ 2.12.18 api/test" diff --git a/build.sbt b/build.sbt index d51b25b53e..5b4d19fbe5 100644 --- a/build.sbt +++ b/build.sbt @@ -193,7 +193,7 @@ val sparkBaseSettings: Seq[Setting[_]] = Seq( ) ++ addArtifact(assembly / artifact, assembly) lazy val spark = project - .dependsOn(aggregator.%("compile->compile;test->test"), online) + .dependsOn(aggregator.%("compile->compile;test->test"), online.%("compile->compile;test->test")) .settings( sparkBaseSettings, crossScalaVersions := supportedVersions, diff --git a/online/src/test/scala/ai/chronon/online/test/CatalystUtilHiveUDFTest.scala b/online/src/test/scala/ai/chronon/online/test/CatalystUtilHiveUDFTest.scala index 5f4233f545..cb715307e7 100644 --- a/online/src/test/scala/ai/chronon/online/test/CatalystUtilHiveUDFTest.scala +++ b/online/src/test/scala/ai/chronon/online/test/CatalystUtilHiveUDFTest.scala @@ -4,7 +4,7 @@ import ai.chronon.online.CatalystUtil import org.junit.Assert.assertEquals import org.scalatest.flatspec.AnyFlatSpec -class CatalystUtilHiveUDFTest extends AnyFlatSpec with CatalystUtilTestSparkSQLStructs { +class CatalystUtilHiveUDFTest extends AnyFlatSpec with CatalystUtilTestSparkSQLStructs with TaggedFilterSuite { it should "hive ud fs via setups should work" in { val setups = Seq( @@ -21,4 +21,6 @@ class CatalystUtilHiveUDFTest extends AnyFlatSpec with CatalystUtilTestSparkSQLS assertEquals(res.get("a"), Int.MaxValue - 1) assertEquals(res.get("b"), "hello123") } + + override def tagName: String = "catalystUtilHiveUdfTest" } diff --git a/spark/src/test/scala/ai/chronon/spark/test/TaggedFilterSuite.scala b/online/src/test/scala/ai/chronon/online/test/TaggedFilterSuite.scala similarity index 83% rename from spark/src/test/scala/ai/chronon/spark/test/TaggedFilterSuite.scala rename to online/src/test/scala/ai/chronon/online/test/TaggedFilterSuite.scala index 46e76d866b..609fc04e02 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/TaggedFilterSuite.scala +++ b/online/src/test/scala/ai/chronon/online/test/TaggedFilterSuite.scala @@ -1,11 +1,6 @@ -package ai.chronon.spark.test +package ai.chronon.online.test -import org.scalatest.Args -import org.scalatest.Filter -import org.scalatest.Status -import org.scalatest.SucceededStatus -import org.scalatest.Suite -import org.scalatest.SuiteMixin +import org.scalatest._ /** * SuiteMixin that skips execution of the tests in a suite if the tests are not triggered diff --git a/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala b/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala index b7b87b9753..345edfcfcf 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala @@ -34,6 +34,7 @@ import ai.chronon.online.MetadataDirWalker import ai.chronon.online.MetadataEndPoint import ai.chronon.online.MetadataStore import ai.chronon.online.SparkConversions +import ai.chronon.online.test.TaggedFilterSuite import ai.chronon.spark.Extensions._ import ai.chronon.spark.stats.ConsistencyJob import ai.chronon.spark.utils.MockApi diff --git a/spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala b/spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala index 4462ed1ea4..c09fc556a3 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala @@ -28,6 +28,7 @@ import ai.chronon.api.ScalaJavaConversions._ import ai.chronon.api.StringType import ai.chronon.api.TimeUnit import ai.chronon.api.Window +import ai.chronon.online.test.TaggedFilterSuite import ai.chronon.spark.Extensions._ import ai.chronon.spark._ import org.apache.spark.rdd.RDD @@ -53,11 +54,12 @@ object TestRow { } } } + // Run as follows: sbt "spark/testOnly -- -n jointest" class JoinTest extends AnyFlatSpec with TaggedFilterSuite { val spark: SparkSession = SparkSessionBuilder.build("JoinTest", local = true) - private implicit val tableUtils = TableTestUtils(spark) + private implicit val tableUtils: TableTestUtils = TableTestUtils(spark) private val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) private val monthAgo = tableUtils.partitionSpec.minus(today, new Window(30, TimeUnit.DAYS)) diff --git a/spark/src/test/scala/ai/chronon/spark/test/MutationsTest.scala b/spark/src/test/scala/ai/chronon/spark/test/MutationsTest.scala index f85c4f8c5c..3603249ea7 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/MutationsTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/MutationsTest.scala @@ -23,6 +23,7 @@ import ai.chronon.api.Builders import ai.chronon.api.Operation import ai.chronon.api.TimeUnit import ai.chronon.api.Window +import ai.chronon.online.test.TaggedFilterSuite import ai.chronon.spark.Comparison import ai.chronon.spark.Extensions._ import ai.chronon.spark.Join