Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,8 @@ class SchemaEvolutionTest extends AnyFlatSpec {
SchemaEvolutionUtils.runLogSchemaGroupBy(mockApi, offlineDs, "2022-10-01")
val flattenerJob = new LogFlattenerJob(spark, joinConf, offlineDs, mockApi.logTable, mockApi.schemaTable)
flattenerJob.buildLogTable()
val flattenedDf = spark
.table(joinConf.metaData.loggedTable)
val flattenedDf = tableUtils
.loadTable(joinConf.metaData.loggedTable)
.where(col(tableUtils.partitionColumn) === offlineDs)
assertEquals(2, flattenedDf.count())
assertTrue(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,8 @@ object TableUtilsFormatTest {
val returnedPartitions = tableUtils.partitions(tableName)
assertTrue(returnedPartitions.toSet == Set(ds1, ds2))

val dataRead1 = spark.table(tableName).where(col("ds") === ds1)
val dataRead2 = spark.table(tableName).where(col("ds") === ds2)
val dataRead1 = tableUtils.loadTable(tableName).where(col("ds") === ds1)
val dataRead2 = tableUtils.loadTable(tableName).where(col("ds") === ds2)
assertTrue(dataRead1.columns.length == dataRead2.columns.length)

val totalColumnsCount = (df1.schema.fieldNames.toSet ++ df2.schema.fieldNames.toSet).size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ class TableUtilsTest extends AnyFlatSpec {

tableUtils.insertPartitions(df2, tableName, autoExpand = true)

val dataRead1 = spark.table(tableName).where(col("ds") === ds1)
val dataRead2 = spark.table(tableName).where(col("ds") === ds2)
val dataRead1 = tableUtils.loadTable(tableName).where(col("ds") === ds1)
val dataRead2 = tableUtils.loadTable(tableName).where(col("ds") === ds2)
assertTrue(dataRead1.columns.length == dataRead2.columns.length)

val totalColumnsCount = (df1.schema.fieldNames.toSet ++ df2.schema.fieldNames.toSet).size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@ import ai.chronon.spark.test.{OnlineUtils, SchemaEvolutionUtils}
import ai.chronon.spark.utils.MockApi
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.junit.Assert.{assertEquals, assertFalse, assertTrue}
import org.junit.Assert.assertEquals
import org.junit.Assert.assertFalse
import org.junit.Assert.assertTrue
import org.scalatest.flatspec.AnyFlatSpec
import org.slf4j.{Logger, LoggerFactory}
import org.slf4j.Logger
import org.slf4j.LoggerFactory

import scala.concurrent.Await
import scala.concurrent.duration.Duration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import ai.chronon.api.Extensions._
import ai.chronon.api.ScalaJavaConversions._
import ai.chronon.api._
import ai.chronon.online.fetcher.Fetcher.Request
import ai.chronon.online.fetcher.MetadataStore
import ai.chronon.spark.Comparison
import ai.chronon.spark.Extensions._
import ai.chronon.spark.LogFlattenerJob
Expand Down Expand Up @@ -110,16 +109,16 @@ class LogBootstrapTest extends AnyFlatSpec {
// Init artifacts to run online fetching and logging
val kvStore = OnlineUtils.buildInMemoryKVStore(namespace)
val mockApi = new MockApi(() => kvStore, namespace)
val endDs = spark.table(queryTable).select(max(tableUtils.partitionColumn)).head().getString(0)
val endDs = tableUtils.loadTable(queryTable).select(max(tableUtils.partitionColumn)).head().getString(0)
OnlineUtils.serve(tableUtils, kvStore, () => kvStore, namespace, endDs, groupBy)
val fetcher = mockApi.buildFetcher(debug = true)

val metadataStore = fetcher.metadataStore
kvStore.create(Constants.MetadataDataset)
metadataStore.putJoinConf(joinV1)

val requests = spark
.table(queryTable)
val requests = tableUtils
.loadTable(queryTable)
.where(col(tableUtils.partitionColumn) === endDs)
.where(col("user").isNotNull and col("request_id").isNotNull)
.select("user", "request_id", "ts")
Expand Down Expand Up @@ -148,7 +147,7 @@ class LogBootstrapTest extends AnyFlatSpec {
val flattenerJob = new LogFlattenerJob(spark, joinV1, endDs, mockApi.logTable, mockApi.schemaTable)
flattenerJob.buildLogTable()

val logDf = spark.table(joinV1.metaData.loggedTable)
val logDf = tableUtils.loadTable(joinV1.metaData.loggedTable)
assertEquals(logDf.count(), responses.length)

val baseJoinJob = new ai.chronon.spark.Join(baseJoinV2, endDs, tableUtils)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ class TableBootstrapTest extends AnyFlatSpec {
samplePercent: Double = 0.8,
partitionCol: String = "ds"): (BootstrapPart, DataFrame) = {
val bootstrapTable = s"$namespace.$tableName"
val preSampleBootstrapDf = spark
.table(queryTable)
val preSampleBootstrapDf = tableUtils
.loadTable(queryTable)
.select(
col("request_id"),
(rand() * 30000)
Expand Down Expand Up @@ -175,7 +175,7 @@ class TableBootstrapTest extends AnyFlatSpec {
tableUtils.createDatabase(namespace)

val queryTable = BootstrapUtils.buildQuery(namespace, spark)
val endDs = spark.table(queryTable).select(max(tableUtils.partitionColumn)).head().getString(0)
val endDs = tableUtils.loadTable(queryTable).select(max(tableUtils.partitionColumn)).head().getString(0)

val joinPart = Builders.JoinPart(groupBy = BootstrapUtils.buildGroupBy(namespace, spark))
val derivations = Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class FeatureWithLabelJoinTest extends AnyFlatSpec {
logger.info(" == First Run Label version 2022-10-30 == ")
prefixColumnName(labelDf, exceptions = labelJoinConf.rowIdentifier(null, tableUtils.partitionColumn))
.show()
val featureDf = tableUtils.sparkSession.table(joinConf.metaData.outputTable)
val featureDf = tableUtils.loadTable(joinConf.metaData.outputTable)
logger.info(" == Features == ")
featureDf.show()
val computed = tableUtils.sql(s"select * from ${joinConf.metaData.outputFinalView}")
Expand Down Expand Up @@ -141,7 +141,7 @@ class FeatureWithLabelJoinTest extends AnyFlatSpec {
logger.info(" == Label DF == ")
prefixColumnName(labelDf, exceptions = labelJoinConf.rowIdentifier(null, tableUtils.partitionColumn))
.show()
val featureDf = tableUtils.sparkSession.table(joinConf.metaData.outputTable)
val featureDf = tableUtils.loadTable(joinConf.metaData.outputTable)
logger.info(" == Features DF == ")
featureDf.show()
val computed = tableUtils.sql(s"select * from ${joinConf.metaData.outputFinalView}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ class JoinTest extends AnyFlatSpec {
)
val data = spark.createDataFrame(rows) toDF ("ds", "value")
data.write.mode(SaveMode.Overwrite).format("hive").partitionBy("ds").saveAsTable(f"${namespace}.table")
assertEquals(spark.table(f"${namespace}.table").as[TestRow].collect().toList.sorted, rows.sorted)
assertEquals(tableUtils.loadTable(f"${namespace}.table").as[TestRow].collect().toList.sorted, rows.sorted)

spark.table(f"${namespace}.table").show(truncate = false)
tableUtils.loadTable(f"${namespace}.table").show(truncate = false)

val dynamicPartitions = List(
TestRow("4", "y"),
Expand All @@ -92,14 +92,15 @@ class JoinTest extends AnyFlatSpec {
.mode(SaveMode.Overwrite)
.insertInto(f"${namespace}.table")

spark.table(f"${namespace}.table").show(truncate = false)
tableUtils.loadTable(f"${namespace}.table").show(truncate = false)

val updatedExpected =
(rows.map((r) => r.ds -> r.value).toMap ++ dynamicPartitions.map((r) => r.ds -> r.value).toMap).map {
case (k, v) => TestRow(k, v)
}.toList

assertEquals(updatedExpected.sorted, spark.table(f"${namespace}.table").as[TestRow].collect().toList.sorted)
assertEquals(updatedExpected.sorted,
tableUtils.loadTable(f"${namespace}.table").as[TestRow].collect().toList.sorted)

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,16 @@ import ai.chronon.aggregator.test.Column
import ai.chronon.api
import ai.chronon.api.Extensions._
import ai.chronon.api._
import ai.chronon.api.PartitionRange
import ai.chronon.orchestration.{
BootstrapJobArgs,
JoinDerivationJobArgs,
JoinPartJobArgs,
MergeJobArgs,
SourceJobArgs,
SourceWithFilter
}
import ai.chronon.orchestration.BootstrapJobArgs
import ai.chronon.orchestration.JoinDerivationJobArgs
import ai.chronon.orchestration.JoinPartJobArgs
import ai.chronon.orchestration.MergeJobArgs
import ai.chronon.orchestration.SourceJobArgs
import ai.chronon.orchestration.SourceWithFilter
import ai.chronon.spark.Extensions._
import ai.chronon.spark._
import ai.chronon.spark.test.{DataFrameGen, TableTestUtils}
import ai.chronon.spark.test.DataFrameGen
import ai.chronon.spark.test.TableTestUtils
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.junit.Assert._
Expand Down Expand Up @@ -129,8 +127,8 @@ class ModularJoinTest extends AnyFlatSpec {
// Make bootstrap part and table
val bootstrapSourceTable = s"$namespace.bootstrap"
val bootstrapCol = "unit_test_user_transactions_amount_dollars_sum_10d"
spark
.table(queryTable)
tableUtils
.loadTable(queryTable)
.select(
col("user"),
col("ts"),
Expand Down
Loading