diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 716d604ca31b4..066512d159d00 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -214,7 +214,6 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We val someJobHasJobGroup = jobs.exists(_.jobGroup.isDefined) val jobIdTitle = if (someJobHasJobGroup) "Job Id (Job Group)" else "Job Id" val jobPage = Option(request.getParameter(jobTag + ".page")).map(_.toInt).getOrElse(1) - val currentTime = System.currentTimeMillis() try { new JobPagedTable( @@ -226,7 +225,6 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We UIUtils.prependBaseUri(request, parent.basePath), "jobs", // subPath killEnabled, - currentTime, jobIdTitle ).table(jobPage) } catch { @@ -399,7 +397,6 @@ private[ui] class JobDataSource( store: AppStatusStore, jobs: Seq[v1.JobData], basePath: String, - currentTime: Long, pageSize: Int, sortColumn: String, desc: Boolean) extends PagedDataSource[JobTableRowData](pageSize) { @@ -410,15 +407,9 @@ private[ui] class JobDataSource( // so that we can avoid creating duplicate contents during sorting the data private val data = jobs.map(jobRow).sorted(ordering(sortColumn, desc)) - private var _slicedJobIds: Set[Int] = null - override def dataSize: Int = data.size - override def sliceData(from: Int, to: Int): Seq[JobTableRowData] = { - val r = data.slice(from, to) - _slicedJobIds = r.map(_.jobData.jobId).toSet - r - } + override def sliceData(from: Int, to: Int): Seq[JobTableRowData] = data.slice(from, to) private def jobRow(jobData: v1.JobData): JobTableRowData = { val duration: Option[Long] = JobDataUtil.getDuration(jobData) @@ -479,17 +470,17 @@ private[ui] class JobPagedTable( basePath: String, subPath: String, killEnabled: Boolean, - currentTime: Long, jobIdTitle: String ) extends PagedTable[JobTableRowData] { + private val (sortColumn, desc, pageSize) = getTableParameters(request, jobTag, jobIdTitle) private val parameterPath = basePath + s"/$subPath/?" + getParameterOtherTable(request, jobTag) + private val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) override def tableId: String = jobTag + "-table" override def tableCssClass: String = - "table table-bordered table-sm table-striped " + - "table-head-clickable table-cell-width-limited" + "table table-bordered table-sm table-striped table-head-clickable table-cell-width-limited" override def pageSizeFormField: String = jobTag + ".pageSize" @@ -499,13 +490,11 @@ private[ui] class JobPagedTable( store, data, basePath, - currentTime, pageSize, sortColumn, desc) override def pageLink(page: Int): String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) parameterPath + s"&$pageNumberFormField=$page" + s"&$jobTag.sort=$encodedSortColumn" + @@ -514,10 +503,8 @@ private[ui] class JobPagedTable( s"#$tableHeaderId" } - override def goButtonFormPath: String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) + override def goButtonFormPath: String = s"$parameterPath&$jobTag.sort=$encodedSortColumn&$jobTag.desc=$desc#$tableHeaderId" - } override def headers: Seq[Node] = { // Information for each header: title, sortable, tooltip diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 1b072274541c8..47ba951953cec 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -212,7 +212,6 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We stageData, UIUtils.prependBaseUri(request, parent.basePath) + s"/stages/stage/?id=${stageId}&attempt=${stageAttemptId}", - currentTime, pageSize = taskPageSize, sortColumn = taskSortColumn, desc = taskSortDesc, @@ -452,7 +451,6 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We private[ui] class TaskDataSource( stage: StageData, - currentTime: Long, pageSize: Int, sortColumn: String, desc: Boolean, @@ -474,8 +472,6 @@ private[ui] class TaskDataSource( _tasksToShow } - def tasks: Seq[TaskData] = _tasksToShow - def executorLogs(id: String): Map[String, String] = { executorIdToLogs.getOrElseUpdate(id, store.asOption(store.executorSummary(id)).map(_.executorLogs).getOrElse(Map.empty)) @@ -486,7 +482,6 @@ private[ui] class TaskDataSource( private[ui] class TaskPagedTable( stage: StageData, basePath: String, - currentTime: Long, pageSize: Int, sortColumn: String, desc: Boolean, @@ -494,6 +489,8 @@ private[ui] class TaskPagedTable( import ApiHelper._ + private val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) + override def tableId: String = "task-table" override def tableCssClass: String = @@ -505,14 +502,12 @@ private[ui] class TaskPagedTable( override val dataSource: TaskDataSource = new TaskDataSource( stage, - currentTime, pageSize, sortColumn, desc, store) override def pageLink(page: Int): String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) basePath + s"&$pageNumberFormField=$page" + s"&task.sort=$encodedSortColumn" + @@ -520,10 +515,7 @@ private[ui] class TaskPagedTable( s"&$pageSizeFormField=$pageSize" } - override def goButtonFormPath: String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) - s"$basePath&task.sort=$encodedSortColumn&task.desc=$desc" - } + override def goButtonFormPath: String = s"$basePath&task.sort=$encodedSortColumn&task.desc=$desc" def headers: Seq[Node] = { import ApiHelper._ diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index f9e84c2b2f4ec..9e6eb418fe134 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -116,8 +116,7 @@ private[ui] class StagePagedTable( override def tableId: String = stageTag + "-table" override def tableCssClass: String = - "table table-bordered table-sm table-striped " + - "table-head-clickable table-cell-width-limited" + "table table-bordered table-sm table-striped table-head-clickable table-cell-width-limited" override def pageSizeFormField: String = stageTag + ".pageSize" @@ -125,7 +124,9 @@ private[ui] class StagePagedTable( private val (sortColumn, desc, pageSize) = getTableParameters(request, stageTag, "Stage Id") - val parameterPath = UIUtils.prependBaseUri(request, basePath) + s"/$subPath/?" + + private val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) + + private val parameterPath = UIUtils.prependBaseUri(request, basePath) + s"/$subPath/?" + getParameterOtherTable(request, stageTag) override val dataSource = new StageDataSource( @@ -138,7 +139,6 @@ private[ui] class StagePagedTable( ) override def pageLink(page: Int): String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) parameterPath + s"&$pageNumberFormField=$page" + s"&$stageTag.sort=$encodedSortColumn" + @@ -147,10 +147,8 @@ private[ui] class StagePagedTable( s"#$tableHeaderId" } - override def goButtonFormPath: String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) + override def goButtonFormPath: String = s"$parameterPath&$stageTag.sort=$encodedSortColumn&$stageTag.desc=$desc#$tableHeaderId" - } override def headers: Seq[Node] = { // stageHeadersAndCssClasses has three parts: header title, sortable and tooltip information. @@ -311,15 +309,9 @@ private[ui] class StageDataSource( // table so that we can avoid creating duplicate contents during sorting the data private val data = stages.map(stageRow).sorted(ordering(sortColumn, desc)) - private var _slicedStageIds: Set[Int] = _ - override def dataSize: Int = data.size - override def sliceData(from: Int, to: Int): Seq[StageTableRowData] = { - val r = data.slice(from, to) - _slicedStageIds = r.map(_.stageId).toSet - r - } + override def sliceData(from: Int, to: Int): Seq[StageTableRowData] = data.slice(from, to) private def stageRow(stageData: v1.StageData): StageTableRowData = { val formattedSubmissionTime = stageData.submissionTime match { @@ -350,7 +342,6 @@ private[ui] class StageDataSource( val shuffleWrite = stageData.shuffleWriteBytes val shuffleWriteWithUnit = if (shuffleWrite > 0) Utils.bytesToString(shuffleWrite) else "" - new StageTableRowData( stageData, Some(stageData), diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 26bbff5e54d25..844d9b7cf2c27 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -487,6 +487,7 @@ private[spark] object JsonProtocol { ("Callsite" -> rddInfo.callSite) ~ ("Parent IDs" -> parentIds) ~ ("Storage Level" -> storageLevel) ~ + ("Barrier" -> rddInfo.isBarrier) ~ ("Number of Partitions" -> rddInfo.numPartitions) ~ ("Number of Cached Partitions" -> rddInfo.numCachedPartitions) ~ ("Memory Size" -> rddInfo.memSize) ~ diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala index 6191e41b4118f..54899bfcf34fa 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.concurrent.Eventually import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.internal.config import org.apache.spark.internal.config.Tests.TEST_NO_STAGE_RETRY class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with Eventually { @@ -37,10 +38,10 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with .setAppName("test-cluster") .set(TEST_NO_STAGE_RETRY, true) sc = new SparkContext(conf) + TestUtils.waitUntilExecutorsUp(sc, numWorker, 60000) } - // TODO (SPARK-31730): re-enable it - ignore("global sync by barrier() call") { + test("global sync by barrier() call") { initLocalClusterSparkContext() val rdd = sc.makeRDD(1 to 10, 4) val rdd2 = rdd.barrier().mapPartitions { it => @@ -57,10 +58,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with } test("share messages with allGather() call") { - val conf = new SparkConf() - .setMaster("local-cluster[4, 1, 1024]") - .setAppName("test-cluster") - sc = new SparkContext(conf) + initLocalClusterSparkContext() val rdd = sc.makeRDD(1 to 10, 4) val rdd2 = rdd.barrier().mapPartitions { it => val context = BarrierTaskContext.get() @@ -78,10 +76,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with } test("throw exception if we attempt to synchronize with different blocking calls") { - val conf = new SparkConf() - .setMaster("local-cluster[4, 1, 1024]") - .setAppName("test-cluster") - sc = new SparkContext(conf) + initLocalClusterSparkContext() val rdd = sc.makeRDD(1 to 10, 4) val rdd2 = rdd.barrier().mapPartitions { it => val context = BarrierTaskContext.get() @@ -100,10 +95,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with } test("successively sync with allGather and barrier") { - val conf = new SparkConf() - .setMaster("local-cluster[4, 1, 1024]") - .setAppName("test-cluster") - sc = new SparkContext(conf) + initLocalClusterSparkContext() val rdd = sc.makeRDD(1 to 10, 4) val rdd2 = rdd.barrier().mapPartitions { it => val context = BarrierTaskContext.get() @@ -129,8 +121,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with assert(times2.max - times2.min <= 1000) } - // TODO (SPARK-31730): re-enable it - ignore("support multiple barrier() call within a single task") { + test("support multiple barrier() call within a single task") { initLocalClusterSparkContext() val rdd = sc.makeRDD(1 to 10, 4) val rdd2 = rdd.barrier().mapPartitions { it => @@ -285,6 +276,9 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with test("SPARK-31485: barrier stage should fail if only partial tasks are launched") { initLocalClusterSparkContext(2) + // It's required to reset the delay timer when a task is scheduled, otherwise all the tasks + // could get scheduled at ANY level. + sc.conf.set(config.LEGACY_LOCALITY_WAIT_RESET, true) val rdd0 = sc.parallelize(Seq(0, 1, 2, 3), 2) val dep = new OneToOneDependency[Int](rdd0) // set up a barrier stage with 2 tasks and both tasks prefer executor 0 (only 1 core) for diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 61ea21fa86c5a..7c23e4449f461 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -36,6 +36,7 @@ import org.apache.spark.deploy.history.{EventLogFileReader, SingleEventLogFileWr import org.apache.spark.deploy.history.EventLogTestHelper._ import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.{EVENT_LOG_DIR, EVENT_LOG_ENABLED} import org.apache.spark.io._ import org.apache.spark.metrics.{ExecutorMetricType, MetricsSystem} import org.apache.spark.resource.ResourceProfile @@ -100,6 +101,49 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit testStageExecutorMetricsEventLogging() } + test("SPARK-31764: isBarrier should be logged in event log") { + val conf = new SparkConf() + conf.set(EVENT_LOG_ENABLED, true) + conf.set(EVENT_LOG_DIR, testDirPath.toString) + val sc = new SparkContext("local", "test-SPARK-31764", conf) + val appId = sc.applicationId + + sc.parallelize(1 to 10) + .barrier() + .mapPartitions(_.map(elem => (elem, elem))) + .filter(elem => elem._1 % 2 == 0) + .reduceByKey(_ + _) + .collect + sc.stop() + + val eventLogStream = EventLogFileReader.openEventLog(new Path(testDirPath, appId), fileSystem) + val events = readLines(eventLogStream).map(line => JsonProtocol.sparkEventFromJson(parse(line))) + val jobStartEvents = events + .filter(event => event.isInstanceOf[SparkListenerJobStart]) + .map(_.asInstanceOf[SparkListenerJobStart]) + + assert(jobStartEvents.size === 1) + val stageInfos = jobStartEvents.head.stageInfos + assert(stageInfos.size === 2) + + val stage0 = stageInfos(0) + val rddInfosInStage0 = stage0.rddInfos + assert(rddInfosInStage0.size === 3) + val sortedRddInfosInStage0 = rddInfosInStage0.sortBy(_.scope.get.name) + assert(sortedRddInfosInStage0(0).scope.get.name === "filter") + assert(sortedRddInfosInStage0(0).isBarrier === true) + assert(sortedRddInfosInStage0(1).scope.get.name === "mapPartitions") + assert(sortedRddInfosInStage0(1).isBarrier === true) + assert(sortedRddInfosInStage0(2).scope.get.name === "parallelize") + assert(sortedRddInfosInStage0(2).isBarrier === false) + + val stage1 = stageInfos(1) + val rddInfosInStage1 = stage1.rddInfos + assert(rddInfosInStage1.size === 1) + assert(rddInfosInStage1(0).scope.get.name === "reduceByKey") + assert(rddInfosInStage1(0).isBarrier === false) // reduceByKey + } + /* ----------------- * * Actual test logic * * ----------------- */ diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 5d34a56473375..3d52199b01327 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -98,7 +98,6 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { val taskTable = new TaskPagedTable( stageData, basePath = "/a/b/c", - currentTime = 0, pageSize = 10, sortColumn = "Index", desc = false, diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index bc7f8b5d719db..248142a5ad633 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -1100,6 +1100,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 201, | "Number of Cached Partitions": 301, | "Memory Size": 401, @@ -1623,6 +1624,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 200, | "Number of Cached Partitions": 300, | "Memory Size": 400, @@ -1668,6 +1670,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 400, | "Number of Cached Partitions": 600, | "Memory Size": 800, @@ -1684,6 +1687,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 401, | "Number of Cached Partitions": 601, | "Memory Size": 801, @@ -1729,6 +1733,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 600, | "Number of Cached Partitions": 900, | "Memory Size": 1200, @@ -1745,6 +1750,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 601, | "Number of Cached Partitions": 901, | "Memory Size": 1201, @@ -1761,6 +1767,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 602, | "Number of Cached Partitions": 902, | "Memory Size": 1202, @@ -1806,6 +1813,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 800, | "Number of Cached Partitions": 1200, | "Memory Size": 1600, @@ -1822,6 +1830,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 801, | "Number of Cached Partitions": 1201, | "Memory Size": 1601, @@ -1838,6 +1847,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 802, | "Number of Cached Partitions": 1202, | "Memory Size": 1602, @@ -1854,6 +1864,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 803, | "Number of Cached Partitions": 1203, | "Memory Size": 1603, diff --git a/docs/sql-ref-datetime-pattern.md b/docs/sql-ref-datetime-pattern.md index 4275f03335b33..48e85b450e6b2 100644 --- a/docs/sql-ref-datetime-pattern.md +++ b/docs/sql-ref-datetime-pattern.md @@ -76,7 +76,7 @@ The count of pattern letters determines the format. - Year: The count of letters determines the minimum field width below which padding is used. If the count of letters is two, then a reduced two digit form is used. For printing, this outputs the rightmost two digits. For parsing, this will parse using the base value of 2000, resulting in a year within the range 2000 to 2099 inclusive. If the count of letters is less than four (but not two), then the sign is only output for negative years. Otherwise, the sign is output if the pad width is exceeded when 'G' is not present. -- Month: If the number of pattern letters is 3 or more, the month is interpreted as text; otherwise, it is interpreted as a number. The text form is depend on letters - 'M' denotes the 'standard' form, and 'L' is for 'stand-alone' form. The difference between the 'standard' and 'stand-alone' forms is trickier to describe as there is no difference in English. However, in other languages there is a difference in the word used when the text is used alone, as opposed to in a complete date. For example, the word used for a month when used alone in a date picker is different to the word used for month in association with a day and year in a date. In Russian, 'Июль' is the stand-alone form of July, and 'Июля' is the standard form. Here are examples for all supported pattern letters (more than 4 letters is invalid): +- Month: It follows the rule of Number/Text. The text form is depend on letters - 'M' denotes the 'standard' form, and 'L' is for 'stand-alone' form. These two forms are different only in some certain languages. For example, in Russian, 'Июль' is the stand-alone form of July, and 'Июля' is the standard form. Here are examples for all supported pattern letters: - `'M'` or `'L'`: Month number in a year starting from 1. There is no difference between 'M' and 'L'. Month from 1 to 9 are printed without padding. ```sql spark-sql> select date_format(date '1970-01-01', "M"); @@ -107,8 +107,8 @@ The count of pattern letters determines the format. ``` - `'MMMM'`: full textual month representation in the standard form. It is used for parsing/formatting months as a part of dates/timestamps. ```sql - spark-sql> select date_format(date '1970-01-01', "MMMM yyyy"); - January 1970 + spark-sql> select date_format(date '1970-01-01', "d MMMM"); + 1 January spark-sql> select to_csv(named_struct('date', date '1970-01-01'), map('dateFormat', 'd MMMM', 'locale', 'RU')); 1 января ``` diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index febeba7e13fcb..e0b128e369816 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml import org.apache.spark.annotation.Since import org.apache.spark.ml.feature.{Instance, LabeledPoint} +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -71,7 +72,7 @@ private[ml] trait PredictorParams extends Params val w = this match { case p: HasWeightCol => if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) { - col($(p.weightCol)).cast(DoubleType) + checkNonNegativeWeight((col($(p.weightCol)).cast(DoubleType))) } else { lit(1.0) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 5459a0fab9135..e65295dbdaf55 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -22,6 +22,7 @@ import org.json4s.DefaultFormats import org.apache.spark.annotation.Since import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.HasWeightCol @@ -179,7 +180,7 @@ class NaiveBayes @Since("1.5.0") ( } val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) { - col($(weightCol)).cast(DoubleType) + checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) } else { lit(1.0) } @@ -259,7 +260,7 @@ class NaiveBayes @Since("1.5.0") ( import spark.implicits._ val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) { - col($(weightCol)).cast(DoubleType) + checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) } else { lit(1.0) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 6c7112b80569f..b09f11dcfe156 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -21,6 +21,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -280,7 +281,7 @@ class BisectingKMeans @Since("2.0.0") ( val handlePersistence = dataset.storageLevel == StorageLevel.NONE val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) { - col($(weightCol)).cast(DoubleType) + checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) } else { lit(1.0) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 6d4137b638dcc..18fd220b4ca9c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -22,6 +22,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Since import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.impl.Utils.{unpackUpperTriangular, EPSILON} import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param._ @@ -417,7 +418,7 @@ class GaussianMixture @Since("2.0.0") ( instr.logNumFeatures(numFeatures) val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) { - col($(weightCol)).cast(DoubleType) + checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) } else { lit(1.0) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index a42c920e24987..806015b633c23 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model, PipelineStage} +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -336,7 +337,7 @@ class KMeans @Since("1.5.0") ( val handlePersistence = dataset.storageLevel == StorageLevel.NONE val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) { - col($(weightCol)).cast(DoubleType) + checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) } else { lit(1.0) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index fac4d92b1810c..52be22f714981 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.Since +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -131,7 +132,7 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va col($(rawPredictionCol)), col($(labelCol)).cast(DoubleType), if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) - else col($(weightCol)).cast(DoubleType)).rdd.map { + else checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))).rdd.map { case Row(rawPrediction: Vector, label: Double, weight: Double) => (rawPrediction(1), label, weight) case Row(rawPrediction: Double, label: Double, weight: Double) => diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala index 19790fd270619..fa2c25a5912a7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.Since +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol, HasWeightCol} import org.apache.spark.ml.util._ @@ -139,7 +140,7 @@ class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: Str } else { dataset.select(col($(predictionCol)), vectorCol.as($(featuresCol), dataset.schema($(featuresCol)).metadata), - col(weightColName).cast(DoubleType)) + checkNonNegativeWeight(col(weightColName).cast(DoubleType))) } val metrics = new ClusteringMetrics(df) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala index 8bf4ee1ecadfb..a785d063f1476 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala @@ -300,7 +300,6 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette { (featureSum: DenseVector, squaredNormSum: Double, weightSum: Double), (features, squaredNorm, weight) ) => - require(weight >= 0.0, s"illegal weight value: $weight. weight must be >= 0.0.") BLAS.axpy(weight, features, featureSum) (featureSum, squaredNormSum + squaredNorm * weight, weightSum + weight) }, @@ -503,7 +502,6 @@ private[evaluation] object CosineSilhouette extends Silhouette { seqOp = { case ((normalizedFeaturesSum: DenseVector, weightSum: Double), (normalizedFeatures, weight)) => - require(weight >= 0.0, s"illegal weight value: $weight. weight must be >= 0.0.") BLAS.axpy(weight, normalizedFeatures, normalizedFeaturesSum) (normalizedFeaturesSum, weightSum + weight) }, diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index ad1b70915e157..3d77792c4fc88 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.Since +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -186,7 +187,7 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid SchemaUtils.checkNumericType(schema, $(labelCol)) val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) { - col($(weightCol)).cast(DoubleType) + checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) } else { lit(1.0) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index aca017762deca..f0b7c345c3285 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.Since +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} @@ -122,7 +123,8 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui val predictionAndLabelsWithWeights = dataset .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType), - if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))) + if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) + else checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))) .rdd .map { case Row(prediction: Double, label: Double, weight: Double) => (prediction, label, weight) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/functions.scala b/mllib/src/main/scala/org/apache/spark/ml/functions.scala index 0f03231079866..a0b6d11a46be9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/functions.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/functions.scala @@ -71,4 +71,10 @@ object functions { ) } } + + private[ml] def checkNonNegativeWeight = udf { + value: Double => + require(value >= 0, s"illegal weight value: $value. weight must be >= 0.0.") + value + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index fa41a98749f32..0ee895a95a288 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -29,6 +29,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.attribute._ import org.apache.spark.ml.feature.{Instance, OffsetInstance} +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} import org.apache.spark.ml.optim._ import org.apache.spark.ml.param._ @@ -399,7 +400,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val "GeneralizedLinearRegression was given data with 0 features, and with Param fitIntercept " + "set to false. To fit a model with 0 features, fitIntercept must be set to true." ) - val w = if (!hasWeightCol) lit(1.0) else col($(weightCol)) + val w = if (!hasWeightCol) lit(1.0) else checkNonNegativeWeight(col($(weightCol))) val offset = if (!hasOffsetCol) lit(0.0) else col($(offsetCol)).cast(DoubleType) val model = if (familyAndLink.family == Gaussian && familyAndLink.link == Identity) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index fe4de57de60f2..ec2640e9ef225 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -22,6 +22,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -87,11 +88,11 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures } else { col($(featuresCol)) } - val w = if (hasWeightCol) col($(weightCol)).cast(DoubleType) else lit(1.0) + val w = + if (hasWeightCol) checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) else lit(1.0) dataset.select(col($(labelCol)).cast(DoubleType), f, w).rdd.map { - case Row(label: Double, feature: Double, weight: Double) => - (label, feature, weight) + case Row(label: Double, feature: Double, weight: Double) => (label, feature, weight) } } diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index a3ce87096e790..65b902cf3c4d5 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -2219,6 +2219,20 @@ def semanticHash(self): """ return self._jdf.semanticHash() + @since(3.1) + def inputFiles(self): + """ + Returns a best-effort snapshot of the files that compose this :class:`DataFrame`. + This method simply asks each constituent BaseRelation for its respective files and + takes the union of all results. Depending on the source relations, this may not find + all input files. Duplicates are removed. + + >>> df = spark.read.load("examples/src/main/resources/people.json", format="json") + >>> len(df.inputFiles()) + 1 + """ + return list(self._jdf.inputFiles()) + where = copy_func( filter, sinceversion=1.3, diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 4dd15d14b9c53..ff0b10a9306cf 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -154,6 +154,9 @@ def create_array(s, t): # Ensure timestamp series are in expected form for Spark internal representation if t is not None and pa.types.is_timestamp(t): s = _check_series_convert_timestamps_internal(s, self._timezone) + elif type(s.dtype) == pd.CategoricalDtype: + # Note: This can be removed once minimum pyarrow version is >= 0.16.1 + s = s.astype(s.dtypes.categories.dtype) try: array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck) except pa.ArrowException as e: diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index d1edf3f9c47c1..4b70c8a2e95e1 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -114,6 +114,8 @@ def from_arrow_type(at): return StructType( [StructField(field.name, from_arrow_type(field.type), nullable=field.nullable) for field in at]) + elif types.is_dictionary(at): + spark_type = from_arrow_type(at.value_type) else: raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) return spark_type diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 004c79f290213..c3c9fb0f12a25 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -415,6 +415,32 @@ def run_test(num_records, num_parts, max_records, use_delay=False): for case in cases: run_test(*case) + def test_createDateFrame_with_category_type(self): + pdf = pd.DataFrame({"A": [u"a", u"b", u"c", u"a"]}) + pdf["B"] = pdf["A"].astype('category') + category_first_element = dict(enumerate(pdf['B'].cat.categories))[0] + + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": True}): + arrow_df = self.spark.createDataFrame(pdf) + arrow_type = arrow_df.dtypes[1][1] + result_arrow = arrow_df.toPandas() + arrow_first_category_element = result_arrow["B"][0] + + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): + df = self.spark.createDataFrame(pdf) + spark_type = df.dtypes[1][1] + result_spark = df.toPandas() + spark_first_category_element = result_spark["B"][0] + + assert_frame_equal(result_spark, result_arrow) + + # ensure original category elements are string + assert isinstance(category_first_element, str) + # spark data frame and arrow execution mode enabled data frame type must match pandas + assert spark_type == arrow_type == 'string' + assert isinstance(arrow_first_category_element, str) + assert isinstance(spark_first_category_element, str) + @unittest.skipIf( not have_pandas or not have_pyarrow, diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 9861178158f85..062e61663a332 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -17,6 +17,8 @@ import os import pydoc +import shutil +import tempfile import time import unittest @@ -820,6 +822,22 @@ def test_same_semantics_error(self): with self.assertRaisesRegexp(ValueError, "should be of DataFrame.*int"): self.spark.range(10).sameSemantics(1) + def test_input_files(self): + tpath = tempfile.mkdtemp() + shutil.rmtree(tpath) + try: + self.spark.range(1, 100, 1, 10).write.parquet(tpath) + # read parquet file and get the input files list + input_files_list = self.spark.read.parquet(tpath).inputFiles() + + # input files list should contain 10 entries + self.assertEquals(len(input_files_list), 10) + # all file paths in list must contain tpath + for file_path in input_files_list: + self.assertTrue(tpath in file_path) + finally: + shutil.rmtree(tpath) + class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils): # These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index 7260e80e2cfca..ae6b8d520f735 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -897,6 +897,27 @@ def test_timestamp_dst(self): result = df.withColumn('time', foo_udf(df.time)) self.assertEquals(df.collect(), result.collect()) + def test_udf_category_type(self): + + @pandas_udf('string') + def to_category_func(x): + return x.astype('category') + + pdf = pd.DataFrame({"A": [u"a", u"b", u"c", u"a"]}) + df = self.spark.createDataFrame(pdf) + df = df.withColumn("B", to_category_func(df['A'])) + result_spark = df.toPandas() + + spark_type = df.dtypes[1][1] + # spark data frame and arrow execution mode enabled data frame type must match pandas + assert spark_type == 'string' + + # Check result value of column 'B' must be equal to column 'A' + for i in range(0, len(result_spark["A"])): + assert result_spark["A"][i] == result_spark["B"][i] + assert isinstance(result_spark["A"][i], str) + assert isinstance(result_spark["B"][i], str) + @unittest.skipIf(sys.version_info[:2] < (3, 5), "Type hints are supported from Python 3.5.") def test_type_annotation(self): # Regression test to check if type hints can be used. See SPARK-23569. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index b65221c236bfe..85c6600685bd1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -208,3 +208,161 @@ object ExtractPythonUDFFromJoinCondition extends Rule[LogicalPlan] with Predicat } } } + +sealed abstract class BuildSide + +case object BuildRight extends BuildSide + +case object BuildLeft extends BuildSide + +trait JoinSelectionHelper { + + def getBroadcastBuildSide( + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + hint: JoinHint, + hintOnly: Boolean, + conf: SQLConf): Option[BuildSide] = { + val buildLeft = if (hintOnly) { + hintToBroadcastLeft(hint) + } else { + canBroadcastBySize(left, conf) && !hintToNotBroadcastLeft(hint) + } + val buildRight = if (hintOnly) { + hintToBroadcastRight(hint) + } else { + canBroadcastBySize(right, conf) && !hintToNotBroadcastRight(hint) + } + getBuildSide( + canBuildLeft(joinType) && buildLeft, + canBuildRight(joinType) && buildRight, + left, + right + ) + } + + def getShuffleHashJoinBuildSide( + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + hint: JoinHint, + hintOnly: Boolean, + conf: SQLConf): Option[BuildSide] = { + val buildLeft = if (hintOnly) { + hintToShuffleHashJoinLeft(hint) + } else { + canBuildLocalHashMapBySize(left, conf) && muchSmaller(left, right) + } + val buildRight = if (hintOnly) { + hintToShuffleHashJoinRight(hint) + } else { + canBuildLocalHashMapBySize(right, conf) && muchSmaller(right, left) + } + getBuildSide( + canBuildLeft(joinType) && buildLeft, + canBuildRight(joinType) && buildRight, + left, + right + ) + } + + def getSmallerSide(left: LogicalPlan, right: LogicalPlan): BuildSide = { + if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft + } + + /** + * Matches a plan whose output should be small enough to be used in broadcast join. + */ + def canBroadcastBySize(plan: LogicalPlan, conf: SQLConf): Boolean = { + plan.stats.sizeInBytes >= 0 && plan.stats.sizeInBytes <= conf.autoBroadcastJoinThreshold + } + + def canBuildLeft(joinType: JoinType): Boolean = { + joinType match { + case _: InnerLike | RightOuter => true + case _ => false + } + } + + def canBuildRight(joinType: JoinType): Boolean = { + joinType match { + case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => true + case _ => false + } + } + + def hintToBroadcastLeft(hint: JoinHint): Boolean = { + hint.leftHint.exists(_.strategy.contains(BROADCAST)) + } + + def hintToBroadcastRight(hint: JoinHint): Boolean = { + hint.rightHint.exists(_.strategy.contains(BROADCAST)) + } + + def hintToNotBroadcastLeft(hint: JoinHint): Boolean = { + hint.leftHint.exists(_.strategy.contains(NO_BROADCAST_HASH)) + } + + def hintToNotBroadcastRight(hint: JoinHint): Boolean = { + hint.rightHint.exists(_.strategy.contains(NO_BROADCAST_HASH)) + } + + def hintToShuffleHashJoinLeft(hint: JoinHint): Boolean = { + hint.leftHint.exists(_.strategy.contains(SHUFFLE_HASH)) + } + + def hintToShuffleHashJoinRight(hint: JoinHint): Boolean = { + hint.rightHint.exists(_.strategy.contains(SHUFFLE_HASH)) + } + + def hintToSortMergeJoin(hint: JoinHint): Boolean = { + hint.leftHint.exists(_.strategy.contains(SHUFFLE_MERGE)) || + hint.rightHint.exists(_.strategy.contains(SHUFFLE_MERGE)) + } + + def hintToShuffleReplicateNL(hint: JoinHint): Boolean = { + hint.leftHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL)) || + hint.rightHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL)) + } + + private def getBuildSide( + canBuildLeft: Boolean, + canBuildRight: Boolean, + left: LogicalPlan, + right: LogicalPlan): Option[BuildSide] = { + if (canBuildLeft && canBuildRight) { + // returns the smaller side base on its estimated physical size, if we want to build the + // both sides. + Some(getSmallerSide(left, right)) + } else if (canBuildLeft) { + Some(BuildLeft) + } else if (canBuildRight) { + Some(BuildRight) + } else { + None + } + } + + /** + * Matches a plan whose single partition should be small enough to build a hash table. + * + * Note: this assume that the number of partition is fixed, requires additional work if it's + * dynamic. + */ + private def canBuildLocalHashMapBySize(plan: LogicalPlan, conf: SQLConf): Boolean = { + plan.stats.sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions + } + + /** + * Returns whether plan a is much smaller (3X) than plan b. + * + * The cost to build hash map is higher than sorting, we should only build hash map on a table + * that is much smaller than other one. Since we does not have the statistic for number of rows, + * use the size of bytes here as estimation. + */ + private def muchSmaller(a: LogicalPlan, b: LogicalPlan): Boolean = { + a.stats.sizeInBytes * 3 <= b.stats.sizeInBytes + } +} + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala index 0ea54c28cb285..353c074caa75e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala @@ -217,9 +217,18 @@ private object DateTimeFormatterHelper { toFormatter(builder, TimestampFormatter.defaultLocale) } + private final val bugInStandAloneForm = { + // Java 8 has a bug for stand-alone form. See https://bugs.openjdk.java.net/browse/JDK-8114833 + // Note: we only check the US locale so that it's a static check. It can produce false-negative + // as some locales are not affected by the bug. Since `L`/`q` is rarely used, we choose to not + // complicate the check here. + // TODO: remove it when we drop Java 8 support. + val formatter = DateTimeFormatter.ofPattern("LLL qqq", Locale.US) + formatter.format(LocalDate.of(2000, 1, 1)) == "1 1" + } final val unsupportedLetters = Set('A', 'c', 'e', 'n', 'N', 'p') final val unsupportedNarrowTextStyle = - Set("GGGGG", "MMMMM", "LLLLL", "EEEEE", "uuuuu", "QQQQQ", "qqqqq", "uuuuu") + Seq("G", "M", "L", "E", "u", "Q", "q").map(_ * 5).toSet /** * In Spark 3.0, we switch to the Proleptic Gregorian calendar and use DateTimeFormatter for @@ -244,6 +253,12 @@ private object DateTimeFormatterHelper { for (style <- unsupportedNarrowTextStyle if patternPart.contains(style)) { throw new IllegalArgumentException(s"Too many pattern letters: ${style.head}") } + if (bugInStandAloneForm && (patternPart.contains("LLL") || patternPart.contains("qqq"))) { + throw new IllegalArgumentException("Java 8 has a bug to support stand-alone " + + "form (3 or more 'L' or 'q' in the pattern string). Please use 'M' or 'Q' instead, " + + "or upgrade your Java version. For more details, please read " + + "https://bugs.openjdk.java.net/browse/JDK-8114833") + } // The meaning of 'u' was day number of week in SimpleDateFormat, it was changed to year // in DateTimeFormatter. Substitute 'u' to 'e' and use DateTimeFormatter to parse the // string. If parsable, return the result; otherwise, fall back to 'u', and then use the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala index de2fd312b7db5..8428964d45707 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala @@ -127,16 +127,37 @@ class FractionTimestampFormatter(zoneId: ZoneId) override protected lazy val formatter = DateTimeFormatterHelper.fractionFormatter // The new formatter will omit the trailing 0 in the timestamp string, but the legacy formatter - // can't. Here we borrow the code from Spark 2.4 DateTimeUtils.timestampToString to omit the - // trailing 0 for the legacy formatter as well. + // can't. Here we use the legacy formatter to format the given timestamp up to seconds fractions, + // and custom implementation to format the fractional part without trailing zeros. override def format(ts: Timestamp): String = { - val timestampString = ts.toString val formatted = legacyFormatter.format(ts) - - if (timestampString.length > 19 && timestampString.substring(19) != ".0") { - formatted + timestampString.substring(19) - } else { + var nanos = ts.getNanos + if (nanos == 0) { formatted + } else { + // Formats non-zero seconds fraction w/o trailing zeros. For example: + // formatted = '2020-05:27 15:55:30' + // nanos = 001234000 + // Counts the length of the fractional part: 001234000 -> 6 + var fracLen = 9 + while (nanos % 10 == 0) { + nanos /= 10 + fracLen -= 1 + } + // Places `nanos` = 1234 after '2020-05:27 15:55:30.' + val fracOffset = formatted.length + 1 + val totalLen = fracOffset + fracLen + // The buffer for the final result: '2020-05:27 15:55:30.001234' + val buf = new Array[Char](totalLen) + formatted.getChars(0, formatted.length, buf, 0) + buf(formatted.length) = '.' + var i = totalLen + do { + i -= 1 + buf(i) = ('0' + (nanos % 10)).toChar + nanos /= 10 + } while (i > fracOffset) + new String(buf) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index e5bff7f7af007..6af995cab64fe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -240,7 +240,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkCast(1.5, "1.5") checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble) - checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble) } test("cast from string") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 02d6d847dc063..1ca7380ead413 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -792,7 +792,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } // Test escaping of format - GenerateUnsafeProjection.generate(FromUnixTime(Literal(0L), Literal("\"quote")) :: Nil) + GenerateUnsafeProjection.generate(FromUnixTime(Literal(0L), Literal("\"quote"), UTC_OPT) :: Nil) } test("unix_timestamp") { @@ -862,7 +862,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } // Test escaping of format GenerateUnsafeProjection.generate( - UnixTimestamp(Literal("2015-07-24"), Literal("\"quote")) :: Nil) + UnixTimestamp(Literal("2015-07-24"), Literal("\"quote"), UTC_OPT) :: Nil) } test("to_unix_timestamp") { @@ -940,7 +940,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } // Test escaping of format GenerateUnsafeProjection.generate( - ToUnixTimestamp(Literal("2015-07-24"), Literal("\"quote")) :: Nil) + ToUnixTimestamp(Literal("2015-07-24"), Literal("\"quote"), UTC_OPT) :: Nil) } test("datediff") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinSelectionHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinSelectionHelperSuite.scala new file mode 100644 index 0000000000000..3513cfa14808f --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinSelectionHelperSuite.scala @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.AttributeMap +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, JoinHint, NO_BROADCAST_HASH, SHUFFLE_HASH} +import org.apache.spark.sql.catalyst.statsEstimation.StatsTestPlan +import org.apache.spark.sql.internal.SQLConf + +class JoinSelectionHelperSuite extends PlanTest with JoinSelectionHelper { + + private val left = StatsTestPlan( + outputList = Seq('a.int, 'b.int, 'c.int), + rowCount = 20000000, + size = Some(20000000), + attributeStats = AttributeMap(Seq())) + + private val right = StatsTestPlan( + outputList = Seq('d.int), + rowCount = 1000, + size = Some(1000), + attributeStats = AttributeMap(Seq())) + + private val hintBroadcast = Some(HintInfo(Some(BROADCAST))) + private val hintNotToBroadcast = Some(HintInfo(Some(NO_BROADCAST_HASH))) + private val hintShuffleHash = Some(HintInfo(Some(SHUFFLE_HASH))) + + test("getBroadcastBuildSide (hintOnly = true) return BuildLeft with only a left hint") { + val broadcastSide = getBroadcastBuildSide( + left, + right, + Inner, + JoinHint(hintBroadcast, None), + hintOnly = true, + SQLConf.get + ) + assert(broadcastSide === Some(BuildLeft)) + } + + test("getBroadcastBuildSide (hintOnly = true) return BuildRight with only a right hint") { + val broadcastSide = getBroadcastBuildSide( + left, + right, + Inner, + JoinHint(None, hintBroadcast), + hintOnly = true, + SQLConf.get + ) + assert(broadcastSide === Some(BuildRight)) + } + + test("getBroadcastBuildSide (hintOnly = true) return smaller side with both having hints") { + val broadcastSide = getBroadcastBuildSide( + left, + right, + Inner, + JoinHint(hintBroadcast, hintBroadcast), + hintOnly = true, + SQLConf.get + ) + assert(broadcastSide === Some(BuildRight)) + } + + test("getBroadcastBuildSide (hintOnly = true) return None when no side has a hint") { + val broadcastSide = getBroadcastBuildSide( + left, + right, + Inner, + JoinHint(None, None), + hintOnly = true, + SQLConf.get + ) + assert(broadcastSide === None) + } + + test("getBroadcastBuildSide (hintOnly = false) return BuildRight when right is broadcastable") { + val broadcastSide = getBroadcastBuildSide( + left, + right, + Inner, + JoinHint(None, None), + hintOnly = false, + SQLConf.get + ) + assert(broadcastSide === Some(BuildRight)) + } + + test("getBroadcastBuildSide (hintOnly = false) return None when right has no broadcast hint") { + val broadcastSide = getBroadcastBuildSide( + left, + right, + Inner, + JoinHint(None, hintNotToBroadcast ), + hintOnly = false, + SQLConf.get + ) + assert(broadcastSide === None) + } + + test("getShuffleHashJoinBuildSide (hintOnly = true) return BuildLeft with only a left hint") { + val broadcastSide = getShuffleHashJoinBuildSide( + left, + right, + Inner, + JoinHint(hintShuffleHash, None), + hintOnly = true, + SQLConf.get + ) + assert(broadcastSide === Some(BuildLeft)) + } + + test("getShuffleHashJoinBuildSide (hintOnly = true) return BuildRight with only a right hint") { + val broadcastSide = getShuffleHashJoinBuildSide( + left, + right, + Inner, + JoinHint(None, hintShuffleHash), + hintOnly = true, + SQLConf.get + ) + assert(broadcastSide === Some(BuildRight)) + } + + test("getShuffleHashJoinBuildSide (hintOnly = true) return smaller side when both have hints") { + val broadcastSide = getShuffleHashJoinBuildSide( + left, + right, + Inner, + JoinHint(hintShuffleHash, hintShuffleHash), + hintOnly = true, + SQLConf.get + ) + assert(broadcastSide === Some(BuildRight)) + } + + test("getShuffleHashJoinBuildSide (hintOnly = true) return None when no side has a hint") { + val broadcastSide = getShuffleHashJoinBuildSide( + left, + right, + Inner, + JoinHint(None, None), + hintOnly = true, + SQLConf.get + ) + assert(broadcastSide === None) + } + + test("getShuffleHashJoinBuildSide (hintOnly = false) return BuildRight when right is smaller") { + val broadcastSide = getBroadcastBuildSide( + left, + right, + Inner, + JoinHint(None, None), + hintOnly = false, + SQLConf.get + ) + assert(broadcastSide === Some(BuildRight)) + } + + test("getSmallerSide should return BuildRight") { + assert(getSmallerSide(left, right) === BuildRight) + } + + test("canBroadcastBySize should return true if the plan size is less than 10MB") { + assert(canBroadcastBySize(left, SQLConf.get) === false) + assert(canBroadcastBySize(right, SQLConf.get) === true) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala index 4324d3cff63d7..7ff9b46bc6719 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala @@ -135,6 +135,9 @@ class TimestampFormatterSuite extends SparkFunSuite with SQLHelper with Matchers test("format fraction of second") { val formatter = TimestampFormatter.getFractionFormatter(UTC) Seq( + -999999 -> "1969-12-31 23:59:59.000001", + -999900 -> "1969-12-31 23:59:59.0001", + -1 -> "1969-12-31 23:59:59.999999", 0 -> "1970-01-01 00:00:00", 1 -> "1970-01-01 00:00:00.000001", 1000 -> "1970-01-01 00:00:00.001", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 12a1a1e7fc16e..302aae08d588b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, JoinSelectionHelper, NormalizeFloatingNumbers} import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -33,7 +33,6 @@ import org.apache.spark.sql.execution.aggregate.AggUtils import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec -import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.execution.python._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.MemoryPlan @@ -135,93 +134,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Supports both equi-joins and non-equi-joins. * Supports only inner like joins. */ - object JoinSelection extends Strategy with PredicateHelper { - - /** - * Matches a plan whose output should be small enough to be used in broadcast join. - */ - private def canBroadcast(plan: LogicalPlan): Boolean = { - plan.stats.sizeInBytes >= 0 && plan.stats.sizeInBytes <= conf.autoBroadcastJoinThreshold - } - - /** - * Matches a plan whose single partition should be small enough to build a hash table. - * - * Note: this assume that the number of partition is fixed, requires additional work if it's - * dynamic. - */ - private def canBuildLocalHashMap(plan: LogicalPlan): Boolean = { - plan.stats.sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions - } - - /** - * Returns whether plan a is much smaller (3X) than plan b. - * - * The cost to build hash map is higher than sorting, we should only build hash map on a table - * that is much smaller than other one. Since we does not have the statistic for number of rows, - * use the size of bytes here as estimation. - */ - private def muchSmaller(a: LogicalPlan, b: LogicalPlan): Boolean = { - a.stats.sizeInBytes * 3 <= b.stats.sizeInBytes - } - - private def canBuildRight(joinType: JoinType): Boolean = joinType match { - case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => true - case _ => false - } - - private def canBuildLeft(joinType: JoinType): Boolean = joinType match { - case _: InnerLike | RightOuter => true - case _ => false - } - - private def getBuildSide( - wantToBuildLeft: Boolean, - wantToBuildRight: Boolean, - left: LogicalPlan, - right: LogicalPlan): Option[BuildSide] = { - if (wantToBuildLeft && wantToBuildRight) { - // returns the smaller side base on its estimated physical size, if we want to build the - // both sides. - Some(getSmallerSide(left, right)) - } else if (wantToBuildLeft) { - Some(BuildLeft) - } else if (wantToBuildRight) { - Some(BuildRight) - } else { - None - } - } - - private def getSmallerSide(left: LogicalPlan, right: LogicalPlan) = { - if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft - } - - private def hintToBroadcastLeft(hint: JoinHint): Boolean = { - hint.leftHint.exists(_.strategy.contains(BROADCAST)) - } - - private def hintToBroadcastRight(hint: JoinHint): Boolean = { - hint.rightHint.exists(_.strategy.contains(BROADCAST)) - } - - private def hintToShuffleHashLeft(hint: JoinHint): Boolean = { - hint.leftHint.exists(_.strategy.contains(SHUFFLE_HASH)) - } - - private def hintToShuffleHashRight(hint: JoinHint): Boolean = { - hint.rightHint.exists(_.strategy.contains(SHUFFLE_HASH)) - } - - private def hintToSortMergeJoin(hint: JoinHint): Boolean = { - hint.leftHint.exists(_.strategy.contains(SHUFFLE_MERGE)) || - hint.rightHint.exists(_.strategy.contains(SHUFFLE_MERGE)) - } - - private def hintToShuffleReplicateNL(hint: JoinHint): Boolean = { - hint.leftHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL)) || - hint.rightHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL)) - } + object JoinSelection extends Strategy + with PredicateHelper + with JoinSelectionHelper { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { @@ -245,33 +160,31 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // 5. Pick broadcast nested loop join as the final solution. It may OOM but we don't have // other choice. case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, hint) => - def createBroadcastHashJoin(buildLeft: Boolean, buildRight: Boolean) = { - val wantToBuildLeft = canBuildLeft(joinType) && buildLeft - val wantToBuildRight = canBuildRight(joinType) && buildRight - getBuildSide(wantToBuildLeft, wantToBuildRight, left, right).map { buildSide => - Seq(joins.BroadcastHashJoinExec( - leftKeys, - rightKeys, - joinType, - buildSide, - condition, - planLater(left), - planLater(right))) + def createBroadcastHashJoin(onlyLookingAtHint: Boolean) = { + getBroadcastBuildSide(left, right, joinType, hint, onlyLookingAtHint, conf).map { + buildSide => + Seq(joins.BroadcastHashJoinExec( + leftKeys, + rightKeys, + joinType, + buildSide, + condition, + planLater(left), + planLater(right))) } } - def createShuffleHashJoin(buildLeft: Boolean, buildRight: Boolean) = { - val wantToBuildLeft = canBuildLeft(joinType) && buildLeft - val wantToBuildRight = canBuildRight(joinType) && buildRight - getBuildSide(wantToBuildLeft, wantToBuildRight, left, right).map { buildSide => - Seq(joins.ShuffledHashJoinExec( - leftKeys, - rightKeys, - joinType, - buildSide, - condition, - planLater(left), - planLater(right))) + def createShuffleHashJoin(onlyLookingAtHint: Boolean) = { + getShuffleHashJoinBuildSide(left, right, joinType, hint, onlyLookingAtHint, conf).map { + buildSide => + Seq(joins.ShuffledHashJoinExec( + leftKeys, + rightKeys, + joinType, + buildSide, + condition, + planLater(left), + planLater(right))) } } @@ -293,14 +206,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } def createJoinWithoutHint() = { - createBroadcastHashJoin( - canBroadcast(left) && !hint.leftHint.exists(_.strategy.contains(NO_BROADCAST_HASH)), - canBroadcast(right) && !hint.rightHint.exists(_.strategy.contains(NO_BROADCAST_HASH))) + createBroadcastHashJoin(false) .orElse { if (!conf.preferSortMergeJoin) { - createShuffleHashJoin( - canBuildLocalHashMap(left) && muchSmaller(left, right), - canBuildLocalHashMap(right) && muchSmaller(right, left)) + createShuffleHashJoin(false) } else { None } @@ -315,9 +224,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - createBroadcastHashJoin(hintToBroadcastLeft(hint), hintToBroadcastRight(hint)) + createBroadcastHashJoin(true) .orElse { if (hintToSortMergeJoin(hint)) createSortMergeJoin() else None } - .orElse(createShuffleHashJoin(hintToShuffleHashLeft(hint), hintToShuffleHashRight(hint))) + .orElse(createShuffleHashJoin(true)) .orElse { if (hintToShuffleReplicateNL(hint)) createCartesianProduct() else None } .getOrElse(createJoinWithoutHint()) @@ -374,7 +283,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } def createJoinWithoutHint() = { - createBroadcastNLJoin(canBroadcast(left), canBroadcast(right)) + createBroadcastNLJoin(canBroadcastBySize(left, conf), canBroadcastBySize(right, conf)) .orElse(createCartesianProduct()) .getOrElse { // This join could be very slow or OOM diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala index d60c3ca72f6f6..ac98342277bc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.expressions.PredicateHelper +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan} import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, BuildLeft, BuildRight} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec} /** * Strategy for plans containing [[LogicalQueryStage]] nodes: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala index 5416fde222cb6..3620f27058af2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.execution.adaptive +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec} -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.internal.SQLConf /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala index eb091758910cd..cfc653a23840d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.dynamicpruning import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.{Alias, BindReferences, DynamicPruningExpression, DynamicPruningSubquery, Expression, ListQuery, Literal, PredicateHelper} +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode import org.apache.spark.sql.catalyst.rules.Rule diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 08128d8f69dab..707ed1402d1ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index 888e7af7c07ed..52b476f9cf134 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -21,6 +21,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{ExplainUtils, SparkPlan} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 7f90a51c1f234..c7c3e1672f034 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{ExplainUtils, RowIterator} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 755a63e545ef1..2b7cd65e7d96f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -23,6 +23,7 @@ import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.optimizer.BuildSide import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.SparkPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala deleted file mode 100644 index 134376628ae7f..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -/** - * Physical execution operators for join operations. - */ -package object joins { - - sealed abstract class BuildSide - - case object BuildRight extends BuildSide - - case object BuildLeft extends BuildSide - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index 8c23f2cbb86ba..33539c01ee5dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -203,11 +203,10 @@ private[ui] class ExecutionPagedTable( private val (sortColumn, desc, pageSize) = getTableParameters(request, executionTag, "ID") + private val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) + override val dataSource = new ExecutionDataSource( - request, - parent, data, - basePath, currentTime, pageSize, sortColumn, @@ -222,11 +221,9 @@ private[ui] class ExecutionPagedTable( override def tableId: String = s"$executionTag-table" override def tableCssClass: String = - "table table-bordered table-sm table-striped " + - "table-head-clickable table-cell-width-limited" + "table table-bordered table-sm table-striped table-head-clickable table-cell-width-limited" override def pageLink(page: Int): String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) parameterPath + s"&$pageNumberFormField=$page" + s"&$executionTag.sort=$encodedSortColumn" + @@ -239,10 +236,8 @@ private[ui] class ExecutionPagedTable( override def pageNumberFormField: String = s"$executionTag.page" - override def goButtonFormPath: String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) + override def goButtonFormPath: String = s"$parameterPath&$executionTag.sort=$encodedSortColumn&$executionTag.desc=$desc#$tableHeaderId" - } override def headers: Seq[Node] = { // Information for each header: title, sortable, tooltip @@ -348,7 +343,6 @@ private[ui] class ExecutionPagedTable( private[ui] class ExecutionTableRowData( - val submissionTime: Long, val duration: Long, val executionUIData: SQLExecutionUIData, val runningJobData: Seq[Int], @@ -357,10 +351,7 @@ private[ui] class ExecutionTableRowData( private[ui] class ExecutionDataSource( - request: HttpServletRequest, - parent: SQLTab, executionData: Seq[SQLExecutionUIData], - basePath: String, currentTime: Long, pageSize: Int, sortColumn: String, @@ -373,20 +364,13 @@ private[ui] class ExecutionDataSource( // in the table so that we can avoid creating duplicate contents during sorting the data private val data = executionData.map(executionRow).sorted(ordering(sortColumn, desc)) - private var _sliceExecutionIds: Set[Int] = _ - override def dataSize: Int = data.size - override def sliceData(from: Int, to: Int): Seq[ExecutionTableRowData] = { - val r = data.slice(from, to) - _sliceExecutionIds = r.map(_.executionUIData.executionId.toInt).toSet - r - } + override def sliceData(from: Int, to: Int): Seq[ExecutionTableRowData] = data.slice(from, to) private def executionRow(executionUIData: SQLExecutionUIData): ExecutionTableRowData = { - val submissionTime = executionUIData.submissionTime val duration = executionUIData.completionTime.map(_.getTime()) - .getOrElse(currentTime) - submissionTime + .getOrElse(currentTime) - executionUIData.submissionTime val runningJobData = if (showRunningJobs) { executionUIData.jobs.filter { @@ -407,7 +391,6 @@ private[ui] class ExecutionDataSource( } else Seq.empty new ExecutionTableRowData( - submissionTime, duration, executionUIData, runningJobData, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala index f68c416941266..234978b9ce176 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.log4j.Level -import org.apache.spark.sql.catalyst.optimizer.EliminateResolvedHint +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide, EliminateResolvedHint} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index a80fc410f5033..3aeb6c5063d20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -24,11 +24,12 @@ import org.apache.log4j.Level import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} import org.apache.spark.sql.{QueryTest, Row, SparkSession, Strategy} +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.execution.{ReusedSubqueryExec, ShuffledRowRDD, SparkPlan} import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec} -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildLeft, BuildRight, SortMergeJoinExec} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 1be9308c06d8c..dd231a52ec300 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -22,6 +22,7 @@ import scala.reflect.ClassTag import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft} +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans.logical.BROADCAST import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AdaptiveTestUtils, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala index 5490246baceea..554990413c28c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 08898f80034e6..44ab3f7d023d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint} @@ -133,7 +134,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSparkSession { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeBroadcastHashJoin( - leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, BuildLeft), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } @@ -145,7 +146,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSparkSession { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeBroadcastHashJoin( - leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, BuildRight), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } @@ -157,7 +158,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSparkSession { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeShuffledHashJoin( - leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, BuildLeft), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } @@ -169,7 +170,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSparkSession { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeShuffledHashJoin( - leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, BuildRight), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index a5ade0d8d7508..879f282e4d05d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index 1f2d4b1b87773..8efbdb30c605c 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -175,19 +175,19 @@ private[ui] class SqlStatsPagedTable( private val (sortColumn, desc, pageSize) = getTableParameters(request, sqlStatsTableTag, "Start Time") - override val dataSource = new SqlStatsTableDataSource(data, pageSize, sortColumn, desc) + private val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) private val parameterPath = s"$basePath/$subPath/?${getParameterOtherTable(request, sqlStatsTableTag)}" + override val dataSource = new SqlStatsTableDataSource(data, pageSize, sortColumn, desc) + override def tableId: String = sqlStatsTableTag override def tableCssClass: String = - "table table-bordered table-sm table-striped " + - "table-head-clickable table-cell-width-limited" + "table table-bordered table-sm table-striped table-head-clickable table-cell-width-limited" override def pageLink(page: Int): String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) parameterPath + s"&$pageNumberFormField=$page" + s"&$sqlStatsTableTag.sort=$encodedSortColumn" + @@ -200,11 +200,9 @@ private[ui] class SqlStatsPagedTable( override def pageNumberFormField: String = s"$sqlStatsTableTag.page" - override def goButtonFormPath: String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) + override def goButtonFormPath: String = s"$parameterPath&$sqlStatsTableTag.sort=$encodedSortColumn" + s"&$sqlStatsTableTag.desc=$desc#$sqlStatsTableTag" - } override def headers: Seq[Node] = { val sqlTableHeadersAndTooltips: Seq[(String, Boolean, Option[String])] = @@ -307,19 +305,19 @@ private[ui] class SessionStatsPagedTable( private val (sortColumn, desc, pageSize) = getTableParameters(request, sessionStatsTableTag, "Start Time") - override val dataSource = new SessionStatsTableDataSource(data, pageSize, sortColumn, desc) + private val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) private val parameterPath = s"$basePath/$subPath/?${getParameterOtherTable(request, sessionStatsTableTag)}" + override val dataSource = new SessionStatsTableDataSource(data, pageSize, sortColumn, desc) + override def tableId: String = sessionStatsTableTag override def tableCssClass: String = - "table table-bordered table-sm table-striped " + - "table-head-clickable table-cell-width-limited" + "table table-bordered table-sm table-striped table-head-clickable table-cell-width-limited" override def pageLink(page: Int): String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) parameterPath + s"&$pageNumberFormField=$page" + s"&$sessionStatsTableTag.sort=$encodedSortColumn" + @@ -332,11 +330,9 @@ private[ui] class SessionStatsPagedTable( override def pageNumberFormField: String = s"$sessionStatsTableTag.page" - override def goButtonFormPath: String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) + override def goButtonFormPath: String = s"$parameterPath&$sessionStatsTableTag.sort=$encodedSortColumn" + s"&$sessionStatsTableTag.desc=$desc#$sessionStatsTableTag" - } override def headers: Seq[Node] = { val sessionTableHeadersAndTooltips: Seq[(String, Boolean, Option[String])] = @@ -370,108 +366,94 @@ private[ui] class SessionStatsPagedTable( } } - private[ui] class SqlStatsTableRow( +private[ui] class SqlStatsTableRow( val jobId: Seq[String], val duration: Long, val executionTime: Long, val executionInfo: ExecutionInfo, val detail: String) - private[ui] class SqlStatsTableDataSource( +private[ui] class SqlStatsTableDataSource( info: Seq[ExecutionInfo], pageSize: Int, sortColumn: String, desc: Boolean) extends PagedDataSource[SqlStatsTableRow](pageSize) { - // Convert ExecutionInfo to SqlStatsTableRow which contains the final contents to show in - // the table so that we can avoid creating duplicate contents during sorting the data - private val data = info.map(sqlStatsTableRow).sorted(ordering(sortColumn, desc)) - - private var _slicedStartTime: Set[Long] = null + // Convert ExecutionInfo to SqlStatsTableRow which contains the final contents to show in + // the table so that we can avoid creating duplicate contents during sorting the data + private val data = info.map(sqlStatsTableRow).sorted(ordering(sortColumn, desc)) - override def dataSize: Int = data.size + override def dataSize: Int = data.size - override def sliceData(from: Int, to: Int): Seq[SqlStatsTableRow] = { - val r = data.slice(from, to) - _slicedStartTime = r.map(_.executionInfo.startTimestamp).toSet - r - } + override def sliceData(from: Int, to: Int): Seq[SqlStatsTableRow] = data.slice(from, to) - private def sqlStatsTableRow(executionInfo: ExecutionInfo): SqlStatsTableRow = { - val duration = executionInfo.totalTime(executionInfo.closeTimestamp) - val executionTime = executionInfo.totalTime(executionInfo.finishTimestamp) - val detail = Option(executionInfo.detail).filter(!_.isEmpty) - .getOrElse(executionInfo.executePlan) - val jobId = executionInfo.jobId.toSeq.sorted + private def sqlStatsTableRow(executionInfo: ExecutionInfo): SqlStatsTableRow = { + val duration = executionInfo.totalTime(executionInfo.closeTimestamp) + val executionTime = executionInfo.totalTime(executionInfo.finishTimestamp) + val detail = Option(executionInfo.detail).filter(!_.isEmpty) + .getOrElse(executionInfo.executePlan) + val jobId = executionInfo.jobId.toSeq.sorted - new SqlStatsTableRow(jobId, duration, executionTime, executionInfo, detail) + new SqlStatsTableRow(jobId, duration, executionTime, executionInfo, detail) + } + /** + * Return Ordering according to sortColumn and desc. + */ + private def ordering(sortColumn: String, desc: Boolean): Ordering[SqlStatsTableRow] = { + val ordering: Ordering[SqlStatsTableRow] = sortColumn match { + case "User" => Ordering.by(_.executionInfo.userName) + case "JobID" => Ordering by (_.jobId.headOption) + case "GroupID" => Ordering.by(_.executionInfo.groupId) + case "Start Time" => Ordering.by(_.executionInfo.startTimestamp) + case "Finish Time" => Ordering.by(_.executionInfo.finishTimestamp) + case "Close Time" => Ordering.by(_.executionInfo.closeTimestamp) + case "Execution Time" => Ordering.by(_.executionTime) + case "Duration" => Ordering.by(_.duration) + case "Statement" => Ordering.by(_.executionInfo.statement) + case "State" => Ordering.by(_.executionInfo.state) + case "Detail" => Ordering.by(_.detail) + case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") } - - /** - * Return Ordering according to sortColumn and desc. - */ - private def ordering(sortColumn: String, desc: Boolean): Ordering[SqlStatsTableRow] = { - val ordering: Ordering[SqlStatsTableRow] = sortColumn match { - case "User" => Ordering.by(_.executionInfo.userName) - case "JobID" => Ordering by (_.jobId.headOption) - case "GroupID" => Ordering.by(_.executionInfo.groupId) - case "Start Time" => Ordering.by(_.executionInfo.startTimestamp) - case "Finish Time" => Ordering.by(_.executionInfo.finishTimestamp) - case "Close Time" => Ordering.by(_.executionInfo.closeTimestamp) - case "Execution Time" => Ordering.by(_.executionTime) - case "Duration" => Ordering.by(_.duration) - case "Statement" => Ordering.by(_.executionInfo.statement) - case "State" => Ordering.by(_.executionInfo.state) - case "Detail" => Ordering.by(_.detail) - case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") - } - if (desc) { - ordering.reverse - } else { - ordering - } + if (desc) { + ordering.reverse + } else { + ordering } - } +} - private[ui] class SessionStatsTableDataSource( +private[ui] class SessionStatsTableDataSource( info: Seq[SessionInfo], pageSize: Int, sortColumn: String, desc: Boolean) extends PagedDataSource[SessionInfo](pageSize) { - // Sorting SessionInfo data - private val data = info.sorted(ordering(sortColumn, desc)) - - private var _slicedStartTime: Set[Long] = null - - override def dataSize: Int = data.size - - override def sliceData(from: Int, to: Int): Seq[SessionInfo] = { - val r = data.slice(from, to) - _slicedStartTime = r.map(_.startTimestamp).toSet - r + // Sorting SessionInfo data + private val data = info.sorted(ordering(sortColumn, desc)) + + override def dataSize: Int = data.size + + override def sliceData(from: Int, to: Int): Seq[SessionInfo] = data.slice(from, to) + + /** + * Return Ordering according to sortColumn and desc. + */ + private def ordering(sortColumn: String, desc: Boolean): Ordering[SessionInfo] = { + val ordering: Ordering[SessionInfo] = sortColumn match { + case "User" => Ordering.by(_.userName) + case "IP" => Ordering.by(_.ip) + case "Session ID" => Ordering.by(_.sessionId) + case "Start Time" => Ordering by (_.startTimestamp) + case "Finish Time" => Ordering.by(_.finishTimestamp) + case "Duration" => Ordering.by(_.totalTime) + case "Total Execute" => Ordering.by(_.totalExecution) + case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") } - - /** - * Return Ordering according to sortColumn and desc. - */ - private def ordering(sortColumn: String, desc: Boolean): Ordering[SessionInfo] = { - val ordering: Ordering[SessionInfo] = sortColumn match { - case "User" => Ordering.by(_.userName) - case "IP" => Ordering.by(_.ip) - case "Session ID" => Ordering.by(_.sessionId) - case "Start Time" => Ordering by (_.startTimestamp) - case "Finish Time" => Ordering.by(_.finishTimestamp) - case "Duration" => Ordering.by(_.totalTime) - case "Total Execute" => Ordering.by(_.totalExecution) - case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") - } - if (desc) { - ordering.reverse - } else { - ordering - } + if (desc) { + ordering.reverse + } else { + ordering } } +}