diff --git a/spark/src/main/scala/ai/chronon/spark/batch/BatchNodeRunner.scala b/spark/src/main/scala/ai/chronon/spark/batch/BatchNodeRunner.scala index 07873bdfcf..40cbef1b30 100644 --- a/spark/src/main/scala/ai/chronon/spark/batch/BatchNodeRunner.scala +++ b/spark/src/main/scala/ai/chronon/spark/batch/BatchNodeRunner.scala @@ -5,7 +5,7 @@ import ai.chronon.api.planner.{DependencyResolver, NodeRunner} import ai.chronon.api.{MetaData, PartitionRange, PartitionSpec, ThriftJsonCodec} import ai.chronon.online.Api import ai.chronon.online.KVStore.PutRequest -import ai.chronon.planner.{GroupByUploadNode, MonolithJoinNode, Node, NodeContent, StagingQueryNode} +import ai.chronon.planner._ import ai.chronon.spark.catalog.TableUtils import ai.chronon.spark.join.UnionJoin import ai.chronon.spark.submission.SparkSessionBuilder @@ -13,6 +13,7 @@ import ai.chronon.spark.{GroupByUpload, Join} import org.rogach.scallop.{ScallopConf, ScallopOption} import org.slf4j.{Logger, LoggerFactory} +import scala.annotation.tailrec import scala.collection.JavaConverters._ import scala.concurrent.Await import scala.concurrent.duration.Duration @@ -45,6 +46,43 @@ object BatchNodeRunner extends NodeRunner { @transient private lazy val logger: Logger = LoggerFactory.getLogger(getClass) + def checkPartitions(conf: ExternalSourceSensorNode, + metadata: MetaData, + tableUtils: TableUtils, + range: PartitionRange): Try[Unit] = { + val tableName = conf.sourceName + val retryCount = if (conf.isSetRetryCount) conf.retryCount else 3L + val retryIntervalMin = if (conf.isSetRetryIntervalMin) conf.retryIntervalMin else 3L + + val spec = metadata.executionInfo.tableDependencies.asScala + .find(_.tableInfo.table == tableName) + .map(_.tableInfo.partitionSpec(tableUtils.partitionSpec)) + + @tailrec + def retry(attempt: Long): Try[Unit] = { + val result = Try { + val partitionsInRange = + tableUtils.partitions(tableName, partitionRange = Option(range), tablePartitionSpec = spec) + val missingPartitions = range.partitions.diff(partitionsInRange) + if (missingPartitions.nonEmpty) { + throw new RuntimeException( + s"Input table ${tableName} is missing partitions: ${missingPartitions.mkString(", ")}") + } else { + logger.info(s"Input table ${tableName} has the requested range present: ${range}.") + } + } + result match { + case Success(value) => Success(value) + case Failure(exception) if attempt < retryCount => + logger.warn(s"Attempt ${attempt + 1} failed, retrying in ${retryIntervalMin} minutes", exception) + Thread.sleep(retryIntervalMin * 60 * 1000) + retry(attempt + 1) + case failure => failure + } + } + retry(0) + } + private def createTableUtils(name: String): TableUtils = { val spark = SparkSessionBuilder.build(s"batch-node-runner-${name}") TableUtils(spark) @@ -59,6 +97,15 @@ object BatchNodeRunner extends NodeRunner { runGroupByUpload(metadata, conf.getGroupByUpload, range, tableUtils) case NodeContent._Fields.STAGING_QUERY => runStagingQuery(metadata, conf.getStagingQuery, range, tableUtils) + case NodeContent._Fields.EXTERNAL_SOURCE_SENSOR => { + + checkPartitions(conf.getExternalSourceSensor, metadata, tableUtils, range) match { + case Success(_) => System.exit(0) + case Failure(exception) => + logger.error(s"ExternalSourceSensor check failed.", exception) + System.exit(1) + } + } case _ => throw new UnsupportedOperationException(s"Unsupported NodeContent type: ${conf.getSetField}") } diff --git a/spark/src/test/scala/ai/chronon/spark/test/batch/BatchNodeRunnerTest.scala b/spark/src/test/scala/ai/chronon/spark/test/batch/BatchNodeRunnerTest.scala index 42b9f2f426..374c34202e 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/batch/BatchNodeRunnerTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/batch/BatchNodeRunnerTest.scala @@ -20,7 +20,7 @@ import ai.chronon.api.Extensions._ import ai.chronon.api._ import ai.chronon.api.planner.TableDependencies import ai.chronon.online.KVStore.PutRequest -import ai.chronon.planner.{MonolithJoinNode, Node, NodeContent} +import ai.chronon.planner.{ExternalSourceSensorNode, MonolithJoinNode, Node, NodeContent} import ai.chronon.spark.batch.BatchNodeRunner import ai.chronon.spark.submission.SparkSessionBuilder import ai.chronon.spark.test.{MockKVStore, TableTestUtils} @@ -430,6 +430,113 @@ class BatchNodeRunnerTest extends AnyFlatSpec with BeforeAndAfterAll with Before } } + "BatchNodeRunner.checkPartitions" should "succeed when all partitions are available" in { + val sensorNode = new ExternalSourceSensorNode() + .setSourceName("test_db.input_table") + .setRetryCount(0L) + .setRetryIntervalMin(1L) + + val metadata = createTestMetadata("test_db.input_table", "test_db.output_table") + val range = PartitionRange(twoDaysAgo, yesterday)(tableUtils.partitionSpec) + + val result = BatchNodeRunner.checkPartitions(sensorNode, metadata, tableUtils, range) + + result match { + case Success(_) => + // Test passed + case Failure(exception) => + fail(s"checkPartitions should have succeeded but failed with: ${exception.getMessage}") + } + } + + it should "fail when partitions are missing and no retries configured" in { + val sensorNode = new ExternalSourceSensorNode() + .setSourceName("test_db.external_table") + .setRetryCount(0L) + .setRetryIntervalMin(1L) + + val metadata = createTestMetadata("test_db.external_table", "test_db.output_table") + val range = PartitionRange(today, today)(tableUtils.partitionSpec) // today's partition doesn't exist + + val result = BatchNodeRunner.checkPartitions(sensorNode, metadata, tableUtils, range) + + result match { + case Success(_) => + fail("checkPartitions should have failed due to missing partitions") + case Failure(exception) => + assertTrue("Exception should mention missing partitions", exception.getMessage.contains("missing partitions")) + assertTrue("Exception should mention table name", exception.getMessage.contains("test_db.external_table")) + assertTrue("Exception should mention specific partition", exception.getMessage.contains(today)) + } + } + + it should "use default retry values when not set" in { + val sensorNode = new ExternalSourceSensorNode() + .setSourceName("test_db.external_table") + // Not setting retryCount and retryIntervalMin to test defaults + + val metadata = createTestMetadata("test_db.external_table", "test_db.output_table") + val range = PartitionRange(today, today)(tableUtils.partitionSpec) // today's partition doesn't exist + + val result = BatchNodeRunner.checkPartitions(sensorNode, metadata, tableUtils, range) + + result match { + case Success(_) => + fail("checkPartitions should have failed due to missing partitions") + case Failure(exception) => + // Should fail immediately with default retry count of 0 + assertTrue("Exception should mention missing partitions", exception.getMessage.contains("missing partitions")) + } + } + + it should "retry when configured but eventually fail if partitions never appear" in { + val sensorNode = new ExternalSourceSensorNode() + .setSourceName("test_db.external_table") + .setRetryCount(2L) // Will try 3 times total (initial + 2 retries) + .setRetryIntervalMin(0L) // Set to 0 to avoid actual delays in test + + val metadata = createTestMetadata("test_db.external_table", "test_db.output_table") + val range = PartitionRange(today, today)(tableUtils.partitionSpec) // today's partition doesn't exist + + val startTime = System.currentTimeMillis() + val result = BatchNodeRunner.checkPartitions(sensorNode, metadata, tableUtils, range) + val endTime = System.currentTimeMillis() + + result match { + case Success(_) => + fail("checkPartitions should have failed due to missing partitions") + case Failure(exception) => + assertTrue("Exception should mention missing partitions", exception.getMessage.contains("missing partitions")) + // Since we set retry interval to 0, the test should complete quickly + assertTrue("Test should complete within reasonable time", (endTime - startTime) < 5000) + } + } + + it should "handle non-existent table gracefully" in { + val sensorNode = new ExternalSourceSensorNode() + .setSourceName("test_db.nonexistent_table") + .setRetryCount(0L) + .setRetryIntervalMin(1L) + + val metadata = createTestMetadata("test_db.nonexistent_table", "test_db.output_table") + val range = PartitionRange(yesterday, yesterday)(tableUtils.partitionSpec) + + val result = BatchNodeRunner.checkPartitions(sensorNode, metadata, tableUtils, range) + + result match { + case Success(_) => + fail("checkPartitions should have failed for nonexistent table") + case Failure(exception) => + // Should fail with some kind of table not found or similar error + assertTrue( + "Exception should indicate table issue", + exception.getMessage.contains("nonexistent_table") || + exception.getMessage.toLowerCase.contains("not found") || + exception.getMessage.toLowerCase.contains("table") + ) + } + } + override def afterAll(): Unit = { spark.sql("DROP DATABASE IF EXISTS test_db CASCADE") spark.stop()