Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {

@Override
public Object apply(Object r) {
// GenerateProjection does not work with UnsafeRows.
assert(!(r instanceof ${classOf[UnsafeRow].getName}));
return new SpecificRow((InternalRow) r);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,11 @@ case class Window(

// Get all relevant projections.
val result = createResultProjection(unboundExpressions)
val grouping = newProjection(partitionSpec, child.output)
val grouping = if (child.outputsUnsafeRows) {
UnsafeProjection.create(partitionSpec, child.output)
} else {
newProjection(partitionSpec, child.output)
}

// Manage the stream and the grouping.
var nextRow: InternalRow = EmptyRow
Expand All @@ -277,7 +281,8 @@ case class Window(
val numFrames = frames.length
private[this] def fetchNextPartition() {
// Collect all the rows in the current partition.
val currentGroup = nextGroup
// Before we start to fetch new input rows, make a copy of nextGroup.
val currentGroup = nextGroup.copy()
rows = new CompactBuffer
while (nextRowAvailable && nextGroup == currentGroup) {
rows += nextRow.copy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,6 @@ case class SortMergeJoin(
override def requiredChildOrdering: Seq[Seq[SortOrder]] =
requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil

@transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output)
@transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output)

protected[this] def isUnsafeMode: Boolean = {
(codegenEnabled && unsafeEnabled
&& UnsafeProjection.canSupport(leftKeys)
Expand All @@ -82,6 +79,28 @@ case class SortMergeJoin(

left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
new RowIterator {
// The projection used to extract keys from input rows of the left child.
private[this] val leftKeyGenerator = {
if (isUnsafeMode) {
// It is very important to use UnsafeProjection if input rows are UnsafeRows.
// Otherwise, GenerateProjection will cause wrong results.
UnsafeProjection.create(leftKeys, left.output)
} else {
newProjection(leftKeys, left.output)
}
}

// The projection used to extract keys from input rows of the right child.
private[this] val rightKeyGenerator = {
if (isUnsafeMode) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, can we move this check into newProjection?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had another discussion. We will do the cleanup work in a follow up pr.

// It is very important to use UnsafeProjection if input rows are UnsafeRows.
// Otherwise, GenerateProjection will cause wrong results.
UnsafeProjection.create(rightKeys, right.output)
} else {
newProjection(rightKeys, right.output)
}
}

// An ordering that can be used to compare keys from both sides.
private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType))
private[this] var currentLeftRow: InternalRow = _
Expand Down
28 changes: 28 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1781,4 +1781,32 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
Seq(Row(1), Row(1)))
}
}

test("SortMergeJoin returns wrong results when using UnsafeRows") {
// This test is for the fix of https://issues.apache.org/jira/browse/SPARK-10737.
// This bug will be triggered when Tungsten is enabled and there are multiple
// SortMergeJoin operators executed in the same task.
val confs =
SQLConf.SORTMERGE_JOIN.key -> "true" ::
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1" ::
SQLConf.TUNGSTEN_ENABLED.key -> "true" :: Nil
withSQLConf(confs: _*) {
val df1 = (1 to 50).map(i => (s"str_$i", i)).toDF("i", "j")
val df2 =
df1
.join(df1.select(df1("i")), "i")
.select(df1("i"), df1("j"))

val df3 = df2.withColumnRenamed("i", "i1").withColumnRenamed("j", "j1")
val df4 =
df2
.join(df3, df2("i") === df3("i1"))
.withColumn("diff", $"j" - $"j1")
.select(df2("i"), df2("j"), $"diff")

checkAnswer(
df4,
df1.withColumn("diff", lit(0)))
}
}
}