diff --git a/LICENSE b/LICENSE
index 7950dd6ceb6db..c21032a1fd274 100644
--- a/LICENSE
+++ b/LICENSE
@@ -297,3 +297,4 @@ The text of each license is also included at licenses/LICENSE-[project].txt.
(MIT License) RowsGroup (http://datatables.net/license/mit)
(MIT License) jsonFormatter (http://www.jqueryscript.net/other/jQuery-Plugin-For-Pretty-JSON-Formatting-jsonFormatter.html)
(MIT License) modernizr (https://github.com/Modernizr/Modernizr/blob/master/LICENSE)
+ (MIT License) machinist (https://github.com/typelevel/machinist)
diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/inst/tests/testthat/test_mllib_classification.R
index 459254d271a58..af7cbdccf5d5d 100644
--- a/R/pkg/inst/tests/testthat/test_mllib_classification.R
+++ b/R/pkg/inst/tests/testthat/test_mllib_classification.R
@@ -288,18 +288,18 @@ test_that("spark.mlp", {
c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9))
mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction"))
expect_equal(head(mlpPredictions$prediction, 10),
- c("1.0", "1.0", "1.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0"))
+ c("1.0", "1.0", "2.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0"))
model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2, initialWeights =
c(0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 5.0, 5.0, 5.0, 5.0, 9.0, 9.0, 9.0, 9.0, 9.0))
mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction"))
expect_equal(head(mlpPredictions$prediction, 10),
- c("1.0", "1.0", "1.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0"))
+ c("1.0", "1.0", "2.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0"))
model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2)
mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction"))
expect_equal(head(mlpPredictions$prediction, 10),
- c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "0.0", "2.0", "1.0", "0.0"))
+ c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "0.0", "0.0", "1.0", "0.0"))
# Test formula works well
df <- suppressWarnings(createDataFrame(iris))
@@ -310,8 +310,8 @@ test_that("spark.mlp", {
expect_equal(summary$numOfOutputs, 3)
expect_equal(summary$layers, c(4, 3))
expect_equal(length(summary$weights), 15)
- expect_equal(head(summary$weights, 5), list(-1.1957257, -5.2693685, 7.4489734, -6.3751413,
- -10.2376130), tolerance = 1e-6)
+ expect_equal(head(summary$weights, 5), list(-0.5793153, -4.652961, 6.216155, -6.649478,
+ -10.51147), tolerance = 1e-3)
})
test_that("spark.naiveBayes", {
diff --git a/assembly/pom.xml b/assembly/pom.xml
index 9d8607d9137c6..742a4a1531e71 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../pom.xml
diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml
index 8657af744c069..066970f24205f 100644
--- a/common/network-common/pom.xml
+++ b/common/network-common/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../../pom.xml
diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml
index 24c10fb1ddb9f..2de882adcb582 100644
--- a/common/network-shuffle/pom.xml
+++ b/common/network-shuffle/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../../pom.xml
diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml
index 5e5a80bd44467..a8488d8d1b704 100644
--- a/common/network-yarn/pom.xml
+++ b/common/network-yarn/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../../pom.xml
diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml
index 1356c4723b662..6b81fc2b2b040 100644
--- a/common/sketch/pom.xml
+++ b/common/sketch/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../../pom.xml
diff --git a/common/tags/pom.xml b/common/tags/pom.xml
index 9345dc8f0cc4b..f7e586ee777e1 100644
--- a/common/tags/pom.xml
+++ b/common/tags/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../../pom.xml
diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml
index f03a4da5e7152..680d0413b1616 100644
--- a/common/unsafe/pom.xml
+++ b/common/unsafe/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../../pom.xml
diff --git a/core/pom.xml b/core/pom.xml
index 24ce36deeb169..7f245b5b6384a 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../pom.xml
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 99efc4893fda4..0ec1bdd39b2f5 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -1350,7 +1350,7 @@ class SparkContext(config: SparkConf) extends Logging {
@deprecated("use AccumulatorV2", "2.0.0")
def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T])
: Accumulator[T] = {
- val acc = new Accumulator(initialValue, param, Some(name))
+ val acc = new Accumulator(initialValue, param, Option(name))
cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc))
acc
}
@@ -1379,7 +1379,7 @@ class SparkContext(config: SparkConf) extends Logging {
@deprecated("use AccumulatorV2", "2.0.0")
def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T])
: Accumulable[R, T] = {
- val acc = new Accumulable(initialValue, param, Some(name))
+ val acc = new Accumulable(initialValue, param, Option(name))
cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc))
acc
}
@@ -1414,7 +1414,7 @@ class SparkContext(config: SparkConf) extends Logging {
* @note Accumulators must be registered before use, or it will throw exception.
*/
def register(acc: AccumulatorV2[_, _], name: String): Unit = {
- acc.register(this, name = Some(name))
+ acc.register(this, name = Option(name))
}
/**
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala
index d7d82800b8b55..6d8758a3d3b1d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala
@@ -86,7 +86,7 @@ private[history] abstract class ApplicationHistoryProvider {
* @return Count of application event logs that are currently under process
*/
def getEventLogsUnderProcess(): Int = {
- return 0;
+ 0
}
/**
@@ -95,7 +95,7 @@ private[history] abstract class ApplicationHistoryProvider {
* @return 0 if this is undefined or unsupported, otherwise the last updated time in millis
*/
def getLastUpdatedTime(): Long = {
- return 0;
+ 0
}
/**
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
index 54f39f7620e5d..d9c8fda99ef97 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
@@ -301,6 +301,14 @@ object HistoryServer extends Logging {
logDebug(s"Clearing ${SecurityManager.SPARK_AUTH_CONF}")
config.set(SecurityManager.SPARK_AUTH_CONF, "false")
}
+
+ if (config.getBoolean("spark.acls.enable", config.getBoolean("spark.ui.acls.enable", false))) {
+ logInfo("Either spark.acls.enable or spark.ui.acls.enable is configured, clearing it and " +
+ "only using spark.history.ui.acl.enable")
+ config.set("spark.acls.enable", "false")
+ config.set("spark.ui.acls.enable", "false")
+ }
+
new SecurityManager(config)
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index e524675332d1b..63a87e7f09d85 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -41,7 +41,7 @@ import org.apache.spark.partial.GroupedCountEvaluator
import org.apache.spark.partial.PartialResult
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
-import org.apache.spark.util.collection.OpenHashMap
+import org.apache.spark.util.collection.{OpenHashMap, Utils => collectionUtils}
import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler,
SamplingUtils}
@@ -1420,7 +1420,7 @@ abstract class RDD[T: ClassTag](
val mapRDDs = mapPartitions { items =>
// Priority keeps the largest elements, so let's reverse the ordering.
val queue = new BoundedPriorityQueue[T](num)(ord.reverse)
- queue ++= util.collection.Utils.takeOrdered(items, num)(ord)
+ queue ++= collectionUtils.takeOrdered(items, num)(ord)
Iterator.single(queue)
}
if (mapRDDs.partitions.length == 0) {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala
similarity index 97%
rename from mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala
rename to core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala
index 145dc22b7428e..ab72addb2466b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala
@@ -15,11 +15,12 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.impl
+package org.apache.spark.rdd.util
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.PeriodicCheckpointer
/**
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala
index 00f918c09c66b..f17b637754826 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala
@@ -184,14 +184,27 @@ private[v1] class ApiRootResource extends ApiRequestContext {
@Path("applications/{appId}/logs")
def getEventLogs(
@PathParam("appId") appId: String): EventLogDownloadResource = {
- new EventLogDownloadResource(uiRoot, appId, None)
+ try {
+ // withSparkUI will throw NotFoundException if attemptId exists for this application.
+ // So we need to try again with attempt id "1".
+ withSparkUI(appId, None) { _ =>
+ new EventLogDownloadResource(uiRoot, appId, None)
+ }
+ } catch {
+ case _: NotFoundException =>
+ withSparkUI(appId, Some("1")) { _ =>
+ new EventLogDownloadResource(uiRoot, appId, None)
+ }
+ }
}
@Path("applications/{appId}/{attemptId}/logs")
def getEventLogs(
@PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): EventLogDownloadResource = {
- new EventLogDownloadResource(uiRoot, appId, Some(attemptId))
+ withSparkUI(appId, Some(attemptId)) { _ =>
+ new EventLogDownloadResource(uiRoot, appId, Some(attemptId))
+ }
}
@Path("version")
@@ -291,7 +304,6 @@ private[v1] trait ApiRequestContext {
case None => throw new NotFoundException("no such app: " + appId)
}
}
-
}
private[v1] class ForbiddenException(msg: String) extends WebApplicationException(
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala
similarity index 95%
rename from mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala
rename to core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala
index 4dd498cd91b4e..ce06e18879a49 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala
+++ b/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.impl
+package org.apache.spark.util
import scala.collection.mutable
@@ -58,7 +58,7 @@ import org.apache.spark.storage.StorageLevel
* @param sc SparkContext for the Datasets given to this checkpointer
* @tparam T Dataset type, such as RDD[Double]
*/
-private[mllib] abstract class PeriodicCheckpointer[T](
+private[spark] abstract class PeriodicCheckpointer[T](
val checkpointInterval: Int,
val sc: SparkContext) extends Logging {
@@ -127,6 +127,16 @@ private[mllib] abstract class PeriodicCheckpointer[T](
/** Get list of checkpoint files for this given Dataset */
protected def getCheckpointFiles(data: T): Iterable[String]
+ /**
+ * Call this to unpersist the Dataset.
+ */
+ def unpersistDataSet(): Unit = {
+ while (persistedQueue.nonEmpty) {
+ val dataToUnpersist = persistedQueue.dequeue()
+ unpersist(dataToUnpersist)
+ }
+ }
+
/**
* Call this at the end to delete any remaining checkpoint files.
*/
diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
index 764156c3edc41..95acb9a54440f 100644
--- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
@@ -565,13 +565,12 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers
assert(jobcount === getNumJobs("/jobs"))
// no need to retain the test dir now the tests complete
- logDir.deleteOnExit();
-
+ logDir.deleteOnExit()
}
test("ui and api authorization checks") {
- val appId = "app-20161115172038-0000"
- val owner = "jose"
+ val appId = "local-1430917381535"
+ val owner = "irashid"
val admin = "root"
val other = "alice"
@@ -590,8 +589,11 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers
val port = server.boundPort
val testUrls = Seq(
- s"http://localhost:$port/api/v1/applications/$appId/jobs",
- s"http://localhost:$port/history/$appId/jobs/")
+ s"http://localhost:$port/api/v1/applications/$appId/1/jobs",
+ s"http://localhost:$port/history/$appId/1/jobs/",
+ s"http://localhost:$port/api/v1/applications/$appId/logs",
+ s"http://localhost:$port/api/v1/applications/$appId/1/logs",
+ s"http://localhost:$port/api/v1/applications/$appId/2/logs")
tests.foreach { case (user, expectedCode) =>
testUrls.foreach { url =>
diff --git a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala
index f9a7f151823a2..7f20206202cb9 100644
--- a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala
@@ -135,7 +135,7 @@ class SortingSuite extends SparkFunSuite with SharedSparkContext with Matchers w
}
test("get a range of elements in an array not partitioned by a range partitioner") {
- val pairArr = util.Random.shuffle((1 to 1000).toList).map(x => (x, x))
+ val pairArr = scala.util.Random.shuffle((1 to 1000).toList).map(x => (x, x))
val pairs = sc.parallelize(pairArr, 10)
val range = pairs.filterByRange(200, 800).collect()
assert((800 to 200 by -1).toArray.sorted === range.map(_._1).sorted)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala b/core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala
similarity index 96%
rename from mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala
rename to core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala
index 14adf8c29fc6b..f9e1b791c86ea 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala
@@ -15,18 +15,18 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.impl
+package org.apache.spark.utils
import org.apache.hadoop.fs.Path
-import org.apache.spark.{SparkContext, SparkFunSuite}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite}
import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.util.PeriodicRDDCheckpointer
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
-class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class PeriodicRDDCheckpointerSuite extends SparkFunSuite with SharedSparkContext {
import PeriodicRDDCheckpointerSuite._
diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6
index 73dc1f9a1398c..9287bd47cf113 100644
--- a/dev/deps/spark-deps-hadoop-2.6
+++ b/dev/deps/spark-deps-hadoop-2.6
@@ -19,8 +19,8 @@ avro-mapred-1.7.7-hadoop2.jar
base64-2.3.8.jar
bcprov-jdk15on-1.51.jar
bonecp-0.8.0.RELEASE.jar
-breeze-macros_2.11-0.12.jar
-breeze_2.11-0.12.jar
+breeze-macros_2.11-0.13.1.jar
+breeze_2.11-0.13.1.jar
calcite-avatica-1.2.0-incubating.jar
calcite-core-1.2.0-incubating.jar
calcite-linq4j-1.2.0-incubating.jar
@@ -129,6 +129,8 @@ libfb303-0.9.3.jar
libthrift-0.9.3.jar
log4j-1.2.17.jar
lz4-1.3.0.jar
+machinist_2.11-0.6.1.jar
+macro-compat_2.11-1.1.1.jar
mail-1.4.7.jar
mesos-1.0.0-shaded-protobuf.jar
metrics-core-3.1.2.jar
@@ -162,13 +164,13 @@ scala-parser-combinators_2.11-1.0.4.jar
scala-reflect-2.11.8.jar
scala-xml_2.11-1.0.2.jar
scalap-2.11.8.jar
-shapeless_2.11-2.0.0.jar
+shapeless_2.11-2.3.2.jar
slf4j-api-1.7.16.jar
slf4j-log4j12-1.7.16.jar
snappy-0.2.jar
snappy-java-1.1.2.6.jar
-spire-macros_2.11-0.7.4.jar
-spire_2.11-0.7.4.jar
+spire-macros_2.11-0.13.0.jar
+spire_2.11-0.13.0.jar
stax-api-1.0-2.jar
stax-api-1.0.1.jar
stream-2.7.0.jar
diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7
index 6bf0923a1d751..ab1de3d3dd8ad 100644
--- a/dev/deps/spark-deps-hadoop-2.7
+++ b/dev/deps/spark-deps-hadoop-2.7
@@ -19,8 +19,8 @@ avro-mapred-1.7.7-hadoop2.jar
base64-2.3.8.jar
bcprov-jdk15on-1.51.jar
bonecp-0.8.0.RELEASE.jar
-breeze-macros_2.11-0.12.jar
-breeze_2.11-0.12.jar
+breeze-macros_2.11-0.13.1.jar
+breeze_2.11-0.13.1.jar
calcite-avatica-1.2.0-incubating.jar
calcite-core-1.2.0-incubating.jar
calcite-linq4j-1.2.0-incubating.jar
@@ -130,6 +130,8 @@ libfb303-0.9.3.jar
libthrift-0.9.3.jar
log4j-1.2.17.jar
lz4-1.3.0.jar
+machinist_2.11-0.6.1.jar
+macro-compat_2.11-1.1.1.jar
mail-1.4.7.jar
mesos-1.0.0-shaded-protobuf.jar
metrics-core-3.1.2.jar
@@ -163,13 +165,13 @@ scala-parser-combinators_2.11-1.0.4.jar
scala-reflect-2.11.8.jar
scala-xml_2.11-1.0.2.jar
scalap-2.11.8.jar
-shapeless_2.11-2.0.0.jar
+shapeless_2.11-2.3.2.jar
slf4j-api-1.7.16.jar
slf4j-log4j12-1.7.16.jar
snappy-0.2.jar
snappy-java-1.1.2.6.jar
-spire-macros_2.11-0.7.4.jar
-spire_2.11-0.7.4.jar
+spire-macros_2.11-0.13.0.jar
+spire_2.11-0.13.0.jar
stax-api-1.0-2.jar
stax-api-1.0.1.jar
stream-2.7.0.jar
diff --git a/docs/_config.yml b/docs/_config.yml
index 83bb30598d153..21255ef7a5c45 100644
--- a/docs/_config.yml
+++ b/docs/_config.yml
@@ -14,8 +14,8 @@ include:
# These allow the documentation to be updated with newer releases
# of Spark, Scala, and Mesos.
-SPARK_VERSION: 2.2.0-SNAPSHOT
-SPARK_VERSION_SHORT: 2.2.0
+SPARK_VERSION: 2.3.0-SNAPSHOT
+SPARK_VERSION_SHORT: 2.3.0
SCALA_BINARY_VERSION: "2.11"
SCALA_VERSION: "2.11.7"
MESOS_VERSION: 1.0.0
diff --git a/docs/building-spark.md b/docs/building-spark.md
index e99b70f7a8b47..0f551bc66b8c9 100644
--- a/docs/building-spark.md
+++ b/docs/building-spark.md
@@ -232,7 +232,7 @@ Once installed, the `docker` service needs to be started, if not already running
On Linux, this can be done by `sudo service docker start`.
./build/mvn install -DskipTests
- ./build/mvn -Pdocker-integration-tests -pl :spark-docker-integration-tests_2.11
+ ./build/mvn test -Pdocker-integration-tests -pl :spark-docker-integration-tests_2.11
or
diff --git a/docs/configuration.md b/docs/configuration.md
index 6b65d2bcb83e5..87b76322cae51 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -2149,6 +2149,20 @@ showDF(properties, numRows = 200, truncate = FALSE)
+### GraphX
+
+
+| Property Name | Default | Meaning |
+
+ spark.graphx.pregel.checkpointInterval |
+ -1 |
+
+ Checkpoint interval for graph and message in Pregel. It used to avoid stackOverflowError due to long lineage chains
+ after lots of iterations. The checkpoint is disabled by default.
+ |
+
+
+
### Deploy
diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md
index e271b28fb4f28..76aa7b405e18c 100644
--- a/docs/graphx-programming-guide.md
+++ b/docs/graphx-programming-guide.md
@@ -708,7 +708,9 @@ messages remaining.
> messaging function. These constraints allow additional optimization within GraphX.
The following is the type signature of the [Pregel operator][GraphOps.pregel] as well as a *sketch*
-of its implementation (note calls to graph.cache have been removed):
+of its implementation (note: to avoid stackOverflowError due to long lineage chains, pregel support periodcally
+checkpoint graph and messages by setting "spark.graphx.pregel.checkpointInterval" to a positive number,
+say 10. And set checkpoint directory as well using SparkContext.setCheckpointDir(directory: String)):
{% highlight scala %}
class GraphOps[VD, ED] {
@@ -722,6 +724,7 @@ class GraphOps[VD, ED] {
: Graph[VD, ED] = {
// Receive the initial message at each vertex
var g = mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg) ).cache()
+
// compute the messages
var messages = g.mapReduceTriplets(sendMsg, mergeMsg)
var activeMessages = messages.count()
@@ -734,8 +737,8 @@ class GraphOps[VD, ED] {
// Send new messages, skipping edges where neither side received a message. We must cache
// messages so it can be materialized on the next line, allowing us to uncache the previous
// iteration.
- messages = g.mapReduceTriplets(
- sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache()
+ messages = GraphXUtils.mapReduceTriplets(
+ g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache()
activeMessages = messages.count()
i += 1
}
diff --git a/examples/pom.xml b/examples/pom.xml
index 91c2e81ebed2f..e674e799f24a3 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../pom.xml
diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml
index 8948df2da89e2..0fa87a697454b 100644
--- a/external/docker-integration-tests/pom.xml
+++ b/external/docker-integration-tests/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../../pom.xml
diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml
index f8ef8a991316d..71016bc645ca7 100644
--- a/external/flume-assembly/pom.xml
+++ b/external/flume-assembly/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../../pom.xml
diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml
index 6d547c46d6a2d..12630840e79dc 100644
--- a/external/flume-sink/pom.xml
+++ b/external/flume-sink/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../../pom.xml
diff --git a/external/flume/pom.xml b/external/flume/pom.xml
index 46901d64eda97..87a09642405a7 100644
--- a/external/flume/pom.xml
+++ b/external/flume/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../../pom.xml
diff --git a/external/kafka-0-10-assembly/pom.xml b/external/kafka-0-10-assembly/pom.xml
index 295142cbfdff9..75df886ca44f6 100644
--- a/external/kafka-0-10-assembly/pom.xml
+++ b/external/kafka-0-10-assembly/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../../pom.xml
diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml
index 6cf448e65e8b4..557d27296345f 100644
--- a/external/kafka-0-10-sql/pom.xml
+++ b/external/kafka-0-10-sql/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../../pom.xml
diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml
index 88499240cd569..6c98cb04fcfa6 100644
--- a/external/kafka-0-10/pom.xml
+++ b/external/kafka-0-10/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../../pom.xml
diff --git a/external/kafka-0-8-assembly/pom.xml b/external/kafka-0-8-assembly/pom.xml
index 3fedd9eda1959..f9c2dcb38dc0e 100644
--- a/external/kafka-0-8-assembly/pom.xml
+++ b/external/kafka-0-8-assembly/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../../pom.xml
diff --git a/external/kafka-0-8/pom.xml b/external/kafka-0-8/pom.xml
index 8368a1f12218d..849c8b465f99e 100644
--- a/external/kafka-0-8/pom.xml
+++ b/external/kafka-0-8/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../../pom.xml
diff --git a/external/kinesis-asl-assembly/pom.xml b/external/kinesis-asl-assembly/pom.xml
index 90bb0e4987c82..48783d65826aa 100644
--- a/external/kinesis-asl-assembly/pom.xml
+++ b/external/kinesis-asl-assembly/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../../pom.xml
diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml
index daa79e79163b9..40a751a652fa9 100644
--- a/external/kinesis-asl/pom.xml
+++ b/external/kinesis-asl/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../../pom.xml
diff --git a/external/spark-ganglia-lgpl/pom.xml b/external/spark-ganglia-lgpl/pom.xml
index 7da27817ebafd..36d555066b181 100644
--- a/external/spark-ganglia-lgpl/pom.xml
+++ b/external/spark-ganglia-lgpl/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../../pom.xml
diff --git a/graphx/pom.xml b/graphx/pom.xml
index 8df33660ea9d1..cb30e4a4af4bc 100644
--- a/graphx/pom.xml
+++ b/graphx/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../pom.xml
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
index 646462b4a8350..755c6febc48e6 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
@@ -19,7 +19,10 @@ package org.apache.spark.graphx
import scala.reflect.ClassTag
+import org.apache.spark.graphx.util.PeriodicGraphCheckpointer
import org.apache.spark.internal.Logging
+import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.util.PeriodicRDDCheckpointer
/**
* Implements a Pregel-like bulk-synchronous message-passing API.
@@ -122,27 +125,39 @@ object Pregel extends Logging {
require(maxIterations > 0, s"Maximum number of iterations must be greater than 0," +
s" but got ${maxIterations}")
- var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache()
+ val checkpointInterval = graph.vertices.sparkContext.getConf
+ .getInt("spark.graphx.pregel.checkpointInterval", -1)
+ var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg))
+ val graphCheckpointer = new PeriodicGraphCheckpointer[VD, ED](
+ checkpointInterval, graph.vertices.sparkContext)
+ graphCheckpointer.update(g)
+
// compute the messages
var messages = GraphXUtils.mapReduceTriplets(g, sendMsg, mergeMsg)
+ val messageCheckpointer = new PeriodicRDDCheckpointer[(VertexId, A)](
+ checkpointInterval, graph.vertices.sparkContext)
+ messageCheckpointer.update(messages.asInstanceOf[RDD[(VertexId, A)]])
var activeMessages = messages.count()
+
// Loop
var prevG: Graph[VD, ED] = null
var i = 0
while (activeMessages > 0 && i < maxIterations) {
// Receive the messages and update the vertices.
prevG = g
- g = g.joinVertices(messages)(vprog).cache()
+ g = g.joinVertices(messages)(vprog)
+ graphCheckpointer.update(g)
val oldMessages = messages
// Send new messages, skipping edges where neither side received a message. We must cache
// messages so it can be materialized on the next line, allowing us to uncache the previous
// iteration.
messages = GraphXUtils.mapReduceTriplets(
- g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache()
+ g, sendMsg, mergeMsg, Some((oldMessages, activeDirection)))
// The call to count() materializes `messages` and the vertices of `g`. This hides oldMessages
// (depended on by the vertices of g) and the vertices of prevG (depended on by oldMessages
// and the vertices of g).
+ messageCheckpointer.update(messages.asInstanceOf[RDD[(VertexId, A)]])
activeMessages = messages.count()
logInfo("Pregel finished iteration " + i)
@@ -154,7 +169,9 @@ object Pregel extends Logging {
// count the iteration
i += 1
}
- messages.unpersist(blocking = false)
+ messageCheckpointer.unpersistDataSet()
+ graphCheckpointer.deleteAllCheckpoints()
+ messageCheckpointer.deleteAllCheckpoints()
g
} // end of apply
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala
similarity index 91%
rename from mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
rename to graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala
index 80074897567eb..fda501aa757d6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala
@@ -15,11 +15,12 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.impl
+package org.apache.spark.graphx.util
import org.apache.spark.SparkContext
import org.apache.spark.graphx.Graph
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.PeriodicCheckpointer
/**
@@ -74,9 +75,8 @@ import org.apache.spark.storage.StorageLevel
* @tparam VD Vertex descriptor type
* @tparam ED Edge descriptor type
*
- * TODO: Move this out of MLlib?
*/
-private[mllib] class PeriodicGraphCheckpointer[VD, ED](
+private[spark] class PeriodicGraphCheckpointer[VD, ED](
checkpointInterval: Int,
sc: SparkContext)
extends PeriodicCheckpointer[Graph[VD, ED]](checkpointInterval, sc) {
@@ -87,10 +87,13 @@ private[mllib] class PeriodicGraphCheckpointer[VD, ED](
override protected def persist(data: Graph[VD, ED]): Unit = {
if (data.vertices.getStorageLevel == StorageLevel.NONE) {
- data.vertices.persist()
+ /* We need to use cache because persist does not honor the default storage level requested
+ * when constructing the graph. Only cache does that.
+ */
+ data.vertices.cache()
}
if (data.edges.getStorageLevel == StorageLevel.NONE) {
- data.edges.persist()
+ data.edges.cache()
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala
similarity index 70%
rename from mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
rename to graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala
index a13e7f63a9296..e0c65e6940f66 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala
@@ -15,77 +15,81 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.impl
+package org.apache.spark.graphx.util
import org.apache.hadoop.fs.Path
import org.apache.spark.{SparkContext, SparkFunSuite}
-import org.apache.spark.graphx.{Edge, Graph}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.graphx.{Edge, Graph, LocalSparkContext}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
-class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class PeriodicGraphCheckpointerSuite extends SparkFunSuite with LocalSparkContext {
import PeriodicGraphCheckpointerSuite._
test("Persisting") {
var graphsToCheck = Seq.empty[GraphToCheck]
- val graph1 = createGraph(sc)
- val checkpointer =
- new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext)
- checkpointer.update(graph1)
- graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
- checkPersistence(graphsToCheck, 1)
-
- var iteration = 2
- while (iteration < 9) {
- val graph = createGraph(sc)
- checkpointer.update(graph)
- graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
- checkPersistence(graphsToCheck, iteration)
- iteration += 1
+ withSpark { sc =>
+ val graph1 = createGraph(sc)
+ val checkpointer =
+ new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext)
+ checkpointer.update(graph1)
+ graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
+ checkPersistence(graphsToCheck, 1)
+
+ var iteration = 2
+ while (iteration < 9) {
+ val graph = createGraph(sc)
+ checkpointer.update(graph)
+ graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
+ checkPersistence(graphsToCheck, iteration)
+ iteration += 1
+ }
}
}
test("Checkpointing") {
- val tempDir = Utils.createTempDir()
- val path = tempDir.toURI.toString
- val checkpointInterval = 2
- var graphsToCheck = Seq.empty[GraphToCheck]
- sc.setCheckpointDir(path)
- val graph1 = createGraph(sc)
- val checkpointer = new PeriodicGraphCheckpointer[Double, Double](
- checkpointInterval, graph1.vertices.sparkContext)
- checkpointer.update(graph1)
- graph1.edges.count()
- graph1.vertices.count()
- graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
- checkCheckpoint(graphsToCheck, 1, checkpointInterval)
-
- var iteration = 2
- while (iteration < 9) {
- val graph = createGraph(sc)
- checkpointer.update(graph)
- graph.vertices.count()
- graph.edges.count()
- graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
- checkCheckpoint(graphsToCheck, iteration, checkpointInterval)
- iteration += 1
- }
+ withSpark { sc =>
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ val checkpointInterval = 2
+ var graphsToCheck = Seq.empty[GraphToCheck]
+ sc.setCheckpointDir(path)
+ val graph1 = createGraph(sc)
+ val checkpointer = new PeriodicGraphCheckpointer[Double, Double](
+ checkpointInterval, graph1.vertices.sparkContext)
+ checkpointer.update(graph1)
+ graph1.edges.count()
+ graph1.vertices.count()
+ graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
+ checkCheckpoint(graphsToCheck, 1, checkpointInterval)
+
+ var iteration = 2
+ while (iteration < 9) {
+ val graph = createGraph(sc)
+ checkpointer.update(graph)
+ graph.vertices.count()
+ graph.edges.count()
+ graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
+ checkCheckpoint(graphsToCheck, iteration, checkpointInterval)
+ iteration += 1
+ }
- checkpointer.deleteAllCheckpoints()
- graphsToCheck.foreach { graph =>
- confirmCheckpointRemoved(graph.graph)
- }
+ checkpointer.deleteAllCheckpoints()
+ graphsToCheck.foreach { graph =>
+ confirmCheckpointRemoved(graph.graph)
+ }
- Utils.deleteRecursively(tempDir)
+ Utils.deleteRecursively(tempDir)
+ }
}
}
private object PeriodicGraphCheckpointerSuite {
+ private val defaultStorageLevel = StorageLevel.MEMORY_ONLY_SER
case class GraphToCheck(graph: Graph[Double, Double], gIndex: Int)
@@ -96,7 +100,8 @@ private object PeriodicGraphCheckpointerSuite {
Edge[Double](3, 4, 0))
def createGraph(sc: SparkContext): Graph[Double, Double] = {
- Graph.fromEdges[Double, Double](sc.parallelize(edges), 0)
+ Graph.fromEdges[Double, Double](
+ sc.parallelize(edges), 0, defaultStorageLevel, defaultStorageLevel)
}
def checkPersistence(graphs: Seq[GraphToCheck], iteration: Int): Unit = {
@@ -116,8 +121,8 @@ private object PeriodicGraphCheckpointerSuite {
assert(graph.vertices.getStorageLevel == StorageLevel.NONE)
assert(graph.edges.getStorageLevel == StorageLevel.NONE)
} else {
- assert(graph.vertices.getStorageLevel != StorageLevel.NONE)
- assert(graph.edges.getStorageLevel != StorageLevel.NONE)
+ assert(graph.vertices.getStorageLevel == defaultStorageLevel)
+ assert(graph.edges.getStorageLevel == defaultStorageLevel)
}
} catch {
case _: AssertionError =>
diff --git a/launcher/pom.xml b/launcher/pom.xml
index 025cd84f20f0e..e9b46c4cf0ffa 100644
--- a/launcher/pom.xml
+++ b/launcher/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../pom.xml
diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml
index 663f7fb0b010d..043d13609fd26 100644
--- a/mllib-local/pom.xml
+++ b/mllib-local/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../pom.xml
diff --git a/mllib/pom.xml b/mllib/pom.xml
index 82f840b0fc269..572670dc11b42 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../pom.xml
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
index f76b14eeeb542..7507c7539d4ef 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
@@ -458,9 +458,7 @@ private class LinearSVCAggregator(
*/
def add(instance: Instance): this.type = {
instance match { case Instance(label, weight, features) =>
- require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")
- require(numFeatures == features.size, s"Dimensions mismatch when adding new instance." +
- s" Expecting $numFeatures but got ${features.size}.")
+
if (weight == 0.0) return this
val localFeaturesStd = bcFeaturesStd.value
val localCoefficients = coefficientsArray
@@ -512,6 +510,7 @@ private class LinearSVCAggregator(
* @return This LinearSVCAggregator object.
*/
def merge(other: LinearSVCAggregator): this.type = {
+
if (other.weightSum != 0.0) {
weightSum += other.weightSum
lossSum += other.lossSum
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index 2f50dc7c85f35..e3026c8efa823 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -36,7 +36,6 @@ import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedL
EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel,
LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel,
OnlineLDAOptimizer => OldOnlineLDAOptimizer}
-import org.apache.spark.mllib.impl.PeriodicCheckpointer
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.linalg.MatrixImplicits._
import org.apache.spark.mllib.linalg.VectorImplicits._
@@ -45,9 +44,9 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions.{col, monotonically_increasing_id, udf}
import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.PeriodicCheckpointer
import org.apache.spark.util.VersionUtils
-
private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasMaxIter
with HasSeed with HasCheckpointInterval {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index d6093a01c671c..bff0d9bbb46ff 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -894,10 +894,10 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
private[regression] object Probit extends Link("probit") {
- override def link(mu: Double): Double = dist.Gaussian(0.0, 1.0).icdf(mu)
+ override def link(mu: Double): Double = dist.Gaussian(0.0, 1.0).inverseCdf(mu)
override def deriv(mu: Double): Double = {
- 1.0 / dist.Gaussian(0.0, 1.0).pdf(dist.Gaussian(0.0, 1.0).icdf(mu))
+ 1.0 / dist.Gaussian(0.0, 1.0).pdf(dist.Gaussian(0.0, 1.0).inverseCdf(mu))
}
override def unlink(eta: Double): Double = dist.Gaussian(0.0, 1.0).cdf(eta)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index f7e3c8fa5b6e6..eaad54985229e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -971,9 +971,6 @@ private class LeastSquaresAggregator(
*/
def add(instance: Instance): this.type = {
instance match { case Instance(label, weight, features) =>
- require(dim == features.size, s"Dimensions mismatch when adding new sample." +
- s" Expecting $dim but got ${features.size}.")
- require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")
if (weight == 0.0) return this
@@ -1005,8 +1002,6 @@ private class LeastSquaresAggregator(
* @return This LeastSquaresAggregator object.
*/
def merge(other: LeastSquaresAggregator): this.type = {
- require(dim == other.dim, s"Dimensions mismatch when merging with another " +
- s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.")
if (other.weightSum != 0) {
totalCnt += other.totalCnt
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
index 4c525c0714ec5..ce2bd7b430f43 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
@@ -21,12 +21,12 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
-import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy}
import org.apache.spark.mllib.tree.impurity.{Variance => OldVariance}
import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.util.PeriodicRDDCheckpointer
import org.apache.spark.storage.StorageLevel
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index 7fd722a332923..15b723dadcff7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -788,20 +788,14 @@ class DistributedLDAModel private[clustering] (
@Since("1.5.0")
def topTopicsPerDocument(k: Int): RDD[(Long, Array[Int], Array[Double])] = {
graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) =>
- // TODO: Remove work-around for the breeze bug.
- // https://github.com/scalanlp/breeze/issues/561
- val topIndices = if (k == topicCounts.length) {
- Seq.range(0, k)
- } else {
- argtopk(topicCounts, k)
- }
+ val topIndices = argtopk(topicCounts, k)
val sumCounts = sum(topicCounts)
val weights = if (sumCounts != 0) {
- topicCounts(topIndices) / sumCounts
+ topicCounts(topIndices).toArray.map(_ / sumCounts)
} else {
- topicCounts(topIndices)
+ topicCounts(topIndices).toArray
}
- (docID.toLong, topIndices.toArray, weights.toArray)
+ (docID.toLong, topIndices.toArray, weights)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
index 48bae4276c480..3697a9b46dd84 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
@@ -25,7 +25,7 @@ import breeze.stats.distributions.{Gamma, RandBasis}
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.graphx._
-import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
+import org.apache.spark.graphx.util.PeriodicGraphCheckpointer
import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
index 572959200f47f..3d6a9f8d84cac 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
@@ -191,8 +191,8 @@ class LBFGSSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers
// With smaller convergenceTol, it takes more steps.
assert(lossLBFGS3.length > lossLBFGS2.length)
- // Based on observation, lossLBFGS2 runs 5 iterations, no theoretically guaranteed.
- assert(lossLBFGS3.length == 6)
+ // Based on observation, lossLBFGS3 runs 7 iterations, no theoretically guaranteed.
+ assert(lossLBFGS3.length == 7)
assert((lossLBFGS3(4) - lossLBFGS3(5)) / lossLBFGS3(4) < convergenceTol)
}
diff --git a/pom.xml b/pom.xml
index c1174593c1922..b6654c1411d25 100644
--- a/pom.xml
+++ b/pom.xml
@@ -26,7 +26,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
pom
Spark Project Parent POM
http://spark.apache.org/
@@ -658,7 +658,7 @@
org.scalanlp
breeze_${scala.binary.version}
- 0.12
+ 0.13.1
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index feae76a087dec..dbf933f28a784 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -34,6 +34,10 @@ import com.typesafe.tools.mima.core.ProblemFilters._
*/
object MimaExcludes {
+ // Exclude rules for 2.3.x
+ lazy val v23excludes = v22excludes ++ Seq(
+ )
+
// Exclude rules for 2.2.x
lazy val v22excludes = v21excludes ++ Seq(
// [SPARK-19652][UI] Do auth checks for REST API access.
@@ -1003,6 +1007,7 @@ object MimaExcludes {
}
def excludes(version: String) = version match {
+ case v if v.startsWith("2.3") => v23excludes
case v if v.startsWith("2.2") => v22excludes
case v if v.startsWith("2.1") => v21excludes
case v if v.startsWith("2.0") => v20excludes
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index b4fc357e42d71..864968390ace9 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -190,9 +190,9 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
>>> blor = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight")
>>> blorModel = blor.fit(bdf)
>>> blorModel.coefficients
- DenseVector([5.5...])
+ DenseVector([5.4...])
>>> blorModel.intercept
- -2.68...
+ -2.63...
>>> mdf = sc.parallelize([
... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)),
... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], [])),
@@ -200,12 +200,10 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
>>> mlor = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight",
... family="multinomial")
>>> mlorModel = mlor.fit(mdf)
- >>> print(mlorModel.coefficientMatrix)
- DenseMatrix([[-2.3...],
- [ 0.2...],
- [ 2.1... ]])
+ >>> mlorModel.coefficientMatrix
+ DenseMatrix(3, 1, [-2.3..., 0.2..., 2.1...], 1)
>>> mlorModel.interceptVector
- DenseVector([2.0..., 0.8..., -2.8...])
+ DenseVector([2.1..., 0.6..., -2.8...])
>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF()
>>> result = blorModel.transform(test0).head()
>>> result.prediction
@@ -213,7 +211,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
>>> result.probability
DenseVector([0.99..., 0.00...])
>>> result.rawPrediction
- DenseVector([8.22..., -8.22...])
+ DenseVector([8.12..., -8.12...])
>>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF()
>>> blorModel.transform(test1).head().prediction
1.0
@@ -1490,9 +1488,9 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable):
>>> ovr = OneVsRest(classifier=lr)
>>> model = ovr.fit(df)
>>> [x.coefficients for x in model.models]
- [DenseVector([3.3925, 1.8785]), DenseVector([-4.3016, -6.3163]), DenseVector([-4.5855, 6.1785])]
+ [DenseVector([4.9791, 2.426]), DenseVector([-4.1198, -5.9326]), DenseVector([-3.314, 5.2423])]
>>> [x.intercept for x in model.models]
- [-3.64747..., 2.55078..., -1.10165...]
+ [-5.06544..., 2.30341..., -1.29133...]
>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 0.0))]).toDF()
>>> model.transform(test0).head().prediction
1.0
diff --git a/repl/pom.xml b/repl/pom.xml
index a256ae3b84183..6d133a3cfff7d 100644
--- a/repl/pom.xml
+++ b/repl/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../pom.xml
diff --git a/resource-managers/mesos/pom.xml b/resource-managers/mesos/pom.xml
index 03846d9f5a3be..20b53f2d8f987 100644
--- a/resource-managers/mesos/pom.xml
+++ b/resource-managers/mesos/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../../pom.xml
diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml
index a1b641c8eeb84..71d4ad681e169 100644
--- a/resource-managers/yarn/pom.xml
+++ b/resource-managers/yarn/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../../pom.xml
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index 765c92b8d3b9e..8d80f8eca5dba 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../../pom.xml
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
index 80ab75cc17fab..dcccbd0ed8d6b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
@@ -34,8 +34,7 @@ import org.apache.spark.sql.types.{DataType, StructType}
abstract class AbstractSqlParser extends ParserInterface with Logging {
/** Creates/Resolves DataType for a given SQL string. */
- def parseDataType(sqlText: String): DataType = parse(sqlText) { parser =>
- // TODO add this to the parser interface.
+ override def parseDataType(sqlText: String): DataType = parse(sqlText) { parser =>
astBuilder.visitSingleDataType(parser.singleDataType())
}
@@ -50,8 +49,10 @@ abstract class AbstractSqlParser extends ParserInterface with Logging {
}
/** Creates FunctionIdentifier for a given SQL string. */
- def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = parse(sqlText) { parser =>
- astBuilder.visitSingleFunctionIdentifier(parser.singleFunctionIdentifier())
+ override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = {
+ parse(sqlText) { parser =>
+ astBuilder.visitSingleFunctionIdentifier(parser.singleFunctionIdentifier())
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala
index db3598bde04d3..75240d2196222 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala
@@ -17,30 +17,51 @@
package org.apache.spark.sql.catalyst.parser
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{DataType, StructType}
/**
* Interface for a parser.
*/
+@DeveloperApi
trait ParserInterface {
- /** Creates LogicalPlan for a given SQL string. */
+ /**
+ * Parse a string to a [[LogicalPlan]].
+ */
+ @throws[ParseException]("Text cannot be parsed to a LogicalPlan")
def parsePlan(sqlText: String): LogicalPlan
- /** Creates Expression for a given SQL string. */
+ /**
+ * Parse a string to an [[Expression]].
+ */
+ @throws[ParseException]("Text cannot be parsed to an Expression")
def parseExpression(sqlText: String): Expression
- /** Creates TableIdentifier for a given SQL string. */
+ /**
+ * Parse a string to a [[TableIdentifier]].
+ */
+ @throws[ParseException]("Text cannot be parsed to a TableIdentifier")
def parseTableIdentifier(sqlText: String): TableIdentifier
- /** Creates FunctionIdentifier for a given SQL string. */
+ /**
+ * Parse a string to a [[FunctionIdentifier]].
+ */
+ @throws[ParseException]("Text cannot be parsed to a FunctionIdentifier")
def parseFunctionIdentifier(sqlText: String): FunctionIdentifier
/**
- * Creates StructType for a given SQL string, which is a comma separated list of field
- * definitions which will preserve the correct Hive metadata.
+ * Parse a string to a [[StructType]]. The passed SQL string should be a comma separated list
+ * of field definitions which will preserve the correct Hive metadata.
*/
+ @throws[ParseException]("Text cannot be parsed to a schema")
def parseTableSchema(sqlText: String): StructType
+
+ /**
+ * Parse a string to a [[DataType]].
+ */
+ @throws[ParseException]("Text cannot be parsed to a DataType")
+ def parseDataType(sqlText: String): DataType
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala
index af1a9cee2962a..c6c0a605d89ff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala
@@ -81,4 +81,10 @@ object StaticSQLConf {
"SQL configuration and the current database.")
.booleanConf
.createWithDefault(false)
+
+ val SPARK_SESSION_EXTENSIONS = buildStaticConf("spark.sql.extensions")
+ .doc("Name of the class used to configure Spark Session extensions. The class should " +
+ "implement Function1[SparkSessionExtension, Unit], and must have a no-args constructor.")
+ .stringConf
+ .createOptional
}
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index b203f31a76f03..e170133f0f0bf 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../../pom.xml
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index c6dcd93bbda66..06dd5500718de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1726,15 +1726,23 @@ class Dataset[T] private[sql](
// It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its
// constituent partitions each time a split is materialized which could result in
// overlapping splits. To prevent this, we explicitly sort each input partition to make the
- // ordering deterministic.
- // MapType cannot be sorted.
- val sorted = Sort(logicalPlan.output.filterNot(_.dataType.isInstanceOf[MapType])
- .map(SortOrder(_, Ascending)), global = false, logicalPlan)
+ // ordering deterministic. Note that MapTypes cannot be sorted and are explicitly pruned out
+ // from the sort order.
+ val sortOrder = logicalPlan.output
+ .filter(attr => RowOrdering.isOrderable(attr.dataType))
+ .map(SortOrder(_, Ascending))
+ val plan = if (sortOrder.nonEmpty) {
+ Sort(sortOrder, global = false, logicalPlan)
+ } else {
+ // SPARK-12662: If sort order is empty, we materialize the dataset to guarantee determinism
+ cache()
+ logicalPlan
+ }
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
new Dataset[T](
- sparkSession, Sample(x(0), x(1), withReplacement = false, seed, sorted)(), encoder)
+ sparkSession, Sample(x(0), x(1), withReplacement = false, seed, plan)(), encoder)
}.toArray
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 95f3463dfe62b..a519492ed8f4f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.ui.SQLListener
-import org.apache.spark.sql.internal.{BaseSessionStateBuilder, CatalogImpl, SessionState, SessionStateBuilder, SharedState}
+import org.apache.spark.sql.internal._
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.streaming._
@@ -77,11 +77,12 @@ import org.apache.spark.util.Utils
class SparkSession private(
@transient val sparkContext: SparkContext,
@transient private val existingSharedState: Option[SharedState],
- @transient private val parentSessionState: Option[SessionState])
+ @transient private val parentSessionState: Option[SessionState],
+ @transient private[sql] val extensions: SparkSessionExtensions)
extends Serializable with Closeable with Logging { self =>
private[sql] def this(sc: SparkContext) {
- this(sc, None, None)
+ this(sc, None, None, new SparkSessionExtensions)
}
sparkContext.assertNotStopped()
@@ -219,7 +220,7 @@ class SparkSession private(
* @since 2.0.0
*/
def newSession(): SparkSession = {
- new SparkSession(sparkContext, Some(sharedState), parentSessionState = None)
+ new SparkSession(sparkContext, Some(sharedState), parentSessionState = None, extensions)
}
/**
@@ -235,7 +236,7 @@ class SparkSession private(
* implementation is Hive, this will initialize the metastore, which may take some time.
*/
private[sql] def cloneSession(): SparkSession = {
- val result = new SparkSession(sparkContext, Some(sharedState), Some(sessionState))
+ val result = new SparkSession(sparkContext, Some(sharedState), Some(sessionState), extensions)
result.sessionState // force copy of SessionState
result
}
@@ -754,6 +755,8 @@ object SparkSession {
private[this] val options = new scala.collection.mutable.HashMap[String, String]
+ private[this] val extensions = new SparkSessionExtensions
+
private[this] var userSuppliedContext: Option[SparkContext] = None
private[spark] def sparkContext(sparkContext: SparkContext): Builder = synchronized {
@@ -847,6 +850,17 @@ object SparkSession {
}
}
+ /**
+ * Inject extensions into the [[SparkSession]]. This allows a user to add Analyzer rules,
+ * Optimizer rules, Planning Strategies or a customized parser.
+ *
+ * @since 2.2.0
+ */
+ def withExtensions(f: SparkSessionExtensions => Unit): Builder = {
+ f(extensions)
+ this
+ }
+
/**
* Gets an existing [[SparkSession]] or, if there is no existing one, creates a new
* one based on the options set in this builder.
@@ -903,7 +917,26 @@ object SparkSession {
}
sc
}
- session = new SparkSession(sparkContext)
+
+ // Initialize extensions if the user has defined a configurator class.
+ val extensionConfOption = sparkContext.conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS)
+ if (extensionConfOption.isDefined) {
+ val extensionConfClassName = extensionConfOption.get
+ try {
+ val extensionConfClass = Utils.classForName(extensionConfClassName)
+ val extensionConf = extensionConfClass.newInstance()
+ .asInstanceOf[SparkSessionExtensions => Unit]
+ extensionConf(extensions)
+ } catch {
+ // Ignore the error if we cannot find the class or when the class has the wrong type.
+ case e @ (_: ClassCastException |
+ _: ClassNotFoundException |
+ _: NoClassDefFoundError) =>
+ logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e)
+ }
+ }
+
+ session = new SparkSession(sparkContext, None, None, extensions)
options.foreach { case (k, v) => session.sessionState.conf.setConfString(k, v) }
defaultSession.set(session)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
new file mode 100644
index 0000000000000..f99c108161f94
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.collection.mutable
+
+import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
+import org.apache.spark.sql.catalyst.parser.ParserInterface
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+
+/**
+ * :: Experimental ::
+ * Holder for injection points to the [[SparkSession]]. We make NO guarantee about the stability
+ * regarding binary compatibility and source compatibility of methods here.
+ *
+ * This current provides the following extension points:
+ * - Analyzer Rules.
+ * - Check Analysis Rules
+ * - Optimizer Rules.
+ * - Planning Strategies.
+ * - Customized Parser.
+ * - (External) Catalog listeners.
+ *
+ * The extensions can be used by calling withExtension on the [[SparkSession.Builder]], for
+ * example:
+ * {{{
+ * SparkSession.builder()
+ * .master("...")
+ * .conf("...", true)
+ * .withExtensions { extensions =>
+ * extensions.injectResolutionRule { session =>
+ * ...
+ * }
+ * extensions.injectParser { (session, parser) =>
+ * ...
+ * }
+ * }
+ * .getOrCreate()
+ * }}}
+ *
+ * Note that none of the injected builders should assume that the [[SparkSession]] is fully
+ * initialized and should not touch the session's internals (e.g. the SessionState).
+ */
+@DeveloperApi
+@Experimental
+@InterfaceStability.Unstable
+class SparkSessionExtensions {
+ type RuleBuilder = SparkSession => Rule[LogicalPlan]
+ type CheckRuleBuilder = SparkSession => LogicalPlan => Unit
+ type StrategyBuilder = SparkSession => Strategy
+ type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface
+
+ private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]
+
+ /**
+ * Build the analyzer resolution `Rule`s using the given [[SparkSession]].
+ */
+ private[sql] def buildResolutionRules(session: SparkSession): Seq[Rule[LogicalPlan]] = {
+ resolutionRuleBuilders.map(_.apply(session))
+ }
+
+ /**
+ * Inject an analyzer resolution `Rule` builder into the [[SparkSession]]. These analyzer
+ * rules will be executed as part of the resolution phase of analysis.
+ */
+ def injectResolutionRule(builder: RuleBuilder): Unit = {
+ resolutionRuleBuilders += builder
+ }
+
+ private[this] val postHocResolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]
+
+ /**
+ * Build the analyzer post-hoc resolution `Rule`s using the given [[SparkSession]].
+ */
+ private[sql] def buildPostHocResolutionRules(session: SparkSession): Seq[Rule[LogicalPlan]] = {
+ postHocResolutionRuleBuilders.map(_.apply(session))
+ }
+
+ /**
+ * Inject an analyzer `Rule` builder into the [[SparkSession]]. These analyzer
+ * rules will be executed after resolution.
+ */
+ def injectPostHocResolutionRule(builder: RuleBuilder): Unit = {
+ postHocResolutionRuleBuilders += builder
+ }
+
+ private[this] val checkRuleBuilders = mutable.Buffer.empty[CheckRuleBuilder]
+
+ /**
+ * Build the check analysis `Rule`s using the given [[SparkSession]].
+ */
+ private[sql] def buildCheckRules(session: SparkSession): Seq[LogicalPlan => Unit] = {
+ checkRuleBuilders.map(_.apply(session))
+ }
+
+ /**
+ * Inject an check analysis `Rule` builder into the [[SparkSession]]. The injected rules will
+ * be executed after the analysis phase. A check analysis rule is used to detect problems with a
+ * LogicalPlan and should throw an exception when a problem is found.
+ */
+ def injectCheckRule(builder: CheckRuleBuilder): Unit = {
+ checkRuleBuilders += builder
+ }
+
+ private[this] val optimizerRules = mutable.Buffer.empty[RuleBuilder]
+
+ private[sql] def buildOptimizerRules(session: SparkSession): Seq[Rule[LogicalPlan]] = {
+ optimizerRules.map(_.apply(session))
+ }
+
+ /**
+ * Inject an optimizer `Rule` builder into the [[SparkSession]]. The injected rules will be
+ * executed during the operator optimization batch. An optimizer rule is used to improve the
+ * quality of an analyzed logical plan; these rules should never modify the result of the
+ * LogicalPlan.
+ */
+ def injectOptimizerRule(builder: RuleBuilder): Unit = {
+ optimizerRules += builder
+ }
+
+ private[this] val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder]
+
+ private[sql] def buildPlannerStrategies(session: SparkSession): Seq[Strategy] = {
+ plannerStrategyBuilders.map(_.apply(session))
+ }
+
+ /**
+ * Inject a planner `Strategy` builder into the [[SparkSession]]. The injected strategy will
+ * be used to convert a `LogicalPlan` into a executable
+ * [[org.apache.spark.sql.execution.SparkPlan]].
+ */
+ def injectPlannerStrategy(builder: StrategyBuilder): Unit = {
+ plannerStrategyBuilders += builder
+ }
+
+ private[this] val parserBuilders = mutable.Buffer.empty[ParserBuilder]
+
+ private[sql] def buildParser(
+ session: SparkSession,
+ initial: ParserInterface): ParserInterface = {
+ parserBuilders.foldLeft(initial) { (parser, builder) =>
+ builder(session, parser)
+ }
+ }
+
+ /**
+ * Inject a custom parser into the [[SparkSession]]. Note that the builder is passed a session
+ * and an initial parser. The latter allows for a user to create a partial parser and to delegate
+ * to the underlying parser for completeness. If a user injects more parsers, then the parsers
+ * are stacked on top of each other.
+ */
+ def injectParser(builder: ParserBuilder): Unit = {
+ parserBuilders += builder
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index df7c3678b7807..2a801d87b12eb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -18,8 +18,8 @@ package org.apache.spark.sql.internal
import org.apache.spark.SparkConf
import org.apache.spark.annotation.{Experimental, InterfaceStability}
-import org.apache.spark.sql.{ExperimentalMethods, SparkSession, Strategy, UDFRegistration}
-import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry, ResolveTimeZone}
+import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _}
+import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.parser.ParserInterface
@@ -63,6 +63,11 @@ abstract class BaseSessionStateBuilder(
*/
protected def newBuilder: NewBuilder
+ /**
+ * Session extensions defined in the [[SparkSession]].
+ */
+ protected def extensions: SparkSessionExtensions = session.extensions
+
/**
* Extract entries from `SparkConf` and put them in the `SQLConf`
*/
@@ -108,7 +113,9 @@ abstract class BaseSessionStateBuilder(
*
* Note: this depends on the `conf` field.
*/
- protected lazy val sqlParser: ParserInterface = new SparkSqlParser(conf)
+ protected lazy val sqlParser: ParserInterface = {
+ extensions.buildParser(session, new SparkSqlParser(conf))
+ }
/**
* ResourceLoader that is used to load function resources and jars.
@@ -171,7 +178,9 @@ abstract class BaseSessionStateBuilder(
*
* Note that this may NOT depend on the `analyzer` function.
*/
- protected def customResolutionRules: Seq[Rule[LogicalPlan]] = Nil
+ protected def customResolutionRules: Seq[Rule[LogicalPlan]] = {
+ extensions.buildResolutionRules(session)
+ }
/**
* Custom post resolution rules to add to the Analyzer. Prefer overriding this instead of
@@ -179,7 +188,9 @@ abstract class BaseSessionStateBuilder(
*
* Note that this may NOT depend on the `analyzer` function.
*/
- protected def customPostHocResolutionRules: Seq[Rule[LogicalPlan]] = Nil
+ protected def customPostHocResolutionRules: Seq[Rule[LogicalPlan]] = {
+ extensions.buildPostHocResolutionRules(session)
+ }
/**
* Custom check rules to add to the Analyzer. Prefer overriding this instead of creating
@@ -187,7 +198,9 @@ abstract class BaseSessionStateBuilder(
*
* Note that this may NOT depend on the `analyzer` function.
*/
- protected def customCheckRules: Seq[LogicalPlan => Unit] = Nil
+ protected def customCheckRules: Seq[LogicalPlan => Unit] = {
+ extensions.buildCheckRules(session)
+ }
/**
* Logical query plan optimizer.
@@ -207,7 +220,9 @@ abstract class BaseSessionStateBuilder(
*
* Note that this may NOT depend on the `optimizer` function.
*/
- protected def customOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil
+ protected def customOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = {
+ extensions.buildOptimizerRules(session)
+ }
/**
* Planner that converts optimized logical plans to physical plans.
@@ -227,7 +242,9 @@ abstract class BaseSessionStateBuilder(
*
* Note that this may NOT depend on the `planner` function.
*/
- protected def customPlanningStrategies: Seq[Strategy] = Nil
+ protected def customPlanningStrategies: Seq[Strategy] = {
+ extensions.buildPlannerStrategies(session)
+ }
/**
* Create a query execution object.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 97890a035a62f..dd118f88e3bb3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -68,25 +68,38 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
}
test("randomSplit on reordered partitions") {
- // This test ensures that randomSplit does not create overlapping splits even when the
- // underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of
- // rows in each partition.
- val data =
- sparkContext.parallelize(1 to 600, 2).mapPartitions(scala.util.Random.shuffle(_)).toDF("id")
- val splits = data.randomSplit(Array[Double](2, 3), seed = 1)
- assert(splits.length == 2, "wrong number of splits")
+ def testNonOverlappingSplits(data: DataFrame): Unit = {
+ val splits = data.randomSplit(Array[Double](2, 3), seed = 1)
+ assert(splits.length == 2, "wrong number of splits")
+
+ // Verify that the splits span the entire dataset
+ assert(splits.flatMap(_.collect()).toSet == data.collect().toSet)
- // Verify that the splits span the entire dataset
- assert(splits.flatMap(_.collect()).toSet == data.collect().toSet)
+ // Verify that the splits don't overlap
+ assert(splits(0).collect().toSeq.intersect(splits(1).collect().toSeq).isEmpty)
- // Verify that the splits don't overlap
- assert(splits(0).intersect(splits(1)).collect().isEmpty)
+ // Verify that the results are deterministic across multiple runs
+ val firstRun = splits.toSeq.map(_.collect().toSeq)
+ val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq)
+ assert(firstRun == secondRun)
+ }
- // Verify that the results are deterministic across multiple runs
- val firstRun = splits.toSeq.map(_.collect().toSeq)
- val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq)
- assert(firstRun == secondRun)
+ // This test ensures that randomSplit does not create overlapping splits even when the
+ // underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of
+ // rows in each partition.
+ val dataWithInts = sparkContext.parallelize(1 to 600, 2)
+ .mapPartitions(scala.util.Random.shuffle(_)).toDF("int")
+ val dataWithMaps = sparkContext.parallelize(1 to 600, 2)
+ .map(i => (i, Map(i -> i.toString)))
+ .mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "map")
+ val dataWithArrayOfMaps = sparkContext.parallelize(1 to 600, 2)
+ .map(i => (i, Array(Map(i -> i.toString))))
+ .mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "arrayOfMaps")
+
+ testNonOverlappingSplits(dataWithInts)
+ testNonOverlappingSplits(dataWithMaps)
+ testNonOverlappingSplits(dataWithArrayOfMaps)
}
test("pearson correlation") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
new file mode 100644
index 0000000000000..43db79663322a
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy}
+import org.apache.spark.sql.types.{DataType, StructType}
+
+/**
+ * Test cases for the [[SparkSessionExtensions]].
+ */
+class SparkSessionExtensionSuite extends SparkFunSuite {
+ type ExtensionsBuilder = SparkSessionExtensions => Unit
+ private def create(builder: ExtensionsBuilder): ExtensionsBuilder = builder
+
+ private def stop(spark: SparkSession): Unit = {
+ spark.stop()
+ SparkSession.clearActiveSession()
+ SparkSession.clearDefaultSession()
+ }
+
+ private def withSession(builder: ExtensionsBuilder)(f: SparkSession => Unit): Unit = {
+ val spark = SparkSession.builder().master("local[1]").withExtensions(builder).getOrCreate()
+ try f(spark) finally {
+ stop(spark)
+ }
+ }
+
+ test("inject analyzer rule") {
+ withSession(_.injectResolutionRule(MyRule)) { session =>
+ assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session)))
+ }
+ }
+
+ test("inject check analysis rule") {
+ withSession(_.injectCheckRule(MyCheckRule)) { session =>
+ assert(session.sessionState.analyzer.extendedCheckRules.contains(MyCheckRule(session)))
+ }
+ }
+
+ test("inject optimizer rule") {
+ withSession(_.injectOptimizerRule(MyRule)) { session =>
+ assert(session.sessionState.optimizer.batches.flatMap(_.rules).contains(MyRule(session)))
+ }
+ }
+
+ test("inject spark planner strategy") {
+ withSession(_.injectPlannerStrategy(MySparkStrategy)) { session =>
+ assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session)))
+ }
+ }
+
+ test("inject parser") {
+ val extension = create { extensions =>
+ extensions.injectParser((_, _) => CatalystSqlParser)
+ }
+ withSession(extension) { session =>
+ assert(session.sessionState.sqlParser == CatalystSqlParser)
+ }
+ }
+
+ test("inject stacked parsers") {
+ val extension = create { extensions =>
+ extensions.injectParser((_, _) => CatalystSqlParser)
+ extensions.injectParser(MyParser)
+ extensions.injectParser(MyParser)
+ }
+ withSession(extension) { session =>
+ val parser = MyParser(session, MyParser(session, CatalystSqlParser))
+ assert(session.sessionState.sqlParser == parser)
+ }
+ }
+
+ test("use custom class for extensions") {
+ val session = SparkSession.builder()
+ .master("local[1]")
+ .config("spark.sql.extensions", classOf[MyExtensions].getCanonicalName)
+ .getOrCreate()
+ try {
+ assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session)))
+ assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session)))
+ } finally {
+ stop(session)
+ }
+ }
+}
+
+case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan
+}
+
+case class MyCheckRule(spark: SparkSession) extends (LogicalPlan => Unit) {
+ override def apply(plan: LogicalPlan): Unit = { }
+}
+
+case class MySparkStrategy(spark: SparkSession) extends SparkStrategy {
+ override def apply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty
+}
+
+case class MyParser(spark: SparkSession, delegate: ParserInterface) extends ParserInterface {
+ override def parsePlan(sqlText: String): LogicalPlan =
+ delegate.parsePlan(sqlText)
+
+ override def parseExpression(sqlText: String): Expression =
+ delegate.parseExpression(sqlText)
+
+ override def parseTableIdentifier(sqlText: String): TableIdentifier =
+ delegate.parseTableIdentifier(sqlText)
+
+ override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier =
+ delegate.parseFunctionIdentifier(sqlText)
+
+ override def parseTableSchema(sqlText: String): StructType =
+ delegate.parseTableSchema(sqlText)
+
+ override def parseDataType(sqlText: String): DataType =
+ delegate.parseDataType(sqlText)
+}
+
+class MyExtensions extends (SparkSessionExtensions => Unit) {
+ def apply(e: SparkSessionExtensions): Unit = {
+ e.injectPlannerStrategy(MySparkStrategy)
+ e.injectResolutionRule(MyRule)
+ }
+}
diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml
index 9c879218ddc0d..a5a8e2640586c 100644
--- a/sql/hive-thriftserver/pom.xml
+++ b/sql/hive-thriftserver/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../../pom.xml
diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml
index 0f249d7d59351..09dcc4055e000 100644
--- a/sql/hive/pom.xml
+++ b/sql/hive/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../../pom.xml
diff --git a/streaming/pom.xml b/streaming/pom.xml
index de1be9c13e05f..fea882ad11230 100644
--- a/streaming/pom.xml
+++ b/streaming/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../pom.xml
diff --git a/tools/pom.xml b/tools/pom.xml
index 938ba2f6ac201..7ba4dc9842f1b 100644
--- a/tools/pom.xml
+++ b/tools/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.3.0-SNAPSHOT
../pom.xml