diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 4dc5ce1de047b..190118828d29e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -591,6 +591,18 @@ public boolean anyNull() { return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes / 8); } + /** + * return whether an UnsafeRow is null on every column. + */ + public boolean allNull() { + for (int i = 0; i < numFields; i++) { + if (!BitSetMethods.isSet(baseObject, baseOffset, i)) { + return false; + } + } + return true; + } + /** * Writes the content of this row into a memory address, identified by an object and an offset. * The target memory address must already been allocated, and have enough space to hold all the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 5a994f1ad0a39..a42d44cd0d40d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.planning +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ @@ -391,35 +393,45 @@ object PhysicalWindow { } } -object ExtractSingleColumnNullAwareAntiJoin extends JoinSelectionHelper with PredicateHelper { +object ExtractNullAwareAntiJoinKeys extends JoinSelectionHelper with PredicateHelper { - // TODO support multi column NULL-aware anti join in future. - // See. http://www.vldb.org/pvldb/vol2/vldb09-423.pdf Section 6 - // multi-column null aware anti join is much more complicated than single column ones. + // FYI. Extra information about Null Aware Anti Join. + // https://dl.acm.org/doi/10.14778/1687553.1687563 + // http://www.vldb.org/pvldb/vol2/vldb09-423.pdf // streamedSideKeys, buildSideKeys private type ReturnType = (Seq[Expression], Seq[Expression]) - /** - * See. [SPARK-32290] - * LeftAnti(condition: Or(EqualTo(a=b), IsNull(EqualTo(a=b))) - * will almost certainly be planned as a Broadcast Nested Loop join, - * which is very time consuming because it's an O(M*N) calculation. - * But if it's a single column case O(M*N) calculation could be optimized into O(M) - * using hash lookup instead of loop lookup. - */ def unapply(join: Join): Option[ReturnType] = join match { - case Join(left, right, LeftAnti, - Some(Or(e @ EqualTo(leftAttr: AttributeReference, rightAttr: AttributeReference), - IsNull(e2 @ EqualTo(_, _)))), _) - if SQLConf.get.optimizeNullAwareAntiJoin && - e.semanticEquals(e2) => - if (canEvaluate(leftAttr, left) && canEvaluate(rightAttr, right)) { - Some(Seq(leftAttr), Seq(rightAttr)) - } else if (canEvaluate(leftAttr, right) && canEvaluate(rightAttr, left)) { - Some(Seq(rightAttr), Seq(leftAttr)) - } else { + case Join(left, right, LeftAnti, condition, _) if SQLConf.get.optimizeNullAwareAntiJoin => + val predicates = condition.map(splitConjunctivePredicates).getOrElse(Nil) + if (predicates.isEmpty || + predicates.length > SQLConf.get.optimizeNullAwareAntiJoinMaxNumKeys) { None + } else { + val joinKeys = ArrayBuffer[(Expression, Expression)]() + + // All predicate must match pattern condition: Or(EqualTo(a=b), IsNull(EqualTo(a=b))) + val allMatch = predicates.forall { + case Or(e @ EqualTo(leftExpr: Expression, rightExpr: Expression), + IsNull(e2 @ EqualTo(_, _))) if e.semanticEquals(e2) => + if (canEvaluate(leftExpr, left) && canEvaluate(rightExpr, right)) { + joinKeys += ((leftExpr, rightExpr)) + true + } else if (canEvaluate(leftExpr, right) && canEvaluate(rightExpr, left)) { + joinKeys += ((rightExpr, leftExpr)) + true + } else { + false + } + case _ => false + } + + if (allMatch) { + Some(joinKeys.unzip) + } else { + None + } } case _ => None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index bae41114caf1c..5dcc5745b056b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2678,14 +2678,26 @@ object SQLConf { .checkValue(_ >= 0, "The value must be non-negative.") .createWithDefault(8) + val OPTIMIZE_NULL_AWARE_ANTI_JOIN_MAX_NUM_KEYS = + buildConf("spark.sql.optimizeNullAwareAntiJoin.maxNumKeys") + .internal() + .doc("The maximum number of keys that will be supported to use NAAJ optimize. " + + "While with NAAJ optimize, buildSide data would be expanded to (2^numKeys - 1) times, " + + "it might cause Driver OOM if NAAJ numKeys increased, since it is exponential growth.") + .version("3.1.0") + .intConf + .checkValue(_ > 0, "The value must be positive.") + .createWithDefault(3) + val OPTIMIZE_NULL_AWARE_ANTI_JOIN = buildConf("spark.sql.optimizeNullAwareAntiJoin") .internal() .doc("When true, NULL-aware anti join execution will be planed into " + "BroadcastHashJoinExec with flag isNullAwareAntiJoin enabled, " + "optimized from O(M*N) calculation into O(M) calculation " + - "using Hash lookup instead of Looping lookup." + - "Only support for singleColumn NAAJ for now.") + "using Hash lookup instead of Looping lookup. " + + "The number of keys supported for NAAJ is configured by " + + s"${OPTIMIZE_NULL_AWARE_ANTI_JOIN_MAX_NUM_KEYS.key}.") .version("3.1.0") .booleanConf .createWithDefault(true) @@ -3300,6 +3312,9 @@ class SQLConf extends Serializable with Logging { def optimizeNullAwareAntiJoin: Boolean = getConf(SQLConf.OPTIMIZE_NULL_AWARE_ANTI_JOIN) + def optimizeNullAwareAntiJoinMaxNumKeys: Int = + getConf(SQLConf.OPTIMIZE_NULL_AWARE_ANTI_JOIN_MAX_NUM_KEYS) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index eb32bfcecae7b..8f5d2fc2b5b57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -232,7 +232,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { .orElse { if (hintToShuffleReplicateNL(hint)) createCartesianProduct() else None } .getOrElse(createJoinWithoutHint()) - case j @ ExtractSingleColumnNullAwareAntiJoin(leftKeys, rightKeys) => + case j @ ExtractNullAwareAntiJoinKeys(leftKeys, rightKeys) => Seq(joins.BroadcastHashJoinExec(leftKeys, rightKeys, LeftAnti, BuildRight, None, planLater(j.left), planLater(j.right), isNullAwareAntiJoin = true)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala index bcf9dc1544ce3..eab45a3e55c20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.expressions.PredicateHelper import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} -import org.apache.spark.sql.catalyst.planning.{ExtractEquiJoinKeys, ExtractSingleColumnNullAwareAntiJoin} +import org.apache.spark.sql.catalyst.planning.{ExtractEquiJoinKeys, ExtractNullAwareAntiJoinKeys} import org.apache.spark.sql.catalyst.plans.LeftAnti import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan} import org.apache.spark.sql.execution.{joins, SparkPlan} @@ -49,7 +49,7 @@ object LogicalQueryStageStrategy extends Strategy with PredicateHelper { Seq(BroadcastHashJoinExec( leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) - case j @ ExtractSingleColumnNullAwareAntiJoin(leftKeys, rightKeys) + case j @ ExtractNullAwareAntiJoinKeys(leftKeys, rightKeys) if isBroadcastStage(j.right) => Seq(joins.BroadcastHashJoinExec(leftKeys, rightKeys, LeftAnti, BuildRight, None, planLater(j.left), planLater(j.right), isNullAwareAntiJoin = true)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index e4935c8c72228..f7646b62b5405 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -49,8 +49,6 @@ case class BroadcastHashJoinExec( extends HashJoin with CodegenSupport { if (isNullAwareAntiJoin) { - require(leftKeys.length == 1, "leftKeys length should be 1") - require(rightKeys.length == 1, "rightKeys length should be 1") require(joinType == LeftAnti, "joinType must be LeftAnti.") require(buildSide == BuildRight, "buildSide must be BuildRight.") require(condition.isEmpty, "null aware anti join optimize condition should be empty.") @@ -156,7 +154,7 @@ case class BroadcastHashJoinExec( ) streamedIter.filter(row => { val lookupKey: UnsafeRow = keyGenerator(row) - if (lookupKey.anyNull()) { + if (lookupKey.allNull()) { false } else { // Anti Join: Drop the row on the streamed side if it is a match on the build @@ -225,9 +223,10 @@ case class BroadcastHashJoinExec( protected override def codegenAnti(ctx: CodegenContext, input: Seq[ExprCode]): String = { if (isNullAwareAntiJoin) { val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) - val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input, true) val (matched, _, _) = getJoinCondition(ctx, input) val numOutput = metricTerm(ctx, "numOutputRows") + val isLongHashedRelation = broadcastRelation.value.isInstanceOf[LongHashedRelation] if (broadcastRelation.value == EmptyHashedRelation) { s""" @@ -245,11 +244,10 @@ case class BroadcastHashJoinExec( |boolean $found = false; |// generate join key for stream side |${keyEv.code} - |if ($anyNull) { + |if (${s"${keyEv.value}.allNull()"}) { | $found = true; |} else { - | UnsafeRow $matched = (UnsafeRow)$relationTerm.getValue(${keyEv.value}); - | if ($matched != null) { + | if ($relationTerm.get(${keyEv.value}) != null) { | $found = true; | } |} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 1c6504b141890..e59bf1d99b3fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -342,9 +342,11 @@ trait HashJoin extends BaseJoinExec with CodegenSupport { */ protected def genStreamSideJoinKey( ctx: CodegenContext, - input: Seq[ExprCode]): (ExprCode, String) = { + input: Seq[ExprCode], + forceUnsafe: Boolean = false): (ExprCode, String) = { ctx.currentVars = input - if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType == LongType) { + if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType == LongType && + !forceUnsafe) { // generate the join key as Long val ev = streamedBoundKeys.head.genCode(ctx) (ev, ev.isNull) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index f2835c2fa6626..65f648eb82132 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -19,8 +19,11 @@ package org.apache.spark.sql.execution.joins import java.io._ +import scala.collection.mutable + import com.esotericsoftware.kryo.{Kryo, KryoSerializable} import com.esotericsoftware.kryo.io.{Input, Output} +import org.roaringbitmap.longlong.Roaring64Bitmap import org.apache.spark.{SparkConf, SparkEnv, SparkException} import org.apache.spark.internal.config.{BUFFER_PAGESIZE, MEMORY_OFFHEAP_ENABLED} @@ -28,7 +31,7 @@ import org.apache.spark.memory._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode -import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.types.{DataType, LongType} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.util.{KnownSizeEstimation, Utils} @@ -108,12 +111,16 @@ private[execution] object HashedRelation { 0) } - if (isNullAware && !input.hasNext) { - EmptyHashedRelation + if (isNullAware) { + if (!input.hasNext) { + EmptyHashedRelation + } else { + InvertedIndexHashedRelation(input, key, sizeEstimate, mm) + } } else if (key.length == 1 && key.head.dataType == LongType) { - LongHashedRelation(input, key, sizeEstimate, mm, isNullAware) + LongHashedRelation(input, key, sizeEstimate, mm) } else { - UnsafeHashedRelation(input, key, sizeEstimate, mm, isNullAware) + UnsafeHashedRelation(input, key, sizeEstimate, mm) } } } @@ -313,8 +320,7 @@ private[joins] object UnsafeHashedRelation { input: Iterator[InternalRow], key: Seq[Expression], sizeEstimate: Int, - taskMemoryManager: TaskMemoryManager, - isNullAware: Boolean = false): HashedRelation = { + taskMemoryManager: TaskMemoryManager): HashedRelation = { val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes) .getOrElse(new SparkConf().get(BUFFER_PAGESIZE).getOrElse(16L * 1024 * 1024)) @@ -342,8 +348,6 @@ private[joins] object UnsafeHashedRelation { throw new SparkOutOfMemoryError("There is not enough memory to build hash map") // scalastyle:on throwerror } - } else if (isNullAware) { - return EmptyHashedRelationWithAllNullKeys } } @@ -895,8 +899,7 @@ private[joins] object LongHashedRelation { input: Iterator[InternalRow], key: Seq[Expression], sizeEstimate: Int, - taskMemoryManager: TaskMemoryManager, - isNullAware: Boolean = false): HashedRelation = { + taskMemoryManager: TaskMemoryManager): HashedRelation = { val map = new LongToUnsafeRowMap(taskMemoryManager, sizeEstimate) val keyGenerator = UnsafeProjection.create(key) @@ -910,8 +913,6 @@ private[joins] object LongHashedRelation { if (!rowKey.isNullAt(0)) { val key = rowKey.getLong(0) map.append(key, unsafeRow) - } else if (isNullAware) { - return EmptyHashedRelationWithAllNullKeys } } map.optimize() @@ -963,6 +964,103 @@ object EmptyHashedRelationWithAllNullKeys extends NullAwareHashedRelation { override def asReadOnlyCopy(): EmptyHashedRelationWithAllNullKeys.type = this } +class NullAwareUnsafeRowHashSet( + numKeys: Int, + dataTypes: Seq[DataType], + keyMaps: Seq[mutable.HashMap[Any, Roaring64Bitmap]]) + extends Serializable { + + private var numRows: Long = 0L + private[joins] def this() = this(0, null, null) // Needed for serialization + + def add(key: UnsafeRow): Unit = { + require(key.numFields() == numKeys) + 0.until(numKeys).foreach { index => + keyMaps(index).getOrElseUpdate( + if (key.isNullAt(index)) null else key.get(index, dataTypes(index)), + new Roaring64Bitmap + ).addLong(numRows) + } + numRows += 1 + } + + def contains(key: UnsafeRow): Boolean = { + require(key.numFields() == numKeys) + // all null key should not exist + val bitmapSeq: Seq[Roaring64Bitmap] = 0.until(numKeys).map { index => + if (key.isNullAt(index)) { + null + } else { + val result = new Roaring64Bitmap + result.or(keyMaps(index).getOrElse(key.get(index, dataTypes(index)), new Roaring64Bitmap)) + if (numKeys > 1) { + result.or(keyMaps(index).getOrElse(null, new Roaring64Bitmap)) + } + result + } + }.filter(_ != null) + + !bitmapSeq.reduceLeft[Roaring64Bitmap]((a, b) => { + a.and(b) + a + }).isEmpty + } +} + +class InvertedIndexHashedRelation( + private var numKeys: Int, + private var hashSet: NullAwareUnsafeRowHashSet) + extends NullAwareHashedRelation with Externalizable { + + private[joins] def this() = this(0, null) // Needed for serialization + override def asReadOnlyCopy(): InvertedIndexHashedRelation = this + + override def get(key: InternalRow): Iterator[InternalRow] = { + if (hashSet.contains(key.asInstanceOf[UnsafeRow])) { + Seq.empty[InternalRow].toIterator + } else { + null + } + } + + override def writeExternal(out: ObjectOutput): Unit = { + out.writeInt(numKeys) + out.writeObject(hashSet) + } + + override def readExternal(in: ObjectInput): Unit = { + numKeys = in.readInt() + hashSet = in.readObject().asInstanceOf[NullAwareUnsafeRowHashSet] + } +} + +object InvertedIndexHashedRelation { + def apply( + input: Iterator[InternalRow], + key: Seq[Expression], + sizeEstimate: Int, + taskMemoryManager: TaskMemoryManager): HashedRelation = { + + val keyGenerator = UnsafeProjection.create(key) + val hashSet: NullAwareUnsafeRowHashSet = new NullAwareUnsafeRowHashSet( + key.length, + key.map(_.dataType), + key.map(_ => new mutable.HashMap[Any, Roaring64Bitmap]()) + ) + + while (input.hasNext) { + val unsafeRow = input.next().asInstanceOf[UnsafeRow] + val rowKey = keyGenerator(unsafeRow) + if (rowKey.allNull()) { + return EmptyHashedRelationWithAllNullKeys + } + hashSet.add(rowKey) + } + + new InvertedIndexHashedRelation(key.length, hashSet) + } +} + /** The HashedRelationBroadcastMode requires that rows are broadcasted as a HashedRelation. */ case class HashedRelationBroadcastMode(key: Seq[Expression], isNullAware: Boolean = false) extends BroadcastMode { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index bedfbffc789ac..e36cd04ee006d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -1154,20 +1154,15 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Long.MaxValue.toString) { // positive not in subquery case var joinExec = assertJoin(( - "select * from testData where key not in (select a from testData2)", + "SELECT * FROM testData WHERE key NOT IN (SELECT a FROM testData2)", classOf[BroadcastHashJoinExec])) assert(joinExec.asInstanceOf[BroadcastHashJoinExec].isNullAwareAntiJoin) - // negative not in subquery case since multi-column is not supported - assertJoin(( - "select * from testData where (key, key + 1) not in (select * from testData2)", - classOf[BroadcastNestedLoopJoinExec])) - // positive hand-written left anti join // testData.key nullable false // testData3.b nullable true joinExec = assertJoin(( - "select * from testData left anti join testData3 ON key = b or isnull(key = b)", + "SELECT * FROM testData LEFT ANTI JOIN testData3 ON key = b OR ISNULL(key = b)", classOf[BroadcastHashJoinExec])) assert(joinExec.asInstanceOf[BroadcastHashJoinExec].isNullAwareAntiJoin) @@ -1176,15 +1171,53 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan // testData2.a nullable false // isnull(key = a) will be optimized to true literal and removed joinExec = assertJoin(( - "select * from testData left anti join testData2 ON key = a or isnull(key = a)", + "SELECT * FROM testData LEFT ANTI JOIN testData2 ON key = a OR ISNULL(key = a)", classOf[BroadcastHashJoinExec])) assert(!joinExec.asInstanceOf[BroadcastHashJoinExec].isNullAwareAntiJoin) // negative hand-written left anti join // not match pattern Or(EqualTo(a=b), IsNull(EqualTo(a=b)) assertJoin(( - "select * from testData2 left anti join testData3 ON testData2.a = testData3.b or " + - "isnull(testData2.b = testData3.b)", + "SELECT * FROM testData2 LEFT ANTI JOIN testData3 ON testData2.a = testData3.b OR " + + "ISNULL(testData2.b = testData3.b)", + classOf[BroadcastNestedLoopJoinExec])) + } + } + + test("SPARK-32494: Null Aware Anti Join Optimize Support Multi-Column") { + withSQLConf(SQLConf.OPTIMIZE_NULL_AWARE_ANTI_JOIN.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Long.MaxValue.toString) { + // positive not in subquery case + var joinExec = assertJoin(( + "SELECT * FROM testData WHERE (key, key + 1) NOT IN (SELECT * FROM testData2)", + classOf[BroadcastHashJoinExec])) + assert(joinExec.asInstanceOf[BroadcastHashJoinExec].isNullAwareAntiJoin) + + // positive hand-written left anti join + // testData.key nullable false + // testData3.b nullable true + joinExec = assertJoin(( + "SELECT * FROM testData LEFT ANTI JOIN testData3 ON (key = b OR ISNULL(key = b)) " + + "AND (key + 1 = b OR ISNULL(key + 1 = b))", + classOf[BroadcastHashJoinExec])) + assert(joinExec.asInstanceOf[BroadcastHashJoinExec].isNullAwareAntiJoin) + + // negative hand-written left anti join + // testData.key nullable false + // testData2.a nullable false + // isnull(key = a) isnull(key+1=a) will be optimized to true literal and removed + joinExec = assertJoin(( + "SELECT * FROM testData LEFT ANTI JOIN testData3 ON (key = a OR ISNULL(key = a)) " + + "AND (key + 1 = a OR ISNULL(key + 1 = a))", + classOf[BroadcastHashJoinExec])) + assert(!joinExec.asInstanceOf[BroadcastHashJoinExec].isNullAwareAntiJoin) + + // negative exceed OPTIMIZE_NULL_AWARE_ANTI_JOIN_MAX_NUM_KEYS + joinExec = assertJoin(( + "SELECT * FROM testData LEFT ANTI JOIN testData3 ON (key = b OR ISNULL(key = b)) " + + "AND (key + 2 = b OR ISNULL(key + 2 = b)) " + + "AND (key + 3 = b OR ISNULL(key + 3 = b)) " + + "AND (key + 4 = b OR ISNULL(key + 4 = b))", classOf[BroadcastNestedLoopJoinExec])) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index a21c461e84588..3b4b56a4b3d0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -1648,6 +1648,12 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark checkAnswer(df, Nil) } + private def findJoinExec(df: DataFrame): BaseJoinExec = { + df.queryExecution.sparkPlan.collectFirst { + case j: BaseJoinExec => j + }.get + } + test("SPARK-32290: SingleColumn Null Aware Anti Join Optimize") { Seq(true, false).foreach { enableNAAJ => Seq(true, false).foreach { enableAQE => @@ -1657,17 +1663,11 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString, SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> enableCodegen.toString) { - def findJoinExec(df: DataFrame): BaseJoinExec = { - df.queryExecution.sparkPlan.collectFirst { - case j: BaseJoinExec => j - }.get - } - var df: DataFrame = null var joinExec: BaseJoinExec = null // single column not in subquery -- empty sub-query - df = sql("select * from l where a not in (select c from r where c > 10)") + df = sql("SELECT * FROM l WHERE a NOT IN (SELECT c FROM r WHERE c > 10)") checkAnswer(df, spark.table("l")) if (enableNAAJ) { joinExec = findJoinExec(df) @@ -1678,7 +1678,7 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } // single column not in subquery -- sub-query include null - df = sql("select * from l where a not in (select c from r where d < 6.0)") + df = sql("SELECT * FROM l WHERE a NOT IN (SELECT c FROM r WHERE d < 6.0)") checkAnswer(df, Seq.empty) if (enableNAAJ) { joinExec = findJoinExec(df) @@ -1689,8 +1689,8 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } // single column not in subquery -- streamedSide row is null - df = - sql("select * from l where b = 5.0 and a not in(select c from r where c is not null)") + df = sql("SELECT * FROM l WHERE b = 5.0 AND a NOT IN " + + "(SELECT c FROM r WHERE c IS NOT NULL)") checkAnswer(df, Seq.empty) if (enableNAAJ) { joinExec = findJoinExec(df) @@ -1702,7 +1702,7 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark // single column not in subquery -- streamedSide row is not null, match found df = - sql("select * from l where a = 6 and a not in (select c from r where c is not null)") + sql("SELECT * FROM l WHERE a = 6 AND a NOT IN (SELECT c FROM r WHERE c IS NOT NULL)") checkAnswer(df, Seq.empty) if (enableNAAJ) { joinExec = findJoinExec(df) @@ -1714,7 +1714,7 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark // single column not in subquery -- streamedSide row is not null, match not found df = - sql("select * from l where a = 1 and a not in (select c from r where c is not null)") + sql("SELECT * FROM l WHERE a = 1 AND a NOT IN (SELECT c FROM r WHERE c IS NOT NULL)") checkAnswer(df, Row(1, 2.0) :: Row(1, 2.0) :: Nil) if (enableNAAJ) { joinExec = findJoinExec(df) @@ -1725,7 +1725,7 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } // single column not in subquery -- d = b + 10 joinKey found, match ExtractEquiJoinKeys - df = sql("select * from l where a not in (select c from r where d = b + 10)") + df = sql("SELECT * FROM l WHERE a NOT IN (SELECT c FROM r WHERE d = b + 10)") checkAnswer(df, spark.table("l")) joinExec = findJoinExec(df) assert(joinExec.isInstanceOf[BroadcastHashJoinExec]) @@ -1734,7 +1734,7 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark // single column not in subquery -- d = b + 10 and b = 5.0 => d = 15, joinKey not found // match ExtractSingleColumnNullAwareAntiJoin df = - sql("select * from l where b = 5.0 and a not in (select c from r where d = b + 10)") + sql("SELECT * FROM l WHERE b = 5.0 AND a NOT IN (SELECT c FROM r WHERE d = b + 10)") checkAnswer(df, Row(null, 5.0) :: Nil) if (enableNAAJ) { joinExec = findJoinExec(df) @@ -1743,11 +1743,82 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } else { assert(findJoinExec(df).isInstanceOf[BroadcastNestedLoopJoinExec]) } + } + } + } + } + } + + test("SPARK-32494: Null Aware Anti Join Optimize Support Multi-Column") { + Seq(true, false).foreach { enableNAAJ => + Seq(true, false).foreach { enableAQE => + Seq(true, false).foreach { enableCodegen => + withSQLConf( + SQLConf.OPTIMIZE_NULL_AWARE_ANTI_JOIN.key -> enableNAAJ.toString, + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString, + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> enableCodegen.toString) { + + var df: DataFrame = null + var joinExec: BaseJoinExec = null - // multi column not in subquery - df = sql("select * from l where (a, b) not in (select c, d from r where c > 10)") + // multi column not in subquery -- empty sub-query + df = sql("SELECT * FROM l WHERE (a, b) NOT IN (SELECT * FROM r WHERE c > 10)") checkAnswer(df, spark.table("l")) - assert(findJoinExec(df).isInstanceOf[BroadcastNestedLoopJoinExec]) + if (enableNAAJ) { + joinExec = findJoinExec(df) + assert(joinExec.isInstanceOf[BroadcastHashJoinExec]) + assert(joinExec.asInstanceOf[BroadcastHashJoinExec].isNullAwareAntiJoin) + } else { + assert(findJoinExec(df).isInstanceOf[BroadcastNestedLoopJoinExec]) + } + + // multi column not in subquery -- sub-query include all null column key + df = sql( + "SELECT * FROM l WHERE (a, b) NOT IN (SELECT * FROM r WHERE c IS NULL and d IS NULL)") + checkAnswer(df, Seq.empty) + if (enableNAAJ) { + joinExec = findJoinExec(df) + assert(joinExec.isInstanceOf[BroadcastHashJoinExec]) + assert(joinExec.asInstanceOf[BroadcastHashJoinExec].isNullAwareAntiJoin) + } else { + assert(findJoinExec(df).isInstanceOf[BroadcastNestedLoopJoinExec]) + } + + // multi column not in subquery -- streamedSide row is all null column key + df = sql("SELECT * FROM l WHERE a IS NULL and b IS NULL " + + "AND (a, b) NOT IN (SELECT * FROM r WHERE c IS NOT NULL)") + checkAnswer(df, Seq.empty) + if (enableNAAJ) { + joinExec = findJoinExec(df) + assert(joinExec.isInstanceOf[BroadcastHashJoinExec]) + assert(joinExec.asInstanceOf[BroadcastHashJoinExec].isNullAwareAntiJoin) + } else { + assert(findJoinExec(df).isInstanceOf[BroadcastNestedLoopJoinExec]) + } + + // multi column not in subquery -- streamedSide row is not all null, match found + df = sql("SELECT * FROM l WHERE a = 6 " + + "AND (a, b) NOT IN (SELECT * FROM r WHERE c IS NOT NULL)") + checkAnswer(df, Seq.empty) + if (enableNAAJ) { + joinExec = findJoinExec(df) + assert(joinExec.isInstanceOf[BroadcastHashJoinExec]) + assert(joinExec.asInstanceOf[BroadcastHashJoinExec].isNullAwareAntiJoin) + } else { + assert(findJoinExec(df).isInstanceOf[BroadcastNestedLoopJoinExec]) + } + + // multi column not in subquery -- streamedSide row is not all null, match not found + df = sql("SELECT * FROM l WHERE a = 1 " + + "AND (a, b) NOT IN (SELECT * FROM r WHERE c IS NOT NULL)") + checkAnswer(df, Row(1, 2.0) :: Row(1, 2.0) :: Nil) + if (enableNAAJ) { + joinExec = findJoinExec(df) + assert(joinExec.isInstanceOf[BroadcastHashJoinExec]) + assert(joinExec.asInstanceOf[BroadcastHashJoinExec].isNullAwareAntiJoin) + } else { + assert(findJoinExec(df).isInstanceOf[BroadcastNestedLoopJoinExec]) + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 21ee88f0d7426..1ae867998ddb6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -580,4 +580,100 @@ class HashedRelationSuite extends SharedSparkSession { assert(proj(packedKeys).get(0, dt) == -i - 1) } } + + test("NullAwareHashedRelation") { + val singleKey = Seq(BoundReference(0, LongType, true)) + val multiKey = Seq(BoundReference(0, IntegerType, true), + BoundReference(1, LongType, true), + BoundReference(2, StringType, true) + ) + + val singleProjection = UnsafeProjection.create(singleKey) + val multiProjection = UnsafeProjection.create(multiKey) + + var hashedRelation: HashedRelation = null + + // singleKey EmptyHashedRelation + hashedRelation = HashedRelation( + Seq.empty[InternalRow].iterator, + singleKey, taskMemoryManager = mm, isNullAware = true) + assert(hashedRelation == EmptyHashedRelation) + + // multiKey EmptyHashedRelation + hashedRelation = HashedRelation( + Seq.empty[InternalRow].iterator, + multiKey, taskMemoryManager = mm, isNullAware = true) + assert(hashedRelation == EmptyHashedRelation) + + // singleKey EmptyHashedRelationWithAllNullKeys + val data1 = Seq( + Seq(1L), + Seq(null), + Seq(2L) + ) + + hashedRelation = HashedRelation( + data1.map(InternalRow.fromSeq).map(row => singleProjection(row).copy()).iterator, + singleKey, taskMemoryManager = mm, isNullAware = true) + assert(hashedRelation == EmptyHashedRelationWithAllNullKeys) + + // multiKey EmptyHashedRelationWithAllNullKeys + val data2 = Seq( + Seq(1, 1L, UTF8String.fromString("1")), + Seq(null, null, null), + Seq(2, 2L, UTF8String.fromString("2")) + ) + + hashedRelation = HashedRelation( + data2.map(InternalRow.fromSeq).map(row => multiProjection(row).copy()).iterator, + multiKey, taskMemoryManager = mm, isNullAware = true) + assert(hashedRelation == EmptyHashedRelationWithAllNullKeys) + + // singleKey AllNullKeys not exists + val data3 = Seq( + Seq(1L), + Seq(2L) + ) + hashedRelation = HashedRelation( + data3.map(InternalRow.fromSeq).map(row => singleProjection(row).copy()).iterator, + singleKey, taskMemoryManager = mm, isNullAware = true) + assert(hashedRelation.isInstanceOf[LongHashedRelation]) + assert(hashedRelation.keys().length == 2) + + // multiKey AllNullKeys not exists + val data4 = Seq( + Seq(1, null, UTF8String.fromString("1")), + Seq(null, 5L, null), + Seq(2, 2L, UTF8String.fromString("2")) + ) + hashedRelation = HashedRelation( + data4.map(InternalRow.fromSeq).map(row => multiProjection(row).copy()).iterator, + multiKey, taskMemoryManager = mm, isNullAware = true) + assert(hashedRelation.isInstanceOf[UnsafeHashedRelation]) + // Original 3 Records will be expanded into 7X + // which is 21 in total + assert(hashedRelation.asInstanceOf[UnsafeHashedRelation].keys().length == 21) + + // key verification after distinct + val data5 = Seq( + Seq(1, null, UTF8String.fromString("1")), + Seq(null, null, UTF8String.fromString("1")), + Seq(1, null, null), + Seq(null, null, null), + Seq(null, 5L, null), + Seq(2, 2L, UTF8String.fromString("2")), + Seq(null, 2L, UTF8String.fromString("2")), + Seq(2, null, UTF8String.fromString("2")), + Seq(2, 2L, null), + Seq(null, null, UTF8String.fromString("2")), + Seq(null, 2L, null), + Seq(2, null, null) + ) + + assert( + data5.map(InternalRow.fromSeq) + .map(row => multiProjection(row).copy()) + .forall(row => hashedRelation.get(row) != null) + ) + } }