Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions api/src/main/scala/ai/chronon/api/planner/NodeRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ import ai.chronon.api.PartitionRange
import ai.chronon.planner.NodeContent
trait NodeRunner {

val DefaultTablePartitionsDataset = "TABLE_PARTITIONS"

def run(metadata: api.MetaData, conf: NodeContent, range: Option[PartitionRange]): Unit
}

object NodeRunner {
val DefaultTablePartitionsDataset = "TABLE_PARTITIONS"
}

object LineageOfflineRunner {
def readFiles(folderPath: String): Seq[Any] = {
// read files from folder using metadata
Expand Down
157 changes: 65 additions & 92 deletions spark/src/main/scala/ai/chronon/spark/batch/BatchNodeRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import ai.chronon.api.{MetaData, PartitionRange, PartitionSpec, ThriftJsonCodec}
import ai.chronon.online.Api
import ai.chronon.online.KVStore.PutRequest
import ai.chronon.planner._
import ai.chronon.spark.batch.BatchNodeRunner.DefaultTablePartitionsDataset
import ai.chronon.spark.catalog.TableUtils
import ai.chronon.spark.join.UnionJoin
import ai.chronon.spark.submission.SparkSessionBuilder
Expand Down Expand Up @@ -42,19 +41,15 @@ class BatchNodeRunnerArgs(args: Array[String]) extends ScallopConf(args) {

val tablePartitionsDataset = opt[String](required = true,
descr = "Name of table in kv store to use to keep track of partitions",
default = Option(DefaultTablePartitionsDataset))
default = Option(NodeRunner.DefaultTablePartitionsDataset))

verify()
}

object BatchNodeRunner extends NodeRunner {

class BatchNodeRunner(node: Node, tableUtils: TableUtils) extends NodeRunner {
@transient private lazy val logger: Logger = LoggerFactory.getLogger(getClass)

def checkPartitions(conf: ExternalSourceSensorNode,
metadata: MetaData,
tableUtils: TableUtils,
range: PartitionRange): Try[Unit] = {
def checkPartitions(conf: ExternalSourceSensorNode, metadata: MetaData, range: PartitionRange): Try[Unit] = {
val tableName = conf.sourceName
val retryCount = if (conf.isSetRetryCount) conf.retryCount else 3L
val retryIntervalMin = if (conf.isSetRetryIntervalMin) conf.retryIntervalMin else 3L
Expand Down Expand Up @@ -88,47 +83,7 @@ object BatchNodeRunner extends NodeRunner {
retry(0)
}

private def createTableUtils(name: String): TableUtils = {
val spark = SparkSessionBuilder.build(s"batch-node-runner-${name}")
TableUtils(spark)
}

private[batch] def run(metadata: MetaData, conf: NodeContent, range: PartitionRange, tableUtils: TableUtils): Unit = {

conf.getSetField match {
case NodeContent._Fields.MONOLITH_JOIN =>
runMonolithJoin(metadata, conf.getMonolithJoin, range, tableUtils)
case NodeContent._Fields.GROUP_BY_UPLOAD =>
runGroupByUpload(metadata, conf.getGroupByUpload, range, tableUtils)
case NodeContent._Fields.GROUP_BY_BACKFILL =>
logger.info(s"Running groupBy backfill for '${metadata.name}' for range: [${range.start}, ${range.end}]")
GroupBy.computeBackfill(
conf.getGroupByBackfill.groupBy,
range.end,
tableUtils,
overrideStartPartition = Option(range.start)
)
logger.info(s"Successfully completed groupBy backfill for '${metadata.name}'")
case NodeContent._Fields.STAGING_QUERY =>
runStagingQuery(metadata, conf.getStagingQuery, range, tableUtils)
case NodeContent._Fields.EXTERNAL_SOURCE_SENSOR => {

checkPartitions(conf.getExternalSourceSensor, metadata, tableUtils, range) match {
case Success(_) => System.exit(0)
case Failure(exception) =>
logger.error(s"ExternalSourceSensor check failed.", exception)
System.exit(1)
}
}
case _ =>
throw new UnsupportedOperationException(s"Unsupported NodeContent type: ${conf.getSetField}")
}
}

private def runStagingQuery(metaData: MetaData,
stagingQuery: StagingQueryNode,
range: PartitionRange,
tableUtils: TableUtils): Unit = {
private def runStagingQuery(metaData: MetaData, stagingQuery: StagingQueryNode, range: PartitionRange): Unit = {
require(stagingQuery.isSetStagingQuery, "StagingQueryNode must have a stagingQuery set")
logger.info(s"Running staging query for '${metaData.name}'")
val stagingQueryConf = stagingQuery.stagingQuery
Expand All @@ -143,22 +98,16 @@ object BatchNodeRunner extends NodeRunner {
logger.info(s"Successfully completed staging query for '${metaData.name}'")
}

private def runGroupByUpload(metadata: MetaData,
groupByUpload: GroupByUploadNode,
range: PartitionRange,
tableUtils: TableUtils): Unit = {
private def runGroupByUpload(metadata: MetaData, groupByUpload: GroupByUploadNode, range: PartitionRange): Unit = {
require(groupByUpload.isSetGroupBy, "GroupByUploadNode must have a groupBy set")
val groupBy = groupByUpload.groupBy
logger.info(s"Running groupBy upload for '${metadata.name}' for day: ${range.end}")

GroupByUpload.run(groupBy, range.end, Some(tableUtils))
GroupByUpload.run(groupBy, range.end, Option(tableUtils))
logger.info(s"Successfully completed groupBy upload for '${metadata.name}' for day: ${range.end}")
}

private def runMonolithJoin(metadata: MetaData,
monolithJoin: MonolithJoinNode,
range: PartitionRange,
tableUtils: TableUtils): Unit = {
private def runMonolithJoin(metadata: MetaData, monolithJoin: MonolithJoinNode, range: PartitionRange): Unit = {
require(monolithJoin.isSetJoin, "MonolithJoinNode must have a join set")

val joinConf = monolithJoin.join
Expand All @@ -185,21 +134,47 @@ object BatchNodeRunner extends NodeRunner {
}
}

override def run(metadata: MetaData, conf: NodeContent, range: Option[PartitionRange]): Unit = {
require(range.isDefined, "Partition range must be defined for batch node runner")
override def run(metadata: MetaData, conf: NodeContent, maybeRange: Option[PartitionRange]): Unit = {
require(maybeRange.isDefined, "Partition range must be defined for batch node runner")
val range = maybeRange.get
conf.getSetField match {
case NodeContent._Fields.MONOLITH_JOIN =>
runMonolithJoin(metadata, conf.getMonolithJoin, range)
case NodeContent._Fields.GROUP_BY_UPLOAD =>
runGroupByUpload(metadata, conf.getGroupByUpload, range)
case NodeContent._Fields.GROUP_BY_BACKFILL =>
logger.info(s"Running groupBy backfill for '${metadata.name}' for range: [${range.start}, ${range.end}]")
GroupBy.computeBackfill(
conf.getGroupByBackfill.groupBy,
range.end,
tableUtils,
overrideStartPartition = Option(range.start)
)
logger.info(s"Successfully completed groupBy backfill for '${metadata.name}'")
case NodeContent._Fields.STAGING_QUERY =>
runStagingQuery(metadata, conf.getStagingQuery, range)
case NodeContent._Fields.EXTERNAL_SOURCE_SENSOR => {

run(metadata, conf, range.get, createTableUtils(metadata.name))
checkPartitions(conf.getExternalSourceSensor, metadata, range) match {
case Success(_) => System.exit(0)
case Failure(exception) =>
logger.error(s"ExternalSourceSensor check failed.", exception)
System.exit(1)
}
}
case _ =>
throw new UnsupportedOperationException(s"Unsupported NodeContent type: ${conf.getSetField}")
}
}

def runFromArgs(api: Api,
confPath: String,
startDs: String,
endDs: String,
tablePartitionsDataset: String): Try[Unit] = {
def runFromArgs(
api: Api,
startDs: String,
endDs: String,
tablePartitionsDataset: String
): Try[Unit] = {
Try {
val node = ThriftJsonCodec.fromJsonFile[Node](confPath, check = true)
val metadata = node.metaData
val tableUtils = createTableUtils(metadata.name)
val range = PartitionRange(startDs, endDs)(PartitionSpec.daily)
val kvStore = api.genKvStore

Expand Down Expand Up @@ -275,37 +250,35 @@ object BatchNodeRunner extends NodeRunner {
}
}
}
}

def instantiateApi(onlineClass: String, props: Map[String, String]): Api = {
val cl = Thread.currentThread().getContextClassLoader
val cls = cl.loadClass(onlineClass)
val constructor = cls.getConstructors.apply(0)
val onlineImpl = constructor.newInstance(props)
onlineImpl.asInstanceOf[Api]
}
object BatchNodeRunner {

def main(args: Array[String]): Unit = {
try {
val batchArgs = new BatchNodeRunnerArgs(args)
val api = instantiateApi(batchArgs.onlineClass(), batchArgs.apiProps)
runFromArgs(api,
batchArgs.confPath(),
batchArgs.startDs(),
batchArgs.endDs(),
batchArgs.tablePartitionsDataset()) match {
val batchArgs = new BatchNodeRunnerArgs(args)
val node = ThriftJsonCodec.fromJsonFile[Node](batchArgs.confPath(), check = true)
val tableUtils = TableUtils(SparkSessionBuilder.build(s"batch-node-runner-${node.metaData.name}"))
val runner = new BatchNodeRunner(node, tableUtils)
val api = instantiateApi(batchArgs.onlineClass(), batchArgs.apiProps)
val exitCode = {
runner.runFromArgs(api, batchArgs.startDs(), batchArgs.endDs(), batchArgs.tablePartitionsDataset()) match {
case Success(_) =>
logger.info("Batch node runner completed successfully")
System.exit(0)
println("Batch node runner succeeded")
0
case Failure(exception) =>
logger.error("Batch node runner failed", exception)
System.exit(1)
println("Batch node runner failed", exception)
1
}
} catch {
case e: Exception =>
logger.error("Failed to parse arguments or initialize runner", e)
System.exit(1)
}
tableUtils.sparkSession.stop()
System.exit(exitCode)
}

// override def tablePartitionsDataset(): String = tablePartitionsDataset
def instantiateApi(onlineClass: String, props: Map[String, String]): Api = {
val cl = Thread.currentThread().getContextClassLoader
val cls = cl.loadClass(onlineClass)
val constructor = cls.getConstructors.apply(0)
val onlineImpl = constructor.newInstance(props)
onlineImpl.asInstanceOf[Api]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@ package ai.chronon.spark.kv_store

import ai.chronon.api.Constants.MetadataDataset
import ai.chronon.api.Extensions.MetadataOps
import ai.chronon.api.planner.NodeRunner
import ai.chronon.api._
import ai.chronon.api.planner.NodeRunner
import ai.chronon.online.Api
import ai.chronon.online.fetcher.{FetchContext, MetadataStore}
import ai.chronon.planner.{Node, NodeContent}
import ai.chronon.spark.batch.BatchNodeRunner.DefaultTablePartitionsDataset
import org.rogach.scallop.ScallopConf
import org.slf4j.{Logger, LoggerFactory}

Expand Down Expand Up @@ -82,7 +81,6 @@ class KVUploadNodeRunner(api: Api) extends NodeRunner {
}
}

// override def tablePartitionsDataset(): String = tablePartitionsDataset
}

object KVUploadNodeRunner {
Expand All @@ -96,7 +94,7 @@ object KVUploadNodeRunner {
val apiProps: Map[String, String] = props[String]('Z', descr = "Props to configure API Store")
val tablePartitionsDataset = opt[String](required = true,
descr = "Name of table in kv store to use to keep track of partitions",
default = Option(DefaultTablePartitionsDataset))
default = Option(NodeRunner.DefaultTablePartitionsDataset))
verify()
}

Expand Down
Loading