Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 67 additions & 93 deletions spark/src/main/scala/ai/chronon/spark/batch/BatchNodeRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import ai.chronon.api.{MetaData, PartitionRange, PartitionSpec, ThriftJsonCodec}
import ai.chronon.online.Api
import ai.chronon.online.KVStore.PutRequest
import ai.chronon.planner._
import ai.chronon.spark.batch.BatchNodeRunner.DefaultTablePartitionsDataset
import ai.chronon.spark.catalog.TableUtils
import ai.chronon.spark.join.UnionJoin
import ai.chronon.spark.submission.SparkSessionBuilder
Expand Down Expand Up @@ -42,19 +41,18 @@ class BatchNodeRunnerArgs(args: Array[String]) extends ScallopConf(args) {

val tablePartitionsDataset = opt[String](required = true,
descr = "Name of table in kv store to use to keep track of partitions",
default = Option(DefaultTablePartitionsDataset))
default = Option("TABLE_PARTITOINS"))

verify()
}

object BatchNodeRunner extends NodeRunner {

class BatchNodeRunner(node: Node) extends NodeRunner {
@transient private lazy val logger: Logger = LoggerFactory.getLogger(getClass)

def checkPartitions(conf: ExternalSourceSensorNode,
metadata: MetaData,
tableUtils: TableUtils,
range: PartitionRange): Try[Unit] = {
private val sparkSession = SparkSessionBuilder.build(s"batch-node-runner-${node.metaData.name}")
private lazy val tableUtils = TableUtils(sparkSession)

def checkPartitions(conf: ExternalSourceSensorNode, metadata: MetaData, range: PartitionRange): Try[Unit] = {
val tableName = conf.sourceName
val retryCount = if (conf.isSetRetryCount) conf.retryCount else 3L
val retryIntervalMin = if (conf.isSetRetryIntervalMin) conf.retryIntervalMin else 3L
Expand Down Expand Up @@ -88,47 +86,7 @@ object BatchNodeRunner extends NodeRunner {
retry(0)
}

private def createTableUtils(name: String): TableUtils = {
val spark = SparkSessionBuilder.build(s"batch-node-runner-${name}")
TableUtils(spark)
}

private[batch] def run(metadata: MetaData, conf: NodeContent, range: PartitionRange, tableUtils: TableUtils): Unit = {

conf.getSetField match {
case NodeContent._Fields.MONOLITH_JOIN =>
runMonolithJoin(metadata, conf.getMonolithJoin, range, tableUtils)
case NodeContent._Fields.GROUP_BY_UPLOAD =>
runGroupByUpload(metadata, conf.getGroupByUpload, range, tableUtils)
case NodeContent._Fields.GROUP_BY_BACKFILL =>
logger.info(s"Running groupBy backfill for '${metadata.name}' for range: [${range.start}, ${range.end}]")
GroupBy.computeBackfill(
conf.getGroupByBackfill.groupBy,
range.end,
tableUtils,
overrideStartPartition = Option(range.start)
)
logger.info(s"Successfully completed groupBy backfill for '${metadata.name}'")
case NodeContent._Fields.STAGING_QUERY =>
runStagingQuery(metadata, conf.getStagingQuery, range, tableUtils)
case NodeContent._Fields.EXTERNAL_SOURCE_SENSOR => {

checkPartitions(conf.getExternalSourceSensor, metadata, tableUtils, range) match {
case Success(_) => System.exit(0)
case Failure(exception) =>
logger.error(s"ExternalSourceSensor check failed.", exception)
System.exit(1)
}
}
case _ =>
throw new UnsupportedOperationException(s"Unsupported NodeContent type: ${conf.getSetField}")
}
}

private def runStagingQuery(metaData: MetaData,
stagingQuery: StagingQueryNode,
range: PartitionRange,
tableUtils: TableUtils): Unit = {
private def runStagingQuery(metaData: MetaData, stagingQuery: StagingQueryNode, range: PartitionRange): Unit = {
require(stagingQuery.isSetStagingQuery, "StagingQueryNode must have a stagingQuery set")
logger.info(s"Running staging query for '${metaData.name}'")
val stagingQueryConf = stagingQuery.stagingQuery
Expand All @@ -143,22 +101,16 @@ object BatchNodeRunner extends NodeRunner {
logger.info(s"Successfully completed staging query for '${metaData.name}'")
}

private def runGroupByUpload(metadata: MetaData,
groupByUpload: GroupByUploadNode,
range: PartitionRange,
tableUtils: TableUtils): Unit = {
private def runGroupByUpload(metadata: MetaData, groupByUpload: GroupByUploadNode, range: PartitionRange): Unit = {
require(groupByUpload.isSetGroupBy, "GroupByUploadNode must have a groupBy set")
val groupBy = groupByUpload.groupBy
logger.info(s"Running groupBy upload for '${metadata.name}' for day: ${range.end}")

GroupByUpload.run(groupBy, range.end, Some(tableUtils))
GroupByUpload.run(groupBy, range.end, Option(tableUtils))
logger.info(s"Successfully completed groupBy upload for '${metadata.name}' for day: ${range.end}")
}

private def runMonolithJoin(metadata: MetaData,
monolithJoin: MonolithJoinNode,
range: PartitionRange,
tableUtils: TableUtils): Unit = {
private def runMonolithJoin(metadata: MetaData, monolithJoin: MonolithJoinNode, range: PartitionRange): Unit = {
require(monolithJoin.isSetJoin, "MonolithJoinNode must have a join set")

val joinConf = monolithJoin.join
Expand All @@ -185,21 +137,47 @@ object BatchNodeRunner extends NodeRunner {
}
}

override def run(metadata: MetaData, conf: NodeContent, range: Option[PartitionRange]): Unit = {
require(range.isDefined, "Partition range must be defined for batch node runner")
override def run(metadata: MetaData, conf: NodeContent, maybeRange: Option[PartitionRange]): Unit = {
require(maybeRange.isDefined, "Partition range must be defined for batch node runner")
val range = maybeRange.get
conf.getSetField match {
case NodeContent._Fields.MONOLITH_JOIN =>
runMonolithJoin(metadata, conf.getMonolithJoin, range)
case NodeContent._Fields.GROUP_BY_UPLOAD =>
runGroupByUpload(metadata, conf.getGroupByUpload, range)
case NodeContent._Fields.GROUP_BY_BACKFILL =>
logger.info(s"Running groupBy backfill for '${metadata.name}' for range: [${range.start}, ${range.end}]")
GroupBy.computeBackfill(
conf.getGroupByBackfill.groupBy,
range.end,
tableUtils,
overrideStartPartition = Option(range.start)
)
logger.info(s"Successfully completed groupBy backfill for '${metadata.name}'")
case NodeContent._Fields.STAGING_QUERY =>
runStagingQuery(metadata, conf.getStagingQuery, range)
case NodeContent._Fields.EXTERNAL_SOURCE_SENSOR => {

run(metadata, conf, range.get, createTableUtils(metadata.name))
checkPartitions(conf.getExternalSourceSensor, metadata, range) match {
case Success(_) => System.exit(0)
case Failure(exception) =>
logger.error(s"ExternalSourceSensor check failed.", exception)
System.exit(1)
}
}
case _ =>
throw new UnsupportedOperationException(s"Unsupported NodeContent type: ${conf.getSetField}")
}
}

def runFromArgs(api: Api,
confPath: String,
startDs: String,
endDs: String,
tablePartitionsDataset: String): Try[Unit] = {
def runFromArgs(
api: Api,
startDs: String,
endDs: String,
tablePartitionsDataset: String
): Try[Unit] = {
Try {
val node = ThriftJsonCodec.fromJsonFile[Node](confPath, check = true)
val metadata = node.metaData
val tableUtils = createTableUtils(metadata.name)
val range = PartitionRange(startDs, endDs)(PartitionSpec.daily)
val kvStore = api.genKvStore

Expand Down Expand Up @@ -275,6 +253,27 @@ object BatchNodeRunner extends NodeRunner {
}
}
}
}

object BatchNodeRunner {

def main(args: Array[String]): Unit = {
val batchArgs = new BatchNodeRunnerArgs(args)
val exitCode = Try { ThriftJsonCodec.fromJsonFile[Node](batchArgs.confPath(), check = true) }
.flatMap((node) => {
val runner = new BatchNodeRunner(node)
val api = instantiateApi(batchArgs.onlineClass(), batchArgs.apiProps)
runner.runFromArgs(api, batchArgs.startDs(), batchArgs.endDs(), batchArgs.tablePartitionsDataset())
}) match {
case Success(_) =>
println("Batch node runner succeeded")
0
case Failure(exception) =>
println("Batch node runner failed", exception)
1
}
System.exit(exitCode)
}

def instantiateApi(onlineClass: String, props: Map[String, String]): Api = {
val cl = Thread.currentThread().getContextClassLoader
Expand All @@ -283,29 +282,4 @@ object BatchNodeRunner extends NodeRunner {
val onlineImpl = constructor.newInstance(props)
onlineImpl.asInstanceOf[Api]
}

def main(args: Array[String]): Unit = {
try {
val batchArgs = new BatchNodeRunnerArgs(args)
val api = instantiateApi(batchArgs.onlineClass(), batchArgs.apiProps)
runFromArgs(api,
batchArgs.confPath(),
batchArgs.startDs(),
batchArgs.endDs(),
batchArgs.tablePartitionsDataset()) match {
case Success(_) =>
logger.info("Batch node runner completed successfully")
System.exit(0)
case Failure(exception) =>
logger.error("Batch node runner failed", exception)
System.exit(1)
}
} catch {
case e: Exception =>
logger.error("Failed to parse arguments or initialize runner", e)
System.exit(1)
}
}

// override def tablePartitionsDataset(): String = tablePartitionsDataset
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@ package ai.chronon.spark.kv_store

import ai.chronon.api.Constants.MetadataDataset
import ai.chronon.api.Extensions.MetadataOps
import ai.chronon.api.planner.NodeRunner
import ai.chronon.api._
import ai.chronon.api.planner.NodeRunner
import ai.chronon.online.Api
import ai.chronon.online.fetcher.{FetchContext, MetadataStore}
import ai.chronon.planner.{Node, NodeContent}
import ai.chronon.spark.batch.BatchNodeRunner.DefaultTablePartitionsDataset
import org.rogach.scallop.ScallopConf
import org.slf4j.{Logger, LoggerFactory}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,12 @@ import ai.chronon.api
import ai.chronon.api.Extensions._
import ai.chronon.api.ScalaJavaConversions.MapOps
import ai.chronon.api._
import ai.chronon.planner._
import ai.chronon.spark.Extensions._
import ai.chronon.spark.batch._
import ai.chronon.spark.submission.SparkSessionBuilder
import ai.chronon.spark.test.{DataFrameGen, TableTestUtils}
import ai.chronon.spark.{GroupBy, Join, _}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{col, rand}
import org.junit.Assert.assertEquals
import org.scalatest.flatspec.AnyFlatSpec
import ai.chronon.planner.{JoinBootstrapNode, JoinDerivationNode, JoinMergeNode, JoinPartNode, SourceWithFilterNode}
import ai.chronon.spark.{Join, _}
import org.slf4j.LoggerFactory

class ShortNamesTest extends AnyFlatSpec {
Expand Down
Loading