Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions api/python/ai/chronon/repo/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
"log-flattener",
"metadata-export",
"label-join",
"source-job",
"join-part-job",
"merge-job",
]
MODES_USING_EMBEDDED = ["metadata-upload", "fetch", "local-streaming"]

Expand All @@ -52,6 +55,9 @@
"metadata-export": OFFLINE_ARGS,
"label-join": OFFLINE_ARGS,
"streaming-client": ONLINE_WRITE_ARGS,
"source-job": OFFLINE_ARGS,
"join-part-job": OFFLINE_ARGS,
"merge-job": OFFLINE_ARGS,
"metastore": "--partition-names={partition_names}",
"info": "",
}
Expand Down Expand Up @@ -83,6 +89,9 @@
"log-flattener": "log-flattener",
"metadata-export": "metadata-export",
"label-join": "label-join",
"source-job": "source-job",
"join-part-job": "join-part-job",
"merge-job": "merge-job",
},
"staging_queries": {
"backfill": "staging-query-backfill",
Expand Down
4 changes: 3 additions & 1 deletion api/python/ai/chronon/repo/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def set_defaults(ctx):
help="Validate the catalyst util Spark expression evaluation logic")
@click.option("--validate-rows", default="10000",
help="Number of rows to run the validation on")
@click.option("--join-part-name", help="Name of the join part to use for join-part-job")
@click.pass_context
def main(
ctx,
Expand Down Expand Up @@ -196,7 +197,8 @@ def main(
mock_source,
savepoint_uri,
validate,
validate_rows
validate_rows,
join_part_name
):
unknown_args = ctx.args
click.echo("Running with args: {}".format(ctx.params))
Expand Down
184 changes: 179 additions & 5 deletions spark/src/main/scala/ai/chronon/spark/Driver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,16 @@
package ai.chronon.spark

import ai.chronon.api
import ai.chronon.api.Constants
import ai.chronon.api.{Constants, DateRange, RelevantLeftForJoinPart, ThriftJsonCodec}
import ai.chronon.api.Constants.MetadataDataset
import ai.chronon.api.Extensions.GroupByOps
import ai.chronon.api.Extensions.MetadataOps
import ai.chronon.api.Extensions.SourceOps
import ai.chronon.api.ThriftJsonCodec
import ai.chronon.api.Extensions.{GroupByOps, JoinPartOps, MetadataOps, SourceOps}
import ai.chronon.api.thrift.TBase
import ai.chronon.online.Api
import ai.chronon.online.MetadataDirWalker
import ai.chronon.online.MetadataEndPoint
import ai.chronon.online.TopicChecker
import ai.chronon.online.fetcher.{ConfPathOrName, FetchContext, FetcherMain, MetadataStore}
import ai.chronon.orchestration.{JoinMergeNode, JoinPartNode}
import ai.chronon.spark.stats.CompareBaseJob
import ai.chronon.spark.stats.CompareJob
import ai.chronon.spark.stats.ConsistencyJob
Expand Down Expand Up @@ -823,6 +821,173 @@ object Driver {
}
}

object SourceJobRun {
@transient lazy val logger: Logger = LoggerFactory.getLogger(getClass)
class Args
extends Subcommand("source-job")
with OfflineSubcommand
with LocalExportTableAbility
with ResultValidationAbility {
lazy val joinConf: api.Join = parseConf[api.Join](confPath())
override def subcommandName(): String = s"source_job_${joinConf.metaData.name}"
}

def run(args: Args): Unit = {
val tableUtils = args.buildTableUtils()
val join = args.joinConf

// Create a SourceWithFilterNode from the join's left source
val source = join.left
val outputTable = JoinUtils.computeLeftSourceTableName(join)

// Create a SourceWithFilterNode with the extracted information
val sourceWithFilterNode = new ai.chronon.orchestration.SourceWithFilterNode()
sourceWithFilterNode.setSource(source)
sourceWithFilterNode.setExcludeKeys(join.skewKeys)

// Set the metadata
val sourceOutputTable = JoinUtils.computeLeftSourceTableName(join)
println(s"Source output table: $sourceOutputTable")

// Split the output table to get namespace and name
val sourceParts = sourceOutputTable.split("\\.", 2)
val sourceNamespace = sourceParts(0)
val sourceName = sourceParts(1)

// Create metadata for source job
val sourceMetaData = new api.MetaData()
.setName(sourceName)
.setOutputNamespace(sourceNamespace)

sourceWithFilterNode.setMetaData(sourceMetaData)

// Calculate the date range
val endDate = args.endDate()
val startDate: String = args.startPartitionOverride.getOrElse(args.endDate())
val dateRange = new DateRange()
.setStartDate(startDate)
.setEndDate(endDate)

// Run the SourceJob
val sourceJob = new SourceJob(sourceWithFilterNode, dateRange)(tableUtils)
sourceJob.run()

logger.info(s"SourceJob completed. Output table: ${outputTable}")

if (args.shouldExport()) {
args.exportTableToLocal(outputTable, tableUtils)
}
}
}

object JoinPartJobRun {
@transient lazy val logger: Logger = LoggerFactory.getLogger(getClass)
class Args
extends Subcommand("join-part-job")
with OfflineSubcommand
with LocalExportTableAbility
with ResultValidationAbility {

val joinPartName: ScallopOption[String] =
opt[String](required = true, descr = "Name of the join part to run")

lazy val joinConf: api.Join = parseConf[api.Join](confPath())
override def subcommandName(): String = s"join_part_job_${joinConf.metaData.name}"
}

def run(args: Args): Unit = {
val tableUtils = args.buildTableUtils()
val join = args.joinConf
val joinPartName = args.joinPartName()

// Find the selected join part
val joinPart = join.joinParts.asScala
.find(part => part.fullPrefix == joinPartName)
.getOrElse(
throw new RuntimeException(s"JoinPart with name $joinPartName not found in join ${join.metaData.name}"))

logger.info(s"Found join part: ${joinPart.fullPrefix}")

// Create a JoinPartNode from the join part
val joinPartNode = new JoinPartNode()
.setJoinPart(joinPart)
.setLeftSourceTable(JoinUtils.computeLeftSourceTableName(join))
.setLeftDataModel(join.left.dataModel.toString)
.setSkewKeys(join.skewKeys)

// Set the metadata
val joinPartTableName = RelevantLeftForJoinPart.partTableName(join, joinPart)
val outputNamespace = join.metaData.outputNamespace
val metadata = new ai.chronon.api.MetaData()
.setName(joinPartTableName)
.setOutputNamespace(outputNamespace)

joinPartNode.setMetaData(metadata)

// Calculate the date range
val endDate = args.endDate()
val startDate = args.startPartitionOverride.getOrElse(args.endDate())
val dateRange = new DateRange()
.setStartDate(startDate)
.setEndDate(endDate)

// Run the JoinPartJob
val joinPartJob = new JoinPartJob(joinPartNode, dateRange, showDf = false)(tableUtils)
joinPartJob.run()

logger.info(s"JoinPartJob completed. Output table: ${joinPartNode.metaData.outputTable}")

if (args.shouldExport()) {
args.exportTableToLocal(joinPartNode.metaData.outputTable, tableUtils)
}
}
}

object MergeJobRun {
@transient lazy val logger: Logger = LoggerFactory.getLogger(getClass)
class Args
extends Subcommand("merge-job")
with OfflineSubcommand
with LocalExportTableAbility
with ResultValidationAbility {
lazy val joinConf: api.Join = parseConf[api.Join](confPath())
override def subcommandName(): String = s"merge_job_${joinConf.metaData.name}"
}

def run(args: Args): Unit = {
val tableUtils = args.buildTableUtils()
val joinConf = args.joinConf

// TODO -- when we support bootstrapping in the modular flow from Driver, we'll need to omit
// Bootstrapped JoinParts here
val allJoinParts = joinConf.joinParts.asScala

val endDate = args.endDate()
val startDate = args.startPartitionOverride.getOrElse(endDate)
val dateRange = new DateRange()
.setStartDate(startDate)
.setEndDate(endDate)

// Create metadata for merge job
val mergeMetaData = new api.MetaData()
.setName(joinConf.metaData.name)
.setOutputNamespace(joinConf.metaData.outputNamespace)

val mergeNode = new JoinMergeNode()
.setJoin(joinConf)
.setMetaData(mergeMetaData)

val mergeJob = new MergeJob(mergeNode, dateRange, allJoinParts)(tableUtils)

mergeJob.run()

logger.info(s"MergeJob completed. Output table: ${joinConf.metaData.outputTable}")
if (args.shouldExport()) {
args.exportTableToLocal(joinConf.metaData.outputTable, tableUtils)
}
}
}

object CheckPartitions {
private val helpNamingConvention =
"Please follow the naming convention: --partition-names=schema.table/pk1=pv1/pk2=pv2"
Expand Down Expand Up @@ -960,6 +1125,12 @@ object Driver {
addSubcommand(CreateStatsTableArgs)
object SummarizeAndUploadArgs extends SummarizeAndUpload.Args
addSubcommand(SummarizeAndUploadArgs)
object SourceJobRunArgs extends SourceJobRun.Args
addSubcommand(SourceJobRunArgs)
object JoinPartJobRunArgs extends JoinPartJobRun.Args
addSubcommand(JoinPartJobRunArgs)
object MergeJobRunArgs extends MergeJobRun.Args
addSubcommand(MergeJobRunArgs)
object CheckPartitionArgs extends CheckPartitions.Args
addSubcommand(CheckPartitionArgs)
requireSubcommand()
Expand Down Expand Up @@ -1003,6 +1174,9 @@ object Driver {
case args.JoinBackfillFinalArgs => JoinBackfillFinal.run(args.JoinBackfillFinalArgs)
case args.CreateStatsTableArgs => CreateSummaryDataset.run(args.CreateStatsTableArgs)
case args.SummarizeAndUploadArgs => SummarizeAndUpload.run(args.SummarizeAndUploadArgs)
case args.SourceJobRunArgs => SourceJobRun.run(args.SourceJobRunArgs)
case args.JoinPartJobRunArgs => JoinPartJobRun.run(args.JoinPartJobRunArgs)
case args.MergeJobRunArgs => MergeJobRun.run(args.MergeJobRunArgs)
case args.CheckPartitionArgs => CheckPartitions.run(args.CheckPartitionArgs)
case _ => logger.info(s"Unknown subcommand: $x")
}
Expand Down