diff --git a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala
new file mode 100644
index 000000000000..97f3803aafce
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy
+
+import java.io.File
+import java.nio.file.Files
+
+import scala.collection.mutable.HashMap
+
+import org.apache.commons.io.FileUtils
+import org.apache.commons.lang3.StringUtils
+import org.apache.hadoop.conf.Configuration
+
+import org.apache.spark.util.MutableURLClassLoader
+
+private[deploy] object DependencyUtils {
+
+ def resolveMavenDependencies(
+ packagesExclusions: String,
+ packages: String,
+ repositories: String,
+ ivyRepoPath: String): String = {
+ val exclusions: Seq[String] =
+ if (!StringUtils.isBlank(packagesExclusions)) {
+ packagesExclusions.split(",")
+ } else {
+ Nil
+ }
+ // Create the IvySettings, either load from file or build defaults
+ val ivySettings = sys.props.get("spark.jars.ivySettings").map { ivySettingsFile =>
+ SparkSubmitUtils.loadIvySettings(ivySettingsFile, Option(repositories), Option(ivyRepoPath))
+ }.getOrElse {
+ SparkSubmitUtils.buildIvySettings(Option(repositories), Option(ivyRepoPath))
+ }
+
+ SparkSubmitUtils.resolveMavenCoordinates(packages, ivySettings, exclusions = exclusions)
+ }
+
+ def createTempDir(): File = {
+ val targetDir = Files.createTempDirectory("tmp").toFile
+ // scalastyle:off runtimeaddshutdownhook
+ Runtime.getRuntime.addShutdownHook(new Thread() {
+ override def run(): Unit = {
+ FileUtils.deleteQuietly(targetDir)
+ }
+ })
+ // scalastyle:on runtimeaddshutdownhook
+ targetDir
+ }
+
+ def resolveAndDownloadJars(jars: String, userJar: String): String = {
+ val targetDir = DependencyUtils.createTempDir()
+ val hadoopConf = new Configuration()
+ val sparkProperties = new HashMap[String, String]()
+ val securityProperties = List("spark.ssl.fs.trustStore", "spark.ssl.trustStore",
+ "spark.ssl.fs.trustStorePassword", "spark.ssl.trustStorePassword",
+ "spark.ssl.fs.protocol", "spark.ssl.protocol")
+
+ securityProperties.foreach { pName =>
+ sys.props.get(pName).foreach { pValue =>
+ sparkProperties.put(pName, pValue)
+ }
+ }
+
+ Option(jars)
+ .map {
+ SparkSubmit.resolveGlobPaths(_, hadoopConf)
+ .split(",")
+ .filterNot(_.contains(userJar.split("/").last))
+ .mkString(",")
+ }
+ .filterNot(_ == "")
+ .map(SparkSubmit.downloadFileList(_, targetDir, sparkProperties, hadoopConf))
+ .orNull
+ }
+
+ def addJarsToClassPath(jars: String, loader: MutableURLClassLoader): Unit = {
+ if (jars != null) {
+ for (jar <- jars.split(",")) {
+ SparkSubmit.addJarToClasspath(jar, loader)
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 0ea14361b2f7..019780076e7e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -20,7 +20,6 @@ package org.apache.spark.deploy
import java.io._
import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowableException}
import java.net.URL
-import java.nio.file.Files
import java.security.{KeyStore, PrivilegedExceptionAction}
import java.security.cert.X509Certificate
import java.text.ParseException
@@ -31,7 +30,6 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, Map}
import scala.util.Properties
import com.google.common.io.ByteStreams
-import org.apache.commons.io.FileUtils
import org.apache.commons.lang3.StringUtils
import org.apache.hadoop.conf.{Configuration => HadoopConfiguration}
import org.apache.hadoop.fs.{FileSystem, Path}
@@ -300,28 +298,13 @@ object SparkSubmit extends CommandLineUtils {
}
val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER
val isMesosCluster = clusterManager == MESOS && deployMode == CLUSTER
+ val isStandAloneCluster = clusterManager == STANDALONE && deployMode == CLUSTER
- if (!isMesosCluster) {
+ if (!isMesosCluster && !isStandAloneCluster) {
// Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files
// too for packages that include Python code
- val exclusions: Seq[String] =
- if (!StringUtils.isBlank(args.packagesExclusions)) {
- args.packagesExclusions.split(",")
- } else {
- Nil
- }
-
- // Create the IvySettings, either load from file or build defaults
- val ivySettings = args.sparkProperties.get("spark.jars.ivySettings").map { ivySettingsFile =>
- SparkSubmitUtils.loadIvySettings(ivySettingsFile, Option(args.repositories),
- Option(args.ivyRepoPath))
- }.getOrElse {
- SparkSubmitUtils.buildIvySettings(Option(args.repositories), Option(args.ivyRepoPath))
- }
-
- val resolvedMavenCoordinates = SparkSubmitUtils.resolveMavenCoordinates(args.packages,
- ivySettings, exclusions = exclusions)
-
+ val resolvedMavenCoordinates = DependencyUtils.resolveMavenDependencies(
+ args.packagesExclusions, args.packages, args.repositories, args.ivyRepoPath)
if (!StringUtils.isBlank(resolvedMavenCoordinates)) {
args.jars = mergeFileLists(args.jars, resolvedMavenCoordinates)
@@ -338,14 +321,7 @@ object SparkSubmit extends CommandLineUtils {
}
val hadoopConf = new HadoopConfiguration()
- val targetDir = Files.createTempDirectory("tmp").toFile
- // scalastyle:off runtimeaddshutdownhook
- Runtime.getRuntime.addShutdownHook(new Thread() {
- override def run(): Unit = {
- FileUtils.deleteQuietly(targetDir)
- }
- })
- // scalastyle:on runtimeaddshutdownhook
+ val targetDir = DependencyUtils.createTempDir()
// Resolve glob path for different resources.
args.jars = Option(args.jars).map(resolveGlobPaths(_, hadoopConf)).orNull
@@ -473,11 +449,13 @@ object SparkSubmit extends CommandLineUtils {
OptionAssigner(args.driverExtraLibraryPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES,
sysProp = "spark.driver.extraLibraryPath"),
- // Mesos only - propagate attributes for dependency resolution at the driver side
- OptionAssigner(args.packages, MESOS, CLUSTER, sysProp = "spark.jars.packages"),
- OptionAssigner(args.repositories, MESOS, CLUSTER, sysProp = "spark.jars.repositories"),
- OptionAssigner(args.ivyRepoPath, MESOS, CLUSTER, sysProp = "spark.jars.ivy"),
- OptionAssigner(args.packagesExclusions, MESOS, CLUSTER, sysProp = "spark.jars.excludes"),
+ // Propagate attributes for dependency resolution at the driver side
+ OptionAssigner(args.packages, STANDALONE | MESOS, CLUSTER, sysProp = "spark.jars.packages"),
+ OptionAssigner(args.repositories, STANDALONE | MESOS, CLUSTER,
+ sysProp = "spark.jars.repositories"),
+ OptionAssigner(args.ivyRepoPath, STANDALONE | MESOS, CLUSTER, sysProp = "spark.jars.ivy"),
+ OptionAssigner(args.packagesExclusions, STANDALONE | MESOS,
+ CLUSTER, sysProp = "spark.jars.excludes"),
// Yarn only
OptionAssigner(args.queue, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.queue"),
@@ -780,7 +758,7 @@ object SparkSubmit extends CommandLineUtils {
}
}
- private def addJarToClasspath(localJar: String, loader: MutableURLClassLoader) {
+ private[deploy] def addJarToClasspath(localJar: String, loader: MutableURLClassLoader) {
val uri = Utils.resolveURI(localJar)
uri.getScheme match {
case "file" | "local" =>
@@ -845,7 +823,7 @@ object SparkSubmit extends CommandLineUtils {
* Merge a sequence of comma-separated file lists, some of which may be null to indicate
* no files, into a single comma-separated string.
*/
- private def mergeFileLists(lists: String*): String = {
+ private[deploy] def mergeFileLists(lists: String*): String = {
val merged = lists.filterNot(StringUtils.isBlank)
.flatMap(_.split(","))
.mkString(",")
@@ -968,7 +946,7 @@ object SparkSubmit extends CommandLineUtils {
}
}
- private def resolveGlobPaths(paths: String, hadoopConf: HadoopConfiguration): String = {
+ private[deploy] def resolveGlobPaths(paths: String, hadoopConf: HadoopConfiguration): String = {
require(paths != null, "paths cannot be null.")
paths.split(",").map(_.trim).filter(_.nonEmpty).flatMap { path =>
val uri = Utils.resolveURI(path)
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
index 6799f78ec0c1..cd3e361530c1 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
@@ -19,7 +19,10 @@ package org.apache.spark.deploy.worker
import java.io.File
+import org.apache.commons.lang3.StringUtils
+
import org.apache.spark.{SecurityManager, SparkConf}
+import org.apache.spark.deploy.{DependencyUtils, SparkSubmit}
import org.apache.spark.rpc.RpcEnv
import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils}
@@ -51,6 +54,7 @@ object DriverWrapper {
new MutableURLClassLoader(Array(userJarUrl), currentLoader)
}
Thread.currentThread.setContextClassLoader(loader)
+ setupDependencies(loader, userJar)
// Delegate to supplied main class
val clazz = Utils.classForName(mainClass)
@@ -66,4 +70,23 @@ object DriverWrapper {
System.exit(-1)
}
}
+
+ private def setupDependencies(loader: MutableURLClassLoader, userJar: String): Unit = {
+ val Seq(packagesExclusions, packages, repositories, ivyRepoPath) =
+ Seq("spark.jars.excludes", "spark.jars.packages", "spark.jars.repositories", "spark.jars.ivy")
+ .map(sys.props.get(_).orNull)
+
+ val resolvedMavenCoordinates = DependencyUtils.resolveMavenDependencies(packagesExclusions,
+ packages, repositories, ivyRepoPath)
+ val jars = {
+ val jarsProp = sys.props.get("spark.jars").orNull
+ if (!StringUtils.isBlank(resolvedMavenCoordinates)) {
+ SparkSubmit.mergeFileLists(jarsProp, resolvedMavenCoordinates)
+ } else {
+ jarsProp
+ }
+ }
+ val localJars = DependencyUtils.resolveAndDownloadJars(jars, userJar)
+ DependencyUtils.addJarsToClassPath(localJars, loader)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala
index 79974df2603f..65fa38387b9e 100644
--- a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala
@@ -94,14 +94,16 @@ private[ui] trait PagedTable[T] {
val _dataSource = dataSource
try {
val PageData(totalPages, data) = _dataSource.pageData(page)
+ val pageNavi = pageNavigation(page, _dataSource.pageSize, totalPages)
truncate |
diff --git a/project/build.properties b/project/build.properties
index d339865ab915..b19518fd7aa1 100644
--- a/project/build.properties
+++ b/project/build.properties
@@ -14,4 +14,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-sbt.version=0.13.13
+sbt.version=0.13.16
diff --git a/project/plugins.sbt b/project/plugins.sbt
index 84d123999085..2b49c297ff9c 100644
--- a/project/plugins.sbt
+++ b/project/plugins.sbt
@@ -1,25 +1,34 @@
+// need to make changes to uptake sbt 1.0 support in "com.eed3si9n" % "sbt-assembly" % "1.14.5"
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2")
-addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "5.0.1")
+// sbt 1.0.0 support: https://github.com/typesafehub/sbteclipse/issues/343
+addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "5.1.0")
+// sbt 1.0.0 support: https://github.com/jrudolph/sbt-dependency-graph/issues/134
addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.8.2")
+// need to make changes to uptake sbt 1.0 support in "org.scalastyle" %% "scalastyle-sbt-plugin" % "0.9.0"
addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.8.0")
-addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.12")
+addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.17")
+// sbt 1.0.0 support: https://github.com/AlpineNow/junit_xml_listener/issues/6
addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.1")
+// need to make changes to uptake sbt 1.0 support in "com.eed3si9n" % "sbt-unidoc" % "0.4.1"
addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.3")
+// need to make changes to uptake sbt 1.0 support in "com.cavorite" % "sbt-avro-1-7" % "1.1.2"
addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2")
+// sbt 1.0.0 support: https://github.com/spray/sbt-revolver/issues/62
addSbtPlugin("io.spray" % "sbt-revolver" % "0.8.0")
libraryDependencies += "org.ow2.asm" % "asm" % "5.1"
libraryDependencies += "org.ow2.asm" % "asm-commons" % "5.1"
+// sbt 1.0.0 support: https://github.com/ihji/sbt-antlr4/issues/14
addSbtPlugin("com.simplytyped" % "sbt-antlr4" % "0.7.11")
// Spark uses a custom fork of the sbt-pom-reader plugin which contains a patch to fix issues
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index a8dc76b846c2..097530230cbc 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -16,6 +16,7 @@
#
import sys
+import os
if sys.version > '3':
basestring = str
@@ -23,7 +24,7 @@
from pyspark import since, keyword_only, SparkContext
from pyspark.ml.base import Estimator, Model, Transformer
from pyspark.ml.param import Param, Params
-from pyspark.ml.util import JavaMLWriter, JavaMLReader, MLReadable, MLWritable
+from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaParams
from pyspark.ml.common import inherit_doc
@@ -130,13 +131,16 @@ def copy(self, extra=None):
@since("2.0.0")
def write(self):
"""Returns an MLWriter instance for this ML instance."""
- return JavaMLWriter(self)
+ allStagesAreJava = PipelineSharedReadWrite.checkStagesForJava(self.getStages())
+ if allStagesAreJava:
+ return JavaMLWriter(self)
+ return PipelineWriter(self)
@classmethod
@since("2.0.0")
def read(cls):
"""Returns an MLReader instance for this class."""
- return JavaMLReader(cls)
+ return PipelineReader(cls)
@classmethod
def _from_java(cls, java_stage):
@@ -171,6 +175,76 @@ def _to_java(self):
return _java_obj
+@inherit_doc
+class PipelineWriter(MLWriter):
+ """
+ (Private) Specialization of :py:class:`MLWriter` for :py:class:`Pipeline` types
+ """
+
+ def __init__(self, instance):
+ super(PipelineWriter, self).__init__()
+ self.instance = instance
+
+ def saveImpl(self, path):
+ stages = self.instance.getStages()
+ PipelineSharedReadWrite.validateStages(stages)
+ PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path)
+
+
+@inherit_doc
+class PipelineReader(MLReader):
+ """
+ (Private) Specialization of :py:class:`MLReader` for :py:class:`Pipeline` types
+ """
+
+ def __init__(self, cls):
+ super(PipelineReader, self).__init__()
+ self.cls = cls
+
+ def load(self, path):
+ metadata = DefaultParamsReader.loadMetadata(path, self.sc)
+ if 'language' not in metadata['paramMap'] or metadata['paramMap']['language'] != 'Python':
+ return JavaMLReader(self.cls).load(path)
+ else:
+ uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path)
+ return Pipeline(stages=stages)._resetUid(uid)
+
+
+@inherit_doc
+class PipelineModelWriter(MLWriter):
+ """
+ (Private) Specialization of :py:class:`MLWriter` for :py:class:`PipelineModel` types
+ """
+
+ def __init__(self, instance):
+ super(PipelineModelWriter, self).__init__()
+ self.instance = instance
+
+ def saveImpl(self, path):
+ stages = self.instance.stages
+ PipelineSharedReadWrite.validateStages(stages)
+ PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path)
+
+
+@inherit_doc
+class PipelineModelReader(MLReader):
+ """
+ (Private) Specialization of :py:class:`MLReader` for :py:class:`PipelineModel` types
+ """
+
+ def __init__(self, cls):
+ super(PipelineModelReader, self).__init__()
+ self.cls = cls
+
+ def load(self, path):
+ metadata = DefaultParamsReader.loadMetadata(path, self.sc)
+ if 'language' not in metadata['paramMap'] or metadata['paramMap']['language'] != 'Python':
+ return JavaMLReader(self.cls).load(path)
+ else:
+ uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path)
+ return PipelineModel(stages=stages)._resetUid(uid)
+
+
@inherit_doc
class PipelineModel(Model, MLReadable, MLWritable):
"""
@@ -204,13 +278,16 @@ def copy(self, extra=None):
@since("2.0.0")
def write(self):
"""Returns an MLWriter instance for this ML instance."""
- return JavaMLWriter(self)
+ allStagesAreJava = PipelineSharedReadWrite.checkStagesForJava(self.stages)
+ if allStagesAreJava:
+ return JavaMLWriter(self)
+ return PipelineModelWriter(self)
@classmethod
@since("2.0.0")
def read(cls):
"""Returns an MLReader instance for this class."""
- return JavaMLReader(cls)
+ return PipelineModelReader(cls)
@classmethod
def _from_java(cls, java_stage):
@@ -242,3 +319,72 @@ def _to_java(self):
JavaParams._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages)
return _java_obj
+
+
+@inherit_doc
+class PipelineSharedReadWrite():
+ """
+ .. note:: DeveloperApi
+
+ Functions for :py:class:`MLReader` and :py:class:`MLWriter` shared between
+ :py:class:`Pipeline` and :py:class:`PipelineModel`
+
+ .. versionadded:: 2.3.0
+ """
+
+ @staticmethod
+ def checkStagesForJava(stages):
+ return all(isinstance(stage, JavaMLWritable) for stage in stages)
+
+ @staticmethod
+ def validateStages(stages):
+ """
+ Check that all stages are Writable
+ """
+ for stage in stages:
+ if not isinstance(stage, MLWritable):
+ raise ValueError("Pipeline write will fail on this pipeline " +
+ "because stage %s of type %s is not MLWritable",
+ stage.uid, type(stage))
+
+ @staticmethod
+ def saveImpl(instance, stages, sc, path):
+ """
+ Save metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel`
+ - save metadata to path/metadata
+ - save stages to stages/IDX_UID
+ """
+ stageUids = [stage.uid for stage in stages]
+ jsonParams = {'stageUids': stageUids, 'language': 'Python'}
+ DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap=jsonParams)
+ stagesDir = os.path.join(path, "stages")
+ for index, stage in enumerate(stages):
+ stage.write().save(PipelineSharedReadWrite
+ .getStagePath(stage.uid, index, len(stages), stagesDir))
+
+ @staticmethod
+ def load(metadata, sc, path):
+ """
+ Load metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel`
+
+ :return: (UID, list of stages)
+ """
+ stagesDir = os.path.join(path, "stages")
+ stageUids = metadata['paramMap']['stageUids']
+ stages = []
+ for index, stageUid in enumerate(stageUids):
+ stagePath = \
+ PipelineSharedReadWrite.getStagePath(stageUid, index, len(stageUids), stagesDir)
+ stage = DefaultParamsReader.loadParamsInstance(stagePath, sc)
+ stages.append(stage)
+ return (metadata['uid'], stages)
+
+ @staticmethod
+ def getStagePath(stageUid, stageIdx, numStages, stagesDir):
+ """
+ Get path for saving the given stage.
+ """
+ stageIdxDigits = len(str(numStages))
+ stageDir = str(stageIdx).zfill(stageIdxDigits) + "_" + stageUid
+ stagePath = os.path.join(stagesDir, stageDir)
+ return stagePath
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 6aecc7fe8707..0495973d2f62 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -123,7 +123,7 @@ def _transform(self, dataset):
return dataset
-class MockUnaryTransformer(UnaryTransformer):
+class MockUnaryTransformer(UnaryTransformer, DefaultParamsReadable, DefaultParamsWritable):
shift = Param(Params._dummy(), "shift", "The amount by which to shift " +
"data in a DataFrame",
@@ -150,7 +150,7 @@ def outputDataType(self):
def validateInputType(self, inputType):
if inputType != DoubleType():
raise TypeError("Bad input type: {}. ".format(inputType) +
- "Requires Integer.")
+ "Requires Double.")
class MockEstimator(Estimator, HasFake):
@@ -1063,7 +1063,7 @@ def _compare_pipelines(self, m1, m2):
"""
self.assertEqual(m1.uid, m2.uid)
self.assertEqual(type(m1), type(m2))
- if isinstance(m1, JavaParams):
+ if isinstance(m1, JavaParams) or isinstance(m1, Transformer):
self.assertEqual(len(m1.params), len(m2.params))
for p in m1.params:
self._compare_params(m1, m2, p)
@@ -1142,6 +1142,35 @@ def test_nested_pipeline_persistence(self):
except OSError:
pass
+ def test_python_transformer_pipeline_persistence(self):
+ """
+ Pipeline[MockUnaryTransformer, Binarizer]
+ """
+ temp_path = tempfile.mkdtemp()
+
+ try:
+ df = self.spark.range(0, 10).toDF('input')
+ tf = MockUnaryTransformer(shiftVal=2)\
+ .setInputCol("input").setOutputCol("shiftedInput")
+ tf2 = Binarizer(threshold=6, inputCol="shiftedInput", outputCol="binarized")
+ pl = Pipeline(stages=[tf, tf2])
+ model = pl.fit(df)
+
+ pipeline_path = temp_path + "/pipeline"
+ pl.save(pipeline_path)
+ loaded_pipeline = Pipeline.load(pipeline_path)
+ self._compare_pipelines(pl, loaded_pipeline)
+
+ model_path = temp_path + "/pipeline-model"
+ model.save(model_path)
+ loaded_model = PipelineModel.load(model_path)
+ self._compare_pipelines(model, loaded_model)
+ finally:
+ try:
+ rmtree(temp_path)
+ except OSError:
+ pass
+
def test_onevsrest(self):
temp_path = tempfile.mkdtemp()
df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index ecb941c5fa9e..733d80e9d46c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -844,24 +844,47 @@ object SQLConf {
.stringConf
.createWithDefaultFunction(() => TimeZone.getDefault.getID)
+ val WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD =
+ buildConf("spark.sql.windowExec.buffer.in.memory.threshold")
+ .internal()
+ .doc("Threshold for number of rows guaranteed to be held in memory by the window operator")
+ .intConf
+ .createWithDefault(4096)
+
val WINDOW_EXEC_BUFFER_SPILL_THRESHOLD =
buildConf("spark.sql.windowExec.buffer.spill.threshold")
.internal()
- .doc("Threshold for number of rows buffered in window operator")
+ .doc("Threshold for number of rows to be spilled by window operator")
.intConf
- .createWithDefault(4096)
+ .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt)
+
+ val SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD =
+ buildConf("spark.sql.sortMergeJoinExec.buffer.in.memory.threshold")
+ .internal()
+ .doc("Threshold for number of rows guaranteed to be held in memory by the sort merge " +
+ "join operator")
+ .intConf
+ .createWithDefault(Int.MaxValue)
val SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD =
buildConf("spark.sql.sortMergeJoinExec.buffer.spill.threshold")
.internal()
- .doc("Threshold for number of rows buffered in sort merge join operator")
+ .doc("Threshold for number of rows to be spilled by sort merge join operator")
.intConf
- .createWithDefault(Int.MaxValue)
+ .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt)
+
+ val CARTESIAN_PRODUCT_EXEC_BUFFER_IN_MEMORY_THRESHOLD =
+ buildConf("spark.sql.cartesianProductExec.buffer.in.memory.threshold")
+ .internal()
+ .doc("Threshold for number of rows guaranteed to be held in memory by the cartesian " +
+ "product operator")
+ .intConf
+ .createWithDefault(4096)
val CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD =
buildConf("spark.sql.cartesianProductExec.buffer.spill.threshold")
.internal()
- .doc("Threshold for number of rows buffered in cartesian product operator")
+ .doc("Threshold for number of rows to be spilled by cartesian product operator")
.intConf
.createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt)
@@ -1137,11 +1160,19 @@ class SQLConf extends Serializable with Logging {
def joinReorderDPStarFilter: Boolean = getConf(SQLConf.JOIN_REORDER_DP_STAR_FILTER)
+ def windowExecBufferInMemoryThreshold: Int = getConf(WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD)
+
def windowExecBufferSpillThreshold: Int = getConf(WINDOW_EXEC_BUFFER_SPILL_THRESHOLD)
+ def sortMergeJoinExecBufferInMemoryThreshold: Int =
+ getConf(SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD)
+
def sortMergeJoinExecBufferSpillThreshold: Int =
getConf(SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD)
+ def cartesianProductExecBufferInMemoryThreshold: Int =
+ getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_IN_MEMORY_THRESHOLD)
+
def cartesianProductExecBufferSpillThreshold: Int =
getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala
index c4d383421f97..ac282ea2e94f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala
@@ -31,16 +31,16 @@ import org.apache.spark.storage.BlockManager
import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator}
/**
- * An append-only array for [[UnsafeRow]]s that spills content to disk when there a predefined
- * threshold of rows is reached.
+ * An append-only array for [[UnsafeRow]]s that strictly keeps content in an in-memory array
+ * until [[numRowsInMemoryBufferThreshold]] is reached post which it will switch to a mode which
+ * would flush to disk after [[numRowsSpillThreshold]] is met (or before if there is
+ * excessive memory consumption). Setting these threshold involves following trade-offs:
*
- * Setting spill threshold faces following trade-off:
- *
- * - If the spill threshold is too high, the in-memory array may occupy more memory than is
- * available, resulting in OOM.
- * - If the spill threshold is too low, we spill frequently and incur unnecessary disk writes.
- * This may lead to a performance regression compared to the normal case of using an
- * [[ArrayBuffer]] or [[Array]].
+ * - If [[numRowsInMemoryBufferThreshold]] is too high, the in-memory array may occupy more memory
+ * than is available, resulting in OOM.
+ * - If [[numRowsSpillThreshold]] is too low, data will be spilled frequently and lead to
+ * excessive disk writes. This may lead to a performance regression compared to the normal case
+ * of using an [[ArrayBuffer]] or [[Array]].
*/
private[sql] class ExternalAppendOnlyUnsafeRowArray(
taskMemoryManager: TaskMemoryManager,
@@ -49,9 +49,10 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray(
taskContext: TaskContext,
initialSize: Int,
pageSizeBytes: Long,
+ numRowsInMemoryBufferThreshold: Int,
numRowsSpillThreshold: Int) extends Logging {
- def this(numRowsSpillThreshold: Int) {
+ def this(numRowsInMemoryBufferThreshold: Int, numRowsSpillThreshold: Int) {
this(
TaskContext.get().taskMemoryManager(),
SparkEnv.get.blockManager,
@@ -59,11 +60,12 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray(
TaskContext.get(),
1024,
SparkEnv.get.memoryManager.pageSizeBytes,
+ numRowsInMemoryBufferThreshold,
numRowsSpillThreshold)
}
private val initialSizeOfInMemoryBuffer =
- Math.min(DefaultInitialSizeOfInMemoryBuffer, numRowsSpillThreshold)
+ Math.min(DefaultInitialSizeOfInMemoryBuffer, numRowsInMemoryBufferThreshold)
private val inMemoryBuffer = if (initialSizeOfInMemoryBuffer > 0) {
new ArrayBuffer[UnsafeRow](initialSizeOfInMemoryBuffer)
@@ -102,11 +104,11 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray(
}
def add(unsafeRow: UnsafeRow): Unit = {
- if (numRows < numRowsSpillThreshold) {
+ if (numRows < numRowsInMemoryBufferThreshold) {
inMemoryBuffer += unsafeRow.copy()
} else {
if (spillableArray == null) {
- logInfo(s"Reached spill threshold of $numRowsSpillThreshold rows, switching to " +
+ logInfo(s"Reached spill threshold of $numRowsInMemoryBufferThreshold rows, switching to " +
s"${classOf[UnsafeExternalSorter].getName}")
// We will not sort the rows, so prefixComparator and recordComparator are null
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index b56fbd4284d2..4accf54a1823 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand}
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange}
+import org.apache.spark.sql.execution.joins.ReorderJoinPredicates
import org.apache.spark.sql.types.{BinaryType, DateType, DecimalType, TimestampType, _}
import org.apache.spark.util.Utils
@@ -103,6 +104,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
protected def preparations: Seq[Rule[SparkPlan]] = Seq(
python.ExtractPythonUDFs,
PlanSubqueries(sparkSession),
+ new ReorderJoinPredicates,
EnsureRequirements(sparkSession.sessionState.conf),
CollapseCodegenStages(sparkSession.sessionState.conf),
ReuseExchange(sparkSession.sessionState.conf),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
index 96a8a51da18e..05b00058618a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
@@ -138,6 +138,8 @@ class JDBCOptions(
case "REPEATABLE_READ" => Connection.TRANSACTION_REPEATABLE_READ
case "SERIALIZABLE" => Connection.TRANSACTION_SERIALIZABLE
}
+ // An option to execute custom SQL before fetching data from the remote DB
+ val sessionInitStatement = parameters.get(JDBC_SESSION_INIT_STATEMENT)
}
object JDBCOptions {
@@ -161,4 +163,5 @@ object JDBCOptions {
val JDBC_CREATE_TABLE_COLUMN_TYPES = newOption("createTableColumnTypes")
val JDBC_BATCH_INSERT_SIZE = newOption("batchsize")
val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel")
+ val JDBC_SESSION_INIT_STATEMENT = newOption("sessionInitStatement")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 24e13697c0c9..3274be91d481 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -273,6 +273,21 @@ private[jdbc] class JDBCRDD(
import scala.collection.JavaConverters._
dialect.beforeFetch(conn, options.asProperties.asScala.toMap)
+ // This executes a generic SQL statement (or PL/SQL block) before reading
+ // the table/query via JDBC. Use this feature to initialize the database
+ // session environment, e.g. for optimizations and/or troubleshooting.
+ options.sessionInitStatement match {
+ case Some(sql) =>
+ val statement = conn.prepareStatement(sql)
+ logInfo(s"Executing sessionInitStatement: $sql")
+ try {
+ statement.execute()
+ } finally {
+ statement.close()
+ }
+ case None =>
+ }
+
// H2's JDBC driver does not support the setSchema() method. We pass a
// fully-qualified table name in the SELECT statement. I don't know how to
// talk about a table in a completely portable way.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
index f38098695131..4d261dd422bc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
@@ -35,11 +35,12 @@ class UnsafeCartesianRDD(
left : RDD[UnsafeRow],
right : RDD[UnsafeRow],
numFieldsOfRight: Int,
+ inMemoryBufferThreshold: Int,
spillThreshold: Int)
extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) {
override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = {
- val rowArray = new ExternalAppendOnlyUnsafeRowArray(spillThreshold)
+ val rowArray = new ExternalAppendOnlyUnsafeRowArray(inMemoryBufferThreshold, spillThreshold)
val partition = split.asInstanceOf[CartesianPartition]
rdd2.iterator(partition.s2, context).foreach(rowArray.add)
@@ -71,9 +72,12 @@ case class CartesianProductExec(
val leftResults = left.execute().asInstanceOf[RDD[UnsafeRow]]
val rightResults = right.execute().asInstanceOf[RDD[UnsafeRow]]
- val spillThreshold = sqlContext.conf.cartesianProductExecBufferSpillThreshold
-
- val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size, spillThreshold)
+ val pair = new UnsafeCartesianRDD(
+ leftResults,
+ rightResults,
+ right.output.size,
+ sqlContext.conf.cartesianProductExecBufferInMemoryThreshold,
+ sqlContext.conf.cartesianProductExecBufferSpillThreshold)
pair.mapPartitionsWithIndexInternal { (index, iter) =>
val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
val filtered = if (condition.isDefined) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala
new file mode 100644
index 000000000000..534d8c5689c2
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.SparkPlan
+
+/**
+ * When the physical operators are created for JOIN, the ordering of join keys is based on order
+ * in which the join keys appear in the user query. That might not match with the output
+ * partitioning of the join node's children (thus leading to extra sort / shuffle being
+ * introduced). This rule will change the ordering of the join keys to match with the
+ * partitioning of the join nodes' children.
+ */
+class ReorderJoinPredicates extends Rule[SparkPlan] {
+ private def reorderJoinKeys(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ leftPartitioning: Partitioning,
+ rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = {
+
+ def reorder(
+ expectedOrderOfKeys: Seq[Expression],
+ currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
+ val leftKeysBuffer = ArrayBuffer[Expression]()
+ val rightKeysBuffer = ArrayBuffer[Expression]()
+
+ expectedOrderOfKeys.foreach(expression => {
+ val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression))
+ leftKeysBuffer.append(leftKeys(index))
+ rightKeysBuffer.append(rightKeys(index))
+ })
+ (leftKeysBuffer, rightKeysBuffer)
+ }
+
+ if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) {
+ leftPartitioning match {
+ case HashPartitioning(leftExpressions, _)
+ if leftExpressions.length == leftKeys.length &&
+ leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) =>
+ reorder(leftExpressions, leftKeys)
+
+ case _ => rightPartitioning match {
+ case HashPartitioning(rightExpressions, _)
+ if rightExpressions.length == rightKeys.length &&
+ rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) =>
+ reorder(rightExpressions, rightKeys)
+
+ case _ => (leftKeys, rightKeys)
+ }
+ }
+ } else {
+ (leftKeys, rightKeys)
+ }
+ }
+
+ def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
+ case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) =>
+ val (reorderedLeftKeys, reorderedRightKeys) =
+ reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
+ BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition,
+ left, right)
+
+ case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) =>
+ val (reorderedLeftKeys, reorderedRightKeys) =
+ reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
+ ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition,
+ left, right)
+
+ case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right) =>
+ val (reorderedLeftKeys, reorderedRightKeys) =
+ reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
+ SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index f41fa14213df..91d214e1978e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -130,9 +130,14 @@ case class SortMergeJoinExec(
sqlContext.conf.sortMergeJoinExecBufferSpillThreshold
}
+ private def getInMemoryThreshold: Int = {
+ sqlContext.conf.sortMergeJoinExecBufferInMemoryThreshold
+ }
+
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
val spillThreshold = getSpillThreshold
+ val inMemoryThreshold = getInMemoryThreshold
left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
val boundCondition: (InternalRow) => Boolean = {
condition.map { cond =>
@@ -158,6 +163,7 @@ case class SortMergeJoinExec(
keyOrdering,
RowIterator.fromScala(leftIter),
RowIterator.fromScala(rightIter),
+ inMemoryThreshold,
spillThreshold
)
private[this] val joinRow = new JoinedRow
@@ -201,6 +207,7 @@ case class SortMergeJoinExec(
keyOrdering,
streamedIter = RowIterator.fromScala(leftIter),
bufferedIter = RowIterator.fromScala(rightIter),
+ inMemoryThreshold,
spillThreshold
)
val rightNullRow = new GenericInternalRow(right.output.length)
@@ -214,6 +221,7 @@ case class SortMergeJoinExec(
keyOrdering,
streamedIter = RowIterator.fromScala(rightIter),
bufferedIter = RowIterator.fromScala(leftIter),
+ inMemoryThreshold,
spillThreshold
)
val leftNullRow = new GenericInternalRow(left.output.length)
@@ -247,6 +255,7 @@ case class SortMergeJoinExec(
keyOrdering,
RowIterator.fromScala(leftIter),
RowIterator.fromScala(rightIter),
+ inMemoryThreshold,
spillThreshold
)
private[this] val joinRow = new JoinedRow
@@ -281,6 +290,7 @@ case class SortMergeJoinExec(
keyOrdering,
RowIterator.fromScala(leftIter),
RowIterator.fromScala(rightIter),
+ inMemoryThreshold,
spillThreshold
)
private[this] val joinRow = new JoinedRow
@@ -322,6 +332,7 @@ case class SortMergeJoinExec(
keyOrdering,
RowIterator.fromScala(leftIter),
RowIterator.fromScala(rightIter),
+ inMemoryThreshold,
spillThreshold
)
private[this] val joinRow = new JoinedRow
@@ -420,8 +431,10 @@ case class SortMergeJoinExec(
val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName
val spillThreshold = getSpillThreshold
+ val inMemoryThreshold = getInMemoryThreshold
- ctx.addMutableState(clsName, matches, s"$matches = new $clsName($spillThreshold);")
+ ctx.addMutableState(clsName, matches,
+ s"$matches = new $clsName($inMemoryThreshold, $spillThreshold);")
// Copy the left keys as class members so they could be used in next function call.
val matchedKeyVars = copyKeys(ctx, leftKeyVars)
@@ -626,6 +639,9 @@ case class SortMergeJoinExec(
* @param streamedIter an input whose rows will be streamed.
* @param bufferedIter an input whose rows will be buffered to construct sequences of rows that
* have the same join key.
+ * @param inMemoryThreshold Threshold for number of rows guaranteed to be held in memory by
+ * internal buffer
+ * @param spillThreshold Threshold for number of rows to be spilled by internal buffer
*/
private[joins] class SortMergeJoinScanner(
streamedKeyGenerator: Projection,
@@ -633,7 +649,8 @@ private[joins] class SortMergeJoinScanner(
keyOrdering: Ordering[InternalRow],
streamedIter: RowIterator,
bufferedIter: RowIterator,
- bufferThreshold: Int) {
+ inMemoryThreshold: Int,
+ spillThreshold: Int) {
private[this] var streamedRow: InternalRow = _
private[this] var streamedRowKey: InternalRow = _
private[this] var bufferedRow: InternalRow = _
@@ -644,7 +661,8 @@ private[joins] class SortMergeJoinScanner(
*/
private[this] var matchJoinKey: InternalRow = _
/** Buffered rows from the buffered side of the join. This is empty if there are no matches. */
- private[this] val bufferedMatches = new ExternalAppendOnlyUnsafeRowArray(bufferThreshold)
+ private[this] val bufferedMatches =
+ new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold)
// Initialization (note: do _not_ want to advance streamed here).
advancedBufferedToRowWithNullFreeJoinKey()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
index f8bb667e3006..800a2ea3f399 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
@@ -292,6 +292,7 @@ case class WindowExec(
// Unwrap the expressions and factories from the map.
val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1)
val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray
+ val inMemoryThreshold = sqlContext.conf.windowExecBufferInMemoryThreshold
val spillThreshold = sqlContext.conf.windowExecBufferSpillThreshold
// Start processing.
@@ -322,7 +323,8 @@ case class WindowExec(
val inputFields = child.output.length
val buffer: ExternalAppendOnlyUnsafeRowArray =
- new ExternalAppendOnlyUnsafeRowArray(spillThreshold)
+ new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold)
+
var bufferIterator: Iterator[UnsafeRow] = _
val windowFunctionResult = new SpecificInternalRow(expressions.map(_.dataType))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 895ca196a7a5..0008d503a2cb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -665,7 +665,8 @@ class JoinSuite extends QueryTest with SharedSQLContext {
test("test SortMergeJoin (with spill)") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1",
- "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> "0") {
+ "spark.sql.sortMergeJoinExec.buffer.in.memory.threshold" -> "0",
+ "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> "1") {
assertSpilled(sparkContext, "inner join") {
checkAnswer(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala
index 031ac38c17d7..efe28afab08e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala
@@ -67,7 +67,10 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark {
benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int =>
var sum = 0L
for (_ <- 0L until iterations) {
- val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold)
+ val array = new ExternalAppendOnlyUnsafeRowArray(
+ ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer,
+ numSpillThreshold)
+
rows.foreach(x => array.add(x))
val iterator = array.generateIterator()
@@ -143,7 +146,7 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark {
benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int =>
var sum = 0L
for (_ <- 0L until iterations) {
- val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold)
+ val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold, numSpillThreshold)
rows.foreach(x => array.add(x))
val iterator = array.generateIterator()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala
index 53c41639942b..ecc7264d7944 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala
@@ -31,7 +31,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
override def afterAll(): Unit = TaskContext.unset()
- private def withExternalArray(spillThreshold: Int)
+ private def withExternalArray(inMemoryThreshold: Int, spillThreshold: Int)
(f: ExternalAppendOnlyUnsafeRowArray => Unit): Unit = {
sc = new SparkContext("local", "test", new SparkConf(false))
@@ -45,6 +45,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
taskContext,
1024,
SparkEnv.get.memoryManager.pageSizeBytes,
+ inMemoryThreshold,
spillThreshold)
try f(array) finally {
array.clear()
@@ -109,9 +110,9 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
assert(getNumBytesSpilled > 0)
}
- test("insert rows less than the spillThreshold") {
- val spillThreshold = 100
- withExternalArray(spillThreshold) { array =>
+ test("insert rows less than the inMemoryThreshold") {
+ val (inMemoryThreshold, spillThreshold) = (100, 50)
+ withExternalArray(inMemoryThreshold, spillThreshold) { array =>
assert(array.isEmpty)
val expectedValues = populateRows(array, 1)
@@ -122,8 +123,8 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
// Add more rows (but not too many to trigger switch to [[UnsafeExternalSorter]])
// Verify that NO spill has happened
- populateRows(array, spillThreshold - 1, expectedValues)
- assert(array.length == spillThreshold)
+ populateRows(array, inMemoryThreshold - 1, expectedValues)
+ assert(array.length == inMemoryThreshold)
assertNoSpill()
val iterator2 = validateData(array, expectedValues)
@@ -133,20 +134,42 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
}
}
- test("insert rows more than the spillThreshold to force spill") {
- val spillThreshold = 100
- withExternalArray(spillThreshold) { array =>
- val numValuesInserted = 20 * spillThreshold
-
+ test("insert rows more than the inMemoryThreshold but less than spillThreshold") {
+ val (inMemoryThreshold, spillThreshold) = (10, 50)
+ withExternalArray(inMemoryThreshold, spillThreshold) { array =>
assert(array.isEmpty)
- val expectedValues = populateRows(array, 1)
- assert(array.length == 1)
+ val expectedValues = populateRows(array, inMemoryThreshold - 1)
+ assert(array.length == (inMemoryThreshold - 1))
+ val iterator1 = validateData(array, expectedValues)
+ assertNoSpill()
+
+ // Add more rows to trigger switch to [[UnsafeExternalSorter]] but not too many to cause a
+ // spill to happen. Verify that NO spill has happened
+ populateRows(array, spillThreshold - expectedValues.length - 1, expectedValues)
+ assert(array.length == spillThreshold - 1)
+ assertNoSpill()
+
+ val iterator2 = validateData(array, expectedValues)
+ assert(!iterator2.hasNext)
+ assert(!iterator1.hasNext)
+ intercept[ConcurrentModificationException](iterator1.next())
+ }
+ }
+
+ test("insert rows enough to force spill") {
+ val (inMemoryThreshold, spillThreshold) = (20, 10)
+ withExternalArray(inMemoryThreshold, spillThreshold) { array =>
+ assert(array.isEmpty)
+ val expectedValues = populateRows(array, inMemoryThreshold - 1)
+ assert(array.length == (inMemoryThreshold - 1))
val iterator1 = validateData(array, expectedValues)
+ assertNoSpill()
- // Populate more rows to trigger spill. Verify that spill has happened
- populateRows(array, numValuesInserted - 1, expectedValues)
- assert(array.length == numValuesInserted)
+ // Add more rows to trigger switch to [[UnsafeExternalSorter]] and cause a spill to happen.
+ // Verify that spill has happened
+ populateRows(array, 2, expectedValues)
+ assert(array.length == inMemoryThreshold + 1)
assertSpill()
val iterator2 = validateData(array, expectedValues)
@@ -158,7 +181,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
}
test("iterator on an empty array should be empty") {
- withExternalArray(spillThreshold = 10) { array =>
+ withExternalArray(inMemoryThreshold = 4, spillThreshold = 10) { array =>
val iterator = array.generateIterator()
assert(array.isEmpty)
assert(array.length == 0)
@@ -167,7 +190,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
}
test("generate iterator with negative start index") {
- withExternalArray(spillThreshold = 2) { array =>
+ withExternalArray(inMemoryThreshold = 100, spillThreshold = 56) { array =>
val exception =
intercept[ArrayIndexOutOfBoundsException](array.generateIterator(startIndex = -10))
@@ -178,8 +201,8 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
}
test("generate iterator with start index exceeding array's size (without spill)") {
- val spillThreshold = 2
- withExternalArray(spillThreshold) { array =>
+ val (inMemoryThreshold, spillThreshold) = (20, 100)
+ withExternalArray(inMemoryThreshold, spillThreshold) { array =>
populateRows(array, spillThreshold / 2)
val exception =
@@ -191,8 +214,8 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
}
test("generate iterator with start index exceeding array's size (with spill)") {
- val spillThreshold = 2
- withExternalArray(spillThreshold) { array =>
+ val (inMemoryThreshold, spillThreshold) = (20, 100)
+ withExternalArray(inMemoryThreshold, spillThreshold) { array =>
populateRows(array, spillThreshold * 2)
val exception =
@@ -205,10 +228,10 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
}
test("generate iterator with custom start index (without spill)") {
- val spillThreshold = 10
- withExternalArray(spillThreshold) { array =>
- val expectedValues = populateRows(array, spillThreshold)
- val startIndex = spillThreshold / 2
+ val (inMemoryThreshold, spillThreshold) = (20, 100)
+ withExternalArray(inMemoryThreshold, spillThreshold) { array =>
+ val expectedValues = populateRows(array, inMemoryThreshold)
+ val startIndex = inMemoryThreshold / 2
val iterator = array.generateIterator(startIndex = startIndex)
for (i <- startIndex until expectedValues.length) {
checkIfValueExists(iterator, expectedValues(i))
@@ -217,8 +240,8 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
}
test("generate iterator with custom start index (with spill)") {
- val spillThreshold = 10
- withExternalArray(spillThreshold) { array =>
+ val (inMemoryThreshold, spillThreshold) = (20, 100)
+ withExternalArray(inMemoryThreshold, spillThreshold) { array =>
val expectedValues = populateRows(array, spillThreshold * 10)
val startIndex = spillThreshold * 2
val iterator = array.generateIterator(startIndex = startIndex)
@@ -229,7 +252,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
}
test("test iterator invalidation (without spill)") {
- withExternalArray(spillThreshold = 10) { array =>
+ withExternalArray(inMemoryThreshold = 10, spillThreshold = 100) { array =>
// insert 2 rows, iterate until the first row
populateRows(array, 2)
@@ -254,9 +277,9 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
}
test("test iterator invalidation (with spill)") {
- val spillThreshold = 10
- withExternalArray(spillThreshold) { array =>
- // Populate enough rows so that spill has happens
+ val (inMemoryThreshold, spillThreshold) = (2, 10)
+ withExternalArray(inMemoryThreshold, spillThreshold) { array =>
+ // Populate enough rows so that spill happens
populateRows(array, spillThreshold * 2)
assertSpill()
@@ -281,7 +304,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
}
test("clear on an empty the array") {
- withExternalArray(spillThreshold = 2) { array =>
+ withExternalArray(inMemoryThreshold = 2, spillThreshold = 3) { array =>
val iterator = array.generateIterator()
assert(!iterator.hasNext)
@@ -299,10 +322,10 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
}
test("clear array (without spill)") {
- val spillThreshold = 10
- withExternalArray(spillThreshold) { array =>
+ val (inMemoryThreshold, spillThreshold) = (10, 100)
+ withExternalArray(inMemoryThreshold, spillThreshold) { array =>
// Populate rows ... but not enough to trigger spill
- populateRows(array, spillThreshold / 2)
+ populateRows(array, inMemoryThreshold / 2)
assertNoSpill()
// Clear the array
@@ -311,21 +334,21 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
// Re-populate few rows so that there is no spill
// Verify the data. Verify that there was no spill
- val expectedValues = populateRows(array, spillThreshold / 3)
+ val expectedValues = populateRows(array, inMemoryThreshold / 2)
validateData(array, expectedValues)
assertNoSpill()
// Populate more rows .. enough to not trigger a spill.
// Verify the data. Verify that there was no spill
- populateRows(array, spillThreshold / 3, expectedValues)
+ populateRows(array, inMemoryThreshold / 2, expectedValues)
validateData(array, expectedValues)
assertNoSpill()
}
}
test("clear array (with spill)") {
- val spillThreshold = 10
- withExternalArray(spillThreshold) { array =>
+ val (inMemoryThreshold, spillThreshold) = (10, 20)
+ withExternalArray(inMemoryThreshold, spillThreshold) { array =>
// Populate enough rows to trigger spill
populateRows(array, spillThreshold * 2)
val bytesSpilled = getNumBytesSpilled
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
index a9f3fb355c77..a57514c256b9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
@@ -477,7 +477,8 @@ class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext {
|WINDOW w1 AS (ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDiNG AND CURRENT RoW)
""".stripMargin)
- withSQLConf("spark.sql.windowExec.buffer.spill.threshold" -> "1") {
+ withSQLConf("spark.sql.windowExec.buffer.in.memory.threshold" -> "1",
+ "spark.sql.windowExec.buffer.spill.threshold" -> "2") {
assertSpilled(sparkContext, "test with low buffer spill threshold") {
checkAnswer(actual, expected)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 4c4364688941..8dc11d80c306 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -1044,4 +1044,35 @@ class JDBCSuite extends SparkFunSuite
assert(sql("select * from people_view").count() == 3)
}
}
+
+ test("SPARK-21519: option sessionInitStatement, run SQL to initialize the database session.") {
+ val initSQL1 = "SET @MYTESTVAR 21519"
+ val df1 = spark.read.format("jdbc")
+ .option("url", urlWithUserAndPass)
+ .option("dbtable", "(SELECT NVL(@MYTESTVAR, -1))")
+ .option("sessionInitStatement", initSQL1)
+ .load()
+ assert(df1.collect() === Array(Row(21519)))
+
+ val initSQL2 = "SET SCHEMA DUMMY"
+ val df2 = spark.read.format("jdbc")
+ .option("url", urlWithUserAndPass)
+ .option("dbtable", "TEST.PEOPLE")
+ .option("sessionInitStatement", initSQL2)
+ .load()
+ val e = intercept[SparkException] {df2.collect()}.getMessage
+ assert(e.contains("""Schema "DUMMY" not found"""))
+
+ sql(
+ s"""
+ |CREATE OR REPLACE TEMPORARY VIEW test_sessionInitStatement
+ |USING org.apache.spark.sql.jdbc
+ |OPTIONS (url '$urlWithUserAndPass',
+ |dbtable '(SELECT NVL(@MYTESTVAR1, -1), NVL(@MYTESTVAR2, -1))',
+ |sessionInitStatement 'SET @MYTESTVAR1 21519; SET @MYTESTVAR2 1234')
+ """.stripMargin)
+
+ val df3 = sql("SELECT * FROM test_sessionInitStatement")
+ assert(df3.collect() === Array(Row(21519, 1234)))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index ba0ca666b5c1..eb9e6458fc61 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -543,6 +543,65 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
)
}
+ test("SPARK-19122 Re-order join predicates if they match with the child's output partitioning") {
+ val bucketedTableTestSpec = BucketedTableTestSpec(
+ Some(BucketSpec(8, Seq("i", "j", "k"), Seq("i", "j", "k"))),
+ numPartitions = 1,
+ expectedShuffle = false,
+ expectedSort = false)
+
+ // If the set of join columns is equal to the set of bucketed + sort columns, then
+ // the order of join keys in the query should not matter and there should not be any shuffle
+ // and sort added in the query plan
+ Seq(
+ Seq("i", "j", "k"),
+ Seq("i", "k", "j"),
+ Seq("j", "k", "i"),
+ Seq("j", "i", "k"),
+ Seq("k", "j", "i"),
+ Seq("k", "i", "j")
+ ).foreach(joinKeys => {
+ testBucketing(
+ bucketedTableTestSpecLeft = bucketedTableTestSpec,
+ bucketedTableTestSpecRight = bucketedTableTestSpec,
+ joinCondition = joinCondition(joinKeys)
+ )
+ })
+ }
+
+ test("SPARK-19122 No re-ordering should happen if set of join columns != set of child's " +
+ "partitioning columns") {
+
+ // join predicates is a super set of child's partitioning columns
+ val bucketedTableTestSpec1 =
+ BucketedTableTestSpec(Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))), numPartitions = 1)
+ testBucketing(
+ bucketedTableTestSpecLeft = bucketedTableTestSpec1,
+ bucketedTableTestSpecRight = bucketedTableTestSpec1,
+ joinCondition = joinCondition(Seq("i", "j", "k"))
+ )
+
+ // child's partitioning columns is a super set of join predicates
+ val bucketedTableTestSpec2 =
+ BucketedTableTestSpec(Some(BucketSpec(8, Seq("i", "j", "k"), Seq("i", "j", "k"))),
+ numPartitions = 1)
+ testBucketing(
+ bucketedTableTestSpecLeft = bucketedTableTestSpec2,
+ bucketedTableTestSpecRight = bucketedTableTestSpec2,
+ joinCondition = joinCondition(Seq("i", "j"))
+ )
+
+ // set of child's partitioning columns != set join predicates (despite the lengths of the
+ // sets are same)
+ val bucketedTableTestSpec3 =
+ BucketedTableTestSpec(Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))), numPartitions = 1)
+ testBucketing(
+ bucketedTableTestSpecLeft = bucketedTableTestSpec3,
+ bucketedTableTestSpecRight = bucketedTableTestSpec3,
+ joinCondition = joinCondition(Seq("j", "k"))
+ )
+ }
+
test("error if there exists any malformed bucket files") {
withTable("bucketed_table") {
df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")
|