diff --git a/spark/src/main/scala/ai/chronon/spark/Join.scala b/spark/src/main/scala/ai/chronon/spark/Join.scala index ac80b3e5f8..d3d6c124e5 100644 --- a/spark/src/main/scala/ai/chronon/spark/Join.scala +++ b/spark/src/main/scala/ai/chronon/spark/Join.scala @@ -565,10 +565,14 @@ class Join(joinConf: api.Join, bootstrapDf = bootstrapDf .select(includedColumns.map(col): _*) - // TODO: allow customization of deduplication logic - .dropDuplicates(part.keys(joinConf, tableUtils.partitionColumn).toArray) - coalescedJoin(partialDf, bootstrapDf, part.keys(joinConf, tableUtils.partitionColumn).toSeq) + val dedupedBootstrap = dropDuplicatesUsingJoinShuffle( + bootstrapDf, + partialDf, + part.keys(joinConf, tableUtils.partitionColumn).toSeq + ) + + coalescedJoin(partialDf, dedupedBootstrap, part.keys(joinConf, tableUtils.partitionColumn).toSeq) // as part of the left outer join process, we update and maintain matched_hashes for each record // that summarizes whether there is a join-match for each bootstrap source. // later on we use this information to decide whether we still need to re-run the backfill logic diff --git a/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala b/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala index 1a893a0a40..5b77d8d754 100644 --- a/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala @@ -21,9 +21,10 @@ import ai.chronon.api.Constants import ai.chronon.api.DataModel.Events import ai.chronon.api.Extensions.{JoinOps, _} import ai.chronon.spark.Extensions._ -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.expressions.UserDefinedFunction -import org.apache.spark.sql.functions.{coalesce, col, udf} +import org.apache.spark.sql.{Column, DataFrame} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.expressions.{UserDefinedFunction, Window} +import org.apache.spark.sql.functions.{coalesce, col, lit, row_number, udf} import org.slf4j.LoggerFactory import java.util @@ -115,6 +116,42 @@ object JoinUtils { PartitionRange(leftStart, leftEnd)(tableUtils) } + /** + * Deduplicate the given [[toDedup]] dataframe using the same shuffle as + * would be used when joining to the [[joinPartner]] dataframe. + * + * This is an optimization for the scenario where a caller is deduplicating + * the dataframe before joining: + * joinPartner.join(toDedup.dropDuplicates(keys), keys) + * + * By default, spark does a poor job of reusing the deduplication shuffle in + * this scenario -- the join query plan will normalize floating point numbers + * but dropDuplicates will not, resulting in subtly different shuffles. + * + * This function plans a fake join and extracts the exact expressions which + * will be used in the join. Then it uses the join expressions to deduplicate + * so that the same shuffle can be reused. + */ + def dropDuplicatesUsingJoinShuffle(toDedup: DataFrame, joinPartner: DataFrame, keys: Seq[String]): DataFrame = { + val condition = keys.map(key => joinPartner(key) === toDedup(key)).reduce(_ && _) + val plannedJoin = joinPartner.join(toDedup, condition) + + plannedJoin.queryExecution.logical match { + case ExtractEquiJoinKeys(_, _, rightKeys, _, _, _, _) => + + val cols = rightKeys.map(new Column(_)) + val w = Window.partitionBy(cols: _*).orderBy(cols: _*) + toDedup + .withColumn("dedup_row_number", row_number().over(w)) + .filter(col("dedup_row_number") === lit(1)) + .drop("dedup_row_number") + case _ => + // This should never happen + logger.warn(s"Couldn't plan an equijoin for $condition. Falling back to dropDuplicates.") + toDedup.dropDuplicates(keys) + } + } + /** * * join left and right dataframes, merging any shared columns if exists by the coalesce rule. * fails if there is any data type mismatch between shared columns. diff --git a/spark/src/test/scala/ai/chronon/spark/test/JoinUtilsTest.scala b/spark/src/test/scala/ai/chronon/spark/test/JoinUtilsTest.scala index 2189c6f616..166924da48 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/JoinUtilsTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/JoinUtilsTest.scala @@ -19,10 +19,11 @@ package ai.chronon.spark.test import ai.chronon.aggregator.test.Column import ai.chronon.api import ai.chronon.api.{Builders, Constants} -import ai.chronon.spark.JoinUtils.{contains_any, set_add} +import ai.chronon.spark.JoinUtils.{contains_any, dropDuplicatesUsingJoinShuffle, set_add} import ai.chronon.spark.{GroupBy, JoinUtils, PartitionRange, SparkSessionBuilder, TableUtils} import ai.chronon.spark.Extensions._ import org.apache.spark.rdd.RDD +import org.apache.spark.sql.execution.exchange.Exchange import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Row, SparkSession} @@ -102,6 +103,147 @@ class JoinUtilsTest { } } + @Test + def reuseDedupShuffle(): Unit = { + val spark: SparkSession = + SparkSessionBuilder.build("JoinUtilsTest" + "_" + Random.alphanumeric.take(6).mkString, local = true) + val leftDf = spark.createDataFrame(spark.sparkContext.parallelize(Seq( + (1.0, "a2") + ))) + val rightDf = spark.createDataFrame(spark.sparkContext.parallelize(Seq( + (1.0, "a2") + ))) + + val keys = Seq("_1") + val deduped = dropDuplicatesUsingJoinShuffle(rightDf, leftDf, keys) + .as[(String, String)](spark.implicits.newProductEncoder) + + val exchanges = leftDf.join(deduped, keys) + .queryExecution.executedPlan + .collect { case _: Exchange => true } + .length + + assertEquals(2, exchanges) + } + + private def runDropDuplicates(left: Seq[(String, String)], right: Seq[(String, String)], keys: Seq[String]): Seq[(String, String)] = { + val spark: SparkSession = + SparkSessionBuilder.build("JoinUtilsTest" + "_" + Random.alphanumeric.take(6).mkString, local = true) + val leftDf = spark.createDataFrame(spark.sparkContext.parallelize(left)) + val rightDf = spark.createDataFrame(spark.sparkContext.parallelize(right)) + + dropDuplicatesUsingJoinShuffle(rightDf, leftDf, keys) + .as[(String, String)](spark.implicits.newProductEncoder) + .collect() + .toSeq + } + + @Test + def testNoDuplicates(): Unit = { + val left = Seq( + ("a1", "a2"), + ("b1", "b2"), + ("c1", "c2") + ) + val right = Seq( + ("a1", "a2"), + ("b1", "b2"), + ("c1", "c2") + ) + val keys = Seq("_1") + + val expected = right + val result = runDropDuplicates(left, right, keys) + + assertEquals(expected.length, result.length) + result.foreach { r => assertTrue(expected.contains(r)) } + } + + @Test + def testNoDuplicatesColumn2(): Unit = { + val left = Seq( + ("a1", "a2"), + ("b1", "b2"), + ("c1", "c2") + ) + val right = Seq( + ("a1", "a2"), + ("a1", "b2"), + ("a1", "c2") + ) + val keys = Seq("_1", "_2") + + val expected = right + val result = runDropDuplicates(left, right, keys) + + assertEquals(expected.length, result.length) + result.foreach { r => assertTrue(expected.contains(r)) } + } + + @Test + def testDuplicates(): Unit = { + val left = Seq( + ("a1", "a2"), + ("b1", "b2"), + ("c1", "c2") + ) + val right = Seq( + ("a1", "a2"), + ("a1", "b2"), + ("c1", "c2") + ) + val keys = Seq("_1") + + // to handle nondeterministic sort + val expected1 = Seq( + ("a1", "a2"), + ("c1", "c2") + ) + val expected2 = Seq( + ("a1", "b2"), + ("c1", "c2") + ) + val result = runDropDuplicates(left, right, keys) + + assertEquals(expected1.length, result.length) + result.foreach { r => assertTrue(expected1.contains(r) || expected2.contains(r)) } + } + + @Test + def test2ColumnDedup(): Unit = { + val left = Seq( + ("a1", "a2"), + ("b1", "b2"), + ("c1", "c2") + ) + val right = Seq( + ("a1", "a2"), + ("a1", "a2"), + ("a1", "c2") + ) + val keys = Seq("_1", "_2") + + val expected = Seq( + ("a1", "a2"), + ("a1", "c2") + ) + val result = runDropDuplicates(left, right, keys) + + assertEquals(expected.length, result.length) + result.foreach { r => assertTrue(expected.contains(r)) } + } + + @Test + def testEmpty(): Unit = { + val left = Seq() + val right = Seq() + val keys = Seq("_1") + + val result = runDropDuplicates(left, right, keys) + + assertTrue(result.isEmpty) + } + private def testJoinScenario(leftSchema: StructType, rightSchema: StructType, keys: Seq[String],