diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index f83f5278e8b8f..1484f29525a4e 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -153,7 +153,7 @@ class RangePartitioner[K : Ordering : ClassTag, V]( val weight = (1.0 / fraction).toFloat candidates ++= reSampled.map(x => (x, weight)) } - RangePartitioner.determineBounds(candidates, partitions) + RangePartitioner.determineBounds(candidates, math.min(partitions, candidates.size)) } } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index a02cf30a5d831..e94babb846128 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -109,8 +109,11 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { val iter = endpoints.keySet().iterator() while (iter.hasNext) { val name = iter.next - postMessage(name, message, (e) => logWarning(s"Message $message dropped. ${e.getMessage}")) - } + postMessage(name, message, (e) => { e match { + case e: RpcEnvStoppedException => logDebug (s"Message $message dropped. ${e.getMessage}") + case e: Throwable => logWarning(s"Message $message dropped. ${e.getMessage}") + }} + )} } /** Posts a message sent by a remote endpoint. */ diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala index ae4a6003517cc..d32eba64e13e9 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -205,7 +205,12 @@ private[netty] class Inbox( try action catch { case NonFatal(e) => try endpoint.onError(e) catch { - case NonFatal(ee) => logError(s"Ignoring error", ee) + case NonFatal(ee) => + if (stopped) { + logDebug("Ignoring error", ee) + } else { + logError("Ignoring error", ee) + } } } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index b316e5443f639..64898499246ac 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -185,7 +185,7 @@ private[netty] class NettyRpcEnv( try { dispatcher.postOneWayMessage(message) } catch { - case e: RpcEnvStoppedException => logWarning(e.getMessage) + case e: RpcEnvStoppedException => logDebug(e.getMessage) } } else { // Message to a remote RPC endpoint. @@ -203,7 +203,10 @@ private[netty] class NettyRpcEnv( def onFailure(e: Throwable): Unit = { if (!promise.tryFailure(e)) { - logWarning(s"Ignored failure: $e") + e match { + case e : RpcEnvStoppedException => logDebug (s"Ignored failure: $e") + case _ => logWarning(s"Ignored failure: $e") + } } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala index a7b7f58376f6b..b7e068aa68357 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala @@ -45,7 +45,7 @@ private[netty] case class OneWayOutboxMessage(content: ByteBuffer) extends Outbo override def onFailure(e: Throwable): Unit = { e match { - case e1: RpcEnvStoppedException => logWarning(e1.getMessage) + case e1: RpcEnvStoppedException => logDebug(e1.getMessage) case e1: Throwable => logWarning(s"Failed to send one-way RPC.", e1) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index 0dd63d4392800..7d5e9809dd7b2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -136,7 +136,7 @@ private[spark] class LiveListenerBus(conf: SparkConf) extends SparkListenerBus { def post(event: SparkListenerEvent): Unit = { if (stopped.get) { // Drop further events to make `listenerThread` exit ASAP - logError(s"$name has already stopped! Dropping event $event") + logDebug(s"$name has already stopped! Dropping event $event") return } metrics.numEventsPosted.inc() diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index 34c017806fe10..dfe4c25670ce0 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -253,6 +253,12 @@ class PartitioningSuite extends SparkFunSuite with SharedSparkContext with Priva // Add other tests here for classes that should be able to handle empty partitions correctly } + + test("Number of elements in RDD is less than number of partitions") { + val rdd = sc.parallelize(1 to 3).map(x => (x, x)) + val partitioner = new RangePartitioner(22, rdd) + assert(partitioner.numPartitions === 3) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala index d55eb14d03456..0ab6eed959381 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala @@ -126,10 +126,26 @@ private[ml] object ValidatorParams { extraMetadata: Option[JObject] = None): Unit = { import org.json4s.JsonDSL._ + var numParamsNotJson = 0 val estimatorParamMapsJson = compact(render( instance.getEstimatorParamMaps.map { case paramMap => paramMap.toSeq.map { case ParamPair(p, v) => - Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v)) + v match { + case writeableObj: DefaultParamsWritable => + val relativePath = "epm_" + p.name + numParamsNotJson + val paramPath = new Path(path, relativePath).toString + numParamsNotJson += 1 + writeableObj.save(paramPath) + Map("parent" -> p.parent, "name" -> p.name, + "value" -> compact(render(JString(relativePath))), + "isJson" -> compact(render(JBool(false)))) + case _: MLWritable => + throw new NotImplementedError("ValidatorParams.saveImpl does not handle parameters " + + "of type: MLWritable that are not DefaultParamsWritable") + case _ => + Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v), + "isJson" -> compact(render(JBool(true)))) + } } }.toSeq )) @@ -183,8 +199,17 @@ private[ml] object ValidatorParams { val paramPairs = pMap.map { case pInfo: Map[String, String] => val est = uidToParams(pInfo("parent")) val param = est.getParam(pInfo("name")) - val value = param.jsonDecode(pInfo("value")) - param -> value + // [Spark-21221] introduced the isJson field + if (!pInfo.contains("isJson") || + (pInfo.contains("isJson") && pInfo("isJson").toBoolean.booleanValue())) { + val value = param.jsonDecode(pInfo("value")) + param -> value + } else { + val relativePath = param.jsonDecode(pInfo("value")).toString + val value = DefaultParamsReader + .loadParamsInstance[MLWritable](new Path(path, relativePath).toString, sc) + param -> value + } } ParamMap(paramPairs: _*) }.toArray diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 2b4e6b53e4f81..2791ea715ace6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.ml.tuning import org.apache.spark.SparkFunSuite import org.apache.spark.ml.{Estimator, Model, Pipeline} -import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} +import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel, OneVsRest} import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput -import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} +import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, MulticlassClassificationEvaluator, RegressionEvaluator} import org.apache.spark.ml.feature.HashingTF -import org.apache.spark.ml.linalg.{DenseMatrix, Vectors} -import org.apache.spark.ml.param.{ParamMap, ParamPair} +import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} @@ -153,7 +153,76 @@ class CrossValidatorSuite s" LogisticRegression but found ${other.getClass.getName}") } - CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) + ValidatorParamsSuiteHelpers + .compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) + } + + test("read/write: CrossValidator with nested estimator") { + val ova = new OneVsRest().setClassifier(new LogisticRegression) + val evaluator = new MulticlassClassificationEvaluator() + .setMetricName("accuracy") + val classifier1 = new LogisticRegression().setRegParam(2.0) + val classifier2 = new LogisticRegression().setRegParam(3.0) + // params that are not JSON serializable must inherit from Params + val paramMaps = new ParamGridBuilder() + .addGrid(ova.classifier, Array(classifier1, classifier2)) + .build() + val cv = new CrossValidator() + .setEstimator(ova) + .setEvaluator(evaluator) + .setNumFolds(20) + .setEstimatorParamMaps(paramMaps) + + val cv2 = testDefaultReadWrite(cv, testParams = false) + + assert(cv.uid === cv2.uid) + assert(cv.getNumFolds === cv2.getNumFolds) + assert(cv.getSeed === cv2.getSeed) + + assert(cv2.getEvaluator.isInstanceOf[MulticlassClassificationEvaluator]) + val evaluator2 = cv2.getEvaluator.asInstanceOf[MulticlassClassificationEvaluator] + assert(evaluator.uid === evaluator2.uid) + assert(evaluator.getMetricName === evaluator2.getMetricName) + + cv2.getEstimator match { + case ova2: OneVsRest => + assert(ova.uid === ova2.uid) + val classifier = ova2.getClassifier + classifier match { + case lr: LogisticRegression => + assert(ova.getClassifier.asInstanceOf[LogisticRegression].getMaxIter + === lr.getMaxIter) + case _ => + throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + + s" LogisticREgression but found ${classifier.getClass.getName}") + } + + case other => + throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + + s" OneVsRest but found ${other.getClass.getName}") + } + + ValidatorParamsSuiteHelpers + .compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) + } + + test("read/write: Persistence of nested estimator works if parent directory changes") { + val ova = new OneVsRest().setClassifier(new LogisticRegression) + val evaluator = new MulticlassClassificationEvaluator() + .setMetricName("accuracy") + val classifier1 = new LogisticRegression().setRegParam(2.0) + val classifier2 = new LogisticRegression().setRegParam(3.0) + // params that are not JSON serializable must inherit from Params + val paramMaps = new ParamGridBuilder() + .addGrid(ova.classifier, Array(classifier1, classifier2)) + .build() + val cv = new CrossValidator() + .setEstimator(ova) + .setEvaluator(evaluator) + .setNumFolds(20) + .setEstimatorParamMaps(paramMaps) + + ValidatorParamsSuiteHelpers.testFileMove(cv) } test("read/write: CrossValidator with complex estimator") { @@ -193,7 +262,8 @@ class CrossValidatorSuite assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) assert(cv.getEvaluator.uid === cv2.getEvaluator.uid) - CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) + ValidatorParamsSuiteHelpers + .compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) cv2.getEstimator match { case pipeline2: Pipeline => @@ -212,7 +282,8 @@ class CrossValidatorSuite assert(lrcv.uid === lrcv2.uid) assert(lrcv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) assert(lrEvaluator.uid === lrcv2.getEvaluator.uid) - CrossValidatorSuite.compareParamMaps(lrParamMaps, lrcv2.getEstimatorParamMaps) + ValidatorParamsSuiteHelpers + .compareParamMaps(lrParamMaps, lrcv2.getEstimatorParamMaps) case other => throw new AssertionError("Loaded Pipeline expected stages (HashingTF, CrossValidator)" + " but found: " + other.map(_.getClass.getName).mkString(", ")) @@ -278,7 +349,8 @@ class CrossValidatorSuite s" LogisticRegression but found ${other.getClass.getName}") } - CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) + ValidatorParamsSuiteHelpers + .compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) cv2.bestModel match { case lrModel2: LogisticRegressionModel => @@ -296,21 +368,6 @@ class CrossValidatorSuite object CrossValidatorSuite extends SparkFunSuite { - /** - * Assert sequences of estimatorParamMaps are identical. - * Params must be simple types comparable with `===`. - */ - def compareParamMaps(pMaps: Array[ParamMap], pMaps2: Array[ParamMap]): Unit = { - assert(pMaps.length === pMaps2.length) - pMaps.zip(pMaps2).foreach { case (pMap, pMap2) => - assert(pMap.size === pMap2.size) - pMap.toSeq.foreach { case ParamPair(p, v) => - assert(pMap2.contains(p)) - assert(pMap2(p) === v) - } - } - } - abstract class MyModel extends Model[MyModel] class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol { diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index a34f930aa11c4..71a1776a2cdd0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -19,11 +19,11 @@ package org.apache.spark.ml.tuning import org.apache.spark.SparkFunSuite import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} +import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel, OneVsRest} import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} import org.apache.spark.ml.linalg.Vectors -import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.{ParamMap} import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} @@ -95,7 +95,7 @@ class TrainValidationSplitSuite } test("transformSchema should check estimatorParamMaps") { - import TrainValidationSplitSuite._ + import TrainValidationSplitSuite.{MyEstimator, MyEvaluator} val est = new MyEstimator("est") val eval = new MyEvaluator @@ -134,6 +134,82 @@ class TrainValidationSplitSuite assert(tvs.getTrainRatio === tvs2.getTrainRatio) assert(tvs.getSeed === tvs2.getSeed) + + ValidatorParamsSuiteHelpers + .compareParamMaps(tvs.getEstimatorParamMaps, tvs2.getEstimatorParamMaps) + + tvs2.getEstimator match { + case lr2: LogisticRegression => + assert(lr.uid === lr2.uid) + assert(lr.getMaxIter === lr2.getMaxIter) + case other => + throw new AssertionError(s"Loaded TrainValidationSplit expected estimator of type" + + s" LogisticRegression but found ${other.getClass.getName}") + } + } + + test("read/write: TrainValidationSplit with nested estimator") { + val ova = new OneVsRest() + .setClassifier(new LogisticRegression) + val evaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderPR") // not default metric + val classifier1 = new LogisticRegression().setRegParam(2.0) + val classifier2 = new LogisticRegression().setRegParam(3.0) + val paramMaps = new ParamGridBuilder() + .addGrid(ova.classifier, Array(classifier1, classifier2)) + .build() + val tvs = new TrainValidationSplit() + .setEstimator(ova) + .setEvaluator(evaluator) + .setTrainRatio(0.5) + .setEstimatorParamMaps(paramMaps) + .setSeed(42L) + + val tvs2 = testDefaultReadWrite(tvs, testParams = false) + + assert(tvs.getTrainRatio === tvs2.getTrainRatio) + assert(tvs.getSeed === tvs2.getSeed) + + tvs2.getEstimator match { + case ova2: OneVsRest => + assert(ova.uid === ova2.uid) + val classifier = ova2.getClassifier + classifier match { + case lr: LogisticRegression => + assert(ova.getClassifier.asInstanceOf[LogisticRegression].getMaxIter + === lr.getMaxIter) + case _ => + throw new AssertionError(s"Loaded TrainValidationSplit expected estimator of type" + + s" LogisticREgression but found ${classifier.getClass.getName}") + } + + case other => + throw new AssertionError(s"Loaded TrainValidationSplit expected estimator of type" + + s" OneVsRest but found ${other.getClass.getName}") + } + + ValidatorParamsSuiteHelpers + .compareParamMaps(tvs.getEstimatorParamMaps, tvs2.getEstimatorParamMaps) + } + + test("read/write: Persistence of nested estimator works if parent directory changes") { + val ova = new OneVsRest() + .setClassifier(new LogisticRegression) + val evaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderPR") // not default metric + val classifier1 = new LogisticRegression().setRegParam(2.0) + val classifier2 = new LogisticRegression().setRegParam(3.0) + val paramMaps = new ParamGridBuilder() + .addGrid(ova.classifier, Array(classifier1, classifier2)) + .build() + val tvs = new TrainValidationSplit() + .setEstimator(ova) + .setEvaluator(evaluator) + .setTrainRatio(0.5) + .setEstimatorParamMaps(paramMaps) + .setSeed(42L) + + ValidatorParamsSuiteHelpers.testFileMove(tvs) } test("read/write: TrainValidationSplitModel") { @@ -160,7 +236,7 @@ class TrainValidationSplitSuite } } -object TrainValidationSplitSuite { +object TrainValidationSplitSuite extends SparkFunSuite{ abstract class MyModel extends Model[MyModel] diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidatorParamsSuiteHelpers.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidatorParamsSuiteHelpers.scala new file mode 100644 index 0000000000000..1df673cf40162 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidatorParamsSuiteHelpers.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import java.io.File +import java.nio.file.{Files, StandardCopyOption} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.{ParamMap, ParamPair, Params} +import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLReader, MLWritable} + +object ValidatorParamsSuiteHelpers extends SparkFunSuite with DefaultReadWriteTest { + /** + * Assert sequences of estimatorParamMaps are identical. + * If the values for a parameter are not directly comparable with === + * and are instead Params types themselves then their corresponding paramMaps + * are compared against each other. + */ + def compareParamMaps(pMaps: Array[ParamMap], pMaps2: Array[ParamMap]): Unit = { + assert(pMaps.length === pMaps2.length) + pMaps.zip(pMaps2).foreach { case (pMap, pMap2) => + assert(pMap.size === pMap2.size) + pMap.toSeq.foreach { case ParamPair(p, v) => + assert(pMap2.contains(p)) + val otherParam = pMap2(p) + v match { + case estimator: Params => + otherParam match { + case estimator2: Params => + val estimatorParamMap = Array(estimator.extractParamMap()) + val estimatorParamMap2 = Array(estimator2.extractParamMap()) + compareParamMaps(estimatorParamMap, estimatorParamMap2) + case other => + throw new AssertionError(s"Expected parameter of type Params but" + + s" found ${otherParam.getClass.getName}") + } + case _ => + assert(otherParam === v) + } + } + } + } + + /** + * When nested estimators (ex. OneVsRest) are saved within meta-algorithms such as + * CrossValidator and TrainValidationSplit, relative paths should be used to store + * the path of the estimator so that if the parent directory changes, loading the + * model still works. + */ + def testFileMove[T <: Params with MLWritable](instance: T): Unit = { + val uid = instance.uid + val subdirName = Identifiable.randomUID("test") + + val subdir = new File(tempDir, subdirName) + val subDirWithUid = new File(subdir, uid) + + instance.save(subDirWithUid.getPath) + + val newSubdirName = Identifiable.randomUID("test_moved") + val newSubdir = new File(tempDir, newSubdirName) + val newSubdirWithUid = new File(newSubdir, uid) + + Files.createDirectory(newSubdir.toPath) + Files.createDirectory(newSubdirWithUid.toPath) + Files.move(subDirWithUid.toPath, newSubdirWithUid.toPath, StandardCopyOption.ATOMIC_MOVE) + + val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[MLReader[T]] + val newInstance = loader.load(newSubdirWithUid.getPath) + assert(uid == newInstance.uid) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 27d606cb05dc2..4da95e74434ee 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -55,7 +55,6 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => instance.write.overwrite().save(path) val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[MLReader[T]] val newInstance = loader.load(path) - assert(newInstance.uid === instance.uid) if (testParams) { instance.params.foreach { p => diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 948806a5c936c..82207f664480a 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -25,7 +25,7 @@ from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams from pyspark.ml.wrapper import JavaWrapper -from pyspark.ml.common import inherit_doc +from pyspark.ml.common import inherit_doc, _java2py, _py2java from pyspark.sql import DataFrame from pyspark.sql.functions import udf, when from pyspark.sql.types import ArrayType, DoubleType @@ -1472,7 +1472,7 @@ def getClassifier(self): @inherit_doc -class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable): +class OneVsRest(Estimator, OneVsRestParams, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1589,22 +1589,6 @@ def copy(self, extra=None): newOvr.setClassifier(self.getClassifier().copy(extra)) return newOvr - @since("2.0.0") - def write(self): - """Returns an MLWriter instance for this ML instance.""" - return JavaMLWriter(self) - - @since("2.0.0") - def save(self, path): - """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" - self.write().save(path) - - @classmethod - @since("2.0.0") - def read(cls): - """Returns an MLReader instance for this class.""" - return JavaMLReader(cls) - @classmethod def _from_java(cls, java_stage): """ @@ -1634,8 +1618,52 @@ def _to_java(self): _java_obj.setPredictionCol(self.getPredictionCol()) return _java_obj + def _make_java_param_pair(self, param, value): + """ + Makes a Java param pair. + """ + sc = SparkContext._active_spark_context + param = self._resolveParam(param) + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRest", + self.uid) + java_param = _java_obj.getParam(param.name) + if isinstance(value, JavaParams): + # used in the case of an estimator having another estimator as a parameter + # the reason why this is not in _py2java in common.py is that importing + # Estimator and Model in common.py results in a circular import with inherit_doc + java_value = value._to_java() + else: + java_value = _py2java(sc, value) + return java_param.w(java_value) -class OneVsRestModel(Model, OneVsRestParams, MLReadable, MLWritable): + def _transfer_param_map_to_java(self, pyParamMap): + """ + Transforms a Python ParamMap into a Java ParamMap. + """ + paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap") + for param in self.params: + if param in pyParamMap: + pair = self._make_java_param_pair(param, pyParamMap[param]) + paramMap.put([pair]) + return paramMap + + def _transfer_param_map_from_java(self, javaParamMap): + """ + Transforms a Java ParamMap into a Python ParamMap. + """ + sc = SparkContext._active_spark_context + paramMap = dict() + for pair in javaParamMap.toList(): + param = pair.param() + if self.hasParam(str(param.name())): + if param.name() == "classifier": + paramMap[self.getParam(param.name())] = JavaParams._from_java(pair.value()) + else: + paramMap[self.getParam(param.name())] = _java2py(sc, pair.value()) + return paramMap + + +class OneVsRestModel(Model, OneVsRestParams, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1650,6 +1678,16 @@ class OneVsRestModel(Model, OneVsRestParams, MLReadable, MLWritable): def __init__(self, models): super(OneVsRestModel, self).__init__() self.models = models + java_models = [model._to_java() for model in self.models] + sc = SparkContext._active_spark_context + java_models_array = JavaWrapper._new_java_array(java_models, + sc._gateway.jvm.org.apache.spark.ml + .classification.ClassificationModel) + # TODO: need to set metadata + metadata = JavaParams._new_java_obj("org.apache.spark.sql.types.Metadata") + self._java_obj = \ + JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRestModel", + self.uid, metadata.empty(), java_models_array) def _transform(self, dataset): # determine the input columns: these need to be passed through @@ -1715,22 +1753,6 @@ def copy(self, extra=None): newModel.models = [model.copy(extra) for model in self.models] return newModel - @since("2.0.0") - def write(self): - """Returns an MLWriter instance for this ML instance.""" - return JavaMLWriter(self) - - @since("2.0.0") - def save(self, path): - """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" - self.write().save(path) - - @classmethod - @since("2.0.0") - def read(cls): - """Returns an MLReader instance for this class.""" - return JavaMLReader(cls) - @classmethod def _from_java(cls, java_stage): """ diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 7870047651601..6c71e69c9b5f9 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -49,7 +49,8 @@ from pyspark.ml.classification import * from pyspark.ml.clustering import * from pyspark.ml.common import _java2py, _py2java -from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator +from pyspark.ml.evaluation import BinaryClassificationEvaluator, \ + MulticlassClassificationEvaluator, RegressionEvaluator from pyspark.ml.feature import * from pyspark.ml.fpm import FPGrowth, FPGrowthModel from pyspark.ml.linalg import DenseMatrix, DenseMatrix, DenseVector, Matrices, MatrixUDT, \ @@ -678,7 +679,7 @@ def test_fit_maximize_metric(self): "Best model should have zero induced error") self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") - def test_save_load(self): + def test_save_load_trained_model(self): # This tests saving and loading the trained model only. # Save/load for CrossValidator will be added later: SPARK-13786 temp_path = tempfile.mkdtemp() @@ -702,6 +703,76 @@ def test_save_load(self): self.assertEqual(loadedLrModel.uid, lrModel.uid) self.assertEqual(loadedLrModel.intercept, lrModel.intercept) + def test_save_load_simple_estimator(self): + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + + # test save/load of CrossValidator + cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + cvModel = cv.fit(dataset) + cvPath = temp_path + "/cv" + cv.save(cvPath) + loadedCV = CrossValidator.load(cvPath) + self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid) + self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid) + self.assertEqual(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps()) + + # test save/load of CrossValidatorModel + cvModelPath = temp_path + "/cvModel" + cvModel.save(cvModelPath) + loadedModel = CrossValidatorModel.load(cvModelPath) + self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid) + + def test_save_load_nested_estimator(self): + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + + ova = OneVsRest(classifier=LogisticRegression()) + lr1 = LogisticRegression().setMaxIter(100) + lr2 = LogisticRegression().setMaxIter(150) + grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build() + evaluator = MulticlassClassificationEvaluator() + + # test save/load of CrossValidator + cv = CrossValidator(estimator=ova, estimatorParamMaps=grid, evaluator=evaluator) + cvModel = cv.fit(dataset) + cvPath = temp_path + "/cv" + cv.save(cvPath) + loadedCV = CrossValidator.load(cvPath) + self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid) + self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid) + + originalParamMap = cv.getEstimatorParamMaps() + loadedParamMap = loadedCV.getEstimatorParamMaps() + for i, param in enumerate(loadedParamMap): + for p in param: + if p.name == "classifier": + self.assertEqual(param[p].uid, originalParamMap[i][p].uid) + else: + self.assertEqual(param[p], originalParamMap[i][p]) + + # test save/load of CrossValidatorModel + cvModelPath = temp_path + "/cvModel" + cvModel.save(cvModelPath) + loadedModel = CrossValidatorModel.load(cvModelPath) + self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid) + class TrainValidationSplitTests(SparkSessionTestCase): @@ -759,7 +830,7 @@ def test_fit_maximize_metric(self): "validationMetrics has the same size of grid parameter") self.assertEqual(1.0, max(validationMetrics)) - def test_save_load(self): + def test_save_load_trained_model(self): # This tests saving and loading the trained model only. # Save/load for TrainValidationSplit will be added later: SPARK-13786 temp_path = tempfile.mkdtemp() @@ -783,6 +854,74 @@ def test_save_load(self): self.assertEqual(loadedLrModel.uid, lrModel.uid) self.assertEqual(loadedLrModel.intercept, lrModel.intercept) + def test_save_load_simple_estimator(self): + # This tests saving and loading the trained model only. + # Save/load for TrainValidationSplit will be added later: SPARK-13786 + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + tvsModel = tvs.fit(dataset) + + tvsPath = temp_path + "/tvs" + tvs.save(tvsPath) + loadedTvs = TrainValidationSplit.load(tvsPath) + self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid) + self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid) + self.assertEqual(loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps()) + + tvsModelPath = temp_path + "/tvsModel" + tvsModel.save(tvsModelPath) + loadedModel = TrainValidationSplitModel.load(tvsModelPath) + self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid) + + def test_save_load_nested_estimator(self): + # This tests saving and loading the trained model only. + # Save/load for TrainValidationSplit will be added later: SPARK-13786 + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + ova = OneVsRest(classifier=LogisticRegression()) + lr1 = LogisticRegression().setMaxIter(100) + lr2 = LogisticRegression().setMaxIter(150) + grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build() + evaluator = MulticlassClassificationEvaluator() + + tvs = TrainValidationSplit(estimator=ova, estimatorParamMaps=grid, evaluator=evaluator) + tvsModel = tvs.fit(dataset) + tvsPath = temp_path + "/tvs" + tvs.save(tvsPath) + loadedTvs = TrainValidationSplit.load(tvsPath) + self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid) + self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid) + + originalParamMap = tvs.getEstimatorParamMaps() + loadedParamMap = loadedTvs.getEstimatorParamMaps() + for i, param in enumerate(loadedParamMap): + for p in param: + if p.name == "classifier": + self.assertEqual(param[p].uid, originalParamMap[i][p].uid) + else: + self.assertEqual(param[p], originalParamMap[i][p]) + + tvsModelPath = temp_path + "/tvsModel" + tvsModel.save(tvsModelPath) + loadedModel = TrainValidationSplitModel.load(tvsModelPath) + self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid) + def test_copy(self): dataset = self.spark.createDataFrame([ (10, 10.0), diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index b64858214d20d..00c348aa9f7de 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -20,8 +20,11 @@ from pyspark import since, keyword_only from pyspark.ml import Estimator, Model +from pyspark.ml.common import _py2java from pyspark.ml.param import Params, Param, TypeConverters from pyspark.ml.param.shared import HasSeed +from pyspark.ml.util import * +from pyspark.ml.wrapper import JavaParams from pyspark.sql.functions import rand __all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainValidationSplit', @@ -137,8 +140,37 @@ def getEvaluator(self): """ return self.getOrDefault(self.evaluator) + @classmethod + def _from_java_impl(cls, java_stage): + """ + Return Python estimator, estimatorParamMaps, and evaluator from a Java ValidatorParams. + """ + + # Load information from java_stage to the instance. + estimator = JavaParams._from_java(java_stage.getEstimator()) + evaluator = JavaParams._from_java(java_stage.getEvaluator()) + epms = [estimator._transfer_param_map_from_java(epm) + for epm in java_stage.getEstimatorParamMaps()] + return estimator, epms, evaluator + + def _to_java_impl(self): + """ + Return Java estimator, estimatorParamMaps, and evaluator from this Python instance. + """ + + gateway = SparkContext._gateway + cls = SparkContext._jvm.org.apache.spark.ml.param.ParamMap + + java_epms = gateway.new_array(cls, len(self.getEstimatorParamMaps())) + for idx, epm in enumerate(self.getEstimatorParamMaps()): + java_epms[idx] = self.getEstimator()._transfer_param_map_to_java(epm) -class CrossValidator(Estimator, ValidatorParams): + java_estimator = self.getEstimator()._to_java() + java_evaluator = self.getEvaluator()._to_java() + return java_estimator, java_epms, java_evaluator + + +class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable): """ K-fold cross validation performs model selection by splitting the dataset into a set of @@ -263,8 +295,53 @@ def copy(self, extra=None): newCV.setEvaluator(self.getEvaluator().copy(extra)) return newCV + @since("2.3.0") + def write(self): + """Returns an MLWriter instance for this ML instance.""" + return JavaMLWriter(self) + + @classmethod + @since("2.3.0") + def read(cls): + """Returns an MLReader instance for this class.""" + return JavaMLReader(cls) + + @classmethod + def _from_java(cls, java_stage): + """ + Given a Java CrossValidator, create and return a Python wrapper of it. + Used for ML persistence. + """ -class CrossValidatorModel(Model, ValidatorParams): + estimator, epms, evaluator = super(CrossValidator, cls)._from_java_impl(java_stage) + numFolds = java_stage.getNumFolds() + seed = java_stage.getSeed() + # Create a new instance of this stage. + py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator, + numFolds=numFolds, seed=seed) + py_stage._resetUid(java_stage.uid()) + return py_stage + + def _to_java(self): + """ + Transfer this instance to a Java CrossValidator. Used for ML persistence. + + :return: Java object equivalent to this instance. + """ + + estimator, epms, evaluator = super(CrossValidator, self)._to_java_impl() + + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid) + _java_obj.setEstimatorParamMaps(epms) + _java_obj.setEvaluator(evaluator) + _java_obj.setEstimator(estimator) + _java_obj.setSeed(self.getSeed()) + _java_obj.setNumFolds(self.getNumFolds()) + + return _java_obj + + +class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable): """ CrossValidatorModel contains the model with the highest average cross-validation @@ -302,8 +379,55 @@ def copy(self, extra=None): avgMetrics = self.avgMetrics return CrossValidatorModel(bestModel, avgMetrics) + @since("2.3.0") + def write(self): + """Returns an MLWriter instance for this ML instance.""" + return JavaMLWriter(self) + + @classmethod + @since("2.3.0") + def read(cls): + """Returns an MLReader instance for this class.""" + return JavaMLReader(cls) -class TrainValidationSplit(Estimator, ValidatorParams): + @classmethod + def _from_java(cls, java_stage): + """ + Given a Java CrossValidatorModel, create and return a Python wrapper of it. + Used for ML persistence. + """ + + bestModel = JavaParams._from_java(java_stage.bestModel()) + estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage) + + py_stage = cls(bestModel=bestModel).setEstimator(estimator) + py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator) + + py_stage._resetUid(java_stage.uid()) + return py_stage + + def _to_java(self): + """ + Transfer this instance to a Java CrossValidatorModel. Used for ML persistence. + + :return: Java object equivalent to this instance. + """ + + sc = SparkContext._active_spark_context + # TODO: persist average metrics as well + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel", + self.uid, + self.bestModel._to_java(), + _py2java(sc, [])) + estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl() + + _java_obj.set("evaluator", evaluator) + _java_obj.set("estimator", estimator) + _java_obj.set("estimatorParamMaps", epms) + return _java_obj + + +class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable): """ .. note:: Experimental @@ -418,8 +542,53 @@ def copy(self, extra=None): newTVS.setEvaluator(self.getEvaluator().copy(extra)) return newTVS + @since("2.3.0") + def write(self): + """Returns an MLWriter instance for this ML instance.""" + return JavaMLWriter(self) + + @classmethod + @since("2.3.0") + def read(cls): + """Returns an MLReader instance for this class.""" + return JavaMLReader(cls) + + @classmethod + def _from_java(cls, java_stage): + """ + Given a Java TrainValidationSplit, create and return a Python wrapper of it. + Used for ML persistence. + """ + + estimator, epms, evaluator = super(TrainValidationSplit, cls)._from_java_impl(java_stage) + trainRatio = java_stage.getTrainRatio() + seed = java_stage.getSeed() + # Create a new instance of this stage. + py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator, + trainRatio=trainRatio, seed=seed) + py_stage._resetUid(java_stage.uid()) + return py_stage + + def _to_java(self): + """ + Transfer this instance to a Java TrainValidationSplit. Used for ML persistence. + :return: Java object equivalent to this instance. + """ + + estimator, epms, evaluator = super(TrainValidationSplit, self)._to_java_impl() -class TrainValidationSplitModel(Model, ValidatorParams): + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit", + self.uid) + _java_obj.setEstimatorParamMaps(epms) + _java_obj.setEvaluator(evaluator) + _java_obj.setEstimator(estimator) + _java_obj.setTrainRatio(self.getTrainRatio()) + _java_obj.setSeed(self.getSeed()) + + return _java_obj + + +class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable): """ .. note:: Experimental @@ -456,6 +625,55 @@ def copy(self, extra=None): validationMetrics = list(self.validationMetrics) return TrainValidationSplitModel(bestModel, validationMetrics) + @since("2.3.0") + def write(self): + """Returns an MLWriter instance for this ML instance.""" + return JavaMLWriter(self) + + @classmethod + @since("2.3.0") + def read(cls): + """Returns an MLReader instance for this class.""" + return JavaMLReader(cls) + + @classmethod + def _from_java(cls, java_stage): + """ + Given a Java TrainValidationSplitModel, create and return a Python wrapper of it. + Used for ML persistence. + """ + + # Load information from java_stage to the instance. + bestModel = JavaParams._from_java(java_stage.bestModel()) + estimator, epms, evaluator = super(TrainValidationSplitModel, + cls)._from_java_impl(java_stage) + # Create a new instance of this stage. + py_stage = cls(bestModel=bestModel).setEstimator(estimator) + py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator) + + py_stage._resetUid(java_stage.uid()) + return py_stage + + def _to_java(self): + """ + Transfer this instance to a Java TrainValidationSplitModel. Used for ML persistence. + :return: Java object equivalent to this instance. + """ + + sc = SparkContext._active_spark_context + # TODO: persst validation metrics as well + _java_obj = JavaParams._new_java_obj( + "org.apache.spark.ml.tuning.TrainValidationSplitModel", + self.uid, + self.bestModel._to_java(), + _py2java(sc, [])) + estimator, epms, evaluator = super(TrainValidationSplitModel, self)._to_java_impl() + + _java_obj.set("evaluator", evaluator) + _java_obj.set("estimator", estimator) + _java_obj.set("estimatorParamMaps", epms) + return _java_obj + if __name__ == "__main__": import doctest diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 80a0b31cd88d9..ee6301ef19a43 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -106,7 +106,7 @@ def __del__(self): def _make_java_param_pair(self, param, value): """ - Makes a Java parm pair. + Makes a Java param pair. """ sc = SparkContext._active_spark_context param = self._resolveParam(param) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d45ff63355de3..2c8c8e2d80f09 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2087,10 +2087,22 @@ def _wrapped(self): """ Wrap this udf with a function and attach docstring from func """ - @functools.wraps(self.func) + + # It is possible for a callable instance without __name__ attribute or/and + # __module__ attribute to be wrapped here. For example, functools.partial. In this case, + # we should avoid wrapping the attributes from the wrapped function to the wrapper + # function. So, we take out these attribute names from the default names to set and + # then manually assign it after being wrapped. + assignments = tuple( + a for a in functools.WRAPPER_ASSIGNMENTS if a != '__name__' and a != '__module__') + + @functools.wraps(self.func, assigned=assignments) def wrapper(*args): return self(*args) + wrapper.__name__ = self._name + wrapper.__module__ = (self.func.__module__ if hasattr(self.func, '__module__') + else self.func.__class__.__module__) wrapper.func = self.func wrapper.returnType = self.returnType diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 29e48a6ccf763..be5495ca019a2 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -679,6 +679,27 @@ def f(x): self.assertEqual(f, f_.func) self.assertEqual(return_type, f_.returnType) + class F(object): + """Identity""" + def __call__(self, x): + return x + + f = F() + return_type = IntegerType() + f_ = udf(f, return_type) + + self.assertTrue(f.__doc__ in f_.__doc__) + self.assertEqual(f, f_.func) + self.assertEqual(return_type, f_.returnType) + + f = functools.partial(f, x=1) + return_type = IntegerType() + f_ = udf(f, return_type) + + self.assertTrue(f.__doc__ in f_.__doc__) + self.assertEqual(f, f_.func) + self.assertEqual(return_type, f_.returnType) + def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) df = self.spark.read.json(rdd) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 4868180569778..ce290c399d9f2 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -90,6 +90,23 @@ private[spark] class ApplicationMaster( @volatile private var reporterThread: Thread = _ @volatile private var allocator: YarnAllocator = _ + private val userClassLoader = { + val classpath = Client.getUserClasspath(sparkConf) + val urls = classpath.map { entry => + new URL("file:" + new File(entry.getPath()).getAbsolutePath()) + } + + if (isClusterMode) { + if (Client.isUserClassPathFirst(sparkConf, isDriver = true)) { + new ChildFirstURLClassLoader(urls, Utils.getContextOrSparkClassLoader) + } else { + new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader) + } + } else { + new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader) + } + } + // Lock for controlling the allocator (heartbeat) thread. private val allocatorLock = new Object() @@ -242,16 +259,27 @@ private[spark] class ApplicationMaster( // If the credentials file config is present, we must periodically renew tokens. So create // a new AMDelegationTokenRenewer - if (sparkConf.contains(CREDENTIALS_FILE_PATH.key)) { - // If a principal and keytab have been set, use that to create new credentials for executors - // periodically - val credentialManager = new YARNHadoopDelegationTokenManager( - sparkConf, - yarnConf, - YarnSparkHadoopUtil.get.hadoopFSsToAccess(sparkConf, yarnConf)) - - val credentialRenewer = new AMCredentialRenewer(sparkConf, yarnConf, credentialManager) - credentialRenewer.scheduleLoginFromKeytab() + if (sparkConf.contains(CREDENTIALS_FILE_PATH)) { + // Start a short-lived thread for AMCredentialRenewer, the only purpose is to set the + // classloader so that main jar and secondary jars could be used by AMCredentialRenewer. + val credentialRenewerThread = new Thread { + setName("AMCredentialRenewerStarter") + setContextClassLoader(userClassLoader) + + override def run(): Unit = { + val credentialManager = new YARNHadoopDelegationTokenManager( + sparkConf, + yarnConf, + YarnSparkHadoopUtil.get.hadoopFSsToAccess(sparkConf, yarnConf)) + + val credentialRenewer = + new AMCredentialRenewer(sparkConf, yarnConf, credentialManager) + credentialRenewer.scheduleLoginFromKeytab() + } + } + + credentialRenewerThread.start() + credentialRenewerThread.join() } if (isClusterMode) { @@ -609,17 +637,6 @@ private[spark] class ApplicationMaster( private def startUserApplication(): Thread = { logInfo("Starting the user application in a separate Thread") - val classpath = Client.getUserClasspath(sparkConf) - val urls = classpath.map { entry => - new URL("file:" + new File(entry.getPath()).getAbsolutePath()) - } - val userClassLoader = - if (Client.isUserClassPathFirst(sparkConf, isDriver = true)) { - new ChildFirstURLClassLoader(urls, Utils.getContextOrSparkClassLoader) - } else { - new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader) - } - var userArgs = args.userArgs if (args.primaryPyFile != null && args.primaryPyFile.endsWith(".py")) { // When running pyspark, the app is run using PythonRunner. The second argument is the list diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index cbc6e60e839c1..8452f43774194 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler.cluster +import java.util.concurrent.atomic.{AtomicBoolean} + import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Success} import scala.util.control.NonFatal @@ -40,6 +42,8 @@ private[spark] abstract class YarnSchedulerBackend( sc: SparkContext) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) { + private val stopped = new AtomicBoolean(false) + override val minRegisteredRatio = if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { 0.8 @@ -93,6 +97,7 @@ private[spark] abstract class YarnSchedulerBackend( requestTotalExecutors(0, 0, Map.empty) super.stop() } finally { + stopped.set(true) services.stop() } } @@ -206,8 +211,10 @@ private[spark] abstract class YarnSchedulerBackend( */ override def onDisconnected(rpcAddress: RpcAddress): Unit = { addressToExecutorId.get(rpcAddress).foreach { executorId => - if (disableExecutor(executorId)) { - yarnSchedulerEndpoint.handleExecutorDisconnectedFromDriver(executorId, rpcAddress) + if (!stopped.get) { + if (disableExecutor(executorId)) { + yarnSchedulerEndpoint.handleExecutorDisconnectedFromDriver(executorId, rpcAddress) + } } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index ff2414b174acb..a5b9855e959d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -17,17 +17,13 @@ package org.apache.spark.sql.catalyst -import java.net.URLClassLoader import java.sql.{Date, Timestamp} -import scala.reflect.runtime.universe.typeOf - import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow} import org.apache.spark.sql.catalyst.expressions.objects.NewInstance import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.util.Utils case class PrimitiveData( intField: Int, @@ -339,39 +335,4 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, _]])) } - private val dataTypeForComplexData = dataTypeFor[ComplexData] - private val typeOfComplexData = typeOf[ComplexData] - - Seq( - ("mirror", () => mirror), - ("dataTypeFor", () => dataTypeFor[ComplexData]), - ("constructorFor", () => deserializerFor[ComplexData]), - ("extractorsFor", { - val inputObject = BoundReference(0, dataTypeForComplexData, nullable = false) - () => serializerFor[ComplexData](inputObject) - }), - ("getConstructorParameters(cls)", () => getConstructorParameters(classOf[ComplexData])), - ("getConstructorParameterNames", () => getConstructorParameterNames(classOf[ComplexData])), - ("getClassFromType", () => getClassFromType(typeOfComplexData)), - ("schemaFor", () => schemaFor[ComplexData]), - ("localTypeOf", () => localTypeOf[ComplexData]), - ("getClassNameFromType", () => getClassNameFromType(typeOfComplexData)), - ("getParameterTypes", () => getParameterTypes(() => ())), - ("getConstructorParameters(tpe)", () => getClassNameFromType(typeOfComplexData))).foreach { - case (name, exec) => - test(s"SPARK-13640: thread safety of ${name}") { - (0 until 100).foreach { _ => - val loader = new URLClassLoader(Array.empty, Utils.getContextOrSparkClassLoader) - (0 until 10).par.foreach { _ => - val cl = Thread.currentThread.getContextClassLoader - try { - Thread.currentThread.setContextClassLoader(loader) - exec() - } finally { - Thread.currentThread.setContextClassLoader(cl) - } - } - } - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 41d40aa926fbb..b97fa54446e0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -22,7 +22,7 @@ import java.util.Locale import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, RowOrdering} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.command.DDLUtils @@ -409,6 +409,42 @@ object HiveOnlyCheck extends (LogicalPlan => Unit) { } } + +/** + * A rule to do various checks before reading a table. + */ +object PreReadCheck extends (LogicalPlan => Unit) { + def apply(plan: LogicalPlan): Unit = { + plan.foreach { + case operator: LogicalPlan => + operator transformExpressionsUp { + case e @ (_: InputFileName | _: InputFileBlockLength | _: InputFileBlockStart) => + checkNumInputFileBlockSources(e, operator) + e + } + } + } + + private def checkNumInputFileBlockSources(e: Expression, operator: LogicalPlan): Int = { + operator match { + case _: CatalogRelation => 1 + case _ @ LogicalRelation(_: HadoopFsRelation, _, _) => 1 + case _: LeafNode => 0 + // UNION ALL has multiple children, but these children do not concurrently use InputFileBlock. + case u: Union => + if (u.children.map(checkNumInputFileBlockSources(e, _)).sum >= 1) 1 else 0 + case o => + val numInputFileBlockSources = o.children.map(checkNumInputFileBlockSources(e, _)).sum + if (numInputFileBlockSources > 1) { + e.failAnalysis(s"'${e.prettyName}' does not support more than one sources") + } else { + numInputFileBlockSources + } + } + } +} + + /** * A rule to do various checks before inserting into or writing to a data source table. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 2b3c5f054893f..1400452833039 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -133,7 +133,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { val validUdfs = udfs.filter { udf => // Check to make sure that the UDF can be evaluated with only the input of this child. udf.references.subsetOf(child.outputSet) - }.toArray + } if (validUdfs.nonEmpty) { val resultAttrs = udfs.zipWithIndex.map { case (u, i) => AttributeReference(s"pythonUDF$i", u.dataType)() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 9dcac33b4107c..ab690fd5fbbca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -136,7 +136,7 @@ case class FlatMapGroupsWithStateExec( outputIterator, { store.commit() - longMetric("numTotalStateRows") += store.numKeys() + setStoreMetrics(store) } ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index a4e4ca821374c..1887b07c49b73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -186,18 +186,10 @@ trait ProgressReporter extends Logging { if (lastExecution == null) return Nil // lastExecution could belong to one of the previous triggers if `!hasNewData`. // Walking the plan again should be inexpensive. - val stateNodes = lastExecution.executedPlan.collect { - case p if p.isInstanceOf[StateStoreWriter] => p - } - stateNodes.map { node => - val numRowsUpdated = if (hasNewData) { - node.metrics.get("numUpdatedStateRows").map(_.value).getOrElse(0L) - } else { - 0L - } - new StateOperatorProgress( - numRowsTotal = node.metrics.get("numTotalStateRows").map(_.value).getOrElse(0L), - numRowsUpdated = numRowsUpdated) + lastExecution.executedPlan.collect { + case p if p.isInstanceOf[StateStoreWriter] => + val progress = p.asInstanceOf[StateStoreWriter].getProgress() + if (hasNewData) progress else progress.copy(newNumRowsUpdated = 0) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index bae7a15165e43..fa4c99c01916f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -35,7 +35,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.io.LZ4CompressionCodec import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.types.StructType -import org.apache.spark.util.Utils +import org.apache.spark.util.{SizeEstimator, Utils} /** @@ -172,7 +172,9 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } } - override def numKeys(): Long = mapToUpdate.size() + override def metrics: StateStoreMetrics = { + StateStoreMetrics(mapToUpdate.size(), SizeEstimator.estimate(mapToUpdate), Map.empty) + } /** * Whether all updates have been committed @@ -230,6 +232,10 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit loadedMaps.values.foreach(_.clear()) } + override def supportedCustomMetrics: Seq[StateStoreCustomMetric] = { + Nil + } + override def toString(): String = { s"HDFSStateStoreProvider[" + s"id = (op=${stateStoreId.operatorId},part=${stateStoreId.partitionId}),dir = $baseDir]" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 86886466c4f56..9da610e359f90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -94,8 +94,8 @@ trait StateStore { def iterator(): Iterator[UnsafeRowPair] - /** Number of keys in the state store */ - def numKeys(): Long + /** Current metrics of the state store */ + def metrics: StateStoreMetrics /** * Whether all updates have been committed @@ -103,6 +103,24 @@ trait StateStore { def hasCommitted: Boolean } +/** + * Metrics reported by a state store + * @param numKeys Number of keys in the state store + * @param memoryUsedBytes Memory used by the state store + * @param customMetrics Custom implementation-specific metrics + * The metrics reported through this must have the same `name` as those + * reported by `StateStoreProvider.customMetrics`. + */ +case class StateStoreMetrics( + numKeys: Long, + memoryUsedBytes: Long, + customMetrics: Map[StateStoreCustomMetric, Long]) + +/** + * Name and description of custom implementation-specific metrics that a + * state store may wish to expose. + */ +case class StateStoreCustomMetric(name: String, desc: String) /** * Trait representing a provider that provide [[StateStore]] instances representing @@ -158,22 +176,36 @@ trait StateStoreProvider { /** Optional method for providers to allow for background maintenance (e.g. compactions) */ def doMaintenance(): Unit = { } + + /** + * Optional custom metrics that the implementation may want to report. + * @note The StateStore objects created by this provider must report the same custom metrics + * (specifically, same names) through `StateStore.metrics`. + */ + def supportedCustomMetrics: Seq[StateStoreCustomMetric] = Nil } object StateStoreProvider { + + /** + * Return a instance of the given provider class name. The instance will not be initialized. + */ + def create(providerClassName: String): StateStoreProvider = { + val providerClass = Utils.classForName(providerClassName) + providerClass.newInstance().asInstanceOf[StateStoreProvider] + } + /** - * Return a provider instance of the given provider class. - * The instance will be already initialized. + * Return a instance of the required provider, initialized with the given configurations. */ - def instantiate( + def createAndInit( stateStoreId: StateStoreId, keySchema: StructType, valueSchema: StructType, indexOrdinal: Option[Int], // for sorting the data storeConf: StateStoreConf, hadoopConf: Configuration): StateStoreProvider = { - val providerClass = Utils.classForName(storeConf.providerClass) - val provider = providerClass.newInstance().asInstanceOf[StateStoreProvider] + val provider = create(storeConf.providerClass) provider.init(stateStoreId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) provider } @@ -298,7 +330,7 @@ object StateStore extends Logging { startMaintenanceIfNeeded() val provider = loadedProviders.getOrElseUpdate( storeProviderId, - StateStoreProvider.instantiate( + StateStoreProvider.createAndInit( storeProviderId.storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) ) reportActiveStoreInstance(storeProviderId) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index c5722466a33af..77b1160a063fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.streaming import java.util.UUID import java.util.concurrent.TimeUnit._ +import scala.collection.JavaConverters._ + import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ @@ -29,9 +31,9 @@ import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.streaming.state._ -import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.streaming.{OutputMode, StateOperatorProgress} import org.apache.spark.sql.types._ import org.apache.spark.util.{CompletionIterator, NextIterator} @@ -73,8 +75,21 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan => "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows"), "allUpdatesTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "total time to update rows"), "allRemovalsTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "total time to remove rows"), - "commitTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time to commit changes") - ) + "commitTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time to commit changes"), + "stateMemory" -> SQLMetrics.createSizeMetric(sparkContext, "memory used by state") + ) ++ stateStoreCustomMetrics + + /** + * Get the progress made by this stateful operator after execution. This should be called in + * the driver after this SparkPlan has been executed and metrics have been updated. + */ + def getProgress(): StateOperatorProgress = { + new StateOperatorProgress( + numRowsTotal = longMetric("numTotalStateRows").value, + numRowsUpdated = longMetric("numUpdatedStateRows").value, + memoryUsedBytes = longMetric("stateMemory").value, + numPartitions = this.sqlContext.conf.numShufflePartitions) + } /** Records the duration of running `body` for the next query progress update. */ protected def timeTakenMs(body: => Unit): Long = { @@ -83,6 +98,26 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan => val endTime = System.nanoTime() math.max(NANOSECONDS.toMillis(endTime - startTime), 0) } + + /** + * Set the SQL metrics related to the state store. + * This should be called in that task after the store has been updated. + */ + protected def setStoreMetrics(store: StateStore): Unit = { + + val storeMetrics = store.metrics + longMetric("numTotalStateRows") += storeMetrics.numKeys + longMetric("stateMemory") += storeMetrics.memoryUsedBytes + storeMetrics.customMetrics.foreach { case (metric, value) => + longMetric(metric.name) += value + } + } + + private def stateStoreCustomMetrics: Map[String, SQLMetric] = { + val provider = StateStoreProvider.create(sqlContext.conf.stateStoreProviderClass) + provider.supportedCustomMetrics.map { m => + m.name -> SQLMetrics.createTimingMetric(sparkContext, m.desc) }.toMap + } } /** An operator that supports watermark. */ @@ -197,7 +232,6 @@ case class StateStoreSaveExec( Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) val numOutputRows = longMetric("numOutputRows") - val numTotalStateRows = longMetric("numTotalStateRows") val numUpdatedStateRows = longMetric("numUpdatedStateRows") val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") val allRemovalsTimeMs = longMetric("allRemovalsTimeMs") @@ -218,7 +252,7 @@ case class StateStoreSaveExec( commitTimeMs += timeTakenMs { store.commit() } - numTotalStateRows += store.numKeys() + setStoreMetrics(store) store.iterator().map { rowPair => numOutputRows += 1 rowPair.value @@ -261,7 +295,7 @@ case class StateStoreSaveExec( override protected def close(): Unit = { allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - removalStartTimeNs) commitTimeMs += timeTakenMs { store.commit() } - numTotalStateRows += store.numKeys() + setStoreMetrics(store) } } @@ -285,7 +319,7 @@ case class StateStoreSaveExec( // Remove old aggregates if watermark specified allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) } commitTimeMs += timeTakenMs { store.commit() } - numTotalStateRows += store.numKeys() + setStoreMetrics(store) false } else { true @@ -368,7 +402,7 @@ case class StreamingDeduplicateExec( allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) } commitTimeMs += timeTakenMs { store.commit() } - numTotalStateRows += store.numKeys() + setStoreMetrics(store) }) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 267f76217df84..37f4f8d4ab65e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -168,6 +168,7 @@ abstract class BaseSessionStateBuilder( override val extendedCheckRules: Seq[LogicalPlan => Unit] = PreWriteCheck +: + PreReadCheck +: HiveOnlyCheck +: customCheckRules } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala index fb590e7df996b..81a2387b80396 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala @@ -37,7 +37,10 @@ import org.apache.spark.annotation.InterfaceStability @InterfaceStability.Evolving class StateOperatorProgress private[sql]( val numRowsTotal: Long, - val numRowsUpdated: Long) extends Serializable { + val numRowsUpdated: Long, + val memoryUsedBytes: Long, + val numPartitions: Long + ) extends Serializable { /** The compact JSON representation of this progress. */ def json: String = compact(render(jsonValue)) @@ -45,9 +48,14 @@ class StateOperatorProgress private[sql]( /** The pretty (i.e. indented) JSON representation of this progress. */ def prettyJson: String = pretty(render(jsonValue)) + private[sql] def copy(newNumRowsUpdated: Long): StateOperatorProgress = + new StateOperatorProgress(numRowsTotal, newNumRowsUpdated, memoryUsedBytes, numPartitions) + private[sql] def jsonValue: JValue = { ("numRowsTotal" -> JInt(numRowsTotal)) ~ - ("numRowsUpdated" -> JInt(numRowsUpdated)) + ("numRowsUpdated" -> JInt(numRowsUpdated)) ~ + ("memoryUsedBytes" -> JInt(memoryUsedBytes)) ~ + ("numPartitions" -> JInt(numPartitions)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index bc708ca88d7e1..7c45be21961d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -530,6 +530,63 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { ) } + test("input_file_name, input_file_block_start, input_file_block_length - more than one source") { + withTempView("tempView1") { + withTable("tab1", "tab2") { + val data = sparkContext.parallelize(0 to 9).toDF("id") + data.write.saveAsTable("tab1") + data.write.saveAsTable("tab2") + data.createOrReplaceTempView("tempView1") + Seq("input_file_name", "input_file_block_start", "input_file_block_length").foreach { f => + val e = intercept[AnalysisException] { + sql(s"SELECT *, $f() FROM tab1 JOIN tab2 ON tab1.id = tab2.id") + }.getMessage + assert(e.contains(s"'$f' does not support more than one source")) + } + + def checkResult( + fromClause: String, + exceptionExpected: Boolean, + numExpectedRows: Int = 0): Unit = { + val stmt = s"SELECT *, input_file_name() FROM ($fromClause)" + if (exceptionExpected) { + val e = intercept[AnalysisException](sql(stmt)).getMessage + assert(e.contains("'input_file_name' does not support more than one source")) + } else { + assert(sql(stmt).count() == numExpectedRows) + } + } + + checkResult( + "SELECT * FROM tab1 UNION ALL SELECT * FROM tab2 UNION ALL SELECT * FROM tab2", + exceptionExpected = false, + numExpectedRows = 30) + + checkResult( + "(SELECT * FROM tempView1 NATURAL JOIN tab2) UNION ALL SELECT * FROM tab2", + exceptionExpected = false, + numExpectedRows = 20) + + checkResult( + "(SELECT * FROM tab1 UNION ALL SELECT * FROM tab2) NATURAL JOIN tempView1", + exceptionExpected = false, + numExpectedRows = 20) + + checkResult( + "(SELECT * FROM tempView1 UNION ALL SELECT * FROM tab2) NATURAL JOIN tab2", + exceptionExpected = true) + + checkResult( + "(SELECT * FROM tab1 NATURAL JOIN tab2) UNION ALL SELECT * FROM tab2", + exceptionExpected = true) + + checkResult( + "(SELECT * FROM tab1 UNION ALL SELECT * FROM tab2) NATURAL JOIN tab2", + exceptionExpected = true) + } + } + } + test("input_file_name, input_file_block_start, input_file_block_length - FileScanRDD") { withTempPath { dir => val data = sparkContext.parallelize(0 to 10).toDF("id") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 7cb86dc143844..c843b65020d8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.streaming.state import java.io.{File, IOException} import java.net.URI import java.util.UUID +import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ import scala.collection.mutable @@ -184,6 +185,15 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } } + test("reports memory usage") { + val provider = newStoreProvider() + val store = provider.getStore(0) + val noDataMemoryUsed = store.metrics.memoryUsedBytes + put(store, "a", 1) + store.commit() + assert(store.metrics.memoryUsedBytes > noDataMemoryUsed) + } + test("StateStore.get") { quietly { val dir = newDir() @@ -554,12 +564,12 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] assert(!store.hasCommitted) assert(get(store, "a") === None) assert(store.iterator().isEmpty) - assert(store.numKeys() === 0) + assert(store.metrics.numKeys === 0) // Verify state after updating put(store, "a", 1) assert(get(store, "a") === Some(1)) - assert(store.numKeys() === 1) + assert(store.metrics.numKeys === 1) assert(store.iterator().nonEmpty) assert(getLatestData(provider).isEmpty) @@ -567,9 +577,9 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] // Make updates, commit and then verify state put(store, "b", 2) put(store, "aa", 3) - assert(store.numKeys() === 3) + assert(store.metrics.numKeys === 3) remove(store, _.startsWith("a")) - assert(store.numKeys() === 1) + assert(store.metrics.numKeys === 1) assert(store.commit() === 1) assert(store.hasCommitted) @@ -587,9 +597,9 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] // New updates to the reloaded store with new version, and does not change old version val reloadedProvider = newStoreProvider(store.id) val reloadedStore = reloadedProvider.getStore(1) - assert(reloadedStore.numKeys() === 1) + assert(reloadedStore.metrics.numKeys === 1) put(reloadedStore, "c", 4) - assert(reloadedStore.numKeys() === 2) + assert(reloadedStore.metrics.numKeys === 2) assert(reloadedStore.commit() === 2) assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4)) assert(getLatestData(provider) === Set("b" -> 2, "c" -> 4)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 6676099d426ba..a5cf40c3581c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -127,8 +127,7 @@ class FileStreamSinkSuite extends StreamTest { // Verify that MetadataLogFileIndex is being used and the correct partitioning schema has // been inferred val hadoopdFsRelations = outputDf.queryExecution.analyzed.collect { - case LogicalRelation(baseRelation, _, _) if baseRelation.isInstanceOf[HadoopFsRelation] => - baseRelation.asInstanceOf[HadoopFsRelation] + case LogicalRelation(baseRelation: HadoopFsRelation, _, _) => baseRelation } assert(hadoopdFsRelations.size === 1) assert(hadoopdFsRelations.head.location.isInstanceOf[MetadataLogFileIndex]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 0d9ca81349be5..9f2f0d195de9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.RDDScanExec import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream} -import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, UnsafeRowPair} +import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair} import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.{DataType, IntegerType} @@ -1077,7 +1077,7 @@ object FlatMapGroupsWithStateSuite { override def abort(): Unit = { } override def id: StateStoreId = null override def version: Long = 0 - override def numKeys(): Long = map.size + override def metrics: StateStoreMetrics = new StateStoreMetrics(map.size, 0, Map.empty) override def hasCommitted: Boolean = true } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala index 901cf34f289cc..d3cafac4f1755 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala @@ -33,16 +33,10 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQueryStatusAndProgressSuite._ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { - implicit class EqualsIgnoreCRLF(source: String) { - def equalsIgnoreCRLF(target: String): Boolean = { - source.replaceAll("\r\n|\r|\n", System.lineSeparator) === - target.replaceAll("\r\n|\r|\n", System.lineSeparator) - } - } - test("StreamingQueryProgress - prettyJson") { val json1 = testProgress1.prettyJson - assert(json1.equalsIgnoreCRLF( + assertJson( + json1, s""" |{ | "id" : "${testProgress1.id.toString}", @@ -62,7 +56,9 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { | }, | "stateOperators" : [ { | "numRowsTotal" : 0, - | "numRowsUpdated" : 1 + | "numRowsUpdated" : 1, + | "memoryUsedBytes" : 2, + | "numPartitions" : 4 | } ], | "sources" : [ { | "description" : "source", @@ -75,13 +71,13 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { | "description" : "sink" | } |} - """.stripMargin.trim)) + """.stripMargin.trim) assert(compact(parse(json1)) === testProgress1.json) val json2 = testProgress2.prettyJson - assert( - json2.equalsIgnoreCRLF( - s""" + assertJson( + json2, + s""" |{ | "id" : "${testProgress2.id.toString}", | "runId" : "${testProgress2.runId.toString}", @@ -93,7 +89,9 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { | }, | "stateOperators" : [ { | "numRowsTotal" : 0, - | "numRowsUpdated" : 1 + | "numRowsUpdated" : 1, + | "memoryUsedBytes" : 2, + | "numPartitions" : 4 | } ], | "sources" : [ { | "description" : "source", @@ -105,7 +103,7 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { | "description" : "sink" | } |} - """.stripMargin.trim)) + """.stripMargin.trim) assert(compact(parse(json2)) === testProgress2.json) } @@ -121,14 +119,15 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { test("StreamingQueryStatus - prettyJson") { val json = testStatus.prettyJson - assert(json.equalsIgnoreCRLF( + assertJson( + json, """ |{ | "message" : "active", | "isDataAvailable" : true, | "isTriggerActive" : false |} - """.stripMargin.trim)) + """.stripMargin.trim) } test("StreamingQueryStatus - json") { @@ -209,6 +208,12 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { } } } + + def assertJson(source: String, expected: String): Unit = { + assert( + source.replaceAll("\r\n|\r|\n", System.lineSeparator) === + expected.replaceAll("\r\n|\r|\n", System.lineSeparator)) + } } object StreamingQueryStatusAndProgressSuite { @@ -224,7 +229,8 @@ object StreamingQueryStatusAndProgressSuite { "min" -> "2016-12-05T20:54:20.827Z", "avg" -> "2016-12-05T20:54:20.827Z", "watermark" -> "2016-12-05T20:54:20.827Z").asJava), - stateOperators = Array(new StateOperatorProgress(numRowsTotal = 0, numRowsUpdated = 1)), + stateOperators = Array(new StateOperatorProgress( + numRowsTotal = 0, numRowsUpdated = 1, memoryUsedBytes = 2, numPartitions = 4)), sources = Array( new SourceProgress( description = "source", @@ -247,7 +253,8 @@ object StreamingQueryStatusAndProgressSuite { durationMs = new java.util.HashMap(Map("total" -> 0L).mapValues(long2Long).asJava), // empty maps should be handled correctly eventTime = new java.util.HashMap(Map.empty[String, String].asJava), - stateOperators = Array(new StateOperatorProgress(numRowsTotal = 0, numRowsUpdated = 1)), + stateOperators = Array(new StateOperatorProgress( + numRowsTotal = 0, numRowsUpdated = 1, memoryUsedBytes = 2, numPartitions = 4)), sources = Array( new SourceProgress( description = "source", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index e16c9e46b7723..92cb4ef11c9e3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -69,22 +69,23 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session override protected def analyzer: Analyzer = new Analyzer(catalog, conf) { override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = new ResolveHiveSerdeTable(session) +: - new FindDataSourceTable(session) +: - new ResolveSQLOnFile(session) +: - customResolutionRules + new FindDataSourceTable(session) +: + new ResolveSQLOnFile(session) +: + customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = new DetermineTableStats(session) +: - RelationConversions(conf, catalog) +: - PreprocessTableCreation(session) +: - PreprocessTableInsertion(conf) +: - DataSourceAnalysis(conf) +: - HiveAnalysis +: - customPostHocResolutionRules + RelationConversions(conf, catalog) +: + PreprocessTableCreation(session) +: + PreprocessTableInsertion(conf) +: + DataSourceAnalysis(conf) +: + HiveAnalysis +: + customPostHocResolutionRules override val extendedCheckRules: Seq[LogicalPlan => Unit] = PreWriteCheck +: - customCheckRules + PreReadCheck +: + customCheckRules } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index a29d7a7565ee1..2a522a1431f45 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -86,8 +86,8 @@ private[spark] object HiveUtils extends Logging { .createWithDefault("builtin") val CONVERT_METASTORE_PARQUET = buildConf("spark.sql.hive.convertMetastoreParquet") - .doc("When set to false, Spark SQL will use the Hive SerDe for parquet tables instead of " + - "the built in support.") + .doc("When set to true, the built-in Parquet reader and writer are used to process " + + "parquet tables created by using the HiveQL syntax, instead of Hive serde.") .booleanConf .createWithDefault(true) @@ -101,8 +101,8 @@ private[spark] object HiveUtils extends Logging { val CONVERT_METASTORE_ORC = buildConf("spark.sql.hive.convertMetastoreOrc") .internal() - .doc("When set to false, Spark SQL will use the Hive SerDe for ORC tables instead of " + - "the built in support.") + .doc("When set to true, the built-in ORC reader and writer are used to process " + + "ORC tables created by using the HiveQL syntax, instead of Hive serde.") .booleanConf .createWithDefault(false)