diff --git a/maven_install.json b/maven_install.json index 6384ee6c17..c5ec24e84b 100755 --- a/maven_install.json +++ b/maven_install.json @@ -1,6 +1,6 @@ { "__AUTOGENERATED_FILE_DO_NOT_MODIFY_THIS_FILE_MANUALLY": "THERE_IS_NO_DATA_ONLY_ZUUL", - "__INPUT_ARTIFACTS_HASH": -542447257, + "__INPUT_ARTIFACTS_HASH": 1615229177, "__RESOLVED_ARTIFACTS_HASH": 733328384, "artifacts": { "ant:ant": { diff --git a/spark/BUILD.bazel b/spark/BUILD.bazel index e5e9ebfda7..bab1c3741e 100644 --- a/spark/BUILD.bazel +++ b/spark/BUILD.bazel @@ -43,6 +43,7 @@ scala_library( maven_artifact("org.apache.thrift:libthrift"), maven_artifact("org.apache.hadoop:hadoop-common"), maven_artifact("org.apache.hadoop:hadoop-client-api"), + maven_artifact_with_suffix("org.apache.iceberg:iceberg-spark-runtime-3.5"), ], ) @@ -84,7 +85,7 @@ scala_library( scala_library( name = "batch_lib", - srcs = glob(["src/main/scala/ai/chronon/spark/batch/*.scala"]), + srcs = glob(["src/main/scala/ai/chronon/spark/batch/**/*.scala"]), format = True, visibility = ["//visibility:public"], deps = [ @@ -102,6 +103,7 @@ scala_library( maven_artifact("org.apache.logging.log4j:log4j-core"), maven_artifact_with_suffix("org.rogach:scallop"), maven_artifact("com.google.code.gson:gson"), + maven_artifact_with_suffix("org.apache.iceberg:iceberg-spark-runtime-3.5"), ], ) @@ -154,6 +156,7 @@ scala_test_suite( deps = test_deps + [ "test_lib", ":batch_lib", + maven_artifact_with_suffix("org.apache.iceberg:iceberg-spark-runtime-3.5"), ], ) diff --git a/spark/src/main/scala/ai/chronon/spark/batch/iceberg/IcebergPartitionStatsExtractor.scala b/spark/src/main/scala/ai/chronon/spark/batch/iceberg/IcebergPartitionStatsExtractor.scala new file mode 100644 index 0000000000..337c212365 --- /dev/null +++ b/spark/src/main/scala/ai/chronon/spark/batch/iceberg/IcebergPartitionStatsExtractor.scala @@ -0,0 +1,81 @@ +package ai.chronon.spark.batch.iceberg + +import org.apache.iceberg.catalog.TableIdentifier +import org.apache.iceberg.spark.SparkSessionCatalog +import org.apache.iceberg.{DataFile, ManifestFiles, PartitionSpec, Table} +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +case class PartitionStats(partitionColToValue: Map[String, String], rowCount: Long, colToNullCount: Map[String, Long]) + +class IcebergPartitionStatsExtractor(spark: SparkSession) { + + def extractPartitionedStats(catalogName: String, namespace: String, tableName: String): Seq[PartitionStats] = { + + val catalog = spark.sessionState.catalogManager + .catalog(catalogName) + .asInstanceOf[SparkSessionCatalog[V2SessionCatalog]] + .icebergCatalog() + + val tableId: TableIdentifier = TableIdentifier.of(namespace, tableName) + val table: Table = catalog.loadTable(tableId) + + require( + table.spec().isPartitioned, + s"Illegal request to compute partition-stats of an un-partitioned table: ${table.name()}." + ) + + val currentSnapshot = Option(table.currentSnapshot()) + val partitionStatsBuffer = mutable.Buffer[PartitionStats]() + + currentSnapshot.foreach { snapshot => + val manifestFiles = snapshot.allManifests(table.io()).asScala + manifestFiles.foreach { manifestFile => + val manifestReader = ManifestFiles.read(manifestFile, table.io()) + + try { + manifestReader.forEach((file: DataFile) => { + + val rowCount: Long = file.recordCount() + + val schema = table.schema() + + val partitionSpec: PartitionSpec = table.specs().get(file.specId()) + val partitionFieldIds = partitionSpec.fields().asScala.map(_.sourceId()).toSet + + val colToNullCount: Map[String, Long] = file + .nullValueCounts() + .asScala + .filterNot { case (fieldId, _) => partitionFieldIds.contains(fieldId) } + .map { case (fieldId, nullCount) => + schema.findField(fieldId).name() -> nullCount.toLong + } + .toMap + + val partitionColToValue = partitionSpec + .fields() + .asScala + .zipWithIndex + .map { case (field, index) => + val sourceField = schema.findField(field.sourceId()) + val partitionValue = file.partition().get(index, classOf[String]) + + sourceField.name() -> String.valueOf(partitionValue) + } + .toMap + + val fileStats = PartitionStats(partitionColToValue, rowCount, colToNullCount) + partitionStatsBuffer.append(fileStats) + }) + } finally { + manifestReader.close() + } + } + } + + partitionStatsBuffer + } +} diff --git a/spark/src/test/scala/ai/chronon/spark/test/batch/iceberg/IcebergPartitionStatsExtractorTest.scala b/spark/src/test/scala/ai/chronon/spark/test/batch/iceberg/IcebergPartitionStatsExtractorTest.scala new file mode 100644 index 0000000000..693b8354bc --- /dev/null +++ b/spark/src/test/scala/ai/chronon/spark/test/batch/iceberg/IcebergPartitionStatsExtractorTest.scala @@ -0,0 +1,139 @@ +package ai.chronon.spark.test.batch.iceberg + +import ai.chronon.spark.batch.iceberg.IcebergPartitionStatsExtractor +import org.apache.spark.sql.SparkSession +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} + +import java.nio.file.Files + +class IcebergPartitionStatsExtractorTest + extends AnyFlatSpec + with Matchers + with BeforeAndAfterAll + with BeforeAndAfterEach { + + private var spark: SparkSession = _ + + override def beforeAll(): Unit = { + + spark = SparkSession + .builder() + .appName("IcebergPartitionStatsExtractorTest") + .master("local[*]") + .config("spark.sql.extensions", "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions") + .config("spark.sql.catalog.spark_catalog", "org.apache.iceberg.spark.SparkSessionCatalog") + .config("spark.sql.warehouse.dir", s"${System.getProperty("java.io.tmpdir")}/warehouse") + .config("spark.sql.catalog.spark_catalog.type", "hadoop") + .config("spark.sql.catalog.spark_catalog.warehouse", Files.createTempDirectory("partition-stats-test").toString) + .config("spark.driver.bindAddress", "127.0.0.1") + .config("spark.ui.enabled", "false") + .enableHiveSupport() + .getOrCreate() + + spark.sparkContext.setLogLevel("WARN") + } + + override def afterAll(): Unit = { + if (spark != null) { + spark.stop() + } + } + + override def beforeEach(): Unit = { + spark.sql("DROP TABLE IF EXISTS test_partitioned_table") + } + + "IcebergPartitionStatsExtractor" should "throw exception for unpartitioned table" in { + spark.sql(""" + CREATE TABLE test_partitioned_table ( + id BIGINT, + name STRING, + value DOUBLE + ) USING iceberg + """) + + val extractor = new IcebergPartitionStatsExtractor(spark) + + val exception = intercept[IllegalArgumentException] { + extractor.extractPartitionedStats("spark_catalog", "default", "test_partitioned_table") + } + + exception.getMessage should include( + "Illegal request to compute partition-stats of an un-partitioned table: spark_catalog.default.test_partitioned_table") + } + + it should "return empty sequence for empty partitioned table" in { + spark.sql(""" + CREATE TABLE test_partitioned_table ( + id BIGINT, + name STRING, + region STRING, + value DOUBLE + ) USING iceberg + PARTITIONED BY (region) + """) + + val extractor = new IcebergPartitionStatsExtractor(spark) + val stats = extractor.extractPartitionedStats("spark_catalog", "default", "test_partitioned_table") + + stats should be(empty) + } + + it should "extract partition stats from table with data" in { + spark.sql(""" + CREATE TABLE test_partitioned_table ( + id BIGINT, + name STRING, + region STRING, + value DOUBLE + ) USING iceberg + PARTITIONED BY (region) + """) + + spark.sql(""" + INSERT INTO test_partitioned_table VALUES + (1, 'Alice', 'North', 100.0), + (2, 'Bob', 'North', 200.0), + (3, 'Charlie', 'South', 150.0), + (4, NULL, 'South', 300.0), + (5, 'Eve', 'East', NULL) + """) + + // Force table refresh to ensure snapshot is available + spark.sql("REFRESH TABLE test_partitioned_table") + + // Verify data was inserted + val rowCount = spark.sql("SELECT COUNT(*) FROM test_partitioned_table").collect()(0).getLong(0) + rowCount should be(5) + + val extractor = new IcebergPartitionStatsExtractor(spark) + val stats = extractor.extractPartitionedStats("spark_catalog", "default", "test_partitioned_table") + + stats should not be empty + stats.length should be(3) + + val totalRows = stats.map(_.rowCount).sum + totalRows should be(5) + + // Verify specific partition stats + val southStats = stats.find(_.partitionColToValue("region") == "South").get + southStats.rowCount should be(2) + southStats.colToNullCount("value") should be(0) // No null values in South partition + southStats.colToNullCount("name") should be(1) // One null name (Charlie, NULL) + southStats.colToNullCount("id") should be(0) // No null ids + + val northStats = stats.find(_.partitionColToValue("region") == "North").get + northStats.rowCount should be(2) + northStats.colToNullCount("value") should be(0) // No null values in North partition + northStats.colToNullCount("name") should be(0) // No null names (Alice, Bob) + northStats.colToNullCount("id") should be(0) // No null ids + + val eastStats = stats.find(_.partitionColToValue("region") == "East").get + eastStats.rowCount should be(1) + eastStats.colToNullCount("value") should be(1) // One null value (Eve has NULL value) + eastStats.colToNullCount("name") should be(0) // No null names (Eve) + eastStats.colToNullCount("id") should be(0) // No null ids + } +} diff --git a/spark/src/test/scala/ai/chronon/spark/test/batch/iceberg/IcebergSparkSPJTest.scala b/spark/src/test/scala/ai/chronon/spark/test/batch/iceberg/IcebergSparkSPJTest.scala new file mode 100644 index 0000000000..e1db8898ef --- /dev/null +++ b/spark/src/test/scala/ai/chronon/spark/test/batch/iceberg/IcebergSparkSPJTest.scala @@ -0,0 +1,345 @@ +package ai.chronon.spark.test.batch.iceberg + +import org.apache.spark.sql.execution.{SortExec, SparkPlan} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} + +import java.nio.file.Files + +/** Comprehensive test suite for Apache Iceberg Storage Partitioned Join (SPJ) optimization. + * Tests verify that SPJ correctly eliminates exchange stages when joining partitioned tables. + */ +class IcebergSparkSPJTest extends AnyFlatSpec with Matchers with BeforeAndAfterAll with BeforeAndAfterEach { + + private var spark: SparkSession = _ + private val regions = Seq("North", "South", "East", "West") + + private var warehouseDir: java.nio.file.Path = _ + + override def beforeAll(): Unit = { + warehouseDir = Files.createTempDirectory("storage-partition-join-test") + spark = SparkSession + .builder() + .appName("IcebergSPJTest") + .master("local[*]") + .config("spark.sql.extensions", "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions") + .config("spark.sql.catalog.spark_catalog", "org.apache.iceberg.spark.SparkSessionCatalog") + .config("spark.sql.warehouse.dir", warehouseDir.toString) + .config("spark.sql.shuffle.partitions", "4") + .config("spark.sql.autoBroadcastJoinThreshold", "-1") + .config("spark.sql.catalog.spark_catalog.type", "hadoop") + .config("spark.sql.catalog.spark_catalog.warehouse", warehouseDir.toString) + .enableHiveSupport() + .getOrCreate() + } + + override def afterAll(): Unit = { + if (warehouseDir != null) { + org.apache.commons.io.FileUtils.deleteDirectory(warehouseDir.toFile) + } + } + + override def beforeEach(): Unit = { + // Clean up any existing test tables + Seq("customers", "orders", "shipping").foreach { table => + spark.sql(s"DROP TABLE IF EXISTS $table") + } + } + + /** Configure Spark session for SPJ optimization + */ + private def enableSPJ(): Unit = { + spark.conf.set("spark.sql.sources.v2.bucketing.enabled", "true") + spark.conf.set("spark.sql.iceberg.planning.preserve-data-grouping", "true") + spark.conf.set("spark.sql.sources.v2.bucketing.pushPartValues.enabled", "true") + spark.conf.set("spark.sql.requireAllClusterKeysForCoPartition", "false") + spark.conf.set("spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled", "true") + spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1") // Disable broadcast joins + spark.conf.set("spark.sql.adaptive.enabled", "false") // Disable AQE for predictable plans + } + + /** Disable SPJ optimization + */ + private def disableSPJ(): Unit = { + spark.conf.set("spark.sql.sources.v2.bucketing.enabled", "false") + spark.conf.set("spark.sql.iceberg.planning.preserve-data-grouping", "false") + } + + /** Create test tables with matching partition schemes + */ + private def createTestTables(): Unit = { + // Create customers table with bucketed partitioning + spark.sql(""" + CREATE TABLE customers ( + customer_id BIGINT, + customer_name STRING, + region STRING, + email STRING, + created_date STRING + ) USING iceberg + PARTITIONED BY (region, bucket(4, customer_id)) + """) + + // Create orders table with identical partitioning + spark.sql(""" + CREATE TABLE orders ( + order_id BIGINT, + customer_id BIGINT, + region STRING, + amount DECIMAL(10,2), + order_date STRING + ) USING iceberg + PARTITIONED BY (region, bucket(4, customer_id)) + """) + + // Insert test data for customers + val customerData = (1 to 1000).map { i => + (i.toLong, s"Customer_$i", regions(i % 4), s"customer_$i@test.com", "2024-01-01") + } + + spark.createDataFrame(customerData).write.mode("append").insertInto("customers") + + // Insert test data for orders + val orderData = (1 to 5000).flatMap { orderId => + val customerId = (orderId % 1000) + 1 + val region = regions(customerId.toInt % 4) + Seq((orderId.toLong, customerId.toLong, region, BigDecimal(orderId * 10.5), "2024-01-15")) + } + + spark.createDataFrame(orderData).write.mode("append").insertInto("orders") + } + + /** Check if a Spark plan contains Exchange operators + */ + private def hasExchange(plan: SparkPlan): Boolean = { + val planString = plan.toString + planString.contains("Exchange") || planString.contains("ShuffleExchange") + } + + /** Count Exchange operators in a plan + */ + private def countExchanges(plan: SparkPlan): Int = { + val planString = plan.toString + val exchangeLines = + planString.split("\n").filter(line => line.contains("Exchange") || line.contains("ShuffleExchange")) + exchangeLines.length + } + + /** Assert that a DataFrame's execution plan contains no Exchange operators + */ + private def assertNoExchange(df: DataFrame, message: String = ""): Unit = { + val plan = df.queryExecution.executedPlan + val exchangeCount = countExchanges(plan) + assert(exchangeCount == 0, + s"Expected no Exchange operators but found $exchangeCount. $message\nPlan:\n${plan.toString}") + } + + /** Assert that a DataFrame's execution plan contains Exchange operators + */ + private def assertHasExchange(df: DataFrame, message: String = ""): Unit = { + val plan = df.queryExecution.executedPlan + assert(hasExchange(plan), s"Expected Exchange operators in plan but found none. $message\nPlan:\n${plan.toString}") + } + + /** Analyze execution plan details + */ + private case class PlanAnalysis( + hasExchange: Boolean, + exchangeCount: Int, + hasShuffle: Boolean, + hasBatchScan: Boolean, + hasSort: Boolean, + planString: String + ) + + private def analyzePlan(df: DataFrame): PlanAnalysis = { + val plan = df.queryExecution.executedPlan + val planString = plan.toString + + df.explain() + PlanAnalysis( + hasExchange = hasExchange(plan), + exchangeCount = countExchanges(plan), + hasShuffle = planString.contains("Exchange hashpartitioning"), + hasBatchScan = planString.contains("BatchScan"), + hasSort = plan.exists(_.isInstanceOf[SortExec]), + planString = planString + ) + } + + it should "eliminate exchange stages when SPJ is enabled and conditions are met" in { + createTestTables() + enableSPJ() + + val customers = spark.table("customers") + val orders = spark.table("orders") + + val joinDf = customers.join(orders, + customers("region") === orders("region") && + customers("customer_id") === orders("customer_id")) + + assertNoExchange(joinDf, "SPJ should eliminate exchange for co-partitioned join") + + // Verify the join still produces correct results + val resultCount = joinDf.count() + resultCount should be > 0L + } + + it should "contain exchange stages when SPJ is disabled" in { + createTestTables() + disableSPJ() + + val customers = spark.table("customers") + val orders = spark.table("orders") + + val joinDf = customers.join(orders, + customers("region") === orders("region") && + customers("customer_id") === orders("customer_id")) + + assertHasExchange(joinDf, "Join without SPJ should contain exchange") + } + + it should "produce identical results with and without SPJ" in { + createTestTables() + + val customers = spark.table("customers") + val orders = spark.table("orders") + + // Get results without SPJ + disableSPJ() + val joinWithoutSPJ = customers.join(orders, + customers("region") === orders("region") && + customers("customer_id") === orders("customer_id")) + val countWithoutSPJ = joinWithoutSPJ.count() + val samplesWithoutSPJ = joinWithoutSPJ.orderBy("order_id").limit(10).collect() + + // Get results with SPJ + enableSPJ() + val joinWithSPJ = customers.join(orders, + customers("region") === orders("region") && + customers("customer_id") === orders("customer_id")) + val countWithSPJ = joinWithSPJ.count() + val samplesWithSPJ = joinWithSPJ.orderBy("order_id").limit(10).collect() + + // Compare results + countWithSPJ shouldEqual countWithoutSPJ + samplesWithSPJ should contain theSameElementsInOrderAs samplesWithoutSPJ + } + + it should "require all partition columns in join condition" in { + createTestTables() + enableSPJ() + + val customers = spark.table("customers") + val orders = spark.table("orders") + + // Join only on region (missing customer_id from partition columns) + val partialJoin = customers.join(orders, Seq("region")) + + // This should still have exchange because not all partition columns are used + assertHasExchange(partialJoin, "SPJ requires all partition columns in join condition") + } + + it should "handle missing partition values with pushPartValues enabled" in { + createTestTables() + + // Delete orders for West region to create missing partition scenario + spark.sql("DELETE FROM orders WHERE region = 'West'") + + enableSPJ() + + val customers = spark.table("customers") + val orders = spark.table("orders") + + val leftJoin = customers.join(orders, + customers("region") === orders("region") && + customers("customer_id") === orders("customer_id"), + "left") + + // SPJ should still work with missing partitions when pushPartValues is enabled + assertNoExchange(leftJoin, "SPJ should handle missing partitions") + + // Verify West customers appear with null order data + val westResults = leftJoin + .filter(col("customers.region") === "West") + .select("customers.customer_id", "orders.order_id") + .collect() + + westResults.length should be > 0 + westResults.foreach { row => + row.isNullAt(1) shouldBe true // order_id should be null + } + } + + it should "work with different join types" in { + createTestTables() + enableSPJ() + + val customers = spark.table("customers") + val orders = spark.table("orders") + val joinCondition = customers("region") === orders("region") && + customers("customer_id") === orders("customer_id") + + // Test different join types + val joinTypes = Seq("inner", "left", "right", "left_semi", "left_anti") + + joinTypes.foreach { joinType => + val joinDf = customers.join(orders, joinCondition, joinType) + assertNoExchange(joinDf, s"SPJ should work for $joinType join") + joinDf.count() // Force execution + } + } + + it should "verify detailed plan characteristics with SPJ" in { + createTestTables() + enableSPJ() + + val customers = spark.table("customers") + val orders = spark.table("orders") + + val joinDf = customers.join(orders, + customers("region") === orders("region") && + customers("customer_id") === orders("customer_id")) + + val analysis = analyzePlan(joinDf) + + // Verify plan characteristics + analysis.hasExchange shouldBe false + analysis.exchangeCount shouldBe 0 + analysis.hasShuffle shouldBe false + analysis.hasBatchScan shouldBe true + + // Print plan for debugging + info(s"Execution plan with SPJ:\n${analysis.planString}") + } + + it should "compare execution plans with and without SPJ" in { + createTestTables() + + val customers = spark.table("customers") + val orders = spark.table("orders") + val joinExpr = customers("region") === orders("region") && + customers("customer_id") === orders("customer_id") + + // Analyze without SPJ + disableSPJ() + val withoutSPJ = analyzePlan(customers.join(orders, joinExpr)) + + // Analyze with SPJ + enableSPJ() + val withSPJ = analyzePlan(customers.join(orders, joinExpr)) + + // Compare analyses + withoutSPJ.hasExchange shouldBe true + withSPJ.hasExchange shouldBe false + + withoutSPJ.exchangeCount should be > 0 + withSPJ.exchangeCount shouldBe 0 + + info(s"Plan without SPJ (${withoutSPJ.exchangeCount} exchanges):\n${withoutSPJ.planString}") + info(s"Plan with SPJ (${withSPJ.exchangeCount} exchanges):\n${withSPJ.planString}") + } +} diff --git a/tools/build_rules/dependencies/maven_repository.bzl b/tools/build_rules/dependencies/maven_repository.bzl index 8f45d0f5ab..ddf3968eb4 100644 --- a/tools/build_rules/dependencies/maven_repository.bzl +++ b/tools/build_rules/dependencies/maven_repository.bzl @@ -76,7 +76,6 @@ maven_repository = repository( "org.json4s:json4s-core_2.12:3.7.0-M11", "org.json4s:json4s-ast_2.12:3.7.0-M11", "io.delta:delta-spark_2.12:3.2.0", - "org.apache.iceberg:iceberg-spark-runtime-3.5_2.12:1.6.1", "org.apache.hudi:hudi-spark3.5-bundle_2.12:1.0.0", # grpc