From 0698e6c88ca11fdfd6e5498cab784cf6dbcdfacb Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 11 May 2017 14:48:13 +0800 Subject: [PATCH 1/7] [SPARK-20606][ML] Revert "[] ML 2.2 QA: Remove deprecated methods for ML" This reverts commit b8733e0ad9f5a700f385e210450fd2c10137293e. Author: Yanbo Liang Closes #17944 from yanboliang/spark-20606-revert. --- .../DecisionTreeClassifier.scala | 18 +-- .../ml/classification/GBTClassifier.scala | 24 ++-- .../RandomForestClassifier.scala | 24 ++-- .../ml/regression/DecisionTreeRegressor.scala | 18 +-- .../spark/ml/regression/GBTRegressor.scala | 24 ++-- .../ml/regression/RandomForestRegressor.scala | 24 ++-- .../org/apache/spark/ml/tree/treeParams.scala | 105 ++++++++++++++++++ .../org/apache/spark/ml/util/ReadWrite.scala | 16 +++ project/MimaExcludes.scala | 68 ------------ python/pyspark/ml/util.py | 32 ++++++ 10 files changed, 219 insertions(+), 134 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 5fb105c6aff60..9f60f0896ec52 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -54,27 +54,27 @@ class DecisionTreeClassifier @Since("1.4.0") ( /** @group setParam */ @Since("1.4.0") - def setMaxDepth(value: Int): this.type = set(maxDepth, value) + override def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - def setMaxBins(value: Int): this.type = set(maxBins, value) + override def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -86,15 +86,15 @@ class DecisionTreeClassifier @Since("1.4.0") ( * @group setParam */ @Since("1.4.0") - def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group setParam */ @Since("1.4.0") - def setImpurity(value: String): this.type = set(impurity, value) + override def setImpurity(value: String): this.type = set(impurity, value) /** @group setParam */ @Since("1.6.0") - def setSeed(value: Long): this.type = set(seed, value) + override def setSeed(value: Long): this.type = set(seed, value) override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = { val categoricalFeatures: Map[Int, Int] = diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 263ed10f19855..ade0960f87a0d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -70,27 +70,27 @@ class GBTClassifier @Since("1.4.0") ( /** @group setParam */ @Since("1.4.0") - def setMaxDepth(value: Int): this.type = set(maxDepth, value) + override def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - def setMaxBins(value: Int): this.type = set(maxBins, value) + override def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -102,7 +102,7 @@ class GBTClassifier @Since("1.4.0") ( * @group setParam */ @Since("1.4.0") - def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** * The impurity setting is ignored for GBT models. @@ -111,7 +111,7 @@ class GBTClassifier @Since("1.4.0") ( * @group setParam */ @Since("1.4.0") - def setImpurity(value: String): this.type = { + override def setImpurity(value: String): this.type = { logWarning("GBTClassifier.setImpurity should NOT be used") this } @@ -120,21 +120,21 @@ class GBTClassifier @Since("1.4.0") ( /** @group setParam */ @Since("1.4.0") - def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) /** @group setParam */ @Since("1.4.0") - def setSeed(value: Long): this.type = set(seed, value) + override def setSeed(value: Long): this.type = set(seed, value) // Parameters from GBTParams: /** @group setParam */ @Since("1.4.0") - def setMaxIter(value: Int): this.type = set(maxIter, value) + override def setMaxIter(value: Int): this.type = set(maxIter, value) /** @group setParam */ @Since("1.4.0") - def setStepSize(value: Double): this.type = set(stepSize, value) + override def setStepSize(value: Double): this.type = set(stepSize, value) // Parameters from GBTClassifierParams: diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 441cfda899276..ab4c235209289 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -56,27 +56,27 @@ class RandomForestClassifier @Since("1.4.0") ( /** @group setParam */ @Since("1.4.0") - def setMaxDepth(value: Int): this.type = set(maxDepth, value) + override def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - def setMaxBins(value: Int): this.type = set(maxBins, value) + override def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -88,31 +88,31 @@ class RandomForestClassifier @Since("1.4.0") ( * @group setParam */ @Since("1.4.0") - def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group setParam */ @Since("1.4.0") - def setImpurity(value: String): this.type = set(impurity, value) + override def setImpurity(value: String): this.type = set(impurity, value) // Parameters from TreeEnsembleParams: /** @group setParam */ @Since("1.4.0") - def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) /** @group setParam */ @Since("1.4.0") - def setSeed(value: Long): this.type = set(seed, value) + override def setSeed(value: Long): this.type = set(seed, value) // Parameters from RandomForestParams: /** @group setParam */ @Since("1.4.0") - def setNumTrees(value: Int): this.type = set(numTrees, value) + override def setNumTrees(value: Int): this.type = set(numTrees, value) /** @group setParam */ @Since("1.4.0") - def setFeatureSubsetStrategy(value: String): this.type = + override def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index c2b0358e8405d..01c5cc1c7efa9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -53,27 +53,27 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S // Override parameter setters from parent trait for Java API compatibility. /** @group setParam */ @Since("1.4.0") - def setMaxDepth(value: Int): this.type = set(maxDepth, value) + override def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - def setMaxBins(value: Int): this.type = set(maxBins, value) + override def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -85,15 +85,15 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S * @group setParam */ @Since("1.4.0") - def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group setParam */ @Since("1.4.0") - def setImpurity(value: String): this.type = set(impurity, value) + override def setImpurity(value: String): this.type = set(impurity, value) /** @group setParam */ @Since("1.6.0") - def setSeed(value: Long): this.type = set(seed, value) + override def setSeed(value: Long): this.type = set(seed, value) /** @group setParam */ @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 8d9b519efb142..08d175cb94442 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -68,27 +68,27 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) /** @group setParam */ @Since("1.4.0") - def setMaxDepth(value: Int): this.type = set(maxDepth, value) + override def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - def setMaxBins(value: Int): this.type = set(maxBins, value) + override def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -100,7 +100,7 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) * @group setParam */ @Since("1.4.0") - def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** * The impurity setting is ignored for GBT models. @@ -109,7 +109,7 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) * @group setParam */ @Since("1.4.0") - def setImpurity(value: String): this.type = { + override def setImpurity(value: String): this.type = { logWarning("GBTRegressor.setImpurity should NOT be used") this } @@ -118,21 +118,21 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) /** @group setParam */ @Since("1.4.0") - def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) /** @group setParam */ @Since("1.4.0") - def setSeed(value: Long): this.type = set(seed, value) + override def setSeed(value: Long): this.type = set(seed, value) // Parameters from GBTParams: /** @group setParam */ @Since("1.4.0") - def setMaxIter(value: Int): this.type = set(maxIter, value) + override def setMaxIter(value: Int): this.type = set(maxIter, value) /** @group setParam */ @Since("1.4.0") - def setStepSize(value: Double): this.type = set(stepSize, value) + override def setStepSize(value: Double): this.type = set(stepSize, value) // Parameters from GBTRegressorParams: diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 7b9ddf6e9521a..a58da50fad972 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -55,27 +55,27 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S /** @group setParam */ @Since("1.4.0") - def setMaxDepth(value: Int): this.type = set(maxDepth, value) + override def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - def setMaxBins(value: Int): this.type = set(maxBins, value) + override def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -87,31 +87,31 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S * @group setParam */ @Since("1.4.0") - def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group setParam */ @Since("1.4.0") - def setImpurity(value: String): this.type = set(impurity, value) + override def setImpurity(value: String): this.type = set(impurity, value) // Parameters from TreeEnsembleParams: /** @group setParam */ @Since("1.4.0") - def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) /** @group setParam */ @Since("1.4.0") - def setSeed(value: Long): this.type = set(seed, value) + override def setSeed(value: Long): this.type = set(seed, value) // Parameters from RandomForestParams: /** @group setParam */ @Since("1.4.0") - def setNumTrees(value: Int): this.type = set(numTrees, value) + override def setNumTrees(value: Int): this.type = set(numTrees, value) /** @group setParam */ @Since("1.4.0") - def setFeatureSubsetStrategy(value: String): this.type = + override def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) override protected def train(dataset: Dataset[_]): RandomForestRegressionModel = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 5526d4d75bd73..cd1950bd76c05 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -109,24 +109,80 @@ private[ml] trait DecisionTreeParams extends PredictorParams setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0, maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10) + /** + * @deprecated This method is deprecated and will be removed in 2.2.0. + * @group setParam + */ + @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") + def setMaxDepth(value: Int): this.type = set(maxDepth, value) + /** @group getParam */ final def getMaxDepth: Int = $(maxDepth) + /** + * @deprecated This method is deprecated and will be removed in 2.2.0. + * @group setParam + */ + @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") + def setMaxBins(value: Int): this.type = set(maxBins, value) + /** @group getParam */ final def getMaxBins: Int = $(maxBins) + /** + * @deprecated This method is deprecated and will be removed in 2.2.0. + * @group setParam + */ + @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + /** @group getParam */ final def getMinInstancesPerNode: Int = $(minInstancesPerNode) + /** + * @deprecated This method is deprecated and will be removed in 2.2.0. + * @group setParam + */ + @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + /** @group getParam */ final def getMinInfoGain: Double = $(minInfoGain) + /** + * @deprecated This method is deprecated and will be removed in 2.2.0. + * @group setParam + */ + @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") + def setSeed(value: Long): this.type = set(seed, value) + + /** + * @deprecated This method is deprecated and will be removed in 2.2.0. + * @group expertSetParam + */ + @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + /** @group expertGetParam */ final def getMaxMemoryInMB: Int = $(maxMemoryInMB) + /** + * @deprecated This method is deprecated and will be removed in 2.2.0. + * @group expertSetParam + */ + @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + /** @group expertGetParam */ final def getCacheNodeIds: Boolean = $(cacheNodeIds) + /** + * @deprecated This method is deprecated and will be removed in 2.2.0. + * @group setParam + */ + @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + /** (private[ml]) Create a Strategy instance to use with the old API. */ private[ml] def getOldStrategy( categoricalFeatures: Map[Int, Int], @@ -169,6 +225,13 @@ private[ml] trait TreeClassifierParams extends Params { setDefault(impurity -> "gini") + /** + * @deprecated This method is deprecated and will be removed in 2.2.0. + * @group setParam + */ + @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") + def setImpurity(value: String): this.type = set(impurity, value) + /** @group getParam */ final def getImpurity: String = $(impurity).toLowerCase(Locale.ROOT) @@ -213,6 +276,13 @@ private[ml] trait TreeRegressorParams extends Params { setDefault(impurity -> "variance") + /** + * @deprecated This method is deprecated and will be removed in 2.2.0. + * @group setParam + */ + @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") + def setImpurity(value: String): this.type = set(impurity, value) + /** @group getParam */ final def getImpurity: String = $(impurity).toLowerCase(Locale.ROOT) @@ -268,6 +338,13 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams { setDefault(subsamplingRate -> 1.0) + /** + * @deprecated This method is deprecated and will be removed in 2.2.0. + * @group setParam + */ + @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") + def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + /** @group getParam */ final def getSubsamplingRate: Double = $(subsamplingRate) @@ -305,6 +382,13 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { setDefault(numTrees -> 20) + /** + * @deprecated This method is deprecated and will be removed in 2.2.0. + * @group setParam + */ + @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") + def setNumTrees(value: Int): this.type = set(numTrees, value) + /** @group getParam */ final def getNumTrees: Int = $(numTrees) @@ -346,6 +430,13 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { setDefault(featureSubsetStrategy -> "auto") + /** + * @deprecated This method is deprecated and will be removed in 2.2.0. + * @group setParam + */ + @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") + def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) + /** @group getParam */ final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase(Locale.ROOT) } @@ -380,6 +471,13 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { // final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", "") // validationTol -> 1e-5 + /** + * @deprecated This method is deprecated and will be removed in 2.2.0. + * @group setParam + */ + @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + /** * Param for Step size (a.k.a. learning rate) in interval (0, 1] for shrinking * the contribution of each estimator. @@ -393,6 +491,13 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { /** @group getParam */ final def getStepSize: Double = $(stepSize) + /** + * @deprecated This method is deprecated and will be removed in 2.2.0. + * @group setParam + */ + @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") + def setStepSize(value: Double): this.type = set(stepSize, value) + setDefault(maxIter -> 20, stepSize -> 0.1) /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index f7e570fd5cc94..a8b80031faf86 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -42,6 +42,16 @@ import org.apache.spark.util.Utils private[util] sealed trait BaseReadWrite { private var optionSparkSession: Option[SparkSession] = None + /** + * Sets the Spark SQLContext to use for saving/loading. + */ + @Since("1.6.0") + @deprecated("Use session instead, This method will be removed in 2.2.0.", "2.0.0") + def context(sqlContext: SQLContext): this.type = { + optionSparkSession = Option(sqlContext.sparkSession) + this + } + /** * Sets the Spark Session to use for saving/loading. */ @@ -120,6 +130,9 @@ abstract class MLWriter extends BaseReadWrite with Logging { // override for Java compatibility override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) + + // override for Java compatibility + override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) } /** @@ -175,6 +188,9 @@ abstract class MLReader[T] extends BaseReadWrite { // override for Java compatibility override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) + + // override for Java compatibility + override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) } /** diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 2dff154967428..3cc089dcede38 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -1008,74 +1008,6 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setFeatureSubsetStrategy"), ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.numTrees"), ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setFeatureSubsetStrategy") - ) ++ Seq( - // [SPARK-20606] ML 2.2 QA: Remove deprecated methods for ML - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setSeed"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMinInfoGain"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setCacheNodeIds"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setCheckpointInterval"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMaxDepth"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setImpurity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMaxMemoryInMB"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMaxBins"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMinInstancesPerNode"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setSeed"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMinInfoGain"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setSubsamplingRate"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxIter"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setCacheNodeIds"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setCheckpointInterval"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxDepth"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setImpurity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxMemoryInMB"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setStepSize"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxBins"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMinInstancesPerNode"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setSeed"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMinInfoGain"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setSubsamplingRate"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setCacheNodeIds"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setCheckpointInterval"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMaxDepth"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setImpurity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMaxMemoryInMB"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setFeatureSubsetStrategy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMaxBins"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMinInstancesPerNode"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setSeed"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMinInfoGain"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setCacheNodeIds"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setCheckpointInterval"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMaxDepth"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setImpurity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMaxMemoryInMB"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMaxBins"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMinInstancesPerNode"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setSeed"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMinInfoGain"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setSubsamplingRate"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxIter"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setCacheNodeIds"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setCheckpointInterval"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxDepth"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setImpurity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxMemoryInMB"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setStepSize"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxBins"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMinInstancesPerNode"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setSeed"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMinInfoGain"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setSubsamplingRate"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setCacheNodeIds"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setCheckpointInterval"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMaxDepth"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setImpurity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMaxMemoryInMB"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setFeatureSubsetStrategy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMaxBins"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMinInstancesPerNode"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.util.MLWriter.context"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.util.MLReader.context") ) } diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 688109ab11fd2..02016f172aebc 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -76,6 +76,13 @@ def overwrite(self): """Overwrites if the output path already exists.""" raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) + def context(self, sqlContext): + """ + Sets the SQL context to use for saving. + .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead. + """ + raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) + def session(self, sparkSession): """Sets the Spark Session to use for saving.""" raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) @@ -103,6 +110,15 @@ def overwrite(self): self._jwrite.overwrite() return self + def context(self, sqlContext): + """ + Sets the SQL context to use for saving. + .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead. + """ + warnings.warn("Deprecated in 2.1 and will be removed in 2.2, use session instead.") + self._jwrite.context(sqlContext._ssql_ctx) + return self + def session(self, sparkSession): """Sets the Spark Session to use for saving.""" self._jwrite.session(sparkSession._jsparkSession) @@ -149,6 +165,13 @@ def load(self, path): """Load the ML instance from the input path.""" raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self)) + def context(self, sqlContext): + """ + Sets the SQL context to use for loading. + .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead. + """ + raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self)) + def session(self, sparkSession): """Sets the Spark Session to use for loading.""" raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self)) @@ -174,6 +197,15 @@ def load(self, path): % self._clazz) return self._clazz._from_java(java_obj) + def context(self, sqlContext): + """ + Sets the SQL context to use for loading. + .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead. + """ + warnings.warn("Deprecated in 2.1 and will be removed in 2.2, use session instead.") + self._jread.context(sqlContext._ssql_ctx) + return self + def session(self, sparkSession): """Sets the Spark Session to use for loading.""" self._jread.session(sparkSession._jsparkSession) From 65accb813add9f58c1e9f1555863fe0bb1932ad8 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Thu, 11 May 2017 15:26:48 +0800 Subject: [PATCH 2/7] [SPARK-17029] make toJSON not go through rdd form but operate on dataset always ## What changes were proposed in this pull request? Don't convert toRdd when doing toJSON ## How was this patch tested? Existing unit tests Author: Robert Kruszewski Closes #14615 from robert3005/robertk/correct-tojson. --- .../src/main/scala/org/apache/spark/sql/Dataset.scala | 8 +++----- .../sql/execution/datasources/json/JsonSuite.scala | 10 ++++++++++ 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 61154e23b1b88..c75921e867f64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2806,7 +2806,7 @@ class Dataset[T] private[sql]( def toJSON: Dataset[String] = { val rowSchema = this.schema val sessionLocalTimeZone = sparkSession.sessionState.conf.sessionLocalTimeZone - val rdd: RDD[String] = queryExecution.toRdd.mapPartitions { iter => + mapPartitions { iter => val writer = new CharArrayWriter() // create the Generator without separator inserted between 2 records val gen = new JacksonGenerator(rowSchema, writer, @@ -2815,7 +2815,7 @@ class Dataset[T] private[sql]( new Iterator[String] { override def hasNext: Boolean = iter.hasNext override def next(): String = { - gen.write(iter.next()) + gen.write(exprEnc.toRow(iter.next())) gen.flush() val json = writer.toString @@ -2828,9 +2828,7 @@ class Dataset[T] private[sql]( json } } - } - import sparkSession.implicits.newStringEncoder - sparkSession.createDataset(rdd) + } (Encoders.STRING) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 2ab03819964be..5e7f7944bd845 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.{functions => F, _} import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.ExternalRDD import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.json.JsonInferSchema.compatibleType import org.apache.spark.sql.internal.SQLConf @@ -1326,6 +1327,15 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } + test("Dataset toJSON doesn't construct rdd") { + val containsRDD = spark.emptyDataFrame.toJSON.queryExecution.logical.find { + case ExternalRDD(_, _) => true + case _ => false + } + + assert(containsRDD.isEmpty, "Expected logical plan of toJSON to not contain an RDD") + } + test("JSONRelation equality test") { withTempPath(dir => { val path = dir.getCanonicalFile.toURI.toString From b4c99f43690f8cfba414af90fa2b3998a510bba8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 11 May 2017 00:41:15 -0700 Subject: [PATCH 3/7] [SPARK-20569][SQL] RuntimeReplaceable functions should not take extra parameters ## What changes were proposed in this pull request? `RuntimeReplaceable` always has a constructor with the expression to replace with, and this constructor should not be the function builder. ## How was this patch tested? new regression test Author: Wenchen Fan Closes #17876 from cloud-fan/minor. --- .../catalyst/analysis/FunctionRegistry.scala | 20 +++++++++++++------ .../org/apache/spark/sql/SQLQuerySuite.scala | 5 +++++ 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index e1d83a86f99dc..6fc154f8debcf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.analysis +import java.lang.reflect.Modifier + import scala.language.existentials import scala.reflect.ClassTag import scala.util.{Failure, Success, Try} @@ -455,8 +457,17 @@ object FunctionRegistry { private def expression[T <: Expression](name: String) (implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = { + // For `RuntimeReplaceable`, skip the constructor with most arguments, which is the main + // constructor and contains non-parameter `child` and should not be used as function builder. + val constructors = if (classOf[RuntimeReplaceable].isAssignableFrom(tag.runtimeClass)) { + val all = tag.runtimeClass.getConstructors + val maxNumArgs = all.map(_.getParameterCount).max + all.filterNot(_.getParameterCount == maxNumArgs) + } else { + tag.runtimeClass.getConstructors + } // See if we can find a constructor that accepts Seq[Expression] - val varargCtor = Try(tag.runtimeClass.getDeclaredConstructor(classOf[Seq[_]])).toOption + val varargCtor = constructors.find(_.getParameterTypes.toSeq == Seq(classOf[Seq[_]])) val builder = (expressions: Seq[Expression]) => { if (varargCtor.isDefined) { // If there is an apply method that accepts Seq[Expression], use that one. @@ -470,11 +481,8 @@ object FunctionRegistry { } else { // Otherwise, find a constructor method that matches the number of arguments, and use that. val params = Seq.fill(expressions.size)(classOf[Expression]) - val f = Try(tag.runtimeClass.getDeclaredConstructor(params : _*)) match { - case Success(e) => - e - case Failure(e) => - throw new AnalysisException(s"Invalid number of arguments for function $name") + val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse { + throw new AnalysisException(s"Invalid number of arguments for function $name") } Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match { case Success(e) => e diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 3ecbf96b41961..cd14d24370bad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2619,4 +2619,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { new URL(jarFromInvalidFs) } } + + test("RuntimeReplaceable functions should not take extra parameters") { + val e = intercept[AnalysisException](sql("SELECT nvl(1, 2, 3)")) + assert(e.message.contains("Invalid number of arguments")) + } } From 8c67aa7f00e0186abe05a1628faf2232b364a61f Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 11 May 2017 18:09:31 +0800 Subject: [PATCH 4/7] [SPARK-20311][SQL] Support aliases for table value functions ## What changes were proposed in this pull request? This pr added parsing rules to support aliases in table value functions. The previous pr (#17666) has been reverted because of the regression. This new pr fixed the regression and add tests in `SQLQueryTestSuite`. ## How was this patch tested? Added tests in `PlanParserSuite` and `SQLQueryTestSuite`. Author: Takeshi Yamamuro Closes #17928 from maropu/SPARK-20311-3. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 20 ++++++++---- .../ResolveTableValuedFunctions.scala | 22 +++++++++++-- .../sql/catalyst/analysis/unresolved.scala | 10 ++++-- .../sql/catalyst/parser/AstBuilder.scala | 17 +++++++--- .../sql/catalyst/analysis/AnalysisSuite.scala | 14 +++++++- .../sql/catalyst/parser/PlanParserSuite.scala | 13 +++++++- .../sql-tests/inputs/inline-table.sql | 3 ++ .../inputs/table-valued-functions.sql | 3 ++ .../sql-tests/results/inline-table.sql.out | 32 ++++++++++++++++++- .../results/table-valued-functions.sql.out | 32 ++++++++++++++++++- 10 files changed, 147 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 14c511f670606..ed5450b494ccd 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -472,15 +472,23 @@ identifierComment ; relationPrimary - : tableIdentifier sample? (AS? strictIdentifier)? #tableName - | '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery - | '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation - | inlineTable #inlineTableDefault2 - | identifier '(' (expression (',' expression)*)? ')' #tableValuedFunction + : tableIdentifier sample? (AS? strictIdentifier)? #tableName + | '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery + | '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation + | inlineTable #inlineTableDefault2 + | functionTable #tableValuedFunction ; inlineTable - : VALUES expression (',' expression)* (AS? identifier identifierList?)? + : VALUES expression (',' expression)* tableAlias + ; + +functionTable + : identifier '(' (expression (',' expression)*)? ')' tableAlias + ; + +tableAlias + : (AS? strictIdentifier identifierList?)? ; rowFormat diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala index de6de24350f23..dad1340571cc8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.analysis import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range} +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Range} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types.{DataType, IntegerType, LongType} @@ -105,7 +105,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => - builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match { + val resolvedFunc = builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match { case Some(tvf) => val resolved = tvf.flatMap { case (argList, resolver) => argList.implicitCast(u.functionArgs) match { @@ -125,5 +125,21 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { case _ => u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function") } + + // If alias names assigned, add `Project` with the aliases + if (u.outputNames.nonEmpty) { + val outputAttrs = resolvedFunc.output + // Checks if the number of the aliases is equal to expected one + if (u.outputNames.size != outputAttrs.size) { + u.failAnalysis(s"expected ${outputAttrs.size} columns but " + + s"found ${u.outputNames.size} columns") + } + val aliases = outputAttrs.zip(u.outputNames).map { + case (attr, name) => Alias(attr, name)() + } + Project(aliases, resolvedFunc) + } else { + resolvedFunc + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 262b894e2a0a3..51bef6e20b9fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -66,10 +66,16 @@ case class UnresolvedInlineTable( /** * A table-valued function, e.g. * {{{ - * select * from range(10); + * select id from range(10); + * + * // Assign alias names + * select t.a from range(10) t(a); * }}} */ -case class UnresolvedTableValuedFunction(functionName: String, functionArgs: Seq[Expression]) +case class UnresolvedTableValuedFunction( + functionName: String, + functionArgs: Seq[Expression], + outputNames: Seq[String]) extends LeafNode { override def output: Seq[Attribute] = Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index d2a9b4a9a9f59..046ea65d454a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -687,7 +687,16 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { */ override def visitTableValuedFunction(ctx: TableValuedFunctionContext) : LogicalPlan = withOrigin(ctx) { - UnresolvedTableValuedFunction(ctx.identifier.getText, ctx.expression.asScala.map(expression)) + val func = ctx.functionTable + val aliases = if (func.tableAlias.identifierList != null) { + visitIdentifierList(func.tableAlias.identifierList) + } else { + Seq.empty + } + + val tvf = UnresolvedTableValuedFunction( + func.identifier.getText, func.expression.asScala.map(expression), aliases) + tvf.optionalMap(func.tableAlias.strictIdentifier)(aliasPlan) } /** @@ -705,14 +714,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } } - val aliases = if (ctx.identifierList != null) { - visitIdentifierList(ctx.identifierList) + val aliases = if (ctx.tableAlias.identifierList != null) { + visitIdentifierList(ctx.tableAlias.identifierList) } else { Seq.tabulate(rows.head.size)(i => s"col${i + 1}") } val table = UnresolvedInlineTable(aliases, rows) - table.optionalMap(ctx.identifier)(aliasPlan) + table.optionalMap(ctx.tableAlias.strictIdentifier)(aliasPlan) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 893bb1b74cea7..31047f688600b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.Cross import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -441,4 +440,17 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { checkAnalysis(SubqueryAlias("tbl", testRelation).as("tbl2"), testRelation) } + + test("SPARK-20311 range(N) as alias") { + def rangeWithAliases(args: Seq[Int], outputNames: Seq[String]): LogicalPlan = { + SubqueryAlias("t", UnresolvedTableValuedFunction("range", args.map(Literal(_)), outputNames)) + .select(star()) + } + assertAnalysisSuccess(rangeWithAliases(3 :: Nil, "a" :: Nil)) + assertAnalysisSuccess(rangeWithAliases(1 :: 4 :: Nil, "b" :: Nil)) + assertAnalysisSuccess(rangeWithAliases(2 :: 6 :: 2 :: Nil, "c" :: Nil)) + assertAnalysisError( + rangeWithAliases(3 :: Nil, "a" :: "b" :: Nil), + Seq("expected 1 columns but found 2 columns")) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 411777d6e85a2..cf137cfdf96e4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -468,7 +468,18 @@ class PlanParserSuite extends PlanTest { test("table valued function") { assertEqual( "select * from range(2)", - UnresolvedTableValuedFunction("range", Literal(2) :: Nil).select(star())) + UnresolvedTableValuedFunction("range", Literal(2) :: Nil, Seq.empty).select(star())) + } + + test("SPARK-20311 range(N) as alias") { + assertEqual( + "SELECT * FROM range(10) AS t", + SubqueryAlias("t", UnresolvedTableValuedFunction("range", Literal(10) :: Nil, Seq.empty)) + .select(star())) + assertEqual( + "SELECT * FROM range(7) AS t(a)", + SubqueryAlias("t", UnresolvedTableValuedFunction("range", Literal(7) :: Nil, "a" :: Nil)) + .select(star())) } test("inline table") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql b/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql index b3ec956cd178e..41d316444ed6b 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql @@ -49,3 +49,6 @@ select * from values ("one", count(1)), ("two", 2) as data(a, b); -- string to timestamp select * from values (timestamp('1991-12-06 00:00:00.0'), array(timestamp('1991-12-06 01:00:00.0'), timestamp('1991-12-06 12:00:00.0'))) as data(a, b); + +-- cross-join inline tables +EXPLAIN EXTENDED SELECT * FROM VALUES ('one', 1), ('three', null) CROSS JOIN VALUES ('one', 1), ('three', null); diff --git a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql index d0d2df7b243d5..72cd8ca9d8722 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql @@ -24,3 +24,6 @@ select * from RaNgE(2); -- Explain EXPLAIN select * from RaNgE(2); + +-- cross-join table valued functions +EXPLAIN EXTENDED SELECT * FROM range(3) CROSS JOIN range(3); diff --git a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out index 4e80f0bda5513..c065ce5012929 100644 --- a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 17 +-- Number of queries: 18 -- !query 0 @@ -151,3 +151,33 @@ select * from values (timestamp('1991-12-06 00:00:00.0'), array(timestamp('1991- struct> -- !query 16 output 1991-12-06 00:00:00 [1991-12-06 01:00:00.0,1991-12-06 12:00:00.0] + + +-- !query 17 +EXPLAIN EXTENDED SELECT * FROM VALUES ('one', 1), ('three', null) CROSS JOIN VALUES ('one', 1), ('three', null) +-- !query 17 schema +struct +-- !query 17 output +== Parsed Logical Plan == +'Project [*] ++- 'Join Cross + :- 'UnresolvedInlineTable [col1, col2], [List(one, 1), List(three, null)] + +- 'UnresolvedInlineTable [col1, col2], [List(one, 1), List(three, null)] + +== Analyzed Logical Plan == +col1: string, col2: int, col1: string, col2: int +Project [col1#x, col2#x, col1#x, col2#x] ++- Join Cross + :- LocalRelation [col1#x, col2#x] + +- LocalRelation [col1#x, col2#x] + +== Optimized Logical Plan == +Join Cross +:- LocalRelation [col1#x, col2#x] ++- LocalRelation [col1#x, col2#x] + +== Physical Plan == +BroadcastNestedLoopJoin BuildRight, Cross +:- LocalTableScan [col1#x, col2#x] ++- BroadcastExchange IdentityBroadcastMode + +- LocalTableScan [col1#x, col2#x] diff --git a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out index e2ee970d35f60..a8bc6faf11262 100644 --- a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 9 +-- Number of queries: 10 -- !query 0 @@ -103,3 +103,33 @@ struct -- !query 8 output == Physical Plan == *Range (0, 2, step=1, splits=2) + + +-- !query 9 +EXPLAIN EXTENDED SELECT * FROM range(3) CROSS JOIN range(3) +-- !query 9 schema +struct +-- !query 9 output +== Parsed Logical Plan == +'Project [*] ++- 'Join Cross + :- 'UnresolvedTableValuedFunction range, [3] + +- 'UnresolvedTableValuedFunction range, [3] + +== Analyzed Logical Plan == +id: bigint, id: bigint +Project [id#xL, id#xL] ++- Join Cross + :- Range (0, 3, step=1, splits=None) + +- Range (0, 3, step=1, splits=None) + +== Optimized Logical Plan == +Join Cross +:- Range (0, 3, step=1, splits=None) ++- Range (0, 3, step=1, splits=None) + +== Physical Plan == +BroadcastNestedLoopJoin BuildRight, Cross +:- *Range (0, 3, step=1, splits=2) ++- BroadcastExchange IdentityBroadcastMode + +- *Range (0, 3, step=1, splits=2) From 3aa4e464a8c81994c6b7f76d445640da719af6ed Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 11 May 2017 09:49:05 -0700 Subject: [PATCH 5/7] [SPARK-20416][SQL] Print UDF names in EXPLAIN ## What changes were proposed in this pull request? This pr added `withName` in `UserDefinedFunction` for printing UDF names in EXPLAIN ## How was this patch tested? Added tests in `UDFSuite`. Author: Takeshi Yamamuro Closes #17712 from maropu/SPARK-20416. --- .../apache/spark/ml/feature/Bucketizer.scala | 2 +- .../apache/spark/sql/UDFRegistration.scala | 50 +++++++++---------- .../sql/expressions/UserDefinedFunction.scala | 13 +++++ .../scala/org/apache/spark/sql/UDFSuite.scala | 12 +++-- 4 files changed, 46 insertions(+), 31 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index bb8f2a3aa5f71..46b512f8aea7e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -114,7 +114,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String val bucketizer: UserDefinedFunction = udf { (feature: Double) => Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid) - } + }.withName("bucketizer") val newCol = bucketizer(filteredDataset($(inputCol)).cast(DoubleType)) val newField = prepOutputField(filteredDataset.schema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 5fd7123af3a03..1bceac41b9de7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.execution.aggregate.ScalaUDAF import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.expressions.{UserDefinedAggregateFunction, UserDefinedFunction} -import org.apache.spark.sql.types.{DataType, DataTypes} +import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils /** @@ -114,7 +114,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try($inputTypes).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) }""") } @@ -147,7 +147,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -160,7 +160,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -173,7 +173,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -186,7 +186,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -199,7 +199,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -212,7 +212,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -225,7 +225,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -238,7 +238,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -251,7 +251,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -264,7 +264,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -277,7 +277,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -290,7 +290,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -303,7 +303,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -316,7 +316,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -329,7 +329,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -342,7 +342,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -355,7 +355,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -368,7 +368,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -381,7 +381,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -394,7 +394,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -407,7 +407,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -420,7 +420,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -433,7 +433,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: ScalaReflection.schemaFor[A22].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 5a0f488149ea4..0c5f1b436591d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -47,6 +47,7 @@ case class UserDefinedFunction protected[sql] ( dataType: DataType, inputTypes: Option[Seq[DataType]]) { + private var _nameOption: Option[String] = None private var _nullable: Boolean = true /** @@ -67,15 +68,27 @@ case class UserDefinedFunction protected[sql] ( dataType, exprs.map(_.expr), inputTypes.getOrElse(Nil), + udfName = _nameOption, nullable = _nullable)) } private def copyAll(): UserDefinedFunction = { val udf = copy() + udf._nameOption = _nameOption udf._nullable = _nullable udf } + /** + * Updates UserDefinedFunction with a given name. + * + * @since 2.3.0 + */ + def withName(name: String): this.type = { + this._nameOption = Option(name) + this + } + /** * Updates UserDefinedFunction with a given nullability. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 6f8723af91cea..b4f744b193ada 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -263,10 +263,12 @@ class UDFSuite extends QueryTest with SharedSQLContext { val sparkPlan = spark.sessionState.executePlan(explain).executedPlan sparkPlan.executeCollect().map(_.getString(0).trim).headOption.getOrElse("") } - val udf1 = "myUdf1" - val udf2 = "myUdf2" - spark.udf.register(udf1, (n: Int) => { n + 1 }) - spark.udf.register(udf2, (n: Int) => { n * 1 }) - assert(explainStr(sql("SELECT myUdf1(myUdf2(1))")).contains(s"UDF:$udf1(UDF:$udf2(1))")) + val udf1Name = "myUdf1" + val udf2Name = "myUdf2" + val udf1 = spark.udf.register(udf1Name, (n: Int) => n + 1) + val udf2 = spark.udf.register(udf2Name, (n: Int) => n * 1) + assert(explainStr(sql("SELECT myUdf1(myUdf2(1))")).contains(s"UDF:$udf1Name(UDF:$udf2Name(1))")) + assert(explainStr(spark.range(1).select(udf1(udf2(functions.lit(1))))) + .contains(s"UDF:$udf1Name(UDF:$udf2Name(1))")) } } From 7144b51809aa99ac076786c369389e2330142beb Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Thu, 11 May 2017 10:55:11 -0700 Subject: [PATCH 6/7] [SPARK-20600][SS] KafkaRelation should be pretty printed in web UI ## What changes were proposed in this pull request? User-friendly name of `KafkaRelation` in web UI (under Details for Query). ### Before spark-20600-before ### After spark-20600-after ## How was this patch tested? Local build ``` ./bin/spark-shell --jars ~/.m2/repository/org/apache/spark/spark-sql-kafka-0-10_2.11/2.3.0-SNAPSHOT/spark-sql-kafka-0-10_2.11-2.3.0-SNAPSHOT.jar --packages org.apache.kafka:kafka-clients:0.10.0.1 ``` Author: Jacek Laskowski Closes #17917 from jaceklaskowski/SPARK-20600-KafkaRelation-webUI. --- .../scala/org/apache/spark/sql/kafka010/KafkaRelation.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala index 97bd283169323..7103709969c18 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala @@ -143,4 +143,7 @@ private[kafka010] class KafkaRelation( validateTopicPartitions(partitions, partitionOffsets) } } + + override def toString: String = + s"KafkaRelation(strategy=$strategy, start=$startingOffsets, end=$endingOffsets)" } From 04901dd03a3f8062fd39ea38d585935ff71a9248 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 11 May 2017 11:06:29 -0700 Subject: [PATCH 7/7] [SPARK-20431][SQL] Specify a schema by using a DDL-formatted string ## What changes were proposed in this pull request? This pr supported a DDL-formatted string in `DataFrameReader.schema`. This fix could make users easily define a schema without importing `o.a.spark.sql.types._`. ## How was this patch tested? Added tests in `DataFrameReaderWriterSuite`. Author: Takeshi Yamamuro Closes #17719 from maropu/SPARK-20431. --- python/pyspark/sql/readwriter.py | 23 ++++++++++++------- .../apache/spark/sql/DataFrameReader.scala | 12 ++++++++++ .../sql/test/DataFrameReaderWriterSuite.scala | 9 ++++++++ 3 files changed, 36 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 61a6b76a79aed..5cf719bd65ae4 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -96,14 +96,18 @@ def schema(self, schema): By specifying the schema here, the underlying data source can skip the schema inference step, and thus speed up data loading. - :param schema: a :class:`pyspark.sql.types.StructType` object + :param schema: a :class:`pyspark.sql.types.StructType` object or a DDL-formatted string + (For example ``col0 INT, col1 DOUBLE``). """ from pyspark.sql import SparkSession - if not isinstance(schema, StructType): - raise TypeError("schema should be StructType") spark = SparkSession.builder.getOrCreate() - jschema = spark._jsparkSession.parseDataType(schema.json()) - self._jreader = self._jreader.schema(jschema) + if isinstance(schema, StructType): + jschema = spark._jsparkSession.parseDataType(schema.json()) + self._jreader = self._jreader.schema(jschema) + elif isinstance(schema, basestring): + self._jreader = self._jreader.schema(schema) + else: + raise TypeError("schema should be StructType or string") return self @since(1.5) @@ -137,7 +141,8 @@ def load(self, path=None, format=None, schema=None, **options): :param path: optional string or a list of string for file-system backed data sources. :param format: optional string for format of the data source. Default to 'parquet'. - :param schema: optional :class:`pyspark.sql.types.StructType` for the input schema. + :param schema: optional :class:`pyspark.sql.types.StructType` for the input schema + or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). :param options: all other string options >>> df = spark.read.load('python/test_support/sql/parquet_partitioned', opt1=True, @@ -181,7 +186,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param path: string represents path to the JSON dataset, or a list of paths, or RDD of Strings storing JSON objects. - :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema. + :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema or + a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). :param primitivesAsString: infers all primitive values as a string type. If None is set, it uses the default value, ``false``. :param prefersDecimal: infers all floating-point values as a decimal type. If the values @@ -324,7 +330,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ``inferSchema`` option or specify the schema explicitly using ``schema``. :param path: string, or list of strings, for input path(s). - :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema. + :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema + or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). :param sep: sets the single character as a separator for each field and value. If None is set, it uses the default value, ``,``. :param encoding: decodes the CSV files by the given encoding type. If None is set, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index c1b32917415ae..0f96e82cedf4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -67,6 +67,18 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { this } + /** + * Specifies the schema by using the input DDL-formatted string. Some data sources (e.g. JSON) can + * infer the input schema automatically from data. By specifying the schema here, the underlying + * data source can skip the schema inference step, and thus speed up data loading. + * + * @since 2.3.0 + */ + def schema(schemaString: String): DataFrameReader = { + this.userSpecifiedSchema = Option(StructType.fromDDL(schemaString)) + this + } + /** * Adds an input option for the underlying data source. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index fb15e7def6dbe..306aecb5bbc86 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -128,6 +128,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be import testImplicits._ private val userSchema = new StructType().add("s", StringType) + private val userSchemaString = "s STRING" private val textSchema = new StructType().add("value", StringType) private val data = Seq("1", "2", "3") private val dir = Utils.createTempDir(namePrefix = "input").getCanonicalPath @@ -678,4 +679,12 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be assert(e.contains("User specified schema not supported with `table`")) } } + + test("SPARK-20431: Specify a schema by using a DDL-formatted string") { + spark.createDataset(data).write.mode(SaveMode.Overwrite).text(dir) + testRead(spark.read.schema(userSchemaString).text(), Seq.empty, userSchema) + testRead(spark.read.schema(userSchemaString).text(dir), data, userSchema) + testRead(spark.read.schema(userSchemaString).text(dir, dir), data ++ data, userSchema) + testRead(spark.read.schema(userSchemaString).text(Seq(dir, dir): _*), data ++ data, userSchema) + } }