diff --git a/api/py/ai/chronon/join.py b/api/py/ai/chronon/join.py index f98a9446d7..74e170b47f 100644 --- a/api/py/ai/chronon/join.py +++ b/api/py/ai/chronon/join.py @@ -57,33 +57,14 @@ def JoinPart( JoinPart specifies how the left side of a join, or the query in online setting, would join with the right side components like GroupBys. """ - # used for reset for next run - import_copy = __builtins__["__import__"] - # get group_by's module info from garbage collector - gc.collect() - group_by_module_name = None - for ref in gc.get_referrers(group_by): - if "__name__" in ref and ref["__name__"].startswith("group_bys"): - group_by_module_name = ref["__name__"] - break - if group_by_module_name: - logging.debug("group_by's module info from garbage collector {}".format(group_by_module_name)) - group_by_module = importlib.import_module(group_by_module_name) - __builtins__["__import__"] = eo.import_module_set_name(group_by_module, api.GroupBy) - else: - if not group_by.metaData.name: - logging.error("No group_by file or custom group_by name found") - raise ValueError( - "[GroupBy] Must specify a group_by name if group_by is not defined in separate file. " - "You may pass it in via GroupBy.name. \n" - ) + # Automatically set the GroupBy name if not already set + _auto_set_group_by_name(group_by, context="JoinPart") + if key_mapping: utils.check_contains(key_mapping.values(), group_by.keyColumns, "key", group_by.metaData.name) join_part = api.JoinPart(groupBy=group_by, keyMapping=key_mapping, prefix=prefix) join_part.tags = tags - # reset before next run - __builtins__["__import__"] = import_copy return join_part @@ -140,6 +121,7 @@ def ExternalSource( custom_json: Optional[str] = None, factory_name: Optional[str] = None, factory_params: Optional[Dict[str, str]] = None, + offline_group_by: Optional[api.GroupBy] = None, ) -> api.ExternalSource: """ External sources are online only data sources. During fetching, using @@ -180,10 +162,17 @@ def ExternalSource( creating the external source handler. :param factory_params: Optional parameters to pass to the factory when creating the handler. + :param offline_group_by: Optional GroupBy configuration to be used for + offline backfill computation. When provided, enables point-in-time + correct (PITC) offline computation for the external source. """ assert name != "contextual", "Please use `ContextualSource`" + # Automatically set the name for offline_group_by if not already set + if offline_group_by is not None: + _auto_set_group_by_name(offline_group_by, context="ExternalSource") + factory_config = None if factory_name is not None or factory_params is not None: factory_config = api.ExternalSourceFactoryConfig(factoryName=factory_name, factoryParams=factory_params) @@ -193,6 +182,7 @@ def ExternalSource( keySchema=DataType.STRUCT(f"ext_{name}_keys", *key_fields), valueSchema=DataType.STRUCT(f"ext_{name}_values", *value_fields), factoryConfig=factory_config, + offlineGroupBy=offline_group_by, ) @@ -654,3 +644,42 @@ def Join( derivations=derivations, modelTransforms=model_transforms, ) + +def _auto_set_group_by_name(group_by: api.GroupBy, context: str = "GroupBy") -> None: + """ + Automatically set the GroupBy name by finding its source module using garbage collection. + This is used by both JoinPart and ExternalSource to automatically name GroupBys. + + :param group_by: The GroupBy object to set the name for + :param context: Context string for error messages (e.g., "JoinPart", "ExternalSource") + """ + if group_by.metaData.name: + # Name already set, nothing to do + return + + # Save and restore __import__ to preserve original behavior + import_copy = __builtins__["__import__"] + + try: + # Use garbage collector to find the module where this GroupBy was defined + gc.collect() + group_by_module_name = None + for ref in gc.get_referrers(group_by): + if "__name__" in ref and ref["__name__"].startswith("group_bys"): + group_by_module_name = ref["__name__"] + break + + if group_by_module_name: + logging.debug(f"{context}: group_by's module info from garbage collector {group_by_module_name}") + group_by_module = importlib.import_module(group_by_module_name) + __builtins__["__import__"] = eo.import_module_set_name(group_by_module, api.GroupBy) + else: + if not group_by.metaData.name: + logging.error(f"{context}: No group_by file or custom group_by name found") + raise ValueError( + f"[{context}] Must specify a group_by name if group_by is not defined in separate file. " + "You can set it via GroupBy(metaData=MetaData(name='team.file.variable_name'))" + ) + finally: + # Reset before next run + __builtins__["__import__"] = import_copy diff --git a/api/py/ai/chronon/repo/compile.py b/api/py/ai/chronon/repo/compile.py index a44fdb0cdf..57665b4191 100755 --- a/api/py/ai/chronon/repo/compile.py +++ b/api/py/ai/chronon/repo/compile.py @@ -136,6 +136,7 @@ def extract_and_convert(chronon_root, input_path, output_root, debug, force_over # In case of join, we need to materialize the following underlying group_bys # 1. group_bys whose online param is set # 2. group_bys whose backfill_start_date param is set. + # 3. offline group_bys from external parts (always materialized) if obj_class is Join: online_group_bys = {} offline_backfill_enabled_group_bys = {} @@ -148,6 +149,10 @@ def extract_and_convert(chronon_root, input_path, output_root, debug, force_over else: offline_gbs.append(jp.groupBy.metaData.name) + # Extract and always materialize online GroupBys from external parts + external_offline_gbs = _extract_external_part_offline_group_bys(obj, team_name, teams_path) + online_group_bys.update(external_offline_gbs) + _print_debug_info(list(online_group_bys.keys()), "Online Groupbys", log_level) _print_debug_info( list(offline_backfill_enabled_group_bys.keys()), "Offline Groupbys With Backfill Enabled", log_level @@ -445,6 +450,27 @@ def _set_templated_values(obj, cls, teams_path, team_name): obj.metaData.dependencies = [__fill_template(dep, obj, namespace) for dep in obj.metaData.dependencies] +def _extract_external_part_offline_group_bys(join_obj: api.Join, team_name: str, teams_path: str): + """ + Extract offline GroupBys from external parts in a Join. + Sets proper metadata (name, team, namespace) for each offline GroupBy. + Returns a dictionary of {groupby_name: groupby_object}. + """ + external_offline_gbs = {} + + if not join_obj.onlineExternalParts: + return external_offline_gbs + + for external_part in join_obj.onlineExternalParts: + if not external_part.source or not external_part.source.offlineGroupBy: + continue + + offline_gb = external_part.source.offlineGroupBy + external_offline_gbs[offline_gb.metaData.name] = offline_gb + + return external_offline_gbs + + def _write_obj( full_output_root: str, validator: ChrononRepoValidator, diff --git a/api/py/test/sample/joins/sample_team/sample_join_with_derivations_on_external_parts.py b/api/py/test/sample/joins/sample_team/sample_join_with_derivations_on_external_parts.py index c9ac5fb541..2940dfdbbe 100644 --- a/api/py/test/sample/joins/sample_team/sample_join_with_derivations_on_external_parts.py +++ b/api/py/test/sample/joins/sample_team/sample_join_with_derivations_on_external_parts.py @@ -52,14 +52,15 @@ name="test_external_source", team="chronon", key_fields=[ - ("key", DataType.LONG) + ("group_by_subject", DataType.STRING) ], value_fields=[ ("value_str", DataType.STRING), ("value_long", DataType.LONG), ("value_bool", DataType.BOOLEAN) - ] - ) + ], + offline_group_by=event_sample_group_by.v1, + ), ), ExternalPart( ContextualSource( diff --git a/api/py/test/sample/production/joins/sample_team/sample_join_with_derivations_on_external_parts.v1 b/api/py/test/sample/production/joins/sample_team/sample_join_with_derivations_on_external_parts.v1 index 60b1f210cf..cddba0cc8f 100644 --- a/api/py/test/sample/production/joins/sample_team/sample_join_with_derivations_on_external_parts.v1 +++ b/api/py/test/sample/production/joins/sample_team/sample_join_with_derivations_on_external_parts.v1 @@ -172,9 +172,9 @@ "kind": 13, "params": [ { - "name": "key", + "name": "group_by_subject", "dataType": { - "kind": 4 + "kind": 7 } } ], @@ -203,6 +203,67 @@ } ], "name": "ext_test_external_source_values" + }, + "offlineGroupBy": { + "metaData": { + "name": "sample_team.event_sample_group_by.v1", + "online": 1, + "customJson": "{\"lag\": 0, \"groupby_tags\": {\"TO_DEPRECATE\": true}, \"column_tags\": {\"event_sum_7d\": {\"DETAILED_TYPE\": \"CONTINUOUS\"}}}", + "dependencies": [ + "{\"name\": \"wait_for_sample_namespace.sample_table_group_by_ds\", \"spec\": \"sample_namespace.sample_table_group_by/ds={{ ds }}\", \"start\": \"2021-04-09\", \"end\": null}" + ], + "tableProperties": { + "source": "chronon" + }, + "outputNamespace": "sample_namespace", + "team": "sample_team", + "offlineSchedule": "@daily" + }, + "sources": [ + { + "events": { + "table": "sample_namespace.sample_table_group_by", + "query": { + "selects": { + "event": "event_expr", + "group_by_subject": "group_by_expr" + }, + "startPartition": "2021-04-09", + "timeColumn": "ts", + "setups": [] + } + } + } + ], + "keyColumns": [ + "group_by_subject" + ], + "aggregations": [ + { + "inputColumn": "event", + "operation": 7, + "argMap": {}, + "windows": [ + { + "length": 7, + "timeUnit": 1 + } + ] + }, + { + "inputColumn": "event", + "operation": 7, + "argMap": {} + }, + { + "inputColumn": "event", + "operation": 12, + "argMap": { + "k": "200", + "percentiles": "[0.99, 0.95, 0.5]" + } + } + ] } } }, diff --git a/api/py/test/test_compile.py b/api/py/test/test_compile.py index 29782ed14f..3621891194 100644 --- a/api/py/test/test_compile.py +++ b/api/py/test/test_compile.py @@ -399,3 +399,48 @@ def test_compile_inline_group_by(): join = json2thrift(file.read(), Join) assert len(join.joinParts) == 1 assert join.joinParts[0].groupBy.metaData.team == "unit_test" + + +def test_compile_external_source_with_offline_group_by(): + """ + Test that compiling a join with an external source that has an offlineGroupBy + correctly materializes the offlineGroupBy in the external source. + """ + runner = CliRunner() + input_path = "joins/sample_team/sample_join_with_derivations_on_external_parts.py" + result = _invoke_cli_with_params(runner, input_path) + assert result.exit_code == 0 + + # Verify the compiled join contains the external source with offlineGroupBy + path = "sample/production/joins/sample_team/sample_join_with_derivations_on_external_parts.v1" + full_file_path = _get_full_file_path(path) + _assert_file_exists(full_file_path, f"Expected {os.path.basename(path)} to be materialized, but it was not.") + + with open(full_file_path, "r") as file: + join = json2thrift(file.read(), Join) + + # Verify the join has online external parts + assert join.onlineExternalParts is not None, "Expected onlineExternalParts to be present" + assert len(join.onlineExternalParts) > 0, "Expected at least one external part" + + # Find the external source with offlineGroupBy + external_source_with_offline_gb = None + for external_part in join.onlineExternalParts: + if external_part.source.metadata.name == "test_external_source": + external_source_with_offline_gb = external_part.source + break + + assert external_source_with_offline_gb is not None, "Expected to find test_external_source" + + # Verify the offlineGroupBy is present and has the expected properties + assert external_source_with_offline_gb.offlineGroupBy is not None, ( + "Expected offlineGroupBy to be present in test_external_source" + ) + + offline_gb = external_source_with_offline_gb.offlineGroupBy + assert offline_gb.keyColumns == ["group_by_subject"], f"Expected key columns to be ['group_by_subject'], got {offline_gb.keyColumns}" + assert offline_gb.aggregations is not None, "Expected aggregations to be present" + assert len(offline_gb.aggregations) == 3, f"Expected 3 aggregations, got {len(offline_gb.aggregations)}" + assert offline_gb.metaData.outputNamespace == "sample_namespace", ( + f"Expected output namespace to be 'sample_namespace', got {offline_gb.metaData.outputNamespace}" + ) diff --git a/api/src/main/scala/ai/chronon/api/Extensions.scala b/api/src/main/scala/ai/chronon/api/Extensions.scala index 873b9bacea..3453fd3bad 100644 --- a/api/src/main/scala/ai/chronon/api/Extensions.scala +++ b/api/src/main/scala/ai/chronon/api/Extensions.scala @@ -704,10 +704,11 @@ object Extensions { } implicit class ExternalPartOps(externalPart: ExternalPart) extends ExternalPart(externalPart) { - lazy val fullName: String = + lazy val fullName: String = { Constants.ExternalPrefix + "_" + Option(externalPart.prefix).map(_ + "_").getOrElse("") + externalPart.source.metadata.name.sanitize + } def apply(query: Map[String, Any], flipped: Map[String, String], right_keys: Seq[String]): Map[String, AnyRef] = { val rightToLeft = right_keys.map(k => k -> flipped.getOrElse(k, k)) @@ -745,7 +746,14 @@ object Extensions { } implicit class JoinPartOps(joinPart: JoinPart) extends JoinPart(joinPart) { - lazy val fullPrefix = (Option(prefix) ++ Some(groupBy.getMetaData.cleanName)).mkString("_") + lazy val fullPrefix = { + joinPart match { + case part: ExternalJoinPart => + part.externalJoinFullPrefix + case _ => + (Option(prefix) ++ Some(groupBy.getMetaData.cleanName)).mkString("_") + } + } lazy val leftToRight: Map[String, String] = rightToLeft.map { case (key, value) => value -> key } def valueColumns: Seq[String] = joinPart.groupBy.valueColumns.map(fullPrefix + "_" + _) @@ -883,7 +891,9 @@ object Extensions { private[api] def baseSemanticHash: Map[String, String] = { val leftHash = ThriftJsonCodec.md5Digest(join.left) logger.info(s"Join Left Object: ${ThriftJsonCodec.toJsonStr(join.left)}") - val partHashes = join.joinParts.toScala.map { jp => partOutputTable(jp) -> jp.groupBy.semanticHash }.toMap + val partHashes = join.getRegularAndExternalJoinParts.map { jp => + partOutputTable(jp) -> jp.groupBy.semanticHash + }.toMap val derivedHashMap = Option(join.derivations) .map { derivations => val derivedHash = @@ -911,7 +921,7 @@ object Extensions { } cleanTopicInSource(join.left) - join.getJoinParts.toScala.foreach(_.groupBy.sources.toScala.foreach(cleanTopicInSource)) + join.getRegularAndExternalJoinParts.foreach(_.groupBy.sources.toScala.foreach(cleanTopicInSource)) join } @@ -1008,9 +1018,50 @@ object Extensions { } def setups: Seq[String] = - (join.left.query.setupsSeq ++ join.joinParts.toScala + (join.left.query.setupsSeq ++ join.getRegularAndExternalJoinParts .flatMap(_.groupBy.setups)).distinct + /** + * Converts offline-capable ExternalParts to JoinParts for unified processing during backfill. + * This enables external sources with offlineGroupBy to participate in offline computation + * while maintaining compatibility with existing join processing logic. + * + * @return Sequence of JoinParts converted from offline-capable ExternalParts + */ + private def getExternalJoinParts: Seq[ExternalJoinPart] = { + Option(join.onlineExternalParts) + .map(_.toScala) + .getOrElse(Seq.empty) + .filter(_.source.offlineGroupBy != null) // Only offline-capable ExternalParts + .map { externalPart => + // Set customJson with fullPrefix override + val offlineGroupBy = externalPart.source.offlineGroupBy.deepCopy() + + // Convert ExternalPart to synthetic JoinPart + val syntheticJoinPart = new JoinPart() + syntheticJoinPart.setGroupBy(offlineGroupBy) + if (externalPart.keyMapping != null) { + syntheticJoinPart.setKeyMapping(externalPart.keyMapping) + } + if (externalPart.prefix != null) { + syntheticJoinPart.setPrefix(externalPart.prefix) + } + new ExternalJoinPart(syntheticJoinPart, externalPart.fullName) + } + } + + /** + * Get all join parts including both regular joinParts and external join parts. + * This provides a unified view of all join parts for processing. + * + * @return Sequence containing all JoinParts (regular + converted external) + */ + def getRegularAndExternalJoinParts: Seq[JoinPart] = { + val regularJoinParts = Option(join.joinParts).map(_.toScala).getOrElse(Seq.empty) + val externalJoinParts = getExternalJoinParts + regularJoinParts ++ externalJoinParts + } + def copyForVersioningComparison(): Join = { // When we compare previous-run join to current join to detect changes requiring table migration // these are the fields that should be checked to not have accidental recomputes diff --git a/api/src/main/scala/ai/chronon/api/ExternalJoinPart.scala b/api/src/main/scala/ai/chronon/api/ExternalJoinPart.scala new file mode 100644 index 0000000000..24f50c23fd --- /dev/null +++ b/api/src/main/scala/ai/chronon/api/ExternalJoinPart.scala @@ -0,0 +1,9 @@ +package ai.chronon.api + +class ExternalJoinPart(joinPart: JoinPart, fullPrefix: String) extends JoinPart(joinPart) { + lazy val externalJoinFullPrefix: String = fullPrefix + + override def deepCopy(): JoinPart = { + new ExternalJoinPart(joinPart.deepCopy(), externalJoinFullPrefix) + } +} diff --git a/api/thrift/api.thrift b/api/thrift/api.thrift index 21c484bf6c..03e3b16a4c 100644 --- a/api/thrift/api.thrift +++ b/api/thrift/api.thrift @@ -121,6 +121,8 @@ struct ExternalSource { 2: optional TDataType keySchema 3: optional TDataType valueSchema 4: optional ExternalSourceFactoryConfig factoryConfig + // GroupBy to be used for offline backfill - enables PITC offline computation + 5: optional GroupBy offlineGroupBy } /** diff --git a/spark/src/main/scala/ai/chronon/spark/Analyzer.scala b/spark/src/main/scala/ai/chronon/spark/Analyzer.scala index 403ce2ca94..383412427e 100644 --- a/spark/src/main/scala/ai/chronon/spark/Analyzer.scala +++ b/spark/src/main/scala/ai/chronon/spark/Analyzer.scala @@ -19,7 +19,7 @@ package ai.chronon.spark import ai.chronon.api import ai.chronon.api.DataModel.{DataModel, Entities, Events} import ai.chronon.api.Extensions._ -import ai.chronon.api.{Accuracy, AggregationPart, Constants, DataType, TimeUnit, Window} +import ai.chronon.api.{Accuracy, AggregationPart, Constants, DataKind, DataType, ExternalJoinPart, TDataType, TimeUnit, Window} import ai.chronon.online.SparkConversions import ai.chronon.spark.Driver.parseConf import ai.chronon.spark.Extensions.StructTypeOps @@ -368,11 +368,14 @@ class Analyzer(tableUtils: TableUtils, val leftSchema = leftDf.schema.fields .map(field => (field.name, SparkConversions.toChrononType(field.name, field.dataType))) val joinIntermediateValuesMetadata = ListBuffer[AggregationMetadata]() + val externalGroupByMetadata = ListBuffer[AggregationMetadata]() val keysWithError: ListBuffer[(String, String)] = ListBuffer.empty[(String, String)] val gbStartPartitions = mutable.Map[String, List[String]]() val noAccessTables = mutable.Set[String]() ++= leftNoAccessTables // Pair of (table name, group_by name, expected_start) which indicate that the table no not have data available for the required group_by val dataAvailabilityErrors: ListBuffer[(String, String, String)] = ListBuffer.empty[(String, String, String)] + // ExternalSource schema validation errors + val externalSourceErrors: ListBuffer[String] = ListBuffer.empty[String] val rangeToFill = JoinUtils.getRangesToFill(joinConf.left, tableUtils, endDate, historicalBackfill = joinConf.historicalBackfill) @@ -388,7 +391,7 @@ class Analyzer(tableUtils: TableUtils, ) .getOrElse(Seq.empty) - joinConf.joinParts.toScala.foreach { part => + joinConf.getRegularAndExternalJoinParts.foreach { part => val analyzeGroupByResult = analyzeGroupBy( part.groupBy, @@ -398,7 +401,9 @@ class Analyzer(tableUtils: TableUtils, skipTimestampCheck = skipTimestampCheck || leftNoAccessTables.nonEmpty, validateTablePermission = validateTablePermission ) - joinIntermediateValuesMetadata ++= analyzeGroupByResult.outputMetadata.map { aggMeta => + + val target = if (!part.isInstanceOf[ExternalJoinPart]) joinIntermediateValuesMetadata else externalGroupByMetadata + target ++= analyzeGroupByResult.outputMetadata.map { aggMeta => AggregationMetadata(part.fullPrefix + "_" + aggMeta.name, aggMeta.columnType, aggMeta.operation, @@ -406,6 +411,7 @@ class Analyzer(tableUtils: TableUtils, aggMeta.inputColumn, part.getGroupBy.getMetaData.getName) } + // Run validation checks. keysWithError ++= runSchemaValidation(leftSchema.toMap, analyzeGroupByResult.keySchema.toMap, part.rightToLeft) val subPartitionFilters = @@ -422,7 +428,13 @@ class Analyzer(tableUtils: TableUtils, if (gbStartPartition.nonEmpty) gbStartPartitions += (part.groupBy.metaData.name -> gbStartPartition) } + + val externalGroupBySchema = externalGroupByMetadata.map(aggregation => (aggregation.name, aggregation.columnType)) + if (joinConf.onlineExternalParts != null) { + // Validate ExternalSource schemas if they have offlineGroupBy configured + externalSourceErrors ++= runExternalSourceCheck(joinConf, externalGroupBySchema) + joinConf.onlineExternalParts.toScala.foreach { part => joinIntermediateValuesMetadata ++= part.source.valueFields.map { field => AggregationMetadata(part.fullName + "_" + field.name, @@ -437,7 +449,7 @@ class Analyzer(tableUtils: TableUtils, val rightSchema = joinIntermediateValuesMetadata.map(aggregation => (aggregation.name, aggregation.columnType)) - val keyColumns: List[String] = joinConf.joinParts.toScala + val keyColumns: List[String] = joinConf.getRegularAndExternalJoinParts.toList .flatMap(joinPart => { val keyCols: Seq[String] = joinPart.groupBy.keyColumns.toScala if (joinPart.keyMapping == null) keyCols @@ -507,7 +519,9 @@ class Analyzer(tableUtils: TableUtils, logger.info(s"$gbName : ${startPartitions.mkString(",")}") } } - if (keysWithError.isEmpty && noAccessTables.isEmpty && dataAvailabilityErrors.isEmpty) { + if ( + keysWithError.isEmpty && noAccessTables.isEmpty && dataAvailabilityErrors.isEmpty && externalSourceErrors.isEmpty + ) { logger.info("----- Backfill validation completed. No errors found. -----") } else { logger.info(s"----- Schema validation completed. Found ${keysWithError.size} errors") @@ -519,6 +533,8 @@ class Analyzer(tableUtils: TableUtils, logger.info(s"---- Data availability check completed. Found issue in ${dataAvailabilityErrors.size} tables ----") dataAvailabilityErrors.foreach(error => logger.info(s"Group_By ${error._2} : Source Tables ${error._1} : Expected start ${error._3}")) + logger.info(s"---- ExternalSource schema validation completed. Found ${externalSourceErrors.size} errors ----") + externalSourceErrors.foreach(error => logger.info(error)) } if (validationAssert) { @@ -526,12 +542,12 @@ class Analyzer(tableUtils: TableUtils, // For joins with bootstrap_parts, do not assert on data availability errors, as bootstrap can cover them // Only print out the errors as a warning assert( - keysWithError.isEmpty && noAccessTables.isEmpty, + keysWithError.isEmpty && noAccessTables.isEmpty && externalSourceErrors.isEmpty, "ERROR: Join validation failed. Please check error message for details." ) } else { assert( - keysWithError.isEmpty && noAccessTables.isEmpty && dataAvailabilityErrors.isEmpty, + keysWithError.isEmpty && noAccessTables.isEmpty && dataAvailabilityErrors.isEmpty && externalSourceErrors.isEmpty, "ERROR: Join validation failed. Please check error message for details." ) } @@ -798,6 +814,174 @@ class Analyzer(tableUtils: TableUtils, analyzeGroupByResult } + /** + * Validates schema compatibility between ExternalPart and its offlineGroupBy. + * This ensures that online and offline serving will produce consistent results. + * + * @param externalPart The external part to validate + * @param externalGroupBySchema The schema of the external groupBy features from offline computation + * @return Sequence of error messages, empty if no errors + */ + def validateOfflineGroupBy(externalPart: api.ExternalPart, + externalGroupBySchema: Seq[(String, DataType)]): Seq[String] = + Option(externalPart.source.offlineGroupBy) + .map(_ => + validateKeySchemaCompatibility(externalPart.source) ++ validateValueSchemaCompatibility(externalPart, + externalGroupBySchema)) + .getOrElse(Seq.empty) + + private def validateKeySchemaCompatibility(externalSource: api.ExternalSource): Seq[String] = { + val errors = scala.collection.mutable.ListBuffer[String]() + + if (externalSource.keySchema == null) { + errors += s"ExternalSource ${externalSource.metadata.name} keySchema cannot be null when offlineGroupBy is specified" + return errors + } + + if (externalSource.offlineGroupBy.keyColumns == null || externalSource.offlineGroupBy.keyColumns.isEmpty) { + errors += s"ExternalSource ${externalSource.metadata.name} offlineGroupBy keyColumns cannot be null or empty" + return errors + } + + val externalKeyFields = externalSource.keyFields + val groupByKeyColumns = externalSource.offlineGroupBy.keyColumns.toScala.toSet + + // Extract field names from external source key schema + val externalKeyNames = externalKeyFields.map(_.name).toSet + + // Validate that GroupBy has key columns that match ExternalSource key schema + val missingKeys = externalKeyNames -- groupByKeyColumns + val extraKeys = groupByKeyColumns -- externalKeyNames + + if (missingKeys.nonEmpty) { + errors += s"ExternalSource ${externalSource.metadata.name} key schema contains columns [${missingKeys.mkString(", ")}] " + + s"that are not present in offlineGroupBy keyColumns [${groupByKeyColumns.mkString(", ")}]. " + + s"All ExternalSource key columns must be present in the GroupBy key columns." + } + + if (extraKeys.nonEmpty) { + errors += s"ExternalSource ${externalSource.metadata.name} offlineGroupBy keyColumns contain [${extraKeys + .mkString(", ")}] " + + s"that are not present in ExternalSource keySchema [${externalKeyNames.mkString(", ")}]. " + + s"GroupBy key columns cannot contain keys not defined in ExternalSource keySchema." + } + + errors + } + + /** + * Recursively clears struct names from a DataType to enable name-agnostic comparison. + * This is needed because TDataType.equals() compares the name field. + * + * @param dataType The DataType to process + * @return A TDataType with all struct names cleared + */ + private def clearStructNames(dataType: DataType): TDataType = { + val tDataType = DataType.toTDataType(dataType) + + // Clear the name if this is a struct type + if (tDataType.getKind == DataKind.STRUCT) { + tDataType.unsetName() + } + + // Recursively clear names in nested types + if (tDataType.isSetParams) { + val params = tDataType.getParams + params.forEach { param => + if (param.isSetDataType) { + val nestedType = DataType.fromTDataType(param.getDataType) + val clearedNested = clearStructNames(nestedType) + param.setDataType(clearedNested) + } + } + } + + tDataType + } + + /** + * Deep comparison of two DataType objects with logging. + * For StructType, ignores the name and only compares fields. + * Uses TDataType conversion for proper deep comparison. + * + * @param expected The expected DataType + * @param actual The actual DataType + * @return true if the types match, false otherwise + */ + private def compareDataTypes(expected: DataType, actual: DataType): Boolean = { + // Convert to TDataType and clear struct names for comparison + val expectedTDataType = clearStructNames(expected) + val actualTDataType = clearStructNames(actual) + + logger.error(s"Expected TDataType: $expectedTDataType") + logger.error(s"Actual TDataType: $actualTDataType") + + expectedTDataType.equals(actualTDataType) + } + + private def validateValueSchemaCompatibility(externalPart: api.ExternalPart, + externalGroupBySchema: Seq[(String, DataType)]): Seq[String] = { + val errors = scala.collection.mutable.ListBuffer[String]() + + if (externalPart.source.valueSchema == null) { + errors += s"ExternalSource ${externalPart.source.metadata.name} valueSchema cannot be null when offlineGroupBy is specified" + return errors + } + + // Get expected schema from ExternalPart (what online expects) + val externalValueFields = externalPart.valueSchemaFull + val expectedSchema = externalValueFields.map(field => (field.name, field.fieldType)).toMap + + // Get actual schema from offline GroupBy computation (what offline produces) + val prefix = externalPart.fullName + "_" + val actualSchema = externalGroupBySchema + .filter(_._1.startsWith(prefix)) + .toMap + + // Check for fields in expected but not in actual (missing fields) + val missingFields = expectedSchema.keySet -- actualSchema.keySet + if (missingFields.nonEmpty) { + errors += s"ExternalSource ${externalPart.source.metadata.name} offline GroupBy is missing value fields: " + + s"[${missingFields.mkString(", ")}]. These fields are defined in the ExternalSource valueSchema " + + s"but are not produced by the offlineGroupBy." + } + + // Check for fields in actual but not in expected (extra fields) + val extraFields = actualSchema.keySet -- expectedSchema.keySet + if (extraFields.nonEmpty) { + errors += s"ExternalSource ${externalPart.source.metadata.name} offline GroupBy produces extra value fields: " + + s"[${extraFields.mkString(", ")}]. These fields are not defined in the ExternalSource valueSchema." + } + + // Check for type mismatches in common fields using deep comparison + val commonFields = expectedSchema.keySet.intersect(actualSchema.keySet) + commonFields.foreach { fieldName => + val expectedType = expectedSchema(fieldName) + val actualType = actualSchema(fieldName) + logger.error(s"=== Validating field '$fieldName' for ExternalSource ${externalPart.source.metadata.name} ===") + if (!compareDataTypes(expectedType, actualType)) { + errors += s"ExternalSource ${externalPart.source.metadata.name} field '$fieldName' has type mismatch: " + + s"expected ${DataType.toString(expectedType)} (from ExternalSource valueSchema) " + + s"but offline GroupBy produces ${DataType.toString(actualType)}" + } + } + + errors + } + + /** + * Validates all ExternalSources in this Join's onlineExternalParts. + * This ensures schema compatibility between ExternalSources and their offlineGroupBy configurations. + * + * @param joinConf The join configuration to validate + * @param externalGroupBySchema The schema of the external groupBy features from offline computation + * @return Sequence of error messages, empty if no errors + */ + def runExternalSourceCheck(joinConf: api.Join, externalGroupBySchema: Seq[(String, DataType)]): Seq[String] = + Option(joinConf.onlineExternalParts) + .map(_.toScala.flatMap(part => validateOfflineGroupBy(part, externalGroupBySchema))) + .getOrElse(Seq.empty) + def run(exportSchema: Boolean = false): Unit = conf match { case confPath: String => diff --git a/spark/src/main/scala/ai/chronon/spark/BootstrapInfo.scala b/spark/src/main/scala/ai/chronon/spark/BootstrapInfo.scala index e89ab84041..9089e1ae8f 100644 --- a/spark/src/main/scala/ai/chronon/spark/BootstrapInfo.scala +++ b/spark/src/main/scala/ai/chronon/spark/BootstrapInfo.scala @@ -83,44 +83,47 @@ object BootstrapInfo { // Enrich each join part with the expected output schema logger.info(s"\nCreating BootstrapInfo for GroupBys for Join ${joinConf.metaData.name}") - var joinParts: Seq[JoinPartMetadata] = Option(joinConf.joinParts.toScala) - .getOrElse(Seq.empty) - .map(part => { - // set computeDependency to False as we compute dependency upstream - val gb = GroupBy.from(part.groupBy, range, tableUtils, computeDependency) - val keySchema = SparkConversions - .toChrononSchema(gb.keySchema) - .map(field => StructField(part.rightToLeft(field._1), field._2)) - - Analyzer.validateAvroCompatibility(tableUtils, gb, part.groupBy) - - val keyAndPartitionFields = - gb.keySchema.fields ++ Seq(org.apache.spark.sql.types.StructField(tableUtils.partitionColumn, StringType)) - // todo: this change is only valid for offline use case - // we need to revisit logic for the logging part to make sure the derived columns are also logged - // to make bootstrap continue to work - val outputSchema = if (part.groupBy.hasDerivations) { - val sparkSchema = { - StructType(SparkConversions.fromChrononSchema(gb.outputSchema).fields ++ keyAndPartitionFields) - } - val dummyOutputDf = tableUtils.sparkSession - .createDataFrame(tableUtils.sparkSession.sparkContext.parallelize(immutable.Seq[Row]()), sparkSchema) - val finalOutputColumns = part.groupBy.derivationsScala.finalOutputColumn(dummyOutputDf.columns).toSeq - val derivedDummyOutputDf = dummyOutputDf.select(finalOutputColumns: _*) - val columns = SparkConversions.toChrononSchema( - StructType(derivedDummyOutputDf.schema.filterNot(keyAndPartitionFields.contains))) - api.StructType("", columns.map(tup => api.StructField(tup._1, tup._2))) - } else { - gb.outputSchema + + // Get all join parts including both regular and external join parts + val allJoinParts = joinConf.getRegularAndExternalJoinParts + + var joinParts: Seq[JoinPartMetadata] = allJoinParts.map(part => { + // set computeDependency to False as we compute dependency upstream + val gb = GroupBy.from(part.groupBy, range, tableUtils, computeDependency) + val keySchema = SparkConversions + .toChrononSchema(gb.keySchema) + .map(field => StructField(part.rightToLeft(field._1), field._2)) + + Analyzer.validateAvroCompatibility(tableUtils, gb, part.groupBy) + + val keyAndPartitionFields = + gb.keySchema.fields ++ Seq(org.apache.spark.sql.types.StructField(tableUtils.partitionColumn, StringType)) + // todo: this change is only valid for offline use case + // we need to revisit logic for the logging part to make sure the derived columns are also logged + // to make bootstrap continue to work + val outputSchema = if (part.groupBy.hasDerivations) { + val sparkSchema = { + StructType(SparkConversions.fromChrononSchema(gb.outputSchema).fields ++ keyAndPartitionFields) } - val valueSchema = outputSchema.fields.map(part.constructJoinPartSchema) - JoinPartMetadata(part, keySchema, valueSchema, Map.empty) // will be populated below - }) + val dummyOutputDf = tableUtils.sparkSession + .createDataFrame(tableUtils.sparkSession.sparkContext.parallelize(immutable.Seq[Row]()), sparkSchema) + val finalOutputColumns = part.groupBy.derivationsScala.finalOutputColumn(dummyOutputDf.columns).toSeq + val derivedDummyOutputDf = dummyOutputDf.select(finalOutputColumns: _*) + val columns = SparkConversions.toChrononSchema( + StructType(derivedDummyOutputDf.schema.filterNot(keyAndPartitionFields.contains))) + api.StructType("", columns.map(tup => api.StructField(tup._1, tup._2))) + } else { + gb.outputSchema + } + val valueSchema = outputSchema.fields.map(part.constructJoinPartSchema) + JoinPartMetadata(part, keySchema, valueSchema, Map.empty) // will be populated below + }) - // Enrich each external part with the expected output schema - logger.info(s"\nCreating BootstrapInfo for ExternalParts for Join ${joinConf.metaData.name}") + // Enrich online only external parts with the expected output schema + logger.info(s"\nCreating BootstrapInfo for online-only ExternalParts for Join ${joinConf.metaData.name}") val externalParts: Seq[ExternalPartMetadata] = Option(joinConf.onlineExternalParts.toScala) .getOrElse(Seq.empty) + .filter(_.source.offlineGroupBy == null) // Only online-only ExternalParts .map(part => ExternalPartMetadata(part, part.keySchemaFull, part.valueSchemaFull)) val leftFields = leftSchema @@ -355,4 +358,5 @@ object BootstrapInfo { bootstrapInfo } + } diff --git a/spark/src/main/scala/ai/chronon/spark/JoinBase.scala b/spark/src/main/scala/ai/chronon/spark/JoinBase.scala index 31becd8579..5aa209a965 100644 --- a/spark/src/main/scala/ai/chronon/spark/JoinBase.scala +++ b/spark/src/main/scala/ai/chronon/spark/JoinBase.scala @@ -474,7 +474,7 @@ abstract class JoinBase(joinConf: api.Join, assert(Option(joinConf.metaData.team).nonEmpty, s"join.metaData.team needs to be set for join ${joinConf.metaData.name}") - joinConf.joinParts.toScala.foreach { jp => + joinConf.getRegularAndExternalJoinParts.foreach { jp => assert(Option(jp.groupBy.metaData.team).nonEmpty, s"groupBy.metaData.team needs to be set for joinPart ${jp.groupBy.metaData.name}") } diff --git a/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala b/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala index b008c36a09..7dc771ad12 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala @@ -18,7 +18,6 @@ package ai.chronon.spark.test import ai.chronon.aggregator.test.Column import ai.chronon.api -import ai.chronon.api.Builders.Query import ai.chronon.api.Extensions.MetadataOps import ai.chronon.api._ import ai.chronon.spark.Extensions._ @@ -909,4 +908,223 @@ class AnalyzerTest { analyzer.analyzeGroupBy(tableGroupBy) } + @Test + def testExternalSourceValidationWithMatchingSchemas(): Unit = { + val spark: SparkSession = SparkSessionBuilder.build("AnalyzerTest", local = true) + val tableUtils = TableUtils(spark) + val namespace = "analyzer_test_ns" + "_" + Random.alphanumeric.take(6).mkString + tableUtils.createDatabase(namespace) + + // Create left side table + val leftSchema = List(Column("user_id", api.StringType, 100)) + val leftTable = s"$namespace.left_table" + DataFrameGen.events(spark, leftSchema, 10, partitions = 5).save(leftTable) + + // Create compatible schemas using the correct DataType objects + val keySchema = StructType("key", Array(StructField("user_id", StringType))) + val valueSchema = StructType("value", Array( + StructField("feature_value", DoubleType), + StructField("list_value", + api.ListType( + api.StructType("contradiction", + Array( + StructField("reason", api.StringType), + StructField("standardRule", api.StringType), + StructField("additionalRule", api.StringType) + ) + ) + ) + ) + )) + + // Create right side table for GroupBy + val rightSchema = List(Column("user_id", api.StringType, 100), Column("value", api.DoubleType, 100)) + val rightTable = s"$namespace.right_table" + DataFrameGen.events(spark, rightSchema, 100, partitions = 10).save(rightTable) + + // Create a query and source for the GroupBy + val query = Builders.Query(selects = Map("feature_value" -> "value"), startPartition = oneYearAgo) + val source = Builders.Source.events(query, rightTable) + + // Create GroupBy with matching key columns, sources, and derivation to match external feature name + // The external part will have fullName "ext_test_external_source", so the feature will be "ext_test_external_source_feature_value" + val groupBy = Builders.GroupBy( + keyColumns = Seq("user_id"), + sources = Seq(source), + derivations = Seq( + Builders.Derivation(name = "feature_value", expression = "feature_value"), + Builders.Derivation(name = "list_value", + expression = "array(named_struct('reason', 'test_reason', 'standardRule', 'test_standard', 'additionalRule', 'test_additional'))") + ), + metaData = Builders.MetaData(name = "test_external_gb", namespace = namespace) + ) + + // Create ExternalSource with compatible schemas + val metadata = Builders.MetaData(name = "test_external_source") + val externalSource = Builders.ExternalSource(metadata, keySchema, valueSchema) + + // Manually set the offlineGroupBy field since the builder doesn't support it yet + externalSource.setOfflineGroupBy(groupBy) + + // Wrap in ExternalPart + val externalPart = Builders.ExternalPart(externalSource = externalSource) + + // Create Join with ExternalPart + val joinConf = Builders.Join( + left = Builders.Source.events(Builders.Query(startPartition = oneMonthAgo), table = leftTable), + externalParts = Seq(externalPart), + metaData = Builders.MetaData(name = "test_join_external", namespace = namespace, team = "chronon") + ) + + // Create analyzer instance and call analyzeJoin + val analyzer = new Analyzer(tableUtils, joinConf, oneMonthAgo, today, skipTimestampCheck = true, validateTablePermission = false) + val result = analyzer.analyzeJoin(joinConf) + // If no exception is thrown, validation passed + assertTrue("Expected successful validation", result != null) + } + + @Test(expected = classOf[java.lang.AssertionError]) + def testExternalSourceValidationWithMismatchedKeySchemas(): Unit = { + val spark: SparkSession = SparkSessionBuilder.build("AnalyzerTest", local = true) + val tableUtils = TableUtils(spark) + val namespace = "analyzer_test_ns" + "_" + Random.alphanumeric.take(6).mkString + tableUtils.createDatabase(namespace) + + // Create left side table + val leftSchema = List(Column("user_id", api.StringType, 100)) + val leftTable = s"$namespace.left_table" + DataFrameGen.events(spark, leftSchema, 10, partitions = 5).save(leftTable) + + // Create key schema with different fields + val keySchema = StructType("key", Array(StructField("user_id", StringType))) + val valueSchema = StructType("value", Array(StructField("feature_value", DoubleType))) + + // Create right side table for GroupBy + val rightSchema = List(Column("different_key", api.StringType, 100), Column("feature_value", api.DoubleType, 100)) + val rightTable = s"$namespace.right_table" + DataFrameGen.events(spark, rightSchema, 100, partitions = 10).save(rightTable) + + // Create a query and source for the GroupBy + val query = Builders.Query(selects = Map("feature_value" -> "feature_value"), startPartition = oneYearAgo) + val source = Builders.Source.events(query, rightTable) + + // Create GroupBy with different key columns + val groupBy = Builders.GroupBy( + keyColumns = Seq("different_key"), // Mismatched key column + sources = Seq(source), + metaData = Builders.MetaData(name = "test_external_gb_mismatch", namespace = namespace) + ) + + // Create ExternalSource with incompatible schemas + val metadata = Builders.MetaData(name = "test_external_source") + val externalSource = Builders.ExternalSource(metadata, keySchema, valueSchema) + externalSource.setOfflineGroupBy(groupBy) + + // Wrap in ExternalPart + val externalPart = Builders.ExternalPart(externalSource = externalSource) + + // Create Join with ExternalPart + val joinConf = Builders.Join( + left = Builders.Source.events(Builders.Query(startPartition = oneMonthAgo), table = leftTable), + externalParts = Seq(externalPart), + metaData = Builders.MetaData(name = "test_join_external_key_mismatch", namespace = namespace, team = "chronon") + ) + + // This should throw AssertionError due to validation errors + val analyzer = new Analyzer(tableUtils, joinConf, oneMonthAgo, today, skipTimestampCheck = true, validateTablePermission = false) + analyzer.analyzeJoin(joinConf, validationAssert = true) + } + + @Test(expected = classOf[java.lang.AssertionError]) + def testExternalSourceValidationWithMismatchedValueSchemas(): Unit = { + val spark: SparkSession = SparkSessionBuilder.build("AnalyzerTest", local = true) + val tableUtils = TableUtils(spark) + val namespace = "analyzer_test_ns" + "_" + Random.alphanumeric.take(6).mkString + tableUtils.createDatabase(namespace) + + // Create left side table + val leftSchema = List(Column("user_id", api.StringType, 100)) + val leftTable = s"$namespace.left_table" + DataFrameGen.events(spark, leftSchema, 10, partitions = 5).save(leftTable) + + // Create compatible key schema but mismatched value schema + val keySchema = StructType("key", Array(StructField("user_id", StringType))) + val valueSchema = StructType("value", Array(StructField("feature_value", DoubleType))) + + // Create right side table for GroupBy + val rightSchema = List(Column("user_id", api.StringType, 100), Column("different_feature", api.DoubleType, 100)) + val rightTable = s"$namespace.right_table" + DataFrameGen.events(spark, rightSchema, 100, partitions = 10).save(rightTable) + + // Create a source for the GroupBy - this is needed for valueColumns to work + val query = Builders.Query(selects = Map("different_feature" -> "different_feature"), startPartition = oneYearAgo) + val source = Builders.Source.events(query, rightTable) + + // Create GroupBy with different derived column name that doesn't match external feature name + // The external part expects "ext_test_external_source_feature_value" but GroupBy produces "wrong_name" + val groupBy = Builders.GroupBy( + keyColumns = Seq("user_id"), + sources = Seq(source), + derivations = Seq( + Builders.Derivation(name = "wrong_name", expression = "different_feature") + ), + metaData = Builders.MetaData(name = "test_external_gb_value_mismatch", namespace = namespace) + ) + + // Create ExternalSource with incompatible schemas + val metadata = Builders.MetaData(name = "test_external_source") + val externalSource = Builders.ExternalSource(metadata, keySchema, valueSchema) + externalSource.setOfflineGroupBy(groupBy) + + // Wrap in ExternalPart + val externalPart = Builders.ExternalPart(externalSource = externalSource) + + // Create Join with ExternalPart + val joinConf = Builders.Join( + left = Builders.Source.events(Builders.Query(startPartition = oneMonthAgo), table = leftTable), + externalParts = Seq(externalPart), + metaData = Builders.MetaData(name = "test_join_external_value_mismatch", namespace = namespace, team = "chronon") + ) + + // This should throw AssertionError due to validation errors + val analyzer = new Analyzer(tableUtils, joinConf, oneMonthAgo, today, skipTimestampCheck = true, validateTablePermission = false) + analyzer.analyzeJoin(joinConf, validationAssert = true) + } + + @Test + def testExternalSourceValidationWithNullOfflineGroupBy(): Unit = { + val spark: SparkSession = SparkSessionBuilder.build("AnalyzerTest", local = true) + val tableUtils = TableUtils(spark) + val namespace = "analyzer_test_ns" + "_" + Random.alphanumeric.take(6).mkString + tableUtils.createDatabase(namespace) + + // Create left side table + val leftSchema = List(Column("user_id", api.StringType, 100)) + val leftTable = s"$namespace.left_table" + DataFrameGen.events(spark, leftSchema, 10, partitions = 5).save(leftTable) + + // Create ExternalSource without offlineGroupBy + val keySchema = StructType("key", Array(StructField("user_id", StringType))) + val valueSchema = StructType("value", Array(StructField("feature_value", DoubleType))) + + val metadata = Builders.MetaData(name = "test_external_source") + val externalSource = Builders.ExternalSource(metadata, keySchema, valueSchema) + // Don't set offlineGroupBy (it remains null) + + // Wrap in ExternalPart + val externalPart = Builders.ExternalPart(externalSource = externalSource) + + // Create Join with ExternalPart + val joinConf = Builders.Join( + left = Builders.Source.events(Builders.Query(startPartition = oneMonthAgo), table = leftTable), + externalParts = Seq(externalPart), + metaData = Builders.MetaData(name = "test_join_external_null_gb", namespace = namespace, team = "chronon") + ) + + // This should not throw - validation should be skipped when offlineGroupBy is null + val analyzer = new Analyzer(tableUtils, joinConf, oneMonthAgo, today, skipTimestampCheck = true, validateTablePermission = false) + val result = analyzer.analyzeJoin(joinConf) + assertTrue("Expected successful validation when offlineGroupBy is null", result != null) + } + } diff --git a/spark/src/test/scala/ai/chronon/spark/test/ChainingFetcherTest.scala b/spark/src/test/scala/ai/chronon/spark/test/ChainingFetcherTest.scala index b389f6ad37..ed03a947e4 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/ChainingFetcherTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/ChainingFetcherTest.scala @@ -212,7 +212,7 @@ class ChainingFetcherTest extends TestCase { def executeFetch(joinConf: api.Join, endDs: String, namespace: String): (DataFrame, Seq[Row]) = { implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1)) implicit val tableUtils: TableUtils = TableUtils(spark) - val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("ChainingFetcherTest") + val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore(s"ChainingFetcherTest_$namespace") val inMemoryKvStore = kvStoreFunc() val mockApi = new MockApi(kvStoreFunc, namespace) diff --git a/spark/src/test/scala/ai/chronon/spark/test/ExternalSourceBackfillTest.scala b/spark/src/test/scala/ai/chronon/spark/test/ExternalSourceBackfillTest.scala new file mode 100644 index 0000000000..9f124ab2da --- /dev/null +++ b/spark/src/test/scala/ai/chronon/spark/test/ExternalSourceBackfillTest.scala @@ -0,0 +1,437 @@ +/* + * Copyright (C) 2023 The Chronon Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.chronon.spark.test + +import ai.chronon.aggregator.test.Column +import ai.chronon.api.{Accuracy, Builders, DoubleType, LongType, Operation, StringType, StructField, StructType, TimeUnit, Window} +import ai.chronon.spark.Extensions._ +import ai.chronon.spark._ +import ai.chronon.spark.catalog.TableUtils +import org.apache.spark.sql.SparkSession +import org.junit.Assert._ +import org.junit.Test + +import scala.util.Random + +class ExternalSourceBackfillTest { + val spark: SparkSession = SparkSessionBuilder.build("ExternalSourceBackfillTest", local = true) + private val tableUtils = TableUtils(spark) + private val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) + private val monthAgo = tableUtils.partitionSpec.minus(today, new Window(30, TimeUnit.DAYS)) + + @Test + def testExternalSourceBackfillComputeJoin(): Unit = { + val spark: SparkSession = + SparkSessionBuilder.build("ExternalSourceBackfillTest" + "_" + Random.alphanumeric.take(6).mkString, local = true) + val tableUtils = TableUtils(spark) + val namespace = "test_namespace_" + Random.alphanumeric.take(6).mkString + tableUtils.createDatabase(namespace) + + // Create user transaction data for offline GroupBy + // IMPORTANT: Use same cardinality as left source to ensure key overlap + val transactionColumns = List( + Column("user_id", StringType, 100), // Same cardinality as userEventColumns + Column("amount", LongType, 1000), + Column("transaction_type", StringType, 5) + ) + + val transactionTable = s"$namespace.user_transactions" + spark.sql(s"DROP TABLE IF EXISTS $transactionTable") + DataFrameGen.events(spark, transactionColumns, 5000, partitions = 100).save(transactionTable) + + // Create offline GroupBy for external source + val offlineGroupBy = Builders.GroupBy( + sources = Seq( + Builders.Source.events( + query = Builders.Query( + selects = Map("user_id" -> "user_id", "amount" -> "amount"), + timeColumn = "ts" + ), + table = transactionTable + ) + ), + keyColumns = Seq("user_id"), + aggregations = Seq( + Builders.Aggregation( + operation = Operation.SUM, + inputColumn = "amount", + windows = Seq(new Window(30, TimeUnit.DAYS)) + ) + ), + derivations = Seq( + Builders.Derivation.star(), // Keep all base aggregation columns + Builders.Derivation( + name = s"es_amount", + expression = "amount_sum_30d" + ) + ), + metaData = Builders.MetaData(name = s"gb_amount", namespace = namespace), + accuracy = Accuracy.SNAPSHOT + ) + + // Create ExternalSource with offline GroupBy + val externalSource = Builders.ExternalSource( + metadata = Builders.MetaData(name = s"test_external_source"), + keySchema = StructType("external_keys", Array(StructField("user_id", StringType))), + valueSchema = StructType("external_values", Array(StructField("es_amount", LongType))) + ) + externalSource.setOfflineGroupBy(offlineGroupBy) + + // Create left source (user events to join against) + val userEventColumns = List( + Column("user_id", StringType, 100), + Column("event_type", StringType, 10), + Column("session_id", StringType, 200) + ) + + val userEventTable = s"$namespace.user_events" + spark.sql(s"DROP TABLE IF EXISTS $userEventTable") + DataFrameGen.events(spark, userEventColumns, 1000, partitions = 50).save(userEventTable) + + // Create Join configuration with ExternalPart + val joinConf = Builders.Join( + left = Builders.Source.events( + query = Builders.Query( + selects = Map("user_id" -> "user_id", "event_type" -> "event_type", "session_id" -> "session_id"), + timeColumn = "ts" + ), + table = userEventTable + ), + joinParts = Seq(), + externalParts = Seq( + Builders.ExternalPart( + externalSource, + prefix = "txn" + ) + ), + metaData = Builders.MetaData(name = s"test_external_part", namespace = namespace) + ) + + // Run analyzer to ensure GroupBy tables are created + val analyzer = new Analyzer(tableUtils, joinConf, monthAgo, today) + analyzer.run() + + // Create Join and compute + val endPartition = monthAgo + val join = new Join(joinConf = joinConf, endPartition = endPartition, tableUtils) + val computed = join.computeJoin(Some(10)) + + // Verify results + assertNotNull("Computed result should not be null", computed) + assertTrue("Result should have rows", computed.count() > 0) + + // Verify that external source columns are present + val columns = computed.columns.toSet + assertTrue("Should contain left source columns", columns.contains("user_id")) + assertTrue("Should contain left source columns", columns.contains("event_type")) + assertTrue("Should contain left source columns", columns.contains("session_id")) + assertTrue("Should contain external source prefixed columns", + columns.exists(_.startsWith("ext_"))) + + // Verify that external source columns have non-null data + val externalColumns = computed.columns.filter(_.startsWith("ext_")) + assertTrue("Should have at least one external column", externalColumns.nonEmpty) + externalColumns.foreach { col => + val nonNullCount = computed.filter(s"$col IS NOT NULL").count() + assertTrue(s"External column $col should have non-null values (found $nonNullCount non-null rows)", + nonNullCount > 0) + } + + spark.sql(s"DROP DATABASE IF EXISTS $namespace CASCADE") + } + + @Test + def testMixedExternalAndJoinParts(): Unit = { + val spark: SparkSession = + SparkSessionBuilder.build("ExternalSourceBackfillTest_Mixed" + "_" + Random.alphanumeric.take(6).mkString, local = true) + val tableUtils = TableUtils(spark) + val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) + val monthAgo = tableUtils.partitionSpec.minus(today, new Window(30, TimeUnit.DAYS)) + val namespace = "test_namespace_" + Random.alphanumeric.take(6).mkString + tableUtils.createDatabase(namespace) + + // Create transaction data for external source GroupBy + val transactionColumns = List( + Column("user_id", StringType, 100), + Column("purchase_amount", LongType, 5000) + ) + + val transactionTable = s"$namespace.purchase_transactions" + spark.sql(s"DROP TABLE IF EXISTS $transactionTable") + DataFrameGen.events(spark, transactionColumns, 1500, partitions = 80).save(transactionTable) + + // Create session data for regular JoinPart GroupBy + val sessionColumns = List( + Column("user_id", StringType, 100), + Column("session_duration", LongType, 7200) + ) + + val sessionTable = s"$namespace.user_sessions" + spark.sql(s"DROP TABLE IF EXISTS $sessionTable") + DataFrameGen.events(spark, sessionColumns, 1200, partitions = 60).save(sessionTable) + + // Create GroupBy for external source (purchases) + val purchaseGroupBy = Builders.GroupBy( + sources = Seq( + Builders.Source.events( + query = Builders.Query( + selects = Map("user_id" -> "user_id", "purchase_amount" -> "purchase_amount"), + timeColumn = "ts" + ), + table = transactionTable + ) + ), + keyColumns = Seq("user_id"), + aggregations = Seq( + Builders.Aggregation( + operation = Operation.AVERAGE, + inputColumn = "purchase_amount", + windows = Seq(new Window(7, TimeUnit.DAYS)) + ) + ), + derivations = Seq( + Builders.Derivation.star(), // Keep all base aggregation columns + Builders.Derivation( + name = s"purchase_amount", + expression = "purchase_amount_average_7d" + ) + ), + metaData = Builders.MetaData(name = s"gb_purchase", namespace = namespace), + accuracy = Accuracy.SNAPSHOT + ) + + // Create GroupBy for regular JoinPart (sessions) + val sessionGroupBy = Builders.GroupBy( + sources = Seq( + Builders.Source.events( + query = Builders.Query( + selects = Map("user_id" -> "user_id", "session_duration" -> "session_duration"), + timeColumn = "ts" + ), + table = sessionTable + ) + ), + keyColumns = Seq("user_id"), + aggregations = Seq( + Builders.Aggregation( + operation = Operation.COUNT, + inputColumn = "session_duration", + windows = Seq(new Window(14, TimeUnit.DAYS)) + ) + ), + metaData = Builders.MetaData(name = s"gb_session", namespace = namespace), + accuracy = Accuracy.SNAPSHOT + ) + + // Create ExternalSource with offline GroupBy + val externalSource = Builders.ExternalSource( + metadata = Builders.MetaData(name = s"es_purchase"), + keySchema = StructType("external_keys", Array(StructField("user_id", StringType))), + valueSchema = StructType("external_values", Array(StructField("purchase_amount", DoubleType))) + ) + externalSource.setOfflineGroupBy(purchaseGroupBy) + + // Create left source + val userActivityColumns = List( + Column("user_id", StringType, 100), + Column("page_views", LongType, 50) + ) + + val userActivityTable = s"$namespace.user_activity" + spark.sql(s"DROP TABLE IF EXISTS $userActivityTable") + DataFrameGen.events(spark, userActivityColumns, 800, partitions = 40).save(userActivityTable) + + // Create Join with both ExternalPart and regular JoinPart + val joinConf = Builders.Join( + left = Builders.Source.events( + query = Builders.Query( + selects = Map("user_id" -> "user_id", "page_views" -> "page_views"), + timeColumn = "ts" + ), + table = userActivityTable + ), + joinParts = Seq( + Builders.JoinPart( + groupBy = sessionGroupBy, + prefix = "session" + ) + ), + externalParts = Seq( + Builders.ExternalPart( + externalSource, + prefix = "purchase" + ) + ), + metaData = Builders.MetaData(name = s"test_mixed_join", namespace = namespace) + ) + + // Run analyzer to ensure all GroupBy tables are created + val analyzer = new Analyzer(tableUtils, joinConf, monthAgo, today) + analyzer.run() + + // Create Join and compute + val endPartition = monthAgo + val join = new Join(joinConf = joinConf, endPartition = endPartition, tableUtils) + val computed = join.computeJoin(Some(10)) + + // Verify results + assertNotNull("Computed result should not be null", computed) + assertTrue("Result should have rows", computed.count() > 0) + + // Verify that both regular JoinPart and ExternalPart columns are present + val columns = computed.columns.toSet + assertTrue("Should contain left source columns", columns.contains("user_id")) + assertTrue("Should contain left source columns", columns.contains("page_views")) + assertTrue("Should contain regular JoinPart prefixed columns", + columns.exists(_.startsWith("session_"))) + assertTrue("Should contain external source prefixed columns", + columns.exists(_.endsWith("purchase_amount"))) + + // Verify that external source columns have non-null data + val externalColumns = computed.columns.filter(col => col.startsWith("ext_") || col.contains("purchase_amount")) + assertTrue("Should have at least one external column", externalColumns.nonEmpty) + externalColumns.foreach { col => + val nonNullCount = computed.filter(s"$col IS NOT NULL").count() + assertTrue(s"External column $col should have non-null values (found $nonNullCount non-null rows)", + nonNullCount > 0) + } + + // Verify that regular JoinPart columns have non-null data + val joinPartColumns = computed.columns.filter(_.startsWith("session_")) + assertTrue("Should have at least one JoinPart column", joinPartColumns.nonEmpty) + joinPartColumns.foreach { col => + val nonNullCount = computed.filter(s"$col IS NOT NULL").count() + assertTrue(s"JoinPart column $col should have non-null values (found $nonNullCount non-null rows)", + nonNullCount > 0) + } + + spark.sql(s"DROP DATABASE IF EXISTS $namespace CASCADE") + } + + @Test + def testExternalSourceBackfillWithKeyMapping(): Unit = { + val spark: SparkSession = + SparkSessionBuilder.build("ExternalSourceBackfillTest_KeyMapping" + "_" + Random.alphanumeric.take(6).mkString, local = true) + val tableUtils = TableUtils(spark) + val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) + val monthAgo = tableUtils.partitionSpec.minus(today, new Window(30, TimeUnit.DAYS)) + val namespace = "test_namespace_" + Random.alphanumeric.take(6).mkString + tableUtils.createDatabase(namespace) + + // Create feature data with internal_user_id + val featureColumns = List( + Column("internal_user_id", StringType, 100), + Column("feature_score", LongType, 100) + ) + + val featureTable = s"$namespace.user_features" + spark.sql(s"DROP TABLE IF EXISTS $featureTable") + // Generate 135 partitions to ensure we have enough data for 30-day window + 1-day shift + monthAgo (30 days) + buffer + // Need to cover: today back to (monthAgo - 30-day window - 1-day shift) = ~91 days + buffer + DataFrameGen.events(spark, featureColumns, 2000, partitions = 135).save(featureTable) + + // Create GroupBy using internal_user_id + val featureGroupBy = Builders.GroupBy( + sources = Seq( + Builders.Source.events( + query = Builders.Query( + selects = Map("internal_user_id" -> "internal_user_id", "feature_score" -> "feature_score"), + timeColumn = "ts" + ), + table = featureTable + ) + ), + keyColumns = Seq("internal_user_id"), + aggregations = Seq( + Builders.Aggregation( + operation = Operation.MAX, + inputColumn = "feature_score", + windows = Seq(new Window(30, TimeUnit.DAYS)) + ) + ), + derivations = Seq( + Builders.Derivation.star(), // Keep all base aggregation columns + Builders.Derivation( + name = s"feature_score", + expression = "feature_score_max_30d" + ) + ), + metaData = Builders.MetaData(name = "gb_feature", namespace = namespace), + accuracy = Accuracy.SNAPSHOT + ) + + // Create ExternalSource that expects internal_user_id + val externalSource = Builders.ExternalSource( + metadata = Builders.MetaData(name = "es_feature"), + keySchema = StructType("external_keys", Array(StructField("internal_user_id", StringType))), + valueSchema = StructType("external_values", Array(StructField("feature_score", LongType))) + ) + externalSource.setOfflineGroupBy(featureGroupBy) + + // Create left source with external_user_id + val requestColumns = List( + Column("external_user_id", StringType, 100), + Column("request_type", StringType, 5) + ) + + val requestTable = s"$namespace.user_requests" + spark.sql(s"DROP TABLE IF EXISTS $requestTable") + // Generate 60 partitions for the left table to limit the backfill window + // Feature table needs 135 partitions to support 60-day backfill + 30-day window + buffer + DataFrameGen.events(spark, requestColumns, 600, partitions = 60).save(requestTable) + + // Create Join with key mapping from external_user_id to internal_user_id + val joinConf = Builders.Join( + left = Builders.Source.events( + query = Builders.Query( + selects = Map("external_user_id" -> "external_user_id", "request_type" -> "request_type"), + timeColumn = "ts" + ), + table = requestTable + ), + externalParts = Seq( + Builders.ExternalPart( + externalSource, + keyMapping = Map("external_user_id" -> "internal_user_id"), + prefix = "mapped" + ) + ), + metaData = Builders.MetaData(name = s"test_keymapping_join", namespace = namespace) + ) + + // Run analyzer to ensure GroupBy tables are created (skip validation for test) + val analyzer = new Analyzer(tableUtils, joinConf, monthAgo, today, validateTablePermission = false, skipTimestampCheck = true) + analyzer.run() + + // Create Join and compute + val endPartition = monthAgo + val join = new Join(joinConf = joinConf, endPartition = endPartition, tableUtils) + val computed = join.computeJoin(Some(10)) + + // Verify results + assertNotNull("Computed result should not be null", computed) + assertTrue("Result should have rows", computed.count() > 0) + + // Verify column structure + val columns = computed.columns.toSet + assertTrue("Should contain external_user_id from left", columns.contains("external_user_id")) + assertTrue("Should contain request_type from left", columns.contains("request_type")) + assertTrue("Should contain mapped external columns", + columns.exists(_.startsWith("ext_mapped_"))) + spark.sql(s"DROP DATABASE IF EXISTS $namespace CASCADE") + } +} \ No newline at end of file diff --git a/spark/src/test/scala/ai/chronon/spark/test/ExternalSourcesTest.scala b/spark/src/test/scala/ai/chronon/spark/test/ExternalSourcesTest.scala index 659fcab36d..d326ccf036 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/ExternalSourcesTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/ExternalSourcesTest.scala @@ -91,7 +91,7 @@ class ExternalSourcesTest { ) // put this join into kv store - val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("external_test") + val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("ExternalSourcesTest_testFetch") val mockApi = new MockApi(kvStoreFunc, "external_test") val fetcher = mockApi.buildFetcher(true) fetcher.kvStore.create(ChrononMetadataKey) @@ -197,7 +197,7 @@ class ExternalSourcesTest { ) // Setup MockApi with factory registration - val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("factory_test") + val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("ExternalSourcesTest_testFactoryBasedExternalSources") val mockApi = new MockApi(kvStoreFunc, "factory_test") // Register a test factory that creates handlers dynamically @@ -227,6 +227,171 @@ class ExternalSourcesTest { assertEquals(numbers, (7 until 10).toSet) } + @Test + def testExternalSourceWithOfflineGroupBy(): Unit = { + // Create an offline GroupBy for the external source + val offlineGroupBy = Builders.GroupBy( + sources = Seq( + Builders.Source.events( + query = Builders.Query( + selects = Map("user_id" -> "user_id", "score" -> "score"), + timeColumn = "ts" + ), + table = "offline_table" + ) + ), + keyColumns = Seq("user_id"), + aggregations = Seq( + Builders.Aggregation( + operation = Operation.SUM, + inputColumn = "score", + windows = Seq(new Window(7, TimeUnit.DAYS)) + ) + ), + metaData = Builders.MetaData(name = "offline_gb", namespace = "test"), + accuracy = Accuracy.SNAPSHOT + ) + + // Create a regular GroupBy for comparison + val regularGroupBy = Builders.GroupBy( + sources = Seq( + Builders.Source.events( + query = Builders.Query( + selects = Map("user_id" -> "user_id", "activity_count" -> "activity_count"), + timeColumn = "ts" + ), + table = "regular_table" + ) + ), + keyColumns = Seq("user_id"), + aggregations = Seq( + Builders.Aggregation( + operation = Operation.COUNT, + inputColumn = "activity_count", + windows = Seq(new Window(1, TimeUnit.DAYS)) + ) + ), + metaData = Builders.MetaData(name = "regular_gb", namespace = "test"), + accuracy = Accuracy.SNAPSHOT + ) + + // Create factory configuration + val factoryConfig = new ExternalSourceFactoryConfig() + factoryConfig.setFactoryName("test-online-factory") + factoryConfig.setFactoryParams(Map("multiplier" -> "10").toJava) + + // Create external source WITH both offlineGroupBy and factory config + val externalSourceWithOffline = Builders.ExternalSource( + metadata = Builders.MetaData(name = "external_with_offline"), + keySchema = StructType("keys", Array(StructField("user_id", StringType))), + valueSchema = StructType("values", Array(StructField("score", LongType))) + ) + externalSourceWithOffline.setOfflineGroupBy(offlineGroupBy) + externalSourceWithOffline.setFactoryConfig(factoryConfig) + + val namespace = "offline_test" + val join = Builders.Join( + left = Builders.Source.events( + Builders.Query(selects = Map("user_id" -> "user_id")), + table = "non_existent_table" + ), + joinParts = Seq( + Builders.JoinPart( + groupBy = regularGroupBy, + prefix = "regular" + ) + ), + externalParts = Seq( + Builders.ExternalPart( + externalSourceWithOffline, + prefix = "offline" + ) + ), + metaData = Builders.MetaData(name = "test/offline_join", namespace = namespace, team = "chronon") + ) + + // Setup MockApi with factory registration + val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("ExternalSourcesTest_testExternalSourceWithOfflineGroupBy") + val mockApi = new MockApi(kvStoreFunc, "offline_test") + + // Register a test factory that returns specific values + // This proves online fetching uses the factory, NOT the offline GroupBy + mockApi.externalRegistry.addFactory("test-online-factory", new TestOnlineFactory()) + + val fetcher = mockApi.buildFetcher(true) + fetcher.kvStore.create(ChrononMetadataKey) + fetcher.putJoinConf(join) + + // Create test requests + val requests = Seq( + Request(join.metaData.name, Map("user_id" -> "user_1")), + Request(join.metaData.name, Map("user_id" -> "user_2")) + ) + + val responsesF = fetcher.fetchJoin(requests) + val responses = Await.result(responsesF, Duration(10, SECONDS)) + + // Verify responses came from the factory for external source, not from offline GroupBy + // This is the key test: even though offlineGroupBy is configured, online serving uses the factory + responses.foreach { response => + assertTrue("Response should be successful", response.values.isSuccess) + val responseMap = response.values.get + val keys = responseMap.keysIterator.toSet + + // Should have external source column from factory + assertTrue("Should contain external source column", keys.contains("ext_offline_external_with_offline_score")) + + // Should have regular GroupBy column + assertTrue("Should contain regular GroupBy column", keys.exists(_.startsWith("regular_"))) + + // Verify external source data comes from factory (100 or 200) + // This is the core assertion: values come from ExternalSourceFactory, NOT from offlineGroupBy + val score = responseMap("ext_offline_external_with_offline_score").asInstanceOf[Long] + assertTrue("Score should be from factory (100 or 200), proving online uses factory not offlineGroupBy", + score == 100L || score == 200L) + } + + // Verify both users got their expected scores from the factory + val scores = responses.map(_.values.get("ext_offline_external_with_offline_score").asInstanceOf[Long]).toSet + assertEquals("Both factory-generated scores should be present", Set(100L, 200L), scores) + + // Additional verification: Confirm the join has both regular GroupBy and external parts + assertEquals("Join should have 1 regular join part", 1, join.joinParts.size()) + assertEquals("Join should have 1 external part", 1, join.onlineExternalParts.size()) + + // Verify the external part has offlineGroupBy configured + val externalPart = join.onlineExternalParts.get(0) + assertNotNull("External source should have offlineGroupBy", externalPart.source.offlineGroupBy) + assertNotNull("External source should have factory config", externalPart.source.factoryConfig) + } + + // Test factory implementation that returns different values than what offline GroupBy would produce + class TestOnlineFactory extends ai.chronon.online.ExternalSourceFactory { + import ai.chronon.online.Fetcher.{Request, Response} + import scala.concurrent.Future + import scala.util.Success + + override def createExternalSourceHandler( + externalSource: ai.chronon.api.ExternalSource): ai.chronon.online.ExternalSourceHandler = { + new ai.chronon.online.ExternalSourceHandler { + override def fetch(requests: scala.collection.Seq[Request]): Future[scala.collection.Seq[Response]] = { + val responses = requests.map { request => + val userId = request.keys("user_id").asInstanceOf[String] + // Return deterministic values based on user_id that would be different from offline GroupBy + val score = userId match { + case "user_1" => 100L + case "user_2" => 200L + case _ => 999L + } + val result: Map[String, AnyRef] = Map("score" -> Long.box(score)) + Response(request = request, values = Success(result)) + } + Future.successful(responses) + } + } + } + } + // Test factory implementation for the factory-based registration test class TestExternalSourceFactory extends ai.chronon.online.ExternalSourceFactory { import ai.chronon.online.Fetcher.{Request, Response} diff --git a/spark/src/test/scala/ai/chronon/spark/test/FetchStatsTest.scala b/spark/src/test/scala/ai/chronon/spark/test/FetchStatsTest.scala index ef91b3a243..f80836405c 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/FetchStatsTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/FetchStatsTest.scala @@ -124,7 +124,7 @@ class FetchStatsTest extends TestCase { joinJob.computeJoin() // Load some data. implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1)) - val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("FetchStatsTest") + val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("FetchStatsTest_testFetchStats") val inMemoryKvStore = kvStoreFunc() val metadataStore = new MetadataStore(inMemoryKvStore, timeoutMillis = 10000) inMemoryKvStore.create(ChrononMetadataKey) diff --git a/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala b/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala index fea7c0c00e..ca58d46959 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala @@ -535,11 +535,12 @@ class FetcherTest extends TestCase { endDs: String, namespace: String, consistencyCheck: Boolean, - dropDsOnWrite: Boolean): Unit = { + dropDsOnWrite: Boolean, + sessionName: String): Unit = { implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1)) val spark: SparkSession = createSparkSession() val tableUtils = TableUtils(spark) - val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("FetcherTest") + val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore(sessionName) val inMemoryKvStore = kvStoreFunc() val mockApi = new MockApi(kvStoreFunc, namespace) @@ -668,7 +669,8 @@ class FetcherTest extends TestCase { def testTemporalFetchJoinDeterministic(): Unit = { val namespace = "deterministic_fetch" val joinConf = generateMutationData(namespace) - compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true) + compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true, + sessionName = "FetcherTest_testTemporalFetchJoinDeterministic") } def testTemporalFetchJoinDerivation(): Unit = { @@ -682,7 +684,8 @@ class FetcherTest extends TestCase { ) joinConf.setDerivations(derivations.toJava) - compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true) + compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true, + sessionName = "FetcherTest_testTemporalFetchJoinDerivation") } def testTemporalFetchJoinDerivationRenameOnly(): Unit = { @@ -692,7 +695,8 @@ class FetcherTest extends TestCase { Seq(Builders.Derivation.star(), Builders.Derivation(name = "listing_id_renamed", expression = "listing_id")) joinConf.setDerivations(derivations.toJava) - compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true) + compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true, + sessionName = "FetcherTest_testTemporalFetchJoinDerivationRenameOnly") } def testTemporalFetchJoinGenerated(): Unit = { @@ -702,13 +706,15 @@ class FetcherTest extends TestCase { dummyTableUtils.partitionSpec.at(System.currentTimeMillis()), namespace, consistencyCheck = true, - dropDsOnWrite = false) + dropDsOnWrite = false, + sessionName = "FetcherTest_testTemporalFetchJoinGenerated") } def testTemporalTiledFetchJoinDeterministic(): Unit = { val namespace = "deterministic_tiled_fetch" val joinConf = generateEventOnlyData(namespace, groupByCustomJson = Some("{\"enable_tiling\": true}")) - compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true) + compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true, + sessionName = "FetcherTest_testTemporalTiledFetchJoinDeterministic") } // test soft-fail on missing keys @@ -717,7 +723,7 @@ class FetcherTest extends TestCase { val namespace = "empty_request" val joinConf = generateRandomData(namespace, 5, 5) implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1)) - val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("FetcherTest#empty_request") + val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("FetcherTest_testEmptyRequest") val inMemoryKvStore = kvStoreFunc() val mockApi = new MockApi(kvStoreFunc, namespace) @@ -743,7 +749,7 @@ class FetcherTest extends TestCase { val joinConf = generateMutationData(namespace, Some(spark)) val endDs = "2021-04-10" val tableUtils = TableUtils(spark) - val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("FetcherTest") + val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("FetcherTest_testTemporalFetchGroupByNonExistKey") val inMemoryKvStore = kvStoreFunc() val mockApi = new MockApi(kvStoreFunc, namespace) @transient lazy val fetcher = mockApi.buildFetcher(debug = false) @@ -770,7 +776,7 @@ class FetcherTest extends TestCase { implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1)) val kvStoreFunc = () => - OnlineUtils.buildInMemoryKVStore("FetcherTest#test_kv_store_partial_failure", hardFailureOnInvalidDataset = true) + OnlineUtils.buildInMemoryKVStore("FetcherTest_testKVStorePartialFailure", hardFailureOnInvalidDataset = true) val inMemoryKvStore = kvStoreFunc() val mockApi = new MockApi(kvStoreFunc, namespace) @@ -798,7 +804,7 @@ class FetcherTest extends TestCase { val groupByConf = joinConf.joinParts.toScala.head.groupBy val endDs = "2021-04-10" val tableUtils = TableUtils(spark) - val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("FetcherTest") + val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("FetcherTest_testGroupByServingInfoTtlCacheRefresh") OnlineUtils.serve(tableUtils, kvStoreFunc(), kvStoreFunc, namespace, endDs, groupByConf, dropDsOnWrite = true) val spyKvStore = spy(kvStoreFunc()) @@ -833,7 +839,7 @@ class FetcherTest extends TestCase { val joinConf = generateMutationData(namespace, Some(spark)) val endDs = "2021-04-10" val tableUtils = TableUtils(spark) - val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("FetcherTest") + val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("FetcherTest_testJoinConfTtlCacheRefresh") val inMemoryKvStore = kvStoreFunc() joinConf.joinParts.toScala.foreach(jp => OnlineUtils.serve(tableUtils, inMemoryKvStore, kvStoreFunc, namespace, endDs, jp.groupBy, dropDsOnWrite = true)) diff --git a/spark/src/test/scala/ai/chronon/spark/test/GroupByUploadTest.scala b/spark/src/test/scala/ai/chronon/spark/test/GroupByUploadTest.scala index a6e9283fff..a3e844e4af 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/GroupByUploadTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/GroupByUploadTest.scala @@ -304,9 +304,9 @@ class GroupByUploadTest { ) ) - val kvStore = OnlineUtils.buildInMemoryKVStore("chaining_test") + val kvStore = OnlineUtils.buildInMemoryKVStore("GroupByUploadTest_listingRatingCategoryJoinSourceTest") val endDs = "2023-08-15" - val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("chaining_test") + val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("GroupByUploadTest_listingRatingCategoryJoinSourceTest") // DO-NOT-SET debug=true here since the streaming job won't put data into kv store joinConf.joinParts.toScala.foreach(jp => diff --git a/spark/src/test/scala/ai/chronon/spark/test/MetadataStoreTest.scala b/spark/src/test/scala/ai/chronon/spark/test/MetadataStoreTest.scala index be3be67486..2cb4bb684f 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/MetadataStoreTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/MetadataStoreTest.scala @@ -24,7 +24,7 @@ class MetadataStoreTest extends TestCase { val acceptedEndPoints = List(MetadataEndPoint.ConfByKeyEndPointName, MetadataEndPoint.NameByTeamEndPointName) def testMetadataStoreSingleFile(): Unit = { - val inMemoryKvStore = OnlineUtils.buildInMemoryKVStore("FetcherTest") + val inMemoryKvStore = OnlineUtils.buildInMemoryKVStore("MetadataStoreTest_testMetadataStoreSingleFile") val singleFileDataSet = ChrononMetadataKey val singleFileMetaDataSet = NameByTeamEndPointName val singleFileMetadataStore = new MetadataStore(inMemoryKvStore, singleFileDataSet, timeoutMillis = 10000) @@ -57,7 +57,7 @@ class MetadataStoreTest extends TestCase { } def testMetadataStoreDirectory(): Unit = { - val inMemoryKvStore = OnlineUtils.buildInMemoryKVStore("FetcherTest") + val inMemoryKvStore = OnlineUtils.buildInMemoryKVStore("MetadataStoreTest_testMetadataStoreDirectory") val directoryDataSetDataSet = ChrononMetadataKey val directoryMetadataDataSet = NameByTeamEndPointName val directoryMetadataStore = new MetadataStore(inMemoryKvStore, directoryDataSetDataSet, timeoutMillis = 10000) diff --git a/spark/src/test/scala/ai/chronon/spark/test/ModelTransformsTest.scala b/spark/src/test/scala/ai/chronon/spark/test/ModelTransformsTest.scala index a54a7caba7..6871cd5531 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/ModelTransformsTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/ModelTransformsTest.scala @@ -85,7 +85,7 @@ class ModelTransformsTest { val spark = createSparkSession() val tableUtils = TableUtils(spark) - val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("testModelTransformSimple") + val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore(s"ModelTransformsTest_$namespace") val inMemoryKvStore = kvStoreFunc() val defaultEmbedding = Array(0.1, 0.2, 0.3, 0.4) val defaultMap = Map("embedding" -> defaultEmbedding.asInstanceOf[AnyRef]) diff --git a/spark/src/test/scala/ai/chronon/spark/test/SchemaEvolutionTest.scala b/spark/src/test/scala/ai/chronon/spark/test/SchemaEvolutionTest.scala index e5a67bfb2d..6f9dd13b10 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/SchemaEvolutionTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/SchemaEvolutionTest.scala @@ -382,7 +382,7 @@ class SchemaEvolutionTest extends TestCase { assert(joinSuiteV1.joinConf.metaData.name == joinSuiteV2.joinConf.metaData.name, message = "Schema evolution can only be tested on changes of the SAME join") val tableUtils: TableUtils = TableUtils(spark) - val inMemoryKvStore = OnlineUtils.buildInMemoryKVStore(namespace) + val inMemoryKvStore = OnlineUtils.buildInMemoryKVStore(s"SchemaEvolutionTest_$namespace") val mockApi = new MockApi(() => inMemoryKvStore, namespace) inMemoryKvStore.create(ChrononMetadataKey) val metadataStore = new MetadataStore(inMemoryKvStore, timeoutMillis = 10000) @@ -520,7 +520,7 @@ class SchemaEvolutionTest extends TestCase { val joinTestSuite = createStructJoin(namespace) val tableUtils: TableUtils = TableUtils(spark) - val inMemoryKvStore = OnlineUtils.buildInMemoryKVStore(namespace) + val inMemoryKvStore = OnlineUtils.buildInMemoryKVStore("SchemaEvolutionTest_testStructFeatures") val mockApi = new MockApi(() => inMemoryKvStore, namespace) inMemoryKvStore.create(ChrononMetadataKey) val metadataStore = new MetadataStore(inMemoryKvStore, timeoutMillis = 10000) diff --git a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/DerivationTest.scala b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/DerivationTest.scala index 05ca50fce3..b8be60d5f4 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/DerivationTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/DerivationTest.scala @@ -428,7 +428,7 @@ class DerivationTest { ) // Init artifacts to run online fetching and logging - val kvStore = OnlineUtils.buildInMemoryKVStore(namespace) + val kvStore = OnlineUtils.buildInMemoryKVStore(s"DerivationTest_$namespace") val mockApi = new MockApi(() => kvStore, namespace) OnlineUtils.serve(tableUtils, kvStore, () => kvStore, namespace, endDs, groupBy) val fetcher = mockApi.buildFetcher(debug = true) diff --git a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/LogBootstrapTest.scala b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/LogBootstrapTest.scala index 51d99e64ac..51c71da0e1 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/LogBootstrapTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/LogBootstrapTest.scala @@ -104,7 +104,7 @@ class LogBootstrapTest { val joinV2 = createBootstrapJoin(baseJoinV2) // Init artifacts to run online fetching and logging - val kvStore = OnlineUtils.buildInMemoryKVStore(namespace) + val kvStore = OnlineUtils.buildInMemoryKVStore("LogBootstrapTest_testBootstrap") val mockApi = new MockApi(() => kvStore, namespace) val endDs = spark.table(queryTable).select(max(tableUtils.partitionColumn)).head().getString(0) OnlineUtils.serve(tableUtils, kvStore, () => kvStore, namespace, endDs, groupBy)