Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
b7a4a3e
Secondary sort Sort Merge Join optimization. Not finished yet.
Mar 22, 2018
fd81613
SortMergeJoin secondary sort optimization
Mar 22, 2018
f1efa9b
Sort-Merge "inner range join" (secondary sort) - code generation
Apr 4, 2018
f533f65
Sort-Merge "inner range join" (secondary sort) - two bug fixes - work…
Apr 5, 2018
2ff492c
Code simplification
Apr 9, 2018
3ff654a
Bug fix
Apr 10, 2018
85039bb
Scalastyle fixes
Apr 13, 2018
bbcb400
Scalastyle fixes
Apr 13, 2018
640aa6d
SMJ range join unit tests
Apr 13, 2018
2548b7d
Scalastyle
Apr 13, 2018
069bc01
Scalastyle
Apr 13, 2018
3bc71c5
Scalastyle
Apr 13, 2018
5c62f98
Fix generated code - dequeue method missing
Apr 13, 2018
16e3e1b
Bug fix: include other binary comparisons in range conditions match
Apr 13, 2018
e7f7bdf
Test fix: sortWithinPartitions; Bug Fix: check references in rangeCon…
Apr 16, 2018
41cde27
Test fix
Apr 16, 2018
080ab0d
Test fix
Apr 17, 2018
8628216
Fix required child ordering for inner range queries
Apr 18, 2018
7bd6732
Parameter for turning off inner range optimization
Apr 19, 2018
094f66b
Scala style
Apr 19, 2018
efd595e
Bug fix - NPE when inner range optimization turned off
Apr 19, 2018
a8372e3
Adding test case when inner range optimization is turned off
Apr 19, 2018
4396985
Stala style
Apr 19, 2018
9c14368
Stala style
Apr 19, 2018
6cbf9fe
Remove range condition extraction when inner range join optimization …
Apr 19, 2018
82943b8
Scala style
Apr 19, 2018
bbddf7a
Unit test fix
Apr 19, 2018
c4060d7
Unit test fix
Apr 19, 2018
5b0f2b5
- Turning off inner range optimization when whole stage code generati…
Apr 27, 2018
68e00c0
Switch off inner range optimization when whole stage codegen is off.
Apr 27, 2018
f5b9ca8
SMJ inner range optimization benchmarks
Apr 27, 2018
7457ab3
Removing "expensive function" from the SMJ inner range optimization b…
Apr 28, 2018
c47c8cd
SMJ inner range optimization with wholeStage codegen turned off - cod…
May 10, 2018
3fbedfc
Unit test fix. Benchmark results update.
May 10, 2018
b8e1ee4
Scalastyle for comments
May 10, 2018
2710957
Code changes based on review comments.
May 15, 2018
52f2b70
Code review changes
Jun 7, 2018
169bd70
Removing exception when numRowsInMemoryBufferThreshold is reached in …
Jun 7, 2018
89169de
Scala style
Jun 7, 2018
eeaf048
Unneeded import
Jun 13, 2018
75ce55d
A dot
Jun 13, 2018
77dd2a8
A dot
Jun 14, 2018
eac81b4
A dot
Jun 15, 2018
1abde55
A dot
Jun 18, 2018
dfb4c0f
A dot
Jun 19, 2018
6d4c031
Fixes for some rebase issues.
ctslater Jun 28, 2018
6d9cd12
Merge with upstream
zecevicp Jun 29, 2018
7742c10
Merge with upstream
zecevicp Jun 29, 2018
3a717ee
SMJ inner range spill over implementation and tests
zecevicp Aug 10, 2018
64437e5
External unsafe row dequeue test extension
zecevicp Aug 10, 2018
0a5c8de
A dot
zecevicp Aug 11, 2018
07ff4d3
Merge branch 'master' into branch-pz-smj
zecevicp Dec 10, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ object StreamingJoinHelper extends PredicateHelper with Logging {
*/
def isWatermarkInJoinKeys(plan: LogicalPlan): Boolean = {
plan match {
case ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, _, _) =>
case ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, _, _, _) =>
(leftKeys ++ rightKeys).exists {
case a: AttributeReference => a.metadata.contains(EventTimeWatermark.delayKey)
case _ => false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, CurrentDate, CurrentTimestamp, MonotonicallyIncreasingID}
import org.apache.spark.sql.catalyst.expressions.{Attribute, CurrentDate, CurrentTimestamp, MonotonicallyIncreasingID}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@

package org.apache.spark.sql.catalyst.planning

import scala.collection.mutable

import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf

/**
* A pattern that matches any number of project or filter operations on top of another relational
Expand Down Expand Up @@ -98,9 +101,10 @@ object PhysicalOperation extends PredicateHelper {
* value).
*/
object ExtractEquiJoinKeys extends Logging with PredicateHelper {
/** (joinType, leftKeys, rightKeys, condition, leftChild, rightChild) */
/** (joinType, leftKeys, rightKeys, rangeConditions, condition, leftChild, rightChild) */
type ReturnType =
(JoinType, Seq[Expression], Seq[Expression], Option[Expression], LogicalPlan, LogicalPlan)
(JoinType, Seq[Expression], Seq[Expression], Seq[BinaryComparison],
Option[Expression], LogicalPlan, LogicalPlan)

def unapply(plan: LogicalPlan): Option[ReturnType] = plan match {
case join @ Join(left, right, joinType, condition) =>
Expand Down Expand Up @@ -132,13 +136,97 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {

if (joinKeys.nonEmpty) {
val (leftKeys, rightKeys) = joinKeys.unzip
logDebug(s"leftKeys:$leftKeys | rightKeys:$rightKeys")
Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), left, right))
// Find any simple range expressions between two columns
// (and involving only those two columns) of the two tables being joined,
// which are not used in the equijoin expressions,
// and which can be used for secondary sort optimizations.
// rangePreds will contain the original expressions to be filtered out later.
val rangePreds = mutable.Set.empty[Expression]
var rangeConditions: Seq[BinaryComparison] =
if (SQLConf.get.useSmjInnerRangeOptimization) {
otherPredicates.flatMap {
case p@LessThan(l, r) => checkRangeConditions(l, r, left, right, joinKeys).map {
case true => rangePreds.add(p); GreaterThan(r, l)
case false => rangePreds.add(p); p
}
case p@LessThanOrEqual(l, r) =>
checkRangeConditions(l, r, left, right, joinKeys).map {
case true => rangePreds.add(p); GreaterThanOrEqual(r, l)
case false => rangePreds.add(p); p
}
case p@GreaterThan(l, r) => checkRangeConditions(l, r, left, right, joinKeys).map {
case true => rangePreds.add(p); LessThan(r, l)
case false => rangePreds.add(p); p
}
case p@GreaterThanOrEqual(l, r) =>
checkRangeConditions(l, r, left, right, joinKeys).map {
case true => rangePreds.add(p); LessThanOrEqual(r, l)
case false => rangePreds.add(p); p
}
case _ => None
}
} else {
Nil
}

// Only using secondary join optimization when both lower and upper conditions
// are specified (e.g. t1.a < t2.b + x and t1.a > t2.b - x)
if (rangeConditions.size != 2 ||
// Looking for one < and one > comparison:
rangeConditions.forall(x => !x.isInstanceOf[LessThan] &&
!x.isInstanceOf[LessThanOrEqual]) ||
rangeConditions.forall(x => !x.isInstanceOf[GreaterThan] &&
!x.isInstanceOf[GreaterThanOrEqual]) ||
// Check if both comparisons reference the same columns:
rangeConditions.flatMap(c => c.left.references.toSeq.distinct).distinct.size != 1 ||
rangeConditions.flatMap(c => c.right.references.toSeq.distinct).distinct.size != 1) {
logDebug("Inner range optimization conditions not met. Clearing range conditions")
rangeConditions = Nil
rangePreds.clear()
}

Some((joinType, leftKeys, rightKeys, rangeConditions,
otherPredicates.filterNot(rangePreds.contains(_)).reduceOption(And), left, right))
} else {
None
}
case _ => None
}

/**
* Checks if l and r are valid range conditions:
* - l and r expressions should both contain a single reference to one and the same column
* - the referenced column should not be part of joinKeys
* If these conditions are not met, the function returns None.
*
* Otherwise, the function checks if the left plan contains l expression and the right plan
* contains r expression. If the expressions need to be switched, the function returns Some(true)
* and Some(false) otherwise.
*/
private def checkRangeConditions(l : Expression, r : Expression,
left : LogicalPlan, right : LogicalPlan,
joinKeys : Seq[(Expression, Expression)]): Option[Boolean] = {
val (lattrs, rattrs) = (l.references.toSeq, r.references.toSeq)
if (lattrs.size != 1 || rattrs.size != 1) {
None
} else if (canEvaluate(l, left) && canEvaluate(r, right)) {
if (joinKeys.exists { case (ljk : Expression, rjk : Expression) =>
ljk.references.toSeq.contains(lattrs(0)) && rjk.references.toSeq.contains(rattrs(0)) }) {
None
} else {
Some(false)
}
} else if (canEvaluate(l, right) && canEvaluate(r, left)) {
if (joinKeys.exists{ case (ljk : Expression, rjk : Expression) =>
rjk.references.toSeq.contains(lattrs(0)) && ljk.references.toSeq.contains(rattrs(0)) }) {
None
} else {
Some(true)
}
} else {
None
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ case class JoinEstimation(join: Join) extends Logging {
case _ if !rowCountsExist(join.left, join.right) =>
None

case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _) =>
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _, _) =>
// 1. Compute join selectivity
val joinKeyPairs = extractJoinKeysWithColStats(leftKeys, rightKeys)
val (numInnerJoinedRows, keyStatsAfterJoin) = computeCardinalityAndStats(joinKeyPairs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1492,6 +1492,19 @@ object SQLConf {
.intConf
.createWithDefault(ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH)

val USE_SMJ_INNER_RANGE_OPTIMIZATION =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, at best make this internal. Are there conditions where you would not want to apply this? is it just a safety valve?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just a safety valve. In case there are some queries that I don't foresee now where this could get in the way.

buildConf("spark.sql.join.smj.useInnerRangeOptimization")
.internal()
.doc("Sort-merge join 'inner range optimization' is applicable in cases where the join " +
"condition includes equality expressions on pairs of columns and a range expression " +
"involving two other columns, (e.g. t1.x = t2.x AND t1.y BETWEEN t2.y - d AND t2.y + d)." +
" If the inner range optimization is enabled, the number of rows considered for each " +
"match of equality conditions can be reduced considerably because a moving window, " +
"corresponding to the range conditions, will be used for iterating over matched rows " +
"in the right relation.")
.booleanConf
.createWithDefault(true)

object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
}
Expand Down Expand Up @@ -1818,6 +1831,8 @@ class SQLConf extends Serializable with Logging {

def topKSortFallbackThreshold: Int = getConf(TOP_K_SORT_FALLBACK_THRESHOLD)

def useSmjInnerRangeOptimization: Boolean = getConf(USE_SMJ_INNER_RANGE_OPTIMIZATION)

def fastHashAggregateRowMaxCapacityBit: Int = getConf(FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT)

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution

import java.util.ConcurrentModificationException

import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.{ArrayBuffer, Queue}

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.internal.Logging
Expand All @@ -41,12 +41,16 @@ import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, Unsaf
* - If [[numRowsSpillThreshold]] is too low, data will be spilled frequently and lead to
* excessive disk writes. This may lead to a performance regression compared to the normal case
* of using an [[ArrayBuffer]] or [[Array]].
*
* If [[asQueue]] is set to true, the class will function as a queue, supporting peek() and
* dequeue() operations.
*/
private[sql] class ExternalAppendOnlyUnsafeRowArray(
taskMemoryManager: TaskMemoryManager,
blockManager: BlockManager,
serializerManager: SerializerManager,
taskContext: TaskContext,
asQueue: Boolean,
initialSize: Int,
pageSizeBytes: Long,
numRowsInMemoryBufferThreshold: Int,
Expand All @@ -58,6 +62,20 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray(
SparkEnv.get.blockManager,
SparkEnv.get.serializerManager,
TaskContext.get(),
false,
1024,
SparkEnv.get.memoryManager.pageSizeBytes,
numRowsInMemoryBufferThreshold,
numRowsSpillThreshold)
}

def this(numRowsInMemoryBufferThreshold: Int, numRowsSpillThreshold: Int, asQueue: Boolean) {
this(
TaskContext.get().taskMemoryManager(),
SparkEnv.get.blockManager,
SparkEnv.get.serializerManager,
TaskContext.get(),
asQueue,
1024,
SparkEnv.get.memoryManager.pageSizeBytes,
numRowsInMemoryBufferThreshold,
Expand All @@ -67,7 +85,13 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray(
private val initialSizeOfInMemoryBuffer =
Math.min(DefaultInitialSizeOfInMemoryBuffer, numRowsInMemoryBufferThreshold)

private val inMemoryBuffer = if (initialSizeOfInMemoryBuffer > 0) {
private val inMemoryQueue = if (asQueue && initialSizeOfInMemoryBuffer > 0) {
new Queue[UnsafeRow]()
} else {
null
}

private val inMemoryBuffer = if (!asQueue && initialSizeOfInMemoryBuffer > 0) {
new ArrayBuffer[UnsafeRow](initialSizeOfInMemoryBuffer)
} else {
null
Expand All @@ -76,6 +100,9 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray(
private var spillableArray: UnsafeExternalSorter = _
private var numRows = 0

// Used when functioning as a queue to allow skipping 'dequeued' items
private var spillableArrayOffset = 0

// A counter to keep track of total modifications done to this array since its creation.
// This helps to invalidate iterators when there are changes done to the backing array.
private var modificationsCount: Long = 0
Expand All @@ -95,17 +122,60 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray(
// inside `UnsafeExternalSorter`
spillableArray.cleanupResources()
spillableArray = null
spillableArrayOffset = 0
} else if (inMemoryBuffer != null) {
inMemoryBuffer.clear()
} else if (inMemoryQueue != null) {
inMemoryQueue.clear()
}
numFieldsPerRow = 0
numRows = 0
modificationsCount += 1
}

def dequeue(): Option[UnsafeRow] = {
if (!asQueue) {
throw new IllegalStateException("Not instantiated as a queue!")
}
if (numRows == 0) {
None
}
else if (spillableArray != null) {
val retval = Some(generateIterator().next)
numRows -= 1
modificationsCount += 1
spillableArrayOffset += 1
retval
}
else {
numRows -= 1
modificationsCount += 1
Some(inMemoryQueue.dequeue())
}
}

def peek(): Option[UnsafeRow] = {
if (!asQueue) {
throw new IllegalStateException("Not instantiated as a queue!")
}
if (numRows == 0) {
None
}
else if (spillableArray != null) {
Some(generateIterator().next)
}
else {
Some(inMemoryQueue(0))
}
}

def add(unsafeRow: UnsafeRow): Unit = {
if (numRows < numRowsInMemoryBufferThreshold) {
inMemoryBuffer += unsafeRow.copy()
if (spillableArray == null && numRows < numRowsInMemoryBufferThreshold) {
if (asQueue) {
inMemoryQueue += unsafeRow.copy()
} else {
inMemoryBuffer += unsafeRow.copy()
}
} else {
if (spillableArray == null) {
logInfo(s"Reached spill threshold of $numRowsInMemoryBufferThreshold rows, switching to " +
Expand All @@ -124,8 +194,21 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray(
numRowsSpillThreshold,
false)

spillableArrayOffset = 0

// populate with existing in-memory buffered rows
if (inMemoryBuffer != null) {
if (asQueue && inMemoryQueue != null) {
inMemoryQueue.foreach(existingUnsafeRow =>
spillableArray.insertRecord(
existingUnsafeRow.getBaseObject,
existingUnsafeRow.getBaseOffset,
existingUnsafeRow.getSizeInBytes,
0,
false)
)
inMemoryQueue.clear()
}
if (!asQueue && inMemoryBuffer != null) {
inMemoryBuffer.foreach(existingUnsafeRow =>
spillableArray.insertRecord(
existingUnsafeRow.getBaseObject,
Expand Down Expand Up @@ -168,7 +251,8 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray(
if (spillableArray == null) {
new InMemoryBufferIterator(startIndex)
} else {
new SpillableArrayIterator(spillableArray.getIterator(startIndex), numFieldsPerRow)
new SpillableArrayIterator(spillableArray.getIterator(startIndex + spillableArrayOffset),
numFieldsPerRow)
}
}

Expand Down Expand Up @@ -198,7 +282,7 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray(

override def next(): UnsafeRow = {
throwExceptionIfModified()
val result = inMemoryBuffer(currentIndex)
val result = if (asQueue) inMemoryQueue(currentIndex) else inMemoryBuffer(currentIndex)
currentIndex += 1
result
}
Expand Down
Loading