diff --git a/spark/src/test/scala/ai/chronon/spark/test/SchemaEvolutionTest.scala b/spark/src/test/scala/ai/chronon/spark/test/SchemaEvolutionTest.scala index 7c020c3153..8314ba8324 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/SchemaEvolutionTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/SchemaEvolutionTest.scala @@ -303,8 +303,8 @@ class SchemaEvolutionTest extends AnyFlatSpec { SchemaEvolutionUtils.runLogSchemaGroupBy(mockApi, offlineDs, "2022-10-01") val flattenerJob = new LogFlattenerJob(spark, joinConf, offlineDs, mockApi.logTable, mockApi.schemaTable) flattenerJob.buildLogTable() - val flattenedDf = spark - .table(joinConf.metaData.loggedTable) + val flattenedDf = tableUtils + .loadTable(joinConf.metaData.loggedTable) .where(col(tableUtils.partitionColumn) === offlineDs) assertEquals(2, flattenedDf.count()) assertTrue( diff --git a/spark/src/test/scala/ai/chronon/spark/test/TableUtilsFormatTest.scala b/spark/src/test/scala/ai/chronon/spark/test/TableUtilsFormatTest.scala index 992b0777a4..44010e8727 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/TableUtilsFormatTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/TableUtilsFormatTest.scala @@ -195,8 +195,8 @@ object TableUtilsFormatTest { val returnedPartitions = tableUtils.partitions(tableName) assertTrue(returnedPartitions.toSet == Set(ds1, ds2)) - val dataRead1 = spark.table(tableName).where(col("ds") === ds1) - val dataRead2 = spark.table(tableName).where(col("ds") === ds2) + val dataRead1 = tableUtils.loadTable(tableName).where(col("ds") === ds1) + val dataRead2 = tableUtils.loadTable(tableName).where(col("ds") === ds2) assertTrue(dataRead1.columns.length == dataRead2.columns.length) val totalColumnsCount = (df1.schema.fieldNames.toSet ++ df2.schema.fieldNames.toSet).size diff --git a/spark/src/test/scala/ai/chronon/spark/test/TableUtilsTest.scala b/spark/src/test/scala/ai/chronon/spark/test/TableUtilsTest.scala index a1df7f2f85..012073a501 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/TableUtilsTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/TableUtilsTest.scala @@ -100,8 +100,8 @@ class TableUtilsTest extends AnyFlatSpec { tableUtils.insertPartitions(df2, tableName, autoExpand = true) - val dataRead1 = spark.table(tableName).where(col("ds") === ds1) - val dataRead2 = spark.table(tableName).where(col("ds") === ds2) + val dataRead1 = tableUtils.loadTable(tableName).where(col("ds") === ds1) + val dataRead2 = tableUtils.loadTable(tableName).where(col("ds") === ds2) assertTrue(dataRead1.columns.length == dataRead2.columns.length) val totalColumnsCount = (df1.schema.fieldNames.toSet ++ df2.schema.fieldNames.toSet).size diff --git a/spark/src/test/scala/ai/chronon/spark/test/analyzer/DerivationTest.scala b/spark/src/test/scala/ai/chronon/spark/test/analyzer/DerivationTest.scala index a579855132..3e1cf21210 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/analyzer/DerivationTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/analyzer/DerivationTest.scala @@ -28,9 +28,12 @@ import ai.chronon.spark.test.{OnlineUtils, SchemaEvolutionUtils} import ai.chronon.spark.utils.MockApi import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions._ -import org.junit.Assert.{assertEquals, assertFalse, assertTrue} +import org.junit.Assert.assertEquals +import org.junit.Assert.assertFalse +import org.junit.Assert.assertTrue import org.scalatest.flatspec.AnyFlatSpec -import org.slf4j.{Logger, LoggerFactory} +import org.slf4j.Logger +import org.slf4j.LoggerFactory import scala.concurrent.Await import scala.concurrent.duration.Duration diff --git a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/LogBootstrapTest.scala b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/LogBootstrapTest.scala index ad8d0a0a2d..18ec9d1a67 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/LogBootstrapTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/LogBootstrapTest.scala @@ -20,7 +20,6 @@ import ai.chronon.api.Extensions._ import ai.chronon.api.ScalaJavaConversions._ import ai.chronon.api._ import ai.chronon.online.fetcher.Fetcher.Request -import ai.chronon.online.fetcher.MetadataStore import ai.chronon.spark.Comparison import ai.chronon.spark.Extensions._ import ai.chronon.spark.LogFlattenerJob @@ -110,7 +109,7 @@ class LogBootstrapTest extends AnyFlatSpec { // Init artifacts to run online fetching and logging val kvStore = OnlineUtils.buildInMemoryKVStore(namespace) val mockApi = new MockApi(() => kvStore, namespace) - val endDs = spark.table(queryTable).select(max(tableUtils.partitionColumn)).head().getString(0) + val endDs = tableUtils.loadTable(queryTable).select(max(tableUtils.partitionColumn)).head().getString(0) OnlineUtils.serve(tableUtils, kvStore, () => kvStore, namespace, endDs, groupBy) val fetcher = mockApi.buildFetcher(debug = true) @@ -118,8 +117,8 @@ class LogBootstrapTest extends AnyFlatSpec { kvStore.create(Constants.MetadataDataset) metadataStore.putJoinConf(joinV1) - val requests = spark - .table(queryTable) + val requests = tableUtils + .loadTable(queryTable) .where(col(tableUtils.partitionColumn) === endDs) .where(col("user").isNotNull and col("request_id").isNotNull) .select("user", "request_id", "ts") @@ -148,7 +147,7 @@ class LogBootstrapTest extends AnyFlatSpec { val flattenerJob = new LogFlattenerJob(spark, joinV1, endDs, mockApi.logTable, mockApi.schemaTable) flattenerJob.buildLogTable() - val logDf = spark.table(joinV1.metaData.loggedTable) + val logDf = tableUtils.loadTable(joinV1.metaData.loggedTable) assertEquals(logDf.count(), responses.length) val baseJoinJob = new ai.chronon.spark.Join(baseJoinV2, endDs, tableUtils) diff --git a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/TableBootstrapTest.scala b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/TableBootstrapTest.scala index dd77469527..a7b6e7ad7c 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/TableBootstrapTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/TableBootstrapTest.scala @@ -47,8 +47,8 @@ class TableBootstrapTest extends AnyFlatSpec { samplePercent: Double = 0.8, partitionCol: String = "ds"): (BootstrapPart, DataFrame) = { val bootstrapTable = s"$namespace.$tableName" - val preSampleBootstrapDf = spark - .table(queryTable) + val preSampleBootstrapDf = tableUtils + .loadTable(queryTable) .select( col("request_id"), (rand() * 30000) @@ -175,7 +175,7 @@ class TableBootstrapTest extends AnyFlatSpec { tableUtils.createDatabase(namespace) val queryTable = BootstrapUtils.buildQuery(namespace, spark) - val endDs = spark.table(queryTable).select(max(tableUtils.partitionColumn)).head().getString(0) + val endDs = tableUtils.loadTable(queryTable).select(max(tableUtils.partitionColumn)).head().getString(0) val joinPart = Builders.JoinPart(groupBy = BootstrapUtils.buildGroupBy(namespace, spark)) val derivations = Seq( diff --git a/spark/src/test/scala/ai/chronon/spark/test/join/FeatureWithLabelJoinTest.scala b/spark/src/test/scala/ai/chronon/spark/test/join/FeatureWithLabelJoinTest.scala index 6dc15cf329..ea5c50c92c 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/join/FeatureWithLabelJoinTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/join/FeatureWithLabelJoinTest.scala @@ -57,7 +57,7 @@ class FeatureWithLabelJoinTest extends AnyFlatSpec { logger.info(" == First Run Label version 2022-10-30 == ") prefixColumnName(labelDf, exceptions = labelJoinConf.rowIdentifier(null, tableUtils.partitionColumn)) .show() - val featureDf = tableUtils.sparkSession.table(joinConf.metaData.outputTable) + val featureDf = tableUtils.loadTable(joinConf.metaData.outputTable) logger.info(" == Features == ") featureDf.show() val computed = tableUtils.sql(s"select * from ${joinConf.metaData.outputFinalView}") @@ -141,7 +141,7 @@ class FeatureWithLabelJoinTest extends AnyFlatSpec { logger.info(" == Label DF == ") prefixColumnName(labelDf, exceptions = labelJoinConf.rowIdentifier(null, tableUtils.partitionColumn)) .show() - val featureDf = tableUtils.sparkSession.table(joinConf.metaData.outputTable) + val featureDf = tableUtils.loadTable(joinConf.metaData.outputTable) logger.info(" == Features DF == ") featureDf.show() val computed = tableUtils.sql(s"select * from ${joinConf.metaData.outputFinalView}") diff --git a/spark/src/test/scala/ai/chronon/spark/test/join/JoinTest.scala b/spark/src/test/scala/ai/chronon/spark/test/join/JoinTest.scala index 1204956a63..44044da154 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/join/JoinTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/join/JoinTest.scala @@ -77,9 +77,9 @@ class JoinTest extends AnyFlatSpec { ) val data = spark.createDataFrame(rows) toDF ("ds", "value") data.write.mode(SaveMode.Overwrite).format("hive").partitionBy("ds").saveAsTable(f"${namespace}.table") - assertEquals(spark.table(f"${namespace}.table").as[TestRow].collect().toList.sorted, rows.sorted) + assertEquals(tableUtils.loadTable(f"${namespace}.table").as[TestRow].collect().toList.sorted, rows.sorted) - spark.table(f"${namespace}.table").show(truncate = false) + tableUtils.loadTable(f"${namespace}.table").show(truncate = false) val dynamicPartitions = List( TestRow("4", "y"), @@ -92,14 +92,15 @@ class JoinTest extends AnyFlatSpec { .mode(SaveMode.Overwrite) .insertInto(f"${namespace}.table") - spark.table(f"${namespace}.table").show(truncate = false) + tableUtils.loadTable(f"${namespace}.table").show(truncate = false) val updatedExpected = (rows.map((r) => r.ds -> r.value).toMap ++ dynamicPartitions.map((r) => r.ds -> r.value).toMap).map { case (k, v) => TestRow(k, v) }.toList - assertEquals(updatedExpected.sorted, spark.table(f"${namespace}.table").as[TestRow].collect().toList.sorted) + assertEquals(updatedExpected.sorted, + tableUtils.loadTable(f"${namespace}.table").as[TestRow].collect().toList.sorted) } diff --git a/spark/src/test/scala/ai/chronon/spark/test/join/ModularJoinTest.scala b/spark/src/test/scala/ai/chronon/spark/test/join/ModularJoinTest.scala index ec978e5e5e..566fce39f9 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/join/ModularJoinTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/join/ModularJoinTest.scala @@ -4,18 +4,16 @@ import ai.chronon.aggregator.test.Column import ai.chronon.api import ai.chronon.api.Extensions._ import ai.chronon.api._ -import ai.chronon.api.PartitionRange -import ai.chronon.orchestration.{ - BootstrapJobArgs, - JoinDerivationJobArgs, - JoinPartJobArgs, - MergeJobArgs, - SourceJobArgs, - SourceWithFilter -} +import ai.chronon.orchestration.BootstrapJobArgs +import ai.chronon.orchestration.JoinDerivationJobArgs +import ai.chronon.orchestration.JoinPartJobArgs +import ai.chronon.orchestration.MergeJobArgs +import ai.chronon.orchestration.SourceJobArgs +import ai.chronon.orchestration.SourceWithFilter import ai.chronon.spark.Extensions._ import ai.chronon.spark._ -import ai.chronon.spark.test.{DataFrameGen, TableTestUtils} +import ai.chronon.spark.test.DataFrameGen +import ai.chronon.spark.test.TableTestUtils import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions._ import org.junit.Assert._ @@ -129,8 +127,8 @@ class ModularJoinTest extends AnyFlatSpec { // Make bootstrap part and table val bootstrapSourceTable = s"$namespace.bootstrap" val bootstrapCol = "unit_test_user_transactions_amount_dollars_sum_10d" - spark - .table(queryTable) + tableUtils + .loadTable(queryTable) .select( col("user"), col("ts"),