From 2387f1e316a90fc9a392ab69ee3c2257b622af4d Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Fri, 11 Aug 2017 14:57:06 +0100 Subject: [PATCH 1/8] [SPARK-21675][WEBUI] Add a navigation bar at the bottom of the Details for Stage Page ## What changes were proposed in this pull request? 1. In Spark Web UI, the Details for Stage Page don't have a navigation bar at the bottom. When we drop down to the bottom, it is better for us to see a navi bar right there to go wherever we what. 2. Executor ID is not equivalent to Host, it may be better to separate them, and then we can group the tasks by Hosts . ## How was this patch tested? manually test ![wx20170809-165606](https://user-images.githubusercontent.com/8326978/29114161-f82b4920-7d25-11e7-8d0c-0c036b008a78.png) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Kent Yao Closes #18893 from yaooqinn/SPARK-21675. --- .../scala/org/apache/spark/ui/PagedTable.scala | 4 +++- .../scala/org/apache/spark/ui/jobs/StagePage.scala | 14 +++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala index 79974df2603fd..65fa38387b9ee 100644 --- a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala @@ -94,14 +94,16 @@ private[ui] trait PagedTable[T] { val _dataSource = dataSource try { val PageData(totalPages, data) = _dataSource.pageData(page) + val pageNavi = pageNavigation(page, _dataSource.pageSize, totalPages)
- {pageNavigation(page, _dataSource.pageSize, totalPages)} + {pageNavi} {headers} {data.map(row)}
+ {pageNavi}
} catch { case e: IndexOutOfBoundsException => diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 8ed51746ab9d0..633e740b9c9bd 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -835,7 +835,8 @@ private[ui] class TaskTableRowData( val speculative: Boolean, val status: String, val taskLocality: String, - val executorIdAndHost: String, + val executorId: String, + val host: String, val launchTime: Long, val duration: Long, val formatDuration: String, @@ -1017,7 +1018,8 @@ private[ui] class TaskDataSource( info.speculative, info.status, info.taskLocality.toString, - s"${info.executorId} / ${info.host}", + info.executorId, + info.host, info.launchTime, duration, formatDuration, @@ -1047,7 +1049,8 @@ private[ui] class TaskDataSource( case "Attempt" => Ordering.by(_.attempt) case "Status" => Ordering.by(_.status) case "Locality Level" => Ordering.by(_.taskLocality) - case "Executor ID / Host" => Ordering.by(_.executorIdAndHost) + case "Executor ID" => Ordering.by(_.executorId) + case "Host" => Ordering.by(_.host) case "Launch Time" => Ordering.by(_.launchTime) case "Duration" => Ordering.by(_.duration) case "Scheduler Delay" => Ordering.by(_.schedulerDelay) @@ -1200,7 +1203,7 @@ private[ui] class TaskPagedTable( val taskHeadersAndCssClasses: Seq[(String, String)] = Seq( ("Index", ""), ("ID", ""), ("Attempt", ""), ("Status", ""), ("Locality Level", ""), - ("Executor ID / Host", ""), ("Launch Time", ""), ("Duration", ""), + ("Executor ID", ""), ("Host", ""), ("Launch Time", ""), ("Duration", ""), ("Scheduler Delay", TaskDetailsClassNames.SCHEDULER_DELAY), ("Task Deserialization Time", TaskDetailsClassNames.TASK_DESERIALIZATION_TIME), ("GC Time", ""), @@ -1271,8 +1274,9 @@ private[ui] class TaskPagedTable( {if (task.speculative) s"${task.attempt} (speculative)" else task.attempt.toString} {task.status} {task.taskLocality} + {task.executorId} -
{task.executorIdAndHost}
+
{task.host}
{ task.logs.map { From 0377338bf76c0d55963a338dd73076113a922de9 Mon Sep 17 00:00:00 2001 From: LucaCanali Date: Fri, 11 Aug 2017 12:03:37 -0700 Subject: [PATCH 2/8] [SPARK-21519][SQL] Add an option to the JDBC data source to initialize the target DB environment Add an option to the JDBC data source to initialize the environment of the remote database session ## What changes were proposed in this pull request? This proposes an option to the JDBC datasource, tentatively called " sessionInitStatement" to implement the functionality of session initialization present for example in the Sqoop connector for Oracle (see https://sqoop.apache.org/docs/1.4.6/SqoopUserGuide.html#_oraoop_oracle_session_initialization_statements ) . After each database session is opened to the remote DB, and before starting to read data, this option executes a custom SQL statement (or a PL/SQL block in the case of Oracle). See also https://issues.apache.org/jira/browse/SPARK-21519 ## How was this patch tested? Manually tested using Spark SQL data source and Oracle JDBC Author: LucaCanali Closes #18724 from LucaCanali/JDBC_datasource_sessionInitStatement. --- docs/sql-programming-guide.md | 7 +++++ .../datasources/jdbc/JDBCOptions.scala | 3 ++ .../execution/datasources/jdbc/JDBCRDD.scala | 15 +++++++++ .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 31 +++++++++++++++++++ 4 files changed, 56 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 2ac2383d699c4..ee231a934a3af 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1308,6 +1308,13 @@ the following case-insensitive options: + + sessionInitStatement + + After each database session is opened to the remote DB and before starting to read data, this option executes a custom SQL statement (or a PL/SQL block). Use this to implement session initialization code. Example: option("sessionInitStatement", """BEGIN execute immediate 'alter session set "_serial_direct_read"=true'; END;""") + + + truncate diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 96a8a51da18e5..05b00058618a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -138,6 +138,8 @@ class JDBCOptions( case "REPEATABLE_READ" => Connection.TRANSACTION_REPEATABLE_READ case "SERIALIZABLE" => Connection.TRANSACTION_SERIALIZABLE } + // An option to execute custom SQL before fetching data from the remote DB + val sessionInitStatement = parameters.get(JDBC_SESSION_INIT_STATEMENT) } object JDBCOptions { @@ -161,4 +163,5 @@ object JDBCOptions { val JDBC_CREATE_TABLE_COLUMN_TYPES = newOption("createTableColumnTypes") val JDBC_BATCH_INSERT_SIZE = newOption("batchsize") val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel") + val JDBC_SESSION_INIT_STATEMENT = newOption("sessionInitStatement") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 24e13697c0c9f..3274be91d4817 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -273,6 +273,21 @@ private[jdbc] class JDBCRDD( import scala.collection.JavaConverters._ dialect.beforeFetch(conn, options.asProperties.asScala.toMap) + // This executes a generic SQL statement (or PL/SQL block) before reading + // the table/query via JDBC. Use this feature to initialize the database + // session environment, e.g. for optimizations and/or troubleshooting. + options.sessionInitStatement match { + case Some(sql) => + val statement = conn.prepareStatement(sql) + logInfo(s"Executing sessionInitStatement: $sql") + try { + statement.execute() + } finally { + statement.close() + } + case None => + } + // H2's JDBC driver does not support the setSchema() method. We pass a // fully-qualified table name in the SELECT statement. I don't know how to // talk about a table in a completely portable way. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 4c43646889418..8dc11d80c3063 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -1044,4 +1044,35 @@ class JDBCSuite extends SparkFunSuite assert(sql("select * from people_view").count() == 3) } } + + test("SPARK-21519: option sessionInitStatement, run SQL to initialize the database session.") { + val initSQL1 = "SET @MYTESTVAR 21519" + val df1 = spark.read.format("jdbc") + .option("url", urlWithUserAndPass) + .option("dbtable", "(SELECT NVL(@MYTESTVAR, -1))") + .option("sessionInitStatement", initSQL1) + .load() + assert(df1.collect() === Array(Row(21519))) + + val initSQL2 = "SET SCHEMA DUMMY" + val df2 = spark.read.format("jdbc") + .option("url", urlWithUserAndPass) + .option("dbtable", "TEST.PEOPLE") + .option("sessionInitStatement", initSQL2) + .load() + val e = intercept[SparkException] {df2.collect()}.getMessage + assert(e.contains("""Schema "DUMMY" not found""")) + + sql( + s""" + |CREATE OR REPLACE TEMPORARY VIEW test_sessionInitStatement + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$urlWithUserAndPass', + |dbtable '(SELECT NVL(@MYTESTVAR1, -1), NVL(@MYTESTVAR2, -1))', + |sessionInitStatement 'SET @MYTESTVAR1 21519; SET @MYTESTVAR2 1234') + """.stripMargin) + + val df3 = sql("SELECT * FROM test_sessionInitStatement") + assert(df3.collect() === Array(Row(21519, 1234))) + } } From 94439997d57875838a8283c543f9b44705d3a503 Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Fri, 11 Aug 2017 22:01:00 +0200 Subject: [PATCH 3/8] [SPARK-21595] Separate thresholds for buffering and spilling in ExternalAppendOnlyUnsafeRowArray ## What changes were proposed in this pull request? [SPARK-21595](https://issues.apache.org/jira/browse/SPARK-21595) reported that there is excessive spilling to disk due to default spill threshold for `ExternalAppendOnlyUnsafeRowArray` being quite small for WINDOW operator. Old behaviour of WINDOW operator (pre https://github.com/apache/spark/pull/16909) would hold data in an array for first 4096 records post which it would switch to `UnsafeExternalSorter` and start spilling to disk after reaching `spark.shuffle.spill.numElementsForceSpillThreshold` (or earlier if there was paucity of memory due to excessive consumers). Currently the (switch from in-memory to `UnsafeExternalSorter`) and (`UnsafeExternalSorter` spilling to disk) for `ExternalAppendOnlyUnsafeRowArray` is controlled by a single threshold. This PR aims to separate that to have more granular control. ## How was this patch tested? Added unit tests Author: Tejas Patil Closes #18843 from tejasapatil/SPARK-21595. --- .../apache/spark/sql/internal/SQLConf.scala | 41 ++++++- .../ExternalAppendOnlyUnsafeRowArray.scala | 28 ++--- .../joins/CartesianProductExec.scala | 12 +- .../execution/joins/SortMergeJoinExec.scala | 24 +++- .../sql/execution/window/WindowExec.scala | 4 +- .../org/apache/spark/sql/JoinSuite.scala | 3 +- ...nalAppendOnlyUnsafeRowArrayBenchmark.scala | 7 +- ...xternalAppendOnlyUnsafeRowArraySuite.scala | 103 +++++++++++------- .../execution/SQLWindowFunctionSuite.scala | 3 +- 9 files changed, 155 insertions(+), 70 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ecb941c5fa9e6..733d80e9d46cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -844,24 +844,47 @@ object SQLConf { .stringConf .createWithDefaultFunction(() => TimeZone.getDefault.getID) + val WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD = + buildConf("spark.sql.windowExec.buffer.in.memory.threshold") + .internal() + .doc("Threshold for number of rows guaranteed to be held in memory by the window operator") + .intConf + .createWithDefault(4096) + val WINDOW_EXEC_BUFFER_SPILL_THRESHOLD = buildConf("spark.sql.windowExec.buffer.spill.threshold") .internal() - .doc("Threshold for number of rows buffered in window operator") + .doc("Threshold for number of rows to be spilled by window operator") .intConf - .createWithDefault(4096) + .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt) + + val SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD = + buildConf("spark.sql.sortMergeJoinExec.buffer.in.memory.threshold") + .internal() + .doc("Threshold for number of rows guaranteed to be held in memory by the sort merge " + + "join operator") + .intConf + .createWithDefault(Int.MaxValue) val SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD = buildConf("spark.sql.sortMergeJoinExec.buffer.spill.threshold") .internal() - .doc("Threshold for number of rows buffered in sort merge join operator") + .doc("Threshold for number of rows to be spilled by sort merge join operator") .intConf - .createWithDefault(Int.MaxValue) + .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt) + + val CARTESIAN_PRODUCT_EXEC_BUFFER_IN_MEMORY_THRESHOLD = + buildConf("spark.sql.cartesianProductExec.buffer.in.memory.threshold") + .internal() + .doc("Threshold for number of rows guaranteed to be held in memory by the cartesian " + + "product operator") + .intConf + .createWithDefault(4096) val CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD = buildConf("spark.sql.cartesianProductExec.buffer.spill.threshold") .internal() - .doc("Threshold for number of rows buffered in cartesian product operator") + .doc("Threshold for number of rows to be spilled by cartesian product operator") .intConf .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt) @@ -1137,11 +1160,19 @@ class SQLConf extends Serializable with Logging { def joinReorderDPStarFilter: Boolean = getConf(SQLConf.JOIN_REORDER_DP_STAR_FILTER) + def windowExecBufferInMemoryThreshold: Int = getConf(WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD) + def windowExecBufferSpillThreshold: Int = getConf(WINDOW_EXEC_BUFFER_SPILL_THRESHOLD) + def sortMergeJoinExecBufferInMemoryThreshold: Int = + getConf(SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD) + def sortMergeJoinExecBufferSpillThreshold: Int = getConf(SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD) + def cartesianProductExecBufferInMemoryThreshold: Int = + getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_IN_MEMORY_THRESHOLD) + def cartesianProductExecBufferSpillThreshold: Int = getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala index c4d383421f976..ac282ea2e94f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala @@ -31,16 +31,16 @@ import org.apache.spark.storage.BlockManager import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator} /** - * An append-only array for [[UnsafeRow]]s that spills content to disk when there a predefined - * threshold of rows is reached. + * An append-only array for [[UnsafeRow]]s that strictly keeps content in an in-memory array + * until [[numRowsInMemoryBufferThreshold]] is reached post which it will switch to a mode which + * would flush to disk after [[numRowsSpillThreshold]] is met (or before if there is + * excessive memory consumption). Setting these threshold involves following trade-offs: * - * Setting spill threshold faces following trade-off: - * - * - If the spill threshold is too high, the in-memory array may occupy more memory than is - * available, resulting in OOM. - * - If the spill threshold is too low, we spill frequently and incur unnecessary disk writes. - * This may lead to a performance regression compared to the normal case of using an - * [[ArrayBuffer]] or [[Array]]. + * - If [[numRowsInMemoryBufferThreshold]] is too high, the in-memory array may occupy more memory + * than is available, resulting in OOM. + * - If [[numRowsSpillThreshold]] is too low, data will be spilled frequently and lead to + * excessive disk writes. This may lead to a performance regression compared to the normal case + * of using an [[ArrayBuffer]] or [[Array]]. */ private[sql] class ExternalAppendOnlyUnsafeRowArray( taskMemoryManager: TaskMemoryManager, @@ -49,9 +49,10 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( taskContext: TaskContext, initialSize: Int, pageSizeBytes: Long, + numRowsInMemoryBufferThreshold: Int, numRowsSpillThreshold: Int) extends Logging { - def this(numRowsSpillThreshold: Int) { + def this(numRowsInMemoryBufferThreshold: Int, numRowsSpillThreshold: Int) { this( TaskContext.get().taskMemoryManager(), SparkEnv.get.blockManager, @@ -59,11 +60,12 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( TaskContext.get(), 1024, SparkEnv.get.memoryManager.pageSizeBytes, + numRowsInMemoryBufferThreshold, numRowsSpillThreshold) } private val initialSizeOfInMemoryBuffer = - Math.min(DefaultInitialSizeOfInMemoryBuffer, numRowsSpillThreshold) + Math.min(DefaultInitialSizeOfInMemoryBuffer, numRowsInMemoryBufferThreshold) private val inMemoryBuffer = if (initialSizeOfInMemoryBuffer > 0) { new ArrayBuffer[UnsafeRow](initialSizeOfInMemoryBuffer) @@ -102,11 +104,11 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( } def add(unsafeRow: UnsafeRow): Unit = { - if (numRows < numRowsSpillThreshold) { + if (numRows < numRowsInMemoryBufferThreshold) { inMemoryBuffer += unsafeRow.copy() } else { if (spillableArray == null) { - logInfo(s"Reached spill threshold of $numRowsSpillThreshold rows, switching to " + + logInfo(s"Reached spill threshold of $numRowsInMemoryBufferThreshold rows, switching to " + s"${classOf[UnsafeExternalSorter].getName}") // We will not sort the rows, so prefixComparator and recordComparator are null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala index f380986951317..4d261dd422bc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala @@ -35,11 +35,12 @@ class UnsafeCartesianRDD( left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int, + inMemoryBufferThreshold: Int, spillThreshold: Int) extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) { override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = { - val rowArray = new ExternalAppendOnlyUnsafeRowArray(spillThreshold) + val rowArray = new ExternalAppendOnlyUnsafeRowArray(inMemoryBufferThreshold, spillThreshold) val partition = split.asInstanceOf[CartesianPartition] rdd2.iterator(partition.s2, context).foreach(rowArray.add) @@ -71,9 +72,12 @@ case class CartesianProductExec( val leftResults = left.execute().asInstanceOf[RDD[UnsafeRow]] val rightResults = right.execute().asInstanceOf[RDD[UnsafeRow]] - val spillThreshold = sqlContext.conf.cartesianProductExecBufferSpillThreshold - - val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size, spillThreshold) + val pair = new UnsafeCartesianRDD( + leftResults, + rightResults, + right.output.size, + sqlContext.conf.cartesianProductExecBufferInMemoryThreshold, + sqlContext.conf.cartesianProductExecBufferSpillThreshold) pair.mapPartitionsWithIndexInternal { (index, iter) => val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema) val filtered = if (condition.isDefined) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index f41fa14213df5..91d214e1978e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -130,9 +130,14 @@ case class SortMergeJoinExec( sqlContext.conf.sortMergeJoinExecBufferSpillThreshold } + private def getInMemoryThreshold: Int = { + sqlContext.conf.sortMergeJoinExecBufferInMemoryThreshold + } + protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") val spillThreshold = getSpillThreshold + val inMemoryThreshold = getInMemoryThreshold left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => val boundCondition: (InternalRow) => Boolean = { condition.map { cond => @@ -158,6 +163,7 @@ case class SortMergeJoinExec( keyOrdering, RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter), + inMemoryThreshold, spillThreshold ) private[this] val joinRow = new JoinedRow @@ -201,6 +207,7 @@ case class SortMergeJoinExec( keyOrdering, streamedIter = RowIterator.fromScala(leftIter), bufferedIter = RowIterator.fromScala(rightIter), + inMemoryThreshold, spillThreshold ) val rightNullRow = new GenericInternalRow(right.output.length) @@ -214,6 +221,7 @@ case class SortMergeJoinExec( keyOrdering, streamedIter = RowIterator.fromScala(rightIter), bufferedIter = RowIterator.fromScala(leftIter), + inMemoryThreshold, spillThreshold ) val leftNullRow = new GenericInternalRow(left.output.length) @@ -247,6 +255,7 @@ case class SortMergeJoinExec( keyOrdering, RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter), + inMemoryThreshold, spillThreshold ) private[this] val joinRow = new JoinedRow @@ -281,6 +290,7 @@ case class SortMergeJoinExec( keyOrdering, RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter), + inMemoryThreshold, spillThreshold ) private[this] val joinRow = new JoinedRow @@ -322,6 +332,7 @@ case class SortMergeJoinExec( keyOrdering, RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter), + inMemoryThreshold, spillThreshold ) private[this] val joinRow = new JoinedRow @@ -420,8 +431,10 @@ case class SortMergeJoinExec( val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName val spillThreshold = getSpillThreshold + val inMemoryThreshold = getInMemoryThreshold - ctx.addMutableState(clsName, matches, s"$matches = new $clsName($spillThreshold);") + ctx.addMutableState(clsName, matches, + s"$matches = new $clsName($inMemoryThreshold, $spillThreshold);") // Copy the left keys as class members so they could be used in next function call. val matchedKeyVars = copyKeys(ctx, leftKeyVars) @@ -626,6 +639,9 @@ case class SortMergeJoinExec( * @param streamedIter an input whose rows will be streamed. * @param bufferedIter an input whose rows will be buffered to construct sequences of rows that * have the same join key. + * @param inMemoryThreshold Threshold for number of rows guaranteed to be held in memory by + * internal buffer + * @param spillThreshold Threshold for number of rows to be spilled by internal buffer */ private[joins] class SortMergeJoinScanner( streamedKeyGenerator: Projection, @@ -633,7 +649,8 @@ private[joins] class SortMergeJoinScanner( keyOrdering: Ordering[InternalRow], streamedIter: RowIterator, bufferedIter: RowIterator, - bufferThreshold: Int) { + inMemoryThreshold: Int, + spillThreshold: Int) { private[this] var streamedRow: InternalRow = _ private[this] var streamedRowKey: InternalRow = _ private[this] var bufferedRow: InternalRow = _ @@ -644,7 +661,8 @@ private[joins] class SortMergeJoinScanner( */ private[this] var matchJoinKey: InternalRow = _ /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */ - private[this] val bufferedMatches = new ExternalAppendOnlyUnsafeRowArray(bufferThreshold) + private[this] val bufferedMatches = + new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) // Initialization (note: do _not_ want to advance streamed here). advancedBufferedToRowWithNullFreeJoinKey() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index f8bb667e30064..800a2ea3f3996 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -292,6 +292,7 @@ case class WindowExec( // Unwrap the expressions and factories from the map. val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1) val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray + val inMemoryThreshold = sqlContext.conf.windowExecBufferInMemoryThreshold val spillThreshold = sqlContext.conf.windowExecBufferSpillThreshold // Start processing. @@ -322,7 +323,8 @@ case class WindowExec( val inputFields = child.output.length val buffer: ExternalAppendOnlyUnsafeRowArray = - new ExternalAppendOnlyUnsafeRowArray(spillThreshold) + new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) + var bufferIterator: Iterator[UnsafeRow] = _ val windowFunctionResult = new SpecificInternalRow(expressions.map(_.dataType)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 895ca196a7a51..0008d503a2cbe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -665,7 +665,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { test("test SortMergeJoin (with spill)") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1", - "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> "0") { + "spark.sql.sortMergeJoinExec.buffer.in.memory.threshold" -> "0", + "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> "1") { assertSpilled(sparkContext, "inner join") { checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala index 031ac38c17d7b..efe28afab08e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala @@ -67,7 +67,10 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark { benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int => var sum = 0L for (_ <- 0L until iterations) { - val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold) + val array = new ExternalAppendOnlyUnsafeRowArray( + ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer, + numSpillThreshold) + rows.foreach(x => array.add(x)) val iterator = array.generateIterator() @@ -143,7 +146,7 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark { benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int => var sum = 0L for (_ <- 0L until iterations) { - val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold) + val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold, numSpillThreshold) rows.foreach(x => array.add(x)) val iterator = array.generateIterator() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala index 53c41639942b4..ecc7264d79442 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala @@ -31,7 +31,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar override def afterAll(): Unit = TaskContext.unset() - private def withExternalArray(spillThreshold: Int) + private def withExternalArray(inMemoryThreshold: Int, spillThreshold: Int) (f: ExternalAppendOnlyUnsafeRowArray => Unit): Unit = { sc = new SparkContext("local", "test", new SparkConf(false)) @@ -45,6 +45,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar taskContext, 1024, SparkEnv.get.memoryManager.pageSizeBytes, + inMemoryThreshold, spillThreshold) try f(array) finally { array.clear() @@ -109,9 +110,9 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar assert(getNumBytesSpilled > 0) } - test("insert rows less than the spillThreshold") { - val spillThreshold = 100 - withExternalArray(spillThreshold) { array => + test("insert rows less than the inMemoryThreshold") { + val (inMemoryThreshold, spillThreshold) = (100, 50) + withExternalArray(inMemoryThreshold, spillThreshold) { array => assert(array.isEmpty) val expectedValues = populateRows(array, 1) @@ -122,8 +123,8 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar // Add more rows (but not too many to trigger switch to [[UnsafeExternalSorter]]) // Verify that NO spill has happened - populateRows(array, spillThreshold - 1, expectedValues) - assert(array.length == spillThreshold) + populateRows(array, inMemoryThreshold - 1, expectedValues) + assert(array.length == inMemoryThreshold) assertNoSpill() val iterator2 = validateData(array, expectedValues) @@ -133,20 +134,42 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } } - test("insert rows more than the spillThreshold to force spill") { - val spillThreshold = 100 - withExternalArray(spillThreshold) { array => - val numValuesInserted = 20 * spillThreshold - + test("insert rows more than the inMemoryThreshold but less than spillThreshold") { + val (inMemoryThreshold, spillThreshold) = (10, 50) + withExternalArray(inMemoryThreshold, spillThreshold) { array => assert(array.isEmpty) - val expectedValues = populateRows(array, 1) - assert(array.length == 1) + val expectedValues = populateRows(array, inMemoryThreshold - 1) + assert(array.length == (inMemoryThreshold - 1)) + val iterator1 = validateData(array, expectedValues) + assertNoSpill() + + // Add more rows to trigger switch to [[UnsafeExternalSorter]] but not too many to cause a + // spill to happen. Verify that NO spill has happened + populateRows(array, spillThreshold - expectedValues.length - 1, expectedValues) + assert(array.length == spillThreshold - 1) + assertNoSpill() + + val iterator2 = validateData(array, expectedValues) + assert(!iterator2.hasNext) + assert(!iterator1.hasNext) + intercept[ConcurrentModificationException](iterator1.next()) + } + } + + test("insert rows enough to force spill") { + val (inMemoryThreshold, spillThreshold) = (20, 10) + withExternalArray(inMemoryThreshold, spillThreshold) { array => + assert(array.isEmpty) + val expectedValues = populateRows(array, inMemoryThreshold - 1) + assert(array.length == (inMemoryThreshold - 1)) val iterator1 = validateData(array, expectedValues) + assertNoSpill() - // Populate more rows to trigger spill. Verify that spill has happened - populateRows(array, numValuesInserted - 1, expectedValues) - assert(array.length == numValuesInserted) + // Add more rows to trigger switch to [[UnsafeExternalSorter]] and cause a spill to happen. + // Verify that spill has happened + populateRows(array, 2, expectedValues) + assert(array.length == inMemoryThreshold + 1) assertSpill() val iterator2 = validateData(array, expectedValues) @@ -158,7 +181,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("iterator on an empty array should be empty") { - withExternalArray(spillThreshold = 10) { array => + withExternalArray(inMemoryThreshold = 4, spillThreshold = 10) { array => val iterator = array.generateIterator() assert(array.isEmpty) assert(array.length == 0) @@ -167,7 +190,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("generate iterator with negative start index") { - withExternalArray(spillThreshold = 2) { array => + withExternalArray(inMemoryThreshold = 100, spillThreshold = 56) { array => val exception = intercept[ArrayIndexOutOfBoundsException](array.generateIterator(startIndex = -10)) @@ -178,8 +201,8 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("generate iterator with start index exceeding array's size (without spill)") { - val spillThreshold = 2 - withExternalArray(spillThreshold) { array => + val (inMemoryThreshold, spillThreshold) = (20, 100) + withExternalArray(inMemoryThreshold, spillThreshold) { array => populateRows(array, spillThreshold / 2) val exception = @@ -191,8 +214,8 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("generate iterator with start index exceeding array's size (with spill)") { - val spillThreshold = 2 - withExternalArray(spillThreshold) { array => + val (inMemoryThreshold, spillThreshold) = (20, 100) + withExternalArray(inMemoryThreshold, spillThreshold) { array => populateRows(array, spillThreshold * 2) val exception = @@ -205,10 +228,10 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("generate iterator with custom start index (without spill)") { - val spillThreshold = 10 - withExternalArray(spillThreshold) { array => - val expectedValues = populateRows(array, spillThreshold) - val startIndex = spillThreshold / 2 + val (inMemoryThreshold, spillThreshold) = (20, 100) + withExternalArray(inMemoryThreshold, spillThreshold) { array => + val expectedValues = populateRows(array, inMemoryThreshold) + val startIndex = inMemoryThreshold / 2 val iterator = array.generateIterator(startIndex = startIndex) for (i <- startIndex until expectedValues.length) { checkIfValueExists(iterator, expectedValues(i)) @@ -217,8 +240,8 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("generate iterator with custom start index (with spill)") { - val spillThreshold = 10 - withExternalArray(spillThreshold) { array => + val (inMemoryThreshold, spillThreshold) = (20, 100) + withExternalArray(inMemoryThreshold, spillThreshold) { array => val expectedValues = populateRows(array, spillThreshold * 10) val startIndex = spillThreshold * 2 val iterator = array.generateIterator(startIndex = startIndex) @@ -229,7 +252,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("test iterator invalidation (without spill)") { - withExternalArray(spillThreshold = 10) { array => + withExternalArray(inMemoryThreshold = 10, spillThreshold = 100) { array => // insert 2 rows, iterate until the first row populateRows(array, 2) @@ -254,9 +277,9 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("test iterator invalidation (with spill)") { - val spillThreshold = 10 - withExternalArray(spillThreshold) { array => - // Populate enough rows so that spill has happens + val (inMemoryThreshold, spillThreshold) = (2, 10) + withExternalArray(inMemoryThreshold, spillThreshold) { array => + // Populate enough rows so that spill happens populateRows(array, spillThreshold * 2) assertSpill() @@ -281,7 +304,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("clear on an empty the array") { - withExternalArray(spillThreshold = 2) { array => + withExternalArray(inMemoryThreshold = 2, spillThreshold = 3) { array => val iterator = array.generateIterator() assert(!iterator.hasNext) @@ -299,10 +322,10 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("clear array (without spill)") { - val spillThreshold = 10 - withExternalArray(spillThreshold) { array => + val (inMemoryThreshold, spillThreshold) = (10, 100) + withExternalArray(inMemoryThreshold, spillThreshold) { array => // Populate rows ... but not enough to trigger spill - populateRows(array, spillThreshold / 2) + populateRows(array, inMemoryThreshold / 2) assertNoSpill() // Clear the array @@ -311,21 +334,21 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar // Re-populate few rows so that there is no spill // Verify the data. Verify that there was no spill - val expectedValues = populateRows(array, spillThreshold / 3) + val expectedValues = populateRows(array, inMemoryThreshold / 2) validateData(array, expectedValues) assertNoSpill() // Populate more rows .. enough to not trigger a spill. // Verify the data. Verify that there was no spill - populateRows(array, spillThreshold / 3, expectedValues) + populateRows(array, inMemoryThreshold / 2, expectedValues) validateData(array, expectedValues) assertNoSpill() } } test("clear array (with spill)") { - val spillThreshold = 10 - withExternalArray(spillThreshold) { array => + val (inMemoryThreshold, spillThreshold) = (10, 20) + withExternalArray(inMemoryThreshold, spillThreshold) { array => // Populate enough rows to trigger spill populateRows(array, spillThreshold * 2) val bytesSpilled = getNumBytesSpilled diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala index a9f3fb355c775..a57514c256b90 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala @@ -477,7 +477,8 @@ class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext { |WINDOW w1 AS (ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDiNG AND CURRENT RoW) """.stripMargin) - withSQLConf("spark.sql.windowExec.buffer.spill.threshold" -> "1") { + withSQLConf("spark.sql.windowExec.buffer.in.memory.threshold" -> "1", + "spark.sql.windowExec.buffer.spill.threshold" -> "2") { assertSpilled(sparkContext, "test with low buffer spill threshold") { checkAnswer(actual, expected) } From 7f16c6910700ab90fe8e382b4a99022f67696317 Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Fri, 11 Aug 2017 15:13:42 -0700 Subject: [PATCH 4/8] [SPARK-19122][SQL] Unnecessary shuffle+sort added if join predicates ordering differ from bucketing and sorting order ## What changes were proposed in this pull request? Jira : https://issues.apache.org/jira/browse/SPARK-19122 `leftKeys` and `rightKeys` in `SortMergeJoinExec` are altered based on the ordering of join keys in the child's `outputPartitioning`. This is done everytime `requiredChildDistribution` is invoked during query planning. ## How was this patch tested? - Added new test case - Existing tests Author: Tejas Patil Closes #16985 from tejasapatil/SPARK-19122_join_order_shuffle. --- .../spark/sql/execution/QueryExecution.scala | 2 + .../joins/ReorderJoinPredicates.scala | 94 +++++++++++++++++++ .../spark/sql/sources/BucketedReadSuite.scala | 59 ++++++++++++ 3 files changed, 155 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index b56fbd4284d2f..4accf54a18232 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} +import org.apache.spark.sql.execution.joins.ReorderJoinPredicates import org.apache.spark.sql.types.{BinaryType, DateType, DecimalType, TimestampType, _} import org.apache.spark.util.Utils @@ -103,6 +104,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { protected def preparations: Seq[Rule[SparkPlan]] = Seq( python.ExtractPythonUDFs, PlanSubqueries(sparkSession), + new ReorderJoinPredicates, EnsureRequirements(sparkSession.sessionState.conf), CollapseCodegenStages(sparkSession.sessionState.conf), ReuseExchange(sparkSession.sessionState.conf), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala new file mode 100644 index 0000000000000..534d8c5689c27 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.SparkPlan + +/** + * When the physical operators are created for JOIN, the ordering of join keys is based on order + * in which the join keys appear in the user query. That might not match with the output + * partitioning of the join node's children (thus leading to extra sort / shuffle being + * introduced). This rule will change the ordering of the join keys to match with the + * partitioning of the join nodes' children. + */ +class ReorderJoinPredicates extends Rule[SparkPlan] { + private def reorderJoinKeys( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + leftPartitioning: Partitioning, + rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = { + + def reorder( + expectedOrderOfKeys: Seq[Expression], + currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + val leftKeysBuffer = ArrayBuffer[Expression]() + val rightKeysBuffer = ArrayBuffer[Expression]() + + expectedOrderOfKeys.foreach(expression => { + val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression)) + leftKeysBuffer.append(leftKeys(index)) + rightKeysBuffer.append(rightKeys(index)) + }) + (leftKeysBuffer, rightKeysBuffer) + } + + if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) { + leftPartitioning match { + case HashPartitioning(leftExpressions, _) + if leftExpressions.length == leftKeys.length && + leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) => + reorder(leftExpressions, leftKeys) + + case _ => rightPartitioning match { + case HashPartitioning(rightExpressions, _) + if rightExpressions.length == rightKeys.length && + rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) => + reorder(rightExpressions, rightKeys) + + case _ => (leftKeys, rightKeys) + } + } + } else { + (leftKeys, rightKeys) + } + } + + def apply(plan: SparkPlan): SparkPlan = plan.transformUp { + case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => + val (reorderedLeftKeys, reorderedRightKeys) = + reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) + BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, + left, right) + + case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => + val (reorderedLeftKeys, reorderedRightKeys) = + reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) + ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, + left, right) + + case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right) => + val (reorderedLeftKeys, reorderedRightKeys) = + reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) + SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index ba0ca666b5c14..eb9e6458fc61c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -543,6 +543,65 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { ) } + test("SPARK-19122 Re-order join predicates if they match with the child's output partitioning") { + val bucketedTableTestSpec = BucketedTableTestSpec( + Some(BucketSpec(8, Seq("i", "j", "k"), Seq("i", "j", "k"))), + numPartitions = 1, + expectedShuffle = false, + expectedSort = false) + + // If the set of join columns is equal to the set of bucketed + sort columns, then + // the order of join keys in the query should not matter and there should not be any shuffle + // and sort added in the query plan + Seq( + Seq("i", "j", "k"), + Seq("i", "k", "j"), + Seq("j", "k", "i"), + Seq("j", "i", "k"), + Seq("k", "j", "i"), + Seq("k", "i", "j") + ).foreach(joinKeys => { + testBucketing( + bucketedTableTestSpecLeft = bucketedTableTestSpec, + bucketedTableTestSpecRight = bucketedTableTestSpec, + joinCondition = joinCondition(joinKeys) + ) + }) + } + + test("SPARK-19122 No re-ordering should happen if set of join columns != set of child's " + + "partitioning columns") { + + // join predicates is a super set of child's partitioning columns + val bucketedTableTestSpec1 = + BucketedTableTestSpec(Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))), numPartitions = 1) + testBucketing( + bucketedTableTestSpecLeft = bucketedTableTestSpec1, + bucketedTableTestSpecRight = bucketedTableTestSpec1, + joinCondition = joinCondition(Seq("i", "j", "k")) + ) + + // child's partitioning columns is a super set of join predicates + val bucketedTableTestSpec2 = + BucketedTableTestSpec(Some(BucketSpec(8, Seq("i", "j", "k"), Seq("i", "j", "k"))), + numPartitions = 1) + testBucketing( + bucketedTableTestSpecLeft = bucketedTableTestSpec2, + bucketedTableTestSpecRight = bucketedTableTestSpec2, + joinCondition = joinCondition(Seq("i", "j")) + ) + + // set of child's partitioning columns != set join predicates (despite the lengths of the + // sets are same) + val bucketedTableTestSpec3 = + BucketedTableTestSpec(Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))), numPartitions = 1) + testBucketing( + bucketedTableTestSpecLeft = bucketedTableTestSpec3, + bucketedTableTestSpecRight = bucketedTableTestSpec3, + joinCondition = joinCondition(Seq("j", "k")) + ) + } + test("error if there exists any malformed bucket files") { withTable("bucketed_table") { df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table") From da8c59bdeabd9ac9eace9b95eda8b8edecc8937e Mon Sep 17 00:00:00 2001 From: Stavros Kontopoulos Date: Fri, 11 Aug 2017 15:49:58 -0700 Subject: [PATCH 5/8] [SPARK-12559][SPARK SUBMIT] fix --packages for stand-alone cluster mode Fixes --packages flag for the stand-alone case in cluster mode. Adds to the driver classpath the jars that are resolved via ivy along with any other jars passed to `spark.jars`. Jars not resolved by ivy are downloaded explicitly to a tmp folder on the driver node. Similar code is available in SparkSubmit so we refactored part of it to use it at the DriverWrapper class which is responsible for launching driver in standalone cluster mode. Note: In stand-alone mode `spark.jars` contains the user jar so it can be fetched later on at the executor side. Manually by submitting a driver in cluster mode within a standalone cluster and checking if dependencies were resolved at the driver side. Author: Stavros Kontopoulos Closes #18630 from skonto/fix_packages_stand_alone_cluster. --- .../apache/spark/deploy/DependencyUtils.scala | 99 +++++++++++++++++++ .../org/apache/spark/deploy/SparkSubmit.scala | 52 +++------- .../spark/deploy/worker/DriverWrapper.scala | 23 +++++ 3 files changed, 137 insertions(+), 37 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala new file mode 100644 index 0000000000000..97f3803aafce4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.io.File +import java.nio.file.Files + +import scala.collection.mutable.HashMap + +import org.apache.commons.io.FileUtils +import org.apache.commons.lang3.StringUtils +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.util.MutableURLClassLoader + +private[deploy] object DependencyUtils { + + def resolveMavenDependencies( + packagesExclusions: String, + packages: String, + repositories: String, + ivyRepoPath: String): String = { + val exclusions: Seq[String] = + if (!StringUtils.isBlank(packagesExclusions)) { + packagesExclusions.split(",") + } else { + Nil + } + // Create the IvySettings, either load from file or build defaults + val ivySettings = sys.props.get("spark.jars.ivySettings").map { ivySettingsFile => + SparkSubmitUtils.loadIvySettings(ivySettingsFile, Option(repositories), Option(ivyRepoPath)) + }.getOrElse { + SparkSubmitUtils.buildIvySettings(Option(repositories), Option(ivyRepoPath)) + } + + SparkSubmitUtils.resolveMavenCoordinates(packages, ivySettings, exclusions = exclusions) + } + + def createTempDir(): File = { + val targetDir = Files.createTempDirectory("tmp").toFile + // scalastyle:off runtimeaddshutdownhook + Runtime.getRuntime.addShutdownHook(new Thread() { + override def run(): Unit = { + FileUtils.deleteQuietly(targetDir) + } + }) + // scalastyle:on runtimeaddshutdownhook + targetDir + } + + def resolveAndDownloadJars(jars: String, userJar: String): String = { + val targetDir = DependencyUtils.createTempDir() + val hadoopConf = new Configuration() + val sparkProperties = new HashMap[String, String]() + val securityProperties = List("spark.ssl.fs.trustStore", "spark.ssl.trustStore", + "spark.ssl.fs.trustStorePassword", "spark.ssl.trustStorePassword", + "spark.ssl.fs.protocol", "spark.ssl.protocol") + + securityProperties.foreach { pName => + sys.props.get(pName).foreach { pValue => + sparkProperties.put(pName, pValue) + } + } + + Option(jars) + .map { + SparkSubmit.resolveGlobPaths(_, hadoopConf) + .split(",") + .filterNot(_.contains(userJar.split("/").last)) + .mkString(",") + } + .filterNot(_ == "") + .map(SparkSubmit.downloadFileList(_, targetDir, sparkProperties, hadoopConf)) + .orNull + } + + def addJarsToClassPath(jars: String, loader: MutableURLClassLoader): Unit = { + if (jars != null) { + for (jar <- jars.split(",")) { + SparkSubmit.addJarToClasspath(jar, loader) + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 0ea14361b2f77..019780076e7e7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -20,7 +20,6 @@ package org.apache.spark.deploy import java.io._ import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowableException} import java.net.URL -import java.nio.file.Files import java.security.{KeyStore, PrivilegedExceptionAction} import java.security.cert.X509Certificate import java.text.ParseException @@ -31,7 +30,6 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, Map} import scala.util.Properties import com.google.common.io.ByteStreams -import org.apache.commons.io.FileUtils import org.apache.commons.lang3.StringUtils import org.apache.hadoop.conf.{Configuration => HadoopConfiguration} import org.apache.hadoop.fs.{FileSystem, Path} @@ -300,28 +298,13 @@ object SparkSubmit extends CommandLineUtils { } val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER val isMesosCluster = clusterManager == MESOS && deployMode == CLUSTER + val isStandAloneCluster = clusterManager == STANDALONE && deployMode == CLUSTER - if (!isMesosCluster) { + if (!isMesosCluster && !isStandAloneCluster) { // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files // too for packages that include Python code - val exclusions: Seq[String] = - if (!StringUtils.isBlank(args.packagesExclusions)) { - args.packagesExclusions.split(",") - } else { - Nil - } - - // Create the IvySettings, either load from file or build defaults - val ivySettings = args.sparkProperties.get("spark.jars.ivySettings").map { ivySettingsFile => - SparkSubmitUtils.loadIvySettings(ivySettingsFile, Option(args.repositories), - Option(args.ivyRepoPath)) - }.getOrElse { - SparkSubmitUtils.buildIvySettings(Option(args.repositories), Option(args.ivyRepoPath)) - } - - val resolvedMavenCoordinates = SparkSubmitUtils.resolveMavenCoordinates(args.packages, - ivySettings, exclusions = exclusions) - + val resolvedMavenCoordinates = DependencyUtils.resolveMavenDependencies( + args.packagesExclusions, args.packages, args.repositories, args.ivyRepoPath) if (!StringUtils.isBlank(resolvedMavenCoordinates)) { args.jars = mergeFileLists(args.jars, resolvedMavenCoordinates) @@ -338,14 +321,7 @@ object SparkSubmit extends CommandLineUtils { } val hadoopConf = new HadoopConfiguration() - val targetDir = Files.createTempDirectory("tmp").toFile - // scalastyle:off runtimeaddshutdownhook - Runtime.getRuntime.addShutdownHook(new Thread() { - override def run(): Unit = { - FileUtils.deleteQuietly(targetDir) - } - }) - // scalastyle:on runtimeaddshutdownhook + val targetDir = DependencyUtils.createTempDir() // Resolve glob path for different resources. args.jars = Option(args.jars).map(resolveGlobPaths(_, hadoopConf)).orNull @@ -473,11 +449,13 @@ object SparkSubmit extends CommandLineUtils { OptionAssigner(args.driverExtraLibraryPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.driver.extraLibraryPath"), - // Mesos only - propagate attributes for dependency resolution at the driver side - OptionAssigner(args.packages, MESOS, CLUSTER, sysProp = "spark.jars.packages"), - OptionAssigner(args.repositories, MESOS, CLUSTER, sysProp = "spark.jars.repositories"), - OptionAssigner(args.ivyRepoPath, MESOS, CLUSTER, sysProp = "spark.jars.ivy"), - OptionAssigner(args.packagesExclusions, MESOS, CLUSTER, sysProp = "spark.jars.excludes"), + // Propagate attributes for dependency resolution at the driver side + OptionAssigner(args.packages, STANDALONE | MESOS, CLUSTER, sysProp = "spark.jars.packages"), + OptionAssigner(args.repositories, STANDALONE | MESOS, CLUSTER, + sysProp = "spark.jars.repositories"), + OptionAssigner(args.ivyRepoPath, STANDALONE | MESOS, CLUSTER, sysProp = "spark.jars.ivy"), + OptionAssigner(args.packagesExclusions, STANDALONE | MESOS, + CLUSTER, sysProp = "spark.jars.excludes"), // Yarn only OptionAssigner(args.queue, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.queue"), @@ -780,7 +758,7 @@ object SparkSubmit extends CommandLineUtils { } } - private def addJarToClasspath(localJar: String, loader: MutableURLClassLoader) { + private[deploy] def addJarToClasspath(localJar: String, loader: MutableURLClassLoader) { val uri = Utils.resolveURI(localJar) uri.getScheme match { case "file" | "local" => @@ -845,7 +823,7 @@ object SparkSubmit extends CommandLineUtils { * Merge a sequence of comma-separated file lists, some of which may be null to indicate * no files, into a single comma-separated string. */ - private def mergeFileLists(lists: String*): String = { + private[deploy] def mergeFileLists(lists: String*): String = { val merged = lists.filterNot(StringUtils.isBlank) .flatMap(_.split(",")) .mkString(",") @@ -968,7 +946,7 @@ object SparkSubmit extends CommandLineUtils { } } - private def resolveGlobPaths(paths: String, hadoopConf: HadoopConfiguration): String = { + private[deploy] def resolveGlobPaths(paths: String, hadoopConf: HadoopConfiguration): String = { require(paths != null, "paths cannot be null.") paths.split(",").map(_.trim).filter(_.nonEmpty).flatMap { path => val uri = Utils.resolveURI(path) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index 6799f78ec0c19..cd3e361530c18 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -19,7 +19,10 @@ package org.apache.spark.deploy.worker import java.io.File +import org.apache.commons.lang3.StringUtils + import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.deploy.{DependencyUtils, SparkSubmit} import org.apache.spark.rpc.RpcEnv import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} @@ -51,6 +54,7 @@ object DriverWrapper { new MutableURLClassLoader(Array(userJarUrl), currentLoader) } Thread.currentThread.setContextClassLoader(loader) + setupDependencies(loader, userJar) // Delegate to supplied main class val clazz = Utils.classForName(mainClass) @@ -66,4 +70,23 @@ object DriverWrapper { System.exit(-1) } } + + private def setupDependencies(loader: MutableURLClassLoader, userJar: String): Unit = { + val Seq(packagesExclusions, packages, repositories, ivyRepoPath) = + Seq("spark.jars.excludes", "spark.jars.packages", "spark.jars.repositories", "spark.jars.ivy") + .map(sys.props.get(_).orNull) + + val resolvedMavenCoordinates = DependencyUtils.resolveMavenDependencies(packagesExclusions, + packages, repositories, ivyRepoPath) + val jars = { + val jarsProp = sys.props.get("spark.jars").orNull + if (!StringUtils.isBlank(resolvedMavenCoordinates)) { + SparkSubmit.mergeFileLists(jarsProp, resolvedMavenCoordinates) + } else { + jarsProp + } + } + val localJars = DependencyUtils.resolveAndDownloadJars(jars, userJar) + DependencyUtils.addJarsToClassPath(localJars, loader) + } } From b0bdfce9cae986096f327e2c7a5bdaa900dedc32 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 12 Aug 2017 14:31:05 +0900 Subject: [PATCH 6/8] [MINOR][BUILD] Download RAT and R version info over HTTPS; use RAT 0.12 ## What changes were proposed in this pull request? This is trivial, but bugged me. We should download software over HTTPS. And we can use RAT 0.12 while at it to pick up bug fixes. ## How was this patch tested? N/A Author: Sean Owen Closes #18927 from srowen/Rat012. --- dev/appveyor-install-dependencies.ps1 | 2 +- dev/check-license | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dev/appveyor-install-dependencies.ps1 b/dev/appveyor-install-dependencies.ps1 index a357fbf59f6c8..e6afb18558852 100644 --- a/dev/appveyor-install-dependencies.ps1 +++ b/dev/appveyor-install-dependencies.ps1 @@ -26,7 +26,7 @@ Function InstallR { } $urlPath = "" - $latestVer = $(ConvertFrom-JSON $(Invoke-WebRequest http://rversions.r-pkg.org/r-release-win).Content).version + $latestVer = $(ConvertFrom-JSON $(Invoke-WebRequest https://rversions.r-pkg.org/r-release-win).Content).version If ($rVer -ne $latestVer) { $urlPath = ("old/" + $rVer + "/") } diff --git a/dev/check-license b/dev/check-license index 678e73fd60f1f..8cee09a53e087 100755 --- a/dev/check-license +++ b/dev/check-license @@ -20,7 +20,7 @@ acquire_rat_jar () { - URL="http://repo1.maven.org/maven2/org/apache/rat/apache-rat/${RAT_VERSION}/apache-rat-${RAT_VERSION}.jar" + URL="https://repo1.maven.org/maven2/org/apache/rat/apache-rat/${RAT_VERSION}/apache-rat-${RAT_VERSION}.jar" JAR="$rat_jar" @@ -58,7 +58,7 @@ else declare java_cmd=java fi -export RAT_VERSION=0.11 +export RAT_VERSION=0.12 export rat_jar="$FWDIR"/lib/apache-rat-${RAT_VERSION}.jar mkdir -p "$FWDIR"/lib From 35db3b9fe38dadfb8afb0b0857c09f83196398be Mon Sep 17 00:00:00 2001 From: Ajay Saini Date: Fri, 11 Aug 2017 23:57:08 -0700 Subject: [PATCH 7/8] [SPARK-17025][ML][PYTHON] Persistence for Pipelines with Python-only Stages ## What changes were proposed in this pull request? Implemented a Python-only persistence framework for pipelines containing stages that cannot be saved using Java. ## How was this patch tested? Created a custom Python-only UnaryTransformer, included it in a Pipeline, and saved/loaded the pipeline. The loaded pipeline was compared against the original using _compare_pipelines() in tests.py. Author: Ajay Saini Closes #18888 from ajaysaini725/PythonPipelines. --- python/pyspark/ml/pipeline.py | 156 ++++++++++++++++++++++++++++++++-- python/pyspark/ml/tests.py | 35 +++++++- 2 files changed, 183 insertions(+), 8 deletions(-) diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index a8dc76b846c24..097530230cbca 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -16,6 +16,7 @@ # import sys +import os if sys.version > '3': basestring = str @@ -23,7 +24,7 @@ from pyspark import since, keyword_only, SparkContext from pyspark.ml.base import Estimator, Model, Transformer from pyspark.ml.param import Param, Params -from pyspark.ml.util import JavaMLWriter, JavaMLReader, MLReadable, MLWritable +from pyspark.ml.util import * from pyspark.ml.wrapper import JavaParams from pyspark.ml.common import inherit_doc @@ -130,13 +131,16 @@ def copy(self, extra=None): @since("2.0.0") def write(self): """Returns an MLWriter instance for this ML instance.""" - return JavaMLWriter(self) + allStagesAreJava = PipelineSharedReadWrite.checkStagesForJava(self.getStages()) + if allStagesAreJava: + return JavaMLWriter(self) + return PipelineWriter(self) @classmethod @since("2.0.0") def read(cls): """Returns an MLReader instance for this class.""" - return JavaMLReader(cls) + return PipelineReader(cls) @classmethod def _from_java(cls, java_stage): @@ -171,6 +175,76 @@ def _to_java(self): return _java_obj +@inherit_doc +class PipelineWriter(MLWriter): + """ + (Private) Specialization of :py:class:`MLWriter` for :py:class:`Pipeline` types + """ + + def __init__(self, instance): + super(PipelineWriter, self).__init__() + self.instance = instance + + def saveImpl(self, path): + stages = self.instance.getStages() + PipelineSharedReadWrite.validateStages(stages) + PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path) + + +@inherit_doc +class PipelineReader(MLReader): + """ + (Private) Specialization of :py:class:`MLReader` for :py:class:`Pipeline` types + """ + + def __init__(self, cls): + super(PipelineReader, self).__init__() + self.cls = cls + + def load(self, path): + metadata = DefaultParamsReader.loadMetadata(path, self.sc) + if 'language' not in metadata['paramMap'] or metadata['paramMap']['language'] != 'Python': + return JavaMLReader(self.cls).load(path) + else: + uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path) + return Pipeline(stages=stages)._resetUid(uid) + + +@inherit_doc +class PipelineModelWriter(MLWriter): + """ + (Private) Specialization of :py:class:`MLWriter` for :py:class:`PipelineModel` types + """ + + def __init__(self, instance): + super(PipelineModelWriter, self).__init__() + self.instance = instance + + def saveImpl(self, path): + stages = self.instance.stages + PipelineSharedReadWrite.validateStages(stages) + PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path) + + +@inherit_doc +class PipelineModelReader(MLReader): + """ + (Private) Specialization of :py:class:`MLReader` for :py:class:`PipelineModel` types + """ + + def __init__(self, cls): + super(PipelineModelReader, self).__init__() + self.cls = cls + + def load(self, path): + metadata = DefaultParamsReader.loadMetadata(path, self.sc) + if 'language' not in metadata['paramMap'] or metadata['paramMap']['language'] != 'Python': + return JavaMLReader(self.cls).load(path) + else: + uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path) + return PipelineModel(stages=stages)._resetUid(uid) + + @inherit_doc class PipelineModel(Model, MLReadable, MLWritable): """ @@ -204,13 +278,16 @@ def copy(self, extra=None): @since("2.0.0") def write(self): """Returns an MLWriter instance for this ML instance.""" - return JavaMLWriter(self) + allStagesAreJava = PipelineSharedReadWrite.checkStagesForJava(self.stages) + if allStagesAreJava: + return JavaMLWriter(self) + return PipelineModelWriter(self) @classmethod @since("2.0.0") def read(cls): """Returns an MLReader instance for this class.""" - return JavaMLReader(cls) + return PipelineModelReader(cls) @classmethod def _from_java(cls, java_stage): @@ -242,3 +319,72 @@ def _to_java(self): JavaParams._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages) return _java_obj + + +@inherit_doc +class PipelineSharedReadWrite(): + """ + .. note:: DeveloperApi + + Functions for :py:class:`MLReader` and :py:class:`MLWriter` shared between + :py:class:`Pipeline` and :py:class:`PipelineModel` + + .. versionadded:: 2.3.0 + """ + + @staticmethod + def checkStagesForJava(stages): + return all(isinstance(stage, JavaMLWritable) for stage in stages) + + @staticmethod + def validateStages(stages): + """ + Check that all stages are Writable + """ + for stage in stages: + if not isinstance(stage, MLWritable): + raise ValueError("Pipeline write will fail on this pipeline " + + "because stage %s of type %s is not MLWritable", + stage.uid, type(stage)) + + @staticmethod + def saveImpl(instance, stages, sc, path): + """ + Save metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel` + - save metadata to path/metadata + - save stages to stages/IDX_UID + """ + stageUids = [stage.uid for stage in stages] + jsonParams = {'stageUids': stageUids, 'language': 'Python'} + DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap=jsonParams) + stagesDir = os.path.join(path, "stages") + for index, stage in enumerate(stages): + stage.write().save(PipelineSharedReadWrite + .getStagePath(stage.uid, index, len(stages), stagesDir)) + + @staticmethod + def load(metadata, sc, path): + """ + Load metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel` + + :return: (UID, list of stages) + """ + stagesDir = os.path.join(path, "stages") + stageUids = metadata['paramMap']['stageUids'] + stages = [] + for index, stageUid in enumerate(stageUids): + stagePath = \ + PipelineSharedReadWrite.getStagePath(stageUid, index, len(stageUids), stagesDir) + stage = DefaultParamsReader.loadParamsInstance(stagePath, sc) + stages.append(stage) + return (metadata['uid'], stages) + + @staticmethod + def getStagePath(stageUid, stageIdx, numStages, stagesDir): + """ + Get path for saving the given stage. + """ + stageIdxDigits = len(str(numStages)) + stageDir = str(stageIdx).zfill(stageIdxDigits) + "_" + stageUid + stagePath = os.path.join(stagesDir, stageDir) + return stagePath diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 6aecc7fe87074..0495973d2f625 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -123,7 +123,7 @@ def _transform(self, dataset): return dataset -class MockUnaryTransformer(UnaryTransformer): +class MockUnaryTransformer(UnaryTransformer, DefaultParamsReadable, DefaultParamsWritable): shift = Param(Params._dummy(), "shift", "The amount by which to shift " + "data in a DataFrame", @@ -150,7 +150,7 @@ def outputDataType(self): def validateInputType(self, inputType): if inputType != DoubleType(): raise TypeError("Bad input type: {}. ".format(inputType) + - "Requires Integer.") + "Requires Double.") class MockEstimator(Estimator, HasFake): @@ -1063,7 +1063,7 @@ def _compare_pipelines(self, m1, m2): """ self.assertEqual(m1.uid, m2.uid) self.assertEqual(type(m1), type(m2)) - if isinstance(m1, JavaParams): + if isinstance(m1, JavaParams) or isinstance(m1, Transformer): self.assertEqual(len(m1.params), len(m2.params)) for p in m1.params: self._compare_params(m1, m2, p) @@ -1142,6 +1142,35 @@ def test_nested_pipeline_persistence(self): except OSError: pass + def test_python_transformer_pipeline_persistence(self): + """ + Pipeline[MockUnaryTransformer, Binarizer] + """ + temp_path = tempfile.mkdtemp() + + try: + df = self.spark.range(0, 10).toDF('input') + tf = MockUnaryTransformer(shiftVal=2)\ + .setInputCol("input").setOutputCol("shiftedInput") + tf2 = Binarizer(threshold=6, inputCol="shiftedInput", outputCol="binarized") + pl = Pipeline(stages=[tf, tf2]) + model = pl.fit(df) + + pipeline_path = temp_path + "/pipeline" + pl.save(pipeline_path) + loaded_pipeline = Pipeline.load(pipeline_path) + self._compare_pipelines(pl, loaded_pipeline) + + model_path = temp_path + "/pipeline-model" + model.save(model_path) + loaded_model = PipelineModel.load(model_path) + self._compare_pipelines(model, loaded_model) + finally: + try: + rmtree(temp_path) + except OSError: + pass + def test_onevsrest(self): temp_path = tempfile.mkdtemp() df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), From c0e333dbed2571ed995c27576c3ef616446a61ce Mon Sep 17 00:00:00 2001 From: "pj.fanning" Date: Sat, 12 Aug 2017 20:01:20 +0100 Subject: [PATCH 8/8] [SPARK-21709][BUILD] sbt 0.13.16 and some plugin updates ## What changes were proposed in this pull request? Update sbt version to 0.13.16. I think this is a useful stepping stone to getting to sbt 1.0.0. ## How was this patch tested? Existing Build. Author: pj.fanning Closes #18921 from pjfanning/SPARK-21709. --- dev/mima | 2 +- project/build.properties | 2 +- project/plugins.sbt | 13 +++++++++++-- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/dev/mima b/dev/mima index 85b09dbb1bf20..5501589b7900a 100755 --- a/dev/mima +++ b/dev/mima @@ -41,7 +41,7 @@ $JAVA_CMD \ -cp "$TOOLS_CLASSPATH:$OLD_DEPS_CLASSPATH" \ org.apache.spark.tools.GenerateMIMAIgnore -echo -e "q\n" | build/sbt -DcopyDependencies=false "$@" mimaReportBinaryIssues | grep -v -e "info.*Resolving" +echo -e "q\n" | build/sbt -mem 4096 -DcopyDependencies=false "$@" mimaReportBinaryIssues | grep -v -e "info.*Resolving" ret_val=$? if [ $ret_val != 0 ]; then diff --git a/project/build.properties b/project/build.properties index d339865ab915a..b19518fd7aa1c 100644 --- a/project/build.properties +++ b/project/build.properties @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -sbt.version=0.13.13 +sbt.version=0.13.16 diff --git a/project/plugins.sbt b/project/plugins.sbt index 84d123999085e..2b49c297ff9cb 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,25 +1,34 @@ +// need to make changes to uptake sbt 1.0 support in "com.eed3si9n" % "sbt-assembly" % "1.14.5" addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2") -addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "5.0.1") +// sbt 1.0.0 support: https://github.com/typesafehub/sbteclipse/issues/343 +addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "5.1.0") +// sbt 1.0.0 support: https://github.com/jrudolph/sbt-dependency-graph/issues/134 addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.8.2") +// need to make changes to uptake sbt 1.0 support in "org.scalastyle" %% "scalastyle-sbt-plugin" % "0.9.0" addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.8.0") -addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.12") +addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.17") +// sbt 1.0.0 support: https://github.com/AlpineNow/junit_xml_listener/issues/6 addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.1") +// need to make changes to uptake sbt 1.0 support in "com.eed3si9n" % "sbt-unidoc" % "0.4.1" addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.3") +// need to make changes to uptake sbt 1.0 support in "com.cavorite" % "sbt-avro-1-7" % "1.1.2" addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2") +// sbt 1.0.0 support: https://github.com/spray/sbt-revolver/issues/62 addSbtPlugin("io.spray" % "sbt-revolver" % "0.8.0") libraryDependencies += "org.ow2.asm" % "asm" % "5.1" libraryDependencies += "org.ow2.asm" % "asm-commons" % "5.1" +// sbt 1.0.0 support: https://github.com/ihji/sbt-antlr4/issues/14 addSbtPlugin("com.simplytyped" % "sbt-antlr4" % "0.7.11") // Spark uses a custom fork of the sbt-pom-reader plugin which contains a patch to fix issues