diff --git a/LICENSE b/LICENSE index 7950dd6ceb6db..c21032a1fd274 100644 --- a/LICENSE +++ b/LICENSE @@ -297,3 +297,4 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (MIT License) RowsGroup (http://datatables.net/license/mit) (MIT License) jsonFormatter (http://www.jqueryscript.net/other/jQuery-Plugin-For-Pretty-JSON-Formatting-jsonFormatter.html) (MIT License) modernizr (https://github.com/Modernizr/Modernizr/blob/master/LICENSE) + (MIT License) machinist (https://github.com/typelevel/machinist) diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/inst/tests/testthat/test_mllib_classification.R index 459254d271a58..af7cbdccf5d5d 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_classification.R +++ b/R/pkg/inst/tests/testthat/test_mllib_classification.R @@ -288,18 +288,18 @@ test_that("spark.mlp", { c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "1.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) + c("1.0", "1.0", "2.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2, initialWeights = c(0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 5.0, 5.0, 5.0, 5.0, 9.0, 9.0, 9.0, 9.0, 9.0)) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "1.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) + c("1.0", "1.0", "2.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "0.0", "2.0", "1.0", "0.0")) + c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "0.0", "0.0", "1.0", "0.0")) # Test formula works well df <- suppressWarnings(createDataFrame(iris)) @@ -310,8 +310,8 @@ test_that("spark.mlp", { expect_equal(summary$numOfOutputs, 3) expect_equal(summary$layers, c(4, 3)) expect_equal(length(summary$weights), 15) - expect_equal(head(summary$weights, 5), list(-1.1957257, -5.2693685, 7.4489734, -6.3751413, - -10.2376130), tolerance = 1e-6) + expect_equal(head(summary$weights, 5), list(-0.5793153, -4.652961, 6.216155, -6.649478, + -10.51147), tolerance = 1e-3) }) test_that("spark.naiveBayes", { diff --git a/assembly/pom.xml b/assembly/pom.xml index 9d8607d9137c6..742a4a1531e71 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index 8657af744c069..066970f24205f 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index 24c10fb1ddb9f..2de882adcb582 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index 5e5a80bd44467..a8488d8d1b704 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index 1356c4723b662..6b81fc2b2b040 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/common/tags/pom.xml b/common/tags/pom.xml index 9345dc8f0cc4b..f7e586ee777e1 100644 --- a/common/tags/pom.xml +++ b/common/tags/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index f03a4da5e7152..680d0413b1616 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/core/pom.xml b/core/pom.xml index 24ce36deeb169..7f245b5b6384a 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 99efc4893fda4..0ec1bdd39b2f5 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1350,7 +1350,7 @@ class SparkContext(config: SparkConf) extends Logging { @deprecated("use AccumulatorV2", "2.0.0") def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) : Accumulator[T] = { - val acc = new Accumulator(initialValue, param, Some(name)) + val acc = new Accumulator(initialValue, param, Option(name)) cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc)) acc } @@ -1379,7 +1379,7 @@ class SparkContext(config: SparkConf) extends Logging { @deprecated("use AccumulatorV2", "2.0.0") def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) : Accumulable[R, T] = { - val acc = new Accumulable(initialValue, param, Some(name)) + val acc = new Accumulable(initialValue, param, Option(name)) cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc)) acc } @@ -1414,7 +1414,7 @@ class SparkContext(config: SparkConf) extends Logging { * @note Accumulators must be registered before use, or it will throw exception. */ def register(acc: AccumulatorV2[_, _], name: String): Unit = { - acc.register(this, name = Some(name)) + acc.register(this, name = Option(name)) } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index d7d82800b8b55..6d8758a3d3b1d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -86,7 +86,7 @@ private[history] abstract class ApplicationHistoryProvider { * @return Count of application event logs that are currently under process */ def getEventLogsUnderProcess(): Int = { - return 0; + 0 } /** @@ -95,7 +95,7 @@ private[history] abstract class ApplicationHistoryProvider { * @return 0 if this is undefined or unsupported, otherwise the last updated time in millis */ def getLastUpdatedTime(): Long = { - return 0; + 0 } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 54f39f7620e5d..d9c8fda99ef97 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -301,6 +301,14 @@ object HistoryServer extends Logging { logDebug(s"Clearing ${SecurityManager.SPARK_AUTH_CONF}") config.set(SecurityManager.SPARK_AUTH_CONF, "false") } + + if (config.getBoolean("spark.acls.enable", config.getBoolean("spark.ui.acls.enable", false))) { + logInfo("Either spark.acls.enable or spark.ui.acls.enable is configured, clearing it and " + + "only using spark.history.ui.acl.enable") + config.set("spark.acls.enable", "false") + config.set("spark.ui.acls.enable", "false") + } + new SecurityManager(config) } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e524675332d1b..63a87e7f09d85 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -41,7 +41,7 @@ import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.{BoundedPriorityQueue, Utils} -import org.apache.spark.util.collection.OpenHashMap +import org.apache.spark.util.collection.{OpenHashMap, Utils => collectionUtils} import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler, SamplingUtils} @@ -1420,7 +1420,7 @@ abstract class RDD[T: ClassTag]( val mapRDDs = mapPartitions { items => // Priority keeps the largest elements, so let's reverse the ordering. val queue = new BoundedPriorityQueue[T](num)(ord.reverse) - queue ++= util.collection.Utils.takeOrdered(items, num)(ord) + queue ++= collectionUtils.takeOrdered(items, num)(ord) Iterator.single(queue) } if (mapRDDs.partitions.length == 0) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala similarity index 97% rename from mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala rename to core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala index 145dc22b7428e..ab72addb2466b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala +++ b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala @@ -15,11 +15,12 @@ * limitations under the License. */ -package org.apache.spark.mllib.impl +package org.apache.spark.rdd.util import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.PeriodicCheckpointer /** diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index 00f918c09c66b..f17b637754826 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -184,14 +184,27 @@ private[v1] class ApiRootResource extends ApiRequestContext { @Path("applications/{appId}/logs") def getEventLogs( @PathParam("appId") appId: String): EventLogDownloadResource = { - new EventLogDownloadResource(uiRoot, appId, None) + try { + // withSparkUI will throw NotFoundException if attemptId exists for this application. + // So we need to try again with attempt id "1". + withSparkUI(appId, None) { _ => + new EventLogDownloadResource(uiRoot, appId, None) + } + } catch { + case _: NotFoundException => + withSparkUI(appId, Some("1")) { _ => + new EventLogDownloadResource(uiRoot, appId, None) + } + } } @Path("applications/{appId}/{attemptId}/logs") def getEventLogs( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): EventLogDownloadResource = { - new EventLogDownloadResource(uiRoot, appId, Some(attemptId)) + withSparkUI(appId, Some(attemptId)) { _ => + new EventLogDownloadResource(uiRoot, appId, Some(attemptId)) + } } @Path("version") @@ -291,7 +304,6 @@ private[v1] trait ApiRequestContext { case None => throw new NotFoundException("no such app: " + appId) } } - } private[v1] class ForbiddenException(msg: String) extends WebApplicationException( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala similarity index 95% rename from mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala rename to core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala index 4dd498cd91b4e..ce06e18879a49 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala +++ b/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.mllib.impl +package org.apache.spark.util import scala.collection.mutable @@ -58,7 +58,7 @@ import org.apache.spark.storage.StorageLevel * @param sc SparkContext for the Datasets given to this checkpointer * @tparam T Dataset type, such as RDD[Double] */ -private[mllib] abstract class PeriodicCheckpointer[T]( +private[spark] abstract class PeriodicCheckpointer[T]( val checkpointInterval: Int, val sc: SparkContext) extends Logging { @@ -127,6 +127,16 @@ private[mllib] abstract class PeriodicCheckpointer[T]( /** Get list of checkpoint files for this given Dataset */ protected def getCheckpointFiles(data: T): Iterable[String] + /** + * Call this to unpersist the Dataset. + */ + def unpersistDataSet(): Unit = { + while (persistedQueue.nonEmpty) { + val dataToUnpersist = persistedQueue.dequeue() + unpersist(dataToUnpersist) + } + } + /** * Call this at the end to delete any remaining checkpoint files. */ diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 764156c3edc41..95acb9a54440f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -565,13 +565,12 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers assert(jobcount === getNumJobs("/jobs")) // no need to retain the test dir now the tests complete - logDir.deleteOnExit(); - + logDir.deleteOnExit() } test("ui and api authorization checks") { - val appId = "app-20161115172038-0000" - val owner = "jose" + val appId = "local-1430917381535" + val owner = "irashid" val admin = "root" val other = "alice" @@ -590,8 +589,11 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers val port = server.boundPort val testUrls = Seq( - s"http://localhost:$port/api/v1/applications/$appId/jobs", - s"http://localhost:$port/history/$appId/jobs/") + s"http://localhost:$port/api/v1/applications/$appId/1/jobs", + s"http://localhost:$port/history/$appId/1/jobs/", + s"http://localhost:$port/api/v1/applications/$appId/logs", + s"http://localhost:$port/api/v1/applications/$appId/1/logs", + s"http://localhost:$port/api/v1/applications/$appId/2/logs") tests.foreach { case (user, expectedCode) => testUrls.foreach { url => diff --git a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala index f9a7f151823a2..7f20206202cb9 100644 --- a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala @@ -135,7 +135,7 @@ class SortingSuite extends SparkFunSuite with SharedSparkContext with Matchers w } test("get a range of elements in an array not partitioned by a range partitioner") { - val pairArr = util.Random.shuffle((1 to 1000).toList).map(x => (x, x)) + val pairArr = scala.util.Random.shuffle((1 to 1000).toList).map(x => (x, x)) val pairs = sc.parallelize(pairArr, 10) val range = pairs.filterByRange(200, 800).collect() assert((800 to 200 by -1).toArray.sorted === range.map(_._1).sorted) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala b/core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala similarity index 96% rename from mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala rename to core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala index 14adf8c29fc6b..f9e1b791c86ea 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala @@ -15,18 +15,18 @@ * limitations under the License. */ -package org.apache.spark.mllib.impl +package org.apache.spark.utils import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkContext, SparkFunSuite} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite} import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.util.PeriodicRDDCheckpointer import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext { +class PeriodicRDDCheckpointerSuite extends SparkFunSuite with SharedSparkContext { import PeriodicRDDCheckpointerSuite._ diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 73dc1f9a1398c..9287bd47cf113 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -19,8 +19,8 @@ avro-mapred-1.7.7-hadoop2.jar base64-2.3.8.jar bcprov-jdk15on-1.51.jar bonecp-0.8.0.RELEASE.jar -breeze-macros_2.11-0.12.jar -breeze_2.11-0.12.jar +breeze-macros_2.11-0.13.1.jar +breeze_2.11-0.13.1.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar @@ -129,6 +129,8 @@ libfb303-0.9.3.jar libthrift-0.9.3.jar log4j-1.2.17.jar lz4-1.3.0.jar +machinist_2.11-0.6.1.jar +macro-compat_2.11-1.1.1.jar mail-1.4.7.jar mesos-1.0.0-shaded-protobuf.jar metrics-core-3.1.2.jar @@ -162,13 +164,13 @@ scala-parser-combinators_2.11-1.0.4.jar scala-reflect-2.11.8.jar scala-xml_2.11-1.0.2.jar scalap-2.11.8.jar -shapeless_2.11-2.0.0.jar +shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snappy-0.2.jar snappy-java-1.1.2.6.jar -spire-macros_2.11-0.7.4.jar -spire_2.11-0.7.4.jar +spire-macros_2.11-0.13.0.jar +spire_2.11-0.13.0.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 6bf0923a1d751..ab1de3d3dd8ad 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -19,8 +19,8 @@ avro-mapred-1.7.7-hadoop2.jar base64-2.3.8.jar bcprov-jdk15on-1.51.jar bonecp-0.8.0.RELEASE.jar -breeze-macros_2.11-0.12.jar -breeze_2.11-0.12.jar +breeze-macros_2.11-0.13.1.jar +breeze_2.11-0.13.1.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar @@ -130,6 +130,8 @@ libfb303-0.9.3.jar libthrift-0.9.3.jar log4j-1.2.17.jar lz4-1.3.0.jar +machinist_2.11-0.6.1.jar +macro-compat_2.11-1.1.1.jar mail-1.4.7.jar mesos-1.0.0-shaded-protobuf.jar metrics-core-3.1.2.jar @@ -163,13 +165,13 @@ scala-parser-combinators_2.11-1.0.4.jar scala-reflect-2.11.8.jar scala-xml_2.11-1.0.2.jar scalap-2.11.8.jar -shapeless_2.11-2.0.0.jar +shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snappy-0.2.jar snappy-java-1.1.2.6.jar -spire-macros_2.11-0.7.4.jar -spire_2.11-0.7.4.jar +spire-macros_2.11-0.13.0.jar +spire_2.11-0.13.0.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar diff --git a/docs/_config.yml b/docs/_config.yml index 83bb30598d153..21255ef7a5c45 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -14,8 +14,8 @@ include: # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 2.2.0-SNAPSHOT -SPARK_VERSION_SHORT: 2.2.0 +SPARK_VERSION: 2.3.0-SNAPSHOT +SPARK_VERSION_SHORT: 2.3.0 SCALA_BINARY_VERSION: "2.11" SCALA_VERSION: "2.11.7" MESOS_VERSION: 1.0.0 diff --git a/docs/building-spark.md b/docs/building-spark.md index e99b70f7a8b47..0f551bc66b8c9 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -232,7 +232,7 @@ Once installed, the `docker` service needs to be started, if not already running On Linux, this can be done by `sudo service docker start`. ./build/mvn install -DskipTests - ./build/mvn -Pdocker-integration-tests -pl :spark-docker-integration-tests_2.11 + ./build/mvn test -Pdocker-integration-tests -pl :spark-docker-integration-tests_2.11 or diff --git a/docs/configuration.md b/docs/configuration.md index 6b65d2bcb83e5..87b76322cae51 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -2149,6 +2149,20 @@ showDF(properties, numRows = 200, truncate = FALSE) +### GraphX + + + + + + + + +
Property NameDefaultMeaning
spark.graphx.pregel.checkpointInterval-1 + Checkpoint interval for graph and message in Pregel. It used to avoid stackOverflowError due to long lineage chains + after lots of iterations. The checkpoint is disabled by default. +
+ ### Deploy diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index e271b28fb4f28..76aa7b405e18c 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -708,7 +708,9 @@ messages remaining. > messaging function. These constraints allow additional optimization within GraphX. The following is the type signature of the [Pregel operator][GraphOps.pregel] as well as a *sketch* -of its implementation (note calls to graph.cache have been removed): +of its implementation (note: to avoid stackOverflowError due to long lineage chains, pregel support periodcally +checkpoint graph and messages by setting "spark.graphx.pregel.checkpointInterval" to a positive number, +say 10. And set checkpoint directory as well using SparkContext.setCheckpointDir(directory: String)): {% highlight scala %} class GraphOps[VD, ED] { @@ -722,6 +724,7 @@ class GraphOps[VD, ED] { : Graph[VD, ED] = { // Receive the initial message at each vertex var g = mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg) ).cache() + // compute the messages var messages = g.mapReduceTriplets(sendMsg, mergeMsg) var activeMessages = messages.count() @@ -734,8 +737,8 @@ class GraphOps[VD, ED] { // Send new messages, skipping edges where neither side received a message. We must cache // messages so it can be materialized on the next line, allowing us to uncache the previous // iteration. - messages = g.mapReduceTriplets( - sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() + messages = GraphXUtils.mapReduceTriplets( + g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() activeMessages = messages.count() i += 1 } diff --git a/examples/pom.xml b/examples/pom.xml index 91c2e81ebed2f..e674e799f24a3 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index 8948df2da89e2..0fa87a697454b 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml index f8ef8a991316d..71016bc645ca7 100644 --- a/external/flume-assembly/pom.xml +++ b/external/flume-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 6d547c46d6a2d..12630840e79dc 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 46901d64eda97..87a09642405a7 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10-assembly/pom.xml b/external/kafka-0-10-assembly/pom.xml index 295142cbfdff9..75df886ca44f6 100644 --- a/external/kafka-0-10-assembly/pom.xml +++ b/external/kafka-0-10-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index 6cf448e65e8b4..557d27296345f 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml index 88499240cd569..6c98cb04fcfa6 100644 --- a/external/kafka-0-10/pom.xml +++ b/external/kafka-0-10/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-8-assembly/pom.xml b/external/kafka-0-8-assembly/pom.xml index 3fedd9eda1959..f9c2dcb38dc0e 100644 --- a/external/kafka-0-8-assembly/pom.xml +++ b/external/kafka-0-8-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-8/pom.xml b/external/kafka-0-8/pom.xml index 8368a1f12218d..849c8b465f99e 100644 --- a/external/kafka-0-8/pom.xml +++ b/external/kafka-0-8/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/kinesis-asl-assembly/pom.xml b/external/kinesis-asl-assembly/pom.xml index 90bb0e4987c82..48783d65826aa 100644 --- a/external/kinesis-asl-assembly/pom.xml +++ b/external/kinesis-asl-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml index daa79e79163b9..40a751a652fa9 100644 --- a/external/kinesis-asl/pom.xml +++ b/external/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/spark-ganglia-lgpl/pom.xml b/external/spark-ganglia-lgpl/pom.xml index 7da27817ebafd..36d555066b181 100644 --- a/external/spark-ganglia-lgpl/pom.xml +++ b/external/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index 8df33660ea9d1..cb30e4a4af4bc 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index 646462b4a8350..755c6febc48e6 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -19,7 +19,10 @@ package org.apache.spark.graphx import scala.reflect.ClassTag +import org.apache.spark.graphx.util.PeriodicGraphCheckpointer import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.util.PeriodicRDDCheckpointer /** * Implements a Pregel-like bulk-synchronous message-passing API. @@ -122,27 +125,39 @@ object Pregel extends Logging { require(maxIterations > 0, s"Maximum number of iterations must be greater than 0," + s" but got ${maxIterations}") - var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache() + val checkpointInterval = graph.vertices.sparkContext.getConf + .getInt("spark.graphx.pregel.checkpointInterval", -1) + var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)) + val graphCheckpointer = new PeriodicGraphCheckpointer[VD, ED]( + checkpointInterval, graph.vertices.sparkContext) + graphCheckpointer.update(g) + // compute the messages var messages = GraphXUtils.mapReduceTriplets(g, sendMsg, mergeMsg) + val messageCheckpointer = new PeriodicRDDCheckpointer[(VertexId, A)]( + checkpointInterval, graph.vertices.sparkContext) + messageCheckpointer.update(messages.asInstanceOf[RDD[(VertexId, A)]]) var activeMessages = messages.count() + // Loop var prevG: Graph[VD, ED] = null var i = 0 while (activeMessages > 0 && i < maxIterations) { // Receive the messages and update the vertices. prevG = g - g = g.joinVertices(messages)(vprog).cache() + g = g.joinVertices(messages)(vprog) + graphCheckpointer.update(g) val oldMessages = messages // Send new messages, skipping edges where neither side received a message. We must cache // messages so it can be materialized on the next line, allowing us to uncache the previous // iteration. messages = GraphXUtils.mapReduceTriplets( - g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() + g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))) // The call to count() materializes `messages` and the vertices of `g`. This hides oldMessages // (depended on by the vertices of g) and the vertices of prevG (depended on by oldMessages // and the vertices of g). + messageCheckpointer.update(messages.asInstanceOf[RDD[(VertexId, A)]]) activeMessages = messages.count() logInfo("Pregel finished iteration " + i) @@ -154,7 +169,9 @@ object Pregel extends Logging { // count the iteration i += 1 } - messages.unpersist(blocking = false) + messageCheckpointer.unpersistDataSet() + graphCheckpointer.deleteAllCheckpoints() + messageCheckpointer.deleteAllCheckpoints() g } // end of apply diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala similarity index 91% rename from mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala rename to graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala index 80074897567eb..fda501aa757d6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala @@ -15,11 +15,12 @@ * limitations under the License. */ -package org.apache.spark.mllib.impl +package org.apache.spark.graphx.util import org.apache.spark.SparkContext import org.apache.spark.graphx.Graph import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.PeriodicCheckpointer /** @@ -74,9 +75,8 @@ import org.apache.spark.storage.StorageLevel * @tparam VD Vertex descriptor type * @tparam ED Edge descriptor type * - * TODO: Move this out of MLlib? */ -private[mllib] class PeriodicGraphCheckpointer[VD, ED]( +private[spark] class PeriodicGraphCheckpointer[VD, ED]( checkpointInterval: Int, sc: SparkContext) extends PeriodicCheckpointer[Graph[VD, ED]](checkpointInterval, sc) { @@ -87,10 +87,13 @@ private[mllib] class PeriodicGraphCheckpointer[VD, ED]( override protected def persist(data: Graph[VD, ED]): Unit = { if (data.vertices.getStorageLevel == StorageLevel.NONE) { - data.vertices.persist() + /* We need to use cache because persist does not honor the default storage level requested + * when constructing the graph. Only cache does that. + */ + data.vertices.cache() } if (data.edges.getStorageLevel == StorageLevel.NONE) { - data.edges.persist() + data.edges.cache() } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala similarity index 70% rename from mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala rename to graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala index a13e7f63a9296..e0c65e6940f66 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala @@ -15,77 +15,81 @@ * limitations under the License. */ -package org.apache.spark.mllib.impl +package org.apache.spark.graphx.util import org.apache.hadoop.fs.Path import org.apache.spark.{SparkContext, SparkFunSuite} -import org.apache.spark.graphx.{Edge, Graph} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.graphx.{Edge, Graph, LocalSparkContext} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext { +class PeriodicGraphCheckpointerSuite extends SparkFunSuite with LocalSparkContext { import PeriodicGraphCheckpointerSuite._ test("Persisting") { var graphsToCheck = Seq.empty[GraphToCheck] - val graph1 = createGraph(sc) - val checkpointer = - new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext) - checkpointer.update(graph1) - graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) - checkPersistence(graphsToCheck, 1) - - var iteration = 2 - while (iteration < 9) { - val graph = createGraph(sc) - checkpointer.update(graph) - graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) - checkPersistence(graphsToCheck, iteration) - iteration += 1 + withSpark { sc => + val graph1 = createGraph(sc) + val checkpointer = + new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext) + checkpointer.update(graph1) + graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) + checkPersistence(graphsToCheck, 1) + + var iteration = 2 + while (iteration < 9) { + val graph = createGraph(sc) + checkpointer.update(graph) + graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) + checkPersistence(graphsToCheck, iteration) + iteration += 1 + } } } test("Checkpointing") { - val tempDir = Utils.createTempDir() - val path = tempDir.toURI.toString - val checkpointInterval = 2 - var graphsToCheck = Seq.empty[GraphToCheck] - sc.setCheckpointDir(path) - val graph1 = createGraph(sc) - val checkpointer = new PeriodicGraphCheckpointer[Double, Double]( - checkpointInterval, graph1.vertices.sparkContext) - checkpointer.update(graph1) - graph1.edges.count() - graph1.vertices.count() - graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) - checkCheckpoint(graphsToCheck, 1, checkpointInterval) - - var iteration = 2 - while (iteration < 9) { - val graph = createGraph(sc) - checkpointer.update(graph) - graph.vertices.count() - graph.edges.count() - graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) - checkCheckpoint(graphsToCheck, iteration, checkpointInterval) - iteration += 1 - } + withSpark { sc => + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + val checkpointInterval = 2 + var graphsToCheck = Seq.empty[GraphToCheck] + sc.setCheckpointDir(path) + val graph1 = createGraph(sc) + val checkpointer = new PeriodicGraphCheckpointer[Double, Double]( + checkpointInterval, graph1.vertices.sparkContext) + checkpointer.update(graph1) + graph1.edges.count() + graph1.vertices.count() + graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) + checkCheckpoint(graphsToCheck, 1, checkpointInterval) + + var iteration = 2 + while (iteration < 9) { + val graph = createGraph(sc) + checkpointer.update(graph) + graph.vertices.count() + graph.edges.count() + graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) + checkCheckpoint(graphsToCheck, iteration, checkpointInterval) + iteration += 1 + } - checkpointer.deleteAllCheckpoints() - graphsToCheck.foreach { graph => - confirmCheckpointRemoved(graph.graph) - } + checkpointer.deleteAllCheckpoints() + graphsToCheck.foreach { graph => + confirmCheckpointRemoved(graph.graph) + } - Utils.deleteRecursively(tempDir) + Utils.deleteRecursively(tempDir) + } } } private object PeriodicGraphCheckpointerSuite { + private val defaultStorageLevel = StorageLevel.MEMORY_ONLY_SER case class GraphToCheck(graph: Graph[Double, Double], gIndex: Int) @@ -96,7 +100,8 @@ private object PeriodicGraphCheckpointerSuite { Edge[Double](3, 4, 0)) def createGraph(sc: SparkContext): Graph[Double, Double] = { - Graph.fromEdges[Double, Double](sc.parallelize(edges), 0) + Graph.fromEdges[Double, Double]( + sc.parallelize(edges), 0, defaultStorageLevel, defaultStorageLevel) } def checkPersistence(graphs: Seq[GraphToCheck], iteration: Int): Unit = { @@ -116,8 +121,8 @@ private object PeriodicGraphCheckpointerSuite { assert(graph.vertices.getStorageLevel == StorageLevel.NONE) assert(graph.edges.getStorageLevel == StorageLevel.NONE) } else { - assert(graph.vertices.getStorageLevel != StorageLevel.NONE) - assert(graph.edges.getStorageLevel != StorageLevel.NONE) + assert(graph.vertices.getStorageLevel == defaultStorageLevel) + assert(graph.edges.getStorageLevel == defaultStorageLevel) } } catch { case _: AssertionError => diff --git a/launcher/pom.xml b/launcher/pom.xml index 025cd84f20f0e..e9b46c4cf0ffa 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml index 663f7fb0b010d..043d13609fd26 100644 --- a/mllib-local/pom.xml +++ b/mllib-local/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/mllib/pom.xml b/mllib/pom.xml index 82f840b0fc269..572670dc11b42 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index f76b14eeeb542..7507c7539d4ef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -458,9 +458,7 @@ private class LinearSVCAggregator( */ def add(instance: Instance): this.type = { instance match { case Instance(label, weight, features) => - require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") - require(numFeatures == features.size, s"Dimensions mismatch when adding new instance." + - s" Expecting $numFeatures but got ${features.size}.") + if (weight == 0.0) return this val localFeaturesStd = bcFeaturesStd.value val localCoefficients = coefficientsArray @@ -512,6 +510,7 @@ private class LinearSVCAggregator( * @return This LinearSVCAggregator object. */ def merge(other: LinearSVCAggregator): this.type = { + if (other.weightSum != 0.0) { weightSum += other.weightSum lossSum += other.lossSum diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 2f50dc7c85f35..e3026c8efa823 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -36,7 +36,6 @@ import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedL EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, OnlineLDAOptimizer => OldOnlineLDAOptimizer} -import org.apache.spark.mllib.impl.PeriodicCheckpointer import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.MatrixImplicits._ import org.apache.spark.mllib.linalg.VectorImplicits._ @@ -45,9 +44,9 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.{col, monotonically_increasing_id, udf} import org.apache.spark.sql.types.StructType +import org.apache.spark.util.PeriodicCheckpointer import org.apache.spark.util.VersionUtils - private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasMaxIter with HasSeed with HasCheckpointInterval { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index d6093a01c671c..bff0d9bbb46ff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -894,10 +894,10 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine private[regression] object Probit extends Link("probit") { - override def link(mu: Double): Double = dist.Gaussian(0.0, 1.0).icdf(mu) + override def link(mu: Double): Double = dist.Gaussian(0.0, 1.0).inverseCdf(mu) override def deriv(mu: Double): Double = { - 1.0 / dist.Gaussian(0.0, 1.0).pdf(dist.Gaussian(0.0, 1.0).icdf(mu)) + 1.0 / dist.Gaussian(0.0, 1.0).pdf(dist.Gaussian(0.0, 1.0).inverseCdf(mu)) } override def unlink(eta: Double): Double = dist.Gaussian(0.0, 1.0).cdf(eta) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index f7e3c8fa5b6e6..eaad54985229e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -971,9 +971,6 @@ private class LeastSquaresAggregator( */ def add(instance: Instance): this.type = { instance match { case Instance(label, weight, features) => - require(dim == features.size, s"Dimensions mismatch when adding new sample." + - s" Expecting $dim but got ${features.size}.") - require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") if (weight == 0.0) return this @@ -1005,8 +1002,6 @@ private class LeastSquaresAggregator( * @return This LeastSquaresAggregator object. */ def merge(other: LeastSquaresAggregator): this.type = { - require(dim == other.dim, s"Dimensions mismatch when merging with another " + - s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.") if (other.weightSum != 0) { totalCnt += other.totalCnt diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index 4c525c0714ec5..ce2bd7b430f43 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -21,12 +21,12 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} -import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy} import org.apache.spark.mllib.tree.impurity.{Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.util.PeriodicRDDCheckpointer import org.apache.spark.storage.StorageLevel diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 7fd722a332923..15b723dadcff7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -788,20 +788,14 @@ class DistributedLDAModel private[clustering] ( @Since("1.5.0") def topTopicsPerDocument(k: Int): RDD[(Long, Array[Int], Array[Double])] = { graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) => - // TODO: Remove work-around for the breeze bug. - // https://github.com/scalanlp/breeze/issues/561 - val topIndices = if (k == topicCounts.length) { - Seq.range(0, k) - } else { - argtopk(topicCounts, k) - } + val topIndices = argtopk(topicCounts, k) val sumCounts = sum(topicCounts) val weights = if (sumCounts != 0) { - topicCounts(topIndices) / sumCounts + topicCounts(topIndices).toArray.map(_ / sumCounts) } else { - topicCounts(topIndices) + topicCounts(topIndices).toArray } - (docID.toLong, topIndices.toArray, weights.toArray) + (docID.toLong, topIndices.toArray, weights) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 48bae4276c480..3697a9b46dd84 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -25,7 +25,7 @@ import breeze.stats.distributions.{Gamma, RandBasis} import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.graphx._ -import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer +import org.apache.spark.graphx.util.PeriodicGraphCheckpointer import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index 572959200f47f..3d6a9f8d84cac 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -191,8 +191,8 @@ class LBFGSSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers // With smaller convergenceTol, it takes more steps. assert(lossLBFGS3.length > lossLBFGS2.length) - // Based on observation, lossLBFGS2 runs 5 iterations, no theoretically guaranteed. - assert(lossLBFGS3.length == 6) + // Based on observation, lossLBFGS3 runs 7 iterations, no theoretically guaranteed. + assert(lossLBFGS3.length == 7) assert((lossLBFGS3(4) - lossLBFGS3(5)) / lossLBFGS3(4) < convergenceTol) } diff --git a/pom.xml b/pom.xml index c1174593c1922..b6654c1411d25 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.sparkspark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOTpomSpark Project Parent POMhttp://spark.apache.org/ @@ -658,7 +658,7 @@ org.scalanlp breeze_${scala.binary.version} - 0.12 + 0.13.1 diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index feae76a087dec..dbf933f28a784 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -34,6 +34,10 @@ import com.typesafe.tools.mima.core.ProblemFilters._ */ object MimaExcludes { + // Exclude rules for 2.3.x + lazy val v23excludes = v22excludes ++ Seq( + ) + // Exclude rules for 2.2.x lazy val v22excludes = v21excludes ++ Seq( // [SPARK-19652][UI] Do auth checks for REST API access. @@ -1003,6 +1007,7 @@ object MimaExcludes { } def excludes(version: String) = version match { + case v if v.startsWith("2.3") => v23excludes case v if v.startsWith("2.2") => v22excludes case v if v.startsWith("2.1") => v21excludes case v if v.startsWith("2.0") => v20excludes diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index b4fc357e42d71..864968390ace9 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -190,9 +190,9 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> blor = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight") >>> blorModel = blor.fit(bdf) >>> blorModel.coefficients - DenseVector([5.5...]) + DenseVector([5.4...]) >>> blorModel.intercept - -2.68... + -2.63... >>> mdf = sc.parallelize([ ... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)), ... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], [])), @@ -200,12 +200,10 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> mlor = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", ... family="multinomial") >>> mlorModel = mlor.fit(mdf) - >>> print(mlorModel.coefficientMatrix) - DenseMatrix([[-2.3...], - [ 0.2...], - [ 2.1... ]]) + >>> mlorModel.coefficientMatrix + DenseMatrix(3, 1, [-2.3..., 0.2..., 2.1...], 1) >>> mlorModel.interceptVector - DenseVector([2.0..., 0.8..., -2.8...]) + DenseVector([2.1..., 0.6..., -2.8...]) >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF() >>> result = blorModel.transform(test0).head() >>> result.prediction @@ -213,7 +211,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> result.probability DenseVector([0.99..., 0.00...]) >>> result.rawPrediction - DenseVector([8.22..., -8.22...]) + DenseVector([8.12..., -8.12...]) >>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF() >>> blorModel.transform(test1).head().prediction 1.0 @@ -1490,9 +1488,9 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable): >>> ovr = OneVsRest(classifier=lr) >>> model = ovr.fit(df) >>> [x.coefficients for x in model.models] - [DenseVector([3.3925, 1.8785]), DenseVector([-4.3016, -6.3163]), DenseVector([-4.5855, 6.1785])] + [DenseVector([4.9791, 2.426]), DenseVector([-4.1198, -5.9326]), DenseVector([-3.314, 5.2423])] >>> [x.intercept for x in model.models] - [-3.64747..., 2.55078..., -1.10165...] + [-5.06544..., 2.30341..., -1.29133...] >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 0.0))]).toDF() >>> model.transform(test0).head().prediction 1.0 diff --git a/repl/pom.xml b/repl/pom.xml index a256ae3b84183..6d133a3cfff7d 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/resource-managers/mesos/pom.xml b/resource-managers/mesos/pom.xml index 03846d9f5a3be..20b53f2d8f987 100644 --- a/resource-managers/mesos/pom.xml +++ b/resource-managers/mesos/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml index a1b641c8eeb84..71d4ad681e169 100644 --- a/resource-managers/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 765c92b8d3b9e..8d80f8eca5dba 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index 80ab75cc17fab..dcccbd0ed8d6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -34,8 +34,7 @@ import org.apache.spark.sql.types.{DataType, StructType} abstract class AbstractSqlParser extends ParserInterface with Logging { /** Creates/Resolves DataType for a given SQL string. */ - def parseDataType(sqlText: String): DataType = parse(sqlText) { parser => - // TODO add this to the parser interface. + override def parseDataType(sqlText: String): DataType = parse(sqlText) { parser => astBuilder.visitSingleDataType(parser.singleDataType()) } @@ -50,8 +49,10 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { } /** Creates FunctionIdentifier for a given SQL string. */ - def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = parse(sqlText) { parser => - astBuilder.visitSingleFunctionIdentifier(parser.singleFunctionIdentifier()) + override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = { + parse(sqlText) { parser => + astBuilder.visitSingleFunctionIdentifier(parser.singleFunctionIdentifier()) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala index db3598bde04d3..75240d2196222 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala @@ -17,30 +17,51 @@ package org.apache.spark.sql.catalyst.parser +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructType} /** * Interface for a parser. */ +@DeveloperApi trait ParserInterface { - /** Creates LogicalPlan for a given SQL string. */ + /** + * Parse a string to a [[LogicalPlan]]. + */ + @throws[ParseException]("Text cannot be parsed to a LogicalPlan") def parsePlan(sqlText: String): LogicalPlan - /** Creates Expression for a given SQL string. */ + /** + * Parse a string to an [[Expression]]. + */ + @throws[ParseException]("Text cannot be parsed to an Expression") def parseExpression(sqlText: String): Expression - /** Creates TableIdentifier for a given SQL string. */ + /** + * Parse a string to a [[TableIdentifier]]. + */ + @throws[ParseException]("Text cannot be parsed to a TableIdentifier") def parseTableIdentifier(sqlText: String): TableIdentifier - /** Creates FunctionIdentifier for a given SQL string. */ + /** + * Parse a string to a [[FunctionIdentifier]]. + */ + @throws[ParseException]("Text cannot be parsed to a FunctionIdentifier") def parseFunctionIdentifier(sqlText: String): FunctionIdentifier /** - * Creates StructType for a given SQL string, which is a comma separated list of field - * definitions which will preserve the correct Hive metadata. + * Parse a string to a [[StructType]]. The passed SQL string should be a comma separated list + * of field definitions which will preserve the correct Hive metadata. */ + @throws[ParseException]("Text cannot be parsed to a schema") def parseTableSchema(sqlText: String): StructType + + /** + * Parse a string to a [[DataType]]. + */ + @throws[ParseException]("Text cannot be parsed to a DataType") + def parseDataType(sqlText: String): DataType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index af1a9cee2962a..c6c0a605d89ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -81,4 +81,10 @@ object StaticSQLConf { "SQL configuration and the current database.") .booleanConf .createWithDefault(false) + + val SPARK_SESSION_EXTENSIONS = buildStaticConf("spark.sql.extensions") + .doc("Name of the class used to configure Spark Session extensions. The class should " + + "implement Function1[SparkSessionExtension, Unit], and must have a no-args constructor.") + .stringConf + .createOptional } diff --git a/sql/core/pom.xml b/sql/core/pom.xml index b203f31a76f03..e170133f0f0bf 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c6dcd93bbda66..06dd5500718de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1726,15 +1726,23 @@ class Dataset[T] private[sql]( // It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its // constituent partitions each time a split is materialized which could result in // overlapping splits. To prevent this, we explicitly sort each input partition to make the - // ordering deterministic. - // MapType cannot be sorted. - val sorted = Sort(logicalPlan.output.filterNot(_.dataType.isInstanceOf[MapType]) - .map(SortOrder(_, Ascending)), global = false, logicalPlan) + // ordering deterministic. Note that MapTypes cannot be sorted and are explicitly pruned out + // from the sort order. + val sortOrder = logicalPlan.output + .filter(attr => RowOrdering.isOrderable(attr.dataType)) + .map(SortOrder(_, Ascending)) + val plan = if (sortOrder.nonEmpty) { + Sort(sortOrder, global = false, logicalPlan) + } else { + // SPARK-12662: If sort order is empty, we materialize the dataset to guarantee determinism + cache() + logicalPlan + } val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => new Dataset[T]( - sparkSession, Sample(x(0), x(1), withReplacement = false, seed, sorted)(), encoder) + sparkSession, Sample(x(0), x(1), withReplacement = false, seed, plan)(), encoder) }.toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 95f3463dfe62b..a519492ed8f4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.ui.SQLListener -import org.apache.spark.sql.internal.{BaseSessionStateBuilder, CatalogImpl, SessionState, SessionStateBuilder, SharedState} +import org.apache.spark.sql.internal._ import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.streaming._ @@ -77,11 +77,12 @@ import org.apache.spark.util.Utils class SparkSession private( @transient val sparkContext: SparkContext, @transient private val existingSharedState: Option[SharedState], - @transient private val parentSessionState: Option[SessionState]) + @transient private val parentSessionState: Option[SessionState], + @transient private[sql] val extensions: SparkSessionExtensions) extends Serializable with Closeable with Logging { self => private[sql] def this(sc: SparkContext) { - this(sc, None, None) + this(sc, None, None, new SparkSessionExtensions) } sparkContext.assertNotStopped() @@ -219,7 +220,7 @@ class SparkSession private( * @since 2.0.0 */ def newSession(): SparkSession = { - new SparkSession(sparkContext, Some(sharedState), parentSessionState = None) + new SparkSession(sparkContext, Some(sharedState), parentSessionState = None, extensions) } /** @@ -235,7 +236,7 @@ class SparkSession private( * implementation is Hive, this will initialize the metastore, which may take some time. */ private[sql] def cloneSession(): SparkSession = { - val result = new SparkSession(sparkContext, Some(sharedState), Some(sessionState)) + val result = new SparkSession(sparkContext, Some(sharedState), Some(sessionState), extensions) result.sessionState // force copy of SessionState result } @@ -754,6 +755,8 @@ object SparkSession { private[this] val options = new scala.collection.mutable.HashMap[String, String] + private[this] val extensions = new SparkSessionExtensions + private[this] var userSuppliedContext: Option[SparkContext] = None private[spark] def sparkContext(sparkContext: SparkContext): Builder = synchronized { @@ -847,6 +850,17 @@ object SparkSession { } } + /** + * Inject extensions into the [[SparkSession]]. This allows a user to add Analyzer rules, + * Optimizer rules, Planning Strategies or a customized parser. + * + * @since 2.2.0 + */ + def withExtensions(f: SparkSessionExtensions => Unit): Builder = { + f(extensions) + this + } + /** * Gets an existing [[SparkSession]] or, if there is no existing one, creates a new * one based on the options set in this builder. @@ -903,7 +917,26 @@ object SparkSession { } sc } - session = new SparkSession(sparkContext) + + // Initialize extensions if the user has defined a configurator class. + val extensionConfOption = sparkContext.conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS) + if (extensionConfOption.isDefined) { + val extensionConfClassName = extensionConfOption.get + try { + val extensionConfClass = Utils.classForName(extensionConfClassName) + val extensionConf = extensionConfClass.newInstance() + .asInstanceOf[SparkSessionExtensions => Unit] + extensionConf(extensions) + } catch { + // Ignore the error if we cannot find the class or when the class has the wrong type. + case e @ (_: ClassCastException | + _: ClassNotFoundException | + _: NoClassDefFoundError) => + logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e) + } + } + + session = new SparkSession(sparkContext, None, None, extensions) options.foreach { case (k, v) => session.sessionState.conf.setConfString(k, v) } defaultSession.set(session) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala new file mode 100644 index 0000000000000..f99c108161f94 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.mutable + +import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * :: Experimental :: + * Holder for injection points to the [[SparkSession]]. We make NO guarantee about the stability + * regarding binary compatibility and source compatibility of methods here. + * + * This current provides the following extension points: + * - Analyzer Rules. + * - Check Analysis Rules + * - Optimizer Rules. + * - Planning Strategies. + * - Customized Parser. + * - (External) Catalog listeners. + * + * The extensions can be used by calling withExtension on the [[SparkSession.Builder]], for + * example: + * {{{ + * SparkSession.builder() + * .master("...") + * .conf("...", true) + * .withExtensions { extensions => + * extensions.injectResolutionRule { session => + * ... + * } + * extensions.injectParser { (session, parser) => + * ... + * } + * } + * .getOrCreate() + * }}} + * + * Note that none of the injected builders should assume that the [[SparkSession]] is fully + * initialized and should not touch the session's internals (e.g. the SessionState). + */ +@DeveloperApi +@Experimental +@InterfaceStability.Unstable +class SparkSessionExtensions { + type RuleBuilder = SparkSession => Rule[LogicalPlan] + type CheckRuleBuilder = SparkSession => LogicalPlan => Unit + type StrategyBuilder = SparkSession => Strategy + type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface + + private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] + + /** + * Build the analyzer resolution `Rule`s using the given [[SparkSession]]. + */ + private[sql] def buildResolutionRules(session: SparkSession): Seq[Rule[LogicalPlan]] = { + resolutionRuleBuilders.map(_.apply(session)) + } + + /** + * Inject an analyzer resolution `Rule` builder into the [[SparkSession]]. These analyzer + * rules will be executed as part of the resolution phase of analysis. + */ + def injectResolutionRule(builder: RuleBuilder): Unit = { + resolutionRuleBuilders += builder + } + + private[this] val postHocResolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] + + /** + * Build the analyzer post-hoc resolution `Rule`s using the given [[SparkSession]]. + */ + private[sql] def buildPostHocResolutionRules(session: SparkSession): Seq[Rule[LogicalPlan]] = { + postHocResolutionRuleBuilders.map(_.apply(session)) + } + + /** + * Inject an analyzer `Rule` builder into the [[SparkSession]]. These analyzer + * rules will be executed after resolution. + */ + def injectPostHocResolutionRule(builder: RuleBuilder): Unit = { + postHocResolutionRuleBuilders += builder + } + + private[this] val checkRuleBuilders = mutable.Buffer.empty[CheckRuleBuilder] + + /** + * Build the check analysis `Rule`s using the given [[SparkSession]]. + */ + private[sql] def buildCheckRules(session: SparkSession): Seq[LogicalPlan => Unit] = { + checkRuleBuilders.map(_.apply(session)) + } + + /** + * Inject an check analysis `Rule` builder into the [[SparkSession]]. The injected rules will + * be executed after the analysis phase. A check analysis rule is used to detect problems with a + * LogicalPlan and should throw an exception when a problem is found. + */ + def injectCheckRule(builder: CheckRuleBuilder): Unit = { + checkRuleBuilders += builder + } + + private[this] val optimizerRules = mutable.Buffer.empty[RuleBuilder] + + private[sql] def buildOptimizerRules(session: SparkSession): Seq[Rule[LogicalPlan]] = { + optimizerRules.map(_.apply(session)) + } + + /** + * Inject an optimizer `Rule` builder into the [[SparkSession]]. The injected rules will be + * executed during the operator optimization batch. An optimizer rule is used to improve the + * quality of an analyzed logical plan; these rules should never modify the result of the + * LogicalPlan. + */ + def injectOptimizerRule(builder: RuleBuilder): Unit = { + optimizerRules += builder + } + + private[this] val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder] + + private[sql] def buildPlannerStrategies(session: SparkSession): Seq[Strategy] = { + plannerStrategyBuilders.map(_.apply(session)) + } + + /** + * Inject a planner `Strategy` builder into the [[SparkSession]]. The injected strategy will + * be used to convert a `LogicalPlan` into a executable + * [[org.apache.spark.sql.execution.SparkPlan]]. + */ + def injectPlannerStrategy(builder: StrategyBuilder): Unit = { + plannerStrategyBuilders += builder + } + + private[this] val parserBuilders = mutable.Buffer.empty[ParserBuilder] + + private[sql] def buildParser( + session: SparkSession, + initial: ParserInterface): ParserInterface = { + parserBuilders.foldLeft(initial) { (parser, builder) => + builder(session, parser) + } + } + + /** + * Inject a custom parser into the [[SparkSession]]. Note that the builder is passed a session + * and an initial parser. The latter allows for a user to create a partial parser and to delegate + * to the underlying parser for completeness. If a user injects more parsers, then the parsers + * are stacked on top of each other. + */ + def injectParser(builder: ParserBuilder): Unit = { + parserBuilders += builder + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index df7c3678b7807..2a801d87b12eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.internal import org.apache.spark.SparkConf import org.apache.spark.annotation.{Experimental, InterfaceStability} -import org.apache.spark.sql.{ExperimentalMethods, SparkSession, Strategy, UDFRegistration} -import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry, ResolveTimeZone} +import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface @@ -63,6 +63,11 @@ abstract class BaseSessionStateBuilder( */ protected def newBuilder: NewBuilder + /** + * Session extensions defined in the [[SparkSession]]. + */ + protected def extensions: SparkSessionExtensions = session.extensions + /** * Extract entries from `SparkConf` and put them in the `SQLConf` */ @@ -108,7 +113,9 @@ abstract class BaseSessionStateBuilder( * * Note: this depends on the `conf` field. */ - protected lazy val sqlParser: ParserInterface = new SparkSqlParser(conf) + protected lazy val sqlParser: ParserInterface = { + extensions.buildParser(session, new SparkSqlParser(conf)) + } /** * ResourceLoader that is used to load function resources and jars. @@ -171,7 +178,9 @@ abstract class BaseSessionStateBuilder( * * Note that this may NOT depend on the `analyzer` function. */ - protected def customResolutionRules: Seq[Rule[LogicalPlan]] = Nil + protected def customResolutionRules: Seq[Rule[LogicalPlan]] = { + extensions.buildResolutionRules(session) + } /** * Custom post resolution rules to add to the Analyzer. Prefer overriding this instead of @@ -179,7 +188,9 @@ abstract class BaseSessionStateBuilder( * * Note that this may NOT depend on the `analyzer` function. */ - protected def customPostHocResolutionRules: Seq[Rule[LogicalPlan]] = Nil + protected def customPostHocResolutionRules: Seq[Rule[LogicalPlan]] = { + extensions.buildPostHocResolutionRules(session) + } /** * Custom check rules to add to the Analyzer. Prefer overriding this instead of creating @@ -187,7 +198,9 @@ abstract class BaseSessionStateBuilder( * * Note that this may NOT depend on the `analyzer` function. */ - protected def customCheckRules: Seq[LogicalPlan => Unit] = Nil + protected def customCheckRules: Seq[LogicalPlan => Unit] = { + extensions.buildCheckRules(session) + } /** * Logical query plan optimizer. @@ -207,7 +220,9 @@ abstract class BaseSessionStateBuilder( * * Note that this may NOT depend on the `optimizer` function. */ - protected def customOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil + protected def customOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = { + extensions.buildOptimizerRules(session) + } /** * Planner that converts optimized logical plans to physical plans. @@ -227,7 +242,9 @@ abstract class BaseSessionStateBuilder( * * Note that this may NOT depend on the `planner` function. */ - protected def customPlanningStrategies: Seq[Strategy] = Nil + protected def customPlanningStrategies: Seq[Strategy] = { + extensions.buildPlannerStrategies(session) + } /** * Create a query execution object. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 97890a035a62f..dd118f88e3bb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -68,25 +68,38 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { } test("randomSplit on reordered partitions") { - // This test ensures that randomSplit does not create overlapping splits even when the - // underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of - // rows in each partition. - val data = - sparkContext.parallelize(1 to 600, 2).mapPartitions(scala.util.Random.shuffle(_)).toDF("id") - val splits = data.randomSplit(Array[Double](2, 3), seed = 1) - assert(splits.length == 2, "wrong number of splits") + def testNonOverlappingSplits(data: DataFrame): Unit = { + val splits = data.randomSplit(Array[Double](2, 3), seed = 1) + assert(splits.length == 2, "wrong number of splits") + + // Verify that the splits span the entire dataset + assert(splits.flatMap(_.collect()).toSet == data.collect().toSet) - // Verify that the splits span the entire dataset - assert(splits.flatMap(_.collect()).toSet == data.collect().toSet) + // Verify that the splits don't overlap + assert(splits(0).collect().toSeq.intersect(splits(1).collect().toSeq).isEmpty) - // Verify that the splits don't overlap - assert(splits(0).intersect(splits(1)).collect().isEmpty) + // Verify that the results are deterministic across multiple runs + val firstRun = splits.toSeq.map(_.collect().toSeq) + val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq) + assert(firstRun == secondRun) + } - // Verify that the results are deterministic across multiple runs - val firstRun = splits.toSeq.map(_.collect().toSeq) - val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq) - assert(firstRun == secondRun) + // This test ensures that randomSplit does not create overlapping splits even when the + // underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of + // rows in each partition. + val dataWithInts = sparkContext.parallelize(1 to 600, 2) + .mapPartitions(scala.util.Random.shuffle(_)).toDF("int") + val dataWithMaps = sparkContext.parallelize(1 to 600, 2) + .map(i => (i, Map(i -> i.toString))) + .mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "map") + val dataWithArrayOfMaps = sparkContext.parallelize(1 to 600, 2) + .map(i => (i, Array(Map(i -> i.toString)))) + .mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "arrayOfMaps") + + testNonOverlappingSplits(dataWithInts) + testNonOverlappingSplits(dataWithMaps) + testNonOverlappingSplits(dataWithArrayOfMaps) } test("pearson correlation") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala new file mode 100644 index 0000000000000..43db79663322a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy} +import org.apache.spark.sql.types.{DataType, StructType} + +/** + * Test cases for the [[SparkSessionExtensions]]. + */ +class SparkSessionExtensionSuite extends SparkFunSuite { + type ExtensionsBuilder = SparkSessionExtensions => Unit + private def create(builder: ExtensionsBuilder): ExtensionsBuilder = builder + + private def stop(spark: SparkSession): Unit = { + spark.stop() + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } + + private def withSession(builder: ExtensionsBuilder)(f: SparkSession => Unit): Unit = { + val spark = SparkSession.builder().master("local[1]").withExtensions(builder).getOrCreate() + try f(spark) finally { + stop(spark) + } + } + + test("inject analyzer rule") { + withSession(_.injectResolutionRule(MyRule)) { session => + assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session))) + } + } + + test("inject check analysis rule") { + withSession(_.injectCheckRule(MyCheckRule)) { session => + assert(session.sessionState.analyzer.extendedCheckRules.contains(MyCheckRule(session))) + } + } + + test("inject optimizer rule") { + withSession(_.injectOptimizerRule(MyRule)) { session => + assert(session.sessionState.optimizer.batches.flatMap(_.rules).contains(MyRule(session))) + } + } + + test("inject spark planner strategy") { + withSession(_.injectPlannerStrategy(MySparkStrategy)) { session => + assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session))) + } + } + + test("inject parser") { + val extension = create { extensions => + extensions.injectParser((_, _) => CatalystSqlParser) + } + withSession(extension) { session => + assert(session.sessionState.sqlParser == CatalystSqlParser) + } + } + + test("inject stacked parsers") { + val extension = create { extensions => + extensions.injectParser((_, _) => CatalystSqlParser) + extensions.injectParser(MyParser) + extensions.injectParser(MyParser) + } + withSession(extension) { session => + val parser = MyParser(session, MyParser(session, CatalystSqlParser)) + assert(session.sessionState.sqlParser == parser) + } + } + + test("use custom class for extensions") { + val session = SparkSession.builder() + .master("local[1]") + .config("spark.sql.extensions", classOf[MyExtensions].getCanonicalName) + .getOrCreate() + try { + assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session))) + assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session))) + } finally { + stop(session) + } + } +} + +case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan +} + +case class MyCheckRule(spark: SparkSession) extends (LogicalPlan => Unit) { + override def apply(plan: LogicalPlan): Unit = { } +} + +case class MySparkStrategy(spark: SparkSession) extends SparkStrategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty +} + +case class MyParser(spark: SparkSession, delegate: ParserInterface) extends ParserInterface { + override def parsePlan(sqlText: String): LogicalPlan = + delegate.parsePlan(sqlText) + + override def parseExpression(sqlText: String): Expression = + delegate.parseExpression(sqlText) + + override def parseTableIdentifier(sqlText: String): TableIdentifier = + delegate.parseTableIdentifier(sqlText) + + override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = + delegate.parseFunctionIdentifier(sqlText) + + override def parseTableSchema(sqlText: String): StructType = + delegate.parseTableSchema(sqlText) + + override def parseDataType(sqlText: String): DataType = + delegate.parseDataType(sqlText) +} + +class MyExtensions extends (SparkSessionExtensions => Unit) { + def apply(e: SparkSessionExtensions): Unit = { + e.injectPlannerStrategy(MySparkStrategy) + e.injectResolutionRule(MyRule) + } +} diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 9c879218ddc0d..a5a8e2640586c 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 0f249d7d59351..09dcc4055e000 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/streaming/pom.xml b/streaming/pom.xml index de1be9c13e05f..fea882ad11230 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/tools/pom.xml b/tools/pom.xml index 938ba2f6ac201..7ba4dc9842f1b 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml