Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 0 additions & 19 deletions api/src/main/scala/ai/chronon/api/planner/NodeRunner.scala

This file was deleted.

47 changes: 47 additions & 0 deletions spark/src/main/scala/ai/chronon/spark/batch/BatchNodeRunner.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package ai.chronon.spark.batch

import ai.chronon.api.{MetaData, PartitionRange}
import ai.chronon.planner.NodeContent
import ai.chronon.spark.join.UnionJoin
import ai.chronon.spark.Driver.JoinBackfill.logger
import ai.chronon.spark.Join
import ai.chronon.spark.catalog.TableUtils
import ai.chronon.spark.submission.SparkSessionBuilder

class BatchNodeRunner extends NodeRunner {

override def run(metadata: MetaData, conf: NodeContent, range: PartitionRange, tableUtils: TableUtils): Unit =
conf.getSetField match {
case NodeContent._Fields.MONOLITH_JOIN => {
val monolithJoin = conf.getMonolithJoin
require(monolithJoin.isSetJoin, "MonolithJoinNode must have a join set")
val spark = SparkSessionBuilder.build(f"node-batch-${metadata.name}")
val joinConf = monolithJoin.join
val joinName = joinConf.metaData.name

if (tableUtils.sparkSession.conf.get("spark.chronon.join.backfill.mode.skewFree", "false").toBoolean) {
logger.info(s" >>> Running join backfill in skew free mode <<< ")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably worth including these / similar log lines in the skew join case too?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


logger.info(s"Filling partitions for join:$joinName, partitions:[${range.start}, ${range.end}]")

logger.info(s"Processing range $range)")
UnionJoin.computeJoinAndSave(joinConf, range)(tableUtils)
logger.info(s"Wrote range $range)")

}

val join = new Join(
joinConf,
range.end,
tableUtils
)

val df = join.computeJoin(overrideStartPartition = Option(range.start))

df.show(numRows = 3, truncate = 0, vertical = true)
logger.info(s"\nShowing three rows of output above.\nQuery table `${joinName}` for more.\n")

}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Session never stopped

SparkSessionBuilder.build starts a new session but run never calls stop, leaking executors.

@@
-        SparkSessionBuilder.build(f"node-batch-${metadata.name}")
+        val spark = SparkSessionBuilder.build(f"node-batch-${metadata.name}")
@@
         logger.info(s"\nShowing three rows of output above.\nQuery table `${joinName}` for more.\n")
+
+        spark.stop()

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In spark/src/main/scala/ai/chronon/spark/batch/BatchNodeRunner.scala between
lines 11 and 44, the SparkSession created by SparkSessionBuilder.build is never
stopped, causing resource leaks. To fix this, ensure that after all processing
with the SparkSession is complete, you call spark.stop() to properly release
resources. Add spark.stop() at the end of the run method or use a try-finally
block to guarantee the session is stopped even if exceptions occur.

case _ => throw new UnsupportedOperationException("Unsupported NodeContent type: " + conf.getClass.getName)
}
}
17 changes: 17 additions & 0 deletions spark/src/main/scala/ai/chronon/spark/batch/NodeRunner.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package ai.chronon.spark.batch

import ai.chronon.api
import ai.chronon.api.PartitionRange
import ai.chronon.planner.NodeContent
import ai.chronon.spark.catalog.TableUtils
trait NodeRunner {

def run(metadata: api.MetaData, conf: NodeContent, range: PartitionRange, tableUtils: TableUtils): Unit
}

object LineageOfflineRunner {
def readFiles(folderPath: String): Seq[Any] = {
// read files from folder using metadata
Seq.empty
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package ai.chronon.spark.test.batch

import ai.chronon.aggregator.test.Column
import ai.chronon.api
import ai.chronon.api.{Accuracy, Builders, Operation, PartitionRange, TimeUnit, Window}
import ai.chronon.spark.batch.BatchNodeRunner
import ai.chronon.spark.join.UnionJoin
import ai.chronon.spark.test.DataFrameGen
import ai.chronon.spark.test.join.BaseJoinTest
import ai.chronon.planner.{MonolithJoinNode, NodeContent}
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper
import ai.chronon.spark.Extensions._

class BatchNodeRunnerTest extends BaseJoinTest {

"BatchNodeRunner" should "run a monolith join node" in {

val viewsSchema = List(
Column("user", api.StringType, 1),
Column("item", api.StringType, 1),
Column("time_spent_ms", api.LongType, 5000)
)

val viewsTable = s"$namespace.view_union_temporal"
DataFrameGen
.events(spark, viewsSchema, count = 10000, partitions = 20)
.save(viewsTable, Map("tblProp1" -> "1"))

val viewsDf = tableUtils.loadTable(viewsTable)

val viewsSource = Builders.Source.events(
table = viewsTable,
topic = "",
query = Builders.Query(selects = Builders.Selects("time_spent_ms"),
startPartition = tableUtils.partitionSpec.minus(today, new Window(20, TimeUnit.DAYS)))
)

val viewsGroupBy = Builders
.GroupBy(
sources = Seq(viewsSource),
keyColumns = Seq("item"),
aggregations = Seq(
Builders.Aggregation(operation = Operation.AVERAGE, inputColumn = "time_spent_ms"),
Builders.Aggregation(
operation = Operation.LAST_K,
argMap = Map("k" -> "50"),
inputColumn = "time_spent_ms",
windows = Seq(new Window(2, TimeUnit.DAYS))
)
),
metaData = Builders.MetaData(name = "unit_test.item_views", namespace = namespace)
)
.setAccuracy(Accuracy.TEMPORAL)

// left side
val itemQueries = List(Column("item", api.StringType, 1))
val itemQueriesTable = s"$namespace.item_queries_union_temporal"
val itemQueriesDf = DataFrameGen
.events(spark, itemQueries, 10000, partitions = 10)

// duplicate the events
itemQueriesDf.union(itemQueriesDf).save(itemQueriesTable)

val queriesDf = tableUtils.loadTable(itemQueriesTable)

val start = tableUtils.partitionSpec.minus(today, new Window(20, TimeUnit.DAYS))
val dateRange = PartitionRange(start, today)(tableUtils.partitionSpec)

val joinConf = Builders.Join(
left = Builders.Source.events(Builders.Query(startPartition = start), table = itemQueriesTable),
joinParts = Seq(Builders.JoinPart(groupBy = viewsGroupBy, prefix = "user")),
metaData =
Builders.MetaData(name = s"item_temporal_features_union_join", namespace = namespace, team = "item_team")
)

// Test UnionJoin.computeJoinAndSave method

UnionJoin.computeJoinAndSave(joinConf, dateRange)
val batchNodeRunner = new BatchNodeRunner()

val joinNodeContent = new NodeContent()
joinNodeContent.setMonolithJoin(new MonolithJoinNode().setJoin(joinConf))

batchNodeRunner.run(joinConf.metaData, joinNodeContent, dateRange, tableUtils)

val outputDf = tableUtils.loadTable(f"${namespace}.${joinConf.metaData.name}")

val outputData = outputDf.where("item IS NOT NULL and ts IS NOT NULL").collect()
val queriesData = queriesDf.where("item IS NOT NULL and ts IS NOT NULL").collect()
outputData.length shouldBe queriesData.length
}

}
Loading