Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions spark/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ scala_library(
maven_artifact("org.apache.thrift:libthrift"),
maven_artifact("org.apache.hadoop:hadoop-common"),
maven_artifact("org.apache.hadoop:hadoop-client-api"),
maven_artifact_with_suffix("org.apache.iceberg:iceberg-spark-runtime-3.5"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create a separate build target for this.

],
)

Expand Down Expand Up @@ -132,6 +133,7 @@ test_deps = _SCALA_TEST_DEPS + [
maven_artifact("org.apache.hive:hive-exec"),
maven_artifact("org.apache.hadoop:hadoop-common"),
maven_artifact("org.apache.hadoop:hadoop-client-api"),
maven_artifact_with_suffix("org.apache.iceberg:iceberg-spark-runtime-3.5"),
]

scala_library(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package ai.chronon.spark.iceberg

import org.apache.iceberg.catalog.TableIdentifier
import org.apache.iceberg.spark.SparkSessionCatalog
import org.apache.iceberg.{DataFile, ManifestFiles, PartitionSpec, Table}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog

import scala.collection.JavaConverters._
import scala.collection.mutable

case class PartitionStats(partitionColToValue: Map[String, String],
rowCount: Long,
colToNullCount: Map[String, Long])

class IcebergPartitionStatsExtractor(spark: SparkSession) {

def extractPartitionedStats(catalogName: String,
namespace: String,
tableName: String): Seq[PartitionStats] = {

val catalog = spark
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

take this from tableUtils

.sessionState
.catalogManager
.catalog(catalogName)
.asInstanceOf[SparkSessionCatalog[V2SessionCatalog]]
.icebergCatalog()

val tableId: TableIdentifier = TableIdentifier.of(namespace, tableName)
val table: Table = catalog.loadTable(tableId)

require(
table.spec().isPartitioned,
s"Illegal request to compute partition-stats of an un-partitioned table: ${table.name()}."
)

val currentSnapshot = Option(table.currentSnapshot())
val partitionStatsBuffer = mutable.Buffer[PartitionStats]()

currentSnapshot.foreach { snapshot =>

val manifestFiles = snapshot.allManifests(table.io()).asScala
manifestFiles.foreach { manifestFile =>
val manifestReader = ManifestFiles.read(manifestFile, table.io())

try {
manifestReader.forEach((file: DataFile) => {

val rowCount: Long = file.recordCount()

val schema = table.schema()

val partitionSpec: PartitionSpec = table.specs().get(file.specId())
val partitionFieldIds = partitionSpec.fields().asScala.map(_.sourceId()).toSet

val colToNullCount: Map[String, Long] = file
.nullValueCounts()
.asScala
.filterNot { case (fieldId, _) => partitionFieldIds.contains(fieldId) }
.map { case (fieldId, nullCount) =>
schema.findField(fieldId).name() -> nullCount.toLong
}.toMap

val partitionColToValue = partitionSpec.fields().asScala.zipWithIndex.map {
case (field, index) =>
val sourceField = schema.findField(field.sourceId())
val partitionValue = file.partition().get(index, classOf[String])

sourceField.name() -> String.valueOf(partitionValue)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Handle null partition values explicitly

String.valueOf converts null to "null" string. Consider explicit null handling.

-                val partitionValue = file.partition().get(index, classOf[String])
-
-                sourceField.name() -> String.valueOf(partitionValue)
+                val partitionValue = file.partition().get(index, classOf[String])
+                sourceField.name() -> Option(partitionValue).getOrElse("")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
val partitionValue = file.partition().get(index, classOf[String])
sourceField.name() -> String.valueOf(partitionValue)
val partitionValue = file.partition().get(index, classOf[String])
sourceField.name() -> Option(partitionValue).getOrElse("")
🤖 Prompt for AI Agents
In
spark/src/main/scala/ai/chronon/spark/iceberg/IcebergPartitionStatsExtractor.scala
around lines 67 to 69, the code uses String.valueOf on a partition value that
may be null, which converts null to the string "null". To fix this, explicitly
check if partitionValue is null and handle it accordingly, for example by
returning an empty string or a specific placeholder instead of the string
"null".

}.toMap

val fileStats = PartitionStats(partitionColToValue, rowCount, colToNullCount)
partitionStatsBuffer.append(fileStats)
})
} finally {
manifestReader.close()
}
}
}

partitionStatsBuffer
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package ai.chronon.spark.test.batch.iceberg

import ai.chronon.spark.iceberg.IcebergPartitionStatsExtractor
import org.apache.spark.sql.SparkSession
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}

import java.nio.file.Files

class IcebergPartitionStatsExtractorTest extends AnyFlatSpec with Matchers with BeforeAndAfterAll with BeforeAndAfterEach {

private var spark: SparkSession = _

override def beforeAll(): Unit = {

spark = SparkSession.builder()
.appName("IcebergPartitionStatsExtractorTest")
.master("local[*]")
.config("spark.sql.extensions", "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions")
.config("spark.sql.catalog.spark_catalog", "org.apache.iceberg.spark.SparkSessionCatalog")
.config("spark.sql.warehouse.dir", s"${System.getProperty("java.io.tmpdir")}/warehouse")
.config("spark.sql.catalog.spark_catalog.type", "hadoop")
.config("spark.sql.catalog.spark_catalog.warehouse", Files.createTempDirectory("partition-stats-test").toString)
.config("spark.driver.bindAddress", "127.0.0.1")
.config("spark.ui.enabled", "false")
.enableHiveSupport()
.getOrCreate()

spark.sparkContext.setLogLevel("WARN")
}

override def afterAll(): Unit = {
if (spark != null) {
spark.stop()
}
}

override def beforeEach(): Unit = {
spark.sql("DROP TABLE IF EXISTS test_partitioned_table")
}

"IcebergPartitionStatsExtractor" should "throw exception for unpartitioned table" in {
spark.sql(
"""
CREATE TABLE test_partitioned_table (
id BIGINT,
name STRING,
value DOUBLE
) USING iceberg
""")

val extractor = new IcebergPartitionStatsExtractor(spark)

val exception = intercept[IllegalArgumentException] {
extractor.extractPartitionedStats("spark_catalog", "default", "test_partitioned_table")
}

exception.getMessage should include("Illegal request to compute partition-stats of an un-partitioned table: spark_catalog.default.test_partitioned_table")
}

it should "return empty sequence for empty partitioned table" in {
spark.sql(
"""
CREATE TABLE test_partitioned_table (
id BIGINT,
name STRING,
region STRING,
value DOUBLE
) USING iceberg
PARTITIONED BY (region)
""")

val extractor = new IcebergPartitionStatsExtractor(spark)
val stats = extractor.extractPartitionedStats("spark_catalog", "default", "test_partitioned_table")

stats should be(empty)
}

it should "extract partition stats from table with data" in {
spark.sql(
"""
CREATE TABLE test_partitioned_table (
id BIGINT,
name STRING,
region STRING,
value DOUBLE
) USING iceberg
PARTITIONED BY (region)
""")

spark.sql(
"""
INSERT INTO test_partitioned_table VALUES
(1, 'Alice', 'North', 100.0),
(2, 'Bob', 'North', 200.0),
(3, 'Charlie', 'South', 150.0),
(4, NULL, 'South', 300.0),
(5, 'Eve', 'East', NULL)
""")

// Force table refresh to ensure snapshot is available
spark.sql("REFRESH TABLE test_partitioned_table")

// Verify data was inserted
val rowCount = spark.sql("SELECT COUNT(*) FROM test_partitioned_table").collect()(0).getLong(0)
rowCount should be(5)

val extractor = new IcebergPartitionStatsExtractor(spark)
val stats = extractor.extractPartitionedStats("spark_catalog", "default", "test_partitioned_table")

stats should not be empty
stats.length should be(3)

val totalRows = stats.map(_.rowCount).sum
totalRows should be(5)

// Verify specific partition stats
val southStats = stats.find(_.partitionColToValue("region") == "South").get
southStats.rowCount should be(2)
southStats.colToNullCount("value") should be(0) // No null values in South partition
southStats.colToNullCount("name") should be(1) // One null name (Charlie, NULL)
southStats.colToNullCount("id") should be(0) // No null ids

val northStats = stats.find(_.partitionColToValue("region") == "North").get
northStats.rowCount should be(2)
northStats.colToNullCount("value") should be(0) // No null values in North partition
northStats.colToNullCount("name") should be(0) // No null names (Alice, Bob)
northStats.colToNullCount("id") should be(0) // No null ids

val eastStats = stats.find(_.partitionColToValue("region") == "East").get
eastStats.rowCount should be(1)
eastStats.colToNullCount("value") should be(1) // One null value (Eve has NULL value)
eastStats.colToNullCount("name") should be(0) // No null names (Eve)
eastStats.colToNullCount("id") should be(0) // No null ids
}
}
Loading
Loading