Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion maven_install.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"__AUTOGENERATED_FILE_DO_NOT_MODIFY_THIS_FILE_MANUALLY": "THERE_IS_NO_DATA_ONLY_ZUUL",
"__INPUT_ARTIFACTS_HASH": -542447257,
"__INPUT_ARTIFACTS_HASH": 1615229177,
"__RESOLVED_ARTIFACTS_HASH": 733328384,
"artifacts": {
"ant:ant": {
Expand Down
5 changes: 4 additions & 1 deletion spark/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ scala_library(
maven_artifact("org.apache.thrift:libthrift"),
maven_artifact("org.apache.hadoop:hadoop-common"),
maven_artifact("org.apache.hadoop:hadoop-client-api"),
maven_artifact_with_suffix("org.apache.iceberg:iceberg-spark-runtime-3.5"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create a separate build target for this.

],
)

Expand Down Expand Up @@ -84,7 +85,7 @@ scala_library(

scala_library(
name = "batch_lib",
srcs = glob(["src/main/scala/ai/chronon/spark/batch/*.scala"]),
srcs = glob(["src/main/scala/ai/chronon/spark/batch/**/*.scala"]),
format = True,
visibility = ["//visibility:public"],
deps = [
Expand All @@ -102,6 +103,7 @@ scala_library(
maven_artifact("org.apache.logging.log4j:log4j-core"),
maven_artifact_with_suffix("org.rogach:scallop"),
maven_artifact("com.google.code.gson:gson"),
maven_artifact_with_suffix("org.apache.iceberg:iceberg-spark-runtime-3.5"),
],
)

Expand Down Expand Up @@ -154,6 +156,7 @@ scala_test_suite(
deps = test_deps + [
"test_lib",
":batch_lib",
maven_artifact_with_suffix("org.apache.iceberg:iceberg-spark-runtime-3.5"),
],
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package ai.chronon.spark.batch.iceberg

import org.apache.iceberg.catalog.TableIdentifier
import org.apache.iceberg.spark.SparkSessionCatalog
import org.apache.iceberg.{DataFile, ManifestFiles, PartitionSpec, Table}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog

import scala.collection.JavaConverters._
import scala.collection.mutable

case class PartitionStats(partitionColToValue: Map[String, String], rowCount: Long, colToNullCount: Map[String, Long])

class IcebergPartitionStatsExtractor(spark: SparkSession) {

def extractPartitionedStats(catalogName: String, namespace: String, tableName: String): Seq[PartitionStats] = {

val catalog = spark.sessionState.catalogManager
.catalog(catalogName)
.asInstanceOf[SparkSessionCatalog[V2SessionCatalog]]
.icebergCatalog()

val tableId: TableIdentifier = TableIdentifier.of(namespace, tableName)
val table: Table = catalog.loadTable(tableId)

require(
table.spec().isPartitioned,
s"Illegal request to compute partition-stats of an un-partitioned table: ${table.name()}."
)

val currentSnapshot = Option(table.currentSnapshot())
val partitionStatsBuffer = mutable.Buffer[PartitionStats]()

currentSnapshot.foreach { snapshot =>
val manifestFiles = snapshot.allManifests(table.io()).asScala
manifestFiles.foreach { manifestFile =>
val manifestReader = ManifestFiles.read(manifestFile, table.io())

try {
manifestReader.forEach((file: DataFile) => {

val rowCount: Long = file.recordCount()

val schema = table.schema()

val partitionSpec: PartitionSpec = table.specs().get(file.specId())
val partitionFieldIds = partitionSpec.fields().asScala.map(_.sourceId()).toSet

val colToNullCount: Map[String, Long] = file
.nullValueCounts()
.asScala
.filterNot { case (fieldId, _) => partitionFieldIds.contains(fieldId) }
.map { case (fieldId, nullCount) =>
schema.findField(fieldId).name() -> nullCount.toLong
}
.toMap

val partitionColToValue = partitionSpec
.fields()
.asScala
.zipWithIndex
.map { case (field, index) =>
val sourceField = schema.findField(field.sourceId())
val partitionValue = file.partition().get(index, classOf[String])

sourceField.name() -> String.valueOf(partitionValue)
}
.toMap

val fileStats = PartitionStats(partitionColToValue, rowCount, colToNullCount)
partitionStatsBuffer.append(fileStats)
})
} finally {
manifestReader.close()
}
}
}

partitionStatsBuffer
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package ai.chronon.spark.test.batch.iceberg

import ai.chronon.spark.batch.iceberg.IcebergPartitionStatsExtractor
import org.apache.spark.sql.SparkSession
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}

import java.nio.file.Files

class IcebergPartitionStatsExtractorTest
extends AnyFlatSpec
with Matchers
with BeforeAndAfterAll
with BeforeAndAfterEach {

private var spark: SparkSession = _

override def beforeAll(): Unit = {

spark = SparkSession
.builder()
.appName("IcebergPartitionStatsExtractorTest")
.master("local[*]")
.config("spark.sql.extensions", "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions")
.config("spark.sql.catalog.spark_catalog", "org.apache.iceberg.spark.SparkSessionCatalog")
.config("spark.sql.warehouse.dir", s"${System.getProperty("java.io.tmpdir")}/warehouse")
.config("spark.sql.catalog.spark_catalog.type", "hadoop")
.config("spark.sql.catalog.spark_catalog.warehouse", Files.createTempDirectory("partition-stats-test").toString)
.config("spark.driver.bindAddress", "127.0.0.1")
.config("spark.ui.enabled", "false")
.enableHiveSupport()
.getOrCreate()

spark.sparkContext.setLogLevel("WARN")
}

override def afterAll(): Unit = {
if (spark != null) {
spark.stop()
}
}

override def beforeEach(): Unit = {
spark.sql("DROP TABLE IF EXISTS test_partitioned_table")
}

"IcebergPartitionStatsExtractor" should "throw exception for unpartitioned table" in {
spark.sql("""
CREATE TABLE test_partitioned_table (
id BIGINT,
name STRING,
value DOUBLE
) USING iceberg
""")

val extractor = new IcebergPartitionStatsExtractor(spark)

val exception = intercept[IllegalArgumentException] {
extractor.extractPartitionedStats("spark_catalog", "default", "test_partitioned_table")
}

exception.getMessage should include(
"Illegal request to compute partition-stats of an un-partitioned table: spark_catalog.default.test_partitioned_table")
}

it should "return empty sequence for empty partitioned table" in {
spark.sql("""
CREATE TABLE test_partitioned_table (
id BIGINT,
name STRING,
region STRING,
value DOUBLE
) USING iceberg
PARTITIONED BY (region)
""")

val extractor = new IcebergPartitionStatsExtractor(spark)
val stats = extractor.extractPartitionedStats("spark_catalog", "default", "test_partitioned_table")

stats should be(empty)
}

it should "extract partition stats from table with data" in {
spark.sql("""
CREATE TABLE test_partitioned_table (
id BIGINT,
name STRING,
region STRING,
value DOUBLE
) USING iceberg
PARTITIONED BY (region)
""")

spark.sql("""
INSERT INTO test_partitioned_table VALUES
(1, 'Alice', 'North', 100.0),
(2, 'Bob', 'North', 200.0),
(3, 'Charlie', 'South', 150.0),
(4, NULL, 'South', 300.0),
(5, 'Eve', 'East', NULL)
""")

// Force table refresh to ensure snapshot is available
spark.sql("REFRESH TABLE test_partitioned_table")

// Verify data was inserted
val rowCount = spark.sql("SELECT COUNT(*) FROM test_partitioned_table").collect()(0).getLong(0)
rowCount should be(5)

val extractor = new IcebergPartitionStatsExtractor(spark)
val stats = extractor.extractPartitionedStats("spark_catalog", "default", "test_partitioned_table")

stats should not be empty
stats.length should be(3)

val totalRows = stats.map(_.rowCount).sum
totalRows should be(5)

// Verify specific partition stats
val southStats = stats.find(_.partitionColToValue("region") == "South").get
southStats.rowCount should be(2)
southStats.colToNullCount("value") should be(0) // No null values in South partition
southStats.colToNullCount("name") should be(1) // One null name (Charlie, NULL)
southStats.colToNullCount("id") should be(0) // No null ids

val northStats = stats.find(_.partitionColToValue("region") == "North").get
northStats.rowCount should be(2)
northStats.colToNullCount("value") should be(0) // No null values in North partition
northStats.colToNullCount("name") should be(0) // No null names (Alice, Bob)
northStats.colToNullCount("id") should be(0) // No null ids

val eastStats = stats.find(_.partitionColToValue("region") == "East").get
eastStats.rowCount should be(1)
eastStats.colToNullCount("value") should be(1) // One null value (Eve has NULL value)
eastStats.colToNullCount("name") should be(0) // No null names (Eve)
eastStats.colToNullCount("id") should be(0) // No null ids
}
}
Loading
Loading