diff --git a/api/src/main/scala/ai/chronon/api/planner/NodeRunner.scala b/api/src/main/scala/ai/chronon/api/planner/NodeRunner.scala index 057bf23b14..e6d0e00551 100644 --- a/api/src/main/scala/ai/chronon/api/planner/NodeRunner.scala +++ b/api/src/main/scala/ai/chronon/api/planner/NodeRunner.scala @@ -1,14 +1,11 @@ package ai.chronon.api.planner -import ai.chronon.api.{PartitionRange, PartitionSpec} import ai.chronon.api +import ai.chronon.api.PartitionRange +import ai.chronon.planner.NodeContent +trait NodeRunner { -trait BatchRunContext { - def partitionSpec: PartitionSpec -} -// run context in our case will be tableUtils -trait NodeRunner[Conf] { - def run(metadata: api.MetaData, conf: Conf, range: PartitionRange, batchContext: BatchRunContext) + def run(metadata: api.MetaData, conf: NodeContent, range: Option[PartitionRange]): Unit } object LineageOfflineRunner { diff --git a/api/src/main/scala/ai/chronon/api/planner/StagingQueryPlanner.scala b/api/src/main/scala/ai/chronon/api/planner/StagingQueryPlanner.scala index 6ed1f34ac4..9209d8b495 100644 --- a/api/src/main/scala/ai/chronon/api/planner/StagingQueryPlanner.scala +++ b/api/src/main/scala/ai/chronon/api/planner/StagingQueryPlanner.scala @@ -4,7 +4,7 @@ import ai.chronon.api.{StagingQuery, PartitionSpec} import ai.chronon.planner.ConfPlan import scala.collection.JavaConverters._ -class StagingQueryPlanner(stagingQuery: StagingQuery)(implicit outputPartitionSpec: PartitionSpec) +case class StagingQueryPlanner(stagingQuery: StagingQuery)(implicit outputPartitionSpec: PartitionSpec) extends Planner[StagingQuery](stagingQuery)(outputPartitionSpec) { override def buildPlan: ConfPlan = { diff --git a/api/src/main/scala/ai/chronon/api/planner/TableDependencies.scala b/api/src/main/scala/ai/chronon/api/planner/TableDependencies.scala index 9f11924166..bdb9b16cc2 100644 --- a/api/src/main/scala/ai/chronon/api/planner/TableDependencies.scala +++ b/api/src/main/scala/ai/chronon/api/planner/TableDependencies.scala @@ -3,18 +3,18 @@ import ai.chronon.api import ai.chronon.api.Extensions._ import ai.chronon.api.ScalaJavaConversions.IteratorOps import ai.chronon.api.{Accuracy, DataModel, PartitionSpec, TableDependency, TableInfo, Window} +import scala.collection.JavaConverters._ object TableDependencies { def fromStagingQuery(stagingQuery: api.StagingQuery)(implicit spec: PartitionSpec): Seq[TableDependency] = { - stagingQuery.tableDependencies - .iterator() - .toScala + Option(stagingQuery.tableDependencies) + .map(_.asScala.toSeq) + .getOrElse(Seq.empty) .map { tableDep => new TableDependency() .setTableInfo(tableDep) } - .toSeq } def fromJoin(join: api.Join)(implicit spec: PartitionSpec): Seq[TableDependency] = { diff --git a/spark/src/main/scala/ai/chronon/spark/batch/BatchNodeRunner.scala b/spark/src/main/scala/ai/chronon/spark/batch/BatchNodeRunner.scala new file mode 100644 index 0000000000..cebc70295d --- /dev/null +++ b/spark/src/main/scala/ai/chronon/spark/batch/BatchNodeRunner.scala @@ -0,0 +1,62 @@ +package ai.chronon.spark.batch + +import ai.chronon.api.{MetaData, PartitionRange} +import ai.chronon.api.planner.NodeRunner +import ai.chronon.planner.NodeContent +import ai.chronon.spark.Join +import ai.chronon.spark.catalog.TableUtils +import ai.chronon.spark.join.UnionJoin +import ai.chronon.spark.submission.SparkSessionBuilder +import org.slf4j.{Logger, LoggerFactory} + +object BatchNodeRunner extends NodeRunner { + + @transient private lazy val logger: Logger = LoggerFactory.getLogger(getClass) + + private def tableUtils(name: String): TableUtils = { + val spark = SparkSessionBuilder.build(s"batch-node-runner-${name}") + TableUtils(spark) + } + + private[batch] def run(metadata: MetaData, conf: NodeContent, range: PartitionRange, tableUtils: TableUtils): Unit = { + conf.getSetField match { + case NodeContent._Fields.MONOLITH_JOIN => { + val monolithJoin = conf.getMonolithJoin + require(monolithJoin.isSetJoin, "MonolithJoinNode must have a join set") + val joinConf = monolithJoin.join + val joinName = metadata.name + val skewFreeMode = + tableUtils.sparkSession.conf.get("spark.chronon.join.backfill.mode.skewFree", "false").toBoolean + logger.info(s" >>> Running join backfill with skewFree mode set to: ${skewFreeMode} <<< ") + if (skewFreeMode) { + + logger.info(s"Filling partitions for join:$joinName, partitions:[${range.start}, ${range.end}]") + + logger.info(s"Processing range $range)") + UnionJoin.computeJoinAndSave(joinConf, range)(tableUtils) + logger.info(s"Wrote range $range)") + + } else { + + val join = new Join( + joinConf, + range.end, + tableUtils + ) + + val df = join.computeJoin(overrideStartPartition = Option(range.start)) + + df.show(numRows = 3, truncate = 0, vertical = true) + logger.info(s"\nShowing three rows of output above.\nQuery table `${joinName}` for more.\n") + } + + } + case _ => throw new UnsupportedOperationException("Unsupported NodeContent type: " + conf.getClass.getName) + } + } + + override def run(metadata: MetaData, conf: NodeContent, range: Option[PartitionRange]): Unit = { + require(range.isDefined, "Partition range must be defined for batch node runner") + run(metadata, conf, range.get, tableUtils(metadata.name)) + } +} diff --git a/spark/src/test/scala/ai/chronon/spark/test/batch/BatchNodeRunnerTest.scala b/spark/src/test/scala/ai/chronon/spark/test/batch/BatchNodeRunnerTest.scala new file mode 100644 index 0000000000..65d7c0c565 --- /dev/null +++ b/spark/src/test/scala/ai/chronon/spark/test/batch/BatchNodeRunnerTest.scala @@ -0,0 +1,91 @@ +package ai.chronon.spark.batch + +import ai.chronon.aggregator.test.Column +import ai.chronon.api +import ai.chronon.api.{Accuracy, Builders, Operation, PartitionRange, TimeUnit, Window} +import ai.chronon.planner.{MonolithJoinNode, NodeContent} +import ai.chronon.spark.join.UnionJoin +import ai.chronon.spark.test.DataFrameGen +import ai.chronon.spark.test.join.BaseJoinTest +import ai.chronon.spark.Extensions._ +import org.scalatest.matchers.should.Matchers + +class BatchNodeRunnerTest extends BaseJoinTest with Matchers { + + "BatchNodeRunner" should "run a monolith join node" in { + + val viewsSchema = List( + Column("user", api.StringType, 1), + Column("item", api.StringType, 1), + Column("time_spent_ms", api.LongType, 5000) + ) + + val viewsTable = s"$namespace.view_union_temporal" + DataFrameGen + .events(spark, viewsSchema, count = 10000, partitions = 20) + .save(viewsTable, Map("tblProp1" -> "1")) + + val viewsDf = tableUtils.loadTable(viewsTable) + + val viewsSource = Builders.Source.events( + table = viewsTable, + topic = "", + query = Builders.Query(selects = Builders.Selects("time_spent_ms"), + startPartition = tableUtils.partitionSpec.minus(today, new Window(20, TimeUnit.DAYS))) + ) + + val viewsGroupBy = Builders + .GroupBy( + sources = Seq(viewsSource), + keyColumns = Seq("item"), + aggregations = Seq( + Builders.Aggregation(operation = Operation.AVERAGE, inputColumn = "time_spent_ms"), + Builders.Aggregation( + operation = Operation.LAST_K, + argMap = Map("k" -> "50"), + inputColumn = "time_spent_ms", + windows = Seq(new Window(2, TimeUnit.DAYS)) + ) + ), + metaData = Builders.MetaData(name = "unit_test.item_views", namespace = namespace) + ) + .setAccuracy(Accuracy.TEMPORAL) + + // left side + val itemQueries = List(Column("item", api.StringType, 1)) + val itemQueriesTable = s"$namespace.item_queries_union_temporal" + val itemQueriesDf = DataFrameGen + .events(spark, itemQueries, 10000, partitions = 10) + + // duplicate the events + itemQueriesDf.union(itemQueriesDf).save(itemQueriesTable) + + val queriesDf = tableUtils.loadTable(itemQueriesTable) + + val start = tableUtils.partitionSpec.minus(today, new Window(20, TimeUnit.DAYS)) + val dateRange = PartitionRange(start, today)(tableUtils.partitionSpec) + + val joinConf = Builders.Join( + left = Builders.Source.events(Builders.Query(startPartition = start), table = itemQueriesTable), + joinParts = Seq(Builders.JoinPart(groupBy = viewsGroupBy, prefix = "user")), + metaData = + Builders.MetaData(name = s"item_temporal_features_union_join", namespace = namespace, team = "item_team") + ) + + // Test UnionJoin.computeJoinAndSave method + + UnionJoin.computeJoinAndSave(joinConf, dateRange) + + val joinNodeContent = new NodeContent() + joinNodeContent.setMonolithJoin(new MonolithJoinNode().setJoin(joinConf)) + + BatchNodeRunner.run(joinConf.metaData, joinNodeContent, dateRange, tableUtils) + + val outputDf = tableUtils.loadTable(f"${namespace}.${joinConf.metaData.name}") + + val outputData = outputDf.where("item IS NOT NULL and ts IS NOT NULL").collect() + val queriesData = queriesDf.where("item IS NOT NULL and ts IS NOT NULL").collect() + outputData.length shouldBe queriesData.length + } + +}