diff --git a/docker-init/demo/Dockerfile b/docker-init/demo/Dockerfile new file mode 100644 index 0000000000..72bd835415 --- /dev/null +++ b/docker-init/demo/Dockerfile @@ -0,0 +1,35 @@ +FROM apache/spark:3.5.3-scala2.12-java17-ubuntu + +# Switch to root to install Java 17 +USER root + +# Install Amazon Corretto 17 +RUN apt-get update && \ + apt-get install -y wget software-properties-common gnupg2 && \ + wget -O- https://apt.corretto.aws/corretto.key --https-only | apt-key add - && \ + add-apt-repository 'deb https://apt.corretto.aws stable main' && \ + apt-get update && \ + apt-get install -y java-17-amazon-corretto-jdk && \ + update-alternatives --set java /usr/lib/jvm/java-17-amazon-corretto/bin/java && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +# Create directory and set appropriate permissions +RUN mkdir -p /opt/chronon/jars && \ + chown -R 185:185 /opt/chronon && \ + chmod 755 /opt/chronon/jars + +# Set JAVA_HOME +ENV JAVA_HOME=/usr/lib/jvm/java-17-amazon-corretto +ENV PATH=$PATH:$JAVA_HOME/bin + +# Switch back to spark user +USER 185 + +# Set environment variables for Spark classpath +ENV SPARK_CLASSPATH="/opt/spark/jars/*" +ENV SPARK_DIST_CLASSPATH="/opt/spark/jars/*" +ENV SPARK_EXTRA_CLASSPATH="/opt/spark/jars/*:/opt/chronon/jars/*" +ENV HADOOP_CLASSPATH="/opt/spark/jars/*" + +CMD ["tail", "-f", "/dev/null"] \ No newline at end of file diff --git a/docker-init/demo/README.md b/docker-init/demo/README.md new file mode 100644 index 0000000000..1d4cc951ad --- /dev/null +++ b/docker-init/demo/README.md @@ -0,0 +1,2 @@ +run build.sh once, and you can repeatedly exec +sbt spark/assembly + run.sh on iterations to the chronon code. \ No newline at end of file diff --git a/docker-init/demo/build.sh b/docker-init/demo/build.sh new file mode 100644 index 0000000000..5627dac2f5 --- /dev/null +++ b/docker-init/demo/build.sh @@ -0,0 +1 @@ +docker build -t obs . \ No newline at end of file diff --git a/docker-init/demo/run.sh b/docker-init/demo/run.sh new file mode 100755 index 0000000000..64007c8135 --- /dev/null +++ b/docker-init/demo/run.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +# Stop and remove existing container +if docker ps -a | grep -q spark-app; then + docker stop spark-app || echo "Failed to stop container" + docker rm spark-app || echo "Failed to remove container" +fi + +CHRONON_JAR_PATH="${CHRONON_JAR_PATH:-$HOME/repos/chronon/spark/target/scala-2.12}" + +if [ ! -d "$CHRONON_JAR_PATH" ]; then + echo "Error: JAR directory not found: $CHRONON_JAR_PATH" + exit 1 +fi + +# Run new container +docker run -d \ + --name spark-app \ + -v "$CHRONON_JAR_PATH":/opt/chronon/jars \ + obs + +# Submit with increased memory +docker exec spark-app \ + /opt/spark/bin/spark-submit \ + --master "local[*]" \ + --driver-memory 8g \ + --conf "spark.driver.maxResultSize=6g" \ + --conf "spark.driver.memory=8g" \ + --driver-class-path "/opt/spark/jars/*:/opt/chronon/jars/*" \ + --conf "spark.driver.host=localhost" \ + --conf "spark.driver.bindAddress=0.0.0.0" \ + --class ai.chronon.spark.scripts.ObservabilityDemo \ + /opt/chronon/jars/spark-assembly-0.1.0-SNAPSHOT.jar \ No newline at end of file diff --git a/docker-init/generate_anomalous_data.py b/docker-init/generate_anomalous_data.py index b28ef95ec8..b8143febba 100644 --- a/docker-init/generate_anomalous_data.py +++ b/docker-init/generate_anomalous_data.py @@ -5,7 +5,7 @@ from pyspark.sql.types import StructType, StructField, DoubleType, IntegerType, StringType, TimestampType, BooleanType # Initialize Spark session -spark = SparkSession.builder.appName("FraudClassificationSchema").getOrCreate() +spark = SparkSession.builder.appName("FraudClassificationSchema").config("spark.log.level", "WARN").getOrCreate() def time_to_value(t, base_value, amplitude, noise_level, scale=1): if scale is None: diff --git a/docker-init/start.sh b/docker-init/start.sh index 37b12bbf73..9f8b39d9f1 100644 --- a/docker-init/start.sh +++ b/docker-init/start.sh @@ -1,7 +1,13 @@ #!/bin/bash + +start_time=$(date +%s) if ! python3.8 generate_anomalous_data.py; then echo "Error: Failed to generate anomalous data" >&2 exit 1 +else + end_time=$(date +%s) + elapsed_time=$((end_time - start_time)) + echo "Anomalous data generated successfully! Took $elapsed_time seconds." fi @@ -11,7 +17,6 @@ if [[ ! -f $SPARK_JAR ]] || [[ ! -f $CLOUD_AWS_JAR ]]; then exit 1 fi - # Load up metadata into DynamoDB echo "Loading metadata.." if ! java -cp $SPARK_JAR:$CLASSPATH ai.chronon.spark.Driver metadata-upload --conf-path=/chronon_sample/production/ --online-jar=$CLOUD_AWS_JAR --online-class=$ONLINE_CLASS; then @@ -32,9 +37,11 @@ fi echo "DynamoDB Table created successfully!" -# Load up summary data into DynamoDB -echo "Loading Summary.." -if ! java --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/sun.security.action=ALL-UNNAMED \ +start_time=$(date +%s) + +if ! java \ + --add-opens=java.base/sun.nio.ch=ALL-UNNAMED \ + --add-opens=java.base/sun.security.action=ALL-UNNAMED \ -cp $SPARK_JAR:$CLASSPATH ai.chronon.spark.Driver summarize-and-upload \ --online-jar=$CLOUD_AWS_JAR \ --online-class=$ONLINE_CLASS \ @@ -43,8 +50,11 @@ if ! java --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/sun --time-column=transaction_time; then echo "Error: Failed to load summary data into DynamoDB" >&2 exit 1 +else + end_time=$(date +%s) + elapsed_time=$((end_time - start_time)) + echo "Summary load completed successfully! Took $elapsed_time seconds." fi -echo "Summary load completed successfully!" # Add these java options as without them we hit the below error: # throws java.lang.ClassFormatError accessible: module java.base does not "opens java.lang" to unnamed module @36328710 diff --git a/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala b/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala index 3eb11421af..9b8baa4d5e 100644 --- a/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala +++ b/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala @@ -74,7 +74,7 @@ class DriftStore(kvStore: KVStore, def getSummaries(joinConf: api.Join, startMs: Option[Long], endMs: Option[Long], - columnPrefix: Option[String] = None): Future[Seq[TileSummaryInfo]] = { + columnPrefix: Option[String]): Future[Seq[TileSummaryInfo]] = { val serializer: TSerializer = compactSerializer val tileKeyMap = tileKeysForJoin(joinConf, columnPrefix) diff --git a/spark/src/main/resources/logback.xml b/spark/src/main/resources/logback.xml new file mode 100644 index 0000000000..1af9c0bc3f --- /dev/null +++ b/spark/src/main/resources/logback.xml @@ -0,0 +1,12 @@ + + + + [%date] {%logger{0}} %level - %message%n + + + + + + + + \ No newline at end of file diff --git a/spark/src/main/scala/ai/chronon/spark/Driver.scala b/spark/src/main/scala/ai/chronon/spark/Driver.scala index fc9a5c0f0d..c465ab581c 100644 --- a/spark/src/main/scala/ai/chronon/spark/Driver.scala +++ b/spark/src/main/scala/ai/chronon/spark/Driver.scala @@ -147,7 +147,9 @@ object Driver { protected def buildSparkSession(): SparkSession = { if (localTableMapping.nonEmpty) { - val localSession = SparkSessionBuilder.build(subcommandName(), local = true, localWarehouseLocation.toOption) + val localSession = SparkSessionBuilder.build(subcommandName(), + local = true, + localWarehouseLocation = localWarehouseLocation.toOption) localTableMapping.foreach { case (table, filePath) => val file = new File(filePath) diff --git a/spark/src/main/scala/ai/chronon/spark/SparkSessionBuilder.scala b/spark/src/main/scala/ai/chronon/spark/SparkSessionBuilder.scala index e6e83b7409..7c6db6276a 100644 --- a/spark/src/main/scala/ai/chronon/spark/SparkSessionBuilder.scala +++ b/spark/src/main/scala/ai/chronon/spark/SparkSessionBuilder.scala @@ -34,6 +34,7 @@ object SparkSessionBuilder { // we would want to share locally generated warehouse during CI testing def build(name: String, local: Boolean = false, + hiveSupport: Boolean = true, localWarehouseLocation: Option[String] = None, additionalConfig: Option[Map[String, String]] = None, enforceKryoSerializer: Boolean = true): SparkSession = { @@ -44,7 +45,10 @@ object SparkSessionBuilder { var baseBuilder = SparkSession .builder() .appName(name) - .enableHiveSupport() + + if (hiveSupport) baseBuilder = baseBuilder.enableHiveSupport() + + baseBuilder = baseBuilder .config("spark.sql.session.timeZone", "UTC") //otherwise overwrite will delete ALL partitions, not just the ones it touches .config("spark.sql.sources.partitionOverwriteMode", "dynamic") diff --git a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala index e71eeef746..2ce15c16b4 100644 --- a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala @@ -76,7 +76,7 @@ case class TableUtils(sparkSession: SparkSession) { sparkSession.conf.get("spark.chronon.backfill.small_mode_cutoff", "5000").toInt val backfillValidationEnforced: Boolean = sparkSession.conf.get("spark.chronon.backfill.validation.enabled", "true").toBoolean - // Threshold to control whether or not to use bloomfilter on join backfill. If the backfill row approximate count is under this threshold, we will use bloomfilter. + // Threshold to control whether to use bloomfilter on join backfill. If the backfill row approximate count is under this threshold, we will use bloomfilter. // default threshold is 100K rows val bloomFilterThreshold: Long = sparkSession.conf.get("spark.chronon.backfill.bloomfilter.threshold", "1000000").toLong diff --git a/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemo.scala b/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemo.scala new file mode 100644 index 0000000000..66694a1e19 --- /dev/null +++ b/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemo.scala @@ -0,0 +1,192 @@ +package ai.chronon.spark.scripts + +import ai.chronon +import ai.chronon.api.ColorPrinter.ColorString +import ai.chronon.api.Constants +import ai.chronon.api.DriftMetric +import ai.chronon.api.Extensions.MetadataOps +import ai.chronon.api.PartitionSpec +import ai.chronon.api.TileDriftSeries +import ai.chronon.api.TileSummarySeries +import ai.chronon.api.Window +import ai.chronon.online.KVStore +import ai.chronon.online.stats.DriftStore +import ai.chronon.spark.SparkSessionBuilder +import ai.chronon.spark.TableUtils +import ai.chronon.spark.stats.drift.Summarizer +import ai.chronon.spark.stats.drift.SummaryUploader +import ai.chronon.spark.stats.drift.scripts.PrepareData +import ai.chronon.spark.utils.InMemoryKvStore +import ai.chronon.spark.utils.MockApi +import org.rogach.scallop.ScallopConf +import org.rogach.scallop.ScallopOption +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +import java.util.concurrent.TimeUnit +import scala.concurrent.Await +import scala.concurrent.duration.Duration +import scala.util.ScalaJavaConversions.IteratorOps + +object ObservabilityDemo { + @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) + + def time(message: String)(block: => Unit): Unit = { + logger.info(s"$message..".yellow) + val start = System.currentTimeMillis() + block + val end = System.currentTimeMillis() + logger.info(s"$message took ${end - start} ms".green) + } + + class Conf(arguments: Seq[String]) extends ScallopConf(arguments) { + val startDs: ScallopOption[String] = opt[String]( + name = "start-ds", + default = Some("2023-01-01"), + descr = "Start date in YYYY-MM-DD format" + ) + + val endDs: ScallopOption[String] = opt[String]( + name = "end-ds", + default = Some("2023-02-30"), + descr = "End date in YYYY-MM-DD format" + ) + + val rowCount: ScallopOption[Int] = opt[Int]( + name = "row-count", + default = Some(700000), + descr = "Number of rows to generate" + ) + + val namespace: ScallopOption[String] = opt[String]( + name = "namespace", + default = Some("observability_demo"), + descr = "Namespace for the demo" + ) + + verify() + } + + def main(args: Array[String]): Unit = { + + val config = new Conf(args) + val startDs = config.startDs() + val endDs = config.endDs() + val rowCount = config.rowCount() + val namespace = config.namespace() + + val spark = SparkSessionBuilder.build(namespace, local = true) + implicit val tableUtils: TableUtils = TableUtils(spark) + tableUtils.createDatabase(namespace) + + // generate anomalous data (join output) + val prepareData = PrepareData(namespace) + val join = prepareData.generateAnomalousFraudJoin + + time("Preparing data") { + val df = prepareData.generateFraudSampleData(rowCount, startDs, endDs, join.metaData.loggedTable) + df.show(10, truncate = false) + } + + time("Summarizing data") { + // compute summary table and packed table (for uploading) + Summarizer.compute(join.metaData, ds = endDs, useLogs = true) + } + + val packedTable = join.metaData.packedSummaryTable + // mock api impl for online fetching and uploading + val kvStoreFunc: () => KVStore = () => { + // cannot reuse the variable - or serialization error + val result = InMemoryKvStore.build(namespace, () => null) + result + } + val api = new MockApi(kvStoreFunc, namespace) + + // create necessary tables in kvstore + val kvStore = api.genKvStore + kvStore.create(Constants.MetadataDataset) + kvStore.create(Constants.TiledSummaryDataset) + + // upload join conf + api.buildFetcher().putJoinConf(join) + + time("Uploading summaries") { + val uploader = new SummaryUploader(tableUtils.loadTable(packedTable), api) + uploader.run() + } + + // test drift store methods + val driftStore = new DriftStore(api.genKvStore) + + // TODO: Wire up drift store into hub and create an endpoint + + // fetch keys + val tileKeys = driftStore.tileKeysForJoin(join) + val tileKeysSimple = tileKeys.mapValues(_.map(_.column).toSeq) + tileKeysSimple.foreach { case (k, v) => logger.info(s"$k -> [${v.mkString(", ")}]") } + + // fetch summaries + val startMs = PartitionSpec.daily.epochMillis(startDs) + val endMs = PartitionSpec.daily.epochMillis(endDs) + val summariesFuture = driftStore.getSummaries(join, Some(startMs), Some(endMs), None) + val summaries = Await.result(summariesFuture, Duration.create(10, TimeUnit.SECONDS)) + logger.info(summaries.toString()) + + var driftSeries: Seq[TileDriftSeries] = null + // fetch drift series + time("Fetching drift series") { + val driftSeriesFuture = driftStore.getDriftSeries( + join.metaData.nameToFilePath, + DriftMetric.JENSEN_SHANNON, + lookBack = new Window(7, chronon.api.TimeUnit.DAYS), + startMs, + endMs + ) + driftSeries = Await.result(driftSeriesFuture.get, Duration.create(10, TimeUnit.SECONDS)) + } + + val (nulls, totals) = driftSeries.iterator.foldLeft(0 -> 0) { + case ((nulls, total), s) => + val currentNulls = s.getPercentileDriftSeries.iterator().toScala.count(_ == null) + val currentCount = s.getPercentileDriftSeries.size() + (nulls + currentNulls, total + currentCount) + } + + logger.info(s"""drift totals: $totals + |drift nulls: $nulls + |""".stripMargin.red) + + logger.info("Drift series fetched successfully".green) + + var summarySeries: Seq[TileSummarySeries] = null + + time("Fetching summary series") { + val summarySeriesFuture = driftStore.getSummarySeries( + join.metaData.nameToFilePath, + startMs, + endMs + ) + summarySeries = Await.result(summarySeriesFuture.get, Duration.create(10, TimeUnit.SECONDS)) + } + + val (summaryNulls, summaryTotals) = summarySeries.iterator.foldLeft(0 -> 0) { + case ((nulls, total), s) => + if (s.getPercentiles == null) { + (nulls + 1) -> (total + 1) + } else { + val currentNulls = s.getPercentiles.iterator().toScala.count(_ == null) + val currentCount = s.getPercentiles.size() + (nulls + currentNulls, total + currentCount) + } + } + + println(s"""summary ptile totals: $summaryTotals + |summary ptile nulls: $summaryNulls + |""".stripMargin) + + logger.info("Summary series fetched successfully".green) + + spark.stop() + System.exit(0) + } +} diff --git a/spark/src/test/scala/ai/chronon/spark/test/stats/drift/PrepareData.scala b/spark/src/main/scala/ai/chronon/spark/stats/drift/scripts/PrepareData.scala similarity index 85% rename from spark/src/test/scala/ai/chronon/spark/test/stats/drift/PrepareData.scala rename to spark/src/main/scala/ai/chronon/spark/stats/drift/scripts/PrepareData.scala index 522de0c5c1..2fba407302 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/stats/drift/PrepareData.scala +++ b/spark/src/main/scala/ai/chronon/spark/stats/drift/scripts/PrepareData.scala @@ -1,4 +1,4 @@ -package ai.chronon.spark.test.stats.drift +package ai.chronon.spark.stats.drift.scripts import ai.chronon.api import ai.chronon.api.Builders @@ -21,18 +21,13 @@ import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.functions.col import org.apache.spark.sql.functions.date_format import org.apache.spark.sql.functions.from_unixtime -import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types._ import org.apache.spark.sql.{Row => SRow} import java.nio.charset.StandardCharsets import java.nio.file.Files import java.nio.file.Paths -import java.time.Duration -import java.time.LocalDate -import java.time.LocalDateTime -import java.time.LocalTime -import java.time.ZoneOffset +import java.time._ import java.time.format.DateTimeFormatter import java.time.temporal.ChronoUnit import scala.collection.mutable.ListBuffer @@ -52,7 +47,13 @@ case class PrepareData(namespace: String)(implicit tableUtils: TableUtils) { val merchant_source = Builders.Source.entities( query = Builders.Query( - selects = Seq("merchant_id", "account_age", "zipcode", "is_big_merchant", "country", "account_type", "preferred_language").map(s => s->s).toMap + selects = Seq("merchant_id", + "account_age", + "zipcode", + "is_big_merchant", + "country", + "account_type", + "preferred_language").map(s => s -> s).toMap ), snapshotTable = "data.merchants" ) @@ -63,13 +64,13 @@ case class PrepareData(namespace: String)(implicit tableUtils: TableUtils) { ) def createTransactionSource(key: String): api.Source = { - Builders.Source.events( - query = Builders.Query( - selects = Seq(key, "transaction_amount", "transaction_type").map(s => s->s).toMap, - timeColumn = "transaction_time" - ), - table = "data.txn_events" - ) + Builders.Source.events( + query = Builders.Query( + selects = Seq(key, "transaction_amount", "transaction_type").map(s => s -> s).toMap, + timeColumn = "transaction_time" + ), + table = "data.txn_events" + ) } def createTxnGroupBy(source: api.Source, key: String, name: String): api.GroupBy = { @@ -110,7 +111,14 @@ case class PrepareData(namespace: String)(implicit tableUtils: TableUtils) { val userSource = Builders.Source.entities( query = Builders.Query( - selects = Seq("user_id", "account_age", "account_balance", "credit_score", "number_of_devices", "country", "account_type", "preferred_language").map(s => s->s).toMap + selects = Seq("user_id", + "account_age", + "account_balance", + "credit_score", + "number_of_devices", + "country", + "account_type", + "preferred_language").map(s => s -> s).toMap ), snapshotTable = "data.users" ) @@ -121,12 +129,12 @@ case class PrepareData(namespace: String)(implicit tableUtils: TableUtils) { ) // TODO: this is inconsistent with the defn of userSource above - but to maintain portability - we will keep it as is - val joinUserSource = Builders.Source.events( + val joinUserSource = Builders.Source.events( query = Builders.Query( - selects = Seq("user_id", "ts").map(s => s->s).toMap, + selects = Seq("user_id", "ts").map(s => s -> s).toMap, timeColumn = "ts" ), - table = "data.users", + table = "data.users" ) val driftSpec = new DriftSpec() @@ -152,7 +160,11 @@ case class PrepareData(namespace: String)(implicit tableUtils: TableUtils) { ) } - def timeToValue(t: LocalTime, baseValue: Double, amplitude: Double, noiseLevel: Double, scale: Double = 1.0): java.lang.Double = { + def timeToValue(t: LocalTime, + baseValue: Double, + amplitude: Double, + noiseLevel: Double, + scale: Double = 1.0): java.lang.Double = { if (scale == 0) null else { val hours = t.getHour + t.getMinute / 60.0 + t.getSecond / 3600.0 @@ -170,7 +182,9 @@ case class PrepareData(namespace: String)(implicit tableUtils: TableUtils) { } } - def generateNonOverlappingWindows(startDate: LocalDate, endDate: LocalDate, numWindows: Int): List[(LocalDate, LocalDate)] = { + def generateNonOverlappingWindows(startDate: LocalDate, + endDate: LocalDate, + numWindows: Int): List[(LocalDate, LocalDate)] = { val totalDays = ChronoUnit.DAYS.between(startDate, endDate).toInt val windowLengths = List.fill(numWindows)(RandomUtils.between(3, 8)) val maxGap = totalDays - windowLengths.sum @@ -194,19 +208,15 @@ case class PrepareData(namespace: String)(implicit tableUtils: TableUtils) { windows.toList } - - - case class DataWithTime(ts: LocalDateTime, value: java.lang.Double) case class TimeSeriesWithAnomalies(dataWithTime: Array[DataWithTime], nullWindow: (LocalDate, LocalDate), spikeWindow: (LocalDate, LocalDate)) def generateTimeseriesWithAnomalies(numSamples: Int = 1000, - baseValue: Double = 100, - amplitude: Double = 50, - noiseLevel: Double = 10 - ): TimeSeriesWithAnomalies = { + baseValue: Double = 100, + amplitude: Double = 50, + noiseLevel: Double = 10): TimeSeriesWithAnomalies = { val startDate = LocalDate.of(2023, 1, 1) val endDate = LocalDate.of(2023, 12, 31) @@ -245,17 +255,14 @@ case class PrepareData(namespace: String)(implicit tableUtils: TableUtils) { } } - private val fraudFields = Array( // join.source - txn_events StructField("user_id", IntegerType, nullable = true), StructField("merchant_id", IntegerType, nullable = true), - // Contextual - 3 StructField("transaction_amount", DoubleType, nullable = true), StructField("transaction_time", LongType, nullable = true), StructField("transaction_type", StringType, nullable = true), - // Transactions agg'd by user - 5 (txn_events) StructField("transaction_amount_average", DoubleType, nullable = true).prefix(txnByUser), StructField("transaction_amount_count_1h", IntegerType, nullable = true).prefix(txnByUser), @@ -264,7 +271,6 @@ case class PrepareData(namespace: String)(implicit tableUtils: TableUtils) { StructField("transaction_amount_count_30d", IntegerType, nullable = true).prefix(txnByUser), StructField("transaction_amount_count_365d", IntegerType, nullable = true).prefix(txnByUser), StructField("transaction_amount_sum_1h", DoubleType, nullable = true).prefix(txnByUser), - // Transactions agg'd by merchant - 7 (txn_events) StructField("transaction_amount_average", DoubleType, nullable = true).prefix(txnByMerchant), StructField("transaction_amount_count_1h", IntegerType, nullable = true).prefix(txnByMerchant), @@ -273,7 +279,6 @@ case class PrepareData(namespace: String)(implicit tableUtils: TableUtils) { StructField("transaction_amount_count_30d", IntegerType, nullable = true).prefix(txnByMerchant), StructField("transaction_amount_count_365d", IntegerType, nullable = true).prefix(txnByMerchant), StructField("transaction_amount_sum_1h", DoubleType, nullable = true).prefix(txnByMerchant), - // User features (dim_user) – 7 StructField("account_age", IntegerType, nullable = true).prefix(dimUser), StructField("account_balance", DoubleType, nullable = true).prefix(dimUser), @@ -282,7 +287,6 @@ case class PrepareData(namespace: String)(implicit tableUtils: TableUtils) { StructField("country", StringType, nullable = true).prefix(dimUser), StructField("account_type", IntegerType, nullable = true).prefix(dimUser), StructField("preferred_language", StringType, nullable = true).prefix(dimUser), - // merchant features (dim_merchant) – 4 StructField("account_age", IntegerType, nullable = true).prefix(dimMerchant), StructField("zipcode", IntegerType, nullable = true).prefix(dimMerchant), @@ -291,7 +295,6 @@ case class PrepareData(namespace: String)(implicit tableUtils: TableUtils) { StructField("country", StringType, nullable = true).prefix(dimMerchant), StructField("account_type", IntegerType, nullable = true).prefix(dimMerchant), StructField("preferred_language", StringType, nullable = true).prefix(dimMerchant), - // derived features - transactions_last_year / account_age - 1 StructField("transaction_frequency_last_year", DoubleType, nullable = true) ) @@ -324,7 +327,7 @@ case class PrepareData(namespace: String)(implicit tableUtils: TableUtils) { val timeDelta = Duration.between(startDate, endDate).dividedBy(numSamples) - val anomalyWindows = generateNonOverlappingWindows(startDate.toLocalDate, endDate.toLocalDate, 2) + val anomalyWindows = generateNonOverlappingWindows(startDate.toLocalDate, endDate.toLocalDate, 2) // Generate base values val transactionAmount = generateTimeseriesWithAnomalies(numSamples, 100, 50, 10) @@ -339,7 +342,7 @@ case class PrepareData(namespace: String)(implicit tableUtils: TableUtils) { val transactionTime = startDate.plus(timeDelta.multipliedBy(i)) val merchantId = Random.nextInt(250) + 1 - if(i % 100000 == 0) { + if (i % 100000 == 0) { println(s"Generated $i/$numSamples rows of data.") } @@ -348,9 +351,9 @@ case class PrepareData(namespace: String)(implicit tableUtils: TableUtils) { val isSlowDrift = transactionTime.isAfter(anomalyWindows(1)._1.atStartOfDay) && transactionTime.isBefore(anomalyWindows(1)._2.atTime(23, 59)) - val driftFactor = if(isFastDrift) 10 else if(isSlowDrift) 1.05 else 1.0 + val driftFactor = if (isFastDrift) 10 else if (isSlowDrift) 1.05 else 1.0 - def genTuple(lastHour: java.lang.Double): (Integer,Integer,Integer,Integer,Integer) = { + def genTuple(lastHour: java.lang.Double): (Integer, Integer, Integer, Integer, Integer) = { lastHour match { case x if x == null => (null, null, null, null, null) case x => @@ -373,15 +376,20 @@ case class PrepareData(namespace: String)(implicit tableUtils: TableUtils) { val userAccountAge = Random.nextInt(3650) + 1 + val (adjustedUserLastHour, + adjustedUserLastDay, + adjustedUserLastWeek, + adjustedUserLastMonth, + adjustedUserLastYear) = genTuple(userLastHourList.dataWithTime(i).value) + val (adjustedMerchantLastHour, + adjustedMerchantLastDay, + adjustedMerchantLastWeek, + adjustedMerchantLastMonth, + adjustedMerchantLastYear) = genTuple(merchantLastHourList.dataWithTime(i).value) - val (adjustedUserLastHour, adjustedUserLastDay, adjustedUserLastWeek, adjustedUserLastMonth, adjustedUserLastYear) - = genTuple(userLastHourList.dataWithTime(i).value) - - val (adjustedMerchantLastHour, adjustedMerchantLastDay, adjustedMerchantLastWeek, adjustedMerchantLastMonth, adjustedMerchantLastYear) - = genTuple(merchantLastHourList.dataWithTime(i).value) - - val arr = Array(Random.nextInt(100) + 1, + val arr = Array( + Random.nextInt(100) + 1, merchantId, transactionAmount.dataWithTime(i).value, transactionTime.toEpochSecond(ZoneOffset.UTC) * 1000, @@ -425,22 +433,21 @@ case class PrepareData(namespace: String)(implicit tableUtils: TableUtils) { val spark = tableUtils.sparkSession val rdd: RDD[SRow] = spark.sparkContext.parallelize(data) val df = spark.createDataFrame(rdd, fraudSchema) - val dfWithTimeConvention = df.withColumn(Constants.TimeColumn, col("transaction_time")) - .withColumn(tableUtils.partitionColumn, date_format(from_unixtime(col(Constants.TimeColumn) / 1000), tableUtils.partitionSpec.format)) + val dfWithTimeConvention = df + .withColumn(Constants.TimeColumn, col("transaction_time")) + .withColumn(tableUtils.partitionColumn, + date_format(from_unixtime(col(Constants.TimeColumn) / 1000), tableUtils.partitionSpec.format)) dfWithTimeConvention.save(outputTable) println(s"Successfully wrote fraud data to table. ${outputTable.yellow}") - dfWithTimeConvention } - def isWithinWindow(date: LocalDate, window: (LocalDate, LocalDate)): Boolean = { !date.isBefore(window._1) && !date.isAfter(window._2) } - // dummy code below to write to spark def expandTilde(path: String): String = { if (path.startsWith("~" + java.io.File.separator)) { @@ -467,4 +474,4 @@ case class PrepareData(namespace: String)(implicit tableUtils: TableUtils) { val prettyJsonString = gson.toJson(jsonObject) prettyJsonString } -} \ No newline at end of file +} diff --git a/spark/src/test/scala/ai/chronon/spark/test/InMemoryKvStore.scala b/spark/src/main/scala/ai/chronon/spark/utils/InMemoryKvStore.scala similarity index 98% rename from spark/src/test/scala/ai/chronon/spark/test/InMemoryKvStore.scala rename to spark/src/main/scala/ai/chronon/spark/utils/InMemoryKvStore.scala index 9a3b69116a..6b571e5465 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/InMemoryKvStore.scala +++ b/spark/src/main/scala/ai/chronon/spark/utils/InMemoryKvStore.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package ai.chronon.spark.test +package ai.chronon.spark.utils import ai.chronon.api.Constants import ai.chronon.online.KVStore @@ -118,7 +118,7 @@ class InMemoryKvStore(tableUtils: () => TableUtils) extends KVStore with Seriali } override def create(dataset: String): Unit = { - database.computeIfAbsent(dataset, _ => new ConcurrentHashMap[Key, VersionedData]) + database.computeIfAbsent(dataset, _ => new ConcurrentHashMap[Key, VersionedData]) } def show(): Unit = { diff --git a/spark/src/test/scala/ai/chronon/spark/test/InMemoryStream.scala b/spark/src/main/scala/ai/chronon/spark/utils/InMemoryStream.scala similarity index 99% rename from spark/src/test/scala/ai/chronon/spark/test/InMemoryStream.scala rename to spark/src/main/scala/ai/chronon/spark/utils/InMemoryStream.scala index 0754f209d8..d3e091d575 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/InMemoryStream.scala +++ b/spark/src/main/scala/ai/chronon/spark/utils/InMemoryStream.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package ai.chronon.spark.test +package ai.chronon.spark.utils import ai.chronon.api.Constants import ai.chronon.api.GroupBy diff --git a/spark/src/test/scala/ai/chronon/spark/test/MockApi.scala b/spark/src/main/scala/ai/chronon/spark/utils/MockApi.scala similarity index 99% rename from spark/src/test/scala/ai/chronon/spark/test/MockApi.scala rename to spark/src/main/scala/ai/chronon/spark/utils/MockApi.scala index b270570c20..e5f9203645 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/MockApi.scala +++ b/spark/src/main/scala/ai/chronon/spark/utils/MockApi.scala @@ -14,14 +14,13 @@ * limitations under the License. */ -package ai.chronon.spark.test +package ai.chronon.spark.utils import ai.chronon.api.Constants import ai.chronon.api.Extensions.GroupByOps import ai.chronon.api.Extensions.SourceOps import ai.chronon.api.StructType import ai.chronon.online.Fetcher.Response -import ai.chronon.online.Serde import ai.chronon.online._ import ai.chronon.spark.Extensions._ import ai.chronon.spark.TableUtils diff --git a/spark/src/test/scala/ai/chronon/spark/test/ChainingFetcherTest.scala b/spark/src/test/scala/ai/chronon/spark/test/ChainingFetcherTest.scala index 038826ad25..6af7150d1a 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/ChainingFetcherTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/ChainingFetcherTest.scala @@ -26,6 +26,7 @@ import ai.chronon.online.Fetcher.Request import ai.chronon.online.MetadataStore import ai.chronon.online.SparkConversions import ai.chronon.spark.Extensions._ +import ai.chronon.spark.utils.MockApi import ai.chronon.spark.{Join => _, _} import junit.framework.TestCase import org.apache.spark.sql.DataFrame diff --git a/spark/src/test/scala/ai/chronon/spark/test/ExternalSourcesTest.scala b/spark/src/test/scala/ai/chronon/spark/test/ExternalSourcesTest.scala index 5d6d487462..de0282ba2b 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/ExternalSourcesTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/ExternalSourcesTest.scala @@ -19,6 +19,7 @@ import ai.chronon.api.Constants.MetadataDataset import ai.chronon.api._ import ai.chronon.online.Fetcher.Request import ai.chronon.spark.LoggingSchema +import ai.chronon.spark.utils.MockApi import org.junit.Assert._ import org.junit.Test diff --git a/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala b/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala index 0abee28e8b..1e5c8b4eff 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala @@ -35,6 +35,7 @@ import ai.chronon.online.MetadataStore import ai.chronon.online.SparkConversions import ai.chronon.spark.Extensions._ import ai.chronon.spark.stats.ConsistencyJob +import ai.chronon.spark.utils.MockApi import ai.chronon.spark.{Join => _, _} import com.google.gson.GsonBuilder import org.apache.spark.sql.DataFrame diff --git a/spark/src/test/scala/ai/chronon/spark/test/GroupByUploadTest.scala b/spark/src/test/scala/ai/chronon/spark/test/GroupByUploadTest.scala index d1213e4889..42c85adff9 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/GroupByUploadTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/GroupByUploadTest.scala @@ -25,6 +25,7 @@ import ai.chronon.spark.Extensions.DataframeOps import ai.chronon.spark.GroupByUpload import ai.chronon.spark.SparkSessionBuilder import ai.chronon.spark.TableUtils +import ai.chronon.spark.utils.MockApi import com.google.gson.Gson import org.apache.spark.sql.SparkSession import org.junit.Assert.assertEquals diff --git a/spark/src/test/scala/ai/chronon/spark/test/JavaFetcherTest.java b/spark/src/test/scala/ai/chronon/spark/test/JavaFetcherTest.java index 2bbefaede3..fad6dd688b 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/JavaFetcherTest.java +++ b/spark/src/test/scala/ai/chronon/spark/test/JavaFetcherTest.java @@ -22,6 +22,8 @@ import ai.chronon.online.Fetcher; import ai.chronon.spark.TableUtils; import ai.chronon.spark.SparkSessionBuilder; +import ai.chronon.spark.utils.InMemoryKvStore; +import ai.chronon.spark.utils.MockApi; import com.google.gson.Gson; import org.apache.spark.sql.SparkSession; import org.junit.Test; @@ -39,7 +41,7 @@ public class JavaFetcherTest { String namespace = "java_fetcher_test"; - SparkSession session = SparkSessionBuilder.build(namespace, true, scala.Option.apply(null), scala.Option.apply(null), true); + SparkSession session = SparkSessionBuilder.build(namespace, true, true, scala.Option.apply(null), scala.Option.apply(null), true); TableUtils tu = new TableUtils(session); InMemoryKvStore kvStore = new InMemoryKvStore(func(() -> tu)); MockApi mockApi = new MockApi(func(() -> kvStore), "java_fetcher_test"); diff --git a/spark/src/test/scala/ai/chronon/spark/test/LocalDataLoaderTest.scala b/spark/src/test/scala/ai/chronon/spark/test/LocalDataLoaderTest.scala index b9b4310a86..07769d8f9c 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/LocalDataLoaderTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/LocalDataLoaderTest.scala @@ -35,7 +35,7 @@ object LocalDataLoaderTest { val spark: SparkSession = SparkSessionBuilder.build( "LocalDataLoaderTest", local = true, - Some(tmpDir.getPath)) + localWarehouseLocation = Some(tmpDir.getPath)) @AfterClass def teardown(): Unit = { diff --git a/spark/src/test/scala/ai/chronon/spark/test/LocalTableExporterTest.scala b/spark/src/test/scala/ai/chronon/spark/test/LocalTableExporterTest.scala index fb099e7a2b..7ccfa92ce9 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/LocalTableExporterTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/LocalTableExporterTest.scala @@ -41,7 +41,7 @@ import java.io.File object LocalTableExporterTest { val tmpDir: File = Files.createTempDir() - val spark: SparkSession = SparkSessionBuilder.build("LocalTableExporterTest", local = true, Some(tmpDir.getPath)) + val spark: SparkSession = SparkSessionBuilder.build("LocalTableExporterTest", local = true, localWarehouseLocation = Some(tmpDir.getPath)) @AfterClass def teardown(): Unit = { diff --git a/spark/src/test/scala/ai/chronon/spark/test/OnlineUtils.scala b/spark/src/test/scala/ai/chronon/spark/test/OnlineUtils.scala index fe73ad641a..3335ab4f46 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/OnlineUtils.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/OnlineUtils.scala @@ -31,6 +31,9 @@ import ai.chronon.spark.SparkSessionBuilder import ai.chronon.spark.TableUtils import ai.chronon.spark.streaming.GroupBy import ai.chronon.spark.streaming.JoinSourceRunner +import ai.chronon.spark.utils.InMemoryKvStore +import ai.chronon.spark.utils.InMemoryStream +import ai.chronon.spark.utils.MockApi import org.apache.spark.sql.SparkSession import org.apache.spark.sql.streaming.Trigger diff --git a/spark/src/test/scala/ai/chronon/spark/test/SchemaEvolutionTest.scala b/spark/src/test/scala/ai/chronon/spark/test/SchemaEvolutionTest.scala index 2a00b6197e..dd86f5f0eb 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/SchemaEvolutionTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/SchemaEvolutionTest.scala @@ -26,6 +26,8 @@ import ai.chronon.spark.LogFlattenerJob import ai.chronon.spark.LoggingSchema import ai.chronon.spark.SparkSessionBuilder import ai.chronon.spark.TableUtils +import ai.chronon.spark.utils.InMemoryKvStore +import ai.chronon.spark.utils.MockApi import junit.framework.TestCase import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Row diff --git a/spark/src/test/scala/ai/chronon/spark/test/SchemaEvolutionUtils.scala b/spark/src/test/scala/ai/chronon/spark/test/SchemaEvolutionUtils.scala index 339a26aa7d..ab1b819e37 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/SchemaEvolutionUtils.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/SchemaEvolutionUtils.scala @@ -19,6 +19,7 @@ package ai.chronon.spark.test import ai.chronon.spark.LogUtils import ai.chronon.spark.SparkSessionBuilder import ai.chronon.spark.TableUtils +import ai.chronon.spark.utils.MockApi object SchemaEvolutionUtils { def runLogSchemaGroupBy(mockApi: MockApi, ds: String, backfillStartDate: String): Unit = { diff --git a/spark/src/test/scala/ai/chronon/spark/test/StreamingTest.scala b/spark/src/test/scala/ai/chronon/spark/test/StreamingTest.scala index 749f2dd9e4..963ca4dc80 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/StreamingTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/StreamingTest.scala @@ -27,6 +27,7 @@ import ai.chronon.api.Window import ai.chronon.online.MetadataStore import ai.chronon.spark.Extensions._ import ai.chronon.spark.test.StreamingTest.buildInMemoryKvStore +import ai.chronon.spark.utils.InMemoryKvStore import ai.chronon.spark.{Join => _, _} import junit.framework.TestCase import org.apache.spark.sql.SparkSession diff --git a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/DerivationTest.scala b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/DerivationTest.scala index f5fae6fc0e..4e9e8a0ffa 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/DerivationTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/DerivationTest.scala @@ -23,9 +23,9 @@ import ai.chronon.online.Fetcher.Request import ai.chronon.online.MetadataStore import ai.chronon.spark.Extensions.DataframeOps import ai.chronon.spark._ -import ai.chronon.spark.test.MockApi import ai.chronon.spark.test.OnlineUtils import ai.chronon.spark.test.SchemaEvolutionUtils +import ai.chronon.spark.utils.MockApi import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions._ import org.junit.Assert.assertEquals diff --git a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/LogBootstrapTest.scala b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/LogBootstrapTest.scala index 28ea481763..2bb7807b71 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/LogBootstrapTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/LogBootstrapTest.scala @@ -25,9 +25,9 @@ import ai.chronon.spark.Extensions._ import ai.chronon.spark.LogFlattenerJob import ai.chronon.spark.SparkSessionBuilder import ai.chronon.spark.TableUtils -import ai.chronon.spark.test.MockApi import ai.chronon.spark.test.OnlineUtils import ai.chronon.spark.test.SchemaEvolutionUtils +import ai.chronon.spark.utils.MockApi import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions._ import org.junit.Assert.assertEquals diff --git a/spark/src/test/scala/ai/chronon/spark/test/stats/drift/DriftTest.scala b/spark/src/test/scala/ai/chronon/spark/test/stats/drift/DriftTest.scala index 99405c8aad..0584ed7da4 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/stats/drift/DriftTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/stats/drift/DriftTest.scala @@ -13,8 +13,9 @@ import ai.chronon.spark.SparkSessionBuilder import ai.chronon.spark.TableUtils import ai.chronon.spark.stats.drift.Summarizer import ai.chronon.spark.stats.drift.SummaryUploader -import ai.chronon.spark.test.InMemoryKvStore -import ai.chronon.spark.test.MockApi +import ai.chronon.spark.stats.drift.scripts.PrepareData +import ai.chronon.spark.utils.InMemoryKvStore +import ai.chronon.spark.utils.MockApi import org.apache.spark.sql.SparkSession import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers @@ -88,7 +89,7 @@ class DriftTest extends AnyFlatSpec with Matchers { // fetch summaries val startMs = PartitionSpec.daily.epochMillis("2023-01-01") val endMs = PartitionSpec.daily.epochMillis("2023-01-29") - val summariesFuture = driftStore.getSummaries(join, Some(startMs), Some(endMs)) + val summariesFuture = driftStore.getSummaries(join, Some(startMs), Some(endMs), None) val summaries = Await.result(summariesFuture, Duration.create(10, TimeUnit.SECONDS)) println(summaries)