Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@ import ai.chronon.api.planner.{DependencyResolver, NodeRunner}
import ai.chronon.api.{MetaData, PartitionRange, PartitionSpec, ThriftJsonCodec}
import ai.chronon.online.Api
import ai.chronon.online.KVStore.PutRequest
import ai.chronon.planner.{GroupByUploadNode, MonolithJoinNode, Node, NodeContent, StagingQueryNode}
import ai.chronon.planner._
import ai.chronon.spark.catalog.TableUtils
import ai.chronon.spark.join.UnionJoin
import ai.chronon.spark.submission.SparkSessionBuilder
import ai.chronon.spark.{GroupByUpload, Join}
import org.rogach.scallop.{ScallopConf, ScallopOption}
import org.slf4j.{Logger, LoggerFactory}

import scala.annotation.tailrec
import scala.collection.JavaConverters._
import scala.concurrent.Await
import scala.concurrent.duration.Duration
Expand Down Expand Up @@ -45,6 +46,43 @@ object BatchNodeRunner extends NodeRunner {

@transient private lazy val logger: Logger = LoggerFactory.getLogger(getClass)

def checkPartitions(conf: ExternalSourceSensorNode,
metadata: MetaData,
tableUtils: TableUtils,
range: PartitionRange): Try[Unit] = {
val tableName = conf.sourceName
val retryCount = if (conf.isSetRetryCount) conf.retryCount else 3L
val retryIntervalMin = if (conf.isSetRetryIntervalMin) conf.retryIntervalMin else 3L

val spec = metadata.executionInfo.tableDependencies.asScala
.find(_.tableInfo.table == tableName)
.map(_.tableInfo.partitionSpec(tableUtils.partitionSpec))

@tailrec
def retry(attempt: Long): Try[Unit] = {
val result = Try {
val partitionsInRange =
tableUtils.partitions(tableName, partitionRange = Option(range), tablePartitionSpec = spec)
val missingPartitions = range.partitions.diff(partitionsInRange)
if (missingPartitions.nonEmpty) {
throw new RuntimeException(
s"Input table ${tableName} is missing partitions: ${missingPartitions.mkString(", ")}")
} else {
logger.info(s"Input table ${tableName} has the requested range present: ${range}.")
}
}
result match {
case Success(value) => Success(value)
case Failure(exception) if attempt < retryCount =>
logger.warn(s"Attempt ${attempt + 1} failed, retrying in ${retryIntervalMin} minutes", exception)
Thread.sleep(retryIntervalMin * 60 * 1000)
retry(attempt + 1)
case failure => failure
}
}
retry(0)
}

private def createTableUtils(name: String): TableUtils = {
val spark = SparkSessionBuilder.build(s"batch-node-runner-${name}")
TableUtils(spark)
Expand All @@ -59,6 +97,15 @@ object BatchNodeRunner extends NodeRunner {
runGroupByUpload(metadata, conf.getGroupByUpload, range, tableUtils)
case NodeContent._Fields.STAGING_QUERY =>
runStagingQuery(metadata, conf.getStagingQuery, range, tableUtils)
case NodeContent._Fields.EXTERNAL_SOURCE_SENSOR => {

checkPartitions(conf.getExternalSourceSensor, metadata, tableUtils, range) match {
case Success(_) => System.exit(0)
case Failure(exception) =>
logger.error(s"ExternalSourceSensor check failed.", exception)
System.exit(1)
}
}
case _ =>
throw new UnsupportedOperationException(s"Unsupported NodeContent type: ${conf.getSetField}")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import ai.chronon.api.Extensions._
import ai.chronon.api._
import ai.chronon.api.planner.TableDependencies
import ai.chronon.online.KVStore.PutRequest
import ai.chronon.planner.{MonolithJoinNode, Node, NodeContent}
import ai.chronon.planner.{ExternalSourceSensorNode, MonolithJoinNode, Node, NodeContent}
import ai.chronon.spark.batch.BatchNodeRunner
import ai.chronon.spark.submission.SparkSessionBuilder
import ai.chronon.spark.test.{MockKVStore, TableTestUtils}
Expand Down Expand Up @@ -430,6 +430,113 @@ class BatchNodeRunnerTest extends AnyFlatSpec with BeforeAndAfterAll with Before
}
}

"BatchNodeRunner.checkPartitions" should "succeed when all partitions are available" in {
val sensorNode = new ExternalSourceSensorNode()
.setSourceName("test_db.input_table")
.setRetryCount(0L)
.setRetryIntervalMin(1L)

val metadata = createTestMetadata("test_db.input_table", "test_db.output_table")
val range = PartitionRange(twoDaysAgo, yesterday)(tableUtils.partitionSpec)

val result = BatchNodeRunner.checkPartitions(sensorNode, metadata, tableUtils, range)

result match {
case Success(_) =>
// Test passed
case Failure(exception) =>
fail(s"checkPartitions should have succeeded but failed with: ${exception.getMessage}")
}
}

it should "fail when partitions are missing and no retries configured" in {
val sensorNode = new ExternalSourceSensorNode()
.setSourceName("test_db.external_table")
.setRetryCount(0L)
.setRetryIntervalMin(1L)

val metadata = createTestMetadata("test_db.external_table", "test_db.output_table")
val range = PartitionRange(today, today)(tableUtils.partitionSpec) // today's partition doesn't exist

val result = BatchNodeRunner.checkPartitions(sensorNode, metadata, tableUtils, range)

result match {
case Success(_) =>
fail("checkPartitions should have failed due to missing partitions")
case Failure(exception) =>
assertTrue("Exception should mention missing partitions", exception.getMessage.contains("missing partitions"))
assertTrue("Exception should mention table name", exception.getMessage.contains("test_db.external_table"))
assertTrue("Exception should mention specific partition", exception.getMessage.contains(today))
}
}

it should "use default retry values when not set" in {
val sensorNode = new ExternalSourceSensorNode()
.setSourceName("test_db.external_table")
// Not setting retryCount and retryIntervalMin to test defaults

val metadata = createTestMetadata("test_db.external_table", "test_db.output_table")
val range = PartitionRange(today, today)(tableUtils.partitionSpec) // today's partition doesn't exist

val result = BatchNodeRunner.checkPartitions(sensorNode, metadata, tableUtils, range)

result match {
case Success(_) =>
fail("checkPartitions should have failed due to missing partitions")
case Failure(exception) =>
// Should fail immediately with default retry count of 0
assertTrue("Exception should mention missing partitions", exception.getMessage.contains("missing partitions"))
}
}

it should "retry when configured but eventually fail if partitions never appear" in {
val sensorNode = new ExternalSourceSensorNode()
.setSourceName("test_db.external_table")
.setRetryCount(2L) // Will try 3 times total (initial + 2 retries)
.setRetryIntervalMin(0L) // Set to 0 to avoid actual delays in test

val metadata = createTestMetadata("test_db.external_table", "test_db.output_table")
val range = PartitionRange(today, today)(tableUtils.partitionSpec) // today's partition doesn't exist

val startTime = System.currentTimeMillis()
val result = BatchNodeRunner.checkPartitions(sensorNode, metadata, tableUtils, range)
val endTime = System.currentTimeMillis()

result match {
case Success(_) =>
fail("checkPartitions should have failed due to missing partitions")
case Failure(exception) =>
assertTrue("Exception should mention missing partitions", exception.getMessage.contains("missing partitions"))
// Since we set retry interval to 0, the test should complete quickly
assertTrue("Test should complete within reasonable time", (endTime - startTime) < 5000)
}
}

it should "handle non-existent table gracefully" in {
val sensorNode = new ExternalSourceSensorNode()
.setSourceName("test_db.nonexistent_table")
.setRetryCount(0L)
.setRetryIntervalMin(1L)

val metadata = createTestMetadata("test_db.nonexistent_table", "test_db.output_table")
val range = PartitionRange(yesterday, yesterday)(tableUtils.partitionSpec)

val result = BatchNodeRunner.checkPartitions(sensorNode, metadata, tableUtils, range)

result match {
case Success(_) =>
fail("checkPartitions should have failed for nonexistent table")
case Failure(exception) =>
// Should fail with some kind of table not found or similar error
assertTrue(
"Exception should indicate table issue",
exception.getMessage.contains("nonexistent_table") ||
exception.getMessage.toLowerCase.contains("not found") ||
exception.getMessage.toLowerCase.contains("table")
)
}
}

override def afterAll(): Unit = {
spark.sql("DROP DATABASE IF EXISTS test_db CASCADE")
spark.stop()
Expand Down