Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 0 additions & 26 deletions spark/src/main/scala/ai/chronon/spark/Extensions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ import ai.chronon.online.SparkConversions
import ai.chronon.online.TimeRange
import org.apache.avro.Schema
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.LongType
Expand Down Expand Up @@ -298,30 +296,6 @@ object Extensions {
}
}

implicit class InternalRowOps(internalRow: InternalRow) {
def toRow: Row = {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused seems like.

new Row() {
override def length: Int = {
internalRow.numFields
}

override def get(i: Int): Any = {
internalRow.get(i, schema.fields(i).dataType)
}

override def copy(): Row = internalRow.copy().toRow
}
}
}

implicit class TupleToJMapOps[K, V](tuples: Iterator[(K, V)]) {
def toJMap: util.Map[K, V] = {
val map = new util.HashMap[K, V]()
tuples.foreach { case (k, v) => map.put(k, v) }
map
}
}

implicit class DataPointerOps(dataPointer: DataPointer) {
def toDf(implicit sparkSession: SparkSession): DataFrame = {
val tableOrPath = dataPointer.tableOrPath
Expand Down
9 changes: 6 additions & 3 deletions spark/src/main/scala/ai/chronon/spark/Join.scala
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,12 @@ class Join(joinConf: api.Join,
if (skipBloomFilter) {
None
} else {
val leftBlooms = joinConf.leftKeyCols.iterator.map { key =>
key -> bootstrapDf.generateBloomFilter(key, leftRowCount, joinConf.left.table, leftRange)
}.toJMap
val leftBlooms = joinConf.leftKeyCols.iterator
.map { key =>
key -> bootstrapDf.generateBloomFilter(key, leftRowCount, joinConf.left.table, leftRange)
}
.toMap
.asJava
Some(leftBlooms)
}
}
Expand Down
11 changes: 7 additions & 4 deletions spark/src/main/scala/ai/chronon/spark/JoinUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -304,10 +304,13 @@ object JoinUtils {
joinLevelBloomMapOpt: Option[util.Map[String, BloomFilter]]): Option[util.Map[String, BloomFilter]] = {

val rightBlooms = joinLevelBloomMapOpt.map { joinBlooms =>
joinPart.rightToLeft.iterator.map {
case (rightCol, leftCol) =>
rightCol -> joinBlooms.get(leftCol)
}.toJMap
joinPart.rightToLeft.iterator
.map {
case (rightCol, leftCol) =>
rightCol -> joinBlooms.get(leftCol)
}
.toMap
.asJava
}

// print bloom sizes
Expand Down
11 changes: 7 additions & 4 deletions spark/src/main/scala/ai/chronon/spark/LabelJoin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,12 @@ class LabelJoin(joinConf: api.Join, tableUtils: TableUtils, labelDS: String) {

def computeRange(leftDf: DataFrame, leftRange: PartitionRange, sanitizedLabelDs: String): DataFrame = {
val leftDfCount = leftDf.count()
val leftBlooms = labelJoinConf.leftKeyCols.iterator.map { key =>
key -> leftDf.generateBloomFilter(key, leftDfCount, joinConf.left.table, leftRange)
}.toJMap
val leftBlooms = labelJoinConf.leftKeyCols.iterator
.map { key =>
key -> leftDf.generateBloomFilter(key, leftDfCount, joinConf.left.table, leftRange)
}
.toMap
.asJava

// compute joinParts in parallel
val rightDfs = labelJoinConf.labels.asScala.map { labelJoinPart =>
Expand Down Expand Up @@ -241,7 +244,7 @@ class LabelJoin(joinConf: api.Join, tableUtils: TableUtils, labelDS: String) {
PartitionRange(labelDS, labelDS),
tableUtils,
computeDependency = true,
Option(rightBloomMap.iterator.toJMap),
Option(rightBloomMap.toMap.asJava),
rightSkewFilter)

val df = (joinConf.left.dataModel, joinPart.groupBy.dataModel, joinPart.groupBy.inferredAccuracy) match {
Expand Down
Loading