Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,18 @@ public boolean anyNull() {
return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes / 8);
}

/**
* return whether an UnsafeRow is null on every column.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: return -> Return and please add tests in UnsafeRowSuite if you add a new method here.

*/
public boolean allNull() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add doc for the method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

for (int i = 0; i < numFields; i++) {
if (!BitSetMethods.isSet(baseObject, baseOffset, i)) {
return false;
}
}
return true;
}

/**
* Writes the content of this row into a memory address, identified by an object and an offset.
* The target memory address must already been allocated, and have enough space to hold all the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.planning

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -391,35 +393,45 @@ object PhysicalWindow {
}
}

object ExtractSingleColumnNullAwareAntiJoin extends JoinSelectionHelper with PredicateHelper {
object ExtractNullAwareAntiJoinKeys extends JoinSelectionHelper with PredicateHelper {
Copy link
Member

@maropu maropu Aug 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, I think its better to add some fine-grained tests for this extractor. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, i will add test upon it.


// TODO support multi column NULL-aware anti join in future.
// See. http://www.vldb.org/pvldb/vol2/vldb09-423.pdf Section 6
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about keeping this comment for the reference to NAAJ? I think this is a good material to understand how it works. (NOTE: IMHO the link to the ACM page is better than the direct link to the PDF).

// multi-column null aware anti join is much more complicated than single column ones.
// FYI. Extra information about Null Aware Anti Join.
// https://dl.acm.org/doi/10.14778/1687553.1687563
// http://www.vldb.org/pvldb/vol2/vldb09-423.pdf

// streamedSideKeys, buildSideKeys
private type ReturnType = (Seq[Expression], Seq[Expression])

/**
* See. [SPARK-32290]
* LeftAnti(condition: Or(EqualTo(a=b), IsNull(EqualTo(a=b)))
* will almost certainly be planned as a Broadcast Nested Loop join,
* which is very time consuming because it's an O(M*N) calculation.
* But if it's a single column case O(M*N) calculation could be optimized into O(M)
* using hash lookup instead of loop lookup.
*/
def unapply(join: Join): Option[ReturnType] = join match {
case Join(left, right, LeftAnti,
Some(Or(e @ EqualTo(leftAttr: AttributeReference, rightAttr: AttributeReference),
IsNull(e2 @ EqualTo(_, _)))), _)
if SQLConf.get.optimizeNullAwareAntiJoin &&
e.semanticEquals(e2) =>
if (canEvaluate(leftAttr, left) && canEvaluate(rightAttr, right)) {
Some(Seq(leftAttr), Seq(rightAttr))
} else if (canEvaluate(leftAttr, right) && canEvaluate(rightAttr, left)) {
Some(Seq(rightAttr), Seq(leftAttr))
} else {
case Join(left, right, LeftAnti, condition, _) if SQLConf.get.optimizeNullAwareAntiJoin =>
val predicates = condition.map(splitConjunctivePredicates).getOrElse(Nil)
if (predicates.isEmpty ||
predicates.length > SQLConf.get.optimizeNullAwareAntiJoinMaxNumKeys) {
None
} else {
val joinKeys = ArrayBuffer[(Expression, Expression)]()

// All predicate must match pattern condition: Or(EqualTo(a=b), IsNull(EqualTo(a=b)))
val allMatch = predicates.forall {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit format; how about this?

        val joinKeys = ArrayBuffer[(Expression, Expression)]()

        // All predicate must match pattern condition: Or(EqualTo(a=b), IsNull(EqualTo(a=b)))
        val allMatch = predicates.forall {
          case Or(e @ EqualTo(leftExpr: Expression, rightExpr: Expression),
              IsNull(e2 @ EqualTo(_, _))) if e.semanticEquals(e2) =>
            if (canEvaluate(leftExpr, left) && canEvaluate(rightExpr, right)) {
              joinKeys += ((leftExpr, rightExpr))
              true
            } else if (canEvaluate(leftExpr, right) && canEvaluate(rightExpr, left)) {
              joinKeys += ((rightExpr, leftExpr))
              true
            } else {
              false
            }
          case _ =>
            false
        }

        if (allMatch) {
          Some(joinKeys.unzip)
        } else {
          None
        }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

case Or(e @ EqualTo(leftExpr: Expression, rightExpr: Expression),
IsNull(e2 @ EqualTo(_, _))) if e.semanticEquals(e2) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC this pattern matching depends on the RewritePredicateSubquery code:

val nullAwareJoinConds = baseJoinConds.map(c => Or(c, IsNull(c)))

This is okay now, but I'm a little worried that it does not work well if the RewritePredicateSubquery code will be updated; for example, if both attributes are non-nullable in a join condition, we might be able to remove IsNull(c) for optimization in the RewritePredicateSubquery rule.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. the IsNull being removed case is considered, we only do NAAJ optimize with the Or condition still exists.

Basically, the NAAJ Optimize switch triggered at SparkStrategies, which I think optimizer is done its job. it's save to put this pattern check in physical plan state

// negative hand-written left anti join
      // testData.key nullable false
      // testData2.a nullable false
      // isnull(key = a) isnull(key+1=a) will be optimized to true literal and removed
      joinExec = assertJoin((
        "SELECT * FROM testData LEFT ANTI JOIN testData3 ON (key = a OR ISNULL(key = a)) " +
          "AND (key + 1 = a OR ISNULL(key + 1 = a))",
        classOf[BroadcastHashJoinExec]))
      assert(!joinExec.asInstanceOf[BroadcastHashJoinExec].isNullAwareAntiJoin)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in JoinSuite Line 1209. FYI.

if (canEvaluate(leftExpr, left) && canEvaluate(rightExpr, right)) {
joinKeys += ((leftExpr, rightExpr))
true
} else if (canEvaluate(leftExpr, right) && canEvaluate(rightExpr, left)) {
joinKeys += ((rightExpr, leftExpr))
true
} else {
false
}
case _ => false
}

if (allMatch) {
Some(joinKeys.unzip)
} else {
None
}
}
case _ => None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2678,14 +2678,26 @@ object SQLConf {
.checkValue(_ >= 0, "The value must be non-negative.")
.createWithDefault(8)

val OPTIMIZE_NULL_AWARE_ANTI_JOIN_MAX_NUM_KEYS =
buildConf("spark.sql.optimizeNullAwareAntiJoin.maxNumKeys")
.internal()
.doc("The maximum number of keys that will be supported to use NAAJ optimize. " +
"While with NAAJ optimize, buildSide data would be expanded to (2^numKeys - 1) times, " +
"it might cause Driver OOM if NAAJ numKeys increased, since it is exponential growth.")
.version("3.1.0")
.intConf
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Plz add checkValue. I think only a positive value seems reasonable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add .version("3.1.0"): #29335

.checkValue(_ > 0, "The value must be positive.")
.createWithDefault(3)

val OPTIMIZE_NULL_AWARE_ANTI_JOIN =
buildConf("spark.sql.optimizeNullAwareAntiJoin")
.internal()
.doc("When true, NULL-aware anti join execution will be planed into " +
"BroadcastHashJoinExec with flag isNullAwareAntiJoin enabled, " +
"optimized from O(M*N) calculation into O(M) calculation " +
"using Hash lookup instead of Looping lookup." +
"Only support for singleColumn NAAJ for now.")
"using Hash lookup instead of Looping lookup. " +
"The number of keys supported for NAAJ is configured by " +
s"${OPTIMIZE_NULL_AWARE_ANTI_JOIN_MAX_NUM_KEYS.key}.")
.version("3.1.0")
.booleanConf
.createWithDefault(true)
Expand Down Expand Up @@ -3300,6 +3312,9 @@ class SQLConf extends Serializable with Logging {
def optimizeNullAwareAntiJoin: Boolean =
getConf(SQLConf.OPTIMIZE_NULL_AWARE_ANTI_JOIN)

def optimizeNullAwareAntiJoinMaxNumKeys: Int =
getConf(SQLConf.OPTIMIZE_NULL_AWARE_ANTI_JOIN_MAX_NUM_KEYS)

/** ********************** SQLConf functionality methods ************ */

/** Set Spark SQL configuration properties. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
.orElse { if (hintToShuffleReplicateNL(hint)) createCartesianProduct() else None }
.getOrElse(createJoinWithoutHint())

case j @ ExtractSingleColumnNullAwareAntiJoin(leftKeys, rightKeys) =>
case j @ ExtractNullAwareAntiJoinKeys(leftKeys, rightKeys) =>
Seq(joins.BroadcastHashJoinExec(leftKeys, rightKeys, LeftAnti, BuildRight,
None, planLater(j.left), planLater(j.right), isNullAwareAntiJoin = true))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.adaptive
import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.expressions.PredicateHelper
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.planning.{ExtractEquiJoinKeys, ExtractSingleColumnNullAwareAntiJoin}
import org.apache.spark.sql.catalyst.planning.{ExtractEquiJoinKeys, ExtractNullAwareAntiJoinKeys}
import org.apache.spark.sql.catalyst.plans.LeftAnti
import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan}
import org.apache.spark.sql.execution.{joins, SparkPlan}
Expand Down Expand Up @@ -49,7 +49,7 @@ object LogicalQueryStageStrategy extends Strategy with PredicateHelper {
Seq(BroadcastHashJoinExec(
leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right)))

case j @ ExtractSingleColumnNullAwareAntiJoin(leftKeys, rightKeys)
case j @ ExtractNullAwareAntiJoinKeys(leftKeys, rightKeys)
if isBroadcastStage(j.right) =>
Seq(joins.BroadcastHashJoinExec(leftKeys, rightKeys, LeftAnti, BuildRight,
None, planLater(j.left), planLater(j.right), isNullAwareAntiJoin = true))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ case class BroadcastHashJoinExec(
extends HashJoin with CodegenSupport {

if (isNullAwareAntiJoin) {
require(leftKeys.length == 1, "leftKeys length should be 1")
require(rightKeys.length == 1, "rightKeys length should be 1")
require(joinType == LeftAnti, "joinType must be LeftAnti.")
require(buildSide == BuildRight, "buildSide must be BuildRight.")
require(condition.isEmpty, "null aware anti join optimize condition should be empty.")
Expand Down Expand Up @@ -156,7 +154,7 @@ case class BroadcastHashJoinExec(
)
streamedIter.filter(row => {
val lookupKey: UnsafeRow = keyGenerator(row)
if (lookupKey.anyNull()) {
if (lookupKey.allNull()) {
false
} else {
// Anti Join: Drop the row on the streamed side if it is a match on the build
Expand Down Expand Up @@ -225,9 +223,10 @@ case class BroadcastHashJoinExec(
protected override def codegenAnti(ctx: CodegenContext, input: Seq[ExprCode]): String = {
if (isNullAwareAntiJoin) {
val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input, true)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HI @leanken ... I am curious why we force unsafe for NAAJ ? Is that for efficiency or is the implementation assume unsafe row ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because if it's a single Long key, keyEv will be codegen as java long, but not UnsafeRow. In InvertedIndexHashedRelation, it only takes UnsafeRow as input to lookup in buildSide.

val (matched, _, _) = getJoinCondition(ctx, input)
val numOutput = metricTerm(ctx, "numOutputRows")
val isLongHashedRelation = broadcastRelation.value.isInstanceOf[LongHashedRelation]

if (broadcastRelation.value == EmptyHashedRelation) {
s"""
Expand All @@ -245,11 +244,10 @@ case class BroadcastHashJoinExec(
|boolean $found = false;
|// generate join key for stream side
|${keyEv.code}
|if ($anyNull) {
|if (${s"${keyEv.value}.allNull()"}) {
| $found = true;
|} else {
| UnsafeRow $matched = (UnsafeRow)$relationTerm.getValue(${keyEv.value});
| if ($matched != null) {
| if ($relationTerm.get(${keyEv.value}) != null) {
| $found = true;
| }
|}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,11 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
*/
protected def genStreamSideJoinKey(
ctx: CodegenContext,
input: Seq[ExprCode]): (ExprCode, String) = {
input: Seq[ExprCode],
forceUnsafe: Boolean = false): (ExprCode, String) = {
ctx.currentVars = input
if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType == LongType) {
if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType == LongType &&
!forceUnsafe) {
// generate the join key as Long
val ev = streamedBoundKeys.head.genCode(ctx)
(ev, ev.isNull)
Expand Down
Loading