diff --git a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala new file mode 100644 index 000000000000..97f3803aafce --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.io.File +import java.nio.file.Files + +import scala.collection.mutable.HashMap + +import org.apache.commons.io.FileUtils +import org.apache.commons.lang3.StringUtils +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.util.MutableURLClassLoader + +private[deploy] object DependencyUtils { + + def resolveMavenDependencies( + packagesExclusions: String, + packages: String, + repositories: String, + ivyRepoPath: String): String = { + val exclusions: Seq[String] = + if (!StringUtils.isBlank(packagesExclusions)) { + packagesExclusions.split(",") + } else { + Nil + } + // Create the IvySettings, either load from file or build defaults + val ivySettings = sys.props.get("spark.jars.ivySettings").map { ivySettingsFile => + SparkSubmitUtils.loadIvySettings(ivySettingsFile, Option(repositories), Option(ivyRepoPath)) + }.getOrElse { + SparkSubmitUtils.buildIvySettings(Option(repositories), Option(ivyRepoPath)) + } + + SparkSubmitUtils.resolveMavenCoordinates(packages, ivySettings, exclusions = exclusions) + } + + def createTempDir(): File = { + val targetDir = Files.createTempDirectory("tmp").toFile + // scalastyle:off runtimeaddshutdownhook + Runtime.getRuntime.addShutdownHook(new Thread() { + override def run(): Unit = { + FileUtils.deleteQuietly(targetDir) + } + }) + // scalastyle:on runtimeaddshutdownhook + targetDir + } + + def resolveAndDownloadJars(jars: String, userJar: String): String = { + val targetDir = DependencyUtils.createTempDir() + val hadoopConf = new Configuration() + val sparkProperties = new HashMap[String, String]() + val securityProperties = List("spark.ssl.fs.trustStore", "spark.ssl.trustStore", + "spark.ssl.fs.trustStorePassword", "spark.ssl.trustStorePassword", + "spark.ssl.fs.protocol", "spark.ssl.protocol") + + securityProperties.foreach { pName => + sys.props.get(pName).foreach { pValue => + sparkProperties.put(pName, pValue) + } + } + + Option(jars) + .map { + SparkSubmit.resolveGlobPaths(_, hadoopConf) + .split(",") + .filterNot(_.contains(userJar.split("/").last)) + .mkString(",") + } + .filterNot(_ == "") + .map(SparkSubmit.downloadFileList(_, targetDir, sparkProperties, hadoopConf)) + .orNull + } + + def addJarsToClassPath(jars: String, loader: MutableURLClassLoader): Unit = { + if (jars != null) { + for (jar <- jars.split(",")) { + SparkSubmit.addJarToClasspath(jar, loader) + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 0ea14361b2f7..019780076e7e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -20,7 +20,6 @@ package org.apache.spark.deploy import java.io._ import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowableException} import java.net.URL -import java.nio.file.Files import java.security.{KeyStore, PrivilegedExceptionAction} import java.security.cert.X509Certificate import java.text.ParseException @@ -31,7 +30,6 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, Map} import scala.util.Properties import com.google.common.io.ByteStreams -import org.apache.commons.io.FileUtils import org.apache.commons.lang3.StringUtils import org.apache.hadoop.conf.{Configuration => HadoopConfiguration} import org.apache.hadoop.fs.{FileSystem, Path} @@ -300,28 +298,13 @@ object SparkSubmit extends CommandLineUtils { } val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER val isMesosCluster = clusterManager == MESOS && deployMode == CLUSTER + val isStandAloneCluster = clusterManager == STANDALONE && deployMode == CLUSTER - if (!isMesosCluster) { + if (!isMesosCluster && !isStandAloneCluster) { // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files // too for packages that include Python code - val exclusions: Seq[String] = - if (!StringUtils.isBlank(args.packagesExclusions)) { - args.packagesExclusions.split(",") - } else { - Nil - } - - // Create the IvySettings, either load from file or build defaults - val ivySettings = args.sparkProperties.get("spark.jars.ivySettings").map { ivySettingsFile => - SparkSubmitUtils.loadIvySettings(ivySettingsFile, Option(args.repositories), - Option(args.ivyRepoPath)) - }.getOrElse { - SparkSubmitUtils.buildIvySettings(Option(args.repositories), Option(args.ivyRepoPath)) - } - - val resolvedMavenCoordinates = SparkSubmitUtils.resolveMavenCoordinates(args.packages, - ivySettings, exclusions = exclusions) - + val resolvedMavenCoordinates = DependencyUtils.resolveMavenDependencies( + args.packagesExclusions, args.packages, args.repositories, args.ivyRepoPath) if (!StringUtils.isBlank(resolvedMavenCoordinates)) { args.jars = mergeFileLists(args.jars, resolvedMavenCoordinates) @@ -338,14 +321,7 @@ object SparkSubmit extends CommandLineUtils { } val hadoopConf = new HadoopConfiguration() - val targetDir = Files.createTempDirectory("tmp").toFile - // scalastyle:off runtimeaddshutdownhook - Runtime.getRuntime.addShutdownHook(new Thread() { - override def run(): Unit = { - FileUtils.deleteQuietly(targetDir) - } - }) - // scalastyle:on runtimeaddshutdownhook + val targetDir = DependencyUtils.createTempDir() // Resolve glob path for different resources. args.jars = Option(args.jars).map(resolveGlobPaths(_, hadoopConf)).orNull @@ -473,11 +449,13 @@ object SparkSubmit extends CommandLineUtils { OptionAssigner(args.driverExtraLibraryPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.driver.extraLibraryPath"), - // Mesos only - propagate attributes for dependency resolution at the driver side - OptionAssigner(args.packages, MESOS, CLUSTER, sysProp = "spark.jars.packages"), - OptionAssigner(args.repositories, MESOS, CLUSTER, sysProp = "spark.jars.repositories"), - OptionAssigner(args.ivyRepoPath, MESOS, CLUSTER, sysProp = "spark.jars.ivy"), - OptionAssigner(args.packagesExclusions, MESOS, CLUSTER, sysProp = "spark.jars.excludes"), + // Propagate attributes for dependency resolution at the driver side + OptionAssigner(args.packages, STANDALONE | MESOS, CLUSTER, sysProp = "spark.jars.packages"), + OptionAssigner(args.repositories, STANDALONE | MESOS, CLUSTER, + sysProp = "spark.jars.repositories"), + OptionAssigner(args.ivyRepoPath, STANDALONE | MESOS, CLUSTER, sysProp = "spark.jars.ivy"), + OptionAssigner(args.packagesExclusions, STANDALONE | MESOS, + CLUSTER, sysProp = "spark.jars.excludes"), // Yarn only OptionAssigner(args.queue, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.queue"), @@ -780,7 +758,7 @@ object SparkSubmit extends CommandLineUtils { } } - private def addJarToClasspath(localJar: String, loader: MutableURLClassLoader) { + private[deploy] def addJarToClasspath(localJar: String, loader: MutableURLClassLoader) { val uri = Utils.resolveURI(localJar) uri.getScheme match { case "file" | "local" => @@ -845,7 +823,7 @@ object SparkSubmit extends CommandLineUtils { * Merge a sequence of comma-separated file lists, some of which may be null to indicate * no files, into a single comma-separated string. */ - private def mergeFileLists(lists: String*): String = { + private[deploy] def mergeFileLists(lists: String*): String = { val merged = lists.filterNot(StringUtils.isBlank) .flatMap(_.split(",")) .mkString(",") @@ -968,7 +946,7 @@ object SparkSubmit extends CommandLineUtils { } } - private def resolveGlobPaths(paths: String, hadoopConf: HadoopConfiguration): String = { + private[deploy] def resolveGlobPaths(paths: String, hadoopConf: HadoopConfiguration): String = { require(paths != null, "paths cannot be null.") paths.split(",").map(_.trim).filter(_.nonEmpty).flatMap { path => val uri = Utils.resolveURI(path) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index 6799f78ec0c1..cd3e361530c1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -19,7 +19,10 @@ package org.apache.spark.deploy.worker import java.io.File +import org.apache.commons.lang3.StringUtils + import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.deploy.{DependencyUtils, SparkSubmit} import org.apache.spark.rpc.RpcEnv import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} @@ -51,6 +54,7 @@ object DriverWrapper { new MutableURLClassLoader(Array(userJarUrl), currentLoader) } Thread.currentThread.setContextClassLoader(loader) + setupDependencies(loader, userJar) // Delegate to supplied main class val clazz = Utils.classForName(mainClass) @@ -66,4 +70,23 @@ object DriverWrapper { System.exit(-1) } } + + private def setupDependencies(loader: MutableURLClassLoader, userJar: String): Unit = { + val Seq(packagesExclusions, packages, repositories, ivyRepoPath) = + Seq("spark.jars.excludes", "spark.jars.packages", "spark.jars.repositories", "spark.jars.ivy") + .map(sys.props.get(_).orNull) + + val resolvedMavenCoordinates = DependencyUtils.resolveMavenDependencies(packagesExclusions, + packages, repositories, ivyRepoPath) + val jars = { + val jarsProp = sys.props.get("spark.jars").orNull + if (!StringUtils.isBlank(resolvedMavenCoordinates)) { + SparkSubmit.mergeFileLists(jarsProp, resolvedMavenCoordinates) + } else { + jarsProp + } + } + val localJars = DependencyUtils.resolveAndDownloadJars(jars, userJar) + DependencyUtils.addJarsToClassPath(localJars, loader) + } } diff --git a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala index 79974df2603f..65fa38387b9e 100644 --- a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala @@ -94,14 +94,16 @@ private[ui] trait PagedTable[T] { val _dataSource = dataSource try { val PageData(totalPages, data) = _dataSource.pageData(page) + val pageNavi = pageNavigation(page, _dataSource.pageSize, totalPages)
- {pageNavigation(page, _dataSource.pageSize, totalPages)} + {pageNavi} {headers} {data.map(row)}
+ {pageNavi}
} catch { case e: IndexOutOfBoundsException => diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 8ed51746ab9d..633e740b9c9b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -835,7 +835,8 @@ private[ui] class TaskTableRowData( val speculative: Boolean, val status: String, val taskLocality: String, - val executorIdAndHost: String, + val executorId: String, + val host: String, val launchTime: Long, val duration: Long, val formatDuration: String, @@ -1017,7 +1018,8 @@ private[ui] class TaskDataSource( info.speculative, info.status, info.taskLocality.toString, - s"${info.executorId} / ${info.host}", + info.executorId, + info.host, info.launchTime, duration, formatDuration, @@ -1047,7 +1049,8 @@ private[ui] class TaskDataSource( case "Attempt" => Ordering.by(_.attempt) case "Status" => Ordering.by(_.status) case "Locality Level" => Ordering.by(_.taskLocality) - case "Executor ID / Host" => Ordering.by(_.executorIdAndHost) + case "Executor ID" => Ordering.by(_.executorId) + case "Host" => Ordering.by(_.host) case "Launch Time" => Ordering.by(_.launchTime) case "Duration" => Ordering.by(_.duration) case "Scheduler Delay" => Ordering.by(_.schedulerDelay) @@ -1200,7 +1203,7 @@ private[ui] class TaskPagedTable( val taskHeadersAndCssClasses: Seq[(String, String)] = Seq( ("Index", ""), ("ID", ""), ("Attempt", ""), ("Status", ""), ("Locality Level", ""), - ("Executor ID / Host", ""), ("Launch Time", ""), ("Duration", ""), + ("Executor ID", ""), ("Host", ""), ("Launch Time", ""), ("Duration", ""), ("Scheduler Delay", TaskDetailsClassNames.SCHEDULER_DELAY), ("Task Deserialization Time", TaskDetailsClassNames.TASK_DESERIALIZATION_TIME), ("GC Time", ""), @@ -1271,8 +1274,9 @@ private[ui] class TaskPagedTable( {if (task.speculative) s"${task.attempt} (speculative)" else task.attempt.toString} {task.status} {task.taskLocality} + {task.executorId} -
{task.executorIdAndHost}
+
{task.host}
{ task.logs.map { diff --git a/dev/appveyor-install-dependencies.ps1 b/dev/appveyor-install-dependencies.ps1 index a357fbf59f6c..e6afb1855885 100644 --- a/dev/appveyor-install-dependencies.ps1 +++ b/dev/appveyor-install-dependencies.ps1 @@ -26,7 +26,7 @@ Function InstallR { } $urlPath = "" - $latestVer = $(ConvertFrom-JSON $(Invoke-WebRequest http://rversions.r-pkg.org/r-release-win).Content).version + $latestVer = $(ConvertFrom-JSON $(Invoke-WebRequest https://rversions.r-pkg.org/r-release-win).Content).version If ($rVer -ne $latestVer) { $urlPath = ("old/" + $rVer + "/") } diff --git a/dev/check-license b/dev/check-license index 678e73fd60f1..8cee09a53e08 100755 --- a/dev/check-license +++ b/dev/check-license @@ -20,7 +20,7 @@ acquire_rat_jar () { - URL="http://repo1.maven.org/maven2/org/apache/rat/apache-rat/${RAT_VERSION}/apache-rat-${RAT_VERSION}.jar" + URL="https://repo1.maven.org/maven2/org/apache/rat/apache-rat/${RAT_VERSION}/apache-rat-${RAT_VERSION}.jar" JAR="$rat_jar" @@ -58,7 +58,7 @@ else declare java_cmd=java fi -export RAT_VERSION=0.11 +export RAT_VERSION=0.12 export rat_jar="$FWDIR"/lib/apache-rat-${RAT_VERSION}.jar mkdir -p "$FWDIR"/lib diff --git a/dev/mima b/dev/mima index 85b09dbb1bf2..5501589b7900 100755 --- a/dev/mima +++ b/dev/mima @@ -41,7 +41,7 @@ $JAVA_CMD \ -cp "$TOOLS_CLASSPATH:$OLD_DEPS_CLASSPATH" \ org.apache.spark.tools.GenerateMIMAIgnore -echo -e "q\n" | build/sbt -DcopyDependencies=false "$@" mimaReportBinaryIssues | grep -v -e "info.*Resolving" +echo -e "q\n" | build/sbt -mem 4096 -DcopyDependencies=false "$@" mimaReportBinaryIssues | grep -v -e "info.*Resolving" ret_val=$? if [ $ret_val != 0 ]; then diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 2ac2383d699c..ee231a934a3a 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1308,6 +1308,13 @@ the following case-insensitive options: + + sessionInitStatement + + After each database session is opened to the remote DB and before starting to read data, this option executes a custom SQL statement (or a PL/SQL block). Use this to implement session initialization code. Example: option("sessionInitStatement", """BEGIN execute immediate 'alter session set "_serial_direct_read"=true'; END;""") + + + truncate diff --git a/project/build.properties b/project/build.properties index d339865ab915..b19518fd7aa1 100644 --- a/project/build.properties +++ b/project/build.properties @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -sbt.version=0.13.13 +sbt.version=0.13.16 diff --git a/project/plugins.sbt b/project/plugins.sbt index 84d123999085..2b49c297ff9c 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,25 +1,34 @@ +// need to make changes to uptake sbt 1.0 support in "com.eed3si9n" % "sbt-assembly" % "1.14.5" addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2") -addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "5.0.1") +// sbt 1.0.0 support: https://github.com/typesafehub/sbteclipse/issues/343 +addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "5.1.0") +// sbt 1.0.0 support: https://github.com/jrudolph/sbt-dependency-graph/issues/134 addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.8.2") +// need to make changes to uptake sbt 1.0 support in "org.scalastyle" %% "scalastyle-sbt-plugin" % "0.9.0" addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.8.0") -addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.12") +addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.17") +// sbt 1.0.0 support: https://github.com/AlpineNow/junit_xml_listener/issues/6 addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.1") +// need to make changes to uptake sbt 1.0 support in "com.eed3si9n" % "sbt-unidoc" % "0.4.1" addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.3") +// need to make changes to uptake sbt 1.0 support in "com.cavorite" % "sbt-avro-1-7" % "1.1.2" addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2") +// sbt 1.0.0 support: https://github.com/spray/sbt-revolver/issues/62 addSbtPlugin("io.spray" % "sbt-revolver" % "0.8.0") libraryDependencies += "org.ow2.asm" % "asm" % "5.1" libraryDependencies += "org.ow2.asm" % "asm-commons" % "5.1" +// sbt 1.0.0 support: https://github.com/ihji/sbt-antlr4/issues/14 addSbtPlugin("com.simplytyped" % "sbt-antlr4" % "0.7.11") // Spark uses a custom fork of the sbt-pom-reader plugin which contains a patch to fix issues diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index a8dc76b846c2..097530230cbc 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -16,6 +16,7 @@ # import sys +import os if sys.version > '3': basestring = str @@ -23,7 +24,7 @@ from pyspark import since, keyword_only, SparkContext from pyspark.ml.base import Estimator, Model, Transformer from pyspark.ml.param import Param, Params -from pyspark.ml.util import JavaMLWriter, JavaMLReader, MLReadable, MLWritable +from pyspark.ml.util import * from pyspark.ml.wrapper import JavaParams from pyspark.ml.common import inherit_doc @@ -130,13 +131,16 @@ def copy(self, extra=None): @since("2.0.0") def write(self): """Returns an MLWriter instance for this ML instance.""" - return JavaMLWriter(self) + allStagesAreJava = PipelineSharedReadWrite.checkStagesForJava(self.getStages()) + if allStagesAreJava: + return JavaMLWriter(self) + return PipelineWriter(self) @classmethod @since("2.0.0") def read(cls): """Returns an MLReader instance for this class.""" - return JavaMLReader(cls) + return PipelineReader(cls) @classmethod def _from_java(cls, java_stage): @@ -171,6 +175,76 @@ def _to_java(self): return _java_obj +@inherit_doc +class PipelineWriter(MLWriter): + """ + (Private) Specialization of :py:class:`MLWriter` for :py:class:`Pipeline` types + """ + + def __init__(self, instance): + super(PipelineWriter, self).__init__() + self.instance = instance + + def saveImpl(self, path): + stages = self.instance.getStages() + PipelineSharedReadWrite.validateStages(stages) + PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path) + + +@inherit_doc +class PipelineReader(MLReader): + """ + (Private) Specialization of :py:class:`MLReader` for :py:class:`Pipeline` types + """ + + def __init__(self, cls): + super(PipelineReader, self).__init__() + self.cls = cls + + def load(self, path): + metadata = DefaultParamsReader.loadMetadata(path, self.sc) + if 'language' not in metadata['paramMap'] or metadata['paramMap']['language'] != 'Python': + return JavaMLReader(self.cls).load(path) + else: + uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path) + return Pipeline(stages=stages)._resetUid(uid) + + +@inherit_doc +class PipelineModelWriter(MLWriter): + """ + (Private) Specialization of :py:class:`MLWriter` for :py:class:`PipelineModel` types + """ + + def __init__(self, instance): + super(PipelineModelWriter, self).__init__() + self.instance = instance + + def saveImpl(self, path): + stages = self.instance.stages + PipelineSharedReadWrite.validateStages(stages) + PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path) + + +@inherit_doc +class PipelineModelReader(MLReader): + """ + (Private) Specialization of :py:class:`MLReader` for :py:class:`PipelineModel` types + """ + + def __init__(self, cls): + super(PipelineModelReader, self).__init__() + self.cls = cls + + def load(self, path): + metadata = DefaultParamsReader.loadMetadata(path, self.sc) + if 'language' not in metadata['paramMap'] or metadata['paramMap']['language'] != 'Python': + return JavaMLReader(self.cls).load(path) + else: + uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path) + return PipelineModel(stages=stages)._resetUid(uid) + + @inherit_doc class PipelineModel(Model, MLReadable, MLWritable): """ @@ -204,13 +278,16 @@ def copy(self, extra=None): @since("2.0.0") def write(self): """Returns an MLWriter instance for this ML instance.""" - return JavaMLWriter(self) + allStagesAreJava = PipelineSharedReadWrite.checkStagesForJava(self.stages) + if allStagesAreJava: + return JavaMLWriter(self) + return PipelineModelWriter(self) @classmethod @since("2.0.0") def read(cls): """Returns an MLReader instance for this class.""" - return JavaMLReader(cls) + return PipelineModelReader(cls) @classmethod def _from_java(cls, java_stage): @@ -242,3 +319,72 @@ def _to_java(self): JavaParams._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages) return _java_obj + + +@inherit_doc +class PipelineSharedReadWrite(): + """ + .. note:: DeveloperApi + + Functions for :py:class:`MLReader` and :py:class:`MLWriter` shared between + :py:class:`Pipeline` and :py:class:`PipelineModel` + + .. versionadded:: 2.3.0 + """ + + @staticmethod + def checkStagesForJava(stages): + return all(isinstance(stage, JavaMLWritable) for stage in stages) + + @staticmethod + def validateStages(stages): + """ + Check that all stages are Writable + """ + for stage in stages: + if not isinstance(stage, MLWritable): + raise ValueError("Pipeline write will fail on this pipeline " + + "because stage %s of type %s is not MLWritable", + stage.uid, type(stage)) + + @staticmethod + def saveImpl(instance, stages, sc, path): + """ + Save metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel` + - save metadata to path/metadata + - save stages to stages/IDX_UID + """ + stageUids = [stage.uid for stage in stages] + jsonParams = {'stageUids': stageUids, 'language': 'Python'} + DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap=jsonParams) + stagesDir = os.path.join(path, "stages") + for index, stage in enumerate(stages): + stage.write().save(PipelineSharedReadWrite + .getStagePath(stage.uid, index, len(stages), stagesDir)) + + @staticmethod + def load(metadata, sc, path): + """ + Load metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel` + + :return: (UID, list of stages) + """ + stagesDir = os.path.join(path, "stages") + stageUids = metadata['paramMap']['stageUids'] + stages = [] + for index, stageUid in enumerate(stageUids): + stagePath = \ + PipelineSharedReadWrite.getStagePath(stageUid, index, len(stageUids), stagesDir) + stage = DefaultParamsReader.loadParamsInstance(stagePath, sc) + stages.append(stage) + return (metadata['uid'], stages) + + @staticmethod + def getStagePath(stageUid, stageIdx, numStages, stagesDir): + """ + Get path for saving the given stage. + """ + stageIdxDigits = len(str(numStages)) + stageDir = str(stageIdx).zfill(stageIdxDigits) + "_" + stageUid + stagePath = os.path.join(stagesDir, stageDir) + return stagePath diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 6aecc7fe8707..0495973d2f62 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -123,7 +123,7 @@ def _transform(self, dataset): return dataset -class MockUnaryTransformer(UnaryTransformer): +class MockUnaryTransformer(UnaryTransformer, DefaultParamsReadable, DefaultParamsWritable): shift = Param(Params._dummy(), "shift", "The amount by which to shift " + "data in a DataFrame", @@ -150,7 +150,7 @@ def outputDataType(self): def validateInputType(self, inputType): if inputType != DoubleType(): raise TypeError("Bad input type: {}. ".format(inputType) + - "Requires Integer.") + "Requires Double.") class MockEstimator(Estimator, HasFake): @@ -1063,7 +1063,7 @@ def _compare_pipelines(self, m1, m2): """ self.assertEqual(m1.uid, m2.uid) self.assertEqual(type(m1), type(m2)) - if isinstance(m1, JavaParams): + if isinstance(m1, JavaParams) or isinstance(m1, Transformer): self.assertEqual(len(m1.params), len(m2.params)) for p in m1.params: self._compare_params(m1, m2, p) @@ -1142,6 +1142,35 @@ def test_nested_pipeline_persistence(self): except OSError: pass + def test_python_transformer_pipeline_persistence(self): + """ + Pipeline[MockUnaryTransformer, Binarizer] + """ + temp_path = tempfile.mkdtemp() + + try: + df = self.spark.range(0, 10).toDF('input') + tf = MockUnaryTransformer(shiftVal=2)\ + .setInputCol("input").setOutputCol("shiftedInput") + tf2 = Binarizer(threshold=6, inputCol="shiftedInput", outputCol="binarized") + pl = Pipeline(stages=[tf, tf2]) + model = pl.fit(df) + + pipeline_path = temp_path + "/pipeline" + pl.save(pipeline_path) + loaded_pipeline = Pipeline.load(pipeline_path) + self._compare_pipelines(pl, loaded_pipeline) + + model_path = temp_path + "/pipeline-model" + model.save(model_path) + loaded_model = PipelineModel.load(model_path) + self._compare_pipelines(model, loaded_model) + finally: + try: + rmtree(temp_path) + except OSError: + pass + def test_onevsrest(self): temp_path = tempfile.mkdtemp() df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ecb941c5fa9e..733d80e9d46c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -844,24 +844,47 @@ object SQLConf { .stringConf .createWithDefaultFunction(() => TimeZone.getDefault.getID) + val WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD = + buildConf("spark.sql.windowExec.buffer.in.memory.threshold") + .internal() + .doc("Threshold for number of rows guaranteed to be held in memory by the window operator") + .intConf + .createWithDefault(4096) + val WINDOW_EXEC_BUFFER_SPILL_THRESHOLD = buildConf("spark.sql.windowExec.buffer.spill.threshold") .internal() - .doc("Threshold for number of rows buffered in window operator") + .doc("Threshold for number of rows to be spilled by window operator") .intConf - .createWithDefault(4096) + .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt) + + val SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD = + buildConf("spark.sql.sortMergeJoinExec.buffer.in.memory.threshold") + .internal() + .doc("Threshold for number of rows guaranteed to be held in memory by the sort merge " + + "join operator") + .intConf + .createWithDefault(Int.MaxValue) val SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD = buildConf("spark.sql.sortMergeJoinExec.buffer.spill.threshold") .internal() - .doc("Threshold for number of rows buffered in sort merge join operator") + .doc("Threshold for number of rows to be spilled by sort merge join operator") .intConf - .createWithDefault(Int.MaxValue) + .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt) + + val CARTESIAN_PRODUCT_EXEC_BUFFER_IN_MEMORY_THRESHOLD = + buildConf("spark.sql.cartesianProductExec.buffer.in.memory.threshold") + .internal() + .doc("Threshold for number of rows guaranteed to be held in memory by the cartesian " + + "product operator") + .intConf + .createWithDefault(4096) val CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD = buildConf("spark.sql.cartesianProductExec.buffer.spill.threshold") .internal() - .doc("Threshold for number of rows buffered in cartesian product operator") + .doc("Threshold for number of rows to be spilled by cartesian product operator") .intConf .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt) @@ -1137,11 +1160,19 @@ class SQLConf extends Serializable with Logging { def joinReorderDPStarFilter: Boolean = getConf(SQLConf.JOIN_REORDER_DP_STAR_FILTER) + def windowExecBufferInMemoryThreshold: Int = getConf(WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD) + def windowExecBufferSpillThreshold: Int = getConf(WINDOW_EXEC_BUFFER_SPILL_THRESHOLD) + def sortMergeJoinExecBufferInMemoryThreshold: Int = + getConf(SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD) + def sortMergeJoinExecBufferSpillThreshold: Int = getConf(SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD) + def cartesianProductExecBufferInMemoryThreshold: Int = + getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_IN_MEMORY_THRESHOLD) + def cartesianProductExecBufferSpillThreshold: Int = getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala index c4d383421f97..ac282ea2e94f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala @@ -31,16 +31,16 @@ import org.apache.spark.storage.BlockManager import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator} /** - * An append-only array for [[UnsafeRow]]s that spills content to disk when there a predefined - * threshold of rows is reached. + * An append-only array for [[UnsafeRow]]s that strictly keeps content in an in-memory array + * until [[numRowsInMemoryBufferThreshold]] is reached post which it will switch to a mode which + * would flush to disk after [[numRowsSpillThreshold]] is met (or before if there is + * excessive memory consumption). Setting these threshold involves following trade-offs: * - * Setting spill threshold faces following trade-off: - * - * - If the spill threshold is too high, the in-memory array may occupy more memory than is - * available, resulting in OOM. - * - If the spill threshold is too low, we spill frequently and incur unnecessary disk writes. - * This may lead to a performance regression compared to the normal case of using an - * [[ArrayBuffer]] or [[Array]]. + * - If [[numRowsInMemoryBufferThreshold]] is too high, the in-memory array may occupy more memory + * than is available, resulting in OOM. + * - If [[numRowsSpillThreshold]] is too low, data will be spilled frequently and lead to + * excessive disk writes. This may lead to a performance regression compared to the normal case + * of using an [[ArrayBuffer]] or [[Array]]. */ private[sql] class ExternalAppendOnlyUnsafeRowArray( taskMemoryManager: TaskMemoryManager, @@ -49,9 +49,10 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( taskContext: TaskContext, initialSize: Int, pageSizeBytes: Long, + numRowsInMemoryBufferThreshold: Int, numRowsSpillThreshold: Int) extends Logging { - def this(numRowsSpillThreshold: Int) { + def this(numRowsInMemoryBufferThreshold: Int, numRowsSpillThreshold: Int) { this( TaskContext.get().taskMemoryManager(), SparkEnv.get.blockManager, @@ -59,11 +60,12 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( TaskContext.get(), 1024, SparkEnv.get.memoryManager.pageSizeBytes, + numRowsInMemoryBufferThreshold, numRowsSpillThreshold) } private val initialSizeOfInMemoryBuffer = - Math.min(DefaultInitialSizeOfInMemoryBuffer, numRowsSpillThreshold) + Math.min(DefaultInitialSizeOfInMemoryBuffer, numRowsInMemoryBufferThreshold) private val inMemoryBuffer = if (initialSizeOfInMemoryBuffer > 0) { new ArrayBuffer[UnsafeRow](initialSizeOfInMemoryBuffer) @@ -102,11 +104,11 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( } def add(unsafeRow: UnsafeRow): Unit = { - if (numRows < numRowsSpillThreshold) { + if (numRows < numRowsInMemoryBufferThreshold) { inMemoryBuffer += unsafeRow.copy() } else { if (spillableArray == null) { - logInfo(s"Reached spill threshold of $numRowsSpillThreshold rows, switching to " + + logInfo(s"Reached spill threshold of $numRowsInMemoryBufferThreshold rows, switching to " + s"${classOf[UnsafeExternalSorter].getName}") // We will not sort the rows, so prefixComparator and recordComparator are null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index b56fbd4284d2..4accf54a1823 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} +import org.apache.spark.sql.execution.joins.ReorderJoinPredicates import org.apache.spark.sql.types.{BinaryType, DateType, DecimalType, TimestampType, _} import org.apache.spark.util.Utils @@ -103,6 +104,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { protected def preparations: Seq[Rule[SparkPlan]] = Seq( python.ExtractPythonUDFs, PlanSubqueries(sparkSession), + new ReorderJoinPredicates, EnsureRequirements(sparkSession.sessionState.conf), CollapseCodegenStages(sparkSession.sessionState.conf), ReuseExchange(sparkSession.sessionState.conf), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 96a8a51da18e..05b00058618a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -138,6 +138,8 @@ class JDBCOptions( case "REPEATABLE_READ" => Connection.TRANSACTION_REPEATABLE_READ case "SERIALIZABLE" => Connection.TRANSACTION_SERIALIZABLE } + // An option to execute custom SQL before fetching data from the remote DB + val sessionInitStatement = parameters.get(JDBC_SESSION_INIT_STATEMENT) } object JDBCOptions { @@ -161,4 +163,5 @@ object JDBCOptions { val JDBC_CREATE_TABLE_COLUMN_TYPES = newOption("createTableColumnTypes") val JDBC_BATCH_INSERT_SIZE = newOption("batchsize") val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel") + val JDBC_SESSION_INIT_STATEMENT = newOption("sessionInitStatement") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 24e13697c0c9..3274be91d481 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -273,6 +273,21 @@ private[jdbc] class JDBCRDD( import scala.collection.JavaConverters._ dialect.beforeFetch(conn, options.asProperties.asScala.toMap) + // This executes a generic SQL statement (or PL/SQL block) before reading + // the table/query via JDBC. Use this feature to initialize the database + // session environment, e.g. for optimizations and/or troubleshooting. + options.sessionInitStatement match { + case Some(sql) => + val statement = conn.prepareStatement(sql) + logInfo(s"Executing sessionInitStatement: $sql") + try { + statement.execute() + } finally { + statement.close() + } + case None => + } + // H2's JDBC driver does not support the setSchema() method. We pass a // fully-qualified table name in the SELECT statement. I don't know how to // talk about a table in a completely portable way. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala index f38098695131..4d261dd422bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala @@ -35,11 +35,12 @@ class UnsafeCartesianRDD( left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int, + inMemoryBufferThreshold: Int, spillThreshold: Int) extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) { override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = { - val rowArray = new ExternalAppendOnlyUnsafeRowArray(spillThreshold) + val rowArray = new ExternalAppendOnlyUnsafeRowArray(inMemoryBufferThreshold, spillThreshold) val partition = split.asInstanceOf[CartesianPartition] rdd2.iterator(partition.s2, context).foreach(rowArray.add) @@ -71,9 +72,12 @@ case class CartesianProductExec( val leftResults = left.execute().asInstanceOf[RDD[UnsafeRow]] val rightResults = right.execute().asInstanceOf[RDD[UnsafeRow]] - val spillThreshold = sqlContext.conf.cartesianProductExecBufferSpillThreshold - - val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size, spillThreshold) + val pair = new UnsafeCartesianRDD( + leftResults, + rightResults, + right.output.size, + sqlContext.conf.cartesianProductExecBufferInMemoryThreshold, + sqlContext.conf.cartesianProductExecBufferSpillThreshold) pair.mapPartitionsWithIndexInternal { (index, iter) => val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema) val filtered = if (condition.isDefined) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala new file mode 100644 index 000000000000..534d8c5689c2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.SparkPlan + +/** + * When the physical operators are created for JOIN, the ordering of join keys is based on order + * in which the join keys appear in the user query. That might not match with the output + * partitioning of the join node's children (thus leading to extra sort / shuffle being + * introduced). This rule will change the ordering of the join keys to match with the + * partitioning of the join nodes' children. + */ +class ReorderJoinPredicates extends Rule[SparkPlan] { + private def reorderJoinKeys( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + leftPartitioning: Partitioning, + rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = { + + def reorder( + expectedOrderOfKeys: Seq[Expression], + currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + val leftKeysBuffer = ArrayBuffer[Expression]() + val rightKeysBuffer = ArrayBuffer[Expression]() + + expectedOrderOfKeys.foreach(expression => { + val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression)) + leftKeysBuffer.append(leftKeys(index)) + rightKeysBuffer.append(rightKeys(index)) + }) + (leftKeysBuffer, rightKeysBuffer) + } + + if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) { + leftPartitioning match { + case HashPartitioning(leftExpressions, _) + if leftExpressions.length == leftKeys.length && + leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) => + reorder(leftExpressions, leftKeys) + + case _ => rightPartitioning match { + case HashPartitioning(rightExpressions, _) + if rightExpressions.length == rightKeys.length && + rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) => + reorder(rightExpressions, rightKeys) + + case _ => (leftKeys, rightKeys) + } + } + } else { + (leftKeys, rightKeys) + } + } + + def apply(plan: SparkPlan): SparkPlan = plan.transformUp { + case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => + val (reorderedLeftKeys, reorderedRightKeys) = + reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) + BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, + left, right) + + case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => + val (reorderedLeftKeys, reorderedRightKeys) = + reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) + ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, + left, right) + + case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right) => + val (reorderedLeftKeys, reorderedRightKeys) = + reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) + SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index f41fa14213df..91d214e1978e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -130,9 +130,14 @@ case class SortMergeJoinExec( sqlContext.conf.sortMergeJoinExecBufferSpillThreshold } + private def getInMemoryThreshold: Int = { + sqlContext.conf.sortMergeJoinExecBufferInMemoryThreshold + } + protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") val spillThreshold = getSpillThreshold + val inMemoryThreshold = getInMemoryThreshold left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => val boundCondition: (InternalRow) => Boolean = { condition.map { cond => @@ -158,6 +163,7 @@ case class SortMergeJoinExec( keyOrdering, RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter), + inMemoryThreshold, spillThreshold ) private[this] val joinRow = new JoinedRow @@ -201,6 +207,7 @@ case class SortMergeJoinExec( keyOrdering, streamedIter = RowIterator.fromScala(leftIter), bufferedIter = RowIterator.fromScala(rightIter), + inMemoryThreshold, spillThreshold ) val rightNullRow = new GenericInternalRow(right.output.length) @@ -214,6 +221,7 @@ case class SortMergeJoinExec( keyOrdering, streamedIter = RowIterator.fromScala(rightIter), bufferedIter = RowIterator.fromScala(leftIter), + inMemoryThreshold, spillThreshold ) val leftNullRow = new GenericInternalRow(left.output.length) @@ -247,6 +255,7 @@ case class SortMergeJoinExec( keyOrdering, RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter), + inMemoryThreshold, spillThreshold ) private[this] val joinRow = new JoinedRow @@ -281,6 +290,7 @@ case class SortMergeJoinExec( keyOrdering, RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter), + inMemoryThreshold, spillThreshold ) private[this] val joinRow = new JoinedRow @@ -322,6 +332,7 @@ case class SortMergeJoinExec( keyOrdering, RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter), + inMemoryThreshold, spillThreshold ) private[this] val joinRow = new JoinedRow @@ -420,8 +431,10 @@ case class SortMergeJoinExec( val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName val spillThreshold = getSpillThreshold + val inMemoryThreshold = getInMemoryThreshold - ctx.addMutableState(clsName, matches, s"$matches = new $clsName($spillThreshold);") + ctx.addMutableState(clsName, matches, + s"$matches = new $clsName($inMemoryThreshold, $spillThreshold);") // Copy the left keys as class members so they could be used in next function call. val matchedKeyVars = copyKeys(ctx, leftKeyVars) @@ -626,6 +639,9 @@ case class SortMergeJoinExec( * @param streamedIter an input whose rows will be streamed. * @param bufferedIter an input whose rows will be buffered to construct sequences of rows that * have the same join key. + * @param inMemoryThreshold Threshold for number of rows guaranteed to be held in memory by + * internal buffer + * @param spillThreshold Threshold for number of rows to be spilled by internal buffer */ private[joins] class SortMergeJoinScanner( streamedKeyGenerator: Projection, @@ -633,7 +649,8 @@ private[joins] class SortMergeJoinScanner( keyOrdering: Ordering[InternalRow], streamedIter: RowIterator, bufferedIter: RowIterator, - bufferThreshold: Int) { + inMemoryThreshold: Int, + spillThreshold: Int) { private[this] var streamedRow: InternalRow = _ private[this] var streamedRowKey: InternalRow = _ private[this] var bufferedRow: InternalRow = _ @@ -644,7 +661,8 @@ private[joins] class SortMergeJoinScanner( */ private[this] var matchJoinKey: InternalRow = _ /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */ - private[this] val bufferedMatches = new ExternalAppendOnlyUnsafeRowArray(bufferThreshold) + private[this] val bufferedMatches = + new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) // Initialization (note: do _not_ want to advance streamed here). advancedBufferedToRowWithNullFreeJoinKey() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index f8bb667e3006..800a2ea3f399 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -292,6 +292,7 @@ case class WindowExec( // Unwrap the expressions and factories from the map. val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1) val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray + val inMemoryThreshold = sqlContext.conf.windowExecBufferInMemoryThreshold val spillThreshold = sqlContext.conf.windowExecBufferSpillThreshold // Start processing. @@ -322,7 +323,8 @@ case class WindowExec( val inputFields = child.output.length val buffer: ExternalAppendOnlyUnsafeRowArray = - new ExternalAppendOnlyUnsafeRowArray(spillThreshold) + new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) + var bufferIterator: Iterator[UnsafeRow] = _ val windowFunctionResult = new SpecificInternalRow(expressions.map(_.dataType)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 895ca196a7a5..0008d503a2cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -665,7 +665,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { test("test SortMergeJoin (with spill)") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1", - "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> "0") { + "spark.sql.sortMergeJoinExec.buffer.in.memory.threshold" -> "0", + "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> "1") { assertSpilled(sparkContext, "inner join") { checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala index 031ac38c17d7..efe28afab08e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala @@ -67,7 +67,10 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark { benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int => var sum = 0L for (_ <- 0L until iterations) { - val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold) + val array = new ExternalAppendOnlyUnsafeRowArray( + ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer, + numSpillThreshold) + rows.foreach(x => array.add(x)) val iterator = array.generateIterator() @@ -143,7 +146,7 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark { benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int => var sum = 0L for (_ <- 0L until iterations) { - val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold) + val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold, numSpillThreshold) rows.foreach(x => array.add(x)) val iterator = array.generateIterator() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala index 53c41639942b..ecc7264d7944 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala @@ -31,7 +31,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar override def afterAll(): Unit = TaskContext.unset() - private def withExternalArray(spillThreshold: Int) + private def withExternalArray(inMemoryThreshold: Int, spillThreshold: Int) (f: ExternalAppendOnlyUnsafeRowArray => Unit): Unit = { sc = new SparkContext("local", "test", new SparkConf(false)) @@ -45,6 +45,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar taskContext, 1024, SparkEnv.get.memoryManager.pageSizeBytes, + inMemoryThreshold, spillThreshold) try f(array) finally { array.clear() @@ -109,9 +110,9 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar assert(getNumBytesSpilled > 0) } - test("insert rows less than the spillThreshold") { - val spillThreshold = 100 - withExternalArray(spillThreshold) { array => + test("insert rows less than the inMemoryThreshold") { + val (inMemoryThreshold, spillThreshold) = (100, 50) + withExternalArray(inMemoryThreshold, spillThreshold) { array => assert(array.isEmpty) val expectedValues = populateRows(array, 1) @@ -122,8 +123,8 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar // Add more rows (but not too many to trigger switch to [[UnsafeExternalSorter]]) // Verify that NO spill has happened - populateRows(array, spillThreshold - 1, expectedValues) - assert(array.length == spillThreshold) + populateRows(array, inMemoryThreshold - 1, expectedValues) + assert(array.length == inMemoryThreshold) assertNoSpill() val iterator2 = validateData(array, expectedValues) @@ -133,20 +134,42 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } } - test("insert rows more than the spillThreshold to force spill") { - val spillThreshold = 100 - withExternalArray(spillThreshold) { array => - val numValuesInserted = 20 * spillThreshold - + test("insert rows more than the inMemoryThreshold but less than spillThreshold") { + val (inMemoryThreshold, spillThreshold) = (10, 50) + withExternalArray(inMemoryThreshold, spillThreshold) { array => assert(array.isEmpty) - val expectedValues = populateRows(array, 1) - assert(array.length == 1) + val expectedValues = populateRows(array, inMemoryThreshold - 1) + assert(array.length == (inMemoryThreshold - 1)) + val iterator1 = validateData(array, expectedValues) + assertNoSpill() + + // Add more rows to trigger switch to [[UnsafeExternalSorter]] but not too many to cause a + // spill to happen. Verify that NO spill has happened + populateRows(array, spillThreshold - expectedValues.length - 1, expectedValues) + assert(array.length == spillThreshold - 1) + assertNoSpill() + + val iterator2 = validateData(array, expectedValues) + assert(!iterator2.hasNext) + assert(!iterator1.hasNext) + intercept[ConcurrentModificationException](iterator1.next()) + } + } + + test("insert rows enough to force spill") { + val (inMemoryThreshold, spillThreshold) = (20, 10) + withExternalArray(inMemoryThreshold, spillThreshold) { array => + assert(array.isEmpty) + val expectedValues = populateRows(array, inMemoryThreshold - 1) + assert(array.length == (inMemoryThreshold - 1)) val iterator1 = validateData(array, expectedValues) + assertNoSpill() - // Populate more rows to trigger spill. Verify that spill has happened - populateRows(array, numValuesInserted - 1, expectedValues) - assert(array.length == numValuesInserted) + // Add more rows to trigger switch to [[UnsafeExternalSorter]] and cause a spill to happen. + // Verify that spill has happened + populateRows(array, 2, expectedValues) + assert(array.length == inMemoryThreshold + 1) assertSpill() val iterator2 = validateData(array, expectedValues) @@ -158,7 +181,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("iterator on an empty array should be empty") { - withExternalArray(spillThreshold = 10) { array => + withExternalArray(inMemoryThreshold = 4, spillThreshold = 10) { array => val iterator = array.generateIterator() assert(array.isEmpty) assert(array.length == 0) @@ -167,7 +190,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("generate iterator with negative start index") { - withExternalArray(spillThreshold = 2) { array => + withExternalArray(inMemoryThreshold = 100, spillThreshold = 56) { array => val exception = intercept[ArrayIndexOutOfBoundsException](array.generateIterator(startIndex = -10)) @@ -178,8 +201,8 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("generate iterator with start index exceeding array's size (without spill)") { - val spillThreshold = 2 - withExternalArray(spillThreshold) { array => + val (inMemoryThreshold, spillThreshold) = (20, 100) + withExternalArray(inMemoryThreshold, spillThreshold) { array => populateRows(array, spillThreshold / 2) val exception = @@ -191,8 +214,8 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("generate iterator with start index exceeding array's size (with spill)") { - val spillThreshold = 2 - withExternalArray(spillThreshold) { array => + val (inMemoryThreshold, spillThreshold) = (20, 100) + withExternalArray(inMemoryThreshold, spillThreshold) { array => populateRows(array, spillThreshold * 2) val exception = @@ -205,10 +228,10 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("generate iterator with custom start index (without spill)") { - val spillThreshold = 10 - withExternalArray(spillThreshold) { array => - val expectedValues = populateRows(array, spillThreshold) - val startIndex = spillThreshold / 2 + val (inMemoryThreshold, spillThreshold) = (20, 100) + withExternalArray(inMemoryThreshold, spillThreshold) { array => + val expectedValues = populateRows(array, inMemoryThreshold) + val startIndex = inMemoryThreshold / 2 val iterator = array.generateIterator(startIndex = startIndex) for (i <- startIndex until expectedValues.length) { checkIfValueExists(iterator, expectedValues(i)) @@ -217,8 +240,8 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("generate iterator with custom start index (with spill)") { - val spillThreshold = 10 - withExternalArray(spillThreshold) { array => + val (inMemoryThreshold, spillThreshold) = (20, 100) + withExternalArray(inMemoryThreshold, spillThreshold) { array => val expectedValues = populateRows(array, spillThreshold * 10) val startIndex = spillThreshold * 2 val iterator = array.generateIterator(startIndex = startIndex) @@ -229,7 +252,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("test iterator invalidation (without spill)") { - withExternalArray(spillThreshold = 10) { array => + withExternalArray(inMemoryThreshold = 10, spillThreshold = 100) { array => // insert 2 rows, iterate until the first row populateRows(array, 2) @@ -254,9 +277,9 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("test iterator invalidation (with spill)") { - val spillThreshold = 10 - withExternalArray(spillThreshold) { array => - // Populate enough rows so that spill has happens + val (inMemoryThreshold, spillThreshold) = (2, 10) + withExternalArray(inMemoryThreshold, spillThreshold) { array => + // Populate enough rows so that spill happens populateRows(array, spillThreshold * 2) assertSpill() @@ -281,7 +304,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("clear on an empty the array") { - withExternalArray(spillThreshold = 2) { array => + withExternalArray(inMemoryThreshold = 2, spillThreshold = 3) { array => val iterator = array.generateIterator() assert(!iterator.hasNext) @@ -299,10 +322,10 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("clear array (without spill)") { - val spillThreshold = 10 - withExternalArray(spillThreshold) { array => + val (inMemoryThreshold, spillThreshold) = (10, 100) + withExternalArray(inMemoryThreshold, spillThreshold) { array => // Populate rows ... but not enough to trigger spill - populateRows(array, spillThreshold / 2) + populateRows(array, inMemoryThreshold / 2) assertNoSpill() // Clear the array @@ -311,21 +334,21 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar // Re-populate few rows so that there is no spill // Verify the data. Verify that there was no spill - val expectedValues = populateRows(array, spillThreshold / 3) + val expectedValues = populateRows(array, inMemoryThreshold / 2) validateData(array, expectedValues) assertNoSpill() // Populate more rows .. enough to not trigger a spill. // Verify the data. Verify that there was no spill - populateRows(array, spillThreshold / 3, expectedValues) + populateRows(array, inMemoryThreshold / 2, expectedValues) validateData(array, expectedValues) assertNoSpill() } } test("clear array (with spill)") { - val spillThreshold = 10 - withExternalArray(spillThreshold) { array => + val (inMemoryThreshold, spillThreshold) = (10, 20) + withExternalArray(inMemoryThreshold, spillThreshold) { array => // Populate enough rows to trigger spill populateRows(array, spillThreshold * 2) val bytesSpilled = getNumBytesSpilled diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala index a9f3fb355c77..a57514c256b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala @@ -477,7 +477,8 @@ class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext { |WINDOW w1 AS (ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDiNG AND CURRENT RoW) """.stripMargin) - withSQLConf("spark.sql.windowExec.buffer.spill.threshold" -> "1") { + withSQLConf("spark.sql.windowExec.buffer.in.memory.threshold" -> "1", + "spark.sql.windowExec.buffer.spill.threshold" -> "2") { assertSpilled(sparkContext, "test with low buffer spill threshold") { checkAnswer(actual, expected) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 4c4364688941..8dc11d80c306 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -1044,4 +1044,35 @@ class JDBCSuite extends SparkFunSuite assert(sql("select * from people_view").count() == 3) } } + + test("SPARK-21519: option sessionInitStatement, run SQL to initialize the database session.") { + val initSQL1 = "SET @MYTESTVAR 21519" + val df1 = spark.read.format("jdbc") + .option("url", urlWithUserAndPass) + .option("dbtable", "(SELECT NVL(@MYTESTVAR, -1))") + .option("sessionInitStatement", initSQL1) + .load() + assert(df1.collect() === Array(Row(21519))) + + val initSQL2 = "SET SCHEMA DUMMY" + val df2 = spark.read.format("jdbc") + .option("url", urlWithUserAndPass) + .option("dbtable", "TEST.PEOPLE") + .option("sessionInitStatement", initSQL2) + .load() + val e = intercept[SparkException] {df2.collect()}.getMessage + assert(e.contains("""Schema "DUMMY" not found""")) + + sql( + s""" + |CREATE OR REPLACE TEMPORARY VIEW test_sessionInitStatement + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$urlWithUserAndPass', + |dbtable '(SELECT NVL(@MYTESTVAR1, -1), NVL(@MYTESTVAR2, -1))', + |sessionInitStatement 'SET @MYTESTVAR1 21519; SET @MYTESTVAR2 1234') + """.stripMargin) + + val df3 = sql("SELECT * FROM test_sessionInitStatement") + assert(df3.collect() === Array(Row(21519, 1234))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index ba0ca666b5c1..eb9e6458fc61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -543,6 +543,65 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { ) } + test("SPARK-19122 Re-order join predicates if they match with the child's output partitioning") { + val bucketedTableTestSpec = BucketedTableTestSpec( + Some(BucketSpec(8, Seq("i", "j", "k"), Seq("i", "j", "k"))), + numPartitions = 1, + expectedShuffle = false, + expectedSort = false) + + // If the set of join columns is equal to the set of bucketed + sort columns, then + // the order of join keys in the query should not matter and there should not be any shuffle + // and sort added in the query plan + Seq( + Seq("i", "j", "k"), + Seq("i", "k", "j"), + Seq("j", "k", "i"), + Seq("j", "i", "k"), + Seq("k", "j", "i"), + Seq("k", "i", "j") + ).foreach(joinKeys => { + testBucketing( + bucketedTableTestSpecLeft = bucketedTableTestSpec, + bucketedTableTestSpecRight = bucketedTableTestSpec, + joinCondition = joinCondition(joinKeys) + ) + }) + } + + test("SPARK-19122 No re-ordering should happen if set of join columns != set of child's " + + "partitioning columns") { + + // join predicates is a super set of child's partitioning columns + val bucketedTableTestSpec1 = + BucketedTableTestSpec(Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))), numPartitions = 1) + testBucketing( + bucketedTableTestSpecLeft = bucketedTableTestSpec1, + bucketedTableTestSpecRight = bucketedTableTestSpec1, + joinCondition = joinCondition(Seq("i", "j", "k")) + ) + + // child's partitioning columns is a super set of join predicates + val bucketedTableTestSpec2 = + BucketedTableTestSpec(Some(BucketSpec(8, Seq("i", "j", "k"), Seq("i", "j", "k"))), + numPartitions = 1) + testBucketing( + bucketedTableTestSpecLeft = bucketedTableTestSpec2, + bucketedTableTestSpecRight = bucketedTableTestSpec2, + joinCondition = joinCondition(Seq("i", "j")) + ) + + // set of child's partitioning columns != set join predicates (despite the lengths of the + // sets are same) + val bucketedTableTestSpec3 = + BucketedTableTestSpec(Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))), numPartitions = 1) + testBucketing( + bucketedTableTestSpecLeft = bucketedTableTestSpec3, + bucketedTableTestSpecRight = bucketedTableTestSpec3, + joinCondition = joinCondition(Seq("j", "k")) + ) + } + test("error if there exists any malformed bucket files") { withTable("bucketed_table") { df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")