diff --git a/spark/src/main/scala/ai/chronon/spark/Extensions.scala b/spark/src/main/scala/ai/chronon/spark/Extensions.scala index 3018f28e7a..a0e650aef3 100644 --- a/spark/src/main/scala/ai/chronon/spark/Extensions.scala +++ b/spark/src/main/scala/ai/chronon/spark/Extensions.scala @@ -26,9 +26,7 @@ import ai.chronon.online.SparkConversions import ai.chronon.online.TimeRange import org.apache.avro.Schema import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.Row import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.LongType @@ -298,30 +296,6 @@ object Extensions { } } - implicit class InternalRowOps(internalRow: InternalRow) { - def toRow: Row = { - new Row() { - override def length: Int = { - internalRow.numFields - } - - override def get(i: Int): Any = { - internalRow.get(i, schema.fields(i).dataType) - } - - override def copy(): Row = internalRow.copy().toRow - } - } - } - - implicit class TupleToJMapOps[K, V](tuples: Iterator[(K, V)]) { - def toJMap: util.Map[K, V] = { - val map = new util.HashMap[K, V]() - tuples.foreach { case (k, v) => map.put(k, v) } - map - } - } - implicit class DataPointerOps(dataPointer: DataPointer) { def toDf(implicit sparkSession: SparkSession): DataFrame = { val tableOrPath = dataPointer.tableOrPath diff --git a/spark/src/main/scala/ai/chronon/spark/Join.scala b/spark/src/main/scala/ai/chronon/spark/Join.scala index d5de5aaa84..14a421ff2c 100644 --- a/spark/src/main/scala/ai/chronon/spark/Join.scala +++ b/spark/src/main/scala/ai/chronon/spark/Join.scala @@ -274,9 +274,12 @@ class Join(joinConf: api.Join, if (skipBloomFilter) { None } else { - val leftBlooms = joinConf.leftKeyCols.iterator.map { key => - key -> bootstrapDf.generateBloomFilter(key, leftRowCount, joinConf.left.table, leftRange) - }.toJMap + val leftBlooms = joinConf.leftKeyCols.iterator + .map { key => + key -> bootstrapDf.generateBloomFilter(key, leftRowCount, joinConf.left.table, leftRange) + } + .toMap + .asJava Some(leftBlooms) } } diff --git a/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala b/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala index ef49e8397d..c3768ccdf7 100644 --- a/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala @@ -304,10 +304,13 @@ object JoinUtils { joinLevelBloomMapOpt: Option[util.Map[String, BloomFilter]]): Option[util.Map[String, BloomFilter]] = { val rightBlooms = joinLevelBloomMapOpt.map { joinBlooms => - joinPart.rightToLeft.iterator.map { - case (rightCol, leftCol) => - rightCol -> joinBlooms.get(leftCol) - }.toJMap + joinPart.rightToLeft.iterator + .map { + case (rightCol, leftCol) => + rightCol -> joinBlooms.get(leftCol) + } + .toMap + .asJava } // print bloom sizes diff --git a/spark/src/main/scala/ai/chronon/spark/LabelJoin.scala b/spark/src/main/scala/ai/chronon/spark/LabelJoin.scala index 633e07639f..66e9b0af64 100644 --- a/spark/src/main/scala/ai/chronon/spark/LabelJoin.scala +++ b/spark/src/main/scala/ai/chronon/spark/LabelJoin.scala @@ -156,9 +156,12 @@ class LabelJoin(joinConf: api.Join, tableUtils: TableUtils, labelDS: String) { def computeRange(leftDf: DataFrame, leftRange: PartitionRange, sanitizedLabelDs: String): DataFrame = { val leftDfCount = leftDf.count() - val leftBlooms = labelJoinConf.leftKeyCols.iterator.map { key => - key -> leftDf.generateBloomFilter(key, leftDfCount, joinConf.left.table, leftRange) - }.toJMap + val leftBlooms = labelJoinConf.leftKeyCols.iterator + .map { key => + key -> leftDf.generateBloomFilter(key, leftDfCount, joinConf.left.table, leftRange) + } + .toMap + .asJava // compute joinParts in parallel val rightDfs = labelJoinConf.labels.asScala.map { labelJoinPart => @@ -241,7 +244,7 @@ class LabelJoin(joinConf: api.Join, tableUtils: TableUtils, labelDS: String) { PartitionRange(labelDS, labelDS), tableUtils, computeDependency = true, - Option(rightBloomMap.iterator.toJMap), + Option(rightBloomMap.toMap.asJava), rightSkewFilter) val df = (joinConf.left.dataModel, joinPart.groupBy.dataModel, joinPart.groupBy.inferredAccuracy) match {