diff --git a/api/python/ai/chronon/repo/constants.py b/api/python/ai/chronon/repo/constants.py index c0c4616bab..c2fa33ac9f 100644 --- a/api/python/ai/chronon/repo/constants.py +++ b/api/python/ai/chronon/repo/constants.py @@ -26,6 +26,9 @@ "log-flattener", "metadata-export", "label-join", + "source-job", + "join-part-job", + "merge-job", ] MODES_USING_EMBEDDED = ["metadata-upload", "fetch", "local-streaming"] @@ -52,6 +55,9 @@ "metadata-export": OFFLINE_ARGS, "label-join": OFFLINE_ARGS, "streaming-client": ONLINE_WRITE_ARGS, + "source-job": OFFLINE_ARGS, + "join-part-job": OFFLINE_ARGS, + "merge-job": OFFLINE_ARGS, "metastore": "--partition-names={partition_names}", "info": "", } @@ -83,6 +89,9 @@ "log-flattener": "log-flattener", "metadata-export": "metadata-export", "label-join": "label-join", + "source-job": "source-job", + "join-part-job": "join-part-job", + "merge-job": "merge-job", }, "staging_queries": { "backfill": "staging-query-backfill", diff --git a/api/python/ai/chronon/repo/run.py b/api/python/ai/chronon/repo/run.py index 87f89025ca..631c783bc4 100755 --- a/api/python/ai/chronon/repo/run.py +++ b/api/python/ai/chronon/repo/run.py @@ -164,6 +164,7 @@ def set_defaults(ctx): help="Validate the catalyst util Spark expression evaluation logic") @click.option("--validate-rows", default="10000", help="Number of rows to run the validation on") +@click.option("--join-part-name", help="Name of the join part to use for join-part-job") @click.pass_context def main( ctx, @@ -196,7 +197,8 @@ def main( mock_source, savepoint_uri, validate, - validate_rows + validate_rows, + join_part_name ): unknown_args = ctx.args click.echo("Running with args: {}".format(ctx.params)) diff --git a/spark/src/main/scala/ai/chronon/spark/Driver.scala b/spark/src/main/scala/ai/chronon/spark/Driver.scala index 356cae3ca8..abda8b95f5 100644 --- a/spark/src/main/scala/ai/chronon/spark/Driver.scala +++ b/spark/src/main/scala/ai/chronon/spark/Driver.scala @@ -17,18 +17,16 @@ package ai.chronon.spark import ai.chronon.api -import ai.chronon.api.Constants +import ai.chronon.api.{Constants, DateRange, RelevantLeftForJoinPart, ThriftJsonCodec} import ai.chronon.api.Constants.MetadataDataset -import ai.chronon.api.Extensions.GroupByOps -import ai.chronon.api.Extensions.MetadataOps -import ai.chronon.api.Extensions.SourceOps -import ai.chronon.api.ThriftJsonCodec +import ai.chronon.api.Extensions.{GroupByOps, JoinPartOps, MetadataOps, SourceOps} import ai.chronon.api.thrift.TBase import ai.chronon.online.Api import ai.chronon.online.MetadataDirWalker import ai.chronon.online.MetadataEndPoint import ai.chronon.online.TopicChecker import ai.chronon.online.fetcher.{ConfPathOrName, FetchContext, FetcherMain, MetadataStore} +import ai.chronon.orchestration.{JoinMergeNode, JoinPartNode} import ai.chronon.spark.stats.CompareBaseJob import ai.chronon.spark.stats.CompareJob import ai.chronon.spark.stats.ConsistencyJob @@ -823,6 +821,173 @@ object Driver { } } + object SourceJobRun { + @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) + class Args + extends Subcommand("source-job") + with OfflineSubcommand + with LocalExportTableAbility + with ResultValidationAbility { + lazy val joinConf: api.Join = parseConf[api.Join](confPath()) + override def subcommandName(): String = s"source_job_${joinConf.metaData.name}" + } + + def run(args: Args): Unit = { + val tableUtils = args.buildTableUtils() + val join = args.joinConf + + // Create a SourceWithFilterNode from the join's left source + val source = join.left + val outputTable = JoinUtils.computeLeftSourceTableName(join) + + // Create a SourceWithFilterNode with the extracted information + val sourceWithFilterNode = new ai.chronon.orchestration.SourceWithFilterNode() + sourceWithFilterNode.setSource(source) + sourceWithFilterNode.setExcludeKeys(join.skewKeys) + + // Set the metadata + val sourceOutputTable = JoinUtils.computeLeftSourceTableName(join) + println(s"Source output table: $sourceOutputTable") + + // Split the output table to get namespace and name + val sourceParts = sourceOutputTable.split("\\.", 2) + val sourceNamespace = sourceParts(0) + val sourceName = sourceParts(1) + + // Create metadata for source job + val sourceMetaData = new api.MetaData() + .setName(sourceName) + .setOutputNamespace(sourceNamespace) + + sourceWithFilterNode.setMetaData(sourceMetaData) + + // Calculate the date range + val endDate = args.endDate() + val startDate: String = args.startPartitionOverride.getOrElse(args.endDate()) + val dateRange = new DateRange() + .setStartDate(startDate) + .setEndDate(endDate) + + // Run the SourceJob + val sourceJob = new SourceJob(sourceWithFilterNode, dateRange)(tableUtils) + sourceJob.run() + + logger.info(s"SourceJob completed. Output table: ${outputTable}") + + if (args.shouldExport()) { + args.exportTableToLocal(outputTable, tableUtils) + } + } + } + + object JoinPartJobRun { + @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) + class Args + extends Subcommand("join-part-job") + with OfflineSubcommand + with LocalExportTableAbility + with ResultValidationAbility { + + val joinPartName: ScallopOption[String] = + opt[String](required = true, descr = "Name of the join part to run") + + lazy val joinConf: api.Join = parseConf[api.Join](confPath()) + override def subcommandName(): String = s"join_part_job_${joinConf.metaData.name}" + } + + def run(args: Args): Unit = { + val tableUtils = args.buildTableUtils() + val join = args.joinConf + val joinPartName = args.joinPartName() + + // Find the selected join part + val joinPart = join.joinParts.asScala + .find(part => part.fullPrefix == joinPartName) + .getOrElse( + throw new RuntimeException(s"JoinPart with name $joinPartName not found in join ${join.metaData.name}")) + + logger.info(s"Found join part: ${joinPart.fullPrefix}") + + // Create a JoinPartNode from the join part + val joinPartNode = new JoinPartNode() + .setJoinPart(joinPart) + .setLeftSourceTable(JoinUtils.computeLeftSourceTableName(join)) + .setLeftDataModel(join.left.dataModel.toString) + .setSkewKeys(join.skewKeys) + + // Set the metadata + val joinPartTableName = RelevantLeftForJoinPart.partTableName(join, joinPart) + val outputNamespace = join.metaData.outputNamespace + val metadata = new ai.chronon.api.MetaData() + .setName(joinPartTableName) + .setOutputNamespace(outputNamespace) + + joinPartNode.setMetaData(metadata) + + // Calculate the date range + val endDate = args.endDate() + val startDate = args.startPartitionOverride.getOrElse(args.endDate()) + val dateRange = new DateRange() + .setStartDate(startDate) + .setEndDate(endDate) + + // Run the JoinPartJob + val joinPartJob = new JoinPartJob(joinPartNode, dateRange, showDf = false)(tableUtils) + joinPartJob.run() + + logger.info(s"JoinPartJob completed. Output table: ${joinPartNode.metaData.outputTable}") + + if (args.shouldExport()) { + args.exportTableToLocal(joinPartNode.metaData.outputTable, tableUtils) + } + } + } + + object MergeJobRun { + @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) + class Args + extends Subcommand("merge-job") + with OfflineSubcommand + with LocalExportTableAbility + with ResultValidationAbility { + lazy val joinConf: api.Join = parseConf[api.Join](confPath()) + override def subcommandName(): String = s"merge_job_${joinConf.metaData.name}" + } + + def run(args: Args): Unit = { + val tableUtils = args.buildTableUtils() + val joinConf = args.joinConf + + // TODO -- when we support bootstrapping in the modular flow from Driver, we'll need to omit + // Bootstrapped JoinParts here + val allJoinParts = joinConf.joinParts.asScala + + val endDate = args.endDate() + val startDate = args.startPartitionOverride.getOrElse(endDate) + val dateRange = new DateRange() + .setStartDate(startDate) + .setEndDate(endDate) + + // Create metadata for merge job + val mergeMetaData = new api.MetaData() + .setName(joinConf.metaData.name) + .setOutputNamespace(joinConf.metaData.outputNamespace) + + val mergeNode = new JoinMergeNode() + .setJoin(joinConf) + .setMetaData(mergeMetaData) + + val mergeJob = new MergeJob(mergeNode, dateRange, allJoinParts)(tableUtils) + + mergeJob.run() + + logger.info(s"MergeJob completed. Output table: ${joinConf.metaData.outputTable}") + if (args.shouldExport()) { + args.exportTableToLocal(joinConf.metaData.outputTable, tableUtils) + } + } + } + object CheckPartitions { private val helpNamingConvention = "Please follow the naming convention: --partition-names=schema.table/pk1=pv1/pk2=pv2" @@ -960,6 +1125,12 @@ object Driver { addSubcommand(CreateStatsTableArgs) object SummarizeAndUploadArgs extends SummarizeAndUpload.Args addSubcommand(SummarizeAndUploadArgs) + object SourceJobRunArgs extends SourceJobRun.Args + addSubcommand(SourceJobRunArgs) + object JoinPartJobRunArgs extends JoinPartJobRun.Args + addSubcommand(JoinPartJobRunArgs) + object MergeJobRunArgs extends MergeJobRun.Args + addSubcommand(MergeJobRunArgs) object CheckPartitionArgs extends CheckPartitions.Args addSubcommand(CheckPartitionArgs) requireSubcommand() @@ -1003,6 +1174,9 @@ object Driver { case args.JoinBackfillFinalArgs => JoinBackfillFinal.run(args.JoinBackfillFinalArgs) case args.CreateStatsTableArgs => CreateSummaryDataset.run(args.CreateStatsTableArgs) case args.SummarizeAndUploadArgs => SummarizeAndUpload.run(args.SummarizeAndUploadArgs) + case args.SourceJobRunArgs => SourceJobRun.run(args.SourceJobRunArgs) + case args.JoinPartJobRunArgs => JoinPartJobRun.run(args.JoinPartJobRunArgs) + case args.MergeJobRunArgs => MergeJobRun.run(args.MergeJobRunArgs) case args.CheckPartitionArgs => CheckPartitions.run(args.CheckPartitionArgs) case _ => logger.info(s"Unknown subcommand: $x") }