From 7655581d2bd864ccb822e1ada0374f37aed277c5 Mon Sep 17 00:00:00 2001 From: Min Shen Date: Sat, 21 Sep 2019 17:40:38 -0700 Subject: [PATCH 1/9] SPARK-21492 Fix memory leak issue in SMJ --- .../execution/UnsafeExternalRowIterator.java | 75 +++++++++++++++++++ .../sql/execution/BufferedRowIterator.java | 6 ++ .../execution/UnsafeExternalRowSorter.java | 37 ++------- .../apache/spark/sql/execution/SortExec.scala | 10 +++ .../sql/execution/WholeStageCodegenExec.scala | 14 +++- .../execution/joins/SortMergeJoinExec.scala | 69 ++++++++++++++++- 6 files changed, 173 insertions(+), 38 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowIterator.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowIterator.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowIterator.java new file mode 100644 index 000000000000..41612de3779b --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowIterator.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution; + +import java.io.Closeable; +import java.io.IOException; + +import scala.collection.AbstractIterator; + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator; + +public abstract class UnsafeExternalRowIterator extends AbstractIterator implements Closeable { + + private final UnsafeSorterIterator sortedIterator; + private UnsafeRow row; + + UnsafeExternalRowIterator(StructType schema, UnsafeSorterIterator iterator) { + row = new UnsafeRow(schema.length()); + sortedIterator = iterator; + } + + @Override + public boolean hasNext() { + return sortedIterator.hasNext(); + } + + @Override + public UnsafeRow next() { + try { + sortedIterator.loadNext(); + row.pointTo( + sortedIterator.getBaseObject(), + sortedIterator.getBaseOffset(), + sortedIterator.getRecordLength()); + if (!hasNext()) { + UnsafeRow copy = row.copy(); // so that we don't have dangling pointers to freed page + row = null; // so that we don't keep references to the base object + close(); + return copy; + } else { + return row; + } + } catch (IOException e) { + close(); + // Scala iterators don't declare any checked exceptions, so we need to use this hack + // to re-throw the exception: + Platform.throwException(e); + } + throw new RuntimeException("Exception should have been re-thrown in next()"); + } + + /** + * Implementation should clean up resources used by this iterator, to prevent memory leaks + */ + @Override + public abstract void close(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java index 3d0511b7ba83..7ad7275d13b1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -95,4 +95,10 @@ public void incPeakExecutionMemory(long size) { * After it's called, if currentRow is still null, it means no more rows left. */ protected abstract void processNext() throws IOException; + + /** + * This enables the generate class to implement a method in order to properly release the resources + * if the iterator is not fully consumed. See SPARK-21492 for more details. + */ + public void close() {} } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 863d80b5cb9c..27bf9c1df448 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -20,7 +20,6 @@ import java.io.IOException; import java.util.function.Supplier; -import scala.collection.AbstractIterator; import scala.collection.Iterator; import scala.math.Ordering; @@ -169,39 +168,13 @@ public Iterator sort() throws IOException { // here in order to prevent memory leaks. cleanupResources(); } - return new AbstractIterator() { - - private final int numFields = schema.length(); - private UnsafeRow row = new UnsafeRow(numFields); - - @Override - public boolean hasNext() { - return sortedIterator.hasNext(); - } + return new UnsafeExternalRowIterator(schema, sortedIterator) { @Override - public UnsafeRow next() { - try { - sortedIterator.loadNext(); - row.pointTo( - sortedIterator.getBaseObject(), - sortedIterator.getBaseOffset(), - sortedIterator.getRecordLength()); - if (!hasNext()) { - UnsafeRow copy = row.copy(); // so that we don't have dangling pointers to freed page - row = null; // so that we don't keep references to the base object - cleanupResources(); - return copy; - } else { - return row; - } - } catch (IOException e) { - cleanupResources(); - // Scala iterators don't declare any checked exceptions, so we need to use this hack - // to re-throw the exception: - Platform.throwException(e); - } - throw new RuntimeException("Exception should have been re-thrown in next()"); + public void close() { + // Caller should clean up resources if the iterator is not consumed in it's entirety, + // for example, in a SortMergeJoin. + cleanupResources(); } }; } catch (IOException e) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index 0a955d6a7523..4f2de890478b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -149,6 +149,16 @@ case class SortExec( | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} | } """.stripMargin.trim) + // Override the close method in BufferedRowIterator to release resources if the sortedIterator + // is not fully consumed + ctx.addNewFunction("close", + s""" + | public void close() { + | if ($sortedIterator != null) { + | ((org.apache.spark.sql.execution.UnsafeExternalRowIterator)$sortedIterator).close(); + | } + | } + """.stripMargin, true) val outputRow = ctx.freshName("outputRow") val peakMemory = metricTerm(ctx, "peakMemory") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index f723fcfac6d0..4c9843849c9d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -568,6 +568,14 @@ object WholeStageCodegenExec { } } +/** + * A trait that extends Scala Iterator[InternalRow] which enables exposing the underlying + * BufferedRowIterator + */ +trait ScalaIteratorWithBufferedIterator extends Iterator[InternalRow] { + def getBufferedRowIterator: BufferedRowIterator +} + /** * WholeStageCodegen compiles a subtree of plans that support codegen together into single Java * function. @@ -721,13 +729,14 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) val (clazz, _) = CodeGenerator.compile(cleanedSource) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] buffer.init(index, Array(iter)) - new Iterator[InternalRow] { + new ScalaIteratorWithBufferedIterator { override def hasNext: Boolean = { val v = buffer.hasNext if (!v) durationMs += buffer.durationMs() v } override def next: InternalRow = buffer.next() + override def getBufferedRowIterator: BufferedRowIterator = buffer } } } else { @@ -740,13 +749,14 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) val (clazz, _) = CodeGenerator.compile(cleanedSource) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] buffer.init(index, Array(leftIter, rightIter)) - new Iterator[InternalRow] { + new ScalaIteratorWithBufferedIterator { override def hasNext: Boolean = { val v = buffer.hasNext if (!v) durationMs += buffer.durationMs() v } override def next: InternalRow = buffer.next() + override def getBufferedRowIterator: BufferedRowIterator = buffer } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 189727a9bc88..939de957357f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -212,6 +212,7 @@ case class SortMergeJoinExec( currentRightMatches = null currentLeftRow = null rightMatchesIterator = null + smjScanner.close() return false } } @@ -221,6 +222,7 @@ case class SortMergeJoinExec( return true } } + smjScanner.close() false } @@ -302,6 +304,7 @@ case class SortMergeJoinExec( } } } + smjScanner.close() false } @@ -343,6 +346,7 @@ case class SortMergeJoinExec( return true } } + smjScanner.close() false } @@ -382,6 +386,7 @@ case class SortMergeJoinExec( numOutputRows += 1 return true } + smjScanner.close() false } @@ -640,6 +645,8 @@ case class SortMergeJoinExec( (evaluateVariables(leftVars), "") } + // The last two line of code generate in processNext here will handle properly + // releasing the resources if the input iterators are not fully consumed s""" |while (findNextInnerJoinRows($leftInput, $rightInput)) { | ${leftVarDecl.mkString("\n")} @@ -653,6 +660,10 @@ case class SortMergeJoinExec( | } | if (shouldStop()) return; |} + |((org.apache.spark.sql.execution.ScalaIteratorWithBufferedIterator)$leftInput) + | .getBufferedRowIterator().close(); + |((org.apache.spark.sql.execution.ScalaIteratorWithBufferedIterator)$rightInput) + | .getBufferedRowIterator().close(); """.stripMargin } } @@ -686,7 +697,7 @@ private[joins] class SortMergeJoinScanner( streamedIter: RowIterator, bufferedIter: RowIterator, inMemoryThreshold: Int, - spillThreshold: Int) { + spillThreshold: Int) extends CloseableScanner { private[this] var streamedRow: InternalRow = _ private[this] var streamedRowKey: InternalRow = _ private[this] var bufferedRow: InternalRow = _ @@ -802,6 +813,15 @@ private[joins] class SortMergeJoinScanner( } } + /** + * Once the join has been completed, the iterators for both relations + * should be closed, so that acquired memory can be released. + */ + def close(): Unit = { + closeIterator(streamedIter) + closeIterator(bufferedIter) + } + // --- Private methods -------------------------------------------------------------------------- /** @@ -969,7 +989,11 @@ private abstract class OneSideOuterIterator( override def advanceNext(): Boolean = { val r = advanceBufferUntilBoundConditionSatisfied() || advanceStream() - if (r) numOutputRows += 1 + if (r) { + numOutputRows += 1 + } else { + smjScanner.close() + } r } @@ -984,7 +1008,7 @@ private class SortMergeFullOuterJoinScanner( rightIter: RowIterator, boundCondition: InternalRow => Boolean, leftNullRow: InternalRow, - rightNullRow: InternalRow) { + rightNullRow: InternalRow) extends CloseableScanner { private[this] val joinedRow: JoinedRow = new JoinedRow() private[this] var leftRow: InternalRow = _ private[this] var leftRowKey: InternalRow = _ @@ -1147,6 +1171,15 @@ private class SortMergeFullOuterJoinScanner( false } } + + /** + * Once the join has been completed, the iterators for both relations + * should be closed, so that acquired memory can be released. + */ + def close(): Unit = { + closeIterator(leftIter) + closeIterator(rightIter) + } } private class FullOuterIterator( @@ -1157,9 +1190,37 @@ private class FullOuterIterator( override def advanceNext(): Boolean = { val r = smjScanner.advanceNext() - if (r) numRows += 1 + if (r) { + numRows += 1 + } else { + smjScanner.close() + } r } override def getRow: InternalRow = resultProj(joinedRow) } + +/** + * This trait enables the SMJ scanner to properly release the resources if either the + * left or right iterator is not fully consumed. This only works for when codegen is + * not enabled, i.e. for non-inner join and inner join when whole-stage codegen is + * disabled. For inner join with whole-stage codegen, it is handled separately in + * the generated code. + */ +trait CloseableScanner { + def closeIterator(iter: RowIterator): Unit = { + iter match { + case rowIter: RowIteratorFromScala => + val underlyingIter = rowIter.toScala + underlyingIter match { + case toClose: UnsafeExternalRowIterator => + toClose.close() + case toClose: ScalaIteratorWithBufferedIterator => + toClose.getBufferedRowIterator.close() + case _ => + } + case _ => + } + } +} From 1f2ca09e5f859f1ff554d4c9b9dc87e971fed437 Mon Sep 17 00:00:00 2001 From: Min Shen Date: Sat, 21 Sep 2019 19:18:30 -0700 Subject: [PATCH 2/9] Fix Java style issues --- .../apache/spark/sql/execution/UnsafeExternalRowIterator.java | 3 ++- .../org/apache/spark/sql/execution/BufferedRowIterator.java | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowIterator.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowIterator.java index 41612de3779b..b170fa2c6fd0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowIterator.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowIterator.java @@ -27,7 +27,8 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator; -public abstract class UnsafeExternalRowIterator extends AbstractIterator implements Closeable { +public abstract class UnsafeExternalRowIterator extends AbstractIterator + implements Closeable { private final UnsafeSorterIterator sortedIterator; private UnsafeRow row; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java index 7ad7275d13b1..42751c7c8f50 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -97,8 +97,8 @@ public void incPeakExecutionMemory(long size) { protected abstract void processNext() throws IOException; /** - * This enables the generate class to implement a method in order to properly release the resources - * if the iterator is not fully consumed. See SPARK-21492 for more details. + * This enables the generate class to implement a method in order to properly release the + * resources if the iterator is not fully consumed. See SPARK-21492 for more details. */ public void close() {} } From e081989e6ba5aa7db3925379de8e72ff9ea7f24f Mon Sep 17 00:00:00 2001 From: Min Shen Date: Sat, 21 Sep 2019 20:26:26 -0700 Subject: [PATCH 3/9] Fix more coding style issue --- .../org/apache/spark/sql/execution/UnsafeExternalRowSorter.java | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 27bf9c1df448..db188d8ae62d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.types.StructType; -import org.apache.spark.unsafe.Platform; import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; import org.apache.spark.util.collection.unsafe.sort.RecordComparator; import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter; From 9f78b70ab164b5e3ab5ea4e3dc728034082937dd Mon Sep 17 00:00:00 2001 From: Min Shen Date: Fri, 27 Sep 2019 09:52:22 -0700 Subject: [PATCH 4/9] Fix a corner case with SMJ inner join --- .../spark/sql/execution/joins/SortMergeJoinExec.scala | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 939de957357f..be0f3031fa25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -466,6 +466,14 @@ case class SortMergeJoinExec( // Inline mutable state since not many join operations in a task val matches = ctx.addMutableState(clsName, "matches", v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true) + // Maintain a boolean flag to track if right iterator has already been fully consumed. + // The following findNextInnerJoinRows could potentially invoke hasNext twice on the + // right iterator after it has already been fully consumed. The 1st time this happens, + // it would trigger releasing the resources of both left and right iterator. When it + // gets invoked a 2nd time, it could potentially lead to NPE. This boolean flag makes + // sure that does not happen. + val rightIterExhausted = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "rightIterExhausted", + v => s"$v = false;", forceInline = true) // Copy the left keys as class members so they could be used in next function call. val matchedKeyVars = copyKeys(ctx, leftKeyVars) @@ -494,7 +502,8 @@ case class SortMergeJoinExec( | | do { | if ($rightRow == null) { - | if (!rightIter.hasNext()) { + | if ($rightIterExhausted || !rightIter.hasNext()) { + | $rightIterExhausted = true; | ${matchedKeyVars.map(_.code).mkString("\n")} | return !$matches.isEmpty(); | } From c20b3b961693688deb5ef412bfa884fcccdd0741 Mon Sep 17 00:00:00 2001 From: Min Shen Date: Fri, 27 Sep 2019 18:34:58 -0700 Subject: [PATCH 5/9] Fix the corner case for SMJ inner join --- .../execution/joins/SortMergeJoinExec.scala | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index be0f3031fa25..df8a0458515a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -654,8 +654,10 @@ case class SortMergeJoinExec( (evaluateVariables(leftVars), "") } - // The last two line of code generate in processNext here will handle properly - // releasing the resources if the input iterators are not fully consumed + // The last two lines of code generated in processNext here will attempt to handle + // releasing the resources if the input iterators are not fully consumed. It only + // attempts to release the resources of an iterator if the associated child operator + // is codegened s""" |while (findNextInnerJoinRows($leftInput, $rightInput)) { | ${leftVarDecl.mkString("\n")} @@ -669,10 +671,16 @@ case class SortMergeJoinExec( | } | if (shouldStop()) return; |} - |((org.apache.spark.sql.execution.ScalaIteratorWithBufferedIterator)$leftInput) - | .getBufferedRowIterator().close(); - |((org.apache.spark.sql.execution.ScalaIteratorWithBufferedIterator)$rightInput) - | .getBufferedRowIterator().close(); + |if ($leftInput instanceof + | org.apache.spark.sql.execution.ScalaIteratorWithBufferedIterator) { + | ((org.apache.spark.sql.execution.ScalaIteratorWithBufferedIterator)$leftInput) + | .getBufferedRowIterator().close(); + |} + |if ($rightInput instanceof + | org.apache.spark.sql.execution.ScalaIteratorWithBufferedIterator) { + | ((org.apache.spark.sql.execution.ScalaIteratorWithBufferedIterator)$rightInput) + | .getBufferedRowIterator().close(); + |} """.stripMargin } } From 445098170af17455a54279de99064535c1f1d59d Mon Sep 17 00:00:00 2001 From: Min Shen Date: Mon, 14 Oct 2019 09:44:37 -0700 Subject: [PATCH 6/9] Ensure UnsafeSorterIterator would not attempt to get the next record after the resources have been cleaned. --- .../unsafe/sort/UnsafeExternalSorter.java | 19 ++++++++++++++----- .../unsafe/sort/UnsafeSorterSpillMerger.java | 5 +++-- .../execution/joins/SortMergeJoinExec.scala | 11 +---------- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 55e4e609c3c7..68a173d912ae 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -74,6 +74,8 @@ public final class UnsafeExternalSorter extends MemoryConsumer { */ private final int numElementsForSpillThreshold; + private boolean resourceCleand = false; + /** * Memory pages that hold the records being sorted. The pages in this list are freed when * spilling, although in principle we could recycle these pages across spills (on the other hand, @@ -324,6 +326,7 @@ public void cleanupResources() { synchronized (this) { deleteSpillFiles(); freeMemory(); + this.resourceCleand = true; if (inMemSorter != null) { inMemSorter.free(); inMemSorter = null; @@ -331,6 +334,10 @@ public void cleanupResources() { } } + public boolean isResourceCleand() { + return resourceCleand; + } + /** * Checks whether there is enough space to insert an additional record in to the sort pointer * array and grows the array if additional space is required. If the required space cannot be @@ -464,7 +471,7 @@ public UnsafeSorterIterator getSortedIterator() throws IOException { assert(recordComparatorSupplier != null); if (spillWriters.isEmpty()) { assert(inMemSorter != null); - readingIterator = new SpillableIterator(inMemSorter.getSortedIterator()); + readingIterator = new SpillableIterator(inMemSorter.getSortedIterator(), this); return readingIterator; } else { final UnsafeSorterSpillMerger spillMerger = new UnsafeSorterSpillMerger( @@ -473,10 +480,10 @@ public UnsafeSorterIterator getSortedIterator() throws IOException { spillMerger.addSpillIfNotEmpty(spillWriter.getReader(serializerManager)); } if (inMemSorter != null) { - readingIterator = new SpillableIterator(inMemSorter.getSortedIterator()); + readingIterator = new SpillableIterator(inMemSorter.getSortedIterator(), this); spillMerger.addSpillIfNotEmpty(readingIterator); } - return spillMerger.getSortedIterator(); + return spillMerger.getSortedIterator(this); } } @@ -503,12 +510,14 @@ class SpillableIterator extends UnsafeSorterIterator { private UnsafeSorterIterator upstream; private UnsafeSorterIterator nextUpstream = null; private MemoryBlock lastPage = null; + private UnsafeExternalSorter sorter; private boolean loaded = false; private int numRecords = 0; - SpillableIterator(UnsafeSorterIterator inMemIterator) { + SpillableIterator(UnsafeSorterIterator inMemIterator, UnsafeExternalSorter sorter) { this.upstream = inMemIterator; this.numRecords = inMemIterator.getNumRecords(); + this.sorter = sorter; } @Override @@ -566,7 +575,7 @@ public long spill() throws IOException { @Override public boolean hasNext() { - return numRecords > 0; + return !sorter.isResourceCleand() && numRecords > 0; } @Override diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java index ab800288dcb4..c7e7ae08a34a 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java @@ -60,7 +60,7 @@ public void addSpillIfNotEmpty(UnsafeSorterIterator spillReader) throws IOExcept } } - public UnsafeSorterIterator getSortedIterator() throws IOException { + public UnsafeSorterIterator getSortedIterator(UnsafeExternalSorter sorter) throws IOException { return new UnsafeSorterIterator() { private UnsafeSorterIterator spillReader; @@ -72,7 +72,8 @@ public int getNumRecords() { @Override public boolean hasNext() { - return !priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext()); + return !sorter.isResourceCleand() + && (!priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext())); } @Override diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index df8a0458515a..73a9a7ac3150 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -466,14 +466,6 @@ case class SortMergeJoinExec( // Inline mutable state since not many join operations in a task val matches = ctx.addMutableState(clsName, "matches", v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true) - // Maintain a boolean flag to track if right iterator has already been fully consumed. - // The following findNextInnerJoinRows could potentially invoke hasNext twice on the - // right iterator after it has already been fully consumed. The 1st time this happens, - // it would trigger releasing the resources of both left and right iterator. When it - // gets invoked a 2nd time, it could potentially lead to NPE. This boolean flag makes - // sure that does not happen. - val rightIterExhausted = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "rightIterExhausted", - v => s"$v = false;", forceInline = true) // Copy the left keys as class members so they could be used in next function call. val matchedKeyVars = copyKeys(ctx, leftKeyVars) @@ -502,8 +494,7 @@ case class SortMergeJoinExec( | | do { | if ($rightRow == null) { - | if ($rightIterExhausted || !rightIter.hasNext()) { - | $rightIterExhausted = true; + | if (!rightIter.hasNext()) { | ${matchedKeyVars.map(_.code).mkString("\n")} | return !$matches.isEmpty(); | } From 7ebfb3eb122bccdb6897629238c1043f9043d15b Mon Sep 17 00:00:00 2001 From: Liang Tang Date: Fri, 4 Oct 2019 17:56:09 -0700 Subject: [PATCH 7/9] A comprehensive test suite to identify coner cases of SMJ memory leak --- .../joins/SMJMemoryLeakTestSuite.scala | 159 ++++++++++++++++++ 1 file changed, 159 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SMJMemoryLeakTestSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SMJMemoryLeakTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SMJMemoryLeakTestSuite.scala new file mode 100644 index 000000000000..6814a1296f2c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SMJMemoryLeakTestSuite.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.execution.SparkPlanTest +import org.apache.spark.sql.functions.{col, rand} +import org.apache.spark.sql.test.SharedSQLContext + +/** + * This test suite permutates three sample test data, different join strategies, + * and enable/disable spilling during UnsafeExternalSort, generates more than 400 + * test cases, in order to discover potential memory leaks, that could happen in a corner + * combination of possibilities. + */ + +class SMJMemoryLeakTestSuite extends SparkPlanTest with SharedSQLContext { + + /** + * Calculates all permutations taking n elements of the input List, + * with repetitions. + * Precondition: input.length > 0 && n > 0 + */ + def permutationsWithRepetitions[T](input : List[T], n : Int) : List[List[T]] = { + require(input.length > 0 && n > 0) + n match { + case 1 => for (el <- input) yield List(el) + case _ => for (el <- input; perm <- permutationsWithRepetitions(input, n - 1)) yield el :: perm + } + } + + private lazy val df0 = spark.range(1, 1001).select(col("id")) + .withColumn("value", rand()).coalesce(1) + + private lazy val df1 = spark.range(1000, 2001).select(col("id")) + .withColumn("value", rand()).coalesce(1) + + private lazy val df2 = spark.range(1, 2001).select(col("id")) + .withColumn("value", rand()).coalesce(1) + + private val SMJwithSortSpillingConf = Seq( + ("spark.sql.join.preferSortMergeJoin", "true"), + ("spark.sql.autoBroadcastJoinThreshold", "-1"), + ("spark.sql.shuffle.partitions", "200"), + ("spark.sql.sortMergeJoinExec.buffer.spill.threshold", "1"), + ("spark.sql.sortMergeJoinExec.buffer.in.memory.threshold", "0") + ) + + private val SMJwithoutSortSpillingConf = Seq( + ("spark.sql.join.preferSortMergeJoin", "true"), + ("spark.sql.autoBroadcastJoinThreshold", "-1"), + ("spark.sql.shuffle.partitions", "200"), + ("spark.sql.sortMergeJoinExec.buffer.spill.threshold", "2000"), + ("spark.sql.sortMergeJoinExec.buffer.in.memory.threshold", "2000") + ) + + { + val list = List(0, 1, 2) + val joinTypes = List("inner", "leftsemi") + + // Permutate data with duplicates and job strategies + for (dataPerm <- permutationsWithRepetitions(list, 3); + joinPerm <- permutationsWithRepetitions(joinTypes, 2)) { + + val leftIndex = dataPerm(0) + val midIndex = dataPerm(1) + val rightIndex = dataPerm(2) + + // Enable spilling during SMJ + val testNameEnableSpilling = s"df$leftIndex, df$midIndex, and df$rightIndex. " + + s"jointype1: ${joinPerm(0)} and jointype2: ${joinPerm(1)} " + + s"spilling enabled " + joinUtility(testNameEnableSpilling, leftIndex, midIndex, rightIndex, + joinPerm(0), joinPerm(1), SMJwithSortSpillingConf: _*) + + // Disable spilling during SMJ + val testNameDisableSpilling = s"df$leftIndex, df$midIndex, and df$rightIndex. " + + s"jointype1: ${joinPerm(0)} and jointype2: ${joinPerm(1)} " + + s"spilling disabled " + joinUtility(testNameDisableSpilling, leftIndex, midIndex, rightIndex, + joinPerm(0), joinPerm(1), SMJwithoutSortSpillingConf: _*) + } + } + + private def joinUtility( testName: String, + leftIndex : Int, + midIndex : Int, + rightIndex: Int, + joinType1: String, + joinType2: String, + sqlConf: (String, String)*) { + + // One nested SMJ when inner SMJ occurring in the right side. + // SMJ + // / \ + // SMJ DF + // / \ + // DF DF + test(testName + " left sub tree") { + withSQLConf(sqlConf: _*) { + val list = Seq(df0, df1, df2) + val joined = list(leftIndex).join(list(midIndex), Seq("id"), joinType1).coalesce(1) + val joined2 = joined.join(list(rightIndex), Seq("id"), joinType2).coalesce(1) + + val cacheJoined = joined2.cache() + cacheJoined.explain() + cacheJoined.count() + } + } + + // One nested SMJ when inner SMJ occurring in the left side. + // SMJ + // / \ + // DF SMJ + // / \ + // DF DF + test(testName + "right sub tree") { + withSQLConf(sqlConf: _*) { + val list = Seq(df0, df1, df2) + val joined = list(midIndex).join(list(rightIndex), Seq("id"), joinType2).coalesce(1) + val joined2 = joined.join(list(leftIndex), Seq("id"), joinType1).coalesce(1) + + val cacheJoined = joined2.cache() + cacheJoined.explain() + cacheJoined.count() + } + } + } + + test("SPARK-21492 memory leak reproduction") { + spark.conf.set("spark.sql.join.preferSortMergeJoin", "true") + spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1) + spark.conf.set("spark.sql.shuffle.partitions", 200) + spark.conf.set("spark.sql.codegen.wholeStage", false) + + var r1 = spark.range(1, 1001).select(col("id").alias("timestamp1")) + r1 = r1.withColumn("value", rand()) + var r2 = spark.range(1000, 2001).select(col("id").alias("timestamp2")) + r2 = r2.withColumn("value2", rand()) + var joined = r1.join(r2, col("timestamp1") === col("timestamp2"), "inner") + joined = joined.coalesce(1) + joined.explain() + joined.count() + } +} From 830adffe6d8eca65fe0aabf629043a3dbd06bd6b Mon Sep 17 00:00:00 2001 From: Min Shen Date: Mon, 14 Oct 2019 10:16:55 -0700 Subject: [PATCH 8/9] Fix Scala style issue --- .../spark/sql/execution/joins/SMJMemoryLeakTestSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SMJMemoryLeakTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SMJMemoryLeakTestSuite.scala index 6814a1296f2c..06f43b99139a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SMJMemoryLeakTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SMJMemoryLeakTestSuite.scala @@ -39,7 +39,8 @@ class SMJMemoryLeakTestSuite extends SparkPlanTest with SharedSQLContext { require(input.length > 0 && n > 0) n match { case 1 => for (el <- input) yield List(el) - case _ => for (el <- input; perm <- permutationsWithRepetitions(input, n - 1)) yield el :: perm + case _ => + for (el <- input; perm <- permutationsWithRepetitions(input, n - 1)) yield el :: perm } } From 265cd931dad27dcefc8ac3d793c6bdb4d18d9710 Mon Sep 17 00:00:00 2001 From: Min Shen Date: Mon, 14 Oct 2019 13:54:57 -0700 Subject: [PATCH 9/9] Fix compatibility with current HEAD of trunk. --- .../spark/sql/execution/joins/SMJMemoryLeakTestSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SMJMemoryLeakTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SMJMemoryLeakTestSuite.scala index 06f43b99139a..6e0e705a0fd9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SMJMemoryLeakTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SMJMemoryLeakTestSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.execution.SparkPlanTest import org.apache.spark.sql.functions.{col, rand} -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SharedSparkSession /** * This test suite permutates three sample test data, different join strategies, @@ -28,7 +28,7 @@ import org.apache.spark.sql.test.SharedSQLContext * combination of possibilities. */ -class SMJMemoryLeakTestSuite extends SparkPlanTest with SharedSQLContext { +class SMJMemoryLeakTestSuite extends SparkPlanTest with SharedSparkSession { /** * Calculates all permutations taking n elements of the input List,