From 76518c003538c79e2f397530171bcbb528319f64 Mon Sep 17 00:00:00 2001 From: Hui Cheng Date: Mon, 29 Sep 2025 14:03:08 -0700 Subject: [PATCH 01/13] Add and validate offline groupby option for external source --- .../scala/ai/chronon/api/Extensions.scala | 87 +++++++++++++++ .../ai/chronon/api/test/ExtensionsTest.scala | 101 +++++++++++++++++- api/thrift/api.thrift | 2 + .../scala/ai/chronon/spark/Analyzer.scala | 13 ++- 4 files changed, 199 insertions(+), 4 deletions(-) diff --git a/api/src/main/scala/ai/chronon/api/Extensions.scala b/api/src/main/scala/ai/chronon/api/Extensions.scala index 873b9bacea..ca4aa51761 100644 --- a/api/src/main/scala/ai/chronon/api/Extensions.scala +++ b/api/src/main/scala/ai/chronon/api/Extensions.scala @@ -686,6 +686,83 @@ object Extensions { lazy val valueFields: Array[StructField] = schemaFields(externalSource.valueSchema) def isContextualSource: Boolean = externalSource.metadata.name == Constants.ContextualSourceName + + /** + * Validates schema compatibility between ExternalSource and its offlineGroupBy. + * This ensures that online and offline serving will produce consistent results. + * + * @return Sequence of error messages, empty if no errors + */ + def validateOfflineGroupBy(): Seq[String] = Option(externalSource.offlineGroupBy) + .map(_ => validateKeySchemaCompatibility() ++ validateValueSchemaCompatibility()) + .getOrElse(Seq.empty) + + private def validateKeySchemaCompatibility(): Seq[String] = { + val errors = scala.collection.mutable.ListBuffer[String]() + + if (externalSource.keySchema == null) { + errors += s"ExternalSource ${externalSource.metadata.name} keySchema cannot be null when offlineGroupBy is specified" + return errors + } + + if (externalSource.offlineGroupBy.keyColumns == null || externalSource.offlineGroupBy.keyColumns.isEmpty) { + errors += s"ExternalSource ${externalSource.metadata.name} offlineGroupBy keyColumns cannot be null or empty" + return errors + } + + val externalKeyFields = keyFields + val groupByKeyColumns = externalSource.offlineGroupBy.keyColumns.toScala.toSet + + // Extract field names from external source key schema + val externalKeyNames = externalKeyFields.map(_.name).toSet + + // Validate that GroupBy has key columns that match ExternalSource key schema + val missingKeys = externalKeyNames -- groupByKeyColumns + val extraKeys = groupByKeyColumns -- externalKeyNames + + if (missingKeys.nonEmpty) { + errors += s"ExternalSource ${externalSource.metadata.name} key schema contains columns [${missingKeys.mkString(", ")}] " + + s"that are not present in offlineGroupBy keyColumns [${groupByKeyColumns.mkString(", ")}]. " + + s"All ExternalSource key columns must be present in the GroupBy key columns." + } + + if (extraKeys.nonEmpty) { + errors += s"ExternalSource ${externalSource.metadata.name} offlineGroupBy keyColumns contain [${extraKeys.mkString(", ")}] " + + s"that are not present in ExternalSource keySchema [${externalKeyNames.mkString(", ")}]. " + + s"GroupBy key columns cannot contain keys not defined in ExternalSource keySchema." + } + + errors + } + + private def validateValueSchemaCompatibility(): Seq[String] = { + val errors = scala.collection.mutable.ListBuffer[String]() + + if (externalSource.valueSchema == null) { + errors += s"ExternalSource ${externalSource.metadata.name} valueSchema cannot be null when offlineGroupBy is specified" + return errors + } + + val externalValueFields = valueFields + val externalValueNames = externalValueFields.map(_.name).toSet + + // For GroupBy value schema, we need to derive the output schema from aggregations + val groupByValueColumns = externalSource.offlineGroupBy.valueColumns.toSet + + // Check that ExternalSource value schema fields are compatible with GroupBy output + val missingValueColumns = externalValueNames -- groupByValueColumns + + if (missingValueColumns.nonEmpty) { + // This is an error because ExternalSource valueSchema must be compatible with GroupBy output + // to ensure consistency between online and offline serving + errors += s"ExternalSource ${externalSource.metadata.name} valueSchema contains columns [${missingValueColumns.mkString(", ")}] " + + s"that are not present in offlineGroupBy output columns [${groupByValueColumns.mkString(", ")}]. " + + s"This indicates schema incompatibility between online and offline serving. " + + s"Please ensure ExternalSource valueSchema matches the expected output of the GroupBy aggregations." + } + + errors + } } object KeyMappingHelper { @@ -951,6 +1028,16 @@ object Extensions { .getOrElse(Seq.empty) } + /** + * Validates all ExternalSources in this Join's onlineExternalParts. + * This ensures schema compatibility between ExternalSources and their offlineGroupBy configurations. + * + * @return Sequence of error messages, empty if no errors + */ + def validateExternalSources(): Seq[String] = Option(join.onlineExternalParts) + .map(_.toScala.flatMap(_.source.validateOfflineGroupBy())) + .getOrElse(Seq.empty) + def isProduction: Boolean = join.getMetaData.isProduction def team: String = join.getMetaData.getTeam diff --git a/api/src/test/scala/ai/chronon/api/test/ExtensionsTest.scala b/api/src/test/scala/ai/chronon/api/test/ExtensionsTest.scala index 22f91d6d22..51839af2d5 100644 --- a/api/src/test/scala/ai/chronon/api/test/ExtensionsTest.scala +++ b/api/src/test/scala/ai/chronon/api/test/ExtensionsTest.scala @@ -16,7 +16,7 @@ package ai.chronon.api.test -import ai.chronon.api.{Accuracy, Builders, Constants, GroupBy} +import ai.chronon.api.{Accuracy, Builders, Constants, GroupBy, StringType, DoubleType, StructType, StructField} import org.junit.Test import ai.chronon.api.Extensions._ import org.junit.Assert.{assertEquals, assertFalse, assertNotEquals, assertTrue} @@ -248,4 +248,103 @@ class ExtensionsTest { assertEquals(join1.semanticHash(excludeTopic = true), join2.semanticHash(excludeTopic = true)) assertEquals(join1.semanticHash(excludeTopic = false), join2.semanticHash(excludeTopic = false)) } + + @Test + def testExternalSourceValidationWithMatchingSchemas(): Unit = { + // Create compatible schemas using the correct DataType objects + val keySchema = StructType("key", Array(StructField("user_id", StringType))) + val valueSchema = StructType("value", Array(StructField("feature_value", DoubleType))) + + // Create a query and source for the GroupBy + val query = Builders.Query(selects = Map("feature_value" -> "value")) + val source = Builders.Source.events(query, "test.table") + + // Create GroupBy with matching key columns and sources + val groupBy = Builders.GroupBy( + keyColumns = Seq("user_id"), + sources = Seq(source) + ) + + // Create ExternalSource with compatible schemas + val metadata = Builders.MetaData(name = "test_external_source") + val externalSource = Builders.ExternalSource(metadata, keySchema, valueSchema) + + // Manually set the offlineGroupBy field since the builder doesn't support it yet + externalSource.setOfflineGroupBy(groupBy) + + // This should return no errors + val errors = externalSource.validateOfflineGroupBy() + assertTrue(s"Expected no errors, but got: ${errors.mkString(", ")}", errors.isEmpty) + } + + @Test + def testExternalSourceValidationWithMismatchedKeySchemas(): Unit = { + // Create key schema with different fields + val keySchema = StructType("key", Array(StructField("user_id", StringType))) + val valueSchema = StructType("value", Array(StructField("feature_value", DoubleType))) + + // Create a query and source for the GroupBy + val query = Builders.Query(selects = Map("feature_value" -> "feature_value")) + val source = Builders.Source.events(query, "test.table") + + // Create GroupBy with different key columns + val groupBy = Builders.GroupBy( + keyColumns = Seq("different_key"), // Mismatched key column + sources = Seq(source) + ) + + // Create ExternalSource with incompatible schemas + val metadata = Builders.MetaData(name = "test_external_source") + val externalSource = Builders.ExternalSource(metadata, keySchema, valueSchema) + externalSource.setOfflineGroupBy(groupBy) + + // This should return validation errors + val errors = externalSource.validateOfflineGroupBy() + assertFalse("Expected validation errors for mismatched key schemas", errors.isEmpty) + assertTrue("Error should mention key schema mismatch", + errors.exists(_.contains("key schema contains columns"))) + } + + @Test + def testExternalSourceValidationWithMismatchedValueSchemas(): Unit = { + // Create compatible key schema but mismatched value schema + val keySchema = StructType("key", Array(StructField("user_id", StringType))) + val valueSchema = StructType("value", Array(StructField("feature_value", DoubleType))) + + // Create a source for the GroupBy - this is needed for valueColumns to work + val query = Builders.Query(selects = Map("different_feature" -> "different_feature")) + val source = Builders.Source.events(query, "test.table") + + // Create GroupBy with different value columns + val groupBy = Builders.GroupBy( + keyColumns = Seq("user_id"), + sources = Seq(source) + ) + + // Create ExternalSource with incompatible schemas + val metadata = Builders.MetaData(name = "test_external_source") + val externalSource = Builders.ExternalSource(metadata, keySchema, valueSchema) + externalSource.setOfflineGroupBy(groupBy) + + // This should return validation errors + val errors = externalSource.validateOfflineGroupBy() + assertFalse("Expected validation errors for mismatched value schemas", errors.isEmpty) + assertTrue("Error should mention value schema mismatch", + errors.exists(_.contains("valueSchema contains columns"))) + } + + @Test + def testExternalSourceValidationWithNullOfflineGroupBy(): Unit = { + // Create ExternalSource without offlineGroupBy + val keySchema = StructType("key", Array(StructField("user_id", StringType))) + val valueSchema = StructType("value", Array(StructField("feature_value", DoubleType))) + + val metadata = Builders.MetaData(name = "test_external_source") + val externalSource = Builders.ExternalSource(metadata, keySchema, valueSchema) + // Don't set offlineGroupBy (it remains null) + + // This should return no errors (validation should be skipped) + val errors = externalSource.validateOfflineGroupBy() + assertTrue("Expected no errors when offlineGroupBy is null", errors.isEmpty) + } } diff --git a/api/thrift/api.thrift b/api/thrift/api.thrift index 21c484bf6c..03e3b16a4c 100644 --- a/api/thrift/api.thrift +++ b/api/thrift/api.thrift @@ -121,6 +121,8 @@ struct ExternalSource { 2: optional TDataType keySchema 3: optional TDataType valueSchema 4: optional ExternalSourceFactoryConfig factoryConfig + // GroupBy to be used for offline backfill - enables PITC offline computation + 5: optional GroupBy offlineGroupBy } /** diff --git a/spark/src/main/scala/ai/chronon/spark/Analyzer.scala b/spark/src/main/scala/ai/chronon/spark/Analyzer.scala index 403ce2ca94..4c3ac9e331 100644 --- a/spark/src/main/scala/ai/chronon/spark/Analyzer.scala +++ b/spark/src/main/scala/ai/chronon/spark/Analyzer.scala @@ -373,6 +373,8 @@ class Analyzer(tableUtils: TableUtils, val noAccessTables = mutable.Set[String]() ++= leftNoAccessTables // Pair of (table name, group_by name, expected_start) which indicate that the table no not have data available for the required group_by val dataAvailabilityErrors: ListBuffer[(String, String, String)] = ListBuffer.empty[(String, String, String)] + // ExternalSource schema validation errors + val externalSourceErrors: ListBuffer[String] = ListBuffer.empty[String] val rangeToFill = JoinUtils.getRangesToFill(joinConf.left, tableUtils, endDate, historicalBackfill = joinConf.historicalBackfill) @@ -423,6 +425,9 @@ class Analyzer(tableUtils: TableUtils, gbStartPartitions += (part.groupBy.metaData.name -> gbStartPartition) } if (joinConf.onlineExternalParts != null) { + // Validate ExternalSource schemas if they have offlineGroupBy configured + externalSourceErrors ++= joinConf.validateExternalSources() + joinConf.onlineExternalParts.toScala.foreach { part => joinIntermediateValuesMetadata ++= part.source.valueFields.map { field => AggregationMetadata(part.fullName + "_" + field.name, @@ -507,7 +512,7 @@ class Analyzer(tableUtils: TableUtils, logger.info(s"$gbName : ${startPartitions.mkString(",")}") } } - if (keysWithError.isEmpty && noAccessTables.isEmpty && dataAvailabilityErrors.isEmpty) { + if (keysWithError.isEmpty && noAccessTables.isEmpty && dataAvailabilityErrors.isEmpty && externalSourceErrors.isEmpty) { logger.info("----- Backfill validation completed. No errors found. -----") } else { logger.info(s"----- Schema validation completed. Found ${keysWithError.size} errors") @@ -519,6 +524,8 @@ class Analyzer(tableUtils: TableUtils, logger.info(s"---- Data availability check completed. Found issue in ${dataAvailabilityErrors.size} tables ----") dataAvailabilityErrors.foreach(error => logger.info(s"Group_By ${error._2} : Source Tables ${error._1} : Expected start ${error._3}")) + logger.info(s"---- ExternalSource schema validation completed. Found ${externalSourceErrors.size} errors ----") + externalSourceErrors.foreach(error => logger.info(error)) } if (validationAssert) { @@ -526,12 +533,12 @@ class Analyzer(tableUtils: TableUtils, // For joins with bootstrap_parts, do not assert on data availability errors, as bootstrap can cover them // Only print out the errors as a warning assert( - keysWithError.isEmpty && noAccessTables.isEmpty, + keysWithError.isEmpty && noAccessTables.isEmpty && externalSourceErrors.isEmpty, "ERROR: Join validation failed. Please check error message for details." ) } else { assert( - keysWithError.isEmpty && noAccessTables.isEmpty && dataAvailabilityErrors.isEmpty, + keysWithError.isEmpty && noAccessTables.isEmpty && dataAvailabilityErrors.isEmpty && externalSourceErrors.isEmpty, "ERROR: Join validation failed. Please check error message for details." ) } From a0f0ae1a746b6f524274d5c58a3f4981b06dfda4 Mon Sep 17 00:00:00 2001 From: Hui Cheng Date: Mon, 29 Sep 2025 14:58:01 -0700 Subject: [PATCH 02/13] Fix scala fmt error --- .../scala/ai/chronon/api/Extensions.scala | 40 ++++++++++--------- .../scala/ai/chronon/spark/Analyzer.scala | 4 +- 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/api/src/main/scala/ai/chronon/api/Extensions.scala b/api/src/main/scala/ai/chronon/api/Extensions.scala index ca4aa51761..e21aa669d6 100644 --- a/api/src/main/scala/ai/chronon/api/Extensions.scala +++ b/api/src/main/scala/ai/chronon/api/Extensions.scala @@ -688,14 +688,15 @@ object Extensions { def isContextualSource: Boolean = externalSource.metadata.name == Constants.ContextualSourceName /** - * Validates schema compatibility between ExternalSource and its offlineGroupBy. - * This ensures that online and offline serving will produce consistent results. - * - * @return Sequence of error messages, empty if no errors - */ - def validateOfflineGroupBy(): Seq[String] = Option(externalSource.offlineGroupBy) - .map(_ => validateKeySchemaCompatibility() ++ validateValueSchemaCompatibility()) - .getOrElse(Seq.empty) + * Validates schema compatibility between ExternalSource and its offlineGroupBy. + * This ensures that online and offline serving will produce consistent results. + * + * @return Sequence of error messages, empty if no errors + */ + def validateOfflineGroupBy(): Seq[String] = + Option(externalSource.offlineGroupBy) + .map(_ => validateKeySchemaCompatibility() ++ validateValueSchemaCompatibility()) + .getOrElse(Seq.empty) private def validateKeySchemaCompatibility(): Seq[String] = { val errors = scala.collection.mutable.ListBuffer[String]() @@ -727,7 +728,8 @@ object Extensions { } if (extraKeys.nonEmpty) { - errors += s"ExternalSource ${externalSource.metadata.name} offlineGroupBy keyColumns contain [${extraKeys.mkString(", ")}] " + + errors += s"ExternalSource ${externalSource.metadata.name} offlineGroupBy keyColumns contain [${extraKeys + .mkString(", ")}] " + s"that are not present in ExternalSource keySchema [${externalKeyNames.mkString(", ")}]. " + s"GroupBy key columns cannot contain keys not defined in ExternalSource keySchema." } @@ -755,7 +757,8 @@ object Extensions { if (missingValueColumns.nonEmpty) { // This is an error because ExternalSource valueSchema must be compatible with GroupBy output // to ensure consistency between online and offline serving - errors += s"ExternalSource ${externalSource.metadata.name} valueSchema contains columns [${missingValueColumns.mkString(", ")}] " + + errors += s"ExternalSource ${externalSource.metadata.name} valueSchema contains columns [${missingValueColumns + .mkString(", ")}] " + s"that are not present in offlineGroupBy output columns [${groupByValueColumns.mkString(", ")}]. " + s"This indicates schema incompatibility between online and offline serving. " + s"Please ensure ExternalSource valueSchema matches the expected output of the GroupBy aggregations." @@ -1029,14 +1032,15 @@ object Extensions { } /** - * Validates all ExternalSources in this Join's onlineExternalParts. - * This ensures schema compatibility between ExternalSources and their offlineGroupBy configurations. - * - * @return Sequence of error messages, empty if no errors - */ - def validateExternalSources(): Seq[String] = Option(join.onlineExternalParts) - .map(_.toScala.flatMap(_.source.validateOfflineGroupBy())) - .getOrElse(Seq.empty) + * Validates all ExternalSources in this Join's onlineExternalParts. + * This ensures schema compatibility between ExternalSources and their offlineGroupBy configurations. + * + * @return Sequence of error messages, empty if no errors + */ + def validateExternalSources(): Seq[String] = + Option(join.onlineExternalParts) + .map(_.toScala.flatMap(_.source.validateOfflineGroupBy())) + .getOrElse(Seq.empty) def isProduction: Boolean = join.getMetaData.isProduction diff --git a/spark/src/main/scala/ai/chronon/spark/Analyzer.scala b/spark/src/main/scala/ai/chronon/spark/Analyzer.scala index 4c3ac9e331..2c1121aa8c 100644 --- a/spark/src/main/scala/ai/chronon/spark/Analyzer.scala +++ b/spark/src/main/scala/ai/chronon/spark/Analyzer.scala @@ -512,7 +512,9 @@ class Analyzer(tableUtils: TableUtils, logger.info(s"$gbName : ${startPartitions.mkString(",")}") } } - if (keysWithError.isEmpty && noAccessTables.isEmpty && dataAvailabilityErrors.isEmpty && externalSourceErrors.isEmpty) { + if ( + keysWithError.isEmpty && noAccessTables.isEmpty && dataAvailabilityErrors.isEmpty && externalSourceErrors.isEmpty + ) { logger.info("----- Backfill validation completed. No errors found. -----") } else { logger.info(s"----- Schema validation completed. Found ${keysWithError.size} errors") From 5bb15dd4ab25175cb7f56ed0e93a10047b5cb076 Mon Sep 17 00:00:00 2001 From: Hui Cheng Date: Mon, 29 Sep 2025 19:37:44 -0700 Subject: [PATCH 03/13] Join and unit tests --- .../ai/chronon/spark/BootstrapInfo.scala | 44 +- .../spark/test/BootstrapInfoTest.scala | 418 ++++++++++++++++++ .../test/ExternalSourceBackfillTest.scala | 399 +++++++++++++++++ 3 files changed, 856 insertions(+), 5 deletions(-) create mode 100644 spark/src/test/scala/ai/chronon/spark/test/BootstrapInfoTest.scala create mode 100644 spark/src/test/scala/ai/chronon/spark/test/ExternalSourceBackfillTest.scala diff --git a/spark/src/main/scala/ai/chronon/spark/BootstrapInfo.scala b/spark/src/main/scala/ai/chronon/spark/BootstrapInfo.scala index e89ab84041..4828c9e1f8 100644 --- a/spark/src/main/scala/ai/chronon/spark/BootstrapInfo.scala +++ b/spark/src/main/scala/ai/chronon/spark/BootstrapInfo.scala @@ -83,9 +83,13 @@ object BootstrapInfo { // Enrich each join part with the expected output schema logger.info(s"\nCreating BootstrapInfo for GroupBys for Join ${joinConf.metaData.name}") - var joinParts: Seq[JoinPartMetadata] = Option(joinConf.joinParts.toScala) - .getOrElse(Seq.empty) - .map(part => { + + // Combine regular JoinParts with converted ExternalParts before processing + val regularJoinParts = Option(joinConf.joinParts.toScala).getOrElse(Seq.empty) + val convertedJoinParts = convertExternalPartsToJoinParts(joinConf) + val allJoinParts = regularJoinParts ++ convertedJoinParts + + var joinParts: Seq[JoinPartMetadata] = allJoinParts.map(part => { // set computeDependency to False as we compute dependency upstream val gb = GroupBy.from(part.groupBy, range, tableUtils, computeDependency) val keySchema = SparkConversions @@ -117,10 +121,11 @@ object BootstrapInfo { JoinPartMetadata(part, keySchema, valueSchema, Map.empty) // will be populated below }) - // Enrich each external part with the expected output schema - logger.info(s"\nCreating BootstrapInfo for ExternalParts for Join ${joinConf.metaData.name}") + // Enrich online only external parts with the expected output schema + logger.info(s"\nCreating BootstrapInfo for online-only ExternalParts for Join ${joinConf.metaData.name}") val externalParts: Seq[ExternalPartMetadata] = Option(joinConf.onlineExternalParts.toScala) .getOrElse(Seq.empty) + .filter(_.source.offlineGroupBy == null) // Only online-only ExternalParts .map(part => ExternalPartMetadata(part, part.keySchemaFull, part.valueSchemaFull)) val leftFields = leftSchema @@ -355,4 +360,33 @@ object BootstrapInfo { bootstrapInfo } + + /** + * Converts ExternalParts with offlineGroupBy to JoinParts for parallel processing. + * This allows offline-capable ExternalParts to be processed alongside regular JoinParts + * in the same parallel execution pool, gaining benefits from bloom filter optimization, + * small mode optimization, and bootstrap coverage analysis. + * + * @param joinConf Join configuration containing ExternalParts to convert + * @return Sequence of JoinParts converted from offline-capable ExternalParts + */ + private def convertExternalPartsToJoinParts(joinConf: api.Join): Seq[JoinPart] = { + logger.info(s"\nConverting offline-capable ExternalParts to JoinParts for Join ${joinConf.metaData.name}") + + Option(joinConf.onlineExternalParts.toScala) + .getOrElse(Seq.empty) + .filter(_.source.offlineGroupBy != null) // Only offline-capable ExternalParts + .map { externalPart => + // Convert ExternalPart to synthetic JoinPart + val syntheticJoinPart = new api.JoinPart() + syntheticJoinPart.setGroupBy(externalPart.source.offlineGroupBy) + if (externalPart.keyMapping != null) { + syntheticJoinPart.setKeyMapping(externalPart.keyMapping) + } + if (externalPart.prefix != null) { + syntheticJoinPart.setPrefix(externalPart.prefix) + } + syntheticJoinPart + } + } } diff --git a/spark/src/test/scala/ai/chronon/spark/test/BootstrapInfoTest.scala b/spark/src/test/scala/ai/chronon/spark/test/BootstrapInfoTest.scala new file mode 100644 index 0000000000..7fa7a9fe66 --- /dev/null +++ b/spark/src/test/scala/ai/chronon/spark/test/BootstrapInfoTest.scala @@ -0,0 +1,418 @@ +/* + * Copyright (C) 2023 The Chronon Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.chronon.spark.test + +import ai.chronon.aggregator.test.Column +import ai.chronon.api.Extensions._ +import ai.chronon.api._ +import ai.chronon.spark.Extensions._ +import ai.chronon.spark._ +import org.apache.spark.sql.SparkSession +import org.junit.Assert._ +import org.junit.Test + +import scala.util.Random + +class BootstrapInfoTest { + val spark: SparkSession = SparkSessionBuilder.build("BootstrapInfoTest", local = true) + private val tableUtils = TableUtils(spark) + private val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) + private val monthAgo = tableUtils.partitionSpec.minus(today, new Window(30, TimeUnit.DAYS)) + + @Test + def testExternalSourceWithOfflineGroupByConversion(): Unit = { + val suffix = "external_backfill_" + Random.alphanumeric.take(6).mkString + val namespace = s"test_namespace_$suffix" + tableUtils.createDatabase(namespace) + + // Create test data for the GroupBy source table + val groupByColumns = List( + Column("user_id", StringType, 100), + Column("feature_value", LongType, 1000) + ) + + val groupByTable = s"$namespace.user_features" + spark.sql(s"DROP TABLE IF EXISTS $groupByTable") + DataFrameGen.events(spark, groupByColumns, 1000, partitions = 50).save(groupByTable) + + // Create the GroupBy for offline backfill + val offlineGroupBy = Builders.GroupBy( + sources = Seq( + Builders.Source.events( + query = Builders.Query( + selects = Map("user_id" -> "user_id", "feature_value" -> "feature_value"), + timeColumn = "ts" + ), + table = groupByTable + ) + ), + keyColumns = Seq("user_id"), + aggregations = Seq( + Builders.Aggregation( + operation = Operation.SUM, + inputColumn = "feature_value", + windows = Seq(new Window(7, TimeUnit.DAYS)) + ) + ), + metaData = Builders.MetaData(name = s"user_features_gb_$suffix", namespace = namespace), + accuracy = Accuracy.TEMPORAL + ) + + // Create ExternalSource with offline GroupBy + val externalSource = Builders.ExternalSource( + metadata = Builders.MetaData(name = s"external_with_offline_$suffix"), + keySchema = StructType("external_keys", Array(StructField("user_id", StringType))), + valueSchema = StructType("external_values", Array(StructField("feature_value_sum_7d", LongType))) + ) + externalSource.setOfflineGroupBy(offlineGroupBy) + + // Create a simple left source for the join + val leftColumns = List( + Column("user_id", StringType, 100), + Column("request_id", StringType, 100) + ) + + val leftTable = s"$namespace.requests" + spark.sql(s"DROP TABLE IF EXISTS $leftTable") + DataFrameGen.events(spark, leftColumns, 500, partitions = 30).save(leftTable) + + // Create Join with ExternalPart that has offline GroupBy + val join = Builders.Join( + left = Builders.Source.events( + query = Builders.Query( + selects = Map("user_id" -> "user_id", "request_id" -> "request_id"), + timeColumn = "ts" + ), + table = leftTable + ), + externalParts = Seq( + Builders.ExternalPart( + externalSource, + prefix = "ext" + ) + ), + metaData = Builders.MetaData(name = s"test_join_$suffix", namespace = namespace) + ) + + // Test BootstrapInfo conversion logic + val endPartition = today + val range = PartitionRange(monthAgo, endPartition)(tableUtils) + val bootstrapInfo = BootstrapInfo.from( + joinConf = join, + range = range, + tableUtils = tableUtils, + leftSchema = None, + computeDependency = true + ) + + // Verify that ExternalPart with offline GroupBy was converted to JoinPart + val totalJoinParts = bootstrapInfo.joinParts.length + assertTrue("Should have at least one JoinPart after conversion", totalJoinParts > 0) + + // Verify that the converted JoinPart has the expected GroupBy + val convertedJoinPart = bootstrapInfo.joinParts.find(_.joinPart.groupBy.metaData.name == offlineGroupBy.metaData.name) + assertTrue("Should find converted JoinPart with matching GroupBy name", convertedJoinPart.isDefined) + + // Verify that online-only external parts are still tracked separately + assertTrue("Should have no online-only external parts in this test", bootstrapInfo.externalParts.isEmpty) + + // Verify schema compatibility + val joinPartMeta = convertedJoinPart.get + assertEquals("Key schema should match", 1, joinPartMeta.keySchema.length) + assertEquals("Key field should be user_id", "user_id", joinPartMeta.keySchema.head.name) + + spark.sql(s"DROP DATABASE IF EXISTS $namespace CASCADE") + } + + @Test + def testExternalSourceOnlineOnlyBehavior(): Unit = { + val suffix = "online_only_" + Random.alphanumeric.take(6).mkString + val namespace = s"test_namespace_$suffix" + tableUtils.createDatabase(namespace) + + // Create ExternalSource without offline GroupBy (online-only) + val onlineOnlyExternalSource = Builders.ExternalSource( + metadata = Builders.MetaData(name = s"online_only_external_$suffix"), + keySchema = StructType("online_keys", Array(StructField("user_id", StringType))), + valueSchema = StructType("online_values", Array(StructField("online_feature", LongType))) + ) + // Note: No offlineGroupBy set, so this remains online-only + + // Create a simple left source + val leftColumns = List( + Column("user_id", StringType, 100), + Column("request_id", StringType, 100) + ) + + val leftTable = s"$namespace.requests" + spark.sql(s"DROP TABLE IF EXISTS $leftTable") + DataFrameGen.events(spark, leftColumns, 500, partitions = 30).save(leftTable) + + // Create Join with online-only ExternalPart + val join = Builders.Join( + left = Builders.Source.events( + query = Builders.Query( + selects = Map("user_id" -> "user_id", "request_id" -> "request_id"), + timeColumn = "ts" + ), + table = leftTable + ), + externalParts = Seq( + Builders.ExternalPart( + onlineOnlyExternalSource, + prefix = "online" + ) + ), + metaData = Builders.MetaData(name = s"online_only_join_$suffix", namespace = namespace) + ) + + // Test BootstrapInfo with online-only external part + val endPartition = today + val range = PartitionRange(monthAgo, endPartition)(tableUtils) + val bootstrapInfo = BootstrapInfo.from( + joinConf = join, + range = range, + tableUtils = tableUtils, + leftSchema = None, + computeDependency = true + ) + + // Verify that online-only ExternalPart was NOT converted to JoinPart + assertEquals("Should have no JoinParts from conversion", 0, bootstrapInfo.joinParts.length) + + // Verify that online-only external part is tracked in externalParts + assertEquals("Should have one online-only external part", 1, bootstrapInfo.externalParts.length) + + val externalPartMeta = bootstrapInfo.externalParts.head + assertEquals("External part name should match", onlineOnlyExternalSource.metadata.name, + externalPartMeta.externalPart.source.metadata.name) + + spark.sql(s"DROP DATABASE IF EXISTS $namespace CASCADE") + } + + @Test + def testMixedExternalPartsConversion(): Unit = { + val suffix = "mixed_" + Random.alphanumeric.take(6).mkString + val namespace = s"test_namespace_$suffix" + tableUtils.createDatabase(namespace) + + // Create test data for the GroupBy source table + val groupByColumns = List( + Column("user_id", StringType, 100), + Column("feature_value", LongType, 1000) + ) + + val groupByTable = s"$namespace.user_features" + spark.sql(s"DROP TABLE IF EXISTS $groupByTable") + DataFrameGen.events(spark, groupByColumns, 1000, partitions = 50).save(groupByTable) + + // Create GroupBy for offline backfill + val offlineGroupBy = Builders.GroupBy( + sources = Seq( + Builders.Source.events( + query = Builders.Query( + selects = Map("user_id" -> "user_id", "feature_value" -> "feature_value"), + timeColumn = "ts" + ), + table = groupByTable + ) + ), + keyColumns = Seq("user_id"), + aggregations = Seq( + Builders.Aggregation( + operation = Operation.SUM, + inputColumn = "feature_value" + ) + ), + metaData = Builders.MetaData(name = s"offline_gb_$suffix", namespace = namespace), + accuracy = Accuracy.TEMPORAL + ) + + // Create ExternalSource with offline GroupBy + val externalSourceWithOffline = Builders.ExternalSource( + metadata = Builders.MetaData(name = s"external_with_offline_$suffix"), + keySchema = StructType("offline_keys", Array(StructField("user_id", StringType))), + valueSchema = StructType("offline_values", Array(StructField("feature_value_sum", LongType))) + ) + externalSourceWithOffline.setOfflineGroupBy(offlineGroupBy) + + // Create online-only ExternalSource + val externalSourceOnlineOnly = Builders.ExternalSource( + metadata = Builders.MetaData(name = s"external_online_only_$suffix"), + keySchema = StructType("online_keys", Array(StructField("user_id", StringType))), + valueSchema = StructType("online_values", Array(StructField("online_feature", LongType))) + ) + // No offlineGroupBy set + + // Create left source + val leftColumns = List( + Column("user_id", StringType, 100), + Column("request_id", StringType, 100) + ) + + val leftTable = s"$namespace.requests" + spark.sql(s"DROP TABLE IF EXISTS $leftTable") + DataFrameGen.events(spark, leftColumns, 500, partitions = 30).save(leftTable) + + // Create Join with both types of ExternalParts + val join = Builders.Join( + left = Builders.Source.events( + query = Builders.Query( + selects = Map("user_id" -> "user_id", "request_id" -> "request_id"), + timeColumn = "ts" + ), + table = leftTable + ), + externalParts = Seq( + Builders.ExternalPart( + externalSourceWithOffline, + prefix = "offline" + ), + Builders.ExternalPart( + externalSourceOnlineOnly, + prefix = "online" + ) + ), + metaData = Builders.MetaData(name = s"mixed_join_$suffix", namespace = namespace) + ) + + // Test BootstrapInfo with mixed external parts + val endPartition = today + val range = PartitionRange(monthAgo, endPartition)(tableUtils) + val bootstrapInfo = BootstrapInfo.from( + joinConf = join, + range = range, + tableUtils = tableUtils, + leftSchema = None, + computeDependency = true + ) + + // Verify that offline-capable ExternalPart was converted to JoinPart + assertEquals("Should have one JoinPart from conversion", 1, bootstrapInfo.joinParts.length) + val convertedJoinPart = bootstrapInfo.joinParts.head + assertEquals("Converted JoinPart should have matching GroupBy name", + offlineGroupBy.metaData.name, convertedJoinPart.joinPart.groupBy.metaData.name) + + // Verify that online-only ExternalPart is tracked separately + assertEquals("Should have one online-only external part", 1, bootstrapInfo.externalParts.length) + val onlineExternalPart = bootstrapInfo.externalParts.head + assertEquals("Online external part should have matching name", + externalSourceOnlineOnly.metadata.name, onlineExternalPart.externalPart.source.metadata.name) + + spark.sql(s"DROP DATABASE IF EXISTS $namespace CASCADE") + } + + @Test + def testExternalPartKeyMappingPreservation(): Unit = { + val suffix = "keymapping_" + Random.alphanumeric.take(6).mkString + val namespace = s"test_namespace_$suffix" + tableUtils.createDatabase(namespace) + + // Create test data for the GroupBy source table + val groupByColumns = List( + Column("internal_user_id", StringType, 100), + Column("feature_value", LongType, 1000) + ) + + val groupByTable = s"$namespace.user_features" + spark.sql(s"DROP TABLE IF EXISTS $groupByTable") + DataFrameGen.events(spark, groupByColumns, 1000, partitions = 50).save(groupByTable) + + // Create GroupBy with internal_user_id as key + val offlineGroupBy = Builders.GroupBy( + sources = Seq( + Builders.Source.events( + query = Builders.Query( + selects = Map("internal_user_id" -> "internal_user_id", "feature_value" -> "feature_value"), + timeColumn = "ts" + ), + table = groupByTable + ) + ), + keyColumns = Seq("internal_user_id"), + aggregations = Seq( + Builders.Aggregation( + operation = Operation.AVERAGE, + inputColumn = "feature_value" + ) + ), + metaData = Builders.MetaData(name = s"keymapping_gb_$suffix", namespace = namespace), + accuracy = Accuracy.TEMPORAL + ) + + // Create ExternalSource with key mapping + val externalSource = Builders.ExternalSource( + metadata = Builders.MetaData(name = s"external_with_keymapping_$suffix"), + keySchema = StructType("external_keys", Array(StructField("internal_user_id", StringType))), + valueSchema = StructType("external_values", Array(StructField("feature_value_avg", DoubleType))) + ) + externalSource.setOfflineGroupBy(offlineGroupBy) + + // Create left source with external_user_id + val leftColumns = List( + Column("external_user_id", StringType, 100), + Column("request_id", StringType, 100) + ) + + val leftTable = s"$namespace.requests" + spark.sql(s"DROP TABLE IF EXISTS $leftTable") + DataFrameGen.events(spark, leftColumns, 500, partitions = 30).save(leftTable) + + // Create Join with key mapping from external_user_id to internal_user_id + val join = Builders.Join( + left = Builders.Source.events( + query = Builders.Query( + selects = Map("external_user_id" -> "external_user_id", "request_id" -> "request_id"), + timeColumn = "ts" + ), + table = leftTable + ), + externalParts = Seq( + Builders.ExternalPart( + externalSource, + keyMapping = Map("external_user_id" -> "internal_user_id"), + prefix = "mapped" + ) + ), + metaData = Builders.MetaData(name = s"keymapping_join_$suffix", namespace = namespace) + ) + + // Test BootstrapInfo preserves key mapping + val endPartition = today + val range = PartitionRange(monthAgo, endPartition)(tableUtils) + val bootstrapInfo = BootstrapInfo.from( + joinConf = join, + range = range, + tableUtils = tableUtils, + leftSchema = None, + computeDependency = true + ) + + // Verify conversion occurred + assertEquals("Should have one converted JoinPart", 1, bootstrapInfo.joinParts.length) + + val convertedJoinPart = bootstrapInfo.joinParts.head.joinPart + assertNotNull("Key mapping should be preserved", convertedJoinPart.keyMapping) + assertEquals("Key mapping should map external_user_id to internal_user_id", + "internal_user_id", convertedJoinPart.keyMapping.get("external_user_id")) + + // Verify prefix is preserved + assertEquals("Prefix should be preserved", "mapped", convertedJoinPart.prefix) + + spark.sql(s"DROP DATABASE IF EXISTS $namespace CASCADE") + } +} \ No newline at end of file diff --git a/spark/src/test/scala/ai/chronon/spark/test/ExternalSourceBackfillTest.scala b/spark/src/test/scala/ai/chronon/spark/test/ExternalSourceBackfillTest.scala new file mode 100644 index 0000000000..75847f49ba --- /dev/null +++ b/spark/src/test/scala/ai/chronon/spark/test/ExternalSourceBackfillTest.scala @@ -0,0 +1,399 @@ +/* + * Copyright (C) 2023 The Chronon Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.chronon.spark.test + +import ai.chronon.aggregator.test.Column +import ai.chronon.api.Extensions._ +import ai.chronon.api.{Accuracy, Builders, DoubleType, LongType, Operation, StringType, StructField, StructType, TimeUnit, Window} +import ai.chronon.spark.Extensions._ +import ai.chronon.spark._ +import org.apache.spark.sql.SparkSession +import org.junit.Assert._ +import org.junit.Test + +import scala.util.Random + +class ExternalSourceBackfillTest { + val spark: SparkSession = SparkSessionBuilder.build("ExternalSourceBackfillTest", local = true) + private val tableUtils = TableUtils(spark) + private val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) + private val monthAgo = tableUtils.partitionSpec.minus(today, new Window(30, TimeUnit.DAYS)) + private val yearAgo = tableUtils.partitionSpec.minus(today, new Window(365, TimeUnit.DAYS)) + + @Test + def testExternalSourceBackfillComputeJoin(): Unit = { + val spark: SparkSession = + SparkSessionBuilder.build("ExternalSourceBackfillTest" + "_" + Random.alphanumeric.take(6).mkString, local = true) + val tableUtils = TableUtils(spark) + val namespace = "test_namespace_ext_backfill" + "_" + Random.alphanumeric.take(6).mkString + tableUtils.createDatabase(namespace) + + // Create user transaction data for offline GroupBy + val transactionColumns = List( + Column("user_id", StringType, 100), + Column("amount", LongType, 1000), + Column("transaction_type", StringType, 5) + ) + + val transactionTable = s"$namespace.user_transactions" + spark.sql(s"DROP TABLE IF EXISTS $transactionTable") + DataFrameGen.events(spark, transactionColumns, 2000, partitions = 100).save(transactionTable) + + // Create offline GroupBy for external source + val offlineGroupBy = Builders.GroupBy( + sources = Seq( + Builders.Source.events( + query = Builders.Query( + selects = Map("user_id" -> "user_id", "amount" -> "amount"), + timeColumn = "ts" + ), + table = transactionTable + ) + ), + keyColumns = Seq("user_id"), + aggregations = Seq( + Builders.Aggregation( + operation = Operation.SUM, + inputColumn = "amount", + windows = Seq(new Window(30, TimeUnit.DAYS)) + ) + ), + metaData = Builders.MetaData(name = s"user_transaction_features_$namespace", namespace = namespace), + accuracy = Accuracy.TEMPORAL + ) + + // Create ExternalSource with offline GroupBy + val externalSource = Builders.ExternalSource( + metadata = Builders.MetaData(name = s"external_transaction_features_$namespace"), + keySchema = StructType("external_keys", Array(StructField("user_id", StringType))), + valueSchema = StructType("external_values", Array(StructField("amount_sum_30d", LongType))) + ) + externalSource.setOfflineGroupBy(offlineGroupBy) + + // Create left source (user events to join against) + val userEventColumns = List( + Column("user_id", StringType, 100), + Column("event_type", StringType, 10), + Column("session_id", StringType, 200) + ) + + val userEventTable = s"$namespace.user_events" + spark.sql(s"DROP TABLE IF EXISTS $userEventTable") + DataFrameGen.events(spark, userEventColumns, 1000, partitions = 50).save(userEventTable) + + // Create Join configuration with ExternalPart + val joinConf = Builders.Join( + left = Builders.Source.events( + query = Builders.Query( + selects = Map("user_id" -> "user_id", "event_type" -> "event_type", "session_id" -> "session_id"), + timeColumn = "ts" + ), + table = userEventTable + ), + externalParts = Seq( + Builders.ExternalPart( + externalSource, + prefix = "ext" + ) + ), + metaData = Builders.MetaData(name = s"test_external_join_$namespace", namespace = namespace) + ) + + // Run analyzer to ensure GroupBy tables are created + val analyzer = new Analyzer(tableUtils, joinConf, monthAgo, today) + analyzer.run() + + // Create Join and compute + val endPartition = monthAgo + val join = new Join(joinConf = joinConf, endPartition = endPartition, tableUtils) + val computed = join.computeJoin(Some(10)) + + // Verify results + assertNotNull("Computed result should not be null", computed) + assertTrue("Result should have rows", computed.count() > 0) + + // Verify that external source columns are present + val columns = computed.columns.toSet + assertTrue("Should contain left source columns", columns.contains("user_id")) + assertTrue("Should contain left source columns", columns.contains("event_type")) + assertTrue("Should contain left source columns", columns.contains("session_id")) + assertTrue("Should contain external source prefixed columns", + columns.exists(_.startsWith("ext_"))) + + // Show results for debugging + println("=== External Source Backfill Join Results ===") + computed.show(20, truncate = false) + println(s"Total rows: ${computed.count()}") + println(s"Columns: ${computed.columns.mkString(", ")}") + + spark.sql(s"DROP DATABASE IF EXISTS $namespace CASCADE") + } + + @Test + def testMixedExternalAndJoinParts(): Unit = { + val spark: SparkSession = + SparkSessionBuilder.build("ExternalSourceBackfillTest_Mixed" + "_" + Random.alphanumeric.take(6).mkString, local = true) + val tableUtils = TableUtils(spark) + val namespace = "test_namespace_mixed" + "_" + Random.alphanumeric.take(6).mkString + tableUtils.createDatabase(namespace) + + // Create transaction data for external source GroupBy + val transactionColumns = List( + Column("user_id", StringType, 100), + Column("purchase_amount", LongType, 5000) + ) + + val transactionTable = s"$namespace.purchase_transactions" + spark.sql(s"DROP TABLE IF EXISTS $transactionTable") + DataFrameGen.events(spark, transactionColumns, 1500, partitions = 80).save(transactionTable) + + // Create session data for regular JoinPart GroupBy + val sessionColumns = List( + Column("user_id", StringType, 100), + Column("session_duration", LongType, 7200) + ) + + val sessionTable = s"$namespace.user_sessions" + spark.sql(s"DROP TABLE IF EXISTS $sessionTable") + DataFrameGen.events(spark, sessionColumns, 1200, partitions = 60).save(sessionTable) + + // Create GroupBy for external source (purchases) + val purchaseGroupBy = Builders.GroupBy( + sources = Seq( + Builders.Source.events( + query = Builders.Query( + selects = Map("user_id" -> "user_id", "purchase_amount" -> "purchase_amount"), + timeColumn = "ts" + ), + table = transactionTable + ) + ), + keyColumns = Seq("user_id"), + aggregations = Seq( + Builders.Aggregation( + operation = Operation.AVERAGE, + inputColumn = "purchase_amount", + windows = Seq(new Window(7, TimeUnit.DAYS)) + ) + ), + metaData = Builders.MetaData(name = s"purchase_features_$namespace", namespace = namespace), + accuracy = Accuracy.TEMPORAL + ) + + // Create GroupBy for regular JoinPart (sessions) + val sessionGroupBy = Builders.GroupBy( + sources = Seq( + Builders.Source.events( + query = Builders.Query( + selects = Map("user_id" -> "user_id", "session_duration" -> "session_duration"), + timeColumn = "ts" + ), + table = sessionTable + ) + ), + keyColumns = Seq("user_id"), + aggregations = Seq( + Builders.Aggregation( + operation = Operation.COUNT, + inputColumn = "session_duration", + windows = Seq(new Window(14, TimeUnit.DAYS)) + ) + ), + metaData = Builders.MetaData(name = s"session_features_$namespace", namespace = namespace), + accuracy = Accuracy.TEMPORAL + ) + + // Create ExternalSource with offline GroupBy + val externalSource = Builders.ExternalSource( + metadata = Builders.MetaData(name = s"external_purchase_features_$namespace"), + keySchema = StructType("external_keys", Array(StructField("user_id", StringType))), + valueSchema = StructType("external_values", Array(StructField("purchase_amount_avg_7d", DoubleType))) + ) + externalSource.setOfflineGroupBy(purchaseGroupBy) + + // Create left source + val userActivityColumns = List( + Column("user_id", StringType, 100), + Column("page_views", LongType, 50) + ) + + val userActivityTable = s"$namespace.user_activity" + spark.sql(s"DROP TABLE IF EXISTS $userActivityTable") + DataFrameGen.events(spark, userActivityColumns, 800, partitions = 40).save(userActivityTable) + + // Create Join with both ExternalPart and regular JoinPart + val joinConf = Builders.Join( + left = Builders.Source.events( + query = Builders.Query( + selects = Map("user_id" -> "user_id", "page_views" -> "page_views"), + timeColumn = "ts" + ), + table = userActivityTable + ), + joinParts = Seq( + Builders.JoinPart( + groupBy = sessionGroupBy, + prefix = "session" + ) + ), + externalParts = Seq( + Builders.ExternalPart( + externalSource, + prefix = "purchase" + ) + ), + metaData = Builders.MetaData(name = s"test_mixed_join_$namespace", namespace = namespace) + ) + + // Run analyzer to ensure all GroupBy tables are created + val analyzer = new Analyzer(tableUtils, joinConf, monthAgo, today) + analyzer.run() + + // Create Join and compute + val endPartition = monthAgo + val join = new Join(joinConf = joinConf, endPartition = endPartition, tableUtils) + val computed = join.computeJoin(Some(10)) + + // Verify results + assertNotNull("Computed result should not be null", computed) + assertTrue("Result should have rows", computed.count() > 0) + + // Verify that both regular JoinPart and ExternalPart columns are present + val columns = computed.columns.toSet + assertTrue("Should contain left source columns", columns.contains("user_id")) + assertTrue("Should contain left source columns", columns.contains("page_views")) + assertTrue("Should contain regular JoinPart prefixed columns", + columns.exists(_.startsWith("session_"))) + assertTrue("Should contain external source prefixed columns", + columns.exists(_.startsWith("purchase_"))) + + // Show results for debugging + println("=== Mixed External and JoinPart Results ===") + computed.show(20, truncate = false) + println(s"Total rows: ${computed.count()}") + println(s"Columns: ${computed.columns.mkString(", ")}") + + spark.sql(s"DROP DATABASE IF EXISTS $namespace CASCADE") + } + + @Test + def testExternalSourceBackfillWithKeyMapping(): Unit = { + val spark: SparkSession = + SparkSessionBuilder.build("ExternalSourceBackfillTest_KeyMapping" + "_" + Random.alphanumeric.take(6).mkString, local = true) + val tableUtils = TableUtils(spark) + val namespace = "test_namespace_keymapping" + "_" + Random.alphanumeric.take(6).mkString + tableUtils.createDatabase(namespace) + + // Create feature data with internal_user_id + val featureColumns = List( + Column("internal_user_id", StringType, 100), + Column("feature_score", LongType, 100) + ) + + val featureTable = s"$namespace.user_features" + spark.sql(s"DROP TABLE IF EXISTS $featureTable") + DataFrameGen.events(spark, featureColumns, 1000, partitions = 50).save(featureTable) + + // Create GroupBy using internal_user_id + val featureGroupBy = Builders.GroupBy( + sources = Seq( + Builders.Source.events( + query = Builders.Query( + selects = Map("internal_user_id" -> "internal_user_id", "feature_score" -> "feature_score"), + timeColumn = "ts" + ), + table = featureTable + ) + ), + keyColumns = Seq("internal_user_id"), + aggregations = Seq( + Builders.Aggregation( + operation = Operation.MAX, + inputColumn = "feature_score", + windows = Seq(new Window(30, TimeUnit.DAYS)) + ) + ), + metaData = Builders.MetaData(name = s"feature_gb_$namespace", namespace = namespace), + accuracy = Accuracy.TEMPORAL + ) + + // Create ExternalSource that expects internal_user_id + val externalSource = Builders.ExternalSource( + metadata = Builders.MetaData(name = s"external_features_$namespace"), + keySchema = StructType("external_keys", Array(StructField("internal_user_id", StringType))), + valueSchema = StructType("external_values", Array(StructField("feature_score_max_30d", LongType))) + ) + externalSource.setOfflineGroupBy(featureGroupBy) + + // Create left source with external_user_id + val requestColumns = List( + Column("external_user_id", StringType, 100), + Column("request_type", StringType, 5) + ) + + val requestTable = s"$namespace.user_requests" + spark.sql(s"DROP TABLE IF EXISTS $requestTable") + DataFrameGen.events(spark, requestColumns, 600, partitions = 30).save(requestTable) + + // Create Join with key mapping from external_user_id to internal_user_id + val joinConf = Builders.Join( + left = Builders.Source.events( + query = Builders.Query( + selects = Map("external_user_id" -> "external_user_id", "request_type" -> "request_type"), + timeColumn = "ts" + ), + table = requestTable + ), + externalParts = Seq( + Builders.ExternalPart( + externalSource, + keyMapping = Map("external_user_id" -> "internal_user_id"), + prefix = "mapped" + ) + ), + metaData = Builders.MetaData(name = s"test_keymapping_join_$namespace", namespace = namespace) + ) + + // Run analyzer to ensure GroupBy tables are created + val analyzer = new Analyzer(tableUtils, joinConf, monthAgo, today) + analyzer.run() + + // Create Join and compute + val endPartition = monthAgo + val join = new Join(joinConf = joinConf, endPartition = endPartition, tableUtils) + val computed = join.computeJoin(Some(10)) + + // Verify results + assertNotNull("Computed result should not be null", computed) + assertTrue("Result should have rows", computed.count() > 0) + + // Verify column structure + val columns = computed.columns.toSet + assertTrue("Should contain external_user_id from left", columns.contains("external_user_id")) + assertTrue("Should contain request_type from left", columns.contains("request_type")) + assertTrue("Should contain mapped external columns", + columns.exists(_.startsWith("mapped_"))) + + // Show results for debugging + println("=== Key Mapping External Source Results ===") + computed.show(20, truncate = false) + println(s"Total rows: ${computed.count()}") + println(s"Columns: ${computed.columns.mkString(", ")}") + + spark.sql(s"DROP DATABASE IF EXISTS $namespace CASCADE") + } +} \ No newline at end of file From 3db338b5947612e0ae2d11edcaaacde083c5ece8 Mon Sep 17 00:00:00 2001 From: Hui Cheng Date: Tue, 30 Sep 2025 09:47:00 -0700 Subject: [PATCH 04/13] Compute join including external groupby --- .../scala/ai/chronon/api/Extensions.scala | 50 ++- .../scala/ai/chronon/spark/Analyzer.scala | 4 +- .../ai/chronon/spark/BootstrapInfo.scala | 34 +- .../scala/ai/chronon/spark/JoinBase.scala | 2 +- .../spark/test/BootstrapInfoTest.scala | 418 ------------------ .../test/ExternalSourceBackfillTest.scala | 4 +- 6 files changed, 50 insertions(+), 462 deletions(-) delete mode 100644 spark/src/test/scala/ai/chronon/spark/test/BootstrapInfoTest.scala diff --git a/api/src/main/scala/ai/chronon/api/Extensions.scala b/api/src/main/scala/ai/chronon/api/Extensions.scala index e21aa669d6..0940a3c526 100644 --- a/api/src/main/scala/ai/chronon/api/Extensions.scala +++ b/api/src/main/scala/ai/chronon/api/Extensions.scala @@ -963,7 +963,7 @@ object Extensions { private[api] def baseSemanticHash: Map[String, String] = { val leftHash = ThriftJsonCodec.md5Digest(join.left) logger.info(s"Join Left Object: ${ThriftJsonCodec.toJsonStr(join.left)}") - val partHashes = join.joinParts.toScala.map { jp => partOutputTable(jp) -> jp.groupBy.semanticHash }.toMap + val partHashes = join.getCombinedJoinParts.map { jp => partOutputTable(jp) -> jp.groupBy.semanticHash }.toMap val derivedHashMap = Option(join.derivations) .map { derivations => val derivedHash = @@ -991,7 +991,7 @@ object Extensions { } cleanTopicInSource(join.left) - join.getJoinParts.toScala.foreach(_.groupBy.sources.toScala.foreach(cleanTopicInSource)) + join.getCombinedJoinParts.foreach(_.groupBy.sources.toScala.foreach(cleanTopicInSource)) join } @@ -1099,9 +1099,47 @@ object Extensions { } def setups: Seq[String] = - (join.left.query.setupsSeq ++ join.joinParts.toScala + (join.left.query.setupsSeq ++ join.getCombinedJoinParts .flatMap(_.groupBy.setups)).distinct + /** + * Converts offline-capable ExternalParts to JoinParts for unified processing during backfill. + * This enables external sources with offlineGroupBy to participate in offline computation + * while maintaining compatibility with existing join processing logic. + * + * @return Sequence of JoinParts converted from offline-capable ExternalParts + */ + private def getExternalJoinParts: Seq[JoinPart] = { + Option(join.onlineExternalParts) + .map(_.toScala) + .getOrElse(Seq.empty) + .filter(_.source.offlineGroupBy != null) // Only offline-capable ExternalParts + .map { externalPart => + // Convert ExternalPart to synthetic JoinPart + val syntheticJoinPart = new JoinPart() + syntheticJoinPart.setGroupBy(externalPart.source.offlineGroupBy) + if (externalPart.keyMapping != null) { + syntheticJoinPart.setKeyMapping(externalPart.keyMapping) + } + if (externalPart.prefix != null) { + syntheticJoinPart.setPrefix(externalPart.prefix) + } + syntheticJoinPart + } + } + + /** + * Get all join parts including both regular joinParts and external join parts. + * This provides a unified view of all join parts for processing. + * + * @return Sequence containing all JoinParts (regular + converted external) + */ + def getCombinedJoinParts: Seq[JoinPart] = { + val regularJoinParts = Option(join.joinParts).map(_.toScala).getOrElse(Seq.empty) + val externalJoinParts = getExternalJoinParts + regularJoinParts ++ externalJoinParts + } + def copyForVersioningComparison(): Join = { // When we compare previous-run join to current join to detect changes requiring table migration // these are the fields that should be checked to not have accidental recomputes @@ -1115,10 +1153,8 @@ object Extensions { } lazy val joinPartOps: Seq[JoinPartOps] = - Option(join.joinParts) - .getOrElse(new util.ArrayList[JoinPart]()) - .toScala - .toSeq + Option(join.getCombinedJoinParts) + .getOrElse(Seq.empty[JoinPart]) .map(new JoinPartOps(_)) def logFullValues: Boolean = true // TODO: supports opt-out in the future diff --git a/spark/src/main/scala/ai/chronon/spark/Analyzer.scala b/spark/src/main/scala/ai/chronon/spark/Analyzer.scala index 2c1121aa8c..edf05ac474 100644 --- a/spark/src/main/scala/ai/chronon/spark/Analyzer.scala +++ b/spark/src/main/scala/ai/chronon/spark/Analyzer.scala @@ -390,7 +390,7 @@ class Analyzer(tableUtils: TableUtils, ) .getOrElse(Seq.empty) - joinConf.joinParts.toScala.foreach { part => + joinConf.getCombinedJoinParts.foreach { part => val analyzeGroupByResult = analyzeGroupBy( part.groupBy, @@ -442,7 +442,7 @@ class Analyzer(tableUtils: TableUtils, val rightSchema = joinIntermediateValuesMetadata.map(aggregation => (aggregation.name, aggregation.columnType)) - val keyColumns: List[String] = joinConf.joinParts.toScala + val keyColumns: List[String] = joinConf.getCombinedJoinParts.toList .flatMap(joinPart => { val keyCols: Seq[String] = joinPart.groupBy.keyColumns.toScala if (joinPart.keyMapping == null) keyCols diff --git a/spark/src/main/scala/ai/chronon/spark/BootstrapInfo.scala b/spark/src/main/scala/ai/chronon/spark/BootstrapInfo.scala index 4828c9e1f8..a503bcb00b 100644 --- a/spark/src/main/scala/ai/chronon/spark/BootstrapInfo.scala +++ b/spark/src/main/scala/ai/chronon/spark/BootstrapInfo.scala @@ -84,10 +84,8 @@ object BootstrapInfo { // Enrich each join part with the expected output schema logger.info(s"\nCreating BootstrapInfo for GroupBys for Join ${joinConf.metaData.name}") - // Combine regular JoinParts with converted ExternalParts before processing - val regularJoinParts = Option(joinConf.joinParts.toScala).getOrElse(Seq.empty) - val convertedJoinParts = convertExternalPartsToJoinParts(joinConf) - val allJoinParts = regularJoinParts ++ convertedJoinParts + // Get all join parts including both regular and external join parts + val allJoinParts = joinConf.getCombinedJoinParts var joinParts: Seq[JoinPartMetadata] = allJoinParts.map(part => { // set computeDependency to False as we compute dependency upstream @@ -361,32 +359,4 @@ object BootstrapInfo { bootstrapInfo } - /** - * Converts ExternalParts with offlineGroupBy to JoinParts for parallel processing. - * This allows offline-capable ExternalParts to be processed alongside regular JoinParts - * in the same parallel execution pool, gaining benefits from bloom filter optimization, - * small mode optimization, and bootstrap coverage analysis. - * - * @param joinConf Join configuration containing ExternalParts to convert - * @return Sequence of JoinParts converted from offline-capable ExternalParts - */ - private def convertExternalPartsToJoinParts(joinConf: api.Join): Seq[JoinPart] = { - logger.info(s"\nConverting offline-capable ExternalParts to JoinParts for Join ${joinConf.metaData.name}") - - Option(joinConf.onlineExternalParts.toScala) - .getOrElse(Seq.empty) - .filter(_.source.offlineGroupBy != null) // Only offline-capable ExternalParts - .map { externalPart => - // Convert ExternalPart to synthetic JoinPart - val syntheticJoinPart = new api.JoinPart() - syntheticJoinPart.setGroupBy(externalPart.source.offlineGroupBy) - if (externalPart.keyMapping != null) { - syntheticJoinPart.setKeyMapping(externalPart.keyMapping) - } - if (externalPart.prefix != null) { - syntheticJoinPart.setPrefix(externalPart.prefix) - } - syntheticJoinPart - } - } } diff --git a/spark/src/main/scala/ai/chronon/spark/JoinBase.scala b/spark/src/main/scala/ai/chronon/spark/JoinBase.scala index 31becd8579..9c35d3c3b3 100644 --- a/spark/src/main/scala/ai/chronon/spark/JoinBase.scala +++ b/spark/src/main/scala/ai/chronon/spark/JoinBase.scala @@ -474,7 +474,7 @@ abstract class JoinBase(joinConf: api.Join, assert(Option(joinConf.metaData.team).nonEmpty, s"join.metaData.team needs to be set for join ${joinConf.metaData.name}") - joinConf.joinParts.toScala.foreach { jp => + joinConf.getCombinedJoinParts.foreach { jp => assert(Option(jp.groupBy.metaData.team).nonEmpty, s"groupBy.metaData.team needs to be set for joinPart ${jp.groupBy.metaData.name}") } diff --git a/spark/src/test/scala/ai/chronon/spark/test/BootstrapInfoTest.scala b/spark/src/test/scala/ai/chronon/spark/test/BootstrapInfoTest.scala deleted file mode 100644 index 7fa7a9fe66..0000000000 --- a/spark/src/test/scala/ai/chronon/spark/test/BootstrapInfoTest.scala +++ /dev/null @@ -1,418 +0,0 @@ -/* - * Copyright (C) 2023 The Chronon Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ai.chronon.spark.test - -import ai.chronon.aggregator.test.Column -import ai.chronon.api.Extensions._ -import ai.chronon.api._ -import ai.chronon.spark.Extensions._ -import ai.chronon.spark._ -import org.apache.spark.sql.SparkSession -import org.junit.Assert._ -import org.junit.Test - -import scala.util.Random - -class BootstrapInfoTest { - val spark: SparkSession = SparkSessionBuilder.build("BootstrapInfoTest", local = true) - private val tableUtils = TableUtils(spark) - private val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) - private val monthAgo = tableUtils.partitionSpec.minus(today, new Window(30, TimeUnit.DAYS)) - - @Test - def testExternalSourceWithOfflineGroupByConversion(): Unit = { - val suffix = "external_backfill_" + Random.alphanumeric.take(6).mkString - val namespace = s"test_namespace_$suffix" - tableUtils.createDatabase(namespace) - - // Create test data for the GroupBy source table - val groupByColumns = List( - Column("user_id", StringType, 100), - Column("feature_value", LongType, 1000) - ) - - val groupByTable = s"$namespace.user_features" - spark.sql(s"DROP TABLE IF EXISTS $groupByTable") - DataFrameGen.events(spark, groupByColumns, 1000, partitions = 50).save(groupByTable) - - // Create the GroupBy for offline backfill - val offlineGroupBy = Builders.GroupBy( - sources = Seq( - Builders.Source.events( - query = Builders.Query( - selects = Map("user_id" -> "user_id", "feature_value" -> "feature_value"), - timeColumn = "ts" - ), - table = groupByTable - ) - ), - keyColumns = Seq("user_id"), - aggregations = Seq( - Builders.Aggregation( - operation = Operation.SUM, - inputColumn = "feature_value", - windows = Seq(new Window(7, TimeUnit.DAYS)) - ) - ), - metaData = Builders.MetaData(name = s"user_features_gb_$suffix", namespace = namespace), - accuracy = Accuracy.TEMPORAL - ) - - // Create ExternalSource with offline GroupBy - val externalSource = Builders.ExternalSource( - metadata = Builders.MetaData(name = s"external_with_offline_$suffix"), - keySchema = StructType("external_keys", Array(StructField("user_id", StringType))), - valueSchema = StructType("external_values", Array(StructField("feature_value_sum_7d", LongType))) - ) - externalSource.setOfflineGroupBy(offlineGroupBy) - - // Create a simple left source for the join - val leftColumns = List( - Column("user_id", StringType, 100), - Column("request_id", StringType, 100) - ) - - val leftTable = s"$namespace.requests" - spark.sql(s"DROP TABLE IF EXISTS $leftTable") - DataFrameGen.events(spark, leftColumns, 500, partitions = 30).save(leftTable) - - // Create Join with ExternalPart that has offline GroupBy - val join = Builders.Join( - left = Builders.Source.events( - query = Builders.Query( - selects = Map("user_id" -> "user_id", "request_id" -> "request_id"), - timeColumn = "ts" - ), - table = leftTable - ), - externalParts = Seq( - Builders.ExternalPart( - externalSource, - prefix = "ext" - ) - ), - metaData = Builders.MetaData(name = s"test_join_$suffix", namespace = namespace) - ) - - // Test BootstrapInfo conversion logic - val endPartition = today - val range = PartitionRange(monthAgo, endPartition)(tableUtils) - val bootstrapInfo = BootstrapInfo.from( - joinConf = join, - range = range, - tableUtils = tableUtils, - leftSchema = None, - computeDependency = true - ) - - // Verify that ExternalPart with offline GroupBy was converted to JoinPart - val totalJoinParts = bootstrapInfo.joinParts.length - assertTrue("Should have at least one JoinPart after conversion", totalJoinParts > 0) - - // Verify that the converted JoinPart has the expected GroupBy - val convertedJoinPart = bootstrapInfo.joinParts.find(_.joinPart.groupBy.metaData.name == offlineGroupBy.metaData.name) - assertTrue("Should find converted JoinPart with matching GroupBy name", convertedJoinPart.isDefined) - - // Verify that online-only external parts are still tracked separately - assertTrue("Should have no online-only external parts in this test", bootstrapInfo.externalParts.isEmpty) - - // Verify schema compatibility - val joinPartMeta = convertedJoinPart.get - assertEquals("Key schema should match", 1, joinPartMeta.keySchema.length) - assertEquals("Key field should be user_id", "user_id", joinPartMeta.keySchema.head.name) - - spark.sql(s"DROP DATABASE IF EXISTS $namespace CASCADE") - } - - @Test - def testExternalSourceOnlineOnlyBehavior(): Unit = { - val suffix = "online_only_" + Random.alphanumeric.take(6).mkString - val namespace = s"test_namespace_$suffix" - tableUtils.createDatabase(namespace) - - // Create ExternalSource without offline GroupBy (online-only) - val onlineOnlyExternalSource = Builders.ExternalSource( - metadata = Builders.MetaData(name = s"online_only_external_$suffix"), - keySchema = StructType("online_keys", Array(StructField("user_id", StringType))), - valueSchema = StructType("online_values", Array(StructField("online_feature", LongType))) - ) - // Note: No offlineGroupBy set, so this remains online-only - - // Create a simple left source - val leftColumns = List( - Column("user_id", StringType, 100), - Column("request_id", StringType, 100) - ) - - val leftTable = s"$namespace.requests" - spark.sql(s"DROP TABLE IF EXISTS $leftTable") - DataFrameGen.events(spark, leftColumns, 500, partitions = 30).save(leftTable) - - // Create Join with online-only ExternalPart - val join = Builders.Join( - left = Builders.Source.events( - query = Builders.Query( - selects = Map("user_id" -> "user_id", "request_id" -> "request_id"), - timeColumn = "ts" - ), - table = leftTable - ), - externalParts = Seq( - Builders.ExternalPart( - onlineOnlyExternalSource, - prefix = "online" - ) - ), - metaData = Builders.MetaData(name = s"online_only_join_$suffix", namespace = namespace) - ) - - // Test BootstrapInfo with online-only external part - val endPartition = today - val range = PartitionRange(monthAgo, endPartition)(tableUtils) - val bootstrapInfo = BootstrapInfo.from( - joinConf = join, - range = range, - tableUtils = tableUtils, - leftSchema = None, - computeDependency = true - ) - - // Verify that online-only ExternalPart was NOT converted to JoinPart - assertEquals("Should have no JoinParts from conversion", 0, bootstrapInfo.joinParts.length) - - // Verify that online-only external part is tracked in externalParts - assertEquals("Should have one online-only external part", 1, bootstrapInfo.externalParts.length) - - val externalPartMeta = bootstrapInfo.externalParts.head - assertEquals("External part name should match", onlineOnlyExternalSource.metadata.name, - externalPartMeta.externalPart.source.metadata.name) - - spark.sql(s"DROP DATABASE IF EXISTS $namespace CASCADE") - } - - @Test - def testMixedExternalPartsConversion(): Unit = { - val suffix = "mixed_" + Random.alphanumeric.take(6).mkString - val namespace = s"test_namespace_$suffix" - tableUtils.createDatabase(namespace) - - // Create test data for the GroupBy source table - val groupByColumns = List( - Column("user_id", StringType, 100), - Column("feature_value", LongType, 1000) - ) - - val groupByTable = s"$namespace.user_features" - spark.sql(s"DROP TABLE IF EXISTS $groupByTable") - DataFrameGen.events(spark, groupByColumns, 1000, partitions = 50).save(groupByTable) - - // Create GroupBy for offline backfill - val offlineGroupBy = Builders.GroupBy( - sources = Seq( - Builders.Source.events( - query = Builders.Query( - selects = Map("user_id" -> "user_id", "feature_value" -> "feature_value"), - timeColumn = "ts" - ), - table = groupByTable - ) - ), - keyColumns = Seq("user_id"), - aggregations = Seq( - Builders.Aggregation( - operation = Operation.SUM, - inputColumn = "feature_value" - ) - ), - metaData = Builders.MetaData(name = s"offline_gb_$suffix", namespace = namespace), - accuracy = Accuracy.TEMPORAL - ) - - // Create ExternalSource with offline GroupBy - val externalSourceWithOffline = Builders.ExternalSource( - metadata = Builders.MetaData(name = s"external_with_offline_$suffix"), - keySchema = StructType("offline_keys", Array(StructField("user_id", StringType))), - valueSchema = StructType("offline_values", Array(StructField("feature_value_sum", LongType))) - ) - externalSourceWithOffline.setOfflineGroupBy(offlineGroupBy) - - // Create online-only ExternalSource - val externalSourceOnlineOnly = Builders.ExternalSource( - metadata = Builders.MetaData(name = s"external_online_only_$suffix"), - keySchema = StructType("online_keys", Array(StructField("user_id", StringType))), - valueSchema = StructType("online_values", Array(StructField("online_feature", LongType))) - ) - // No offlineGroupBy set - - // Create left source - val leftColumns = List( - Column("user_id", StringType, 100), - Column("request_id", StringType, 100) - ) - - val leftTable = s"$namespace.requests" - spark.sql(s"DROP TABLE IF EXISTS $leftTable") - DataFrameGen.events(spark, leftColumns, 500, partitions = 30).save(leftTable) - - // Create Join with both types of ExternalParts - val join = Builders.Join( - left = Builders.Source.events( - query = Builders.Query( - selects = Map("user_id" -> "user_id", "request_id" -> "request_id"), - timeColumn = "ts" - ), - table = leftTable - ), - externalParts = Seq( - Builders.ExternalPart( - externalSourceWithOffline, - prefix = "offline" - ), - Builders.ExternalPart( - externalSourceOnlineOnly, - prefix = "online" - ) - ), - metaData = Builders.MetaData(name = s"mixed_join_$suffix", namespace = namespace) - ) - - // Test BootstrapInfo with mixed external parts - val endPartition = today - val range = PartitionRange(monthAgo, endPartition)(tableUtils) - val bootstrapInfo = BootstrapInfo.from( - joinConf = join, - range = range, - tableUtils = tableUtils, - leftSchema = None, - computeDependency = true - ) - - // Verify that offline-capable ExternalPart was converted to JoinPart - assertEquals("Should have one JoinPart from conversion", 1, bootstrapInfo.joinParts.length) - val convertedJoinPart = bootstrapInfo.joinParts.head - assertEquals("Converted JoinPart should have matching GroupBy name", - offlineGroupBy.metaData.name, convertedJoinPart.joinPart.groupBy.metaData.name) - - // Verify that online-only ExternalPart is tracked separately - assertEquals("Should have one online-only external part", 1, bootstrapInfo.externalParts.length) - val onlineExternalPart = bootstrapInfo.externalParts.head - assertEquals("Online external part should have matching name", - externalSourceOnlineOnly.metadata.name, onlineExternalPart.externalPart.source.metadata.name) - - spark.sql(s"DROP DATABASE IF EXISTS $namespace CASCADE") - } - - @Test - def testExternalPartKeyMappingPreservation(): Unit = { - val suffix = "keymapping_" + Random.alphanumeric.take(6).mkString - val namespace = s"test_namespace_$suffix" - tableUtils.createDatabase(namespace) - - // Create test data for the GroupBy source table - val groupByColumns = List( - Column("internal_user_id", StringType, 100), - Column("feature_value", LongType, 1000) - ) - - val groupByTable = s"$namespace.user_features" - spark.sql(s"DROP TABLE IF EXISTS $groupByTable") - DataFrameGen.events(spark, groupByColumns, 1000, partitions = 50).save(groupByTable) - - // Create GroupBy with internal_user_id as key - val offlineGroupBy = Builders.GroupBy( - sources = Seq( - Builders.Source.events( - query = Builders.Query( - selects = Map("internal_user_id" -> "internal_user_id", "feature_value" -> "feature_value"), - timeColumn = "ts" - ), - table = groupByTable - ) - ), - keyColumns = Seq("internal_user_id"), - aggregations = Seq( - Builders.Aggregation( - operation = Operation.AVERAGE, - inputColumn = "feature_value" - ) - ), - metaData = Builders.MetaData(name = s"keymapping_gb_$suffix", namespace = namespace), - accuracy = Accuracy.TEMPORAL - ) - - // Create ExternalSource with key mapping - val externalSource = Builders.ExternalSource( - metadata = Builders.MetaData(name = s"external_with_keymapping_$suffix"), - keySchema = StructType("external_keys", Array(StructField("internal_user_id", StringType))), - valueSchema = StructType("external_values", Array(StructField("feature_value_avg", DoubleType))) - ) - externalSource.setOfflineGroupBy(offlineGroupBy) - - // Create left source with external_user_id - val leftColumns = List( - Column("external_user_id", StringType, 100), - Column("request_id", StringType, 100) - ) - - val leftTable = s"$namespace.requests" - spark.sql(s"DROP TABLE IF EXISTS $leftTable") - DataFrameGen.events(spark, leftColumns, 500, partitions = 30).save(leftTable) - - // Create Join with key mapping from external_user_id to internal_user_id - val join = Builders.Join( - left = Builders.Source.events( - query = Builders.Query( - selects = Map("external_user_id" -> "external_user_id", "request_id" -> "request_id"), - timeColumn = "ts" - ), - table = leftTable - ), - externalParts = Seq( - Builders.ExternalPart( - externalSource, - keyMapping = Map("external_user_id" -> "internal_user_id"), - prefix = "mapped" - ) - ), - metaData = Builders.MetaData(name = s"keymapping_join_$suffix", namespace = namespace) - ) - - // Test BootstrapInfo preserves key mapping - val endPartition = today - val range = PartitionRange(monthAgo, endPartition)(tableUtils) - val bootstrapInfo = BootstrapInfo.from( - joinConf = join, - range = range, - tableUtils = tableUtils, - leftSchema = None, - computeDependency = true - ) - - // Verify conversion occurred - assertEquals("Should have one converted JoinPart", 1, bootstrapInfo.joinParts.length) - - val convertedJoinPart = bootstrapInfo.joinParts.head.joinPart - assertNotNull("Key mapping should be preserved", convertedJoinPart.keyMapping) - assertEquals("Key mapping should map external_user_id to internal_user_id", - "internal_user_id", convertedJoinPart.keyMapping.get("external_user_id")) - - // Verify prefix is preserved - assertEquals("Prefix should be preserved", "mapped", convertedJoinPart.prefix) - - spark.sql(s"DROP DATABASE IF EXISTS $namespace CASCADE") - } -} \ No newline at end of file diff --git a/spark/src/test/scala/ai/chronon/spark/test/ExternalSourceBackfillTest.scala b/spark/src/test/scala/ai/chronon/spark/test/ExternalSourceBackfillTest.scala index 75847f49ba..5bc8156b65 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/ExternalSourceBackfillTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/ExternalSourceBackfillTest.scala @@ -221,7 +221,7 @@ class ExternalSourceBackfillTest { val externalSource = Builders.ExternalSource( metadata = Builders.MetaData(name = s"external_purchase_features_$namespace"), keySchema = StructType("external_keys", Array(StructField("user_id", StringType))), - valueSchema = StructType("external_values", Array(StructField("purchase_amount_avg_7d", DoubleType))) + valueSchema = StructType("external_values", Array(StructField("purchase_amount_average_7d", DoubleType))) ) externalSource.setOfflineGroupBy(purchaseGroupBy) @@ -306,7 +306,7 @@ class ExternalSourceBackfillTest { val featureTable = s"$namespace.user_features" spark.sql(s"DROP TABLE IF EXISTS $featureTable") - DataFrameGen.events(spark, featureColumns, 1000, partitions = 50).save(featureTable) + DataFrameGen.events(spark, featureColumns, 2000, partitions = 60).save(featureTable) // Create GroupBy using internal_user_id val featureGroupBy = Builders.GroupBy( From 0f487a008d462e8e128a96ad62977af3f24c7863 Mon Sep 17 00:00:00 2001 From: Hui Cheng Date: Tue, 30 Sep 2025 13:25:39 -0700 Subject: [PATCH 05/13] Refactory: update function name and move unit tests --- .../scala/ai/chronon/api/Extensions.scala | 127 +++--------------- .../ai/chronon/api/test/ExtensionsTest.scala | 98 -------------- .../scala/ai/chronon/spark/Analyzer.scala | 99 +++++++++++++- .../ai/chronon/spark/BootstrapInfo.scala | 62 ++++----- .../scala/ai/chronon/spark/JoinBase.scala | 2 +- .../ai/chronon/spark/test/AnalyzerTest.scala | 105 ++++++++++++++- .../test/ExternalSourceBackfillTest.scala | 1 - 7 files changed, 251 insertions(+), 243 deletions(-) diff --git a/api/src/main/scala/ai/chronon/api/Extensions.scala b/api/src/main/scala/ai/chronon/api/Extensions.scala index 0940a3c526..cb407df463 100644 --- a/api/src/main/scala/ai/chronon/api/Extensions.scala +++ b/api/src/main/scala/ai/chronon/api/Extensions.scala @@ -686,86 +686,6 @@ object Extensions { lazy val valueFields: Array[StructField] = schemaFields(externalSource.valueSchema) def isContextualSource: Boolean = externalSource.metadata.name == Constants.ContextualSourceName - - /** - * Validates schema compatibility between ExternalSource and its offlineGroupBy. - * This ensures that online and offline serving will produce consistent results. - * - * @return Sequence of error messages, empty if no errors - */ - def validateOfflineGroupBy(): Seq[String] = - Option(externalSource.offlineGroupBy) - .map(_ => validateKeySchemaCompatibility() ++ validateValueSchemaCompatibility()) - .getOrElse(Seq.empty) - - private def validateKeySchemaCompatibility(): Seq[String] = { - val errors = scala.collection.mutable.ListBuffer[String]() - - if (externalSource.keySchema == null) { - errors += s"ExternalSource ${externalSource.metadata.name} keySchema cannot be null when offlineGroupBy is specified" - return errors - } - - if (externalSource.offlineGroupBy.keyColumns == null || externalSource.offlineGroupBy.keyColumns.isEmpty) { - errors += s"ExternalSource ${externalSource.metadata.name} offlineGroupBy keyColumns cannot be null or empty" - return errors - } - - val externalKeyFields = keyFields - val groupByKeyColumns = externalSource.offlineGroupBy.keyColumns.toScala.toSet - - // Extract field names from external source key schema - val externalKeyNames = externalKeyFields.map(_.name).toSet - - // Validate that GroupBy has key columns that match ExternalSource key schema - val missingKeys = externalKeyNames -- groupByKeyColumns - val extraKeys = groupByKeyColumns -- externalKeyNames - - if (missingKeys.nonEmpty) { - errors += s"ExternalSource ${externalSource.metadata.name} key schema contains columns [${missingKeys.mkString(", ")}] " + - s"that are not present in offlineGroupBy keyColumns [${groupByKeyColumns.mkString(", ")}]. " + - s"All ExternalSource key columns must be present in the GroupBy key columns." - } - - if (extraKeys.nonEmpty) { - errors += s"ExternalSource ${externalSource.metadata.name} offlineGroupBy keyColumns contain [${extraKeys - .mkString(", ")}] " + - s"that are not present in ExternalSource keySchema [${externalKeyNames.mkString(", ")}]. " + - s"GroupBy key columns cannot contain keys not defined in ExternalSource keySchema." - } - - errors - } - - private def validateValueSchemaCompatibility(): Seq[String] = { - val errors = scala.collection.mutable.ListBuffer[String]() - - if (externalSource.valueSchema == null) { - errors += s"ExternalSource ${externalSource.metadata.name} valueSchema cannot be null when offlineGroupBy is specified" - return errors - } - - val externalValueFields = valueFields - val externalValueNames = externalValueFields.map(_.name).toSet - - // For GroupBy value schema, we need to derive the output schema from aggregations - val groupByValueColumns = externalSource.offlineGroupBy.valueColumns.toSet - - // Check that ExternalSource value schema fields are compatible with GroupBy output - val missingValueColumns = externalValueNames -- groupByValueColumns - - if (missingValueColumns.nonEmpty) { - // This is an error because ExternalSource valueSchema must be compatible with GroupBy output - // to ensure consistency between online and offline serving - errors += s"ExternalSource ${externalSource.metadata.name} valueSchema contains columns [${missingValueColumns - .mkString(", ")}] " + - s"that are not present in offlineGroupBy output columns [${groupByValueColumns.mkString(", ")}]. " + - s"This indicates schema incompatibility between online and offline serving. " + - s"Please ensure ExternalSource valueSchema matches the expected output of the GroupBy aggregations." - } - - errors - } } object KeyMappingHelper { @@ -963,7 +883,9 @@ object Extensions { private[api] def baseSemanticHash: Map[String, String] = { val leftHash = ThriftJsonCodec.md5Digest(join.left) logger.info(s"Join Left Object: ${ThriftJsonCodec.toJsonStr(join.left)}") - val partHashes = join.getCombinedJoinParts.map { jp => partOutputTable(jp) -> jp.groupBy.semanticHash }.toMap + val partHashes = join.getRegularAndExternalJoinParts.map { jp => + partOutputTable(jp) -> jp.groupBy.semanticHash + }.toMap val derivedHashMap = Option(join.derivations) .map { derivations => val derivedHash = @@ -991,7 +913,7 @@ object Extensions { } cleanTopicInSource(join.left) - join.getCombinedJoinParts.foreach(_.groupBy.sources.toScala.foreach(cleanTopicInSource)) + join.getRegularAndExternalJoinParts.foreach(_.groupBy.sources.toScala.foreach(cleanTopicInSource)) join } @@ -1031,17 +953,6 @@ object Extensions { .getOrElse(Seq.empty) } - /** - * Validates all ExternalSources in this Join's onlineExternalParts. - * This ensures schema compatibility between ExternalSources and their offlineGroupBy configurations. - * - * @return Sequence of error messages, empty if no errors - */ - def validateExternalSources(): Seq[String] = - Option(join.onlineExternalParts) - .map(_.toScala.flatMap(_.source.validateOfflineGroupBy())) - .getOrElse(Seq.empty) - def isProduction: Boolean = join.getMetaData.isProduction def team: String = join.getMetaData.getTeam @@ -1099,21 +1010,21 @@ object Extensions { } def setups: Seq[String] = - (join.left.query.setupsSeq ++ join.getCombinedJoinParts + (join.left.query.setupsSeq ++ join.getRegularAndExternalJoinParts .flatMap(_.groupBy.setups)).distinct /** - * Converts offline-capable ExternalParts to JoinParts for unified processing during backfill. - * This enables external sources with offlineGroupBy to participate in offline computation - * while maintaining compatibility with existing join processing logic. - * - * @return Sequence of JoinParts converted from offline-capable ExternalParts - */ + * Converts offline-capable ExternalParts to JoinParts for unified processing during backfill. + * This enables external sources with offlineGroupBy to participate in offline computation + * while maintaining compatibility with existing join processing logic. + * + * @return Sequence of JoinParts converted from offline-capable ExternalParts + */ private def getExternalJoinParts: Seq[JoinPart] = { Option(join.onlineExternalParts) .map(_.toScala) .getOrElse(Seq.empty) - .filter(_.source.offlineGroupBy != null) // Only offline-capable ExternalParts + .filter(_.source.offlineGroupBy != null) // Only offline-capable ExternalParts .map { externalPart => // Convert ExternalPart to synthetic JoinPart val syntheticJoinPart = new JoinPart() @@ -1129,12 +1040,12 @@ object Extensions { } /** - * Get all join parts including both regular joinParts and external join parts. - * This provides a unified view of all join parts for processing. - * - * @return Sequence containing all JoinParts (regular + converted external) - */ - def getCombinedJoinParts: Seq[JoinPart] = { + * Get all join parts including both regular joinParts and external join parts. + * This provides a unified view of all join parts for processing. + * + * @return Sequence containing all JoinParts (regular + converted external) + */ + def getRegularAndExternalJoinParts: Seq[JoinPart] = { val regularJoinParts = Option(join.joinParts).map(_.toScala).getOrElse(Seq.empty) val externalJoinParts = getExternalJoinParts regularJoinParts ++ externalJoinParts @@ -1153,7 +1064,7 @@ object Extensions { } lazy val joinPartOps: Seq[JoinPartOps] = - Option(join.getCombinedJoinParts) + Option(join.getRegularAndExternalJoinParts) .getOrElse(Seq.empty[JoinPart]) .map(new JoinPartOps(_)) diff --git a/api/src/test/scala/ai/chronon/api/test/ExtensionsTest.scala b/api/src/test/scala/ai/chronon/api/test/ExtensionsTest.scala index 51839af2d5..1d6e09a01e 100644 --- a/api/src/test/scala/ai/chronon/api/test/ExtensionsTest.scala +++ b/api/src/test/scala/ai/chronon/api/test/ExtensionsTest.scala @@ -249,102 +249,4 @@ class ExtensionsTest { assertEquals(join1.semanticHash(excludeTopic = false), join2.semanticHash(excludeTopic = false)) } - @Test - def testExternalSourceValidationWithMatchingSchemas(): Unit = { - // Create compatible schemas using the correct DataType objects - val keySchema = StructType("key", Array(StructField("user_id", StringType))) - val valueSchema = StructType("value", Array(StructField("feature_value", DoubleType))) - - // Create a query and source for the GroupBy - val query = Builders.Query(selects = Map("feature_value" -> "value")) - val source = Builders.Source.events(query, "test.table") - - // Create GroupBy with matching key columns and sources - val groupBy = Builders.GroupBy( - keyColumns = Seq("user_id"), - sources = Seq(source) - ) - - // Create ExternalSource with compatible schemas - val metadata = Builders.MetaData(name = "test_external_source") - val externalSource = Builders.ExternalSource(metadata, keySchema, valueSchema) - - // Manually set the offlineGroupBy field since the builder doesn't support it yet - externalSource.setOfflineGroupBy(groupBy) - - // This should return no errors - val errors = externalSource.validateOfflineGroupBy() - assertTrue(s"Expected no errors, but got: ${errors.mkString(", ")}", errors.isEmpty) - } - - @Test - def testExternalSourceValidationWithMismatchedKeySchemas(): Unit = { - // Create key schema with different fields - val keySchema = StructType("key", Array(StructField("user_id", StringType))) - val valueSchema = StructType("value", Array(StructField("feature_value", DoubleType))) - - // Create a query and source for the GroupBy - val query = Builders.Query(selects = Map("feature_value" -> "feature_value")) - val source = Builders.Source.events(query, "test.table") - - // Create GroupBy with different key columns - val groupBy = Builders.GroupBy( - keyColumns = Seq("different_key"), // Mismatched key column - sources = Seq(source) - ) - - // Create ExternalSource with incompatible schemas - val metadata = Builders.MetaData(name = "test_external_source") - val externalSource = Builders.ExternalSource(metadata, keySchema, valueSchema) - externalSource.setOfflineGroupBy(groupBy) - - // This should return validation errors - val errors = externalSource.validateOfflineGroupBy() - assertFalse("Expected validation errors for mismatched key schemas", errors.isEmpty) - assertTrue("Error should mention key schema mismatch", - errors.exists(_.contains("key schema contains columns"))) - } - - @Test - def testExternalSourceValidationWithMismatchedValueSchemas(): Unit = { - // Create compatible key schema but mismatched value schema - val keySchema = StructType("key", Array(StructField("user_id", StringType))) - val valueSchema = StructType("value", Array(StructField("feature_value", DoubleType))) - - // Create a source for the GroupBy - this is needed for valueColumns to work - val query = Builders.Query(selects = Map("different_feature" -> "different_feature")) - val source = Builders.Source.events(query, "test.table") - - // Create GroupBy with different value columns - val groupBy = Builders.GroupBy( - keyColumns = Seq("user_id"), - sources = Seq(source) - ) - - // Create ExternalSource with incompatible schemas - val metadata = Builders.MetaData(name = "test_external_source") - val externalSource = Builders.ExternalSource(metadata, keySchema, valueSchema) - externalSource.setOfflineGroupBy(groupBy) - - // This should return validation errors - val errors = externalSource.validateOfflineGroupBy() - assertFalse("Expected validation errors for mismatched value schemas", errors.isEmpty) - assertTrue("Error should mention value schema mismatch", - errors.exists(_.contains("valueSchema contains columns"))) - } - - @Test - def testExternalSourceValidationWithNullOfflineGroupBy(): Unit = { - // Create ExternalSource without offlineGroupBy - val keySchema = StructType("key", Array(StructField("user_id", StringType))) - val valueSchema = StructType("value", Array(StructField("feature_value", DoubleType))) - - val metadata = Builders.MetaData(name = "test_external_source") - val externalSource = Builders.ExternalSource(metadata, keySchema, valueSchema) - // Don't set offlineGroupBy (it remains null) - - // This should return no errors (validation should be skipped) - val errors = externalSource.validateOfflineGroupBy() - assertTrue("Expected no errors when offlineGroupBy is null", errors.isEmpty) - } } diff --git a/spark/src/main/scala/ai/chronon/spark/Analyzer.scala b/spark/src/main/scala/ai/chronon/spark/Analyzer.scala index edf05ac474..09cc953d89 100644 --- a/spark/src/main/scala/ai/chronon/spark/Analyzer.scala +++ b/spark/src/main/scala/ai/chronon/spark/Analyzer.scala @@ -390,7 +390,7 @@ class Analyzer(tableUtils: TableUtils, ) .getOrElse(Seq.empty) - joinConf.getCombinedJoinParts.foreach { part => + joinConf.getRegularAndExternalJoinParts.foreach { part => val analyzeGroupByResult = analyzeGroupBy( part.groupBy, @@ -426,7 +426,7 @@ class Analyzer(tableUtils: TableUtils, } if (joinConf.onlineExternalParts != null) { // Validate ExternalSource schemas if they have offlineGroupBy configured - externalSourceErrors ++= joinConf.validateExternalSources() + externalSourceErrors ++= runExternalSourceCheck(joinConf) joinConf.onlineExternalParts.toScala.foreach { part => joinIntermediateValuesMetadata ++= part.source.valueFields.map { field => @@ -442,7 +442,7 @@ class Analyzer(tableUtils: TableUtils, val rightSchema = joinIntermediateValuesMetadata.map(aggregation => (aggregation.name, aggregation.columnType)) - val keyColumns: List[String] = joinConf.getCombinedJoinParts.toList + val keyColumns: List[String] = joinConf.getRegularAndExternalJoinParts.toList .flatMap(joinPart => { val keyCols: Seq[String] = joinPart.groupBy.keyColumns.toScala if (joinPart.keyMapping == null) keyCols @@ -807,6 +807,99 @@ class Analyzer(tableUtils: TableUtils, analyzeGroupByResult } + /** + * Validates schema compatibility between ExternalSource and its offlineGroupBy. + * This ensures that online and offline serving will produce consistent results. + * + * @param externalSource The external source to validate + * @return Sequence of error messages, empty if no errors + */ + def validateOfflineGroupBy(externalSource: api.ExternalSource): Seq[String] = + Option(externalSource.offlineGroupBy) + .map(_ => validateKeySchemaCompatibility(externalSource) ++ validateValueSchemaCompatibility(externalSource)) + .getOrElse(Seq.empty) + + private def validateKeySchemaCompatibility(externalSource: api.ExternalSource): Seq[String] = { + val errors = scala.collection.mutable.ListBuffer[String]() + + if (externalSource.keySchema == null) { + errors += s"ExternalSource ${externalSource.metadata.name} keySchema cannot be null when offlineGroupBy is specified" + return errors + } + + if (externalSource.offlineGroupBy.keyColumns == null || externalSource.offlineGroupBy.keyColumns.isEmpty) { + errors += s"ExternalSource ${externalSource.metadata.name} offlineGroupBy keyColumns cannot be null or empty" + return errors + } + + val externalKeyFields = externalSource.keyFields + val groupByKeyColumns = externalSource.offlineGroupBy.keyColumns.toScala.toSet + + // Extract field names from external source key schema + val externalKeyNames = externalKeyFields.map(_.name).toSet + + // Validate that GroupBy has key columns that match ExternalSource key schema + val missingKeys = externalKeyNames -- groupByKeyColumns + val extraKeys = groupByKeyColumns -- externalKeyNames + + if (missingKeys.nonEmpty) { + errors += s"ExternalSource ${externalSource.metadata.name} key schema contains columns [${missingKeys.mkString(", ")}] " + + s"that are not present in offlineGroupBy keyColumns [${groupByKeyColumns.mkString(", ")}]. " + + s"All ExternalSource key columns must be present in the GroupBy key columns." + } + + if (extraKeys.nonEmpty) { + errors += s"ExternalSource ${externalSource.metadata.name} offlineGroupBy keyColumns contain [${extraKeys + .mkString(", ")}] " + + s"that are not present in ExternalSource keySchema [${externalKeyNames.mkString(", ")}]. " + + s"GroupBy key columns cannot contain keys not defined in ExternalSource keySchema." + } + + errors + } + + private def validateValueSchemaCompatibility(externalSource: api.ExternalSource): Seq[String] = { + val errors = scala.collection.mutable.ListBuffer[String]() + + if (externalSource.valueSchema == null) { + errors += s"ExternalSource ${externalSource.metadata.name} valueSchema cannot be null when offlineGroupBy is specified" + return errors + } + + val externalValueFields = externalSource.valueFields + val externalValueNames = externalValueFields.map(_.name).toSet + + // For GroupBy value schema, we need to derive the output schema from aggregations + val groupByValueColumns = externalSource.offlineGroupBy.valueColumns.toSet + + // Check that ExternalSource value schema fields are compatible with GroupBy output + val missingValueColumns = externalValueNames -- groupByValueColumns + + if (missingValueColumns.nonEmpty) { + // This is an error because ExternalSource valueSchema must be compatible with GroupBy output + // to ensure consistency between online and offline serving + errors += s"ExternalSource ${externalSource.metadata.name} valueSchema contains columns [${missingValueColumns + .mkString(", ")}] " + + s"that are not present in offlineGroupBy output columns [${groupByValueColumns.mkString(", ")}]. " + + s"This indicates schema incompatibility between online and offline serving. " + + s"Please ensure ExternalSource valueSchema matches the expected output of the GroupBy aggregations." + } + + errors + } + + /** + * Validates all ExternalSources in this Join's onlineExternalParts. + * This ensures schema compatibility between ExternalSources and their offlineGroupBy configurations. + * + * @param joinConf The join configuration to validate + * @return Sequence of error messages, empty if no errors + */ + def runExternalSourceCheck(joinConf: api.Join): Seq[String] = + Option(joinConf.onlineExternalParts) + .map(_.toScala.flatMap(part => validateOfflineGroupBy(part.source))) + .getOrElse(Seq.empty) + def run(exportSchema: Boolean = false): Unit = conf match { case confPath: String => diff --git a/spark/src/main/scala/ai/chronon/spark/BootstrapInfo.scala b/spark/src/main/scala/ai/chronon/spark/BootstrapInfo.scala index a503bcb00b..9089e1ae8f 100644 --- a/spark/src/main/scala/ai/chronon/spark/BootstrapInfo.scala +++ b/spark/src/main/scala/ai/chronon/spark/BootstrapInfo.scala @@ -85,45 +85,45 @@ object BootstrapInfo { logger.info(s"\nCreating BootstrapInfo for GroupBys for Join ${joinConf.metaData.name}") // Get all join parts including both regular and external join parts - val allJoinParts = joinConf.getCombinedJoinParts + val allJoinParts = joinConf.getRegularAndExternalJoinParts var joinParts: Seq[JoinPartMetadata] = allJoinParts.map(part => { - // set computeDependency to False as we compute dependency upstream - val gb = GroupBy.from(part.groupBy, range, tableUtils, computeDependency) - val keySchema = SparkConversions - .toChrononSchema(gb.keySchema) - .map(field => StructField(part.rightToLeft(field._1), field._2)) - - Analyzer.validateAvroCompatibility(tableUtils, gb, part.groupBy) - - val keyAndPartitionFields = - gb.keySchema.fields ++ Seq(org.apache.spark.sql.types.StructField(tableUtils.partitionColumn, StringType)) - // todo: this change is only valid for offline use case - // we need to revisit logic for the logging part to make sure the derived columns are also logged - // to make bootstrap continue to work - val outputSchema = if (part.groupBy.hasDerivations) { - val sparkSchema = { - StructType(SparkConversions.fromChrononSchema(gb.outputSchema).fields ++ keyAndPartitionFields) - } - val dummyOutputDf = tableUtils.sparkSession - .createDataFrame(tableUtils.sparkSession.sparkContext.parallelize(immutable.Seq[Row]()), sparkSchema) - val finalOutputColumns = part.groupBy.derivationsScala.finalOutputColumn(dummyOutputDf.columns).toSeq - val derivedDummyOutputDf = dummyOutputDf.select(finalOutputColumns: _*) - val columns = SparkConversions.toChrononSchema( - StructType(derivedDummyOutputDf.schema.filterNot(keyAndPartitionFields.contains))) - api.StructType("", columns.map(tup => api.StructField(tup._1, tup._2))) - } else { - gb.outputSchema + // set computeDependency to False as we compute dependency upstream + val gb = GroupBy.from(part.groupBy, range, tableUtils, computeDependency) + val keySchema = SparkConversions + .toChrononSchema(gb.keySchema) + .map(field => StructField(part.rightToLeft(field._1), field._2)) + + Analyzer.validateAvroCompatibility(tableUtils, gb, part.groupBy) + + val keyAndPartitionFields = + gb.keySchema.fields ++ Seq(org.apache.spark.sql.types.StructField(tableUtils.partitionColumn, StringType)) + // todo: this change is only valid for offline use case + // we need to revisit logic for the logging part to make sure the derived columns are also logged + // to make bootstrap continue to work + val outputSchema = if (part.groupBy.hasDerivations) { + val sparkSchema = { + StructType(SparkConversions.fromChrononSchema(gb.outputSchema).fields ++ keyAndPartitionFields) } - val valueSchema = outputSchema.fields.map(part.constructJoinPartSchema) - JoinPartMetadata(part, keySchema, valueSchema, Map.empty) // will be populated below - }) + val dummyOutputDf = tableUtils.sparkSession + .createDataFrame(tableUtils.sparkSession.sparkContext.parallelize(immutable.Seq[Row]()), sparkSchema) + val finalOutputColumns = part.groupBy.derivationsScala.finalOutputColumn(dummyOutputDf.columns).toSeq + val derivedDummyOutputDf = dummyOutputDf.select(finalOutputColumns: _*) + val columns = SparkConversions.toChrononSchema( + StructType(derivedDummyOutputDf.schema.filterNot(keyAndPartitionFields.contains))) + api.StructType("", columns.map(tup => api.StructField(tup._1, tup._2))) + } else { + gb.outputSchema + } + val valueSchema = outputSchema.fields.map(part.constructJoinPartSchema) + JoinPartMetadata(part, keySchema, valueSchema, Map.empty) // will be populated below + }) // Enrich online only external parts with the expected output schema logger.info(s"\nCreating BootstrapInfo for online-only ExternalParts for Join ${joinConf.metaData.name}") val externalParts: Seq[ExternalPartMetadata] = Option(joinConf.onlineExternalParts.toScala) .getOrElse(Seq.empty) - .filter(_.source.offlineGroupBy == null) // Only online-only ExternalParts + .filter(_.source.offlineGroupBy == null) // Only online-only ExternalParts .map(part => ExternalPartMetadata(part, part.keySchemaFull, part.valueSchemaFull)) val leftFields = leftSchema diff --git a/spark/src/main/scala/ai/chronon/spark/JoinBase.scala b/spark/src/main/scala/ai/chronon/spark/JoinBase.scala index 9c35d3c3b3..5aa209a965 100644 --- a/spark/src/main/scala/ai/chronon/spark/JoinBase.scala +++ b/spark/src/main/scala/ai/chronon/spark/JoinBase.scala @@ -474,7 +474,7 @@ abstract class JoinBase(joinConf: api.Join, assert(Option(joinConf.metaData.team).nonEmpty, s"join.metaData.team needs to be set for join ${joinConf.metaData.name}") - joinConf.getCombinedJoinParts.foreach { jp => + joinConf.getRegularAndExternalJoinParts.foreach { jp => assert(Option(jp.groupBy.metaData.team).nonEmpty, s"groupBy.metaData.team needs to be set for joinPart ${jp.groupBy.metaData.name}") } diff --git a/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala b/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala index b008c36a09..c070d37dd8 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala @@ -26,7 +26,7 @@ import ai.chronon.spark.catalog.TableUtils import ai.chronon.spark.{Analyzer, Join, SparkSessionBuilder} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions.{col, lit, to_json} -import org.junit.Assert.{assertEquals, assertTrue} +import org.junit.Assert.{assertEquals, assertFalse, assertTrue} import org.junit.Test import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.{never, spy, verify, when} @@ -909,4 +909,107 @@ class AnalyzerTest { analyzer.analyzeGroupBy(tableGroupBy) } + @Test + def testExternalSourceValidationWithMatchingSchemas(): Unit = { + // Create compatible schemas using the correct DataType objects + val keySchema = StructType("key", Array(StructField("user_id", StringType))) + val valueSchema = StructType("value", Array(StructField("feature_value", DoubleType))) + + // Create a query and source for the GroupBy + val query = Builders.Query(selects = Map("feature_value" -> "value")) + val source = Builders.Source.events(query, "test.table") + + // Create GroupBy with matching key columns and sources + val groupBy = Builders.GroupBy( + keyColumns = Seq("user_id"), + sources = Seq(source) + ) + + // Create ExternalSource with compatible schemas + val metadata = Builders.MetaData(name = "test_external_source") + val externalSource = Builders.ExternalSource(metadata, keySchema, valueSchema) + + // Manually set the offlineGroupBy field since the builder doesn't support it yet + externalSource.setOfflineGroupBy(groupBy) + + // Create analyzer instance and call validation + val analyzer = new Analyzer(dummyTableUtils, externalSource, oneMonthAgo, today) + val errors = analyzer.validateOfflineGroupBy(externalSource) + assertTrue(s"Expected no errors, but got: ${errors.mkString(", ")}", errors.isEmpty) + } + + @Test + def testExternalSourceValidationWithMismatchedKeySchemas(): Unit = { + // Create key schema with different fields + val keySchema = StructType("key", Array(StructField("user_id", StringType))) + val valueSchema = StructType("value", Array(StructField("feature_value", DoubleType))) + + // Create a query and source for the GroupBy + val query = Builders.Query(selects = Map("feature_value" -> "feature_value")) + val source = Builders.Source.events(query, "test.table") + + // Create GroupBy with different key columns + val groupBy = Builders.GroupBy( + keyColumns = Seq("different_key"), // Mismatched key column + sources = Seq(source) + ) + + // Create ExternalSource with incompatible schemas + val metadata = Builders.MetaData(name = "test_external_source") + val externalSource = Builders.ExternalSource(metadata, keySchema, valueSchema) + externalSource.setOfflineGroupBy(groupBy) + + // This should return validation errors + val analyzer = new Analyzer(dummyTableUtils, externalSource, oneMonthAgo, today) + val errors = analyzer.validateOfflineGroupBy(externalSource) + assertFalse("Expected validation errors for mismatched key schemas", errors.isEmpty) + assertTrue("Error should mention key schema mismatch", + errors.exists(_.contains("key schema contains columns"))) + } + + @Test + def testExternalSourceValidationWithMismatchedValueSchemas(): Unit = { + // Create compatible key schema but mismatched value schema + val keySchema = StructType("key", Array(StructField("user_id", StringType))) + val valueSchema = StructType("value", Array(StructField("feature_value", DoubleType))) + + // Create a source for the GroupBy - this is needed for valueColumns to work + val query = Builders.Query(selects = Map("different_feature" -> "different_feature")) + val source = Builders.Source.events(query, "test.table") + + // Create GroupBy with different value columns + val groupBy = Builders.GroupBy( + keyColumns = Seq("user_id"), + sources = Seq(source) + ) + + // Create ExternalSource with incompatible schemas + val metadata = Builders.MetaData(name = "test_external_source") + val externalSource = Builders.ExternalSource(metadata, keySchema, valueSchema) + externalSource.setOfflineGroupBy(groupBy) + + // This should return validation errors + val analyzer = new Analyzer(dummyTableUtils, externalSource, oneMonthAgo, today) + val errors = analyzer.validateOfflineGroupBy(externalSource) + assertFalse("Expected validation errors for mismatched value schemas", errors.isEmpty) + assertTrue("Error should mention value schema mismatch", + errors.exists(_.contains("valueSchema contains columns"))) + } + + @Test + def testExternalSourceValidationWithNullOfflineGroupBy(): Unit = { + // Create ExternalSource without offlineGroupBy + val keySchema = StructType("key", Array(StructField("user_id", StringType))) + val valueSchema = StructType("value", Array(StructField("feature_value", DoubleType))) + + val metadata = Builders.MetaData(name = "test_external_source") + val externalSource = Builders.ExternalSource(metadata, keySchema, valueSchema) + // Don't set offlineGroupBy (it remains null) + + // This should return no errors (validation should be skipped) + val analyzer = new Analyzer(dummyTableUtils, externalSource, oneMonthAgo, today) + val errors = analyzer.validateOfflineGroupBy(externalSource) + assertTrue("Expected no errors when offlineGroupBy is null", errors.isEmpty) + } + } diff --git a/spark/src/test/scala/ai/chronon/spark/test/ExternalSourceBackfillTest.scala b/spark/src/test/scala/ai/chronon/spark/test/ExternalSourceBackfillTest.scala index 5bc8156b65..32c552a7dd 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/ExternalSourceBackfillTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/ExternalSourceBackfillTest.scala @@ -32,7 +32,6 @@ class ExternalSourceBackfillTest { private val tableUtils = TableUtils(spark) private val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) private val monthAgo = tableUtils.partitionSpec.minus(today, new Window(30, TimeUnit.DAYS)) - private val yearAgo = tableUtils.partitionSpec.minus(today, new Window(365, TimeUnit.DAYS)) @Test def testExternalSourceBackfillComputeJoin(): Unit = { From 6e5064df88e929c628900caf1fa743df9931caae Mon Sep 17 00:00:00 2001 From: Hui Cheng Date: Tue, 30 Sep 2025 20:42:56 -0700 Subject: [PATCH 06/13] Update extension class --- .../scala/ai/chronon/api/Extensions.scala | 6 ++- .../ai/chronon/api/test/ExtensionsTest.scala | 3 +- .../scala/ai/chronon/spark/Analyzer.scala | 33 +++++++++-------- .../ai/chronon/spark/test/AnalyzerTest.scala | 37 ++++++++++++++----- .../test/ExternalSourceBackfillTest.scala | 28 ++++++++++++++ 5 files changed, 79 insertions(+), 28 deletions(-) diff --git a/api/src/main/scala/ai/chronon/api/Extensions.scala b/api/src/main/scala/ai/chronon/api/Extensions.scala index cb407df463..a1e032a01b 100644 --- a/api/src/main/scala/ai/chronon/api/Extensions.scala +++ b/api/src/main/scala/ai/chronon/api/Extensions.scala @@ -1064,8 +1064,10 @@ object Extensions { } lazy val joinPartOps: Seq[JoinPartOps] = - Option(join.getRegularAndExternalJoinParts) - .getOrElse(Seq.empty[JoinPart]) + Option(join.joinParts) + .getOrElse(new util.ArrayList[JoinPart]()) + .toScala + .toSeq .map(new JoinPartOps(_)) def logFullValues: Boolean = true // TODO: supports opt-out in the future diff --git a/api/src/test/scala/ai/chronon/api/test/ExtensionsTest.scala b/api/src/test/scala/ai/chronon/api/test/ExtensionsTest.scala index 1d6e09a01e..22f91d6d22 100644 --- a/api/src/test/scala/ai/chronon/api/test/ExtensionsTest.scala +++ b/api/src/test/scala/ai/chronon/api/test/ExtensionsTest.scala @@ -16,7 +16,7 @@ package ai.chronon.api.test -import ai.chronon.api.{Accuracy, Builders, Constants, GroupBy, StringType, DoubleType, StructType, StructField} +import ai.chronon.api.{Accuracy, Builders, Constants, GroupBy} import org.junit.Test import ai.chronon.api.Extensions._ import org.junit.Assert.{assertEquals, assertFalse, assertNotEquals, assertTrue} @@ -248,5 +248,4 @@ class ExtensionsTest { assertEquals(join1.semanticHash(excludeTopic = true), join2.semanticHash(excludeTopic = true)) assertEquals(join1.semanticHash(excludeTopic = false), join2.semanticHash(excludeTopic = false)) } - } diff --git a/spark/src/main/scala/ai/chronon/spark/Analyzer.scala b/spark/src/main/scala/ai/chronon/spark/Analyzer.scala index 09cc953d89..2609ff1357 100644 --- a/spark/src/main/scala/ai/chronon/spark/Analyzer.scala +++ b/spark/src/main/scala/ai/chronon/spark/Analyzer.scala @@ -808,15 +808,15 @@ class Analyzer(tableUtils: TableUtils, } /** - * Validates schema compatibility between ExternalSource and its offlineGroupBy. + * Validates schema compatibility between ExternalPart and its offlineGroupBy. * This ensures that online and offline serving will produce consistent results. * - * @param externalSource The external source to validate + * @param externalPart The external part to validate * @return Sequence of error messages, empty if no errors */ - def validateOfflineGroupBy(externalSource: api.ExternalSource): Seq[String] = - Option(externalSource.offlineGroupBy) - .map(_ => validateKeySchemaCompatibility(externalSource) ++ validateValueSchemaCompatibility(externalSource)) + def validateOfflineGroupBy(externalPart: api.ExternalPart): Seq[String] = + Option(externalPart.source.offlineGroupBy) + .map(_ => validateKeySchemaCompatibility(externalPart.source) ++ validateValueSchemaCompatibility(externalPart)) .getOrElse(Seq.empty) private def validateKeySchemaCompatibility(externalSource: api.ExternalSource): Seq[String] = { @@ -858,29 +858,32 @@ class Analyzer(tableUtils: TableUtils, errors } - private def validateValueSchemaCompatibility(externalSource: api.ExternalSource): Seq[String] = { + private def validateValueSchemaCompatibility(externalPart: api.ExternalPart): Seq[String] = { val errors = scala.collection.mutable.ListBuffer[String]() - if (externalSource.valueSchema == null) { - errors += s"ExternalSource ${externalSource.metadata.name} valueSchema cannot be null when offlineGroupBy is specified" + if (externalPart.source.valueSchema == null) { + errors += s"ExternalSource ${externalPart.source.metadata.name} valueSchema cannot be null when offlineGroupBy is specified" return errors } - val externalValueFields = externalSource.valueFields + val externalValueFields = externalPart.valueSchemaFull val externalValueNames = externalValueFields.map(_.name).toSet - // For GroupBy value schema, we need to derive the output schema from aggregations - val groupByValueColumns = externalSource.offlineGroupBy.valueColumns.toSet + // External features use full names that include the prefix: ext_[prefix_]sourceName_fieldName + // The offlineGroupBy must define derivations to produce features with matching names. + // For example, if the external part has fullName="ext_prefix_source" and valueField="feature_value", + // the final feature name will be "ext_prefix_source_feature_value", which must match a derived column name. + val groupByDerivedColumns = externalPart.source.offlineGroupBy.derivationsScala.map(_.name).toSet // Check that ExternalSource value schema fields are compatible with GroupBy output - val missingValueColumns = externalValueNames -- groupByValueColumns + val missingValueColumns = externalValueNames -- groupByDerivedColumns if (missingValueColumns.nonEmpty) { // This is an error because ExternalSource valueSchema must be compatible with GroupBy output // to ensure consistency between online and offline serving - errors += s"ExternalSource ${externalSource.metadata.name} valueSchema contains columns [${missingValueColumns + errors += s"ExternalSource ${externalPart.source.metadata.name} valueSchema contains columns [${missingValueColumns .mkString(", ")}] " + - s"that are not present in offlineGroupBy output columns [${groupByValueColumns.mkString(", ")}]. " + + s"that are not present in offlineGroupBy derived output columns [${groupByDerivedColumns.mkString(", ")}]. " + s"This indicates schema incompatibility between online and offline serving. " + s"Please ensure ExternalSource valueSchema matches the expected output of the GroupBy aggregations." } @@ -897,7 +900,7 @@ class Analyzer(tableUtils: TableUtils, */ def runExternalSourceCheck(joinConf: api.Join): Seq[String] = Option(joinConf.onlineExternalParts) - .map(_.toScala.flatMap(part => validateOfflineGroupBy(part.source))) + .map(_.toScala.flatMap(part => validateOfflineGroupBy(part))) .getOrElse(Seq.empty) def run(exportSchema: Boolean = false): Unit = diff --git a/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala b/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala index c070d37dd8..3dcad860e7 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala @@ -18,7 +18,6 @@ package ai.chronon.spark.test import ai.chronon.aggregator.test.Column import ai.chronon.api -import ai.chronon.api.Builders.Query import ai.chronon.api.Extensions.MetadataOps import ai.chronon.api._ import ai.chronon.spark.Extensions._ @@ -919,10 +918,14 @@ class AnalyzerTest { val query = Builders.Query(selects = Map("feature_value" -> "value")) val source = Builders.Source.events(query, "test.table") - // Create GroupBy with matching key columns and sources + // Create GroupBy with matching key columns, sources, and derivation to match external feature name + // The external part will have fullName "ext_test_external_source", so the feature will be "ext_test_external_source_feature_value" val groupBy = Builders.GroupBy( keyColumns = Seq("user_id"), - sources = Seq(source) + sources = Seq(source), + derivations = Seq( + Builders.Derivation(name = "ext_test_external_source_feature_value", expression = "feature_value") + ) ) // Create ExternalSource with compatible schemas @@ -932,9 +935,12 @@ class AnalyzerTest { // Manually set the offlineGroupBy field since the builder doesn't support it yet externalSource.setOfflineGroupBy(groupBy) + // Wrap in ExternalPart + val externalPart = Builders.ExternalPart(externalSource = externalSource) + // Create analyzer instance and call validation val analyzer = new Analyzer(dummyTableUtils, externalSource, oneMonthAgo, today) - val errors = analyzer.validateOfflineGroupBy(externalSource) + val errors = analyzer.validateOfflineGroupBy(externalPart) assertTrue(s"Expected no errors, but got: ${errors.mkString(", ")}", errors.isEmpty) } @@ -959,9 +965,12 @@ class AnalyzerTest { val externalSource = Builders.ExternalSource(metadata, keySchema, valueSchema) externalSource.setOfflineGroupBy(groupBy) + // Wrap in ExternalPart + val externalPart = Builders.ExternalPart(externalSource = externalSource) + // This should return validation errors val analyzer = new Analyzer(dummyTableUtils, externalSource, oneMonthAgo, today) - val errors = analyzer.validateOfflineGroupBy(externalSource) + val errors = analyzer.validateOfflineGroupBy(externalPart) assertFalse("Expected validation errors for mismatched key schemas", errors.isEmpty) assertTrue("Error should mention key schema mismatch", errors.exists(_.contains("key schema contains columns"))) @@ -977,10 +986,14 @@ class AnalyzerTest { val query = Builders.Query(selects = Map("different_feature" -> "different_feature")) val source = Builders.Source.events(query, "test.table") - // Create GroupBy with different value columns + // Create GroupBy with different derived column name that doesn't match external feature name + // The external part expects "ext_test_external_source_feature_value" but GroupBy produces "wrong_name" val groupBy = Builders.GroupBy( keyColumns = Seq("user_id"), - sources = Seq(source) + sources = Seq(source), + derivations = Seq( + Builders.Derivation(name = "wrong_name", expression = "different_feature") + ) ) // Create ExternalSource with incompatible schemas @@ -988,9 +1001,12 @@ class AnalyzerTest { val externalSource = Builders.ExternalSource(metadata, keySchema, valueSchema) externalSource.setOfflineGroupBy(groupBy) + // Wrap in ExternalPart + val externalPart = Builders.ExternalPart(externalSource = externalSource) + // This should return validation errors val analyzer = new Analyzer(dummyTableUtils, externalSource, oneMonthAgo, today) - val errors = analyzer.validateOfflineGroupBy(externalSource) + val errors = analyzer.validateOfflineGroupBy(externalPart) assertFalse("Expected validation errors for mismatched value schemas", errors.isEmpty) assertTrue("Error should mention value schema mismatch", errors.exists(_.contains("valueSchema contains columns"))) @@ -1006,9 +1022,12 @@ class AnalyzerTest { val externalSource = Builders.ExternalSource(metadata, keySchema, valueSchema) // Don't set offlineGroupBy (it remains null) + // Wrap in ExternalPart + val externalPart = Builders.ExternalPart(externalSource = externalSource) + // This should return no errors (validation should be skipped) val analyzer = new Analyzer(dummyTableUtils, externalSource, oneMonthAgo, today) - val errors = analyzer.validateOfflineGroupBy(externalSource) + val errors = analyzer.validateOfflineGroupBy(externalPart) assertTrue("Expected no errors when offlineGroupBy is null", errors.isEmpty) } diff --git a/spark/src/test/scala/ai/chronon/spark/test/ExternalSourceBackfillTest.scala b/spark/src/test/scala/ai/chronon/spark/test/ExternalSourceBackfillTest.scala index 32c552a7dd..97f1bc5415 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/ExternalSourceBackfillTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/ExternalSourceBackfillTest.scala @@ -71,6 +71,16 @@ class ExternalSourceBackfillTest { windows = Seq(new Window(30, TimeUnit.DAYS)) ) ), + derivations = Seq( + // External part will have fullName "ext_ext_external_transaction_features_XXX" + // (Constants.ExternalPrefix + "_" + prefix + "_" + metadata.name) + // So we need to derive column with that prefix + field name + Builders.Derivation.star(), // Keep all base aggregation columns + Builders.Derivation( + name = s"ext_ext_external_transaction_features_${namespace}_amount_sum_30d", + expression = "amount_sum_30d" + ) + ), metaData = Builders.MetaData(name = s"user_transaction_features_$namespace", namespace = namespace), accuracy = Accuracy.TEMPORAL ) @@ -189,6 +199,15 @@ class ExternalSourceBackfillTest { windows = Seq(new Window(7, TimeUnit.DAYS)) ) ), + derivations = Seq( + // External part will have fullName "ext_purchase_external_purchase_features_XXX" + // So we need to derive column with that prefix + field name + Builders.Derivation.star(), // Keep all base aggregation columns + Builders.Derivation( + name = s"ext_purchase_external_purchase_features_${namespace}_purchase_amount_average_7d", + expression = "purchase_amount_average_7d" + ) + ), metaData = Builders.MetaData(name = s"purchase_features_$namespace", namespace = namespace), accuracy = Accuracy.TEMPORAL ) @@ -326,6 +345,15 @@ class ExternalSourceBackfillTest { windows = Seq(new Window(30, TimeUnit.DAYS)) ) ), + derivations = Seq( + // External part will have fullName "ext_mapped_external_features_XXX" + // So we need to derive column with that prefix + field name + Builders.Derivation.star(), // Keep all base aggregation columns + Builders.Derivation( + name = s"ext_mapped_external_features_${namespace}_feature_score_max_30d", + expression = "feature_score_max_30d" + ) + ), metaData = Builders.MetaData(name = s"feature_gb_$namespace", namespace = namespace), accuracy = Accuracy.TEMPORAL ) From 5d03aa11b584350ace1946db896619543153084f Mon Sep 17 00:00:00 2001 From: Hui Cheng Date: Wed, 1 Oct 2025 19:02:06 -0700 Subject: [PATCH 07/13] Ensure feature full name matches --- .../scala/ai/chronon/api/Extensions.scala | 25 ++- .../scala/ai/chronon/spark/Analyzer.scala | 76 ++++++--- .../ai/chronon/spark/test/AnalyzerTest.scala | 147 ++++++++++++++---- .../test/ExternalSourceBackfillTest.scala | 54 +++---- 4 files changed, 213 insertions(+), 89 deletions(-) diff --git a/api/src/main/scala/ai/chronon/api/Extensions.scala b/api/src/main/scala/ai/chronon/api/Extensions.scala index a1e032a01b..228b83de67 100644 --- a/api/src/main/scala/ai/chronon/api/Extensions.scala +++ b/api/src/main/scala/ai/chronon/api/Extensions.scala @@ -745,9 +745,18 @@ object Extensions { } implicit class JoinPartOps(joinPart: JoinPart) extends JoinPart(joinPart) { - lazy val fullPrefix = (Option(prefix) ++ Some(groupBy.getMetaData.cleanName)).mkString("_") + lazy val fullPrefix = { + Option(groupBy.getMetaData.customJsonLookUp("customizedFullPrefix")) + .map(_.toString) + .getOrElse((Option(prefix) ++ Some(groupBy.getMetaData.cleanName)).mkString("_")) + } lazy val leftToRight: Map[String, String] = rightToLeft.map { case (key, value) => value -> key } + lazy val isExternal: Boolean = Option(groupBy.getMetaData.customJsonLookUp("isExternal")) + .map(_.toString) + .getOrElse("false") + .toBoolean + def valueColumns: Seq[String] = joinPart.groupBy.valueColumns.map(fullPrefix + "_" + _) def rightToLeft: Map[String, String] = { @@ -1026,9 +1035,21 @@ object Extensions { .getOrElse(Seq.empty) .filter(_.source.offlineGroupBy != null) // Only offline-capable ExternalParts .map { externalPart => + // Set customJson with fullPrefix override + val offlineGroupBy = externalPart.source.offlineGroupBy.deepCopy() + val existingCustomJson = Option(offlineGroupBy.metaData.customJson).getOrElse("{}") + + val mapper = new ObjectMapper() + val typeRef = new TypeReference[java.util.HashMap[String, Object]]() {} + val customJsonMap: java.util.Map[String, Object] = mapper.readValue(existingCustomJson, typeRef) + customJsonMap.put("customizedFullPrefix", externalPart.fullName) + customJsonMap.put("isExternal", "true") // Mark as external for JoinPartOps.isExternal + val updatedCustomJson = mapper.writeValueAsString(customJsonMap) + offlineGroupBy.metaData.setCustomJson(updatedCustomJson) + // Convert ExternalPart to synthetic JoinPart val syntheticJoinPart = new JoinPart() - syntheticJoinPart.setGroupBy(externalPart.source.offlineGroupBy) + syntheticJoinPart.setGroupBy(offlineGroupBy) if (externalPart.keyMapping != null) { syntheticJoinPart.setKeyMapping(externalPart.keyMapping) } diff --git a/spark/src/main/scala/ai/chronon/spark/Analyzer.scala b/spark/src/main/scala/ai/chronon/spark/Analyzer.scala index 2609ff1357..ef9c637b25 100644 --- a/spark/src/main/scala/ai/chronon/spark/Analyzer.scala +++ b/spark/src/main/scala/ai/chronon/spark/Analyzer.scala @@ -368,6 +368,7 @@ class Analyzer(tableUtils: TableUtils, val leftSchema = leftDf.schema.fields .map(field => (field.name, SparkConversions.toChrononType(field.name, field.dataType))) val joinIntermediateValuesMetadata = ListBuffer[AggregationMetadata]() + val externalGroupByMetadata = ListBuffer[AggregationMetadata]() val keysWithError: ListBuffer[(String, String)] = ListBuffer.empty[(String, String)] val gbStartPartitions = mutable.Map[String, List[String]]() val noAccessTables = mutable.Set[String]() ++= leftNoAccessTables @@ -400,7 +401,9 @@ class Analyzer(tableUtils: TableUtils, skipTimestampCheck = skipTimestampCheck || leftNoAccessTables.nonEmpty, validateTablePermission = validateTablePermission ) - joinIntermediateValuesMetadata ++= analyzeGroupByResult.outputMetadata.map { aggMeta => + + val target = if (!part.isExternal) joinIntermediateValuesMetadata else externalGroupByMetadata + target ++= analyzeGroupByResult.outputMetadata.map { aggMeta => AggregationMetadata(part.fullPrefix + "_" + aggMeta.name, aggMeta.columnType, aggMeta.operation, @@ -408,6 +411,7 @@ class Analyzer(tableUtils: TableUtils, aggMeta.inputColumn, part.getGroupBy.getMetaData.getName) } + // Run validation checks. keysWithError ++= runSchemaValidation(leftSchema.toMap, analyzeGroupByResult.keySchema.toMap, part.rightToLeft) val subPartitionFilters = @@ -424,9 +428,12 @@ class Analyzer(tableUtils: TableUtils, if (gbStartPartition.nonEmpty) gbStartPartitions += (part.groupBy.metaData.name -> gbStartPartition) } + + val externalGroupBySchema = externalGroupByMetadata.map(aggregation => (aggregation.name, aggregation.columnType)) + if (joinConf.onlineExternalParts != null) { // Validate ExternalSource schemas if they have offlineGroupBy configured - externalSourceErrors ++= runExternalSourceCheck(joinConf) + externalSourceErrors ++= runExternalSourceCheck(joinConf, externalGroupBySchema) joinConf.onlineExternalParts.toScala.foreach { part => joinIntermediateValuesMetadata ++= part.source.valueFields.map { field => @@ -812,11 +819,15 @@ class Analyzer(tableUtils: TableUtils, * This ensures that online and offline serving will produce consistent results. * * @param externalPart The external part to validate + * @param externalGroupBySchema The schema of the external groupBy features from offline computation * @return Sequence of error messages, empty if no errors */ - def validateOfflineGroupBy(externalPart: api.ExternalPart): Seq[String] = + def validateOfflineGroupBy(externalPart: api.ExternalPart, + externalGroupBySchema: Seq[(String, DataType)]): Seq[String] = Option(externalPart.source.offlineGroupBy) - .map(_ => validateKeySchemaCompatibility(externalPart.source) ++ validateValueSchemaCompatibility(externalPart)) + .map(_ => + validateKeySchemaCompatibility(externalPart.source) ++ validateValueSchemaCompatibility(externalPart, + externalGroupBySchema)) .getOrElse(Seq.empty) private def validateKeySchemaCompatibility(externalSource: api.ExternalSource): Seq[String] = { @@ -858,7 +869,8 @@ class Analyzer(tableUtils: TableUtils, errors } - private def validateValueSchemaCompatibility(externalPart: api.ExternalPart): Seq[String] = { + private def validateValueSchemaCompatibility(externalPart: api.ExternalPart, + externalGroupBySchema: Seq[(String, DataType)]): Seq[String] = { val errors = scala.collection.mutable.ListBuffer[String]() if (externalPart.source.valueSchema == null) { @@ -866,26 +878,41 @@ class Analyzer(tableUtils: TableUtils, return errors } + // Get expected schema from ExternalPart (what online expects) val externalValueFields = externalPart.valueSchemaFull - val externalValueNames = externalValueFields.map(_.name).toSet + val expectedSchema = externalValueFields.map(field => (field.name, field.fieldType)).toMap - // External features use full names that include the prefix: ext_[prefix_]sourceName_fieldName - // The offlineGroupBy must define derivations to produce features with matching names. - // For example, if the external part has fullName="ext_prefix_source" and valueField="feature_value", - // the final feature name will be "ext_prefix_source_feature_value", which must match a derived column name. - val groupByDerivedColumns = externalPart.source.offlineGroupBy.derivationsScala.map(_.name).toSet - - // Check that ExternalSource value schema fields are compatible with GroupBy output - val missingValueColumns = externalValueNames -- groupByDerivedColumns + // Get actual schema from offline GroupBy computation (what offline produces) + val prefix = externalPart.fullName + "_" + val actualSchema = externalGroupBySchema + .filter(_._1.startsWith(prefix)) + .toMap - if (missingValueColumns.nonEmpty) { - // This is an error because ExternalSource valueSchema must be compatible with GroupBy output - // to ensure consistency between online and offline serving - errors += s"ExternalSource ${externalPart.source.metadata.name} valueSchema contains columns [${missingValueColumns - .mkString(", ")}] " + - s"that are not present in offlineGroupBy derived output columns [${groupByDerivedColumns.mkString(", ")}]. " + - s"This indicates schema incompatibility between online and offline serving. " + - s"Please ensure ExternalSource valueSchema matches the expected output of the GroupBy aggregations." + // Check for fields in expected but not in actual (missing fields) + val missingFields = expectedSchema.keySet -- actualSchema.keySet + if (missingFields.nonEmpty) { + errors += s"ExternalSource ${externalPart.source.metadata.name} offline GroupBy is missing value fields: " + + s"[${missingFields.mkString(", ")}]. These fields are defined in the ExternalSource valueSchema " + + s"but are not produced by the offlineGroupBy." + } + + // Check for fields in actual but not in expected (extra fields) + val extraFields = actualSchema.keySet -- expectedSchema.keySet + if (extraFields.nonEmpty) { + errors += s"ExternalSource ${externalPart.source.metadata.name} offline GroupBy produces extra value fields: " + + s"[${extraFields.mkString(", ")}]. These fields are not defined in the ExternalSource valueSchema." + } + + // Check for type mismatches in common fields + val commonFields = expectedSchema.keySet.intersect(actualSchema.keySet) + commonFields.foreach { fieldName => + val expectedType = expectedSchema(fieldName) + val actualType = actualSchema(fieldName) + if (expectedType != actualType) { + errors += s"ExternalSource ${externalPart.source.metadata.name} field '$fieldName' has type mismatch: " + + s"expected ${DataType.toString(expectedType)} (from ExternalSource valueSchema) " + + s"but offline GroupBy produces ${DataType.toString(actualType)}" + } } errors @@ -896,11 +923,12 @@ class Analyzer(tableUtils: TableUtils, * This ensures schema compatibility between ExternalSources and their offlineGroupBy configurations. * * @param joinConf The join configuration to validate + * @param externalGroupBySchema The schema of the external groupBy features from offline computation * @return Sequence of error messages, empty if no errors */ - def runExternalSourceCheck(joinConf: api.Join): Seq[String] = + def runExternalSourceCheck(joinConf: api.Join, externalGroupBySchema: Seq[(String, DataType)]): Seq[String] = Option(joinConf.onlineExternalParts) - .map(_.toScala.flatMap(part => validateOfflineGroupBy(part))) + .map(_.toScala.flatMap(part => validateOfflineGroupBy(part, externalGroupBySchema))) .getOrElse(Seq.empty) def run(exportSchema: Boolean = false): Unit = diff --git a/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala b/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala index 3dcad860e7..f581772357 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala @@ -25,7 +25,7 @@ import ai.chronon.spark.catalog.TableUtils import ai.chronon.spark.{Analyzer, Join, SparkSessionBuilder} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions.{col, lit, to_json} -import org.junit.Assert.{assertEquals, assertFalse, assertTrue} +import org.junit.Assert.{assertEquals, assertTrue} import org.junit.Test import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.{never, spy, verify, when} @@ -910,13 +910,28 @@ class AnalyzerTest { @Test def testExternalSourceValidationWithMatchingSchemas(): Unit = { + val spark: SparkSession = SparkSessionBuilder.build("AnalyzerTest", local = true) + val tableUtils = TableUtils(spark) + val namespace = "analyzer_test_ns" + "_" + Random.alphanumeric.take(6).mkString + tableUtils.createDatabase(namespace) + + // Create left side table + val leftSchema = List(Column("user_id", api.StringType, 100)) + val leftTable = s"$namespace.left_table" + DataFrameGen.events(spark, leftSchema, 10, partitions = 5).save(leftTable) + // Create compatible schemas using the correct DataType objects val keySchema = StructType("key", Array(StructField("user_id", StringType))) val valueSchema = StructType("value", Array(StructField("feature_value", DoubleType))) + // Create right side table for GroupBy + val rightSchema = List(Column("user_id", api.StringType, 100), Column("value", api.DoubleType, 100)) + val rightTable = s"$namespace.right_table" + DataFrameGen.events(spark, rightSchema, 100, partitions = 10).save(rightTable) + // Create a query and source for the GroupBy - val query = Builders.Query(selects = Map("feature_value" -> "value")) - val source = Builders.Source.events(query, "test.table") + val query = Builders.Query(selects = Map("feature_value" -> "value"), startPartition = oneYearAgo) + val source = Builders.Source.events(query, rightTable) // Create GroupBy with matching key columns, sources, and derivation to match external feature name // The external part will have fullName "ext_test_external_source", so the feature will be "ext_test_external_source_feature_value" @@ -924,8 +939,9 @@ class AnalyzerTest { keyColumns = Seq("user_id"), sources = Seq(source), derivations = Seq( - Builders.Derivation(name = "ext_test_external_source_feature_value", expression = "feature_value") - ) + Builders.Derivation(name = "feature_value", expression = "feature_value") + ), + metaData = Builders.MetaData(name = "test_external_gb", namespace = namespace) ) // Create ExternalSource with compatible schemas @@ -938,26 +954,50 @@ class AnalyzerTest { // Wrap in ExternalPart val externalPart = Builders.ExternalPart(externalSource = externalSource) - // Create analyzer instance and call validation - val analyzer = new Analyzer(dummyTableUtils, externalSource, oneMonthAgo, today) - val errors = analyzer.validateOfflineGroupBy(externalPart) - assertTrue(s"Expected no errors, but got: ${errors.mkString(", ")}", errors.isEmpty) + // Create Join with ExternalPart + val joinConf = Builders.Join( + left = Builders.Source.events(Builders.Query(startPartition = oneMonthAgo), table = leftTable), + externalParts = Seq(externalPart), + metaData = Builders.MetaData(name = "test_join_external", namespace = namespace, team = "chronon") + ) + + // Create analyzer instance and call analyzeJoin + val analyzer = new Analyzer(tableUtils, joinConf, oneMonthAgo, today, skipTimestampCheck = true, validateTablePermission = false) + val result = analyzer.analyzeJoin(joinConf) + // If no exception is thrown, validation passed + assertTrue("Expected successful validation", result != null) } - @Test + @Test(expected = classOf[java.lang.AssertionError]) def testExternalSourceValidationWithMismatchedKeySchemas(): Unit = { + val spark: SparkSession = SparkSessionBuilder.build("AnalyzerTest", local = true) + val tableUtils = TableUtils(spark) + val namespace = "analyzer_test_ns" + "_" + Random.alphanumeric.take(6).mkString + tableUtils.createDatabase(namespace) + + // Create left side table + val leftSchema = List(Column("user_id", api.StringType, 100)) + val leftTable = s"$namespace.left_table" + DataFrameGen.events(spark, leftSchema, 10, partitions = 5).save(leftTable) + // Create key schema with different fields val keySchema = StructType("key", Array(StructField("user_id", StringType))) val valueSchema = StructType("value", Array(StructField("feature_value", DoubleType))) + // Create right side table for GroupBy + val rightSchema = List(Column("different_key", api.StringType, 100), Column("feature_value", api.DoubleType, 100)) + val rightTable = s"$namespace.right_table" + DataFrameGen.events(spark, rightSchema, 100, partitions = 10).save(rightTable) + // Create a query and source for the GroupBy - val query = Builders.Query(selects = Map("feature_value" -> "feature_value")) - val source = Builders.Source.events(query, "test.table") + val query = Builders.Query(selects = Map("feature_value" -> "feature_value"), startPartition = oneYearAgo) + val source = Builders.Source.events(query, rightTable) // Create GroupBy with different key columns val groupBy = Builders.GroupBy( keyColumns = Seq("different_key"), // Mismatched key column - sources = Seq(source) + sources = Seq(source), + metaData = Builders.MetaData(name = "test_external_gb_mismatch", namespace = namespace) ) // Create ExternalSource with incompatible schemas @@ -968,23 +1008,42 @@ class AnalyzerTest { // Wrap in ExternalPart val externalPart = Builders.ExternalPart(externalSource = externalSource) - // This should return validation errors - val analyzer = new Analyzer(dummyTableUtils, externalSource, oneMonthAgo, today) - val errors = analyzer.validateOfflineGroupBy(externalPart) - assertFalse("Expected validation errors for mismatched key schemas", errors.isEmpty) - assertTrue("Error should mention key schema mismatch", - errors.exists(_.contains("key schema contains columns"))) + // Create Join with ExternalPart + val joinConf = Builders.Join( + left = Builders.Source.events(Builders.Query(startPartition = oneMonthAgo), table = leftTable), + externalParts = Seq(externalPart), + metaData = Builders.MetaData(name = "test_join_external_key_mismatch", namespace = namespace, team = "chronon") + ) + + // This should throw AssertionError due to validation errors + val analyzer = new Analyzer(tableUtils, joinConf, oneMonthAgo, today, skipTimestampCheck = true, validateTablePermission = false) + analyzer.analyzeJoin(joinConf, validationAssert = true) } - @Test + @Test(expected = classOf[java.lang.AssertionError]) def testExternalSourceValidationWithMismatchedValueSchemas(): Unit = { + val spark: SparkSession = SparkSessionBuilder.build("AnalyzerTest", local = true) + val tableUtils = TableUtils(spark) + val namespace = "analyzer_test_ns" + "_" + Random.alphanumeric.take(6).mkString + tableUtils.createDatabase(namespace) + + // Create left side table + val leftSchema = List(Column("user_id", api.StringType, 100)) + val leftTable = s"$namespace.left_table" + DataFrameGen.events(spark, leftSchema, 10, partitions = 5).save(leftTable) + // Create compatible key schema but mismatched value schema val keySchema = StructType("key", Array(StructField("user_id", StringType))) val valueSchema = StructType("value", Array(StructField("feature_value", DoubleType))) + // Create right side table for GroupBy + val rightSchema = List(Column("user_id", api.StringType, 100), Column("different_feature", api.DoubleType, 100)) + val rightTable = s"$namespace.right_table" + DataFrameGen.events(spark, rightSchema, 100, partitions = 10).save(rightTable) + // Create a source for the GroupBy - this is needed for valueColumns to work - val query = Builders.Query(selects = Map("different_feature" -> "different_feature")) - val source = Builders.Source.events(query, "test.table") + val query = Builders.Query(selects = Map("different_feature" -> "different_feature"), startPartition = oneYearAgo) + val source = Builders.Source.events(query, rightTable) // Create GroupBy with different derived column name that doesn't match external feature name // The external part expects "ext_test_external_source_feature_value" but GroupBy produces "wrong_name" @@ -993,7 +1052,8 @@ class AnalyzerTest { sources = Seq(source), derivations = Seq( Builders.Derivation(name = "wrong_name", expression = "different_feature") - ) + ), + metaData = Builders.MetaData(name = "test_external_gb_value_mismatch", namespace = namespace) ) // Create ExternalSource with incompatible schemas @@ -1004,16 +1064,30 @@ class AnalyzerTest { // Wrap in ExternalPart val externalPart = Builders.ExternalPart(externalSource = externalSource) - // This should return validation errors - val analyzer = new Analyzer(dummyTableUtils, externalSource, oneMonthAgo, today) - val errors = analyzer.validateOfflineGroupBy(externalPart) - assertFalse("Expected validation errors for mismatched value schemas", errors.isEmpty) - assertTrue("Error should mention value schema mismatch", - errors.exists(_.contains("valueSchema contains columns"))) + // Create Join with ExternalPart + val joinConf = Builders.Join( + left = Builders.Source.events(Builders.Query(startPartition = oneMonthAgo), table = leftTable), + externalParts = Seq(externalPart), + metaData = Builders.MetaData(name = "test_join_external_value_mismatch", namespace = namespace, team = "chronon") + ) + + // This should throw AssertionError due to validation errors + val analyzer = new Analyzer(tableUtils, joinConf, oneMonthAgo, today, skipTimestampCheck = true, validateTablePermission = false) + analyzer.analyzeJoin(joinConf, validationAssert = true) } @Test def testExternalSourceValidationWithNullOfflineGroupBy(): Unit = { + val spark: SparkSession = SparkSessionBuilder.build("AnalyzerTest", local = true) + val tableUtils = TableUtils(spark) + val namespace = "analyzer_test_ns" + "_" + Random.alphanumeric.take(6).mkString + tableUtils.createDatabase(namespace) + + // Create left side table + val leftSchema = List(Column("user_id", api.StringType, 100)) + val leftTable = s"$namespace.left_table" + DataFrameGen.events(spark, leftSchema, 10, partitions = 5).save(leftTable) + // Create ExternalSource without offlineGroupBy val keySchema = StructType("key", Array(StructField("user_id", StringType))) val valueSchema = StructType("value", Array(StructField("feature_value", DoubleType))) @@ -1025,10 +1099,17 @@ class AnalyzerTest { // Wrap in ExternalPart val externalPart = Builders.ExternalPart(externalSource = externalSource) - // This should return no errors (validation should be skipped) - val analyzer = new Analyzer(dummyTableUtils, externalSource, oneMonthAgo, today) - val errors = analyzer.validateOfflineGroupBy(externalPart) - assertTrue("Expected no errors when offlineGroupBy is null", errors.isEmpty) + // Create Join with ExternalPart + val joinConf = Builders.Join( + left = Builders.Source.events(Builders.Query(startPartition = oneMonthAgo), table = leftTable), + externalParts = Seq(externalPart), + metaData = Builders.MetaData(name = "test_join_external_null_gb", namespace = namespace, team = "chronon") + ) + + // This should not throw - validation should be skipped when offlineGroupBy is null + val analyzer = new Analyzer(tableUtils, joinConf, oneMonthAgo, today, skipTimestampCheck = true, validateTablePermission = false) + val result = analyzer.analyzeJoin(joinConf) + assertTrue("Expected successful validation when offlineGroupBy is null", result != null) } } diff --git a/spark/src/test/scala/ai/chronon/spark/test/ExternalSourceBackfillTest.scala b/spark/src/test/scala/ai/chronon/spark/test/ExternalSourceBackfillTest.scala index 97f1bc5415..1a614a09b2 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/ExternalSourceBackfillTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/ExternalSourceBackfillTest.scala @@ -38,7 +38,7 @@ class ExternalSourceBackfillTest { val spark: SparkSession = SparkSessionBuilder.build("ExternalSourceBackfillTest" + "_" + Random.alphanumeric.take(6).mkString, local = true) val tableUtils = TableUtils(spark) - val namespace = "test_namespace_ext_backfill" + "_" + Random.alphanumeric.take(6).mkString + val namespace = "test_namespace_" + Random.alphanumeric.take(6).mkString tableUtils.createDatabase(namespace) // Create user transaction data for offline GroupBy @@ -72,24 +72,21 @@ class ExternalSourceBackfillTest { ) ), derivations = Seq( - // External part will have fullName "ext_ext_external_transaction_features_XXX" - // (Constants.ExternalPrefix + "_" + prefix + "_" + metadata.name) - // So we need to derive column with that prefix + field name Builders.Derivation.star(), // Keep all base aggregation columns Builders.Derivation( - name = s"ext_ext_external_transaction_features_${namespace}_amount_sum_30d", + name = s"es_amount", expression = "amount_sum_30d" ) ), - metaData = Builders.MetaData(name = s"user_transaction_features_$namespace", namespace = namespace), + metaData = Builders.MetaData(name = s"gb_amount", namespace = namespace), accuracy = Accuracy.TEMPORAL ) // Create ExternalSource with offline GroupBy val externalSource = Builders.ExternalSource( - metadata = Builders.MetaData(name = s"external_transaction_features_$namespace"), + metadata = Builders.MetaData(name = s"test_external_source"), keySchema = StructType("external_keys", Array(StructField("user_id", StringType))), - valueSchema = StructType("external_values", Array(StructField("amount_sum_30d", LongType))) + valueSchema = StructType("external_values", Array(StructField("es_amount", LongType))) ) externalSource.setOfflineGroupBy(offlineGroupBy) @@ -113,13 +110,14 @@ class ExternalSourceBackfillTest { ), table = userEventTable ), + joinParts = Seq(), externalParts = Seq( Builders.ExternalPart( externalSource, - prefix = "ext" + prefix = "txn" ) ), - metaData = Builders.MetaData(name = s"test_external_join_$namespace", namespace = namespace) + metaData = Builders.MetaData(name = s"test_external_part", namespace = namespace) ) // Run analyzer to ensure GroupBy tables are created @@ -157,7 +155,7 @@ class ExternalSourceBackfillTest { val spark: SparkSession = SparkSessionBuilder.build("ExternalSourceBackfillTest_Mixed" + "_" + Random.alphanumeric.take(6).mkString, local = true) val tableUtils = TableUtils(spark) - val namespace = "test_namespace_mixed" + "_" + Random.alphanumeric.take(6).mkString + val namespace = "test_namespace_" + Random.alphanumeric.take(6).mkString tableUtils.createDatabase(namespace) // Create transaction data for external source GroupBy @@ -200,15 +198,13 @@ class ExternalSourceBackfillTest { ) ), derivations = Seq( - // External part will have fullName "ext_purchase_external_purchase_features_XXX" - // So we need to derive column with that prefix + field name Builders.Derivation.star(), // Keep all base aggregation columns Builders.Derivation( - name = s"ext_purchase_external_purchase_features_${namespace}_purchase_amount_average_7d", + name = s"purchase_amount", expression = "purchase_amount_average_7d" ) ), - metaData = Builders.MetaData(name = s"purchase_features_$namespace", namespace = namespace), + metaData = Builders.MetaData(name = s"gb_purchase", namespace = namespace), accuracy = Accuracy.TEMPORAL ) @@ -231,15 +227,15 @@ class ExternalSourceBackfillTest { windows = Seq(new Window(14, TimeUnit.DAYS)) ) ), - metaData = Builders.MetaData(name = s"session_features_$namespace", namespace = namespace), + metaData = Builders.MetaData(name = s"gb_session", namespace = namespace), accuracy = Accuracy.TEMPORAL ) // Create ExternalSource with offline GroupBy val externalSource = Builders.ExternalSource( - metadata = Builders.MetaData(name = s"external_purchase_features_$namespace"), + metadata = Builders.MetaData(name = s"es_purchase"), keySchema = StructType("external_keys", Array(StructField("user_id", StringType))), - valueSchema = StructType("external_values", Array(StructField("purchase_amount_average_7d", DoubleType))) + valueSchema = StructType("external_values", Array(StructField("purchase_amount", DoubleType))) ) externalSource.setOfflineGroupBy(purchaseGroupBy) @@ -274,7 +270,7 @@ class ExternalSourceBackfillTest { prefix = "purchase" ) ), - metaData = Builders.MetaData(name = s"test_mixed_join_$namespace", namespace = namespace) + metaData = Builders.MetaData(name = s"test_mixed_join", namespace = namespace) ) // Run analyzer to ensure all GroupBy tables are created @@ -292,18 +288,18 @@ class ExternalSourceBackfillTest { // Verify that both regular JoinPart and ExternalPart columns are present val columns = computed.columns.toSet + println(s"Columns: ${computed.columns.mkString(", ")}") assertTrue("Should contain left source columns", columns.contains("user_id")) assertTrue("Should contain left source columns", columns.contains("page_views")) assertTrue("Should contain regular JoinPart prefixed columns", columns.exists(_.startsWith("session_"))) assertTrue("Should contain external source prefixed columns", - columns.exists(_.startsWith("purchase_"))) + columns.exists(_.endsWith("purchase_amount"))) // Show results for debugging println("=== Mixed External and JoinPart Results ===") computed.show(20, truncate = false) println(s"Total rows: ${computed.count()}") - println(s"Columns: ${computed.columns.mkString(", ")}") spark.sql(s"DROP DATABASE IF EXISTS $namespace CASCADE") } @@ -313,7 +309,7 @@ class ExternalSourceBackfillTest { val spark: SparkSession = SparkSessionBuilder.build("ExternalSourceBackfillTest_KeyMapping" + "_" + Random.alphanumeric.take(6).mkString, local = true) val tableUtils = TableUtils(spark) - val namespace = "test_namespace_keymapping" + "_" + Random.alphanumeric.take(6).mkString + val namespace = "test_namespace_" + Random.alphanumeric.take(6).mkString tableUtils.createDatabase(namespace) // Create feature data with internal_user_id @@ -346,23 +342,21 @@ class ExternalSourceBackfillTest { ) ), derivations = Seq( - // External part will have fullName "ext_mapped_external_features_XXX" - // So we need to derive column with that prefix + field name Builders.Derivation.star(), // Keep all base aggregation columns Builders.Derivation( - name = s"ext_mapped_external_features_${namespace}_feature_score_max_30d", + name = s"feature_score", expression = "feature_score_max_30d" ) ), - metaData = Builders.MetaData(name = s"feature_gb_$namespace", namespace = namespace), + metaData = Builders.MetaData(name = "gb_feature", namespace = namespace), accuracy = Accuracy.TEMPORAL ) // Create ExternalSource that expects internal_user_id val externalSource = Builders.ExternalSource( - metadata = Builders.MetaData(name = s"external_features_$namespace"), + metadata = Builders.MetaData(name = "es_feature"), keySchema = StructType("external_keys", Array(StructField("internal_user_id", StringType))), - valueSchema = StructType("external_values", Array(StructField("feature_score_max_30d", LongType))) + valueSchema = StructType("external_values", Array(StructField("feature_score", LongType))) ) externalSource.setOfflineGroupBy(featureGroupBy) @@ -392,7 +386,7 @@ class ExternalSourceBackfillTest { prefix = "mapped" ) ), - metaData = Builders.MetaData(name = s"test_keymapping_join_$namespace", namespace = namespace) + metaData = Builders.MetaData(name = s"test_keymapping_join", namespace = namespace) ) // Run analyzer to ensure GroupBy tables are created @@ -413,7 +407,7 @@ class ExternalSourceBackfillTest { assertTrue("Should contain external_user_id from left", columns.contains("external_user_id")) assertTrue("Should contain request_type from left", columns.contains("request_type")) assertTrue("Should contain mapped external columns", - columns.exists(_.startsWith("mapped_"))) + columns.exists(_.startsWith("ext_mapped_"))) // Show results for debugging println("=== Key Mapping External Source Results ===") From d76c4b208c38fecc56b482640cb85e99c61a9dc3 Mon Sep 17 00:00:00 2001 From: Hui Cheng Date: Wed, 8 Oct 2025 17:56:53 -0700 Subject: [PATCH 08/13] Address review comments --- .../scala/ai/chronon/api/Extensions.scala | 30 +++----- .../ai/chronon/api/ExternalJoinPart.scala | 9 +++ .../scala/ai/chronon/spark/Analyzer.scala | 4 +- .../scala/ai/chronon/spark/JoinUtils.scala | 2 +- .../test/ExternalSourceBackfillTest.scala | 71 ++++++++++++------- 5 files changed, 66 insertions(+), 50 deletions(-) create mode 100644 api/src/main/scala/ai/chronon/api/ExternalJoinPart.scala diff --git a/api/src/main/scala/ai/chronon/api/Extensions.scala b/api/src/main/scala/ai/chronon/api/Extensions.scala index 228b83de67..3453fd3bad 100644 --- a/api/src/main/scala/ai/chronon/api/Extensions.scala +++ b/api/src/main/scala/ai/chronon/api/Extensions.scala @@ -704,10 +704,11 @@ object Extensions { } implicit class ExternalPartOps(externalPart: ExternalPart) extends ExternalPart(externalPart) { - lazy val fullName: String = + lazy val fullName: String = { Constants.ExternalPrefix + "_" + Option(externalPart.prefix).map(_ + "_").getOrElse("") + externalPart.source.metadata.name.sanitize + } def apply(query: Map[String, Any], flipped: Map[String, String], right_keys: Seq[String]): Map[String, AnyRef] = { val rightToLeft = right_keys.map(k => k -> flipped.getOrElse(k, k)) @@ -746,17 +747,15 @@ object Extensions { implicit class JoinPartOps(joinPart: JoinPart) extends JoinPart(joinPart) { lazy val fullPrefix = { - Option(groupBy.getMetaData.customJsonLookUp("customizedFullPrefix")) - .map(_.toString) - .getOrElse((Option(prefix) ++ Some(groupBy.getMetaData.cleanName)).mkString("_")) + joinPart match { + case part: ExternalJoinPart => + part.externalJoinFullPrefix + case _ => + (Option(prefix) ++ Some(groupBy.getMetaData.cleanName)).mkString("_") + } } lazy val leftToRight: Map[String, String] = rightToLeft.map { case (key, value) => value -> key } - lazy val isExternal: Boolean = Option(groupBy.getMetaData.customJsonLookUp("isExternal")) - .map(_.toString) - .getOrElse("false") - .toBoolean - def valueColumns: Seq[String] = joinPart.groupBy.valueColumns.map(fullPrefix + "_" + _) def rightToLeft: Map[String, String] = { @@ -1029,7 +1028,7 @@ object Extensions { * * @return Sequence of JoinParts converted from offline-capable ExternalParts */ - private def getExternalJoinParts: Seq[JoinPart] = { + private def getExternalJoinParts: Seq[ExternalJoinPart] = { Option(join.onlineExternalParts) .map(_.toScala) .getOrElse(Seq.empty) @@ -1037,15 +1036,6 @@ object Extensions { .map { externalPart => // Set customJson with fullPrefix override val offlineGroupBy = externalPart.source.offlineGroupBy.deepCopy() - val existingCustomJson = Option(offlineGroupBy.metaData.customJson).getOrElse("{}") - - val mapper = new ObjectMapper() - val typeRef = new TypeReference[java.util.HashMap[String, Object]]() {} - val customJsonMap: java.util.Map[String, Object] = mapper.readValue(existingCustomJson, typeRef) - customJsonMap.put("customizedFullPrefix", externalPart.fullName) - customJsonMap.put("isExternal", "true") // Mark as external for JoinPartOps.isExternal - val updatedCustomJson = mapper.writeValueAsString(customJsonMap) - offlineGroupBy.metaData.setCustomJson(updatedCustomJson) // Convert ExternalPart to synthetic JoinPart val syntheticJoinPart = new JoinPart() @@ -1056,7 +1046,7 @@ object Extensions { if (externalPart.prefix != null) { syntheticJoinPart.setPrefix(externalPart.prefix) } - syntheticJoinPart + new ExternalJoinPart(syntheticJoinPart, externalPart.fullName) } } diff --git a/api/src/main/scala/ai/chronon/api/ExternalJoinPart.scala b/api/src/main/scala/ai/chronon/api/ExternalJoinPart.scala new file mode 100644 index 0000000000..24f50c23fd --- /dev/null +++ b/api/src/main/scala/ai/chronon/api/ExternalJoinPart.scala @@ -0,0 +1,9 @@ +package ai.chronon.api + +class ExternalJoinPart(joinPart: JoinPart, fullPrefix: String) extends JoinPart(joinPart) { + lazy val externalJoinFullPrefix: String = fullPrefix + + override def deepCopy(): JoinPart = { + new ExternalJoinPart(joinPart.deepCopy(), externalJoinFullPrefix) + } +} diff --git a/spark/src/main/scala/ai/chronon/spark/Analyzer.scala b/spark/src/main/scala/ai/chronon/spark/Analyzer.scala index ef9c637b25..a1fd715cf8 100644 --- a/spark/src/main/scala/ai/chronon/spark/Analyzer.scala +++ b/spark/src/main/scala/ai/chronon/spark/Analyzer.scala @@ -19,7 +19,7 @@ package ai.chronon.spark import ai.chronon.api import ai.chronon.api.DataModel.{DataModel, Entities, Events} import ai.chronon.api.Extensions._ -import ai.chronon.api.{Accuracy, AggregationPart, Constants, DataType, TimeUnit, Window} +import ai.chronon.api.{Accuracy, AggregationPart, Constants, DataType, ExternalJoinPart, TimeUnit, Window} import ai.chronon.online.SparkConversions import ai.chronon.spark.Driver.parseConf import ai.chronon.spark.Extensions.StructTypeOps @@ -402,7 +402,7 @@ class Analyzer(tableUtils: TableUtils, validateTablePermission = validateTablePermission ) - val target = if (!part.isExternal) joinIntermediateValuesMetadata else externalGroupByMetadata + val target = if (!part.isInstanceOf[ExternalJoinPart]) joinIntermediateValuesMetadata else externalGroupByMetadata target ++= analyzeGroupByResult.outputMetadata.map { aggMeta => AggregationMetadata(part.fullPrefix + "_" + aggMeta.name, aggMeta.columnType, diff --git a/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala b/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala index 46689897a7..48e049f98a 100644 --- a/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala @@ -295,6 +295,7 @@ object JoinUtils { def injectKeyFilter(leftDf: DataFrame, originalJoinPart: api.JoinPart): api.JoinPart = { // make a copy of the original joinPart to avoid accumulating the key filters into the same object + // IMPORTANT: Preserve ExternalJoinPart type if present val joinPart = originalJoinPart.deepCopy() // Modifies the joinPart to inject the key filter into the where Clause of GroupBys by hardcoding the keyset val groupByKeyNames = joinPart.groupBy.getKeyColumns.toScala @@ -355,5 +356,4 @@ object JoinUtils { .filterNot(col => filter.contains(col)) df.drop(columnsToDrop: _*) } - } diff --git a/spark/src/test/scala/ai/chronon/spark/test/ExternalSourceBackfillTest.scala b/spark/src/test/scala/ai/chronon/spark/test/ExternalSourceBackfillTest.scala index 1a614a09b2..89d7e09422 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/ExternalSourceBackfillTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/ExternalSourceBackfillTest.scala @@ -42,15 +42,16 @@ class ExternalSourceBackfillTest { tableUtils.createDatabase(namespace) // Create user transaction data for offline GroupBy + // IMPORTANT: Use same cardinality as left source to ensure key overlap val transactionColumns = List( - Column("user_id", StringType, 100), + Column("user_id", StringType, 100), // Same cardinality as userEventColumns Column("amount", LongType, 1000), Column("transaction_type", StringType, 5) ) val transactionTable = s"$namespace.user_transactions" spark.sql(s"DROP TABLE IF EXISTS $transactionTable") - DataFrameGen.events(spark, transactionColumns, 2000, partitions = 100).save(transactionTable) + DataFrameGen.events(spark, transactionColumns, 5000, partitions = 100).save(transactionTable) // Create offline GroupBy for external source val offlineGroupBy = Builders.GroupBy( @@ -79,7 +80,7 @@ class ExternalSourceBackfillTest { ) ), metaData = Builders.MetaData(name = s"gb_amount", namespace = namespace), - accuracy = Accuracy.TEMPORAL + accuracy = Accuracy.SNAPSHOT ) // Create ExternalSource with offline GroupBy @@ -141,11 +142,14 @@ class ExternalSourceBackfillTest { assertTrue("Should contain external source prefixed columns", columns.exists(_.startsWith("ext_"))) - // Show results for debugging - println("=== External Source Backfill Join Results ===") - computed.show(20, truncate = false) - println(s"Total rows: ${computed.count()}") - println(s"Columns: ${computed.columns.mkString(", ")}") + // Verify that external source columns have non-null data + val externalColumns = computed.columns.filter(_.startsWith("ext_")) + assertTrue("Should have at least one external column", externalColumns.nonEmpty) + externalColumns.foreach { col => + val nonNullCount = computed.filter(s"$col IS NOT NULL").count() + assertTrue(s"External column $col should have non-null values (found $nonNullCount non-null rows)", + nonNullCount > 0) + } spark.sql(s"DROP DATABASE IF EXISTS $namespace CASCADE") } @@ -155,6 +159,8 @@ class ExternalSourceBackfillTest { val spark: SparkSession = SparkSessionBuilder.build("ExternalSourceBackfillTest_Mixed" + "_" + Random.alphanumeric.take(6).mkString, local = true) val tableUtils = TableUtils(spark) + val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) + val monthAgo = tableUtils.partitionSpec.minus(today, new Window(30, TimeUnit.DAYS)) val namespace = "test_namespace_" + Random.alphanumeric.take(6).mkString tableUtils.createDatabase(namespace) @@ -205,7 +211,7 @@ class ExternalSourceBackfillTest { ) ), metaData = Builders.MetaData(name = s"gb_purchase", namespace = namespace), - accuracy = Accuracy.TEMPORAL + accuracy = Accuracy.SNAPSHOT ) // Create GroupBy for regular JoinPart (sessions) @@ -228,7 +234,7 @@ class ExternalSourceBackfillTest { ) ), metaData = Builders.MetaData(name = s"gb_session", namespace = namespace), - accuracy = Accuracy.TEMPORAL + accuracy = Accuracy.SNAPSHOT ) // Create ExternalSource with offline GroupBy @@ -288,7 +294,6 @@ class ExternalSourceBackfillTest { // Verify that both regular JoinPart and ExternalPart columns are present val columns = computed.columns.toSet - println(s"Columns: ${computed.columns.mkString(", ")}") assertTrue("Should contain left source columns", columns.contains("user_id")) assertTrue("Should contain left source columns", columns.contains("page_views")) assertTrue("Should contain regular JoinPart prefixed columns", @@ -296,10 +301,23 @@ class ExternalSourceBackfillTest { assertTrue("Should contain external source prefixed columns", columns.exists(_.endsWith("purchase_amount"))) - // Show results for debugging - println("=== Mixed External and JoinPart Results ===") - computed.show(20, truncate = false) - println(s"Total rows: ${computed.count()}") + // Verify that external source columns have non-null data + val externalColumns = computed.columns.filter(col => col.startsWith("ext_") || col.contains("purchase_amount")) + assertTrue("Should have at least one external column", externalColumns.nonEmpty) + externalColumns.foreach { col => + val nonNullCount = computed.filter(s"$col IS NOT NULL").count() + assertTrue(s"External column $col should have non-null values (found $nonNullCount non-null rows)", + nonNullCount > 0) + } + + // Verify that regular JoinPart columns have non-null data + val joinPartColumns = computed.columns.filter(_.startsWith("session_")) + assertTrue("Should have at least one JoinPart column", joinPartColumns.nonEmpty) + joinPartColumns.foreach { col => + val nonNullCount = computed.filter(s"$col IS NOT NULL").count() + assertTrue(s"JoinPart column $col should have non-null values (found $nonNullCount non-null rows)", + nonNullCount > 0) + } spark.sql(s"DROP DATABASE IF EXISTS $namespace CASCADE") } @@ -309,6 +327,8 @@ class ExternalSourceBackfillTest { val spark: SparkSession = SparkSessionBuilder.build("ExternalSourceBackfillTest_KeyMapping" + "_" + Random.alphanumeric.take(6).mkString, local = true) val tableUtils = TableUtils(spark) + val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) + val monthAgo = tableUtils.partitionSpec.minus(today, new Window(30, TimeUnit.DAYS)) val namespace = "test_namespace_" + Random.alphanumeric.take(6).mkString tableUtils.createDatabase(namespace) @@ -320,7 +340,9 @@ class ExternalSourceBackfillTest { val featureTable = s"$namespace.user_features" spark.sql(s"DROP TABLE IF EXISTS $featureTable") - DataFrameGen.events(spark, featureColumns, 2000, partitions = 60).save(featureTable) + // Generate 135 partitions to ensure we have enough data for 30-day window + 1-day shift + monthAgo (30 days) + buffer + // Need to cover: today back to (monthAgo - 30-day window - 1-day shift) = ~91 days + buffer + DataFrameGen.events(spark, featureColumns, 2000, partitions = 135).save(featureTable) // Create GroupBy using internal_user_id val featureGroupBy = Builders.GroupBy( @@ -349,7 +371,7 @@ class ExternalSourceBackfillTest { ) ), metaData = Builders.MetaData(name = "gb_feature", namespace = namespace), - accuracy = Accuracy.TEMPORAL + accuracy = Accuracy.SNAPSHOT ) // Create ExternalSource that expects internal_user_id @@ -368,7 +390,9 @@ class ExternalSourceBackfillTest { val requestTable = s"$namespace.user_requests" spark.sql(s"DROP TABLE IF EXISTS $requestTable") - DataFrameGen.events(spark, requestColumns, 600, partitions = 30).save(requestTable) + // Generate 60 partitions for the left table to limit the backfill window + // Feature table needs 135 partitions to support 60-day backfill + 30-day window + buffer + DataFrameGen.events(spark, requestColumns, 600, partitions = 60).save(requestTable) // Create Join with key mapping from external_user_id to internal_user_id val joinConf = Builders.Join( @@ -389,8 +413,8 @@ class ExternalSourceBackfillTest { metaData = Builders.MetaData(name = s"test_keymapping_join", namespace = namespace) ) - // Run analyzer to ensure GroupBy tables are created - val analyzer = new Analyzer(tableUtils, joinConf, monthAgo, today) + // Run analyzer to ensure GroupBy tables are created (skip validation for test) + val analyzer = new Analyzer(tableUtils, joinConf, monthAgo, today, validateTablePermission = false, skipTimestampCheck = true) analyzer.run() // Create Join and compute @@ -408,13 +432,6 @@ class ExternalSourceBackfillTest { assertTrue("Should contain request_type from left", columns.contains("request_type")) assertTrue("Should contain mapped external columns", columns.exists(_.startsWith("ext_mapped_"))) - - // Show results for debugging - println("=== Key Mapping External Source Results ===") - computed.show(20, truncate = false) - println(s"Total rows: ${computed.count()}") - println(s"Columns: ${computed.columns.mkString(", ")}") - spark.sql(s"DROP DATABASE IF EXISTS $namespace CASCADE") } } \ No newline at end of file From f51fd747cde80b5bcc4df3a18be97f266c47e997 Mon Sep 17 00:00:00 2001 From: Hui Cheng Date: Thu, 9 Oct 2025 09:00:31 -0700 Subject: [PATCH 09/13] Add online fetching test --- .../scala/ai/chronon/spark/JoinUtils.scala | 2 +- .../spark/test/ExternalSourcesTest.scala | 165 ++++++++++++++++++ 2 files changed, 166 insertions(+), 1 deletion(-) diff --git a/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala b/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala index 48e049f98a..46689897a7 100644 --- a/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala @@ -295,7 +295,6 @@ object JoinUtils { def injectKeyFilter(leftDf: DataFrame, originalJoinPart: api.JoinPart): api.JoinPart = { // make a copy of the original joinPart to avoid accumulating the key filters into the same object - // IMPORTANT: Preserve ExternalJoinPart type if present val joinPart = originalJoinPart.deepCopy() // Modifies the joinPart to inject the key filter into the where Clause of GroupBys by hardcoding the keyset val groupByKeyNames = joinPart.groupBy.getKeyColumns.toScala @@ -356,4 +355,5 @@ object JoinUtils { .filterNot(col => filter.contains(col)) df.drop(columnsToDrop: _*) } + } diff --git a/spark/src/test/scala/ai/chronon/spark/test/ExternalSourcesTest.scala b/spark/src/test/scala/ai/chronon/spark/test/ExternalSourcesTest.scala index 659fcab36d..5a2c06e693 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/ExternalSourcesTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/ExternalSourcesTest.scala @@ -227,6 +227,171 @@ class ExternalSourcesTest { assertEquals(numbers, (7 until 10).toSet) } + @Test + def testExternalSourceWithOfflineGroupBy(): Unit = { + // Create an offline GroupBy for the external source + val offlineGroupBy = Builders.GroupBy( + sources = Seq( + Builders.Source.events( + query = Builders.Query( + selects = Map("user_id" -> "user_id", "score" -> "score"), + timeColumn = "ts" + ), + table = "offline_table" + ) + ), + keyColumns = Seq("user_id"), + aggregations = Seq( + Builders.Aggregation( + operation = Operation.SUM, + inputColumn = "score", + windows = Seq(new Window(7, TimeUnit.DAYS)) + ) + ), + metaData = Builders.MetaData(name = "offline_gb", namespace = "test"), + accuracy = Accuracy.SNAPSHOT + ) + + // Create a regular GroupBy for comparison + val regularGroupBy = Builders.GroupBy( + sources = Seq( + Builders.Source.events( + query = Builders.Query( + selects = Map("user_id" -> "user_id", "activity_count" -> "activity_count"), + timeColumn = "ts" + ), + table = "regular_table" + ) + ), + keyColumns = Seq("user_id"), + aggregations = Seq( + Builders.Aggregation( + operation = Operation.COUNT, + inputColumn = "activity_count", + windows = Seq(new Window(1, TimeUnit.DAYS)) + ) + ), + metaData = Builders.MetaData(name = "regular_gb", namespace = "test"), + accuracy = Accuracy.SNAPSHOT + ) + + // Create factory configuration + val factoryConfig = new ExternalSourceFactoryConfig() + factoryConfig.setFactoryName("test-online-factory") + factoryConfig.setFactoryParams(Map("multiplier" -> "10").toJava) + + // Create external source WITH both offlineGroupBy and factory config + val externalSourceWithOffline = Builders.ExternalSource( + metadata = Builders.MetaData(name = "external_with_offline"), + keySchema = StructType("keys", Array(StructField("user_id", StringType))), + valueSchema = StructType("values", Array(StructField("score", LongType))) + ) + externalSourceWithOffline.setOfflineGroupBy(offlineGroupBy) + externalSourceWithOffline.setFactoryConfig(factoryConfig) + + val namespace = "offline_test" + val join = Builders.Join( + left = Builders.Source.events( + Builders.Query(selects = Map("user_id" -> "user_id")), + table = "non_existent_table" + ), + joinParts = Seq( + Builders.JoinPart( + groupBy = regularGroupBy, + prefix = "regular" + ) + ), + externalParts = Seq( + Builders.ExternalPart( + externalSourceWithOffline, + prefix = "offline" + ) + ), + metaData = Builders.MetaData(name = "test/offline_join", namespace = namespace, team = "chronon") + ) + + // Setup MockApi with factory registration + val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("offline_test") + val mockApi = new MockApi(kvStoreFunc, "offline_test") + + // Register a test factory that returns specific values + // This proves online fetching uses the factory, NOT the offline GroupBy + mockApi.externalRegistry.addFactory("test-online-factory", new TestOnlineFactory()) + + val fetcher = mockApi.buildFetcher(true) + fetcher.kvStore.create(ChrononMetadataKey) + fetcher.putJoinConf(join) + + // Create test requests + val requests = Seq( + Request(join.metaData.name, Map("user_id" -> "user_1")), + Request(join.metaData.name, Map("user_id" -> "user_2")) + ) + + val responsesF = fetcher.fetchJoin(requests) + val responses = Await.result(responsesF, Duration(10, SECONDS)) + + // Verify responses came from the factory for external source, not from offline GroupBy + // This is the key test: even though offlineGroupBy is configured, online serving uses the factory + responses.foreach { response => + assertTrue("Response should be successful", response.values.isSuccess) + val responseMap = response.values.get + val keys = responseMap.keysIterator.toSet + + // Should have external source column from factory + assertTrue("Should contain external source column", keys.contains("ext_offline_external_with_offline_score")) + + // Should have regular GroupBy column + assertTrue("Should contain regular GroupBy column", keys.exists(_.startsWith("regular_"))) + + // Verify external source data comes from factory (100 or 200) + // This is the core assertion: values come from ExternalSourceFactory, NOT from offlineGroupBy + val score = responseMap("ext_offline_external_with_offline_score").asInstanceOf[Long] + assertTrue("Score should be from factory (100 or 200), proving online uses factory not offlineGroupBy", + score == 100L || score == 200L) + } + + // Verify both users got their expected scores from the factory + val scores = responses.map(_.values.get("ext_offline_external_with_offline_score").asInstanceOf[Long]).toSet + assertEquals("Both factory-generated scores should be present", Set(100L, 200L), scores) + + // Additional verification: Confirm the join has both regular GroupBy and external parts + assertEquals("Join should have 1 regular join part", 1, join.joinParts.size()) + assertEquals("Join should have 1 external part", 1, join.onlineExternalParts.size()) + + // Verify the external part has offlineGroupBy configured + val externalPart = join.onlineExternalParts.get(0) + assertNotNull("External source should have offlineGroupBy", externalPart.source.offlineGroupBy) + assertNotNull("External source should have factory config", externalPart.source.factoryConfig) + } + + // Test factory implementation that returns different values than what offline GroupBy would produce + class TestOnlineFactory extends ai.chronon.online.ExternalSourceFactory { + import ai.chronon.online.Fetcher.{Request, Response} + import scala.concurrent.Future + import scala.util.Success + + override def createExternalSourceHandler( + externalSource: ai.chronon.api.ExternalSource): ai.chronon.online.ExternalSourceHandler = { + new ai.chronon.online.ExternalSourceHandler { + override def fetch(requests: scala.collection.Seq[Request]): Future[scala.collection.Seq[Response]] = { + val responses = requests.map { request => + val userId = request.keys("user_id").asInstanceOf[String] + // Return deterministic values based on user_id that would be different from offline GroupBy + val score = userId match { + case "user_1" => 100L + case "user_2" => 200L + case _ => 999L + } + val result: Map[String, AnyRef] = Map("score" -> Long.box(score)) + Response(request = request, values = Success(result)) + } + Future.successful(responses) + } + } + } + } + // Test factory implementation for the factory-based registration test class TestExternalSourceFactory extends ai.chronon.online.ExternalSourceFactory { import ai.chronon.online.Fetcher.{Request, Response} From 0f4a1eaf714f3703a553fbeb312e8b7d26a740c9 Mon Sep 17 00:00:00 2001 From: Hui Cheng Date: Mon, 13 Oct 2025 15:15:24 -0700 Subject: [PATCH 10/13] Add python part --- api/py/ai/chronon/join.py | 5 ++ ...join_with_derivations_on_external_parts.py | 7 +- ...join_with_derivations_on_external_parts.v1 | 65 ++++++++++++++++++- api/py/test/test_compile.py | 45 +++++++++++++ 4 files changed, 117 insertions(+), 5 deletions(-) diff --git a/api/py/ai/chronon/join.py b/api/py/ai/chronon/join.py index f98a9446d7..1c55a71973 100644 --- a/api/py/ai/chronon/join.py +++ b/api/py/ai/chronon/join.py @@ -140,6 +140,7 @@ def ExternalSource( custom_json: Optional[str] = None, factory_name: Optional[str] = None, factory_params: Optional[Dict[str, str]] = None, + offline_group_by: Optional[api.GroupBy] = None, ) -> api.ExternalSource: """ External sources are online only data sources. During fetching, using @@ -180,6 +181,9 @@ def ExternalSource( creating the external source handler. :param factory_params: Optional parameters to pass to the factory when creating the handler. + :param offline_group_by: Optional GroupBy configuration to be used for + offline backfill computation. When provided, enables point-in-time + correct (PITC) offline computation for the external source. """ assert name != "contextual", "Please use `ContextualSource`" @@ -193,6 +197,7 @@ def ExternalSource( keySchema=DataType.STRUCT(f"ext_{name}_keys", *key_fields), valueSchema=DataType.STRUCT(f"ext_{name}_values", *value_fields), factoryConfig=factory_config, + offlineGroupBy=offline_group_by, ) diff --git a/api/py/test/sample/joins/sample_team/sample_join_with_derivations_on_external_parts.py b/api/py/test/sample/joins/sample_team/sample_join_with_derivations_on_external_parts.py index c9ac5fb541..2940dfdbbe 100644 --- a/api/py/test/sample/joins/sample_team/sample_join_with_derivations_on_external_parts.py +++ b/api/py/test/sample/joins/sample_team/sample_join_with_derivations_on_external_parts.py @@ -52,14 +52,15 @@ name="test_external_source", team="chronon", key_fields=[ - ("key", DataType.LONG) + ("group_by_subject", DataType.STRING) ], value_fields=[ ("value_str", DataType.STRING), ("value_long", DataType.LONG), ("value_bool", DataType.BOOLEAN) - ] - ) + ], + offline_group_by=event_sample_group_by.v1, + ), ), ExternalPart( ContextualSource( diff --git a/api/py/test/sample/production/joins/sample_team/sample_join_with_derivations_on_external_parts.v1 b/api/py/test/sample/production/joins/sample_team/sample_join_with_derivations_on_external_parts.v1 index 60b1f210cf..cddba0cc8f 100644 --- a/api/py/test/sample/production/joins/sample_team/sample_join_with_derivations_on_external_parts.v1 +++ b/api/py/test/sample/production/joins/sample_team/sample_join_with_derivations_on_external_parts.v1 @@ -172,9 +172,9 @@ "kind": 13, "params": [ { - "name": "key", + "name": "group_by_subject", "dataType": { - "kind": 4 + "kind": 7 } } ], @@ -203,6 +203,67 @@ } ], "name": "ext_test_external_source_values" + }, + "offlineGroupBy": { + "metaData": { + "name": "sample_team.event_sample_group_by.v1", + "online": 1, + "customJson": "{\"lag\": 0, \"groupby_tags\": {\"TO_DEPRECATE\": true}, \"column_tags\": {\"event_sum_7d\": {\"DETAILED_TYPE\": \"CONTINUOUS\"}}}", + "dependencies": [ + "{\"name\": \"wait_for_sample_namespace.sample_table_group_by_ds\", \"spec\": \"sample_namespace.sample_table_group_by/ds={{ ds }}\", \"start\": \"2021-04-09\", \"end\": null}" + ], + "tableProperties": { + "source": "chronon" + }, + "outputNamespace": "sample_namespace", + "team": "sample_team", + "offlineSchedule": "@daily" + }, + "sources": [ + { + "events": { + "table": "sample_namespace.sample_table_group_by", + "query": { + "selects": { + "event": "event_expr", + "group_by_subject": "group_by_expr" + }, + "startPartition": "2021-04-09", + "timeColumn": "ts", + "setups": [] + } + } + } + ], + "keyColumns": [ + "group_by_subject" + ], + "aggregations": [ + { + "inputColumn": "event", + "operation": 7, + "argMap": {}, + "windows": [ + { + "length": 7, + "timeUnit": 1 + } + ] + }, + { + "inputColumn": "event", + "operation": 7, + "argMap": {} + }, + { + "inputColumn": "event", + "operation": 12, + "argMap": { + "k": "200", + "percentiles": "[0.99, 0.95, 0.5]" + } + } + ] } } }, diff --git a/api/py/test/test_compile.py b/api/py/test/test_compile.py index 29782ed14f..3621891194 100644 --- a/api/py/test/test_compile.py +++ b/api/py/test/test_compile.py @@ -399,3 +399,48 @@ def test_compile_inline_group_by(): join = json2thrift(file.read(), Join) assert len(join.joinParts) == 1 assert join.joinParts[0].groupBy.metaData.team == "unit_test" + + +def test_compile_external_source_with_offline_group_by(): + """ + Test that compiling a join with an external source that has an offlineGroupBy + correctly materializes the offlineGroupBy in the external source. + """ + runner = CliRunner() + input_path = "joins/sample_team/sample_join_with_derivations_on_external_parts.py" + result = _invoke_cli_with_params(runner, input_path) + assert result.exit_code == 0 + + # Verify the compiled join contains the external source with offlineGroupBy + path = "sample/production/joins/sample_team/sample_join_with_derivations_on_external_parts.v1" + full_file_path = _get_full_file_path(path) + _assert_file_exists(full_file_path, f"Expected {os.path.basename(path)} to be materialized, but it was not.") + + with open(full_file_path, "r") as file: + join = json2thrift(file.read(), Join) + + # Verify the join has online external parts + assert join.onlineExternalParts is not None, "Expected onlineExternalParts to be present" + assert len(join.onlineExternalParts) > 0, "Expected at least one external part" + + # Find the external source with offlineGroupBy + external_source_with_offline_gb = None + for external_part in join.onlineExternalParts: + if external_part.source.metadata.name == "test_external_source": + external_source_with_offline_gb = external_part.source + break + + assert external_source_with_offline_gb is not None, "Expected to find test_external_source" + + # Verify the offlineGroupBy is present and has the expected properties + assert external_source_with_offline_gb.offlineGroupBy is not None, ( + "Expected offlineGroupBy to be present in test_external_source" + ) + + offline_gb = external_source_with_offline_gb.offlineGroupBy + assert offline_gb.keyColumns == ["group_by_subject"], f"Expected key columns to be ['group_by_subject'], got {offline_gb.keyColumns}" + assert offline_gb.aggregations is not None, "Expected aggregations to be present" + assert len(offline_gb.aggregations) == 3, f"Expected 3 aggregations, got {len(offline_gb.aggregations)}" + assert offline_gb.metaData.outputNamespace == "sample_namespace", ( + f"Expected output namespace to be 'sample_namespace', got {offline_gb.metaData.outputNamespace}" + ) From 327a0099b8c2ed9388044d75ba3f876d94849301 Mon Sep 17 00:00:00 2001 From: Hui Cheng Date: Mon, 13 Oct 2025 16:43:43 -0700 Subject: [PATCH 11/13] Update test --- .../ai/chronon/spark/test/ExternalSourceBackfillTest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/test/scala/ai/chronon/spark/test/ExternalSourceBackfillTest.scala b/spark/src/test/scala/ai/chronon/spark/test/ExternalSourceBackfillTest.scala index 89d7e09422..9f124ab2da 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/ExternalSourceBackfillTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/ExternalSourceBackfillTest.scala @@ -17,10 +17,10 @@ package ai.chronon.spark.test import ai.chronon.aggregator.test.Column -import ai.chronon.api.Extensions._ import ai.chronon.api.{Accuracy, Builders, DoubleType, LongType, Operation, StringType, StructField, StructType, TimeUnit, Window} import ai.chronon.spark.Extensions._ import ai.chronon.spark._ +import ai.chronon.spark.catalog.TableUtils import org.apache.spark.sql.SparkSession import org.junit.Assert._ import org.junit.Test From 4c9e70bd0315c7790ecd1c7395e69f0e8a0beb2c Mon Sep 17 00:00:00 2001 From: Hui Cheng Date: Tue, 14 Oct 2025 09:35:43 -0700 Subject: [PATCH 12/13] Update test cases to use separate kv store --- .../spark/test/ChainingFetcherTest.scala | 2 +- .../spark/test/ExternalSourcesTest.scala | 6 ++-- .../chronon/spark/test/FetchStatsTest.scala | 2 +- .../ai/chronon/spark/test/FetcherTest.scala | 30 +++++++++++-------- .../spark/test/GroupByUploadTest.scala | 4 +-- .../spark/test/MetadataStoreTest.scala | 4 +-- .../spark/test/ModelTransformsTest.scala | 2 +- .../spark/test/SchemaEvolutionTest.scala | 4 +-- .../spark/test/bootstrap/DerivationTest.scala | 2 +- .../test/bootstrap/LogBootstrapTest.scala | 2 +- 10 files changed, 32 insertions(+), 26 deletions(-) diff --git a/spark/src/test/scala/ai/chronon/spark/test/ChainingFetcherTest.scala b/spark/src/test/scala/ai/chronon/spark/test/ChainingFetcherTest.scala index b389f6ad37..ed03a947e4 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/ChainingFetcherTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/ChainingFetcherTest.scala @@ -212,7 +212,7 @@ class ChainingFetcherTest extends TestCase { def executeFetch(joinConf: api.Join, endDs: String, namespace: String): (DataFrame, Seq[Row]) = { implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1)) implicit val tableUtils: TableUtils = TableUtils(spark) - val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("ChainingFetcherTest") + val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore(s"ChainingFetcherTest_$namespace") val inMemoryKvStore = kvStoreFunc() val mockApi = new MockApi(kvStoreFunc, namespace) diff --git a/spark/src/test/scala/ai/chronon/spark/test/ExternalSourcesTest.scala b/spark/src/test/scala/ai/chronon/spark/test/ExternalSourcesTest.scala index 5a2c06e693..d326ccf036 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/ExternalSourcesTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/ExternalSourcesTest.scala @@ -91,7 +91,7 @@ class ExternalSourcesTest { ) // put this join into kv store - val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("external_test") + val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("ExternalSourcesTest_testFetch") val mockApi = new MockApi(kvStoreFunc, "external_test") val fetcher = mockApi.buildFetcher(true) fetcher.kvStore.create(ChrononMetadataKey) @@ -197,7 +197,7 @@ class ExternalSourcesTest { ) // Setup MockApi with factory registration - val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("factory_test") + val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("ExternalSourcesTest_testFactoryBasedExternalSources") val mockApi = new MockApi(kvStoreFunc, "factory_test") // Register a test factory that creates handlers dynamically @@ -311,7 +311,7 @@ class ExternalSourcesTest { ) // Setup MockApi with factory registration - val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("offline_test") + val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("ExternalSourcesTest_testExternalSourceWithOfflineGroupBy") val mockApi = new MockApi(kvStoreFunc, "offline_test") // Register a test factory that returns specific values diff --git a/spark/src/test/scala/ai/chronon/spark/test/FetchStatsTest.scala b/spark/src/test/scala/ai/chronon/spark/test/FetchStatsTest.scala index ef91b3a243..f80836405c 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/FetchStatsTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/FetchStatsTest.scala @@ -124,7 +124,7 @@ class FetchStatsTest extends TestCase { joinJob.computeJoin() // Load some data. implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1)) - val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("FetchStatsTest") + val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("FetchStatsTest_testFetchStats") val inMemoryKvStore = kvStoreFunc() val metadataStore = new MetadataStore(inMemoryKvStore, timeoutMillis = 10000) inMemoryKvStore.create(ChrononMetadataKey) diff --git a/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala b/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala index fea7c0c00e..ca58d46959 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala @@ -535,11 +535,12 @@ class FetcherTest extends TestCase { endDs: String, namespace: String, consistencyCheck: Boolean, - dropDsOnWrite: Boolean): Unit = { + dropDsOnWrite: Boolean, + sessionName: String): Unit = { implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1)) val spark: SparkSession = createSparkSession() val tableUtils = TableUtils(spark) - val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("FetcherTest") + val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore(sessionName) val inMemoryKvStore = kvStoreFunc() val mockApi = new MockApi(kvStoreFunc, namespace) @@ -668,7 +669,8 @@ class FetcherTest extends TestCase { def testTemporalFetchJoinDeterministic(): Unit = { val namespace = "deterministic_fetch" val joinConf = generateMutationData(namespace) - compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true) + compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true, + sessionName = "FetcherTest_testTemporalFetchJoinDeterministic") } def testTemporalFetchJoinDerivation(): Unit = { @@ -682,7 +684,8 @@ class FetcherTest extends TestCase { ) joinConf.setDerivations(derivations.toJava) - compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true) + compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true, + sessionName = "FetcherTest_testTemporalFetchJoinDerivation") } def testTemporalFetchJoinDerivationRenameOnly(): Unit = { @@ -692,7 +695,8 @@ class FetcherTest extends TestCase { Seq(Builders.Derivation.star(), Builders.Derivation(name = "listing_id_renamed", expression = "listing_id")) joinConf.setDerivations(derivations.toJava) - compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true) + compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true, + sessionName = "FetcherTest_testTemporalFetchJoinDerivationRenameOnly") } def testTemporalFetchJoinGenerated(): Unit = { @@ -702,13 +706,15 @@ class FetcherTest extends TestCase { dummyTableUtils.partitionSpec.at(System.currentTimeMillis()), namespace, consistencyCheck = true, - dropDsOnWrite = false) + dropDsOnWrite = false, + sessionName = "FetcherTest_testTemporalFetchJoinGenerated") } def testTemporalTiledFetchJoinDeterministic(): Unit = { val namespace = "deterministic_tiled_fetch" val joinConf = generateEventOnlyData(namespace, groupByCustomJson = Some("{\"enable_tiling\": true}")) - compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true) + compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true, + sessionName = "FetcherTest_testTemporalTiledFetchJoinDeterministic") } // test soft-fail on missing keys @@ -717,7 +723,7 @@ class FetcherTest extends TestCase { val namespace = "empty_request" val joinConf = generateRandomData(namespace, 5, 5) implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1)) - val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("FetcherTest#empty_request") + val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("FetcherTest_testEmptyRequest") val inMemoryKvStore = kvStoreFunc() val mockApi = new MockApi(kvStoreFunc, namespace) @@ -743,7 +749,7 @@ class FetcherTest extends TestCase { val joinConf = generateMutationData(namespace, Some(spark)) val endDs = "2021-04-10" val tableUtils = TableUtils(spark) - val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("FetcherTest") + val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("FetcherTest_testTemporalFetchGroupByNonExistKey") val inMemoryKvStore = kvStoreFunc() val mockApi = new MockApi(kvStoreFunc, namespace) @transient lazy val fetcher = mockApi.buildFetcher(debug = false) @@ -770,7 +776,7 @@ class FetcherTest extends TestCase { implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1)) val kvStoreFunc = () => - OnlineUtils.buildInMemoryKVStore("FetcherTest#test_kv_store_partial_failure", hardFailureOnInvalidDataset = true) + OnlineUtils.buildInMemoryKVStore("FetcherTest_testKVStorePartialFailure", hardFailureOnInvalidDataset = true) val inMemoryKvStore = kvStoreFunc() val mockApi = new MockApi(kvStoreFunc, namespace) @@ -798,7 +804,7 @@ class FetcherTest extends TestCase { val groupByConf = joinConf.joinParts.toScala.head.groupBy val endDs = "2021-04-10" val tableUtils = TableUtils(spark) - val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("FetcherTest") + val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("FetcherTest_testGroupByServingInfoTtlCacheRefresh") OnlineUtils.serve(tableUtils, kvStoreFunc(), kvStoreFunc, namespace, endDs, groupByConf, dropDsOnWrite = true) val spyKvStore = spy(kvStoreFunc()) @@ -833,7 +839,7 @@ class FetcherTest extends TestCase { val joinConf = generateMutationData(namespace, Some(spark)) val endDs = "2021-04-10" val tableUtils = TableUtils(spark) - val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("FetcherTest") + val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("FetcherTest_testJoinConfTtlCacheRefresh") val inMemoryKvStore = kvStoreFunc() joinConf.joinParts.toScala.foreach(jp => OnlineUtils.serve(tableUtils, inMemoryKvStore, kvStoreFunc, namespace, endDs, jp.groupBy, dropDsOnWrite = true)) diff --git a/spark/src/test/scala/ai/chronon/spark/test/GroupByUploadTest.scala b/spark/src/test/scala/ai/chronon/spark/test/GroupByUploadTest.scala index a6e9283fff..a3e844e4af 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/GroupByUploadTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/GroupByUploadTest.scala @@ -304,9 +304,9 @@ class GroupByUploadTest { ) ) - val kvStore = OnlineUtils.buildInMemoryKVStore("chaining_test") + val kvStore = OnlineUtils.buildInMemoryKVStore("GroupByUploadTest_listingRatingCategoryJoinSourceTest") val endDs = "2023-08-15" - val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("chaining_test") + val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("GroupByUploadTest_listingRatingCategoryJoinSourceTest") // DO-NOT-SET debug=true here since the streaming job won't put data into kv store joinConf.joinParts.toScala.foreach(jp => diff --git a/spark/src/test/scala/ai/chronon/spark/test/MetadataStoreTest.scala b/spark/src/test/scala/ai/chronon/spark/test/MetadataStoreTest.scala index be3be67486..2cb4bb684f 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/MetadataStoreTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/MetadataStoreTest.scala @@ -24,7 +24,7 @@ class MetadataStoreTest extends TestCase { val acceptedEndPoints = List(MetadataEndPoint.ConfByKeyEndPointName, MetadataEndPoint.NameByTeamEndPointName) def testMetadataStoreSingleFile(): Unit = { - val inMemoryKvStore = OnlineUtils.buildInMemoryKVStore("FetcherTest") + val inMemoryKvStore = OnlineUtils.buildInMemoryKVStore("MetadataStoreTest_testMetadataStoreSingleFile") val singleFileDataSet = ChrononMetadataKey val singleFileMetaDataSet = NameByTeamEndPointName val singleFileMetadataStore = new MetadataStore(inMemoryKvStore, singleFileDataSet, timeoutMillis = 10000) @@ -57,7 +57,7 @@ class MetadataStoreTest extends TestCase { } def testMetadataStoreDirectory(): Unit = { - val inMemoryKvStore = OnlineUtils.buildInMemoryKVStore("FetcherTest") + val inMemoryKvStore = OnlineUtils.buildInMemoryKVStore("MetadataStoreTest_testMetadataStoreDirectory") val directoryDataSetDataSet = ChrononMetadataKey val directoryMetadataDataSet = NameByTeamEndPointName val directoryMetadataStore = new MetadataStore(inMemoryKvStore, directoryDataSetDataSet, timeoutMillis = 10000) diff --git a/spark/src/test/scala/ai/chronon/spark/test/ModelTransformsTest.scala b/spark/src/test/scala/ai/chronon/spark/test/ModelTransformsTest.scala index a54a7caba7..6871cd5531 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/ModelTransformsTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/ModelTransformsTest.scala @@ -85,7 +85,7 @@ class ModelTransformsTest { val spark = createSparkSession() val tableUtils = TableUtils(spark) - val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("testModelTransformSimple") + val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore(s"ModelTransformsTest_$namespace") val inMemoryKvStore = kvStoreFunc() val defaultEmbedding = Array(0.1, 0.2, 0.3, 0.4) val defaultMap = Map("embedding" -> defaultEmbedding.asInstanceOf[AnyRef]) diff --git a/spark/src/test/scala/ai/chronon/spark/test/SchemaEvolutionTest.scala b/spark/src/test/scala/ai/chronon/spark/test/SchemaEvolutionTest.scala index e5a67bfb2d..6f9dd13b10 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/SchemaEvolutionTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/SchemaEvolutionTest.scala @@ -382,7 +382,7 @@ class SchemaEvolutionTest extends TestCase { assert(joinSuiteV1.joinConf.metaData.name == joinSuiteV2.joinConf.metaData.name, message = "Schema evolution can only be tested on changes of the SAME join") val tableUtils: TableUtils = TableUtils(spark) - val inMemoryKvStore = OnlineUtils.buildInMemoryKVStore(namespace) + val inMemoryKvStore = OnlineUtils.buildInMemoryKVStore(s"SchemaEvolutionTest_$namespace") val mockApi = new MockApi(() => inMemoryKvStore, namespace) inMemoryKvStore.create(ChrononMetadataKey) val metadataStore = new MetadataStore(inMemoryKvStore, timeoutMillis = 10000) @@ -520,7 +520,7 @@ class SchemaEvolutionTest extends TestCase { val joinTestSuite = createStructJoin(namespace) val tableUtils: TableUtils = TableUtils(spark) - val inMemoryKvStore = OnlineUtils.buildInMemoryKVStore(namespace) + val inMemoryKvStore = OnlineUtils.buildInMemoryKVStore("SchemaEvolutionTest_testStructFeatures") val mockApi = new MockApi(() => inMemoryKvStore, namespace) inMemoryKvStore.create(ChrononMetadataKey) val metadataStore = new MetadataStore(inMemoryKvStore, timeoutMillis = 10000) diff --git a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/DerivationTest.scala b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/DerivationTest.scala index 05ca50fce3..b8be60d5f4 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/DerivationTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/DerivationTest.scala @@ -428,7 +428,7 @@ class DerivationTest { ) // Init artifacts to run online fetching and logging - val kvStore = OnlineUtils.buildInMemoryKVStore(namespace) + val kvStore = OnlineUtils.buildInMemoryKVStore(s"DerivationTest_$namespace") val mockApi = new MockApi(() => kvStore, namespace) OnlineUtils.serve(tableUtils, kvStore, () => kvStore, namespace, endDs, groupBy) val fetcher = mockApi.buildFetcher(debug = true) diff --git a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/LogBootstrapTest.scala b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/LogBootstrapTest.scala index 51d99e64ac..51c71da0e1 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/LogBootstrapTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/LogBootstrapTest.scala @@ -104,7 +104,7 @@ class LogBootstrapTest { val joinV2 = createBootstrapJoin(baseJoinV2) // Init artifacts to run online fetching and logging - val kvStore = OnlineUtils.buildInMemoryKVStore(namespace) + val kvStore = OnlineUtils.buildInMemoryKVStore("LogBootstrapTest_testBootstrap") val mockApi = new MockApi(() => kvStore, namespace) val endDs = spark.table(queryTable).select(max(tableUtils.partitionColumn)).head().getString(0) OnlineUtils.serve(tableUtils, kvStore, () => kvStore, namespace, endDs, groupBy) From 232b9695d7e985c2a0a3a064c3969da978b97bd0 Mon Sep 17 00:00:00 2001 From: Hui Cheng Date: Thu, 16 Oct 2025 16:27:38 -0700 Subject: [PATCH 13/13] Update Analyzer and compile --- api/py/ai/chronon/join.py | 68 +++++++++++++------ api/py/ai/chronon/repo/compile.py | 26 +++++++ .../scala/ai/chronon/spark/Analyzer.scala | 57 +++++++++++++++- .../ai/chronon/spark/test/AnalyzerTest.scala | 19 +++++- 4 files changed, 143 insertions(+), 27 deletions(-) diff --git a/api/py/ai/chronon/join.py b/api/py/ai/chronon/join.py index 1c55a71973..74e170b47f 100644 --- a/api/py/ai/chronon/join.py +++ b/api/py/ai/chronon/join.py @@ -57,33 +57,14 @@ def JoinPart( JoinPart specifies how the left side of a join, or the query in online setting, would join with the right side components like GroupBys. """ - # used for reset for next run - import_copy = __builtins__["__import__"] - # get group_by's module info from garbage collector - gc.collect() - group_by_module_name = None - for ref in gc.get_referrers(group_by): - if "__name__" in ref and ref["__name__"].startswith("group_bys"): - group_by_module_name = ref["__name__"] - break - if group_by_module_name: - logging.debug("group_by's module info from garbage collector {}".format(group_by_module_name)) - group_by_module = importlib.import_module(group_by_module_name) - __builtins__["__import__"] = eo.import_module_set_name(group_by_module, api.GroupBy) - else: - if not group_by.metaData.name: - logging.error("No group_by file or custom group_by name found") - raise ValueError( - "[GroupBy] Must specify a group_by name if group_by is not defined in separate file. " - "You may pass it in via GroupBy.name. \n" - ) + # Automatically set the GroupBy name if not already set + _auto_set_group_by_name(group_by, context="JoinPart") + if key_mapping: utils.check_contains(key_mapping.values(), group_by.keyColumns, "key", group_by.metaData.name) join_part = api.JoinPart(groupBy=group_by, keyMapping=key_mapping, prefix=prefix) join_part.tags = tags - # reset before next run - __builtins__["__import__"] = import_copy return join_part @@ -188,6 +169,10 @@ def ExternalSource( """ assert name != "contextual", "Please use `ContextualSource`" + # Automatically set the name for offline_group_by if not already set + if offline_group_by is not None: + _auto_set_group_by_name(offline_group_by, context="ExternalSource") + factory_config = None if factory_name is not None or factory_params is not None: factory_config = api.ExternalSourceFactoryConfig(factoryName=factory_name, factoryParams=factory_params) @@ -659,3 +644,42 @@ def Join( derivations=derivations, modelTransforms=model_transforms, ) + +def _auto_set_group_by_name(group_by: api.GroupBy, context: str = "GroupBy") -> None: + """ + Automatically set the GroupBy name by finding its source module using garbage collection. + This is used by both JoinPart and ExternalSource to automatically name GroupBys. + + :param group_by: The GroupBy object to set the name for + :param context: Context string for error messages (e.g., "JoinPart", "ExternalSource") + """ + if group_by.metaData.name: + # Name already set, nothing to do + return + + # Save and restore __import__ to preserve original behavior + import_copy = __builtins__["__import__"] + + try: + # Use garbage collector to find the module where this GroupBy was defined + gc.collect() + group_by_module_name = None + for ref in gc.get_referrers(group_by): + if "__name__" in ref and ref["__name__"].startswith("group_bys"): + group_by_module_name = ref["__name__"] + break + + if group_by_module_name: + logging.debug(f"{context}: group_by's module info from garbage collector {group_by_module_name}") + group_by_module = importlib.import_module(group_by_module_name) + __builtins__["__import__"] = eo.import_module_set_name(group_by_module, api.GroupBy) + else: + if not group_by.metaData.name: + logging.error(f"{context}: No group_by file or custom group_by name found") + raise ValueError( + f"[{context}] Must specify a group_by name if group_by is not defined in separate file. " + "You can set it via GroupBy(metaData=MetaData(name='team.file.variable_name'))" + ) + finally: + # Reset before next run + __builtins__["__import__"] = import_copy diff --git a/api/py/ai/chronon/repo/compile.py b/api/py/ai/chronon/repo/compile.py index a44fdb0cdf..57665b4191 100755 --- a/api/py/ai/chronon/repo/compile.py +++ b/api/py/ai/chronon/repo/compile.py @@ -136,6 +136,7 @@ def extract_and_convert(chronon_root, input_path, output_root, debug, force_over # In case of join, we need to materialize the following underlying group_bys # 1. group_bys whose online param is set # 2. group_bys whose backfill_start_date param is set. + # 3. offline group_bys from external parts (always materialized) if obj_class is Join: online_group_bys = {} offline_backfill_enabled_group_bys = {} @@ -148,6 +149,10 @@ def extract_and_convert(chronon_root, input_path, output_root, debug, force_over else: offline_gbs.append(jp.groupBy.metaData.name) + # Extract and always materialize online GroupBys from external parts + external_offline_gbs = _extract_external_part_offline_group_bys(obj, team_name, teams_path) + online_group_bys.update(external_offline_gbs) + _print_debug_info(list(online_group_bys.keys()), "Online Groupbys", log_level) _print_debug_info( list(offline_backfill_enabled_group_bys.keys()), "Offline Groupbys With Backfill Enabled", log_level @@ -445,6 +450,27 @@ def _set_templated_values(obj, cls, teams_path, team_name): obj.metaData.dependencies = [__fill_template(dep, obj, namespace) for dep in obj.metaData.dependencies] +def _extract_external_part_offline_group_bys(join_obj: api.Join, team_name: str, teams_path: str): + """ + Extract offline GroupBys from external parts in a Join. + Sets proper metadata (name, team, namespace) for each offline GroupBy. + Returns a dictionary of {groupby_name: groupby_object}. + """ + external_offline_gbs = {} + + if not join_obj.onlineExternalParts: + return external_offline_gbs + + for external_part in join_obj.onlineExternalParts: + if not external_part.source or not external_part.source.offlineGroupBy: + continue + + offline_gb = external_part.source.offlineGroupBy + external_offline_gbs[offline_gb.metaData.name] = offline_gb + + return external_offline_gbs + + def _write_obj( full_output_root: str, validator: ChrononRepoValidator, diff --git a/spark/src/main/scala/ai/chronon/spark/Analyzer.scala b/spark/src/main/scala/ai/chronon/spark/Analyzer.scala index a1fd715cf8..383412427e 100644 --- a/spark/src/main/scala/ai/chronon/spark/Analyzer.scala +++ b/spark/src/main/scala/ai/chronon/spark/Analyzer.scala @@ -19,7 +19,7 @@ package ai.chronon.spark import ai.chronon.api import ai.chronon.api.DataModel.{DataModel, Entities, Events} import ai.chronon.api.Extensions._ -import ai.chronon.api.{Accuracy, AggregationPart, Constants, DataType, ExternalJoinPart, TimeUnit, Window} +import ai.chronon.api.{Accuracy, AggregationPart, Constants, DataKind, DataType, ExternalJoinPart, TDataType, TimeUnit, Window} import ai.chronon.online.SparkConversions import ai.chronon.spark.Driver.parseConf import ai.chronon.spark.Extensions.StructTypeOps @@ -869,6 +869,56 @@ class Analyzer(tableUtils: TableUtils, errors } + /** + * Recursively clears struct names from a DataType to enable name-agnostic comparison. + * This is needed because TDataType.equals() compares the name field. + * + * @param dataType The DataType to process + * @return A TDataType with all struct names cleared + */ + private def clearStructNames(dataType: DataType): TDataType = { + val tDataType = DataType.toTDataType(dataType) + + // Clear the name if this is a struct type + if (tDataType.getKind == DataKind.STRUCT) { + tDataType.unsetName() + } + + // Recursively clear names in nested types + if (tDataType.isSetParams) { + val params = tDataType.getParams + params.forEach { param => + if (param.isSetDataType) { + val nestedType = DataType.fromTDataType(param.getDataType) + val clearedNested = clearStructNames(nestedType) + param.setDataType(clearedNested) + } + } + } + + tDataType + } + + /** + * Deep comparison of two DataType objects with logging. + * For StructType, ignores the name and only compares fields. + * Uses TDataType conversion for proper deep comparison. + * + * @param expected The expected DataType + * @param actual The actual DataType + * @return true if the types match, false otherwise + */ + private def compareDataTypes(expected: DataType, actual: DataType): Boolean = { + // Convert to TDataType and clear struct names for comparison + val expectedTDataType = clearStructNames(expected) + val actualTDataType = clearStructNames(actual) + + logger.error(s"Expected TDataType: $expectedTDataType") + logger.error(s"Actual TDataType: $actualTDataType") + + expectedTDataType.equals(actualTDataType) + } + private def validateValueSchemaCompatibility(externalPart: api.ExternalPart, externalGroupBySchema: Seq[(String, DataType)]): Seq[String] = { val errors = scala.collection.mutable.ListBuffer[String]() @@ -903,12 +953,13 @@ class Analyzer(tableUtils: TableUtils, s"[${extraFields.mkString(", ")}]. These fields are not defined in the ExternalSource valueSchema." } - // Check for type mismatches in common fields + // Check for type mismatches in common fields using deep comparison val commonFields = expectedSchema.keySet.intersect(actualSchema.keySet) commonFields.foreach { fieldName => val expectedType = expectedSchema(fieldName) val actualType = actualSchema(fieldName) - if (expectedType != actualType) { + logger.error(s"=== Validating field '$fieldName' for ExternalSource ${externalPart.source.metadata.name} ===") + if (!compareDataTypes(expectedType, actualType)) { errors += s"ExternalSource ${externalPart.source.metadata.name} field '$fieldName' has type mismatch: " + s"expected ${DataType.toString(expectedType)} (from ExternalSource valueSchema) " + s"but offline GroupBy produces ${DataType.toString(actualType)}" diff --git a/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala b/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala index f581772357..7dc771ad12 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala @@ -922,7 +922,20 @@ class AnalyzerTest { // Create compatible schemas using the correct DataType objects val keySchema = StructType("key", Array(StructField("user_id", StringType))) - val valueSchema = StructType("value", Array(StructField("feature_value", DoubleType))) + val valueSchema = StructType("value", Array( + StructField("feature_value", DoubleType), + StructField("list_value", + api.ListType( + api.StructType("contradiction", + Array( + StructField("reason", api.StringType), + StructField("standardRule", api.StringType), + StructField("additionalRule", api.StringType) + ) + ) + ) + ) + )) // Create right side table for GroupBy val rightSchema = List(Column("user_id", api.StringType, 100), Column("value", api.DoubleType, 100)) @@ -939,7 +952,9 @@ class AnalyzerTest { keyColumns = Seq("user_id"), sources = Seq(source), derivations = Seq( - Builders.Derivation(name = "feature_value", expression = "feature_value") + Builders.Derivation(name = "feature_value", expression = "feature_value"), + Builders.Derivation(name = "list_value", + expression = "array(named_struct('reason', 'test_reason', 'standardRule', 'test_standard', 'additionalRule', 'test_additional'))") ), metaData = Builders.MetaData(name = "test_external_gb", namespace = namespace) )