diff --git a/api/python/ai/chronon/join.py b/api/python/ai/chronon/join.py index 8b9466dfeb..be4eeead1a 100644 --- a/api/python/ai/chronon/join.py +++ b/api/python/ai/chronon/join.py @@ -356,7 +356,7 @@ def Join( left: api.Source, right_parts: List[api.JoinPart], version: int, - row_ids: Union[str, List[str]], + row_ids: Union[str, List[str]] = None, online_external_parts: List[api.ExternalPart] = None, bootstrap_parts: List[api.BootstrapPart] = None, bootstrap_from_log: bool = False, @@ -478,7 +478,10 @@ def Join( # create a deep copy for case: multiple LeftOuterJoin use the same left, # validation will fail after the first iteration updated_left = copy.deepcopy(left) - if left.events and left.events.query.selects: + + selects = None + if left.events: + selects = left.events.query.selects assert "ts" not in left.events.query.selects.keys(), ( "'ts' is a reserved key word for Chronon," " please specify the expression in timeColumn" @@ -487,6 +490,13 @@ def Join( updated_left.events.query.selects.update( {"ts": updated_left.events.query.timeColumn} ) + elif left.entities: + selects = left.entities.query.selects + + if selects: + # For JoinSource, we can rely on the validation at the base join level + # TODO add more docs about row ID and link here + assert "row_id" in selects, "Left side of the join must contain `row_id` as a column." if label_part: label_metadata = api.MetaData( diff --git a/api/python/ai/chronon/query.py b/api/python/ai/chronon/query.py index 676c97026d..24ddc454b0 100644 --- a/api/python/ai/chronon/query.py +++ b/api/python/ai/chronon/query.py @@ -19,7 +19,7 @@ def Query( - selects: Dict[str, str] = None, + selects: Dict[str, str], wheres: List[str] = None, start_partition: str = None, end_partition: str = None, diff --git a/api/python/test/canary/compiled/joins/gcp/item_event_join.canary_batch_v1__0 b/api/python/test/canary/compiled/joins/gcp/item_event_join.canary_batch_v1__0 index 21d03cb7d0..912a2948ca 100644 --- a/api/python/test/canary/compiled/joins/gcp/item_event_join.canary_batch_v1__0 +++ b/api/python/test/canary/compiled/joins/gcp/item_event_join.canary_batch_v1__0 @@ -17,6 +17,7 @@ "user_id_purchase_price_last10": "5609e04d61a47a8cafc67970785a3a59", "listing_id": "f2f6c814b8ae1521176b22a1ae7f2d0d", "user_id": "493d3df28f80664abd11b19fcd33b6e6", + "row_id": "6c80b6474532731fdb20d0a43bcb287f", "ts": "ad9fd4c611e20ad833819a4ce9d752bf" }, "online": 1, @@ -144,6 +145,7 @@ "selects": { "listing_id": "EXPLODE(TRANSFORM(SPLIT(COALESCE(attributes['sold_listing_ids'], attributes['listing_id']), ','), e -> CAST(e AS LONG)))", "user_id": "attributes['user_id']", + "row_id": "request_UUID", "ts": "timestamp" }, "timeColumn": "timestamp" diff --git a/api/python/test/canary/compiled/joins/gcp/item_event_join.canary_combined_v1__0 b/api/python/test/canary/compiled/joins/gcp/item_event_join.canary_combined_v1__0 index 945d4eea9f..188bbedcdc 100644 --- a/api/python/test/canary/compiled/joins/gcp/item_event_join.canary_combined_v1__0 +++ b/api/python/test/canary/compiled/joins/gcp/item_event_join.canary_combined_v1__0 @@ -21,6 +21,7 @@ "user_id_purchase_price_last10": "5609e04d61a47a8cafc67970785a3a59", "listing_id": "f2f6c814b8ae1521176b22a1ae7f2d0d", "user_id": "493d3df28f80664abd11b19fcd33b6e6", + "row_id": "6c80b6474532731fdb20d0a43bcb287f", "ts": "ad9fd4c611e20ad833819a4ce9d752bf" }, "online": 1, @@ -148,6 +149,7 @@ "selects": { "listing_id": "EXPLODE(TRANSFORM(SPLIT(COALESCE(attributes['sold_listing_ids'], attributes['listing_id']), ','), e -> CAST(e AS LONG)))", "user_id": "attributes['user_id']", + "row_id": "request_UUID", "ts": "timestamp" }, "timeColumn": "timestamp" diff --git a/api/python/test/canary/compiled/joins/gcp/item_event_join.canary_streaming_v1__0 b/api/python/test/canary/compiled/joins/gcp/item_event_join.canary_streaming_v1__0 index b2aaafca27..84d0e9893a 100644 --- a/api/python/test/canary/compiled/joins/gcp/item_event_join.canary_streaming_v1__0 +++ b/api/python/test/canary/compiled/joins/gcp/item_event_join.canary_streaming_v1__0 @@ -11,6 +11,7 @@ "listing_id_favorite_sum_1d": "bb07e39a4b7ea81aca5baf027fd8f555", "listing_id": "f2f6c814b8ae1521176b22a1ae7f2d0d", "user_id": "493d3df28f80664abd11b19fcd33b6e6", + "row_id": "6c80b6474532731fdb20d0a43bcb287f", "ts": "ad9fd4c611e20ad833819a4ce9d752bf" }, "online": 1, @@ -138,6 +139,7 @@ "selects": { "listing_id": "EXPLODE(TRANSFORM(SPLIT(COALESCE(attributes['sold_listing_ids'], attributes['listing_id']), ','), e -> CAST(e AS LONG)))", "user_id": "attributes['user_id']", + "row_id": "request_UUID", "ts": "timestamp" }, "timeColumn": "timestamp" @@ -251,8 +253,5 @@ "useLongNames": 0 } ], - "rowIds": [ - "user_id" - ], "useLongNames": 0 } \ No newline at end of file diff --git a/api/python/test/canary/compiled/joins/gcp/training_set.v1_dev__0 b/api/python/test/canary/compiled/joins/gcp/training_set.v1_dev__0 index 9ba8366a4d..ea253074e8 100644 --- a/api/python/test/canary/compiled/joins/gcp/training_set.v1_dev__0 +++ b/api/python/test/canary/compiled/joins/gcp/training_set.v1_dev__0 @@ -16,6 +16,7 @@ "user_id_purchase_price_average_7d": "e3351cc3265d17cfd1b99f8da7fb083d", "user_id_purchase_price_last10": "e3351cc3265d17cfd1b99f8da7fb083d", "user_id": "a673c80ffc3260ee5ca173869ba07f7c", + "row_id": "cd335c9bb6f12c6e71b482b33109e07c", "ts": "a173b2adccfe3c386f262c4cd7ffe9bd" }, "online": 0, @@ -142,6 +143,7 @@ "query": { "selects": { "user_id": "user_id", + "row_id": "request_UUID", "ts": "ts" }, "startPartition": "2023-11-01", @@ -265,8 +267,5 @@ "useLongNames": 0 } ], - "rowIds": [ - "user_id" - ], "useLongNames": 0 } \ No newline at end of file diff --git a/api/python/test/canary/compiled/joins/gcp/training_set.v1_dev_notds__0 b/api/python/test/canary/compiled/joins/gcp/training_set.v1_dev_notds__0 index 597e2d7488..e12457cf1f 100644 --- a/api/python/test/canary/compiled/joins/gcp/training_set.v1_dev_notds__0 +++ b/api/python/test/canary/compiled/joins/gcp/training_set.v1_dev_notds__0 @@ -16,6 +16,7 @@ "user_id_purchase_price_average_7d": "1d85a493771d48ac9fb6581c18f186e2", "user_id_purchase_price_last10": "1d85a493771d48ac9fb6581c18f186e2", "user_id": "fbc25deb4a5dd5d9c1d6e960c3922d0e", + "row_id": "70230580430bc568b5ab415e20b802dc", "ts": "c4364195d7639fb8b71e2f87ee017d4d" }, "online": 0, @@ -142,6 +143,7 @@ "query": { "selects": { "user_id": "user_id", + "row_id": "request_UUID", "ts": "ts" }, "timeColumn": "ts", @@ -266,8 +268,5 @@ "useLongNames": 0 } ], - "rowIds": [ - "user_id" - ], "useLongNames": 0 } \ No newline at end of file diff --git a/api/python/test/canary/compiled/joins/gcp/training_set.v1_test__0 b/api/python/test/canary/compiled/joins/gcp/training_set.v1_test__0 index cdfe39c9af..db8d908b1b 100644 --- a/api/python/test/canary/compiled/joins/gcp/training_set.v1_test__0 +++ b/api/python/test/canary/compiled/joins/gcp/training_set.v1_test__0 @@ -16,6 +16,7 @@ "user_id_purchase_price_average_7d": "e3351cc3265d17cfd1b99f8da7fb083d", "user_id_purchase_price_last10": "e3351cc3265d17cfd1b99f8da7fb083d", "user_id": "a673c80ffc3260ee5ca173869ba07f7c", + "row_id": "cd335c9bb6f12c6e71b482b33109e07c", "ts": "a173b2adccfe3c386f262c4cd7ffe9bd" }, "online": 0, @@ -142,6 +143,7 @@ "query": { "selects": { "user_id": "user_id", + "row_id": "request_UUID", "ts": "ts" }, "startPartition": "2023-11-01", @@ -265,8 +267,5 @@ "useLongNames": 0 } ], - "rowIds": [ - "user_id" - ], "useLongNames": 0 } \ No newline at end of file diff --git a/api/python/test/canary/compiled/joins/gcp/training_set.v1_test_notds__0 b/api/python/test/canary/compiled/joins/gcp/training_set.v1_test_notds__0 index 96ddc9d756..1a92a8dc98 100644 --- a/api/python/test/canary/compiled/joins/gcp/training_set.v1_test_notds__0 +++ b/api/python/test/canary/compiled/joins/gcp/training_set.v1_test_notds__0 @@ -16,6 +16,7 @@ "user_id_purchase_price_average_7d": "1d85a493771d48ac9fb6581c18f186e2", "user_id_purchase_price_last10": "1d85a493771d48ac9fb6581c18f186e2", "user_id": "fbc25deb4a5dd5d9c1d6e960c3922d0e", + "row_id": "70230580430bc568b5ab415e20b802dc", "ts": "c4364195d7639fb8b71e2f87ee017d4d" }, "online": 0, @@ -142,6 +143,7 @@ "query": { "selects": { "user_id": "user_id", + "row_id": "request_UUID", "ts": "ts" }, "timeColumn": "ts", @@ -266,8 +268,5 @@ "useLongNames": 0 } ], - "rowIds": [ - "user_id" - ], "useLongNames": 0 } \ No newline at end of file diff --git a/api/python/test/canary/joins/gcp/item_event_join.py b/api/python/test/canary/joins/gcp/item_event_join.py index f48b5f4d7a..7407851662 100644 --- a/api/python/test/canary/joins/gcp/item_event_join.py +++ b/api/python/test/canary/joins/gcp/item_event_join.py @@ -11,6 +11,7 @@ selects=selects( listing_id="EXPLODE(TRANSFORM(SPLIT(COALESCE(attributes['sold_listing_ids'], attributes['listing_id']), ','), e -> CAST(e AS LONG)))", user_id="attributes['user_id']", + row_id="request_UUID", ), time_column="timestamp", ), @@ -20,7 +21,6 @@ # Join with just a streaming GB canary_streaming_v1 = Join( left=source, - row_ids="user_id", right_parts=[ JoinPart(group_by=item_event_canary.actions_pubsub_v2) ], diff --git a/api/python/test/canary/joins/gcp/training_set.py b/api/python/test/canary/joins/gcp/training_set.py index 75c274e655..e0b43b58e7 100644 --- a/api/python/test/canary/joins/gcp/training_set.py +++ b/api/python/test/canary/joins/gcp/training_set.py @@ -13,7 +13,8 @@ table="data.checkouts", query=Query( selects=selects( - "user_id" + user_id="user_id", + row_id="request_UUID", ), # The primary key used to join various GroupBys together start_partition="2023-11-01", time_column="ts", @@ -23,7 +24,6 @@ v1_test = Join( left=source, - row_ids="user_id", right_parts=[ JoinPart(group_by=purchases.v1_test) ], @@ -41,7 +41,6 @@ v1_dev = Join( left=source, - row_ids="user_id", right_parts=[ JoinPart(group_by=purchases.v1_dev) ], @@ -53,7 +52,8 @@ table="data.checkouts_notds", query=Query( selects=selects( - "user_id" + user_id = "user_id", + row_id="request_UUID", ), # The primary key used to join various GroupBys together time_column="ts", partition_column="notds" @@ -63,7 +63,6 @@ v1_test_notds = Join( left=source_notds, - row_ids=["user_id"], right_parts=[ JoinPart(group_by=purchases.v1_test_notds) ], @@ -72,7 +71,6 @@ v1_dev_notds = Join( left=source_notds, - row_ids=["user_id"], right_parts=[ JoinPart(group_by=purchases.v1_dev_notds) ], diff --git a/api/python/test/sample/joins/kaggle/outbrain.py b/api/python/test/sample/joins/kaggle/outbrain.py index f4ee3bc101..d564fec80f 100644 --- a/api/python/test/sample/joins/kaggle/outbrain.py +++ b/api/python/test/sample/joins/kaggle/outbrain.py @@ -21,8 +21,7 @@ training_set = Join( # left equi join left=outbrain_left_events( - "uuid", "display_id", "ad_id", "document_id", "clicked", "geo_location", "platform"), - row_ids="uuid", + "uuid", "display_id", "ad_id", "document_id", "clicked", "geo_location", "platform", "row_id"), right_parts=[JoinPart(group_by=group_by) for group_by in [ad_doc, ad_uuid, ad_streaming, ad_platform]], use_long_names = True, version=0, diff --git a/api/python/test/sample/joins/quickstart/training_set.py b/api/python/test/sample/joins/quickstart/training_set.py index 2f3170dba0..965d6332ed 100644 --- a/api/python/test/sample/joins/quickstart/training_set.py +++ b/api/python/test/sample/joins/quickstart/training_set.py @@ -29,7 +29,8 @@ table="data.checkouts", query=Query( selects=selects( - "user_id" + user_id="user_id", + row_id="request_UUID", ), # The primary key used to join various GroupBys together time_column="ts", ), # The event time used to compute feature values as-of @@ -38,7 +39,6 @@ v1 = Join( left=source, - row_ids="user_id", right_parts=[ JoinPart(group_by=group_by) for group_by in [purchases_v1, returns_v1, users] ], # Include the three GroupBys @@ -47,7 +47,6 @@ v2 = Join( left=source, - row_ids=["user_id"], right_parts=[ JoinPart(group_by=group_by) for group_by in [purchases_v1, returns_v1] ], # Include the two online GroupBys diff --git a/api/python/test/sample/joins/risk/user_transactions.py b/api/python/test/sample/joins/risk/user_transactions.py index c54ae5904b..29990fb1a1 100644 --- a/api/python/test/sample/joins/risk/user_transactions.py +++ b/api/python/test/sample/joins/risk/user_transactions.py @@ -8,13 +8,12 @@ source_users = Source( events=EventSource( - table="data.users", query=Query(selects=selects("user_id"), time_column="ts") + table="data.users", query=Query(selects=selects("user_id", "row_id"), time_column="ts") ) ) txn_join = Join( left=source_users, - row_ids="user_id", right_parts=[ JoinPart(group_by=txn_group_by_user, prefix="user"), JoinPart(group_by=txn_group_by_merchant, prefix="merchant"), diff --git a/api/python/test/sample/joins/sample_team/sample_backfill_mutation_join.py b/api/python/test/sample/joins/sample_team/sample_backfill_mutation_join.py index ba10e25366..563f7c0043 100644 --- a/api/python/test/sample/joins/sample_team/sample_backfill_mutation_join.py +++ b/api/python/test/sample/joins/sample_team/sample_backfill_mutation_join.py @@ -23,7 +23,6 @@ v0 = Join( left=test_sources.event_source, - row_ids="subject", right_parts=[JoinPart(group_by=mutation_sample_group_by.v0)], online=False, version=0, diff --git a/api/python/test/sample/joins/sample_team/sample_chaining_join.py b/api/python/test/sample/joins/sample_team/sample_chaining_join.py index b624df9518..2af73240a5 100644 --- a/api/python/test/sample/joins/sample_team/sample_chaining_join.py +++ b/api/python/test/sample/joins/sample_team/sample_chaining_join.py @@ -23,7 +23,6 @@ v1 = Join( left=test_sources.event_source, - row_ids=["subject", "event"], right_parts=[ JoinPart( group_by=chaining_group_by_v1, diff --git a/api/python/test/sample/joins/sample_team/sample_chaining_join_parent.py b/api/python/test/sample/joins/sample_team/sample_chaining_join_parent.py index 14d3246605..8260a58ea9 100644 --- a/api/python/test/sample/joins/sample_team/sample_chaining_join_parent.py +++ b/api/python/test/sample/joins/sample_team/sample_chaining_join_parent.py @@ -8,7 +8,6 @@ parent_join = Join( left=test_sources.event_source, - row_ids="group_by_subject", right_parts=[ JoinPart( group_by=event_sample_group_by.v1, diff --git a/api/python/test/sample/joins/sample_team/sample_join.py b/api/python/test/sample/joins/sample_team/sample_join.py index a81818b3f9..089e93e018 100644 --- a/api/python/test/sample/joins/sample_team/sample_join.py +++ b/api/python/test/sample/joins/sample_team/sample_join.py @@ -25,7 +25,6 @@ v1 = Join( left=test_sources.staging_entities, right_parts=[JoinPart(group_by=sample_group_by.v1)], - row_ids="place_id", table_properties={"config_json": """{"sample_key": "sample_value"}"""}, output_namespace="sample_namespace", env_vars=EnvironmentVariables( @@ -41,7 +40,6 @@ never = Join( left=test_sources.staging_entities, right_parts=[JoinPart(group_by=sample_group_by.v1)], - row_ids=["s2CellId", "place_id"], output_namespace="sample_namespace", offline_schedule="@never", version=0, @@ -50,7 +48,6 @@ group_by_of_group_by = Join( left=test_sources.staging_entities, right_parts=[JoinPart(group_by=sample_group_by_group_by.v1)], - row_ids="s2CellId", output_namespace="sample_namespace", version=0, ) @@ -58,7 +55,6 @@ consistency_check = Join( left=test_sources.staging_entities, right_parts=[JoinPart(group_by=sample_group_by.v1)], - row_ids="place_id", output_namespace="sample_namespace", check_consistency=True, version=0, @@ -67,7 +63,6 @@ no_log_flattener = Join( left=test_sources.staging_entities, right_parts=[JoinPart(group_by=sample_group_by.v1)], - row_ids=["place_id"], output_namespace="sample_namespace", sample_percent=0.0, version=0, diff --git a/api/python/test/sample/sources/test_sources.py b/api/python/test/sample/sources/test_sources.py index 400591df92..d561aab0e4 100644 --- a/api/python/test/sample/sources/test_sources.py +++ b/api/python/test/sample/sources/test_sources.py @@ -26,6 +26,7 @@ def basic_event_source(table): selects=selects( event="event_expr", group_by_subject="group_by_expr", + row_id="row_id_expr", ), start_partition="2021-04-09", time_column="ts", @@ -41,6 +42,7 @@ def basic_event_source(table): event="event_expr", group_by_subject="group_by_expr", subject="subject", + row_id="row_id_expr", ), start_partition="2021-04-09", time_column="ts", @@ -59,6 +61,7 @@ def basic_event_source(table): selects=selects( group_by_subject="group_by_subject_expr", entity="entity_expr", + row_id="row_id_expr", ), time_column="ts", ), @@ -72,6 +75,7 @@ def basic_event_source(table): selects=selects( group_by_subject="group_by_subject_expr", entity="entity_expr", + row_id="row_id_expr", ), time_column="ts", ), @@ -83,6 +87,7 @@ def basic_event_source(table): "viewed_unique_count_1d": "viewed_unique_count_1d", "s2CellId": "s2CellId", "place_id": "place_id", + "row_id": "row_id_expr", } ) @@ -108,6 +113,7 @@ def basic_event_source(table): **{ "group_by_subject": "group_by_subject_expr_old_version", "event": "event_expr_old_version", + "row_id": "row_id_expr", } ), time_column="UNIX_TIMESTAMP(ts) * 1000", @@ -124,6 +130,7 @@ def basic_event_source(table): **{ "group_by_subject": "possibly_different_group_by_subject_expr", "event": "possibly_different_event_expr", + "row_id": "possibly_different_row_id_expr", } ), time_column="__timestamp", diff --git a/api/src/main/scala/ai/chronon/api/Constants.scala b/api/src/main/scala/ai/chronon/api/Constants.scala index 07babb9f24..23ba648036 100644 --- a/api/src/main/scala/ai/chronon/api/Constants.scala +++ b/api/src/main/scala/ai/chronon/api/Constants.scala @@ -66,6 +66,7 @@ object Constants { val LabelViewPropertyFeatureTable: String = "feature_table" val LabelViewPropertyKeyLabelTable: String = "label_table" val ChrononRunDs: String = "CHRONON_RUN_DS" + val RowIDColumn: String = "row_id" val TiledSummaryDataset: String = "TILE_SUMMARIES" diff --git a/online/src/main/scala/ai/chronon/online/fetcher/MetadataStore.scala b/online/src/main/scala/ai/chronon/online/fetcher/MetadataStore.scala index 1b5a21049a..fd85016026 100644 --- a/online/src/main/scala/ai/chronon/online/fetcher/MetadataStore.scala +++ b/online/src/main/scala/ai/chronon/online/fetcher/MetadataStore.scala @@ -246,6 +246,7 @@ class MetadataStore(fetchContext: FetchContext) { val valueFields = new mutable.ListBuffer[StructField] val valueInfos = mutable.ListBuffer.empty[JoinCodec.ValueInfo] var hasPartialFailure = false + // collect keyFields and valueFields from joinParts/GroupBys joinConf.joinPartOps.foreach { joinPart => getGroupByServingInfo(joinPart.groupBy.metaData.getName) @@ -309,7 +310,10 @@ class MetadataStore(fetchContext: FetchContext) { } val joinName = joinConf.metaData.nameToFilePath - val keySchema = StructType(s"${joinName.sanitize}_key", keyFields.toArray) + val keyFieldsWithRowId = if (keyFields.nonEmpty) { + Array(StructField(Constants.RowIDColumn, StringType)) ++ keyFields + } else keyFields.toArray + val keySchema = StructType(s"${joinName.sanitize}_key", keyFieldsWithRowId) val keyCodec = AvroCodec.of(AvroConversions.fromChrononSchema(keySchema).toString) val baseValueSchema = StructType(s"${joinName.sanitize}_value", valueFields.toArray) val baseValueCodec = serde.AvroCodec.of(AvroConversions.fromChrononSchema(baseValueSchema).toString) diff --git a/spark/BUILD.bazel b/spark/BUILD.bazel index e5e9ebfda7..6fb39f96d0 100644 --- a/spark/BUILD.bazel +++ b/spark/BUILD.bazel @@ -132,6 +132,7 @@ test_deps = _SCALA_TEST_DEPS + [ maven_artifact("org.apache.hive:hive-exec"), maven_artifact("org.apache.hadoop:hadoop-common"), maven_artifact("org.apache.hadoop:hadoop-client-api"), + maven_artifact_with_suffix("org.apache.iceberg:iceberg-spark-runtime-3.5"), ] scala_library( diff --git a/spark/src/main/scala/ai/chronon/spark/Extensions.scala b/spark/src/main/scala/ai/chronon/spark/Extensions.scala index 02d0c97bfd..422dad43bb 100644 --- a/spark/src/main/scala/ai/chronon/spark/Extensions.scala +++ b/spark/src/main/scala/ai/chronon/spark/Extensions.scala @@ -141,13 +141,15 @@ object Extensions { def save(tableName: String, tableProperties: Map[String, String] = null, partitionColumns: Seq[String] = List(tableUtils.partitionColumn), - autoExpand: Boolean = false): Unit = { + autoExpand: Boolean = false, + bucketByRowId: Boolean = false): Unit = { TableUtils(df.sparkSession).insertPartitions(df, tableName, tableProperties, partitionColumns.toList, - autoExpand = autoExpand) + autoExpand = autoExpand, + bucketByRowId = bucketByRowId) } def prefixColumnNames(prefix: String, columns: Seq[String]): DataFrame = { diff --git a/spark/src/main/scala/ai/chronon/spark/Join.scala b/spark/src/main/scala/ai/chronon/spark/Join.scala index 01d9074826..0f38c94fb6 100644 --- a/spark/src/main/scala/ai/chronon/spark/Join.scala +++ b/spark/src/main/scala/ai/chronon/spark/Join.scala @@ -195,13 +195,7 @@ class Join(joinConf: api.Join, private def getRightPartsData(leftRange: PartitionRange): Seq[(JoinPart, DataFrame)] = { joinConfCloned.joinParts.asScala.map { joinPart => val partTable = joinConfCloned.partOutputTable(joinPart) - val effectiveRange = - if (joinConfCloned.left.dataModel != ENTITIES && joinPart.groupBy.inferredAccuracy == Accuracy.SNAPSHOT) { - leftRange.shift(-1) - } else { - leftRange - } - val wheres = effectiveRange.whereClauses + val wheres = leftRange.whereClauses val sql = QueryUtils.build(null, partTable, wheres) logger.info(s"Pulling data from joinPart table with: $sql") (joinPart, tableUtils.scanDfBase(null, partTable, List.empty, wheres, None)) @@ -212,12 +206,13 @@ class Join(joinConf: api.Join, val bootstrapDf = tableUtils.scanDf(query = null, table = bootstrapTable, range = Some(leftRange)).addTimebasedColIfExists() val rightPartsData = getRightPartsData(leftRange) + val joinedDfTry = try { Success( rightPartsData .foldLeft(bootstrapDf) { case (partialDf, (rightPart, rightDf)) => - joinWithLeft(partialDf, rightDf, rightPart) + JoinUtils.joinWithLeft(partialDf, rightDf, rightPart, tableUtils) } // drop all processing metadata columns .drop(Constants.MatchedHashes, Constants.TimePartitionColumn)) @@ -369,7 +364,7 @@ class Join(joinConf: api.Join, Success( rightResults .foldLeft(bootstrapDf.addTimebasedColIfExists()) { case (partialDf, (rightPart, rightDf)) => - joinWithLeft(partialDf, rightDf, rightPart) + JoinUtils.joinWithLeft(partialDf, rightDf, rightPart, tableUtils) } // drop all processing metadata columns .drop(Constants.MatchedHashes, Constants.TimePartitionColumn)) diff --git a/spark/src/main/scala/ai/chronon/spark/JoinBase.scala b/spark/src/main/scala/ai/chronon/spark/JoinBase.scala index 9eaa18a871..8a0a106c00 100644 --- a/spark/src/main/scala/ai/chronon/spark/JoinBase.scala +++ b/spark/src/main/scala/ai/chronon/spark/JoinBase.scala @@ -64,67 +64,6 @@ abstract class JoinBase(val joinConfCloned: api.Join, protected val tableProps: Map[String, String] = confTableProps ++ Map(Constants.SemanticHashKey -> gson.toJson(joinConfCloned.semanticHash.asJava)) - def joinWithLeft(leftDf: DataFrame, rightDf: DataFrame, joinPart: JoinPart): DataFrame = { - val partLeftKeys = joinPart.rightToLeft.values.toArray - - // compute join keys, besides the groupBy keys - like ds, ts etc., - val additionalKeys: Seq[String] = { - if (joinConfCloned.left.dataModel == ENTITIES) { - Seq(tableUtils.partitionColumn) - } else if (joinPart.groupBy.inferredAccuracy == Accuracy.TEMPORAL) { - Seq(Constants.TimeColumn, tableUtils.partitionColumn) - } else { // left-events + snapshot => join-key = ds_of_left_ts - Seq(Constants.TimePartitionColumn) - } - } - val keys = partLeftKeys ++ additionalKeys - - // apply prefix to value columns - val nonValueColumns = joinPart.rightToLeft.keys.toArray ++ Array(Constants.TimeColumn, - tableUtils.partitionColumn, - Constants.TimePartitionColumn) - val valueColumns = rightDf.schema.names.filterNot(nonValueColumns.contains) - val prefixedRightDf = rightDf.prefixColumnNames(joinPart.columnPrefix, valueColumns) - - // apply key-renaming to key columns - val newColumns = prefixedRightDf.columns.map { column => - if (joinPart.rightToLeft.contains(column)) { - col(column).as(joinPart.rightToLeft(column)) - } else { - col(column) - } - } - val keyRenamedRightDf = prefixedRightDf.select(newColumns: _*) - - // adjust join keys - val joinableRightDf = if (additionalKeys.contains(Constants.TimePartitionColumn)) { - // increment one day to align with left side ts_ds - // because one day was decremented from the partition range for snapshot accuracy - keyRenamedRightDf - .withColumn( - Constants.TimePartitionColumn, - date_format(date_add(to_date(col(tableUtils.partitionColumn), tableUtils.partitionSpec.format), 1), - tableUtils.partitionSpec.format) - ) - .drop(tableUtils.partitionColumn) - } else { - keyRenamedRightDf - } - - logger.info(s""" - |Join keys for ${joinPart.groupBy.metaData.name}: ${keys.mkString(", ")} - |Left Schema: - |${leftDf.schema.pretty} - |Right Schema: - |${joinableRightDf.schema.pretty}""".stripMargin) - val joinedDf = coalescedJoin(leftDf, joinableRightDf, keys) - logger.info(s"""Final Schema: - |${joinedDf.schema.pretty} - |""".stripMargin) - - joinedDf - } - def computeRange(leftDf: DataFrame, leftRange: PartitionRange, bootstrapInfo: BootstrapInfo, diff --git a/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala b/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala index 579379bfc7..146bf367c4 100644 --- a/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala @@ -18,6 +18,7 @@ package ai.chronon.spark import ai.chronon.api import ai.chronon.api._ +import ai.chronon.api.Constants import ai.chronon.api.DataModel.EVENTS import ai.chronon.api.Extensions._ import ai.chronon.api.ScalaJavaConversions._ @@ -166,6 +167,7 @@ object JoinUtils { } val joinedDf = leftDf.join(rightDf, keys.toSeq, joinType) + // find columns that exist both on left and right that are not keys and coalesce them val selects = keys.map(col) ++ leftDf.columns.flatMap { colName => @@ -474,4 +476,29 @@ object JoinUtils { def computeFullLeftSourceTableName(join: api.Join)(implicit tableUtils: TableUtils): String = { new JoinPlanner(join)(tableUtils.partitionSpec).leftSourceNode.metaData.outputTable } + + def joinWithLeft(leftDf: DataFrame, rightDf: DataFrame, joinPart: JoinPart, tableUtils: TableUtils): DataFrame = { + val nonValueColumns = joinPart.groupBy.keyColumns.toArray ++ Array(Constants.TimeColumn, + tableUtils.partitionColumn, + Constants.TimePartitionColumn, + Constants.RowIDColumn) + // apply prefix to value columns + val valueColumns = rightDf.schema.names.filterNot(nonValueColumns.contains) + val prefixedRightDf = rightDf.prefixColumnNames(joinPart.columnPrefix, valueColumns) + + logger.info(s""" + |Left Schema: + |${leftDf.schema.pretty} + |Right Schema: + |${prefixedRightDf.schema.pretty}""".stripMargin) + + val joinedDf = coalescedJoin(leftDf, prefixedRightDf, Seq(tableUtils.partitionColumn, Constants.RowIDColumn)) + + logger.info(s"""Final Schema: + |${joinedDf.schema.pretty} + |""".stripMargin) + + joinedDf + } + } diff --git a/spark/src/main/scala/ai/chronon/spark/batch/JoinPartJob.scala b/spark/src/main/scala/ai/chronon/spark/batch/JoinPartJob.scala index dcedb7947d..26faed6443 100644 --- a/spark/src/main/scala/ai/chronon/spark/batch/JoinPartJob.scala +++ b/spark/src/main/scala/ai/chronon/spark/batch/JoinPartJob.scala @@ -3,16 +3,16 @@ package ai.chronon.spark.batch import ai.chronon.api.DataModel.{ENTITIES, EVENTS} import ai.chronon.api.Extensions.{DateRangeOps, DerivationOps, GroupByOps, JoinPartOps, MetadataOps} import ai.chronon.api.PartitionRange.toTimeRange -import ai.chronon.api.ScalaJavaConversions.ListOps import ai.chronon.api._ import ai.chronon.online.metrics.Metrics import ai.chronon.planner.JoinPartNode import ai.chronon.spark.Extensions._ +import ai.chronon.spark.JoinUtils.coalescedJoin import ai.chronon.spark.catalog.TableUtils -import ai.chronon.spark.{GroupBy, JoinUtils} import ai.chronon.spark.join.UnionJoin +import ai.chronon.spark.{GroupBy, JoinUtils} import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.functions.{col, column, date_format} +import org.apache.spark.sql.functions._ import org.apache.spark.util.sketch.BloomFilter import org.slf4j.{Logger, LoggerFactory} @@ -43,11 +43,14 @@ class JoinPartJob(node: JoinPartNode, metaData: MetaData, range: DateRange, show val jobContext = context.getOrElse { // LeftTable is already computed by SourceJob, no need to apply query/filters/etc - val relevantLeftCols = - joinPart.rightToLeft.keys.toArray ++ Seq(tableUtils.partitionColumn) ++ (node.leftDataModel match { - case ENTITIES => None - case EVENTS => Some(Constants.TimeColumn) - }) + val entityCols = joinPart.rightToLeft.keys.toArray + val additionalCols = Seq(tableUtils.partitionColumn, Constants.RowIDColumn) + val timeCol = node.leftDataModel match { + case ENTITIES => None + case EVENTS => Some(Constants.TimeColumn) + } + + val relevantLeftCols = entityCols ++ additionalCols ++ timeCol val query = Builders.Query(selects = relevantLeftCols.map(t => t -> t).toMap) val cachedLeftDf = tableUtils.scanDf(query = query, leftTable, range = Some(dateRange)) @@ -82,8 +85,6 @@ class JoinPartJob(node: JoinPartNode, metaData: MetaData, range: DateRange, show // val partMetrics = Metrics.Context(metrics, joinPart) -- TODO is this metrics context sufficient, or should we pass thru for monolith join? val partMetrics = Metrics.Context(Metrics.Environment.JoinOffline, joinPart.groupBy) - val rightRange = JoinUtils.shiftDays(node.leftDataModel, joinPart, leftRange) - // Can kill the option after we deprecate monolith join job jobContext.leftDf.foreach { leftDf => try { @@ -95,8 +96,9 @@ class JoinPartJob(node: JoinPartNode, metaData: MetaData, range: DateRange, show // Cache join part data into intermediate table if (filledDf.isDefined) { - logger.info(s"Writing to join part table: $partTable for partition range $rightRange") - filledDf.get.save(partTable, jobContext.tableProps.toMap) + logger.info(s"Writing to join part table: $partTable for partition range $leftRange") + // Apply bucketing on row ID column if it exists in the DataFrame + filledDf.get.save(partTable, tableProperties = jobContext.tableProps.toMap, bucketByRowId = true) } else { logger.info(s"Skipping $partTable because no data in computed joinPart.") } @@ -112,7 +114,7 @@ class JoinPartJob(node: JoinPartNode, metaData: MetaData, range: DateRange, show } if (tableUtils.tableReachable(partTable)) { - Some(tableUtils.scanDf(query = null, partTable, range = Some(rightRange))) + Some(tableUtils.scanDf(query = null, partTable, range = Some(leftRange))) } else { // Happens when everything is handled by bootstrap None @@ -203,44 +205,62 @@ class JoinPartJob(node: JoinPartNode, metaData: MetaData, range: DateRange, show case c => renamedLeftRawDf.col(c) }.toList: _*) - val rightDf = (node.leftDataModel, joinPart.groupBy.dataModel, joinPart.groupBy.inferredAccuracy) match { - case (ENTITIES, EVENTS, _) => partitionRangeGroupBy.snapshotEvents(dateRange) - case (ENTITIES, ENTITIES, _) => partitionRangeGroupBy.snapshotEntities - case (EVENTS, EVENTS, Accuracy.SNAPSHOT) => - genGroupBy(shiftedPartitionRange).snapshotEvents(shiftedPartitionRange) - case (EVENTS, EVENTS, Accuracy.TEMPORAL) => - val skewFreeMode = tableUtils.sparkSession.conf - .get("spark.chronon.join.backfill.mode.skewFree", "false") - .toBoolean - - if (skewFreeMode) { - // Use UnionJoin for skewFree mode - it will handle column selection internally - logger.info(s"Using UnionJoin for TEMPORAL events join part: ${joinPart.groupBy.metaData.name}") - UnionJoin.computeJoinPart(renamedLeftDf, joinPart, unfilledPartitionRange, produceFinalJoinOutput = false) - } else { - // Use traditional temporalEvents approach - genGroupBy(unfilledPartitionRange).temporalEvents(renamedLeftDf, Some(toTimeRange(unfilledPartitionRange))) - } - - case (EVENTS, ENTITIES, Accuracy.SNAPSHOT) => genGroupBy(shiftedPartitionRange).snapshotEntities + // When we implement versioning on JoinPartJob, we can modify this to also include the reused columns + val colsToJoinFromLeft = Seq(Constants.RowIDColumn) + + // RightDF is the joinPart data, shouldJoinToLeft indicates whether we need to join it back to the left to extract + // additional `colsToJoinFromLeft` or not. For some compute modes, the relevant columns can be "passed through" computation + // And we don't need to join them back to the leftDf. + val (rightDf, shouldJoinToLeft) = + (node.leftDataModel, joinPart.groupBy.dataModel, joinPart.groupBy.inferredAccuracy) match { + case (ENTITIES, EVENTS, _) => (partitionRangeGroupBy.snapshotEvents(dateRange), true) + case (ENTITIES, ENTITIES, _) => (partitionRangeGroupBy.snapshotEntities, true) + case (EVENTS, EVENTS, Accuracy.SNAPSHOT) => + (genGroupBy(shiftedPartitionRange).snapshotEvents(shiftedPartitionRange), true) + case (EVENTS, EVENTS, Accuracy.TEMPORAL) => + val skewFreeMode = tableUtils.sparkSession.conf + .get("spark.chronon.join.backfill.mode.skewFree", "false") + .toBoolean + + if (skewFreeMode) { + // Use UnionJoin for skewFree mode - it will handle column selection internally + logger.info(s"Using UnionJoin for TEMPORAL events join part: ${joinPart.groupBy.metaData.name}") + (UnionJoin.computeJoinPart(renamedLeftDf, joinPart, unfilledPartitionRange, produceFinalJoinOutput = false), + false) + } else { + // Use traditional temporalEvents approach + // TODO: Modify temporalEvents to include row ID column on output, then we can return false for shouldJoinToLeft + (genGroupBy(unfilledPartitionRange).temporalEvents(renamedLeftDf, + Some(toTimeRange(unfilledPartitionRange))), + true) + } + + case (EVENTS, ENTITIES, Accuracy.SNAPSHOT) => (genGroupBy(shiftedPartitionRange).snapshotEntities, true) + + case (EVENTS, ENTITIES, Accuracy.TEMPORAL) => + // Snapshots and mutations are partitioned with ds holding data between and ds <23:59>. + // TODO: Modify temporalEntities to include row ID column on output, then we can return false for shouldJoinToLeft + (genGroupBy(shiftedPartitionRange).temporalEntities(renamedLeftDf), true) + } - case (EVENTS, ENTITIES, Accuracy.TEMPORAL) => - // Snapshots and mutations are partitioned with ds holding data between and ds <23:59>. - genGroupBy(shiftedPartitionRange).temporalEntities(renamedLeftDf) + val rightDfWithAllCols = if (shouldJoinToLeft) { + joinWithLeft(renamedLeftDf, rightDf, colsToJoinFromLeft) + } else { + rightDf } val rightDfWithDerivations = if (joinPart.groupBy.hasDerivations) { val finalOutputColumns = joinPart.groupBy.derivationsScala.finalOutputColumn( - rightDf.columns, - ensureKeys = joinPart.groupBy.keys(tableUtils.partitionColumn) + rightDfWithAllCols.columns, + ensureKeys = joinPart.groupBy.keys(tableUtils.partitionColumn) ++ Seq(Constants.RowIDColumn) ) - val result = rightDf.select(finalOutputColumns: _*) + val result = rightDfWithAllCols.select(finalOutputColumns: _*) result } else { - rightDf + rightDfWithAllCols } if (showDf) { @@ -250,4 +270,63 @@ class JoinPartJob(node: JoinPartNode, metaData: MetaData, range: DateRange, show Some(rightDfWithDerivations) } + + def joinWithLeft(leftDf: DataFrame, + rightDf: DataFrame, + additionalLeftColumnsToInclude: Seq[String] = Seq.empty): DataFrame = { + + // This join logic does not do any bucket hinting because it is on pre-bucketed data. + // The output of this will get bucketed and written, which MergeJob will benefit from. + val partLeftKeys = joinPart.rightToLeft.keys.toArray + + // compute join keys, besides the groupBy keys - like ds, ts etc., + val additionalKeys: Seq[String] = { + if (node.leftDataModel == ENTITIES) { + Seq(tableUtils.partitionColumn) + } else if (joinPart.groupBy.inferredAccuracy == Accuracy.TEMPORAL) { + Seq(Constants.TimeColumn, tableUtils.partitionColumn) + } else { // left-events + snapshot => join-key = ds_of_left_ts + Seq(Constants.TimePartitionColumn) + } + } + + val keys = partLeftKeys ++ additionalKeys + + val allLeftCols = keys ++ additionalLeftColumnsToInclude :+ tableUtils.partitionColumn + // Filter down left to only the columns that we want to keep on the joined output for the Joinpart + val leftDfWithRelevantCols = + if (node.leftDataModel == DataModel.EVENTS && !leftDf.columns.contains(Constants.TimePartitionColumn)) { + leftDf.withTimeBasedColumn(Constants.TimePartitionColumn) + } else { + leftDf + }.select(allLeftCols.map(column): _*) + + // adjust join keys + val joinableRightDf = if (additionalKeys.contains(Constants.TimePartitionColumn)) { + // increment one day to align with left side ts_ds + // because one day was decremented from the partition range for snapshot accuracy + rightDf + .withColumn( + Constants.TimePartitionColumn, + date_format(date_add(to_date(col(tableUtils.partitionColumn), tableUtils.partitionSpec.format), 1), + tableUtils.partitionSpec.format) + ) + } else { + rightDf + } + + logger.info(s""" + |Join keys for ${joinPart.groupBy.metaData.name}: ${keys.mkString(", ")} + |Left Schema: + |${leftDfWithRelevantCols.schema.pretty} + |Right Schema: + |${joinableRightDf.schema.pretty}""".stripMargin) + + val joinedDf = coalescedJoin(leftDfWithRelevantCols, joinableRightDf, keys) + logger.info(s"""Final Schema: + |${joinedDf.schema.pretty} + |""".stripMargin) + + joinedDf + } } diff --git a/spark/src/main/scala/ai/chronon/spark/batch/MergeJob.scala b/spark/src/main/scala/ai/chronon/spark/batch/MergeJob.scala index 8e18ed52ce..6c535b92ca 100644 --- a/spark/src/main/scala/ai/chronon/spark/batch/MergeJob.scala +++ b/spark/src/main/scala/ai/chronon/spark/batch/MergeJob.scala @@ -27,7 +27,7 @@ import org.slf4j.{Logger, LoggerFactory} import java.time.Instant import scala.collection.Seq -import scala.util.{Failure, Success} +import scala.util.{Failure, Success, Try} /** Result of analyzing join parts for reuse from production table. * @@ -57,7 +57,7 @@ class MergeJob(node: JoinMergeNode, metaData: MetaData, range: DateRange, joinPa // Processing metadata columns that get dropped in final output private val processingColumns = Set(Constants.MatchedHashes, Constants.TimePartitionColumn) - private val hashExclusionColumn: Set[String] = processingColumns ++ Set(tableUtils.partitionColumn) + private val hashExclusionColumns: Set[String] = processingColumns ++ Set(tableUtils.partitionColumn) private val archiveReuseTableSuffix = "_archive_reuse" private val colHashTablePropsKey = "column_hashes" @@ -79,7 +79,6 @@ class MergeJob(node: JoinMergeNode, metaData: MetaData, range: DateRange, joinPa val archiveReuseTable = outputTable + archiveReuseTableSuffix def run(): Unit = { - // Always check to see if we need to archive the current output table // Occurs when columns are changed/removed/added // Computed based on column level semantic hashing that occurs at compile time @@ -87,55 +86,56 @@ class MergeJob(node: JoinMergeNode, metaData: MetaData, range: DateRange, joinPa // This job benefits from a step day of 1 to avoid needing to shuffle on writing output (single partition) dateRange.steps(days = 1).foreach { dayStep => - // Scan left input table once to get schema and potentially reuse - val leftInputDf = tableUtils.scanDf(query = null, table = leftInputTable, range = Some(dayStep)) + val joinedDfTry = runDayStep(dayStep) + val tableProps = createTableProperties + joinedDfTry.get.save(outputTable, tableProperties = tableProps, autoExpand = true, bucketByRowId = true) + } + } - // Check if we can reuse columns from production table - val reuseAnalysis = analyzeJoinPartsForReuse(dayStep, leftInputDf) + def runDayStep(dayStep: PartitionRange): Try[DataFrame] = { + // Scan left input table once to get schema and potentially reuse + val leftInputDf = tableUtils.scanDf(query = null, table = leftInputTable, range = Some(dayStep)) - // Get left DataFrame with potentially reused columns from production - val leftDf = if (reuseAnalysis.reuseTable.isDefined) { - logger.info(s"Reusing ${reuseAnalysis.columnsToReuse.length} columns (${reuseAnalysis.columnsToReuse - .mkString(", ")}) from table: ${reuseAnalysis.reuseTable.get}") + // Check if we can reuse columns from production table + val reuseAnalysis = analyzeJoinPartsForReuse(dayStep, leftInputDf) - // Select left columns + reused columns from production table - val leftColumns = leftInputDf.schema.fieldNames.filterNot(processingColumns.contains) - val columnsToSelect = leftColumns ++ reuseAnalysis.columnsToReuse - val productionDf = tableUtils.scanDf(query = null, table = reuseAnalysis.reuseTable.get, range = Some(dayStep)) + // Get left DataFrame with potentially reused columns from production + val leftDf = if (reuseAnalysis.reuseTable.isDefined) { + logger.info(s"Reusing ${reuseAnalysis.columnsToReuse.length} columns (${reuseAnalysis.columnsToReuse + .mkString(", ")}) from table: ${reuseAnalysis.reuseTable.get}") - val selectedDf = productionDf.select(columnsToSelect.map(col): _*) + // Select left columns + reused columns from production table + val leftColumns = leftInputDf.schema.fieldNames.filterNot(processingColumns.contains) + val columnsToSelect = leftColumns ++ reuseAnalysis.columnsToReuse + val productionDf = tableUtils.scanDf(query = null, table = reuseAnalysis.reuseTable.get, range = Some(dayStep)) - // Add back ts_ds column if this is an EVENTS source and the column is missing - if (join.left.dataModel == DataModel.EVENTS && !selectedDf.columns.contains(Constants.TimePartitionColumn)) { - selectedDf.withTimeBasedColumn(Constants.TimePartitionColumn) - } else { - selectedDf - } + val selectedDf = productionDf.select(columnsToSelect.map(col): _*) + + // Add back ts_ds column if this is an EVENTS source and the column is missing + if (join.left.dataModel == DataModel.EVENTS && !selectedDf.columns.contains(Constants.TimePartitionColumn)) { + selectedDf.withTimeBasedColumn(Constants.TimePartitionColumn) } else { - leftInputDf + selectedDf } + } else { + leftInputDf + } - // Get right parts data only for join parts that need to be computed - val rightPartsData = getRightPartsData(dayStep, reuseAnalysis.joinPartsToCompute) - - val joinedDfTry = - try { - Success( - rightPartsData - .foldLeft(leftDf) { case (partialDf, (rightPart, rightDf)) => - joinWithLeft(partialDf, rightDf, rightPart) - } - // drop all processing metadata columns - .drop(Constants.MatchedHashes, Constants.TimePartitionColumn)) - } catch { - case e: Exception => - e.printStackTrace() - Failure(e) - } - - val tableProps = createTableProperties - - joinedDfTry.get.save(outputTable, tableProps, autoExpand = true) + // Get right parts data only for join parts that need to be computed + val rightPartsData = getRightPartsData(dayStep, reuseAnalysis.joinPartsToCompute) + + try { + Success( + rightPartsData + .foldLeft(leftDf) { case (partialDf, (rightPart, rightDf)) => + JoinUtils.joinWithLeft(partialDf, rightDf, rightPart, tableUtils) + } + // drop all processing metadata columns + .drop(Constants.MatchedHashes, Constants.TimePartitionColumn)) + } catch { + case e: Exception => + e.printStackTrace() + Failure(e) } } @@ -158,81 +158,13 @@ class MergeJob(node: JoinMergeNode, metaData: MetaData, range: DateRange, joinPa joinPartsToProcess.map { joinPart => // Use the RelevantLeftForJoinPart utility to get the part table name val partTable = RelevantLeftForJoinPart.fullPartTableName(join, joinPart) - val effectiveRange = - if (join.left.dataModel == DataModel.EVENTS && joinPart.groupBy.inferredAccuracy == Accuracy.SNAPSHOT) { - dayStep.shift(-1) - } else { - dayStep - } - val wheres = effectiveRange.whereClauses + val wheres = dayStep.whereClauses val sql = QueryUtils.build(null, partTable, wheres) logger.info(s"Pulling data from joinPart table with: $sql") (joinPart, tableUtils.scanDfBase(null, partTable, List.empty, wheres, None)) }.toSeq } - def joinWithLeft(leftDf: DataFrame, rightDf: DataFrame, joinPart: JoinPart): DataFrame = { - val partLeftKeys = joinPart.rightToLeft.values.toArray - - // compute join keys, besides the groupBy keys - like ds, ts etc., - val additionalKeys: Seq[String] = { - if (join.left.dataModel == ENTITIES) { - Seq(tableUtils.partitionColumn) - } else if (joinPart.groupBy.inferredAccuracy == Accuracy.TEMPORAL) { - Seq(Constants.TimeColumn, tableUtils.partitionColumn) - } else { // left-events + snapshot => join-key = ds_of_left_ts - Seq(Constants.TimePartitionColumn) - } - } - val keys = partLeftKeys ++ additionalKeys - - // apply prefix to value columns - val nonValueColumns = joinPart.rightToLeft.keys.toArray ++ Array(Constants.TimeColumn, - tableUtils.partitionColumn, - Constants.TimePartitionColumn) - val valueColumns = rightDf.schema.names.filterNot(nonValueColumns.contains) - val prefixedRightDf = rightDf.prefixColumnNames(joinPart.columnPrefix, valueColumns) - - // apply key-renaming to key columns - val newColumns = prefixedRightDf.columns.map { column => - if (joinPart.rightToLeft.contains(column)) { - col(column).as(joinPart.rightToLeft(column)) - } else { - col(column) - } - } - - val keyRenamedRightDf = prefixedRightDf.select(newColumns: _*) - - // adjust join keys - val joinableRightDf = if (additionalKeys.contains(Constants.TimePartitionColumn)) { - // increment one day to align with left side ts_ds - // because one day was decremented from the partition range for snapshot accuracy - keyRenamedRightDf - .withColumn( - Constants.TimePartitionColumn, - date_format(date_add(to_date(col(tableUtils.partitionColumn), tableUtils.partitionSpec.format), 1), - tableUtils.partitionSpec.format) - ) - .drop(tableUtils.partitionColumn) - } else { - keyRenamedRightDf - } - - logger.info(s""" - |Join keys for ${joinPart.groupBy.metaData.name}: ${keys.mkString(", ")} - |Left Schema: - |${leftDf.schema.pretty} - |Right Schema: - |${joinableRightDf.schema.pretty}""".stripMargin) - val joinedDf = coalescedJoin(leftDf, joinableRightDf, keys) - logger.info(s"""Final Schema: - |${joinedDf.schema.pretty} - |""".stripMargin) - - joinedDf - } - /** Check for columns that have mismatched semantic hashes between two hash maps * @param columns The columns to check * @param reuseTableColHashes Hash map from the reuse table @@ -355,12 +287,12 @@ class MergeJob(node: JoinMergeNode, metaData: MetaData, range: DateRange, joinPa // Check if left schemas are compatible using the production columns we just got // Also check semantic hashes to ensure left columns have matching semantics - val currentLeftColumns = currentLeftDf.schema.fieldNames.toSet -- hashExclusionColumn + val currentLeftColumns = currentLeftDf.schema.fieldNames.toSet -- hashExclusionColumns val leftSchemaMismatches = findMismatchedHashes(currentLeftColumns, reuseTableColHashes, currentColumnHashes) if (leftSchemaMismatches.nonEmpty) { logger.info( - s"Left columns have mismatched semantic hashes, cannot reuse from production table. Mismatched columns: ${leftSchemaMismatches - .mkString(", ")}") + s"Left columns have mismatched semantic hashes, cannot reuse from production table. Mismatched columns from $currentLeftColumns: ${leftSchemaMismatches + .mkString(", ")} - $reuseTableColHashes (reuse table) vs $currentColumnHashes (current)") return JoinPartReuseAnalysis(None, Seq.empty, joinParts) } @@ -379,7 +311,8 @@ class MergeJob(node: JoinMergeNode, metaData: MetaData, range: DateRange, joinPa logger.info(s"Join part $joinPartGroupByName schema: ${partSchema.pretty}") val partKeyColumns = joinPart.rightToLeft.keys.toSet ++ Set(Constants.TimeColumn, tableUtils.partitionColumn, - Constants.TimePartitionColumn) + Constants.TimePartitionColumn, + Constants.RowIDColumn) val partValueColumns = partSchema.fieldNames.filterNot(partKeyColumns.contains).map(joinPart.columnPrefix + _) @@ -414,4 +347,5 @@ class MergeJob(node: JoinMergeNode, metaData: MetaData, range: DateRange, joinPa JoinPartReuseAnalysis(Option(reuseTable), columnsToReuse.toSeq, joinPartsToRejoin.toSeq) } } + } diff --git a/spark/src/main/scala/ai/chronon/spark/batch/SourceJob.scala b/spark/src/main/scala/ai/chronon/spark/batch/SourceJob.scala index fe27dea6a0..903a8c6ae1 100644 --- a/spark/src/main/scala/ai/chronon/spark/batch/SourceJob.scala +++ b/spark/src/main/scala/ai/chronon/spark/batch/SourceJob.scala @@ -6,6 +6,7 @@ import ai.chronon.api.ScalaJavaConversions._ import ai.chronon.planner.SourceWithFilterNode import ai.chronon.spark.Extensions._ import ai.chronon.spark.catalog.TableUtils +import org.slf4j.{Logger, LoggerFactory} import scala.collection.{Map, Seq} import scala.jdk.CollectionConverters._ @@ -15,6 +16,8 @@ Runs and materializes a `Source` for a given `dateRange`. Used in the Join compu then each join may have a further Bootstrap computation to produce the left side for use in the final join step. */ class SourceJob(node: SourceWithFilterNode, metaData: MetaData, range: DateRange)(implicit tableUtils: TableUtils) { + @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) + private val sourceWithFilter = node private val dateRange = range.toPartitionRange(tableUtils.partitionSpec) private val outputTable = metaData.outputTable @@ -62,8 +65,17 @@ class SourceJob(node: SourceWithFilterNode, metaData: MetaData, range: DateRange df } - // Save using the provided outputTable or compute one if not provided - dfWithTimeCol.save(outputTable, tableProperties = metaData.tableProps) + // Assert that row ID column is present (should be injected at Python level) + require( + dfWithTimeCol.columns.contains(Constants.RowIDColumn), + s"Row ID column ${Constants.RowIDColumn} must be present in source data for bucketing support" + ) + + logger.info(s"Found row ID column ${Constants.RowIDColumn} in source data for bucketing") + val dfWithRowId = dfWithTimeCol + + // Save with bucketing on row ID column + dfWithRowId.save(outputTable, tableProperties = metaData.tableProps, bucketByRowId = true) } } diff --git a/spark/src/main/scala/ai/chronon/spark/catalog/CreationUtils.scala b/spark/src/main/scala/ai/chronon/spark/catalog/CreationUtils.scala index 5578e45165..6199db224c 100644 --- a/spark/src/main/scala/ai/chronon/spark/catalog/CreationUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/catalog/CreationUtils.scala @@ -11,7 +11,9 @@ object CreationUtils { partitionColumns: List[String], tableProperties: Map[String, String], fileFormatString: String, - tableTypeString: String): String = { + tableTypeString: String, + bucketColumnName: Option[String] = None, + bucketNumber: Option[Int] = None): String = { require( tableTypeString.isEmpty || ALLOWED_TABLE_TYPES.contains(tableTypeString.toLowerCase), @@ -24,25 +26,40 @@ object CreationUtils { val createFragment = s"""CREATE TABLE $tableName ( - | ${noPartitions.toDDL} + | ${schema.toDDL} |) |${if (tableTypeString.isEmpty) "" else f"USING ${tableTypeString}"} |""".stripMargin - val partitionFragment = if (partitionColumns != null && partitionColumns.nonEmpty) { + val partitionFragment = if ( + (partitionColumns != null && partitionColumns.nonEmpty) || (bucketColumnName.isDefined && bucketNumber.isDefined) + ) { - val partitionDefinitions = schema - .filter(field => partitionColumns.contains(field.name)) - .map(field => s"${field.name} ${field.dataType.catalogString}") + val partitionDefinitions = if (partitionColumns != null && partitionColumns.nonEmpty) { + schema + .filter(field => partitionColumns.contains(field.name)) + .map(field => s"${field.name}") + } else { + List.empty[String] + } s"""PARTITIONED BY ( - | ${partitionDefinitions.mkString(",\n ")} + | ${partitionDefinitions.mkString(",\n ")}${bucketColumnName + .map((bucketCol) => s",\nbucket(${bucketNumber.get}, ${bucketCol})") + .getOrElse("")} + | |)""".stripMargin - } else { "" } +// val bucketFragment = if (bucketColumnName.isDefined && bucketNumber.isDefined) { +// // Todo: add `SORTED BY (${bucketColumnName.get})` below? +// s"CLUSTERED BY (${bucketColumnName.get}) INTO ${bucketNumber.get} BUCKETS" +// } else { +// "" +// } + val propertiesFragment = if (tableProperties != null && tableProperties.nonEmpty) { s"""TBLPROPERTIES ( | ${(tableProperties + ("file_format" -> fileFormatString) + ("table_type" -> tableTypeString)) @@ -55,7 +72,6 @@ object CreationUtils { } Seq(createFragment, partitionFragment, propertiesFragment).mkString("\n") - } // Needs provider diff --git a/spark/src/main/scala/ai/chronon/spark/catalog/TableUtils.scala b/spark/src/main/scala/ai/chronon/spark/catalog/TableUtils.scala index 49d68698ba..f3a8c84d3a 100644 --- a/spark/src/main/scala/ai/chronon/spark/catalog/TableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/catalog/TableUtils.scala @@ -78,6 +78,9 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable val joinPartParallelism: Int = sparkSession.conf.get("spark.chronon.join.part.parallelism", "1").toInt + val internalRowIdColumnName: String = Constants.RowIDColumn + val rowIdClusterNumber: Int = 360 + sparkSession.sparkContext.setLogLevel("ERROR") def tableReachable(tableName: String, ignoreFailure: Boolean = false): Boolean = { @@ -214,13 +217,23 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable tableName: String, partitionColumns: List[String] = List.empty, tableProperties: Map[String, String] = null, - fileFormat: String): Unit = { + fileFormat: String, + bucketByRowId: Boolean = false): Unit = { if (!tableReachable(tableName, ignoreFailure = true)) { try { + val bucketColumnName = if (bucketByRowId) Some(internalRowIdColumnName) else None + val bucketNumber = if (bucketByRowId) Some(rowIdClusterNumber) else None sql( CreationUtils - .createTableSql(tableName, df.schema, partitionColumns, tableProperties, fileFormat, tableWriteFormat)) + .createTableSql(tableName, + df.schema, + partitionColumns, + tableProperties, + fileFormat, + tableWriteFormat, + bucketColumnName, + bucketNumber)) } catch { case _: TableAlreadyExistsException => logger.info(s"Table $tableName already exists, skipping creation") @@ -238,14 +251,15 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable partitionColumns: List[String] = List(partitionColumn), saveMode: SaveMode = SaveMode.Overwrite, fileFormat: String = "PARQUET", - autoExpand: Boolean = false): Unit = { + autoExpand: Boolean = false, + bucketByRowId: Boolean = false): Unit = { // partitions to the last val colOrder = df.columns.diff(partitionColumns) ++ partitionColumns val dfRearranged = df.select(colOrder.map(colName => df.col(QuotingUtils.quoteIdentifier(colName))): _*) - createTable(dfRearranged, tableName, partitionColumns, tableProperties, fileFormat) + createTable(dfRearranged, tableName, partitionColumns, tableProperties, fileFormat, bucketByRowId) if (autoExpand) { expandTable(tableName, dfRearranged.schema) @@ -281,8 +295,9 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable } logger.info(s"Writing to $tableName ...") - finalizedDf.write - .mode(saveMode) + val dataFrameWriter = finalizedDf.write.mode(saveMode) + + dataFrameWriter // Requires table to exist before inserting. // Fails if schema does not match. // Does NOT overwrite the schema. @@ -588,11 +603,14 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable // TODO: this is a temporary fix to handle the case where the partition column is not a string. // This is the case for partitioned BigQuery native tables. + /* (if (df.schema.fieldNames.contains(partitionColumn)) { df.withColumn(partitionColumn, date_format(df.col(partitionColumn), partitionFormat)) } else { df }).coalesce(coalesceFactor * parallelism) + */ + df } def whereClauses(partitionRange: PartitionRange, partitionColumn: String = partitionColumn): Seq[String] = { diff --git a/spark/src/main/scala/ai/chronon/spark/join/UnionJoin.scala b/spark/src/main/scala/ai/chronon/spark/join/UnionJoin.scala index 985ba402a3..58905c591d 100644 --- a/spark/src/main/scala/ai/chronon/spark/join/UnionJoin.scala +++ b/spark/src/main/scala/ai/chronon/spark/join/UnionJoin.scala @@ -124,9 +124,11 @@ object UnionJoin { val selectedLeftDf = if (produceFinalJoinOutput) { leftDf } else { - val keyColumns = joinPart.leftToRight.keys.toSeq :+ Constants.TimeColumn :+ tableUtils.partitionColumn + val keyColumns = + joinPart.leftToRight.keys.toSeq :+ Constants.TimeColumn :+ tableUtils.partitionColumn :+ Constants.RowIDColumn val existingColumns = leftDf.columns.toSet val columnsToSelect = keyColumns.filter(existingColumns.contains) + // If row_id is present, there should be no dupes. This will have the correct behavior downstream. leftDf.select(columnsToSelect.map(F.col): _*).dropDuplicates(keyColumns) } diff --git a/spark/src/main/scala/ai/chronon/spark/stats/CompareJob.scala b/spark/src/main/scala/ai/chronon/spark/stats/CompareJob.scala index 28a9dd3d37..c19d9f42ac 100644 --- a/spark/src/main/scala/ai/chronon/spark/stats/CompareJob.scala +++ b/spark/src/main/scala/ai/chronon/spark/stats/CompareJob.scala @@ -170,16 +170,12 @@ object CompareJob { } def getJoinKeys(joinConf: api.Join, tableUtils: TableUtils): Array[String] = { - if (joinConf.isSetRowIds) { - joinConf.rowIds.toScala.toArray + val leftPartitionCol = joinConf.left.query.partitionSpec(tableUtils.partitionSpec).column + val keyCols = Array(Constants.RowIDColumn, leftPartitionCol) + if (joinConf.left.dataModel == EVENTS) { + keyCols ++ Seq(Constants.TimeColumn) } else { - val leftPartitionCol = joinConf.left.query.partitionSpec(tableUtils.partitionSpec).column - val keyCols = joinConf.leftKeyCols :+ leftPartitionCol - if (joinConf.left.dataModel == EVENTS) { - keyCols ++ Seq(Constants.TimeColumn) - } else { - keyCols - } + keyCols } } } diff --git a/spark/src/main/scala/ai/chronon/spark/stats/ConsistencyJob.scala b/spark/src/main/scala/ai/chronon/spark/stats/ConsistencyJob.scala index d215d65e15..b1056abdf2 100644 --- a/spark/src/main/scala/ai/chronon/spark/stats/ConsistencyJob.scala +++ b/spark/src/main/scala/ai/chronon/spark/stats/ConsistencyJob.scala @@ -20,6 +20,7 @@ import ai.chronon import ai.chronon.api.Extensions._ import ai.chronon.api.ScalaJavaConversions._ import ai.chronon.api._ +import ai.chronon.api.Constants import ai.chronon.online.OnlineDerivationUtil.timeFields import ai.chronon.online.{fetcher, _} import ai.chronon.spark.Extensions._ @@ -49,6 +50,7 @@ class ConsistencyJob(session: SparkSession, joinConf: Join, endDate: String) ext val mapping = joinConf.leftKeyCols.map(k => k -> k) val selects = new util.HashMap[String, String]() mapping.foreach { case (key, value) => selects.put(key, value) } + selects.put(Constants.RowIDColumn, Constants.RowIDColumn) query.setSelects(selects) query.setTimeColumn(Constants.TimeColumn) query.setStartPartition(joinConf.left.query.startPartition) diff --git a/spark/src/main/scala/ai/chronon/spark/submission/ChrononKryoRegistrator.scala b/spark/src/main/scala/ai/chronon/spark/submission/ChrononKryoRegistrator.scala index 6d03bbc2de..6c4c7f524d 100644 --- a/spark/src/main/scala/ai/chronon/spark/submission/ChrononKryoRegistrator.scala +++ b/spark/src/main/scala/ai/chronon/spark/submission/ChrononKryoRegistrator.scala @@ -104,6 +104,30 @@ class ChrononKryoRegistrator extends KryoRegistrator { override def registerClasses(kryo: Kryo): Unit = { // kryo.setWarnUnregisteredClasses(true) val names = Seq( + "org.apache.iceberg.DataFile", + "org.apache.iceberg.FileContent", + "org.apache.iceberg.FileFormat", + "org.apache.iceberg.GenericDataFile", + "org.apache.iceberg.PartitionData", + "org.apache.iceberg.SerializableByteBufferMap", + "org.apache.iceberg.SerializableTable$SerializableConfSupplier", + "org.apache.iceberg.SnapshotRef", + "org.apache.iceberg.SnapshotRefType", + "org.apache.iceberg.encryption.PlaintextEncryptionManager", + "org.apache.iceberg.gcp.GCPProperties", + "org.apache.iceberg.hadoop.HadoopFileIO", + "org.apache.iceberg.hadoop.HadoopMetricsContext", + "org.apache.iceberg.MetadataTableType", + "org.apache.iceberg.io.ResolvingFileIO", + "org.apache.iceberg.spark.source.SerializableTableWithSize", + "org.apache.iceberg.spark.source.SerializableTableWithSize$SerializableMetadataTableWithSize", + "org.apache.iceberg.spark.source.SparkWrite$TaskCommit", + "org.apache.iceberg.types.Types$DateType", + "org.apache.iceberg.types.Types$NestedField", + "org.apache.iceberg.types.Types$StringType", + "org.apache.iceberg.types.Types$StructType", + "org.apache.iceberg.types.Types$IntegerType", + "org.apache.iceberg.util.SerializableMap", "ai.chronon.aggregator.base.ApproxHistogramIr", "ai.chronon.aggregator.base.MomentsIR", "ai.chronon.aggregator.base.UniqueOrderByLimit$State", diff --git a/spark/src/test/scala/ai/chronon/spark/test/DataFrameGen.scala b/spark/src/test/scala/ai/chronon/spark/test/DataFrameGen.scala index 76f39883c5..9176603c59 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/DataFrameGen.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/DataFrameGen.scala @@ -64,14 +64,22 @@ object DataFrameGen { count: Int, partitions: Int, partitionColumn: Option[String] = None, - partitionFormat: Option[String] = None): DataFrame = { + partitionFormat: Option[String] = None, + addRowID: Boolean = true): DataFrame = { val tableUtils = TableUtils(spark) val partitionColumnString = partitionColumn.getOrElse(tableUtils.partitionColumn) val partitionFormatString = partitionFormat.getOrElse(tableUtils.partitionFormat) val generated = gen(spark, columns :+ Column(Constants.TimeColumn, LongType, partitions), count, partitionColumn, partitionFormat) - generated.withColumn(partitionColumnString, - from_unixtime(generated.col(Constants.TimeColumn) / 1000, partitionFormatString)) + val generatedWithPartition = generated.withColumn( + partitionColumnString, + from_unixtime(generated.col(Constants.TimeColumn) / 1000, partitionFormatString)) + + if (addRowID) { + generatedWithPartition.withColumn(Constants.RowIDColumn, uuid()) + } else { + generatedWithPartition + } } // Generates Entity data @@ -80,13 +88,21 @@ object DataFrameGen { count: Int, partitions: Int, partitionColumn: Option[String] = None, - partitionFormat: Option[String] = None): DataFrame = { - val partitionColumnString = partitionColumn.getOrElse(TableUtils(spark).partitionColumn) - gen(spark, - columns :+ Column(partitionColumnString, StringType, partitions), - count, - partitionColumn, - partitionFormat) + partitionFormat: Option[String] = None, + addRowID: Boolean = true): DataFrame = { + val tableUtils = TableUtils(spark) + val partitionColumnString = partitionColumn.getOrElse(tableUtils.partitionColumn) + val generated = gen(spark, + columns :+ Column(partitionColumnString, StringType, partitions), + count, + partitionColumn, + partitionFormat) + + if (addRowID) { + generated.withColumn(Constants.RowIDColumn, uuid()) + } else { + generated + } } /** Mutations and snapshots generation. diff --git a/spark/src/test/scala/ai/chronon/spark/test/ExternalSourcesTest.scala b/spark/src/test/scala/ai/chronon/spark/test/ExternalSourcesTest.scala index eba55d20f1..eec6ebcc7b 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/ExternalSourcesTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/ExternalSourcesTest.scala @@ -132,7 +132,8 @@ class ExternalSourcesTest extends AnyFlatSpec { "number", "str", "context_1", - "context_2" + "context_2", + Constants.RowIDColumn ), schema.keyFields.fields.map(_.name).toSet ) diff --git a/spark/src/test/scala/ai/chronon/spark/test/MigrationCompareTest.scala b/spark/src/test/scala/ai/chronon/spark/test/MigrationCompareTest.scala index 5a443a697d..26eb5e4382 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/MigrationCompareTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/MigrationCompareTest.scala @@ -107,7 +107,7 @@ class MigrationCompareTest extends AnyFlatSpec with BeforeAndAfterAll { // Run the staging query to generate the corresponding table for comparison val stagingQueryConf = Builders.StagingQuery( - query = s"select item, ts, ds from ${joinConf.metaData.outputTable}", + query = s"select item, ts, ds, ${tableUtils.internalRowIdColumnName} from ${joinConf.metaData.outputTable}", startPartition = ninetyDaysAgo, metaData = Builders.MetaData(name = "test.item_snapshot_features_sq_4", namespace = namespace, diff --git a/spark/src/test/scala/ai/chronon/spark/test/TestUtils.scala b/spark/src/test/scala/ai/chronon/spark/test/TestUtils.scala index b7ace29405..65f2550053 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/TestUtils.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/TestUtils.scala @@ -20,13 +20,14 @@ import ai.chronon.aggregator.test.Column import ai.chronon.api import ai.chronon.api.ScalaJavaConversions._ import ai.chronon.api._ +import ai.chronon.api.Constants import ai.chronon.online.serde.SparkConversions import ai.chronon.spark.Extensions._ import ai.chronon.spark.catalog.TableUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Row import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.{col, uuid} object TestUtils { def createViewsGroupBy(namespace: String, @@ -415,14 +416,19 @@ object TestUtils { val viewsDf = DataFrameGen .events(spark, viewsCols, 30, 7) .filter(col("user").isNotNull && col("listing").isNotNull) + .withColumn(Constants.RowIDColumn, uuid()) viewsDf.show() viewsDf.save(viewsTable) val joinConf = Builders.Join( left = Builders.Source.events( - Builders.Query(startPartition = "2023-06-01", selects = Builders.Selects("listing", "user")), + Builders.Query(startPartition = "2023-06-01", + selects = Builders.Selects.exprs("listing" -> "listing", + "user" -> "user", + Constants.RowIDColumn -> Constants.RowIDColumn)), table = viewsTable, - topic = topic), + topic = topic + ), joinParts = Seq( Builders.JoinPart(groupBy = priceGroupBy) ), diff --git a/spark/src/test/scala/ai/chronon/spark/test/analyzer/DerivationBootstrapTest.scala b/spark/src/test/scala/ai/chronon/spark/test/analyzer/DerivationBootstrapTest.scala index cee6b8c420..f8a4e61ca8 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/analyzer/DerivationBootstrapTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/analyzer/DerivationBootstrapTest.scala @@ -138,7 +138,7 @@ class DerivationBootstrapTest extends AnyFlatSpec { ) val runner = new ai.chronon.spark.Join(baseJoin, today, tableUtils) - val outputDf = runner.computeJoin() + val outputDf = runner.computeJoin().drop(tableUtils.internalRowIdColumnName) assertTrue( outputDf.columns.toSet == Set( @@ -248,7 +248,7 @@ class DerivationBootstrapTest extends AnyFlatSpec { ) val runner2 = new ai.chronon.spark.Join(bootstrapJoin, today, tableUtils) - val computed = runner2.computeJoin() + val computed = runner2.computeJoin().drop(tableUtils.internalRowIdColumnName) // Comparison val expected = outputDf @@ -352,7 +352,7 @@ class DerivationBootstrapTest extends AnyFlatSpec { ) val runner = new ai.chronon.spark.Join(joinConf, today, tableUtils) - val outputDf = runner.computeJoin() + val outputDf = runner.computeJoin().drop(tableUtils.internalRowIdColumnName) // assert that no computation happened for join part since all derivations have been bootstrapped assertFalse(tableUtils.tableReachable(joinConf.partOutputTable(joinPart))) diff --git a/spark/src/test/scala/ai/chronon/spark/test/batch/EvalTest.scala b/spark/src/test/scala/ai/chronon/spark/test/batch/EvalTest.scala index 593fb571e4..b58a685af6 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/batch/EvalTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/batch/EvalTest.scala @@ -308,8 +308,14 @@ class EvalTest extends AnyFlatSpec { assertEquals(result.getExternalPartsSchema.toScala, Map("ext_return_one_value_number" -> "IntType")) - assertEquals(result.getLeftQuerySchema.toScala, - Map("user_name" -> "StringType", "user" -> "StringType", "ts" -> "LongType", "ds" -> "StringType")) + assertEquals( + result.getLeftQuerySchema.toScala, + Map("user_name" -> "StringType", + "user" -> "StringType", + "ts" -> "LongType", + "ds" -> "StringType", + tableUtils.internalRowIdColumnName -> "StringType") + ) } it should "evaluate staging query schema successfully" in { diff --git a/spark/src/test/scala/ai/chronon/spark/test/batch/MergeJobAnalyzeReuseTest.scala b/spark/src/test/scala/ai/chronon/spark/test/batch/MergeJobAnalyzeReuseTest.scala index 51954e1c2e..01e2cc1116 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/batch/MergeJobAnalyzeReuseTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/batch/MergeJobAnalyzeReuseTest.scala @@ -6,6 +6,7 @@ import ai.chronon.api.ScalaJavaConversions.JIterableOps import ai.chronon.api.{ BooleanType, Builders, + Constants, DateRange, DoubleType, IntType, @@ -145,14 +146,15 @@ class MergeJobAnalyzeReuseTest extends AnyFlatSpec { SparkStructField("ts", SparkLongType), SparkStructField("ds", SparkStringType), SparkStructField("price_user_price_sum", SparkDoubleType), - SparkStructField("user_quantity_count", SparkLongType) + SparkStructField("user_quantity_count", SparkLongType), + SparkStructField(Constants.RowIDColumn, SparkStringType) )) val productionData = spark.createDataFrame( spark.sparkContext.parallelize( Seq( - SparkRow("user1", "item1", 1000L, monthAgo, 100.0, 5L), - SparkRow("user2", "item2", 2000L, monthAgo, 200.0, 10L) + SparkRow("user1", "item1", 1000L, monthAgo, 100.0, 5L, "row_id_1"), + SparkRow("user2", "item2", 2000L, monthAgo, 200.0, 10L, "row_id_2") )), productionSchema ) @@ -166,7 +168,8 @@ class MergeJobAnalyzeReuseTest extends AnyFlatSpec { "item" -> "hash_item", "ts" -> "hash_ts", "price_user_price_sum" -> "hash_price_user_price_sum", - "user_quantity_count" -> "hash_user_quantity_count" + "user_quantity_count" -> "hash_user_quantity_count", + Constants.RowIDColumn -> "hash_chronon_row_id" ) productionData.save(productionTable) @@ -225,6 +228,7 @@ class MergeJobAnalyzeReuseTest extends AnyFlatSpec { // Create left DataFrame for schema compatibility check val leftDf = DataFrameGen.events(spark, leftSchema, 10, 1) + leftDf.show() // Create MergeJob with production join reference val mergeNode = new JoinMergeNode() @@ -243,7 +247,7 @@ class MergeJobAnalyzeReuseTest extends AnyFlatSpec { "user" -> "hash_user", "item" -> "hash_item", "ts" -> "hash_ts", - "ds" -> "hash_ds", + Constants.RowIDColumn -> "hash_chronon_row_id", "price_user_price_sum" -> "hash_price_user_price_sum", "user_quantity_count" -> "hash_user_quantity_count" ).asJava) @@ -368,7 +372,8 @@ class MergeJobAnalyzeReuseTest extends AnyFlatSpec { "item" -> "hash_item", "ts" -> "hash_ts", "ds" -> "hash_ds", - "user_price_sum" -> "hash_user_price_sum" + "user_price_sum" -> "hash_user_price_sum", + Constants.RowIDColumn -> "hash_chronon_row_id" // Note: rating column hash is missing, so it can't be reused ) @@ -446,7 +451,8 @@ class MergeJobAnalyzeReuseTest extends AnyFlatSpec { "ts" -> "hash_ts", "ds" -> "hash_ds", "user_price_sum" -> "hash_user_price_sum", - "user_rating_average" -> "hash_user_rating_average" // This won't match production table + "user_rating_average" -> "hash_user_rating_average", // This won't match production table + Constants.RowIDColumn -> "hash_chronon_row_id" ).asJava // Test the analyzeJoinPartsForReuse method directly @@ -571,7 +577,8 @@ class MergeJobAnalyzeReuseTest extends AnyFlatSpec { "item" -> "hash_item", "ts" -> "hash_ts", "ds" -> "hash_ds", - "user_price_sum" -> "hash_user_price_sum_v0" // v0 hash + "user_price_sum" -> "hash_user_price_sum_v0", // v0 hash + Constants.RowIDColumn -> "hash_chronon_row_id" ) productionData.save(productionTable) @@ -626,7 +633,8 @@ class MergeJobAnalyzeReuseTest extends AnyFlatSpec { "item" -> "hash_item", "ts" -> "hash_ts", "ds" -> "hash_ds", - "user_price_sum" -> "hash_user_price_sum_v1" // v1 hash - different from production + "user_price_sum" -> "hash_user_price_sum_v1", // v1 hash - different from production + Constants.RowIDColumn -> "hash_chronon_row_id" ).asJava // Test the analyzeJoinPartsForReuse method directly diff --git a/spark/src/test/scala/ai/chronon/spark/test/batch/MergeJobVersioningTest.scala b/spark/src/test/scala/ai/chronon/spark/test/batch/MergeJobVersioningTest.scala index 7202a89bc3..ffbe230786 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/batch/MergeJobVersioningTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/batch/MergeJobVersioningTest.scala @@ -4,6 +4,7 @@ import ai.chronon.aggregator.test.Column import ai.chronon.api.Extensions._ import ai.chronon.api._ import ai.chronon.api.planner.RelevantLeftForJoinPart + import scala.collection.JavaConverters._ import com.google.gson.Gson import ai.chronon.planner.{JoinMergeNode, JoinPartNode, SourceWithFilterNode} @@ -22,6 +23,7 @@ class MergeJobVersioningTest extends AnyFlatSpec { import ai.chronon.spark.submission val spark: SparkSession = submission.SparkSessionBuilder.build("MergeJobVersioningTest", local = true) + import spark.implicits._ private implicit val tableUtils: TableTestUtils = TableTestUtils(spark) private val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) @@ -137,6 +139,7 @@ class MergeJobVersioningTest extends AnyFlatSpec { "user" -> "hash_user", "item" -> "hash_item", "ts" -> "hash_ts", + Constants.RowIDColumn -> Constants.RowIDColumn, "shared_user_price_sum" -> "hash_shared_user_price_sum", "removed_user_quantity_count" -> "hash_removed_user_quantity_count" ) @@ -157,6 +160,9 @@ class MergeJobVersioningTest extends AnyFlatSpec { val v0SourceRunner = new SourceJob(v0LeftSourceWithFilter, v0SourceMetaData, dateRange) v0SourceRunner.run() + val v0SourceRowIDs: Set[String] = + tableUtils.scanDf(null, v0SourceOutputTable, None).select(col(Constants.RowIDColumn)).as[String].collect().toSet + // 1b. Run join part jobs for v0 (shared and removed) for (joinPart <- Seq(sharedJoinPart, removedJoinPart)) { val partTableName = RelevantLeftForJoinPart.partTableName(joinV0, joinPart) @@ -170,7 +176,7 @@ class MergeJobVersioningTest extends AnyFlatSpec { .setJoinPart(joinPart) val joinPartJob = new JoinPartJob(joinPartNode, partMetaData, dateRange) - joinPartJob.run() + val result = joinPartJob.run() } // 1c. Run merge job for v0 @@ -185,11 +191,8 @@ class MergeJobVersioningTest extends AnyFlatSpec { v0MergeJob.run() // Step 2: Manually modify production table with literal values, maintaining partitioning - val productionTable = joinV0.metaData.outputTable - val existingProductionData = tableUtils.scanDf(null, productionTable, None) - - existingProductionData.show() - print(existingProductionData.schema.pretty) + val v0productionTable = joinV0.metaData.outputTable + val existingProductionData = tableUtils.scanDf(null, v0productionTable, None) val sharedColumnName = s"shared_user_price_sum" val removedColumnName = s"removed_user_quantity_count" @@ -206,41 +209,32 @@ class MergeJobVersioningTest extends AnyFlatSpec { // Use proper partition overwrite to maintain partitioning productionDataWithLiterals.write .mode(SaveMode.Overwrite) - .insertInto(productionTable) - - // Step 3: Run source job for v1 - val sourceOutputTable = JoinUtils.computeFullLeftSourceTableName(joinV1) - val sourceParts = sourceOutputTable.split("\\.", 2) - val sourceNamespace = sourceParts(0) - val sourceName = sourceParts(1) - - val sourceMetaData = new MetaData() - .setName(sourceName) - .setOutputNamespace(sourceNamespace) + .insertInto(v0productionTable) - tableUtils.sql(f"SELECT * from $leftTable").show() - tableUtils.sql(f"SELECT distinct ds from $leftTable order by ds desc").show(100) - - val leftSourceWithFilter = new SourceWithFilterNode().setSource(joinV1.left) - val sourceRunner = new SourceJob(leftSourceWithFilter, sourceMetaData, dateRange) - sourceRunner.run() + // Step 3: Should not need to run source job for v1 + val v1SourceOutputTable = JoinUtils.computeFullLeftSourceTableName(joinV1) + assert(v1SourceOutputTable == v0SourceOutputTable, + "Source table names should match for v1/v0 -- left semantics same.") // Step 4: Run join part job for the added GroupBy only (shared will be reused from production) val addedPartTableName = RelevantLeftForJoinPart.partTableName(joinV1, addedJoinPart) val addedPartFullTableName = RelevantLeftForJoinPart.fullPartTableName(joinV1, addedJoinPart) + // Compute added joinPart val addedPartMetaData = new MetaData() .setName(addedPartTableName) .setOutputNamespace(joinV1.metaData.outputNamespace) val addedJoinPartNode = new JoinPartNode() - .setLeftSourceTable(sourceOutputTable) + .setLeftSourceTable(v1SourceOutputTable) .setLeftDataModel(joinV1.getLeft.dataModel) .setJoinPart(addedJoinPart) val addedJoinPartJob = new JoinPartJob(addedJoinPartNode, addedPartMetaData, dateRange) addedJoinPartJob.run() + println(s"Added join part table created: $addedPartFullTableName") + // Step 5: Run MergeJob with production join reference val mergeNode = new JoinMergeNode() .setJoin(joinV1) @@ -258,6 +252,7 @@ class MergeJobVersioningTest extends AnyFlatSpec { "user" -> "hash_user", "item" -> "hash_item", "ts" -> "hash_ts", + Constants.RowIDColumn -> Constants.RowIDColumn, "shared_user_price_sum" -> "hash_shared_user_price_sum", // This should match production table "added_user_rating_average" -> "hash_added_user_rating_average" // This won't be in production ).asJava) @@ -265,8 +260,10 @@ class MergeJobVersioningTest extends AnyFlatSpec { mergeJob.run() // Step 6: Verify results - val resultTable = joinV1.metaData.outputTable - val result = tableUtils.scanDf(null, resultTable, None) + val v1ResultTable = joinV1.metaData.outputTable + val result = tableUtils.scanDf(null, v1ResultTable, None) + println(s"V1 Join output: $v1ResultTable") + result.show() val resultRows = result.collect() assertTrue("Should have results", resultRows.length > 0) @@ -379,6 +376,7 @@ class MergeJobVersioningTest extends AnyFlatSpec { "user" -> "hash_user", "product" -> "hash_product", "ts" -> "hash_ts", + Constants.RowIDColumn -> Constants.RowIDColumn, "user_user_price_sum" -> "hash_user_user_price_sum", "product_product_rating_average" -> "hash_product_product_rating_average" ) @@ -470,6 +468,7 @@ class MergeJobVersioningTest extends AnyFlatSpec { "user" -> "hash_user", "product" -> "hash_product", "ts" -> "hash_ts", + Constants.RowIDColumn -> Constants.RowIDColumn, "user_user_price_sum" -> "hash_modified_user_user_price_sum", // Different hash due to filter change "product_product_rating_average" -> "hash_product_product_rating_average" // Same hash (unchanged) ).asJava) @@ -581,7 +580,6 @@ class MergeJobVersioningTest extends AnyFlatSpec { println(s"Test passed! Archive table: ${archiveTable}") println(s"Result schema: ${result.columns.mkString(", ")}") println(s"Reused column distinct values: ${reusedValues.length}") - println(s"Computed column non-null values: ${userValues.length}") } it should "archive current table but not reuse any columns when left time column changes" in { @@ -651,7 +649,7 @@ class MergeJobVersioningTest extends AnyFlatSpec { left = Builders.Source.events( table = leftTable, query = Builders.Query( - selects = Builders.Selects("user", "product"), + selects = Builders.Selects("user", "product", Constants.RowIDColumn), timeColumn = "ts", // Simple time column startPartition = start ) @@ -667,6 +665,7 @@ class MergeJobVersioningTest extends AnyFlatSpec { val originalColumnHashes = Map( "user" -> "hash_user_original", "product" -> "hash_product_original", + Constants.RowIDColumn -> "__chronon_row_id__original", "user_user_price_sum" -> "hash_user_user_price_sum_original", "product_product_rating_average" -> "hash_product_product_rating_average_original" ) @@ -727,7 +726,7 @@ class MergeJobVersioningTest extends AnyFlatSpec { left = Builders.Source.events( table = leftTable, query = Builders.Query( - selects = Builders.Selects("user", "product"), + selects = Builders.Selects("user", "product", Constants.RowIDColumn), timeColumn = "CAST(ts AS DOUBLE)", // Different time column expression - changes ALL semantic hashes startPartition = start ) @@ -742,6 +741,7 @@ class MergeJobVersioningTest extends AnyFlatSpec { Map( "user" -> "hash_user_new_time", // Different hash due to time column change "product" -> "hash_product_new_time", // Different hash due to time column change + Constants.RowIDColumn -> "__chronon_row_id__new_time", "user_user_price_sum" -> "hash_user_user_price_sum_new_time", // Different hash due to time column change "product_product_rating_average" -> "hash_product_product_rating_average_new_time" // Different hash due to time column change ).asJava) diff --git a/spark/src/test/scala/ai/chronon/spark/test/batch/ModularJoinTest.scala b/spark/src/test/scala/ai/chronon/spark/test/batch/ModularJoinTest.scala index 24f5e32a74..952bcfc178 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/batch/ModularJoinTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/batch/ModularJoinTest.scala @@ -18,11 +18,21 @@ class ModularJoinTest extends AnyFlatSpec { import ai.chronon.spark.submission - val spark: SparkSession = submission.SparkSessionBuilder.build("ModularJoinTest", local = true) + val spark: SparkSession = submission.SparkSessionBuilder.build( + "ModularJoinTest", + hiveSupport = false, + local = true, + additionalConfig = Some( + Map( + "spark.sql.sources.bucketing.enabled" -> "true", + "spark.sql.bucketing.coalesceBucketsInJoin.enabled" -> "true", + "spark.sql.autoBroadcastJoinThreshold" -> "-1" // Disable broadcast joins to force bucketed joins + )) + ) private implicit val tableUtils: TableTestUtils = TableTestUtils(spark) private val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) - val start = tableUtils.partitionSpec.minus(today, new Window(60, TimeUnit.DAYS)) + val start = tableUtils.partitionSpec.minus(today, new Window(35, TimeUnit.DAYS)) private val monthAgo = tableUtils.partitionSpec.minus(today, new Window(30, TimeUnit.DAYS)) private val yearAgo = tableUtils.partitionSpec.minus(today, new Window(365, TimeUnit.DAYS)) private val dayAndMonthBefore = tableUtils.partitionSpec.before(monthAgo) @@ -116,7 +126,7 @@ class ModularJoinTest extends AnyFlatSpec { val queryTable = s"$namespace.queries" DataFrameGen - .events(spark, queriesSchema, 4000, partitions = 100, partitionColumn = Some("date")) + .events(spark, queriesSchema, 4000, partitions = 100, partitionColumn = Some("date"), addRowID = true) .save(queryTable, partitionColumns = Seq("date")) // Make bootstrap part and table @@ -148,7 +158,7 @@ class ModularJoinTest extends AnyFlatSpec { query = Builders.Query( selects = Builders.Selects("user", "ts", "unit_test_user_transactions_amount_dollars_sum_10d"), startPartition = start, - endPartition = today + endPartition = monthAgo ), table = s"$namespace.bootstrap", keyColumns = Seq("user", "ts") @@ -210,12 +220,13 @@ class ModularJoinTest extends AnyFlatSpec { val sourceJobRange = new DateRange() .setStartDate(start) - .setEndDate(today) + .setEndDate(monthAgo) val sourceRunner = new SourceJob(leftSourceWithFilter, sourceMetaData, sourceJobRange) sourceRunner.run() tableUtils.sql(s"SELECT * FROM $sourceOutputTable").show() - val sourceExpected = spark.sql(s"SELECT *, date as ds FROM $queryTable WHERE date >= '$start' AND date <= '$today'") + val sourceExpected = + spark.sql(s"SELECT *, date as ds FROM $queryTable WHERE date >= '$start' AND date <= '$monthAgo'") val sourceComputed = tableUtils.sql(s"SELECT * FROM $sourceOutputTable").drop("ts_ds") val diff = Comparison.sideBySide(sourceComputed, sourceExpected, List("user_name", "user", "ts")) if (diff.count() > 0) { @@ -231,7 +242,7 @@ class ModularJoinTest extends AnyFlatSpec { val bootstrapOutputTable = joinConf.metaData.bootstrapTable val bootstrapJobRange = new DateRange() .setStartDate(start) - .setEndDate(today) + .setEndDate(monthAgo) // Split bootstrap output table val bootstrapParts = bootstrapOutputTable.split("\\.", 2) @@ -257,6 +268,7 @@ class ModularJoinTest extends AnyFlatSpec { "user", "ts", "user_name", + Constants.RowIDColumn, "ts_ds", "matched_hashes", "unit_test_user_transactions_amount_dollars_sum_10d", @@ -275,7 +287,7 @@ class ModularJoinTest extends AnyFlatSpec { val joinPartJobRange = new DateRange() .setStartDate(start) - .setEndDate(today) + .setEndDate(monthAgo) // Create metadata with name and namespace directly val metaData = new api.MetaData() @@ -314,7 +326,7 @@ class ModularJoinTest extends AnyFlatSpec { val mergeJobRange = new DateRange() .setStartDate(start) - .setEndDate(today) + .setEndDate(monthAgo) // Create metadata for merge job val mergeMetaData = new api.MetaData() @@ -333,7 +345,7 @@ class ModularJoinTest extends AnyFlatSpec { val derivationRange = new DateRange() .setStartDate(start) - .setEndDate(today) + .setEndDate(monthAgo) // Split derivation output table val derivationParts = derivationOutputTable.split("\\.", 2) @@ -365,7 +377,7 @@ class ModularJoinTest extends AnyFlatSpec { | AND ts IS NOT NULL | AND date IS NOT NULL | AND date >= '$start' - | AND date <= '$today') + | AND date <= '$monthAgo') | SELECT | queries.user, | queries.ts, diff --git a/spark/src/test/scala/ai/chronon/spark/test/batch/ShortNamesTest.scala b/spark/src/test/scala/ai/chronon/spark/test/batch/ShortNamesTest.scala index 171fd0246e..148c9567af 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/batch/ShortNamesTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/batch/ShortNamesTest.scala @@ -342,6 +342,7 @@ class ShortNamesTest extends AnyFlatSpec { "user", "ts", "user_name", + Constants.RowIDColumn, "ts_ds", "matched_hashes", "user_amount_dollars_sum_10d", diff --git a/spark/src/test/scala/ai/chronon/spark/test/batch/StoragePartitionJoinTest.scala b/spark/src/test/scala/ai/chronon/spark/test/batch/StoragePartitionJoinTest.scala new file mode 100644 index 0000000000..e5d5f8940c --- /dev/null +++ b/spark/src/test/scala/ai/chronon/spark/test/batch/StoragePartitionJoinTest.scala @@ -0,0 +1,350 @@ +package ai.chronon.spark.test.batch + +import ai.chronon.aggregator.test.Column +import ai.chronon.api +import ai.chronon.api.Extensions._ +import ai.chronon.api._ +import ai.chronon.api.planner.RelevantLeftForJoinPart +import ai.chronon.planner.{JoinMergeNode, JoinPartNode, SourceWithFilterNode} +import ai.chronon.spark.Extensions._ +import ai.chronon.spark._ +import ai.chronon.spark.batch._ +import ai.chronon.spark.test.{DataFrameGen, TableTestUtils} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.exchange.Exchange +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.junit.Assert._ +import org.scalatest.flatspec.AnyFlatSpec + +import java.nio.file.Files + +class StoragePartitionJoinTest extends AnyFlatSpec { + + import ai.chronon.spark.submission + + val correctConfigsForSPJ = Map( + // V1 bucketing configurations + "spark.sql.sources.bucketing.enabled" -> "false", + "spark.sql.bucketing.coalesceBucketsInJoin.enabled" -> "false", + "spark.sql.autoBroadcastJoinThreshold" -> "-1", // Disable broadcast joins to force bucketed joins + "spark.sql.adaptive.enabled" -> "true", + + // V2 bucketing configurations for Iceberg + "spark.sql.sources.v2.bucketing.enabled" -> "true", + "spark.sql.sources.v2.bucketing.pushPartValues.enabled" -> "true", + "spark.sql.iceberg.planning.preserve-data-grouping" -> "true", + "spark.sql.requireAllClusterKeysForCoPartition" -> "false", + "spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled" -> "true", + + // Iceberg catalog configurations + "spark.sql.catalog.spark_catalog" -> "org.apache.iceberg.spark.SparkSessionCatalog", + "spark.sql.catalog.spark_catalog.type" -> "hadoop", + "spark.sql.catalog.spark_catalog.warehouse" -> Files + .createTempDirectory("storage-partition-join-test") + .toString, + "spark.sql.extensions" -> "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions", + + // Hive metastore configurations for Iceberg + "javax.jdo.option.ConnectionURL" -> "jdbc:derby:memory:metastore_db;create=true", + "javax.jdo.option.ConnectionDriverName" -> "org.apache.derby.jdbc.EmbeddedDriver", + "datanucleus.schema.autoCreateAll" -> "true", + "datanucleus.schema.autoCreateTables" -> "true", + "datanucleus.schema.autoCreateColumns" -> "true", + "datanucleus.schema.autoCreateConstraints" -> "true", + "spark.sql.join.preferSortMergeJoin" -> "false", + "hive.metastore.schema.verification" -> "false", + "hive.metastore.schema.verification.record.version" -> "false", + "hive.metastore.uris" -> "", + "hive.metastore.warehouse.dir" -> "file:///tmp/hive-warehouse", + "spark.chronon.table_write.format" -> "iceberg" + ) + + val spark: SparkSession = submission.SparkSessionBuilder.build( + "StoragePartitionJoinTest", + hiveSupport = true, + local = true, + additionalConfig = Some(correctConfigsForSPJ) + ) + private implicit val tableUtils: TableTestUtils = TableTestUtils(spark) + + private val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) + val start = tableUtils.partitionSpec.minus(today, new Window(35, TimeUnit.DAYS)) + private val monthAgo = tableUtils.partitionSpec.minus(today, new Window(30, TimeUnit.DAYS)) + private val yearAgo = tableUtils.partitionSpec.minus(today, new Window(365, TimeUnit.DAYS)) + + private val namespace = "test_namespace_storage_partition_join" + tableUtils.createDatabase(namespace) + + def setSPJConfigs(enable: Boolean): Unit = { + val value = enable.toString + + val spjConfigs = Map( + "spark.sql.sources.bucketing.enabled" -> value, + "spark.sql.bucketing.coalesceBucketsInJoin.enabled" -> value, + "spark.sql.adaptive.enabled" -> value, + "spark.sql.sources.v2.bucketing.enabled" -> value, + "spark.sql.sources.v2.bucketing.pushPartValues.enabled" -> value, + "spark.sql.iceberg.planning.preserve-data-grouping" -> value, + "spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled" -> value, + "spark.sql.requireAllClusterKeysForCoPartition" -> (!enable).toString // this one is inverted + ) + + spjConfigs.foreach { case (key, v) => + spark.conf.set(key, v) + } + } + + def getPhysicalPlan(df: org.apache.spark.sql.DataFrame): String = { + val physicalPlan = df.queryExecution.executedPlan.toString + println(s"=== Physical plan ===") + println(physicalPlan) + println("=" * 50) + physicalPlan + } + + /** Verifies that a DataFrame's physical plan doesn't contain shuffles (Exchange operators) + * and uses bucketed joins as expected. + */ + def verifyNoShuffle(df: org.apache.spark.sql.DataFrame, testName: String): Unit = { + val physicalPlan = getPhysicalPlan(df) + + // Assert no shuffles (Exchange operators) + assertFalse( + s"Physical plan should not contain Exchange (shuffle) for $testName", + physicalPlan.contains("Exchange") + ) + + // Assert successful bucketed join - look for SortMergeJoin or similar + val hasBucketedJoin = physicalPlan.contains("SortMergeJoin") + + assertTrue( + s"Physical plan should contain a join operator for $testName", + hasBucketedJoin + ) + + println(s"✓ $testName: No shuffles detected in physical plan") + } + + it should "test toy example" in { + val left = + s""" + |CREATE TABLE target (id INT, salary INT, dep STRING) + |USING iceberg + |PARTITIONED BY (dep, bucket(4, id)) + """.stripMargin + val right = + s""" + |CREATE TABLE source (id INT, salary INT, dep STRING) + |USING iceberg + |PARTITIONED BY (dep, bucket(4, id)) + |""".stripMargin + spark.sql(left) + spark.sql(right) + + // Insert dummy data into target table + spark.sql(""" + INSERT INTO target VALUES + (1, 50000, 'engineering'), + (2, 60000, 'marketing'), + (3, 55000, 'engineering'), + (4, 45000, 'sales'), + (5, 70000, 'marketing') + """) + + // Insert dummy data into source table + spark.sql(""" + INSERT INTO source VALUES + (1, 52000, 'engineering'), + (2, 58000, 'marketing'), + (6, 48000, 'sales'), + (7, 65000, 'engineering'), + (8, 40000, 'hr') + """) + + val tgt = spark.sql("SELECT * FROM target") + val src = spark.sql("SELECT * FROM source") + + assertNoExchange(tgt.join(src, tgt("dep") === src("dep") && tgt("id") === src("id"), "inner")) + + } + + private def countExchanges(plan: SparkPlan): Int = { + plan.collect { case _: Exchange => 1 }.sum + } + + private def assertNoExchange(df: DataFrame, message: String = ""): Unit = { + val plan = df.queryExecution.executedPlan + val exchangeCount = countExchanges(plan) + assert(exchangeCount == 0, + s"Expected no Exchange operators but found $exchangeCount. $message\nPlan:\n${plan.toString}") + } + + private def setupAndGetMergeDF(tableUtils: TableTestUtils) = { + implicit val tu: TableTestUtils = tableUtils + // Step 1: Setup test data - similar to ModularJoinTest but simplified with single joinPart + val userTransactions = List( + Column("user", StringType, 20), + Column("user_name", api.StringType, 20), + Column("ts", LongType, 200), + Column("amount_dollars", LongType, 1000) + ) + + val userTransactionTable = s"$namespace.user_transactions" + spark.sql(s"DROP TABLE IF EXISTS $userTransactionTable") + // Create larger dataset to ensure we don't hit broadcast threshold + DataFrameGen.entities(spark, userTransactions, 10000, partitions = 100).save(userTransactionTable) + + // Create the GroupBy source for the right side + val transactionSource = Builders.Source.entities( + query = Builders.Query( + selects = Builders.Selects("ts", "amount_dollars", "user_name", "user"), + startPartition = yearAgo, + endPartition = monthAgo + ), + snapshotTable = userTransactionTable + ) + + // Create a simple GroupBy with single aggregation + val groupBy = Builders.GroupBy( + sources = Seq(transactionSource), + keyColumns = Seq("user"), + aggregations = Seq( + Builders.Aggregation( + operation = Operation.SUM, + inputColumn = "amount_dollars", + windows = Seq(new Window(30, TimeUnit.DAYS)) + ) + ), + metaData = Builders.MetaData(name = "unit_test.user_transaction_sum", namespace = namespace, team = "chronon") + ) + + // Create queries (left side) with row IDs for bucketing + val queriesSchema = List( + Column("user", api.StringType, 20) + ) + + val queryTable = s"$namespace.user_queries" + DataFrameGen + .events(spark, queriesSchema, 5000, partitions = 50, addRowID = true) + .save(queryTable) + + // Create the join configuration with single joinPart + val joinPart = Builders.JoinPart(groupBy = groupBy) + + val joinConf: ai.chronon.api.Join = Builders.Join( + left = Builders.Source.events( + query = Builders.Query(startPartition = start), + table = queryTable + ), + joinParts = Seq(joinPart), // Single joinPart only + metaData = + Builders.MetaData(name = "test.storage_partition_join_features", namespace = namespace, team = "chronon") + ) + + // Step 2: Run SourceJob to create bucketed left source table + val leftSourceWithFilter = new SourceWithFilterNode().setSource(joinConf.left) + val sourceOutputTable = JoinUtils.computeFullLeftSourceTableName(joinConf) + + println(s"Source output table: $sourceOutputTable") + + // Split the output table to get namespace and name + val sourceParts = sourceOutputTable.split("\\.", 2) + val sourceNamespace = sourceParts(0) + val sourceName = sourceParts(1) + + // Create metadata for source job + val sourceMetaData = new api.MetaData() + .setName(sourceName) + .setOutputNamespace(sourceNamespace) + + val sourceJobRange = new DateRange() + .setStartDate(start) + .setEndDate(monthAgo) + + val sourceRunner = new SourceJob(leftSourceWithFilter, sourceMetaData, sourceJobRange) + sourceRunner.run() + + val sourceRowCount = tableUtils.sql(s"SELECT * FROM $sourceOutputTable").count() + println(s"Source table row count: $sourceRowCount") + assertTrue("Source table should have data", sourceRowCount > 0) + + // Step 3: Run JoinPartJob to create bucketed right table + val joinPartTableName = RelevantLeftForJoinPart.partTableName(joinConf, joinPart) + val joinPartFullTableName = RelevantLeftForJoinPart.fullPartTableName(joinConf, joinPart) + val outputNamespace = joinConf.metaData.outputNamespace + + println(s"JoinPart output table: $joinPartFullTableName") + + val joinPartJobRange = new DateRange() + .setStartDate(start) + .setEndDate(monthAgo) + + // Create metadata for join part job + val joinPartMetaData = new api.MetaData() + .setName(joinPartTableName) + .setOutputNamespace(outputNamespace) + + val joinPartNode = new JoinPartNode() + .setLeftSourceTable(sourceOutputTable) + .setLeftDataModel(joinConf.getLeft.dataModel) + .setJoinPart(joinPart) + + val joinPartJob = new JoinPartJob(joinPartNode, joinPartMetaData, joinPartJobRange) + joinPartJob.run() + + val joinPartRowCount = tableUtils.sql(s"SELECT * FROM $joinPartFullTableName").count() + println(s"JoinPart table row count: $joinPartRowCount") + assertTrue("JoinPart table should have data", joinPartRowCount > 0) + + // Step 4: Call MergeJob.runDayStep to get joined DataFrame (instead of full MergeJob.run) + val mergeNode = new JoinMergeNode() + .setJoin(joinConf) + + val mergeMetaData = new api.MetaData() + .setName(joinConf.metaData.name) + .setOutputNamespace(namespace) + + val mergeJob = new MergeJob(mergeNode, mergeMetaData, joinPartJobRange, Seq(joinPart))(tableUtils) + + // Create a single day step for testing + val dayStep = joinPartJobRange.toPartitionRange(tableUtils.partitionSpec).steps(days = 1).head + println(s"Testing day step: ${dayStep.start} to ${dayStep.end}") + + // Call runDayStep directly to get the DataFrame + val joinedDfTry = mergeJob.runDayStep(dayStep) + assertTrue("MergeJob.runDayStep should succeed", joinedDfTry.isSuccess) + + joinedDfTry.get + + } + + it should "test storage partition bucketing with no shuffle in join" in { + val joinedDf = setupAndGetMergeDF(tableUtils) + + joinedDf.show() + joinedDf.explain(true) + + // Step 5: Analyze physical plan to verify no shuffles + verifyNoShuffle(joinedDf, "Storage Partition Bucketed Join") + + println("✓ Storage partition bucketing join test completed successfully!") + println("✓ No shuffles detected - bucketing optimization is working!") + } + + it should "NOT storage partitioned join due to incorrect configuration" in { + setSPJConfigs(false) + val tableUtils: TableTestUtils = TableTestUtils(spark) + + val joinedDf = setupAndGetMergeDF(tableUtils) + + joinedDf.show() + joinedDf.explain(true) + + // Step 5: Analyze physical plan to verify no shuffles + assertThrows[AssertionError](verifyNoShuffle(joinedDf, "Storage Partition Bucketed Join")) + + println("✓ Shuffles detected as expected when incorrect configs are used") + setSPJConfigs(true) + } +} diff --git a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/LogBootstrapTest.scala b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/LogBootstrapTest.scala index fc8dff6211..8f1eb20a7e 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/LogBootstrapTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/LogBootstrapTest.scala @@ -19,6 +19,7 @@ package ai.chronon.spark.test.bootstrap import ai.chronon.api.Extensions._ import ai.chronon.api.ScalaJavaConversions._ import ai.chronon.api._ +import ai.chronon.api.Constants import ai.chronon.online.fetcher.Fetcher.Request import ai.chronon.spark.Comparison import ai.chronon.spark.Extensions._ @@ -173,7 +174,7 @@ class LogBootstrapTest extends AnyFlatSpec { ) val joinJob = new ai.chronon.spark.Join(joinV2, endDs, tableUtils) - val computed = joinJob.computeJoin() + val computed = joinJob.computeJoin().drop(Constants.RowIDColumn) val overlapCount = baseOutput.join(logDf, Seq("request_id", "ds")).count() logger.info(s"""Debugging information: diff --git a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/TableBootstrapTest.scala b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/TableBootstrapTest.scala index d4698c264c..0253ac13e7 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/TableBootstrapTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/TableBootstrapTest.scala @@ -19,6 +19,7 @@ package ai.chronon.spark.test.bootstrap import ai.chronon.api.Extensions.JoinOps import ai.chronon.api.ScalaJavaConversions._ import ai.chronon.api._ +import ai.chronon.api.Constants import ai.chronon.spark.Comparison import ai.chronon.spark.Extensions._ import ai.chronon.spark.catalog.TableUtils @@ -124,7 +125,7 @@ class TableBootstrapTest extends AnyFlatSpec { // Runs through boostrap backfill which combines backfill and bootstrap val runner2 = new ai.chronon.spark.Join(bootstrapJoin, today, tableUtils) - val computed = runner2.computeJoin() + val computed = runner2.computeJoin().drop(Constants.RowIDColumn) // Comparison val expected = baseOutput diff --git a/spark/src/test/scala/ai/chronon/spark/test/fetcher/ChainingFetcherTest.scala b/spark/src/test/scala/ai/chronon/spark/test/fetcher/ChainingFetcherTest.scala index 956726a1d0..d2c5a95d78 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/fetcher/ChainingFetcherTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/fetcher/ChainingFetcherTest.scala @@ -64,16 +64,17 @@ class ChainingFetcherTest extends AnyFlatSpec { StructField("user", LongType), StructField("listing", LongType), StructField("ts", LongType), - StructField("ds", StringType) + StructField("ds", StringType), + StructField(Constants.RowIDColumn, StringType) ) ) val viewsData = Seq( - Row(12L, 59L, toTs("2021-04-15 10:00:00"), "2021-04-15"), - Row(12L, 1L, toTs("2021-04-15 08:00:00"), "2021-04-15"), - Row(12L, 123L, toTs("2021-04-15 12:00:00"), "2021-04-15"), - Row(88L, 1L, toTs("2021-04-15 11:00:00"), "2021-04-15"), - Row(88L, 59L, toTs("2021-04-15 01:10:00"), "2021-04-15"), - Row(88L, 456L, toTs("2021-04-15 12:00:00"), "2021-04-15") + Row(12L, 59L, toTs("2021-04-15 10:00:00"), "2021-04-15", "A"), + Row(12L, 1L, toTs("2021-04-15 08:00:00"), "2021-04-15", "B"), + Row(12L, 123L, toTs("2021-04-15 12:00:00"), "2021-04-15", "C"), + Row(88L, 1L, toTs("2021-04-15 11:00:00"), "2021-04-15", "D"), + Row(88L, 59L, toTs("2021-04-15 01:10:00"), "2021-04-15", "E"), + Row(88L, 456L, toTs("2021-04-15 12:00:00"), "2021-04-15", "F") ) // {listing, ts, rating, ds} val ratingSchema = StructType( @@ -121,7 +122,7 @@ class ChainingFetcherTest extends AnyFlatSpec { val leftSource = Builders.Source.events( query = Builders.Query( - selects = Builders.Selects("user", "listing", "ts"), + selects = Builders.Selects("user", "listing", "ts", Constants.RowIDColumn), startPartition = startPartition ), table = s"$namespace.${viewsSchema.name}" @@ -154,19 +155,22 @@ class ChainingFetcherTest extends AnyFlatSpec { // User search listing event. Schema: user, listing, ts, ds val searchSchema = StructType( "user_search_listing_event", - Array(StructField("user", LongType), - StructField("listing", LongType), - StructField("ts", LongType), - StructField("ds", StringType)) + Array( + StructField("user", LongType), + StructField("listing", LongType), + StructField("ts", LongType), + StructField("ds", StringType), + StructField(Constants.RowIDColumn, StringType) + ) // row id for chaining) ) val searchData = Seq( - Row(12L, 59L, toTs("2021-04-18 10:00:00"), "2021-04-18"), - Row(12L, 123L, toTs("2021-04-18 13:45:00"), "2021-04-18"), - Row(88L, 1L, toTs("2021-04-18 00:10:00"), "2021-04-18"), - Row(88L, 59L, toTs("2021-04-18 23:10:00"), "2021-04-18"), - Row(88L, 456L, toTs("2021-04-18 03:10:00"), "2021-04-18"), - Row(68L, 123L, toTs("2021-04-17 23:55:00"), "2021-04-18") + Row(12L, 59L, toTs("2021-04-18 10:00:00"), "2021-04-18", "A"), + Row(12L, 123L, toTs("2021-04-18 13:45:00"), "2021-04-18", "B"), + Row(88L, 1L, toTs("2021-04-18 00:10:00"), "2021-04-18", "C"), + Row(88L, 59L, toTs("2021-04-18 23:10:00"), "2021-04-18", "D"), + Row(88L, 456L, toTs("2021-04-18 03:10:00"), "2021-04-18", "E"), + Row(68L, 123L, toTs("2021-04-17 23:55:00"), "2021-04-18", "F") ).toList TestUtils.makeDf(spark, searchSchema, searchData).save(s"$namespace.${searchSchema.name}") @@ -194,7 +198,7 @@ class ChainingFetcherTest extends AnyFlatSpec { val leftSource = Builders.Source.events( query = Builders.Query( - selects = Builders.Selects("user", "listing", "ts"), + selects = Builders.Selects("user", "listing", "ts", Constants.RowIDColumn), startPartition = startPartition, endPartition = endPartition ), @@ -289,11 +293,13 @@ class ChainingFetcherTest extends AnyFlatSpec { responseDf.show() // remove user during comparison since `user` is not the key - val diff = Comparison.sideBySide(responseDf.drop(ignoreCol), - expectedDf.drop(ignoreCol), - keyishColumns, - aName = "online", - bName = "offline") + val diff = Comparison.sideBySide( + responseDf.drop(ignoreCol).drop(Constants.RowIDColumn), + expectedDf.drop(ignoreCol).drop(Constants.RowIDColumn), + keyishColumns, + aName = "online", + bName = "offline" + ) assertEquals(expectedDf.count(), responseDf.count()) if (diff.count() > 0) { logger.info(s"Total count: ${responseDf.count()}") diff --git a/spark/src/test/scala/ai/chronon/spark/test/fetcher/FetcherFailureTest.scala b/spark/src/test/scala/ai/chronon/spark/test/fetcher/FetcherFailureTest.scala index a9a34c475b..88a610fd19 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/fetcher/FetcherFailureTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/fetcher/FetcherFailureTest.scala @@ -82,7 +82,8 @@ class FetcherFailureTest extends AnyFlatSpec { val responseMap = responses.head.values.get val exceptionKeys = joinConf.joinPartOps.map(jp => jp.columnPrefix + "exception") - println(responseMap) + println(responseMap.keys.toSeq) + println(exceptionKeys) exceptionKeys.foreach(k => assertTrue(responseMap.contains(k))) } diff --git a/spark/src/test/scala/ai/chronon/spark/test/fetcher/FetcherTestUtil.scala b/spark/src/test/scala/ai/chronon/spark/test/fetcher/FetcherTestUtil.scala index c4caa9c060..139a590722 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/fetcher/FetcherTestUtil.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/fetcher/FetcherTestUtil.scala @@ -3,12 +3,13 @@ package ai.chronon.spark.test.fetcher import ai.chronon.aggregator.test.Column import ai.chronon.api import ai.chronon.api.Builders.Derivation -import ai.chronon.api.Constants.MetadataDataset -import ai.chronon.api.Extensions.{JoinOps, MetadataOps} +import ai.chronon.api.Constants.{ContextualSourceName, MetadataDataset} +import ai.chronon.api.Extensions.{JoinOps, MetadataOps, SourceOps} import ai.chronon.api.{ Accuracy, BooleanType, Builders, + Constants, DoubleType, IntType, ListType, @@ -19,8 +20,7 @@ import ai.chronon.api.{ StructType, TimeUnit, TsUtils, - Window, - Constants + Window } import ai.chronon.api.ScalaJavaConversions._ import ai.chronon.online._ @@ -171,10 +171,15 @@ object FetcherTestUtil { val mockApi = new MockApi(kvStoreFunc, namespace) mockApi.setFlagStore(tilingEnabledFlagStore) + println("left table:") + tableUtils.loadTable(joinConf.left.table).show() + val joinedDf = new ai.chronon.spark.Join(joinConf, endDs, tableUtils).computeJoin() val joinTable = s"$namespace.join_test_expected_${joinConf.metaData.cleanName}" joinedDf.save(joinTable) val endDsExpected = tableUtils.sql(s"SELECT * FROM $joinTable WHERE ds='$endDs'") + println("join result:") + endDsExpected.show() joinConf.joinParts.toScala.foreach(jp => OnlineUtils.serve(tableUtils, @@ -192,7 +197,7 @@ object FetcherTestUtil { s"SELECT * FROM $joinTable WHERE ts >= unix_timestamp('$endDs', '${tableUtils.partitionSpec.format}')") } // Keep only left-side columns (keys, ts, ds) and drop all feature columns - val keys = joinConf.leftKeyCols + val keys = joinConf.leftKeyCols :+ Constants.RowIDColumn val leftSideColumns = keys ++ Array(Constants.TimeColumn, tableUtils.partitionColumn) val columnsToKeep = endDsEvents.schema.fieldNames.filter(leftSideColumns.contains) val endDsQueries = endDsEvents.select(columnsToKeep.map(col): _*) @@ -295,11 +300,11 @@ object FetcherTestUtil { tableUtils.createDatabase(namespace) def toTs(arg: String): Long = TsUtils.datetimeToTs(arg) val eventData = Seq( - Row(595125622443733822L, toTs("2021-04-10 09:00:00"), "2021-04-10"), - Row(595125622443733822L, toTs("2021-04-10 23:00:00"), "2021-04-10"), // Query for added event - Row(595125622443733822L, toTs("2021-04-10 23:45:00"), "2021-04-10"), // Query for mutated event - Row(1L, toTs("2021-04-10 00:10:00"), "2021-04-10"), // query for added event - Row(1L, toTs("2021-04-10 03:10:00"), "2021-04-10") // query for mutated event + Row(595125622443733822L, toTs("2021-04-10 09:00:00"), "2021-04-10", "A"), + Row(595125622443733822L, toTs("2021-04-10 23:00:00"), "2021-04-10", "B"), // Query for added event + Row(595125622443733822L, toTs("2021-04-10 23:45:00"), "2021-04-10", "C"), // Query for mutated event + Row(1L, toTs("2021-04-10 00:10:00"), "2021-04-10", "D"), // query for added event + Row(1L, toTs("2021-04-10 03:10:00"), "2021-04-10", "E") // query for mutated event ) val snapshotData = Seq( Row(1L, toTs("2021-04-04 00:30:00"), 4, "2021-04-08"), @@ -350,12 +355,15 @@ object FetcherTestUtil { ) // {..., event (generic event column), ...} - val eventSchema = StructType("listing_events_fetcher", - Array( - StructField("listing_id", LongType), - StructField("ts", LongType), - StructField("ds", StringType) - )) + val eventSchema = StructType( + "listing_events_fetcher", + Array( + StructField("listing_id", LongType), + StructField("ts", LongType), + StructField("ds", StringType), + StructField(Constants.RowIDColumn, StringType) + ) + ) val sourceData: Map[StructType, Seq[Row]] = Map( eventSchema -> eventData, @@ -389,7 +397,7 @@ object FetcherTestUtil { val leftSource = Builders.Source.events( query = Builders.Query( - selects = Builders.Selects("listing_id", "ts"), + selects = Builders.Selects("listing_id", "ts", Constants.RowIDColumn), startPartition = startPartition ), table = s"$namespace.${eventSchema.name}" @@ -433,8 +441,8 @@ object FetcherTestUtil { // Create manual struct data for UniqueTopK testing val eventData = Seq( - Row(1L, toTs("2021-04-10 09:00:00"), "2021-04-10"), - Row(2L, toTs("2021-04-10 23:00:00"), "2021-04-10") + Row(1L, toTs("2021-04-10 09:00:00"), "2021-04-10", "A"), + Row(2L, toTs("2021-04-10 23:00:00"), "2021-04-10", "B") ) val structData = Seq( @@ -466,7 +474,10 @@ object FetcherTestUtil { // Event schema val eventSchema = StructType( "listing_events_struct", - Array(StructField("listing_id", LongType), StructField("ts", LongType), StructField("ds", StringType)) + Array(StructField("listing_id", LongType), + StructField("ts", LongType), + StructField("ds", StringType), + StructField(Constants.RowIDColumn, StringType)) ) // Struct snapshot schema @@ -532,7 +543,7 @@ object FetcherTestUtil { val leftSource = Builders.Source.events( query = Builders.Query( - selects = Builders.Selects("listing_id", "ts"), + selects = Builders.Selects("listing_id", "ts", Constants.RowIDColumn), startPartition = "2021-04-01" ), table = s"$namespace.${eventSchema.name}" @@ -734,7 +745,19 @@ object FetcherTestUtil { val joinConf = Builders .Join( - left = Builders.Source.events(Builders.Query(startPartition = today), table = queriesTable), + left = Builders.Source.events( + Builders.Query( + startPartition = today, + selects = Builders.Selects.exprs( + "user_id" -> "user_id", + "vendor_id" -> "vendor_id", + "ts" -> "ts", + "ds" -> "ds", + Constants.RowIDColumn -> Constants.RowIDColumn + ) + ), + table = queriesTable + ), joinParts = Seq( Builders .JoinPart(groupBy = vendorRatingsGroupBy, keyMapping = Map("vendor_id" -> "vendor")) @@ -773,8 +796,8 @@ object FetcherTestUtil { def toTs(arg: String): Long = TsUtils.datetimeToTs(arg) val listingEventData = Seq( - Row(1L, toTs("2021-04-10 03:10:00"), "2021-04-10"), - Row(2L, toTs("2021-04-10 03:10:00"), "2021-04-10") + Row(1L, toTs("2021-04-10 03:10:00"), "2021-04-10", "A"), + Row(2L, toTs("2021-04-10 03:10:00"), "2021-04-10", "B") ) val ratingEventData = Seq( // 1L listing id event data @@ -799,12 +822,15 @@ object FetcherTestUtil { ) // Schemas // {..., event (generic event column), ...} - val listingsSchema = StructType("listing_events_fetcher", - Array( - StructField("listing_id", LongType), - StructField("ts", LongType), - StructField("ds", StringType) - )) + val listingsSchema = StructType( + "listing_events_fetcher", + Array( + StructField("listing_id", LongType), + StructField("ts", LongType), + StructField("ds", StringType), + StructField(Constants.RowIDColumn, StringType) + ) + ) val ratingsSchema = StructType( "listing_ratings_fetcher", @@ -835,7 +861,7 @@ object FetcherTestUtil { val leftSource = Builders.Source.events( query = Builders.Query( - selects = Builders.Selects("listing_id", "ts"), + selects = Builders.Selects("listing_id", "ts", Constants.RowIDColumn), startPartition = startPartition ), table = s"$namespace.${listingsSchema.name}" diff --git a/spark/src/test/scala/ai/chronon/spark/test/groupby/GroupByTest.scala b/spark/src/test/scala/ai/chronon/spark/test/groupby/GroupByTest.scala index c9bc244a6e..0c403f2b32 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/groupby/GroupByTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/groupby/GroupByTest.scala @@ -601,11 +601,12 @@ class GroupByTest extends AnyFlatSpec { println("Expected input DF: ") expectedInputDf.show() println("Computed input DF: ") - newGroupBy.inputDf.show() + val computedInputDf = newGroupBy.inputDf.drop(Constants.RowIDColumn) + computedInputDf.show() - val diff = Comparison.sideBySide(newGroupBy.inputDf, expectedInputDf, List("listing", "user", "ds")) + val diff = Comparison.sideBySide(computedInputDf, expectedInputDf, List("listing", "user", "ds")) if (diff.count() > 0) { - println(s"Actual count: ${newGroupBy.inputDf.count()}") + println(s"Actual count: ${computedInputDf.count()}") println(s"Expected count: ${expectedInputDf.count()}") println(s"Diff count: ${diff.count()}") diff.show() diff --git a/spark/src/test/scala/ai/chronon/spark/test/groupby/GroupByUploadTest.scala b/spark/src/test/scala/ai/chronon/spark/test/groupby/GroupByUploadTest.scala index 2953d59c26..23e8cbacc7 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/groupby/GroupByUploadTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/groupby/GroupByUploadTest.scala @@ -208,7 +208,11 @@ class GroupByUploadTest extends AnyFlatSpec { val leftRatings = Builders.Source.entities( - Builders.Query(selects = Builders.Selects("review", "rating", "category_ratings", "ts")), + Builders.Query(selects = Builders.Selects.exprs("review" -> "review", + "rating" -> "rating", + "category_ratings" -> "category_ratings", + "ts" -> "ts", + tableUtils.internalRowIdColumnName -> "review")), snapshotTable = ratingsTable, mutationTopic = s"${ratingsTable}_mutations", mutationTable = s"${ratingsTable}_mutations" diff --git a/spark/src/test/scala/ai/chronon/spark/test/join/BaseJoinTest.scala b/spark/src/test/scala/ai/chronon/spark/test/join/BaseJoinTest.scala index d2438523cc..056c8aa905 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/join/BaseJoinTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/join/BaseJoinTest.scala @@ -40,7 +40,16 @@ abstract class BaseJoinTest extends AnyFlatSpec { import ai.chronon.spark.submission - val spark: SparkSession = submission.SparkSessionBuilder.build("JoinTest", local = true) + val spark: SparkSession = submission.SparkSessionBuilder.build( + "JoinTest", + local = true, + additionalConfig = Some( + Map( + "spark.sql.sources.bucketing.enabled" -> "true", + "spark.sql.bucketing.coalesceBucketsInJoin.enabled" -> "true", + "spark.sql.autoBroadcastJoinThreshold" -> "-1" + )) + ) protected implicit val tableUtils: TableTestUtils = TableTestUtils(spark) protected val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) @@ -92,10 +101,15 @@ abstract class BaseJoinTest extends AnyFlatSpec { // left side val itemQueries = List(Column("item", api.StringType, 100)) val itemQueriesTable = s"$namespace.item_queries" + tableUtils.sql(s"DROP TABLE IF EXISTS $itemQueriesTable").drop(Constants.RowIDColumn) val itemQueriesDf = DataFrameGen .events(spark, itemQueries, 1000, partitions = 100) + // duplicate the events - itemQueriesDf.union(itemQueriesDf).save(itemQueriesTable) // .union(itemQueriesDf) + val duplicated = itemQueriesDf.union(itemQueriesDf) + + // Add row ID here to avoid duplicating it + duplicated.withColumn(Constants.RowIDColumn, uuid()).save(itemQueriesTable) val start = tableUtils.partitionSpec.minus(today, new Window(100, TimeUnit.DAYS)) val suffix = if (nameSuffix.isEmpty) "" else s"_$nameSuffix" diff --git a/spark/src/test/scala/ai/chronon/spark/test/join/DifferentPartitionColumnsTest.scala b/spark/src/test/scala/ai/chronon/spark/test/join/DifferentPartitionColumnsTest.scala index 304e026204..65e6da02d0 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/join/DifferentPartitionColumnsTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/join/DifferentPartitionColumnsTest.scala @@ -18,7 +18,7 @@ package ai.chronon.spark.test.join import ai.chronon.aggregator.test.Column import ai.chronon.api -import ai.chronon.api.{Builders, Operation, TimeUnit, Window} +import ai.chronon.api.{Builders, Constants, Operation, TimeUnit, Window} import ai.chronon.spark._ import ai.chronon.spark.Extensions._ import ai.chronon.spark.test.DataFrameGen @@ -81,14 +81,18 @@ class DifferentPartitionColumnsTest extends BaseJoinTest { val start = tableUtils.partitionSpec.minus(today, new Window(60, TimeUnit.DAYS)) val end = tableUtils.partitionSpec.minus(today, new Window(15, TimeUnit.DAYS)) val joinConf = Builders.Join( - left = Builders.Source.entities(Builders.Query(startPartition = start), snapshotTable = countryTable), + left = Builders.Source.entities( + Builders.Query(selects = Map("country" -> "country", Constants.RowIDColumn -> Constants.RowIDColumn), + startPartition = start), + snapshotTable = countryTable + ), joinParts = Seq(Builders.JoinPart(groupBy = weightGroupBy), Builders.JoinPart(groupBy = heightGroupBy)), metaData = Builders.MetaData(name = "test.country_features_partition_test", namespace = namespace, team = "chronon") ) val runner = new ai.chronon.spark.Join(joinConf = joinConf, endPartition = end, tableUtils = tableUtils) - val computed = runner.computeJoin(Some(7)) + val computed = runner.computeJoin(Some(7)).drop(Constants.RowIDColumn) val expected = tableUtils.sql(s""" |WITH | countries AS (SELECT country, ds from $countryTable where ds >= '$start' and ds <= '$end'), diff --git a/spark/src/test/scala/ai/chronon/spark/test/join/EndPartitionJoinTest.scala b/spark/src/test/scala/ai/chronon/spark/test/join/EndPartitionJoinTest.scala index 3ea92c8382..99f265ed2e 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/join/EndPartitionJoinTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/join/EndPartitionJoinTest.scala @@ -17,6 +17,7 @@ package ai.chronon.spark.test.join import ai.chronon.api.Builders +import ai.chronon.api.Constants import ai.chronon.api.Extensions._ import ai.chronon.api.ScalaJavaConversions._ import ai.chronon.spark._ @@ -30,8 +31,12 @@ class EndPartitionJoinTest extends BaseJoinTest { val start = join.getLeft.query.startPartition val end = tableUtils.partitionSpec.after(start) val limitedJoin = Builders.Join( - left = - Builders.Source.events(Builders.Query(startPartition = start, endPartition = end), table = join.getLeft.table), + left = Builders.Source.events( + Builders.Query(selects = Map("item" -> "item", Constants.RowIDColumn -> Constants.RowIDColumn), + startPartition = start, + endPartition = end), + table = join.getLeft.table + ), joinParts = join.getJoinParts.toScala, metaData = join.metaData ) diff --git a/spark/src/test/scala/ai/chronon/spark/test/join/EntitiesEntitiesTest.scala b/spark/src/test/scala/ai/chronon/spark/test/join/EntitiesEntitiesTest.scala index d2c403be7a..f9d7a812df 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/join/EntitiesEntitiesTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/join/EntitiesEntitiesTest.scala @@ -18,7 +18,7 @@ package ai.chronon.spark.test.join import ai.chronon.aggregator.test.Column import ai.chronon.api -import ai.chronon.api.{Builders, Operation, TimeUnit, Window} +import ai.chronon.api.{Builders, Constants, Operation, TimeUnit, Window} import ai.chronon.spark._ import ai.chronon.spark.Extensions._ import ai.chronon.spark.test.DataFrameGen @@ -41,7 +41,9 @@ class EntitiesEntitiesTest extends BaseJoinTest { val weightSource = Builders.Source.entities( query = Builders - .Query(selects = Builders.Selects("weight"), startPartition = yearAgo, endPartition = dayAndMonthBefore) + .Query(selects = Builders.Selects("weight", Constants.RowIDColumn), + startPartition = yearAgo, + endPartition = dayAndMonthBefore) .setPartitionFormat("yyyyMMdd"), snapshotTable = weightTable ) @@ -61,7 +63,7 @@ class EntitiesEntitiesTest extends BaseJoinTest { val heightTable = s"$namespace.heights" DataFrameGen.entities(spark, heightSchema, 100, partitions = 400).save(heightTable) val heightSource = Builders.Source.entities( - query = Builders.Query(selects = Builders.Selects("height"), startPartition = monthAgo), + query = Builders.Query(selects = Builders.Selects("height", Constants.RowIDColumn), startPartition = monthAgo), snapshotTable = heightTable ) @@ -86,7 +88,7 @@ class EntitiesEntitiesTest extends BaseJoinTest { ) val runner = new ai.chronon.spark.Join(joinConf = joinConf, endPartition = end, tableUtils = tableUtils) - val computed = runner.computeJoin(Some(7)) + val computed = runner.computeJoin(Some(7)).drop(Constants.RowIDColumn) val expected = tableUtils.sql(s""" |WITH diff --git a/spark/src/test/scala/ai/chronon/spark/test/join/EventsEntitiesSnapshotTest.scala b/spark/src/test/scala/ai/chronon/spark/test/join/EventsEntitiesSnapshotTest.scala index bbc557a561..634ecce18e 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/join/EventsEntitiesSnapshotTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/join/EventsEntitiesSnapshotTest.scala @@ -52,7 +52,7 @@ class EventsEntitiesSnapshotTest extends BaseJoinTest { val dollarSource = Builders.Source.entities( query = Builders.Query( - selects = Builders.Selects("ts", "amount_dollars", "user_name", "user"), + selects = Builders.Selects("ts", "amount_dollars", "user_name", "user", Constants.RowIDColumn), startPartition = yearAgo, endPartition = dayAndMonthBefore, setups = @@ -65,10 +65,13 @@ class EventsEntitiesSnapshotTest extends BaseJoinTest { val rupeeSource = Builders.Source.entities( query = Builders.Query( - selects = Map("ts" -> "ts", - "amount_dollars" -> "CAST(amount_rupees/70 as long)", - "user_name" -> "user_name", - "user" -> "user"), + selects = Map( + "ts" -> "ts", + "amount_dollars" -> "CAST(amount_rupees/70 as long)", + "user_name" -> "user_name", + "user" -> "user", + Constants.RowIDColumn -> Constants.RowIDColumn + ), startPartition = monthAgo, setups = Seq( "create temporary function temp_replace_right_b as 'org.apache.hadoop.hive.ql.udf.UDFRegExpReplace'", @@ -141,7 +144,7 @@ class EventsEntitiesSnapshotTest extends BaseJoinTest { resetUDFs() val runner2 = new ai.chronon.spark.Join(joinConf = joinConf, endPartition = end, tableUtils = tableUtils) - val computed = runner2.computeJoin(Some(3)) + val computed = runner2.computeJoin(Some(3)).drop(Constants.RowIDColumn) println(s"join start = $start") val expectedQuery = s""" @@ -221,7 +224,7 @@ class EventsEntitiesSnapshotTest extends BaseJoinTest { val runner3 = new ai.chronon.spark.Join(joinConf = joinConf, endPartition = end, tableUtils = tableUtils) val expected2 = spark.sql(expectedQuery) - val computed2 = runner3.computeJoin(Some(3)) + val computed2 = runner3.computeJoin(Some(3)).drop(Constants.RowIDColumn) val diff2 = Comparison.sideBySide(computed2, expected2, List("user_name", "user", "ts", "ds")) if (diff2.count() > 0) { diff --git a/spark/src/test/scala/ai/chronon/spark/test/join/EventsEventsCumulativeTest.scala b/spark/src/test/scala/ai/chronon/spark/test/join/EventsEventsCumulativeTest.scala index f12bd26f53..b3b15905aa 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/join/EventsEventsCumulativeTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/join/EventsEventsCumulativeTest.scala @@ -17,7 +17,7 @@ package ai.chronon.spark.test.join import ai.chronon.api.Builders -import ai.chronon.api.{Window, TimeUnit} +import ai.chronon.api.{Constants, Window, TimeUnit} import ai.chronon.api.ScalaJavaConversions._ import ai.chronon.spark._ import ai.chronon.spark.Extensions._ @@ -49,7 +49,7 @@ class EventsEventsCumulativeTest extends BaseJoinTest { spark.sql(q).show() val start = tableUtils.partitionSpec.minus(today, new Window(100, TimeUnit.DAYS)) val join = new Join(joinConf = joinConf, endPartition = dayAndMonthBefore, tableUtils) - val computed = join.computeJoin(Some(100)) + val computed = join.computeJoin(Some(100)).drop(Constants.RowIDColumn) computed.show() val expected = tableUtils.sql(s""" diff --git a/spark/src/test/scala/ai/chronon/spark/test/join/EventsEventsSnapshotTest.scala b/spark/src/test/scala/ai/chronon/spark/test/join/EventsEventsSnapshotTest.scala index 1acec31190..fbcd134c88 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/join/EventsEventsSnapshotTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/join/EventsEventsSnapshotTest.scala @@ -34,10 +34,11 @@ class EventsEventsSnapshotTest extends BaseJoinTest { ) val viewsTable = s"$namespace.view_events" - DataFrameGen.events(spark, viewsSchema, count = 100, partitions = 200).drop("ts").save(viewsTable) + DataFrameGen.events(spark, viewsSchema, count = 1000, partitions = 200).drop("ts").save(viewsTable) val viewsSource = Builders.Source.events( - query = Builders.Query(selects = Builders.Selects("time_spent_ms"), startPartition = yearAgo), + query = + Builders.Query(selects = Builders.Selects("time_spent_ms", Constants.RowIDColumn), startPartition = yearAgo), table = viewsTable ) @@ -55,10 +56,11 @@ class EventsEventsSnapshotTest extends BaseJoinTest { val itemQueries = List(Column("item", api.StringType, 100)) val itemQueriesTable = s"$namespace.item_queries" DataFrameGen - .events(spark, itemQueries, 100, partitions = 100) + .events(spark, itemQueries, 100, partitions = 10) .save(itemQueriesTable) - val start = tableUtils.partitionSpec.minus(today, new Window(100, TimeUnit.DAYS)) + val start = tableUtils.partitionSpec.minus(today, new Window(5, TimeUnit.DAYS)) + val end = tableUtils.partitionSpec.minus(today, new Window(5, TimeUnit.DAYS)) val joinConf = Builders.Join( left = Builders.Source.events(Builders.Query(startPartition = start), table = itemQueriesTable), @@ -66,23 +68,25 @@ class EventsEventsSnapshotTest extends BaseJoinTest { metaData = Builders.MetaData(name = "test.item_snapshot_features_2", namespace = namespace, team = "chronon") ) - (new Analyzer(tableUtils, joinConf, monthAgo, today)).run() - val join = new ai.chronon.spark.Join(joinConf = joinConf, endPartition = monthAgo, tableUtils = tableUtils) - val computed = join.computeJoin() + (new Analyzer(tableUtils, joinConf, start, today)).run() + val join = new ai.chronon.spark.Join(joinConf = joinConf, endPartition = end, tableUtils = tableUtils) + val computed = join.computeJoin(overrideStartPartition = Option(start)).drop(Constants.RowIDColumn) + println("Computed:") computed.show() val expected = tableUtils.sql(s""" |WITH - | queries AS (SELECT item, ts, ds from $itemQueriesTable where ds >= '$start' and ds <= '$monthAgo') + | queries AS (SELECT item, ts, ds from $itemQueriesTable where ds >= '$start' and ds <= '$end') | SELECT queries.item, | queries.ts, | queries.ds, | AVG(IF(queries.ds > $viewsTable.ds, time_spent_ms, null)) as user_unit_test_item_views_time_spent_ms_average | FROM queries left outer join $viewsTable | ON queries.item = $viewsTable.item - | WHERE ($viewsTable.item IS NOT NULL) AND $viewsTable.ds >= '$yearAgo' AND $viewsTable.ds <= '$dayAndMonthBefore' + | WHERE ($viewsTable.item IS NOT NULL) AND $viewsTable.ds >= '$yearAgo' AND $viewsTable.ds <= '$end' | GROUP BY queries.item, queries.ts, queries.ds, from_unixtime(queries.ts/1000, 'yyyy-MM-dd') |""".stripMargin) + println("Expected:") expected.show() val diff = Comparison.sideBySide(computed, expected, List("item", "ts", "ds")) diff --git a/spark/src/test/scala/ai/chronon/spark/test/join/EventsEventsTemporalTest.scala b/spark/src/test/scala/ai/chronon/spark/test/join/EventsEventsTemporalTest.scala index 68b23a30e7..f84871b3e9 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/join/EventsEventsTemporalTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/join/EventsEventsTemporalTest.scala @@ -18,7 +18,7 @@ package ai.chronon.spark.test.join import ai.chronon.aggregator.test.Column import ai.chronon.api -import ai.chronon.api.{Builders, Operation, TimeUnit, Window} +import ai.chronon.api.{Builders, Constants, Operation, TimeUnit, Window} import ai.chronon.spark._ import ai.chronon.spark.Extensions._ import ai.chronon.spark.test.{DataFrameGen, TableTestUtils} @@ -30,7 +30,13 @@ class EventsEventsTemporalTest extends BaseJoinTest { val sparkSkewFree: SparkSession = submission.SparkSessionBuilder.build( "JoinTest", local = true, - additionalConfig = Option(Map("spark.chronon.join.backfill.mode.skewFree" -> "true")) + additionalConfig = Option( + Map( + "spark.chronon.join.backfill.mode.skewFree" -> "true", + "spark.sql.sources.bucketing.enabled" -> "true", + "spark.sql.bucketing.coalesceBucketsInJoin.enabled" -> "true", + "spark.sql.autoBroadcastJoinThreshold" -> "-1" + )) ) protected implicit val tableUtilsSkewFree: TableTestUtils = TableTestUtils(sparkSkewFree) @@ -44,7 +50,7 @@ class EventsEventsTemporalTest extends BaseJoinTest { val viewsTable = s"$namespace.view_temporal" DataFrameGen - .events(sparkSkewFree, viewsSchema, count = 100, partitions = 200) + .events(sparkSkewFree, viewsSchema, count = 100, partitions = 200, addRowID = true) .save(viewsTable, Map("tblProp1" -> "1")) val viewsSource = Builders.Source.events( @@ -73,7 +79,7 @@ class EventsEventsTemporalTest extends BaseJoinTest { val start = tableUtilsSkewFree.partitionSpec.minus(today, new Window(100, TimeUnit.DAYS)) (new Analyzer(tableUtilsSkewFree, joinConf, monthAgo, today)).run() val join = new Join(joinConf = joinConf, endPartition = dayAndMonthBefore, tableUtilsSkewFree) - val computed = join.computeJoin(Some(100)) + val computed = join.computeJoin(Some(100)).drop(Constants.RowIDColumn) computed.show() val expected = tableUtilsSkewFree.sql(s""" diff --git a/spark/src/test/scala/ai/chronon/spark/test/join/EventsEventsTemporalWithGBDerivation.scala b/spark/src/test/scala/ai/chronon/spark/test/join/EventsEventsTemporalWithGBDerivation.scala index 64afc4fd8a..500427e184 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/join/EventsEventsTemporalWithGBDerivation.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/join/EventsEventsTemporalWithGBDerivation.scala @@ -2,7 +2,7 @@ package ai.chronon.spark.test.join import ai.chronon.aggregator.test.Column import ai.chronon.api -import ai.chronon.api.{Builders, Operation, TimeUnit, Window} +import ai.chronon.api.{Builders, Constants, Operation, TimeUnit, Window} import ai.chronon.spark._ import ai.chronon.spark.Extensions._ import ai.chronon.spark.test.{DataFrameGen, TableTestUtils} @@ -15,7 +15,13 @@ class EventsEventsTemporalWithGBDerivation extends BaseJoinTest { val sparkSkewFree: SparkSession = submission.SparkSessionBuilder.build( "JoinTest", local = true, - additionalConfig = Option(Map("spark.chronon.join.backfill.mode.skewFree" -> "true")) + additionalConfig = Option( + Map( + "spark.chronon.join.backfill.mode.skewFree" -> "true", + "spark.sql.sources.bucketing.enabled" -> "true", + "spark.sql.bucketing.coalesceBucketsInJoin.enabled" -> "true", + "spark.sql.autoBroadcastJoinThreshold" -> "-1" + )) ) protected implicit val tableUtilsSkewFree: TableTestUtils = TableTestUtils(sparkSkewFree) @@ -34,7 +40,8 @@ class EventsEventsTemporalWithGBDerivation extends BaseJoinTest { val viewsSource = Builders.Source.events( table = viewsTable, - query = Builders.Query(selects = Builders.Selects("time_spent_ms"), startPartition = yearAgo) + query = + Builders.Query(selects = Builders.Selects("time_spent_ms", Constants.RowIDColumn), startPartition = yearAgo) ) // left side @@ -48,7 +55,7 @@ class EventsEventsTemporalWithGBDerivation extends BaseJoinTest { val start = tableUtilsSkewFree.partitionSpec.minus(today, new Window(100, TimeUnit.DAYS)) (new Analyzer(tableUtilsSkewFree, joinConf, monthAgo, today)).run() val join = new Join(joinConf = joinConf, endPartition = dayAndMonthBefore, tableUtilsSkewFree) - val computed = join.computeJoin(Some(100)) + val computed = join.computeJoin(Some(100)).drop(tableUtilsSkewFree.internalRowIdColumnName) computed.show() val expected = tableUtilsSkewFree.sql(s""" @@ -131,7 +138,8 @@ class EventsEventsTemporalWithGBDerivation extends BaseJoinTest { val viewsSource = Builders.Source.events( table = viewsTable, - query = Builders.Query(selects = Builders.Selects("time_spent_ms"), startPartition = yearAgo) + query = + Builders.Query(selects = Builders.Selects("time_spent_ms", Constants.RowIDColumn), startPartition = yearAgo) ) spark.sql(s"DROP TABLE IF EXISTS $viewsTable") diff --git a/spark/src/test/scala/ai/chronon/spark/test/join/HeterogeneousPartitionColumnsTest.scala b/spark/src/test/scala/ai/chronon/spark/test/join/HeterogeneousPartitionColumnsTest.scala index 869d809618..23f35e5512 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/join/HeterogeneousPartitionColumnsTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/join/HeterogeneousPartitionColumnsTest.scala @@ -137,6 +137,7 @@ class HeterogeneousPartitionColumnsTest extends BaseJoinTest { left = Builders.Source.events( Builders .Query( + selects = Map("item" -> "item", Constants.RowIDColumn -> Constants.RowIDColumn), partitionColumn = leftCustomPartitionCol ) .setPartitionFormat(leftCustomFormat), @@ -149,7 +150,7 @@ class HeterogeneousPartitionColumnsTest extends BaseJoinTest { ) val join = new ai.chronon.spark.Join(joinConf = joinConf, endPartition = "2025-05-08", tableUtils = tableUtils) - val computed = join.computeJoin() + val computed = join.computeJoin().drop(Constants.RowIDColumn) assert(computed.collect().nonEmpty) } } diff --git a/spark/src/test/scala/ai/chronon/spark/test/join/KeyMappingOverlappingFieldsTest.scala b/spark/src/test/scala/ai/chronon/spark/test/join/KeyMappingOverlappingFieldsTest.scala index 08fe941cb2..26d0e14cf7 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/join/KeyMappingOverlappingFieldsTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/join/KeyMappingOverlappingFieldsTest.scala @@ -18,7 +18,7 @@ package ai.chronon.spark.test.join import ai.chronon.aggregator.test.Column import ai.chronon.api -import ai.chronon.api.{Builders, TimeUnit, Window} +import ai.chronon.api.{Builders, Constants, TimeUnit, Window} import ai.chronon.spark._ import ai.chronon.spark.Extensions._ import ai.chronon.spark.test.DataFrameGen @@ -38,11 +38,11 @@ class KeyMappingOverlappingFieldsTest extends BaseJoinTest { DataFrameGen.entities(spark, namesSchema, 100, partitions = 400).save(namesTable) val namesSource = Builders.Source.entities( - query = - Builders.Query(selects = - Builders.Selects.exprs("user" -> "user", "user_id" -> "user", "attribute" -> "attribute"), - startPartition = yearAgo, - endPartition = dayAndMonthBefore), + query = Builders.Query( + selects = Builders.Selects.exprs("user" -> "user", "user_id" -> "user", "attribute" -> "attribute"), + startPartition = yearAgo, + endPartition = dayAndMonthBefore + ), snapshotTable = namesTable ) @@ -61,7 +61,9 @@ class KeyMappingOverlappingFieldsTest extends BaseJoinTest { val start = tableUtils.partitionSpec.minus(today, new Window(60, TimeUnit.DAYS)) val end = tableUtils.partitionSpec.minus(today, new Window(15, TimeUnit.DAYS)) val joinConf = Builders.Join( - left = Builders.Source.entities(Builders.Query(selects = Map("user_id" -> "user_id"), startPartition = start), + left = Builders.Source.entities(Builders.Query(selects = Map("user_id" -> "user_id", + Constants.RowIDColumn -> Constants.RowIDColumn), + startPartition = start), snapshotTable = usersTable), joinParts = Seq( Builders.JoinPart(groupBy = namesGroupBy, @@ -73,7 +75,7 @@ class KeyMappingOverlappingFieldsTest extends BaseJoinTest { ) val runner = new ai.chronon.spark.Join(joinConf = joinConf, endPartition = end, tableUtils = tableUtils) - val computed = runner.computeJoin(Some(7)) + val computed = runner.computeJoin(Some(7)).drop(Constants.RowIDColumn) assertFalse(computed.isEmpty) } } diff --git a/spark/src/test/scala/ai/chronon/spark/test/join/NoAggTest.scala b/spark/src/test/scala/ai/chronon/spark/test/join/NoAggTest.scala index 567883e9ab..3684a5be33 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/join/NoAggTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/join/NoAggTest.scala @@ -18,10 +18,11 @@ package ai.chronon.spark.test.join import ai.chronon.aggregator.test.Column import ai.chronon.api -import ai.chronon.api.{Builders, TimeUnit, Window} +import ai.chronon.api.{Builders, Constants, TimeUnit, Window} import ai.chronon.spark._ import ai.chronon.spark.Extensions._ import ai.chronon.spark.test.DataFrameGen +import org.apache.spark.sql.functions.uuid import org.junit.Assert._ class NoAggTest extends BaseJoinTest { @@ -34,7 +35,7 @@ class NoAggTest extends BaseJoinTest { Column("name", api.StringType, 10) ) val namesTable = s"$namespace.names" - DataFrameGen.entities(spark, namesSchema, 100, partitions = 400).save(namesTable) + DataFrameGen.entities(spark, namesSchema, 100, partitions = 400, addRowID = false).save(namesTable) val namesSource = Builders.Source.entities( query = @@ -63,14 +64,17 @@ class NoAggTest extends BaseJoinTest { val start = tableUtils.partitionSpec.minus(today, new Window(60, TimeUnit.DAYS)) val end = tableUtils.partitionSpec.minus(today, new Window(15, TimeUnit.DAYS)) val joinConf = Builders.Join( - left = Builders.Source.entities(Builders.Query(selects = Map("user" -> "user"), startPartition = start), - snapshotTable = usersTable), + left = + Builders.Source.entities(Builders.Query(selects = + Map("user" -> "user", Constants.RowIDColumn -> Constants.RowIDColumn), + startPartition = start), + snapshotTable = usersTable), joinParts = Seq(Builders.JoinPart(groupBy = namesGroupBy)), metaData = Builders.MetaData(name = "test.user_features", namespace = namespace, team = "chronon") ) val runner = new ai.chronon.spark.Join(joinConf = joinConf, endPartition = end, tableUtils = tableUtils) - val computed = runner.computeJoin(Some(7)) + val computed = runner.computeJoin(Some(7)).drop(Constants.RowIDColumn) println(s"join start = $start") val expected = tableUtils.sql(s""" |WITH diff --git a/spark/src/test/scala/ai/chronon/spark/test/join/NoHistoricalBackfillTest.scala b/spark/src/test/scala/ai/chronon/spark/test/join/NoHistoricalBackfillTest.scala index ce2a02bc23..0b16270633 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/join/NoHistoricalBackfillTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/join/NoHistoricalBackfillTest.scala @@ -18,7 +18,7 @@ package ai.chronon.spark.test.join import ai.chronon.aggregator.test.Column import ai.chronon.api -import ai.chronon.api.{Builders, Operation, TimeUnit, Window} +import ai.chronon.api.{Builders, Constants, Operation, TimeUnit, Window} import ai.chronon.spark.Extensions._ import ai.chronon.spark.test.DataFrameGen import org.junit.Assert._ @@ -56,7 +56,11 @@ class NoHistoricalBackfillTest extends BaseJoinTest { val start = tableUtils.partitionSpec.minus(today, new Window(30, TimeUnit.DAYS)) val end = tableUtils.partitionSpec.minus(today, new Window(5, TimeUnit.DAYS)) val joinConf = Builders.Join( - left = Builders.Source.entities(Builders.Query(startPartition = start), snapshotTable = countryTable), + left = Builders.Source.entities( + Builders.Query(selects = Map("country" -> "country", Constants.RowIDColumn -> Constants.RowIDColumn), + startPartition = start), + snapshotTable = countryTable + ), joinParts = Seq(Builders.JoinPart(groupBy = weightGroupBy)), metaData = Builders.MetaData(name = "test.country_no_historical_backfill", namespace = namespace, @@ -65,7 +69,7 @@ class NoHistoricalBackfillTest extends BaseJoinTest { ) val runner = new ai.chronon.spark.Join(joinConf = joinConf, endPartition = end, tableUtils = tableUtils) - val computed = runner.computeJoin(Some(7)) + val computed = runner.computeJoin(Some(7)).drop(Constants.RowIDColumn) println("showing join result") computed.show() diff --git a/spark/src/test/scala/ai/chronon/spark/test/join/SelectedJoinPartsTest.scala b/spark/src/test/scala/ai/chronon/spark/test/join/SelectedJoinPartsTest.scala index 2c4a1e18c8..079f39c748 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/join/SelectedJoinPartsTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/join/SelectedJoinPartsTest.scala @@ -19,7 +19,7 @@ package ai.chronon.spark.test.join import ai.chronon.aggregator.test.Column import ai.chronon.api import ai.chronon.api.planner.RelevantLeftForJoinPart -import ai.chronon.api.{Accuracy, Builders, Operation} +import ai.chronon.api.{Accuracy, Builders, Constants, Operation} import ai.chronon.spark._ import ai.chronon.spark.Extensions._ import ai.chronon.spark.test.DataFrameGen @@ -44,7 +44,8 @@ class SelectedJoinPartsTest extends BaseJoinTest { spark.sql(s"DROP TABLE IF EXISTS $itemQueriesTable") spark.sql(s"DROP TABLE IF EXISTS ${itemQueriesTable}_tmp") DataFrameGen.events(spark, itemQueries, 100, partitions = 30).save(s"${itemQueriesTable}_tmp") - val leftDf = tableUtils.sql(s"SELECT item, value, ts, ds FROM ${itemQueriesTable}_tmp") + val leftDf = + tableUtils.sql(s"SELECT item, value, ts, ds, ${tableUtils.internalRowIdColumnName} FROM ${itemQueriesTable}_tmp") leftDf.save(itemQueriesTable) val start = monthAgo @@ -63,7 +64,8 @@ class SelectedJoinPartsTest extends BaseJoinTest { sources = Seq( Builders.Source.events( table = viewsTable, - query = Builders.Query(startPartition = start) + query = + Builders.Query(selects = Builders.Selects("user", "value", Constants.RowIDColumn), startPartition = start) )), keyColumns = Seq("item"), aggregations = Seq( @@ -80,7 +82,8 @@ class SelectedJoinPartsTest extends BaseJoinTest { sources = Seq( Builders.Source.events( table = viewsTable, - query = Builders.Query(startPartition = start) + query = + Builders.Query(selects = Builders.Selects("user", "value", Constants.RowIDColumn), startPartition = start) )), keyColumns = Seq("item"), aggregations = Seq( @@ -96,7 +99,8 @@ class SelectedJoinPartsTest extends BaseJoinTest { sources = Seq( Builders.Source.events( table = viewsTable, - query = Builders.Query(startPartition = start) + query = + Builders.Query(selects = Builders.Selects("user", "value", Constants.RowIDColumn), startPartition = start) )), keyColumns = Seq("item"), aggregations = Seq( diff --git a/spark/src/test/scala/ai/chronon/spark/test/join/SkipBloomFilterJoinBackfillTest.scala b/spark/src/test/scala/ai/chronon/spark/test/join/SkipBloomFilterJoinBackfillTest.scala index e4556d323e..a634f41833 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/join/SkipBloomFilterJoinBackfillTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/join/SkipBloomFilterJoinBackfillTest.scala @@ -32,10 +32,17 @@ class SkipBloomFilterJoinBackfillTest extends BaseJoinTest { it should "test skip bloom filter join backfill" in { import ai.chronon.spark.submission val testSpark: SparkSession = - submission.SparkSessionBuilder.build("JoinTest", - local = true, - additionalConfig = - Some(Map("spark.chronon.backfill.bloomfilter.threshold" -> "100"))) + submission.SparkSessionBuilder.build( + "JoinTest", + local = true, + additionalConfig = Some( + Map( + "spark.chronon.backfill.bloomfilter.threshold" -> "100", + "spark.sql.sources.bucketing.enabled" -> "true", + "spark.sql.bucketing.coalesceBucketsInJoin.enabled" -> "true", + "spark.sql.autoBroadcastJoinThreshold" -> "-1" + )) + ) val testTableUtils = TableUtils(testSpark) val viewsSchema = List( Column("user", api.StringType, 10), @@ -70,12 +77,19 @@ class SkipBloomFilterJoinBackfillTest extends BaseJoinTest { val start = testTableUtils.partitionSpec.minus(today, new Window(100, TimeUnit.DAYS)) val joinConf = Builders.Join( - left = Builders.Source.events(Builders.Query(startPartition = start), table = itemQueriesTable), + left = Builders.Source.events( + Builders.Query(selects = Map("item" -> "item", + testTableUtils.internalRowIdColumnName -> testTableUtils.internalRowIdColumnName), + startPartition = start), + table = itemQueriesTable + ), joinParts = Seq(Builders.JoinPart(groupBy = viewsGroupBy, prefix = "user")), metaData = Builders.MetaData(name = "test.item_snapshot_bloom_test", namespace = namespace, team = "chronon") ) val skipBloomComputed = - new ai.chronon.spark.Join(joinConf = joinConf, endPartition = today, tableUtils = testTableUtils).computeJoin() + new ai.chronon.spark.Join(joinConf = joinConf, endPartition = today, tableUtils = testTableUtils) + .computeJoin() + .drop(testTableUtils.internalRowIdColumnName) val leftSideCount = testSpark.sql(s"SELECT item, ts, ds from $itemQueriesTable where ds >= '$start'").count() println("computed count: " + skipBloomComputed.count()) assertEquals(leftSideCount, skipBloomComputed.count()) diff --git a/spark/src/test/scala/ai/chronon/spark/test/join/StructJoinTest.scala b/spark/src/test/scala/ai/chronon/spark/test/join/StructJoinTest.scala index ff230b1c79..65f09ff8d4 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/join/StructJoinTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/join/StructJoinTest.scala @@ -34,7 +34,7 @@ class StructJoinTest extends BaseJoinTest { itemQueriesDf.save(s"${itemQueriesTable}_tmp") val structLeftDf = tableUtils.sql( - s"SELECT item, NAMED_STRUCT('item_repeat', item) as item_struct, ts, ds FROM ${itemQueriesTable}_tmp") + s"SELECT item, NAMED_STRUCT('item_repeat', item) as item_struct, ts, ds, ${Constants.RowIDColumn} FROM ${itemQueriesTable}_tmp") structLeftDf.save(itemQueriesTable) val start = tableUtils.partitionSpec.minus(today, new Window(100, TimeUnit.DAYS)) @@ -49,7 +49,8 @@ class StructJoinTest extends BaseJoinTest { val viewsSource = Builders.Source.events( table = viewsTable, - query = Builders.Query(selects = Builders.Selects("time_spent_ms", "item_struct"), startPartition = yearAgo) + query = Builders.Query(selects = Builders.Selects("time_spent_ms", "item_struct", Constants.RowIDColumn), + startPartition = yearAgo) ) spark.sql(s"DROP TABLE IF EXISTS $viewsTable") df.save(s"${viewsTable}_tmp", Map("tblProp1" -> "1")) diff --git a/spark/src/test/scala/ai/chronon/spark/test/join/UnionJoinTest.scala b/spark/src/test/scala/ai/chronon/spark/test/join/UnionJoinTest.scala index 8d710e59c1..e8d9e1ef79 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/join/UnionJoinTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/join/UnionJoinTest.scala @@ -29,7 +29,7 @@ class UnionJoinTest extends BaseJoinTest { val viewsSource = Builders.Source.events( table = viewsTable, topic = "", - query = Builders.Query(selects = Builders.Selects("time_spent_ms"), + query = Builders.Query(selects = Builders.Selects("time_spent_ms", Constants.RowIDColumn), startPartition = tableUtils.partitionSpec.minus(today, new Window(20, TimeUnit.DAYS))) ) @@ -99,7 +99,7 @@ class UnionJoinTest extends BaseJoinTest { val eventsSource = Builders.Source.events( table = eventsTable, query = Builders.Query( - selects = Builders.Selects("amount", "category"), + selects = Builders.Selects("amount", "category", Constants.RowIDColumn), startPartition = tableUtils.partitionSpec.minus(today, new Window(40, TimeUnit.DAYS)) // Increased window ) ) @@ -215,7 +215,7 @@ class UnionJoinTest extends BaseJoinTest { val eventsSource = Builders.Source.events( table = eventsTable, query = Builders.Query( - selects = Builders.Selects("amount", "category"), + selects = Builders.Selects("amount", "category", Constants.RowIDColumn), startPartition = tableUtils.partitionSpec.minus(today, new Window(40, TimeUnit.DAYS)) // Increased window ) ) @@ -250,10 +250,15 @@ class UnionJoinTest extends BaseJoinTest { // Join with derivations val joinWithSingleJP = Builders.Join( left = Builders.Source.events( - Builders.Query(selects = - Builders.Selects("user_id", "item_id", "amount", "category"), // Select all left cols here + Builders.Query(selects = Builders.Selects("user_id", + "item_id", + "amount", + "category", + Constants.RowIDColumn + ), // Select all left cols here startPartition = start), - table = eventsTable), + table = eventsTable + ), joinParts = Seq(Builders.JoinPart(groupBy = groupByWithDerivations)), metaData = Builders.MetaData(name = "test.user_features_derived.union_join", namespace = namespace, team = "user_team") diff --git a/spark/src/test/scala/ai/chronon/spark/test/join/VersioningTest.scala b/spark/src/test/scala/ai/chronon/spark/test/join/VersioningTest.scala index 12ca24c4b0..0f8dcf4b91 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/join/VersioningTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/join/VersioningTest.scala @@ -17,6 +17,7 @@ package ai.chronon.spark.test.join import ai.chronon.api.Builders +import ai.chronon.api.Constants import ai.chronon.api.Extensions._ import ai.chronon.api.ScalaJavaConversions._ import ai.chronon.spark._ @@ -28,9 +29,17 @@ class VersioningTest extends BaseJoinTest { it should "test versioning" in { val joinConf = getEventsEventsTemporal("versioning") + val itemQueriesTable = joinConf.getLeft.getEvents.table + val start = joinConf.getLeft.getEvents.getQuery.getStartPartition + val queriesBare = + tableUtils.sql(s"SELECT item, ts, ds from $itemQueriesTable where ds >= '$start' and ds <= '$dayAndMonthBefore'") + // Run the old join to ensure that tables exist val oldJoin = new Join(joinConf = joinConf, endPartition = dayAndMonthBefore, tableUtils) - oldJoin.computeJoin(Some(100)) + val oldDf = oldJoin.computeJoin(Some(100)) + assert(oldDf.count() == queriesBare.count(), + s"Join output count ${oldDf.count()} does not match left count ${queriesBare.count()}") + println(s"OLD DF COUNT: ${oldDf.count()}") // Make sure that there is no versioning-detected changes at this phase val joinPartsToRecomputeNoChange = JoinUtils.tablesToRecompute(joinConf, joinConf.metaData.outputTable, tableUtils) @@ -59,7 +68,9 @@ class VersioningTest extends BaseJoinTest { assertEquals(addPartRecompute.size, 1) assertEquals(addPartRecompute, Seq(addPartJoinConf.metaData.outputTable)) // Compute to ensure that it works and to set the stage for the next assertion - addPartJoin.computeJoin(Some(100)) + //addPartJoin.computeJoin(Some(100)) + val addPartDf = addPartJoin.computeJoin(Some(100)) + assert(addPartDf.count() == oldDf.count(), "Final output counts should match after adding a join part") // Test modifying only one of two joinParts val rightModJoinConf = addPartJoinConf.deepCopy() @@ -74,12 +85,11 @@ class VersioningTest extends BaseJoinTest { rightModJoinConf.getJoinParts.get(0).setPrefix("user_4") val rightModBothJoin = new Join(joinConf = rightModJoinConf, endPartition = dayAndMonthBefore, tableUtils) // Compute to ensure that it works - val computed = rightModBothJoin.computeJoin(Some(100)) + val computed = rightModBothJoin.computeJoin(Some(100)).drop(Constants.RowIDColumn) + println(s"computed DF COUNT: ${computed.count()}") // Now assert that the actual output is correct after all these runs computed.show() - val itemQueriesTable = joinConf.getLeft.getEvents.table - val start = joinConf.getLeft.getEvents.getQuery.getStartPartition val viewsTable = s"$namespace.view_versioning" val expected = tableUtils.sql(s""" @@ -105,8 +115,6 @@ class VersioningTest extends BaseJoinTest { expected.show() val diff = Comparison.sideBySide(expected, computed, List("item", "ts", "ds")) - val queriesBare = - tableUtils.sql(s"SELECT item, ts, ds from $itemQueriesTable where ds >= '$start' and ds <= '$dayAndMonthBefore'") assertEquals(queriesBare.count(), computed.count()) if (diff.count() > 0) { println(s"Diff count: ${diff.count()}") diff --git a/spark/src/test/scala/ai/chronon/spark/test/streaming/MutationsTest.scala b/spark/src/test/scala/ai/chronon/spark/test/streaming/MutationsTest.scala index b4ec7da2dc..508cb183ba 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/streaming/MutationsTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/streaming/MutationsTest.scala @@ -18,13 +18,14 @@ package ai.chronon.spark.test.streaming import ai.chronon.aggregator.test.Column import ai.chronon.api -import ai.chronon.api.{Builders, Operation, TimeUnit, TsUtils, Window} +import ai.chronon.api.{Builders, Constants, Operation, TimeUnit, TsUtils, Window} import ai.chronon.spark.Extensions._ import ai.chronon.spark.submission.SparkSessionBuilder import ai.chronon.spark.test.DataFrameGen import ai.chronon.spark.{Comparison, Join} import ai.chronon.spark.catalog.TableUtils import ai.chronon.spark.submission.SparkSessionBuilder +import org.apache.spark.sql.functions.uuid import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.scalatest.flatspec.AnyFlatSpec @@ -151,7 +152,10 @@ class MutationsTest extends AnyFlatSpec { operation: Operation = Operation.AVERAGE): DataFrame = { val testNamespace = namespace(suffix) tableUtils.sql(s"CREATE DATABASE IF NOT EXISTS $testNamespace") - spark.createDataFrame(spark.sparkContext.parallelize(leftData), leftSchema).save(s"$testNamespace.$eventTable") + spark + .createDataFrame(spark.sparkContext.parallelize(leftData), leftSchema) + .withColumn(Constants.RowIDColumn, uuid()) + .save(s"$testNamespace.$eventTable") spark .createDataFrame(spark.sparkContext.parallelize(snapshotData), snapshotSchema) .save(s"$testNamespace.$snapshotTable") @@ -182,7 +186,7 @@ class MutationsTest extends AnyFlatSpec { val leftSource = Builders.Source.events( query = Builders.Query( - selects = Builders.Selects("listing_id", "ts", "event"), + selects = Builders.Selects("listing_id", "ts", "event", Constants.RowIDColumn), startPartition = startPartition ), table = s"$testNamespace.$eventTable" @@ -209,7 +213,7 @@ class MutationsTest extends AnyFlatSpec { ) val runner = new Join(joinConf, endPartition, tableUtils) - runner.computeJoin() + runner.computeJoin().drop(Constants.RowIDColumn) } /** Compute the no windows average based on the tables using pure sql