From 8ac09108fcf3fb62a812333a5b386b566a9d98ec Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Tue, 1 Nov 2016 10:46:36 -0700 Subject: [PATCH 001/198] [SPARK-17848][ML] Move LabelCol datatype cast into Predictor.fit ## What changes were proposed in this pull request? 1, move cast to `Predictor` 2, and then, remove unnecessary cast ## How was this patch tested? existing tests Author: Zheng RuiFeng Closes #15414 from zhengruifeng/move_cast. --- .../scala/org/apache/spark/ml/Predictor.scala | 12 ++- .../spark/ml/classification/Classifier.scala | 4 +- .../ml/classification/GBTClassifier.scala | 2 +- .../classification/LogisticRegression.scala | 2 +- .../spark/ml/classification/NaiveBayes.scala | 2 +- .../GeneralizedLinearRegression.scala | 2 +- .../ml/regression/LinearRegression.scala | 2 +- .../org/apache/spark/ml/PredictorSuite.scala | 82 +++++++++++++++++++ .../LogisticRegressionSuite.scala | 1 - 9 files changed, 98 insertions(+), 11 deletions(-) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index e29d7f48a1d6b..aa92edde7acd1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -58,7 +58,8 @@ private[ml] trait PredictorParams extends Params /** * :: DeveloperApi :: - * Abstraction for prediction problems (regression and classification). + * Abstraction for prediction problems (regression and classification). It accepts all NumericType + * labels and will automatically cast it to DoubleType in [[fit()]]. * * @tparam FeaturesType Type of features. * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. @@ -87,7 +88,12 @@ abstract class Predictor[ // This handles a few items such as schema validation. // Developers only need to implement train(). transformSchema(dataset.schema, logging = true) - copyValues(train(dataset).setParent(this)) + + // Cast LabelCol to DoubleType and keep the metadata. + val labelMeta = dataset.schema($(labelCol)).metadata + val casted = dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelMeta) + + copyValues(train(casted).setParent(this)) } override def copy(extra: ParamMap): Learner @@ -121,7 +127,7 @@ abstract class Predictor[ * and put it in an RDD with strong types. */ protected def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = { - dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index d1b21b16f2342..a3da3067e1b5f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -71,7 +71,7 @@ abstract class Classifier[ * and put it in an RDD with strong types. * * @param dataset DataFrame with columns for labels ([[org.apache.spark.sql.types.NumericType]]) - * and features ([[Vector]]). Labels are cast to [[DoubleType]]. + * and features ([[Vector]]). * @param numClasses Number of classes label can take. Labels must be integers in the range * [0, numClasses). * @throws SparkException if any label is not an integer >= 0 @@ -79,7 +79,7 @@ abstract class Classifier[ protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = { require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" + s" $numClasses, but requires numClasses > 0.") - dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" + s" dataset with invalid label $label. Labels must be integers in range" + diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 8bffe0cda0327..f8f164e8c14bd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -128,7 +128,7 @@ class GBTClassifier @Since("1.4.0") ( // We copy and modify this from Classifier.extractLabeledPoints since GBT only supports // 2 classes now. This lets us provide a more precise error message. val oldDataset: RDD[LabeledPoint] = - dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => require(label == 0 || label == 1, s"GBTClassifier was given" + s" dataset with invalid label $label. Labels must be in {0,1}; note that" + diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 8fdaae04c42ec..c4651054fd765 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -322,7 +322,7 @@ class LogisticRegression @Since("1.2.0") ( LogisticRegressionModel = { val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = - dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 994ed993c99df..b03a07a6bc1e7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -171,7 +171,7 @@ class NaiveBayes @Since("1.5.0") ( // Aggregates term frequencies per label. // TODO: Calling aggregateByKey and collect creates two stages, we can implement something // TODO: similar to reduceByKeyLocally to save one stage. - val aggregated = dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd + val aggregated = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd .map { row => (row.getDouble(0), (row.getDouble(1), row.getAs[Vector](2))) }.aggregateByKey[(Double, DenseVector)]((0.0, Vectors.zeros(numFeatures).toDense))( seqOp = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 33cb25c8c7f66..8656ecf609ea4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -255,7 +255,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = - dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 519f3bdec82df..ae876b3839734 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -190,7 +190,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = dataset.select( - col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + col($(labelCol)), w, col($(featuresCol))).rdd.map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala new file mode 100644 index 0000000000000..03e0c536a973e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +class PredictorSuite extends SparkFunSuite with MLlibTestSparkContext { + + import PredictorSuite._ + + test("should support all NumericType labels and not support other types") { + val df = spark.createDataFrame(Seq( + (0, Vectors.dense(0, 2, 3)), + (1, Vectors.dense(0, 3, 9)), + (0, Vectors.dense(0, 2, 6)) + )).toDF("label", "features") + + val types = + Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) + + val predictor = new MockPredictor() + + types.foreach { t => + predictor.fit(df.select(col("label").cast(t), col("features"))) + } + + intercept[IllegalArgumentException] { + predictor.fit(df.select(col("label").cast(StringType), col("features"))) + } + } +} + +object PredictorSuite { + + class MockPredictor(override val uid: String) + extends Predictor[Vector, MockPredictor, MockPredictionModel] { + + def this() = this(Identifiable.randomUID("mockpredictor")) + + override def train(dataset: Dataset[_]): MockPredictionModel = { + require(dataset.schema("label").dataType == DoubleType) + new MockPredictionModel(uid) + } + + override def copy(extra: ParamMap): MockPredictor = + throw new NotImplementedError() + } + + class MockPredictionModel(override val uid: String) + extends PredictionModel[Vector, MockPredictionModel] { + + def this() = this(Identifiable.randomUID("mockpredictormodel")) + + override def predict(features: Vector): Double = + throw new NotImplementedError() + + override def copy(extra: ParamMap): MockPredictionModel = + throw new NotImplementedError() + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index bc631dc6d3149..8771fd2e9d2b2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -1807,7 +1807,6 @@ class LogisticRegressionSuite .objectiveHistory .sliding(2) .forall(x => x(0) >= x(1))) - } test("binary logistic regression with weighted data") { From 8cdf143f4b1ca5c6bc0256808e6f42d9ef299cbd Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Tue, 1 Nov 2016 11:17:35 -0700 Subject: [PATCH 002/198] [SPARK-18103][FOLLOW-UP][SQL][MINOR] Rename `MetadataLogFileCatalog` to `MetadataLogFileIndex` ## What changes were proposed in this pull request? This is a follow-up to https://github.com/apache/spark/pull/15634. ## How was this patch tested? N/A Author: Liwei Lin Closes #15712 from lw-lin/18103. --- .../{MetadataLogFileCatalog.scala => MetadataLogFileIndex.scala} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/{MetadataLogFileCatalog.scala => MetadataLogFileIndex.scala} (100%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala similarity index 100% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileCatalog.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala From 8a538c97b556f80f67c80519af0ce879557050d5 Mon Sep 17 00:00:00 2001 From: Ergin Seyfe Date: Tue, 1 Nov 2016 11:18:42 -0700 Subject: [PATCH 003/198] [SPARK-18189][SQL] Fix serialization issue in KeyValueGroupedDataset ## What changes were proposed in this pull request? Likewise [DataSet.scala](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala#L156) KeyValueGroupedDataset should mark the queryExecution as transient. As mentioned in the Jira ticket, without transient we saw serialization issues like ``` Caused by: java.io.NotSerializableException: org.apache.spark.sql.execution.QueryExecution Serialization stack: - object not serializable (class: org.apache.spark.sql.execution.QueryExecution, value: == ``` ## How was this patch tested? Run the query which is specified in the Jira ticket before and after: ``` val a = spark.createDataFrame(sc.parallelize(Seq((1,2),(3,4)))).as[(Int,Int)] val grouped = a.groupByKey( {x:(Int,Int)=>x._1} ) val mappedGroups = grouped.mapGroups((k,x)=> {(k,1)} ) val yyy = sc.broadcast(1) val last = mappedGroups.rdd.map(xx=> { val simpley = yyy.value 1 } ) ``` Author: Ergin Seyfe Closes #15706 from seyfe/keyvaluegrouped_serialization. --- .../scala/org/apache/spark/repl/ReplSuite.scala | 17 +++++++++++++++++ .../spark/sql/KeyValueGroupedDataset.scala | 2 +- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 9262e938c2a60..96d2dfc2658b9 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -473,4 +473,21 @@ class ReplSuite extends SparkFunSuite { assertDoesNotContain("AssertionError", output) assertDoesNotContain("Exception", output) } + + test("SPARK-18189: Fix serialization issue in KeyValueGroupedDataset") { + val resultValue = 12345 + val output = runInterpreter("local", + s""" + |val keyValueGrouped = Seq((1, 2), (3, 4)).toDS().groupByKey(_._1) + |val mapGroups = keyValueGrouped.mapGroups((k, v) => (k, 1)) + |val broadcasted = sc.broadcast($resultValue) + | + |// Using broadcast triggers serialization issue in KeyValueGroupedDataset + |val dataset = mapGroups.map(_ => broadcasted.value) + |dataset.collect() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains(s": Array[Int] = Array($resultValue, $resultValue)", output) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 4cb0313aa9037..31ce8eb25e808 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.expressions.ReduceAggregator class KeyValueGroupedDataset[K, V] private[sql]( kEncoder: Encoder[K], vEncoder: Encoder[V], - val queryExecution: QueryExecution, + @transient val queryExecution: QueryExecution, private val dataAttributes: Seq[Attribute], private val groupingAttributes: Seq[Attribute]) extends Serializable { From d0272b436512b71f04313e109d3d21a6e9deefca Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Tue, 1 Nov 2016 11:25:11 -0700 Subject: [PATCH 004/198] [SPARK-18148][SQL] Misleading Error Message for Aggregation Without Window/GroupBy ## What changes were proposed in this pull request? Aggregation Without Window/GroupBy expressions will fail in `checkAnalysis`, the error message is a bit misleading, we should generate a more specific error message for this case. For example, ``` spark.read.load("/some-data") .withColumn("date_dt", to_date($"date")) .withColumn("year", year($"date_dt")) .withColumn("week", weekofyear($"date_dt")) .withColumn("user_count", count($"userId")) .withColumn("daily_max_in_week", max($"user_count").over(weeklyWindow)) ) ``` creates the following output: ``` org.apache.spark.sql.AnalysisException: expression '`randomColumn`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.; ``` In the error message above, `randomColumn` doesn't appear in the query(acturally it's added by function `withColumn`), so the message is not enough for the user to address the problem. ## How was this patch tested? Manually test Before: ``` scala> spark.sql("select col, count(col) from tbl") org.apache.spark.sql.AnalysisException: expression 'tbl.`col`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.;; ``` After: ``` scala> spark.sql("select col, count(col) from tbl") org.apache.spark.sql.AnalysisException: grouping expressions sequence is empty, and 'tbl.`col`' is not an aggregate function. Wrap '(count(col#231L) AS count(col)#239L)' in windowing function(s) or wrap 'tbl.`col`' in first() (or first_value) if you don't care which value you get.;; ``` Also add new test sqls in `group-by.sql`. Author: jiangxingbo Closes #15672 from jiangxb1987/groupBy-empty. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 12 ++ .../resources/sql-tests/inputs/group-by.sql | 41 +++++-- .../sql-tests/results/group-by.sql.out | 116 +++++++++++++++--- .../org/apache/spark/sql/SQLQuerySuite.scala | 35 ------ 4 files changed, 140 insertions(+), 64 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 9a7c2a944b588..3455a567b7786 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -214,6 +214,18 @@ trait CheckAnalysis extends PredicateHelper { s"appear in the arguments of an aggregate function.") } } + case e: Attribute if groupingExprs.isEmpty => + // Collect all [[AggregateExpressions]]s. + val aggExprs = aggregateExprs.filter(_.collect { + case a: AggregateExpression => a + }.nonEmpty) + failAnalysis( + s"grouping expressions sequence is empty, " + + s"and '${e.sql}' is not an aggregate function. " + + s"Wrap '${aggExprs.map(_.sql).mkString("(", ", ", ")")}' in windowing " + + s"function(s) or wrap '${e.sql}' in first() (or first_value) " + + s"if you don't care which value you get." + ) case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) => failAnalysis( s"expression '${e.sql}' is neither present in the group by, " + diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 6741703d9d82c..d950ec83d98c3 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -1,17 +1,34 @@ --- Temporary data. -create temporary view myview as values 128, 256 as v(int_col); +-- Test data. +CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES +(1, 1), (1, 2), (2, 1), (2, 2), (3, 1), (3, 2), (null, 1), (3, null), (null, null) +AS testData(a, b); --- group by should produce all input rows, -select int_col, count(*) from myview group by int_col; +-- Aggregate with empty GroupBy expressions. +SELECT a, COUNT(b) FROM testData; +SELECT COUNT(a), COUNT(b) FROM testData; --- group by should produce a single row. -select 'foo', count(*) from myview group by 1; +-- Aggregate with non-empty GroupBy expressions. +SELECT a, COUNT(b) FROM testData GROUP BY a; +SELECT a, COUNT(b) FROM testData GROUP BY b; +SELECT COUNT(a), COUNT(b) FROM testData GROUP BY a; --- group-by should not produce any rows (whole stage code generation). -select 'foo' from myview where int_col == 0 group by 1; +-- Aggregate grouped by literals. +SELECT 'foo', COUNT(a) FROM testData GROUP BY 1; --- group-by should not produce any rows (hash aggregate). -select 'foo', approx_count_distinct(int_col) from myview where int_col == 0 group by 1; +-- Aggregate grouped by literals (whole stage code generation). +SELECT 'foo' FROM testData WHERE a = 0 GROUP BY 1; --- group-by should not produce any rows (sort aggregate). -select 'foo', max(struct(int_col)) from myview where int_col == 0 group by 1; +-- Aggregate grouped by literals (hash aggregate). +SELECT 'foo', APPROX_COUNT_DISTINCT(a) FROM testData WHERE a = 0 GROUP BY 1; + +-- Aggregate grouped by literals (sort aggregate). +SELECT 'foo', MAX(STRUCT(a)) FROM testData WHERE a = 0 GROUP BY 1; + +-- Aggregate with complex GroupBy expressions. +SELECT a + b, COUNT(b) FROM testData GROUP BY a + b; +SELECT a + 2, COUNT(b) FROM testData GROUP BY a + 1; +SELECT a + 1 + 1, COUNT(b) FROM testData GROUP BY a + 1; + +-- Aggregate with nulls. +SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a), AVG(a), VARIANCE(a), STDDEV(a), SUM(a), COUNT(a) +FROM testData; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 9127bd4dd4c6f..a91f04e098b18 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,9 +1,11 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 6 +-- Number of queries: 14 -- !query 0 -create temporary view myview as values 128, 256 as v(int_col) +CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES +(1, 1), (1, 2), (2, 1), (2, 2), (3, 1), (3, 2), (null, 1), (3, null), (null, null) +AS testData(a, b) -- !query 0 schema struct<> -- !query 0 output @@ -11,41 +13,121 @@ struct<> -- !query 1 -select int_col, count(*) from myview group by int_col +SELECT a, COUNT(b) FROM testData -- !query 1 schema -struct +struct<> -- !query 1 output -128 1 -256 1 +org.apache.spark.sql.AnalysisException +grouping expressions sequence is empty, and 'testdata.`a`' is not an aggregate function. Wrap '(count(testdata.`b`) AS `count(b)`)' in windowing function(s) or wrap 'testdata.`a`' in first() (or first_value) if you don't care which value you get.; -- !query 2 -select 'foo', count(*) from myview group by 1 +SELECT COUNT(a), COUNT(b) FROM testData -- !query 2 schema -struct +struct -- !query 2 output -foo 2 +7 7 -- !query 3 -select 'foo' from myview where int_col == 0 group by 1 +SELECT a, COUNT(b) FROM testData GROUP BY a -- !query 3 schema -struct +struct -- !query 3 output - +1 2 +2 2 +3 2 +NULL 1 -- !query 4 -select 'foo', approx_count_distinct(int_col) from myview where int_col == 0 group by 1 +SELECT a, COUNT(b) FROM testData GROUP BY b -- !query 4 schema -struct +struct<> -- !query 4 output - +org.apache.spark.sql.AnalysisException +expression 'testdata.`a`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.; -- !query 5 -select 'foo', max(struct(int_col)) from myview where int_col == 0 group by 1 +SELECT COUNT(a), COUNT(b) FROM testData GROUP BY a -- !query 5 schema -struct> +struct -- !query 5 output +0 1 +2 2 +2 2 +3 2 + + +-- !query 6 +SELECT 'foo', COUNT(a) FROM testData GROUP BY 1 +-- !query 6 schema +struct +-- !query 6 output +foo 7 + + +-- !query 7 +SELECT 'foo' FROM testData WHERE a = 0 GROUP BY 1 +-- !query 7 schema +struct +-- !query 7 output + + +-- !query 8 +SELECT 'foo', APPROX_COUNT_DISTINCT(a) FROM testData WHERE a = 0 GROUP BY 1 +-- !query 8 schema +struct +-- !query 8 output + + + +-- !query 9 +SELECT 'foo', MAX(STRUCT(a)) FROM testData WHERE a = 0 GROUP BY 1 +-- !query 9 schema +struct> +-- !query 9 output + + + +-- !query 10 +SELECT a + b, COUNT(b) FROM testData GROUP BY a + b +-- !query 10 schema +struct<(a + b):int,count(b):bigint> +-- !query 10 output +2 1 +3 2 +4 2 +5 1 +NULL 1 + + +-- !query 11 +SELECT a + 2, COUNT(b) FROM testData GROUP BY a + 1 +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +expression 'testdata.`a`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.; + + +-- !query 12 +SELECT a + 1 + 1, COUNT(b) FROM testData GROUP BY a + 1 +-- !query 12 schema +struct<((a + 1) + 1):int,count(b):bigint> +-- !query 12 output +3 2 +4 2 +5 2 +NULL 1 + + +-- !query 13 +SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a), AVG(a), VARIANCE(a), STDDEV(a), SUM(a), COUNT(a) +FROM testData +-- !query 13 schema +struct +-- !query 13 output +-0.2723801058145729 -1.5069204152249134 1 3 2.142857142857143 0.8095238095238094 0.8997354108424372 15 7 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 1a43d0b2205ca..9a3d93cf17b78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -463,20 +463,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) } - test("agg") { - checkAnswer( - sql("SELECT a, SUM(b) FROM testData2 GROUP BY a"), - Seq(Row(1, 3), Row(2, 3), Row(3, 3))) - } - - test("aggregates with nulls") { - checkAnswer( - sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a)," + - "AVG(a), VARIANCE(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"), - Row(0, -1.5, 1, 3, 2, 1.0, 1, 6, 3) - ) - } - test("select *") { checkAnswer( sql("SELECT * FROM testData"), @@ -1178,27 +1164,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(1)) } - test("throw errors for non-aggregate attributes with aggregation") { - def checkAggregation(query: String, isInvalidQuery: Boolean = true) { - if (isInvalidQuery) { - val e = intercept[AnalysisException](sql(query).queryExecution.analyzed) - assert(e.getMessage contains "group by") - } else { - // Should not throw - sql(query).queryExecution.analyzed - } - } - - checkAggregation("SELECT key, COUNT(*) FROM testData") - checkAggregation("SELECT COUNT(key), COUNT(*) FROM testData", isInvalidQuery = false) - - checkAggregation("SELECT value, COUNT(*) FROM testData GROUP BY key") - checkAggregation("SELECT COUNT(value), SUM(key) FROM testData GROUP BY key", false) - - checkAggregation("SELECT key + 2, COUNT(*) FROM testData GROUP BY key + 1") - checkAggregation("SELECT key + 1 + 1, COUNT(*) FROM testData GROUP BY key + 1", false) - } - testQuietly( "SPARK-16748: SparkExceptions during planning should not wrapped in TreeNodeException") { intercept[SparkException] { From cfac17ee1cec414663b957228e469869eb7673c1 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 1 Nov 2016 12:35:34 -0700 Subject: [PATCH 005/198] [SPARK-18167] Disable flaky SQLQuerySuite test We now know it's a persistent environmental issue that is causing this test to sometimes fail. One hypothesis is that some configuration is leaked from another suite, and depending on suite ordering this can cause this test to fail. I am planning on mining the jenkins logs to try to narrow down which suite could be causing this. For now, disable the test. Author: Eric Liang Closes #15720 from ericl/disable-flaky-test. --- .../org/apache/spark/sql/hive/execution/SQLQuerySuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 8b916932ff543..b9353b5b5d2a7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1565,7 +1565,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3, i._4))) } - test("SPARK-10562: partition by column with mixed case name") { + ignore("SPARK-10562: partition by column with mixed case name") { def runOnce() { withTable("tbl10562") { val df = Seq(2012 -> "a").toDF("Year", "val") From 01dd0083011741c2bbe5ae1d2a25f2c9a1302b76 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 1 Nov 2016 12:46:41 -0700 Subject: [PATCH 006/198] [SPARK-17764][SQL] Add `to_json` supporting to convert nested struct column to JSON string ## What changes were proposed in this pull request? This PR proposes to add `to_json` function in contrast with `from_json` in Scala, Java and Python. It'd be useful if we can convert a same column from/to json. Also, some datasources do not support nested types. If we are forced to save a dataframe into those data sources, we might be able to work around by this function. The usage is as below: ``` scala val df = Seq(Tuple1(Tuple1(1))).toDF("a") df.select(to_json($"a").as("json")).show() ``` ``` bash +--------+ | json| +--------+ |{"_1":1}| +--------+ ``` ## How was this patch tested? Unit tests in `JsonFunctionsSuite` and `JsonExpressionsSuite`. Author: hyukjinkwon Closes #15354 from HyukjinKwon/SPARK-17764. --- python/pyspark/sql/functions.py | 23 +++++++++ python/pyspark/sql/readwriter.py | 2 +- python/pyspark/sql/streaming.py | 2 +- .../expressions/jsonExpressions.scala | 48 ++++++++++++++++++- .../sql/catalyst}/json/JacksonGenerator.scala | 5 +- .../sql/catalyst/json/JacksonUtils.scala | 26 ++++++++++ .../expressions/JsonExpressionsSuite.scala | 9 ++++ .../scala/org/apache/spark/sql/Dataset.scala | 2 +- .../datasources/json/JsonFileFormat.scala | 2 +- .../org/apache/spark/sql/functions.scala | 44 ++++++++++++++++- .../apache/spark/sql/JsonFunctionsSuite.scala | 30 +++++++++--- 11 files changed, 177 insertions(+), 16 deletions(-) rename sql/{core/src/main/scala/org/apache/spark/sql/execution/datasources => catalyst/src/main/scala/org/apache/spark/sql/catalyst}/json/JacksonGenerator.scala (98%) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 7fa3fd2de7ddf..45e3c22bfc6a9 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1744,6 +1744,29 @@ def from_json(col, schema, options={}): return Column(jc) +@ignore_unicode_prefix +@since(2.1) +def to_json(col, options={}): + """ + Converts a column containing a [[StructType]] into a JSON string. Throws an exception, + in the case of an unsupported type. + + :param col: name of column containing the struct + :param options: options to control converting. accepts the same options as the json datasource + + >>> from pyspark.sql import Row + >>> from pyspark.sql.types import * + >>> data = [(1, Row(name='Alice', age=2))] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(to_json(df.value).alias("json")).collect() + [Row(json=u'{"age":2,"name":"Alice"}')] + """ + + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.to_json(_to_java_column(col), options) + return Column(jc) + + @since(1.5) def size(col): """ diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index bc786ef95ed03..b0c51b1e9992e 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -161,7 +161,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None): """ Loads a JSON file (`JSON Lines text format or newline-delimited JSON - <[http://jsonlines.org/>`_) or an RDD of Strings storing JSON objects (one object per + `_) or an RDD of Strings storing JSON objects (one object per record) and returns the result as a :class`DataFrame`. If the ``schema`` parameter is not specified, this function goes diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 559647bbabf67..1c94413e3c457 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -641,7 +641,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, timestampFormat=None): """ Loads a JSON file stream (`JSON Lines text format or newline-delimited JSON - <[http://jsonlines.org/>`_) and returns a :class`DataFrame`. + `_) and returns a :class`DataFrame`. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 65dbd6a4e3f1d..244a5a34f3594 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -17,16 +17,17 @@ package org.apache.spark.sql.catalyst.expressions -import java.io.{ByteArrayOutputStream, StringWriter} +import java.io.{ByteArrayOutputStream, CharArrayWriter, StringWriter} import scala.util.parsing.combinator.RegexParsers import com.fasterxml.jackson.core._ +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions, SparkSQLJsonProcessingException} +import org.apache.spark.sql.catalyst.json._ import org.apache.spark.sql.catalyst.util.ParseModes import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -494,3 +495,46 @@ case class JsonToStruct(schema: StructType, options: Map[String, String], child: override def inputTypes: Seq[AbstractDataType] = StringType :: Nil } + +/** + * Converts a [[StructType]] to a json output string. + */ +case class StructToJson(options: Map[String, String], child: Expression) + extends Expression with CodegenFallback with ExpectsInputTypes { + override def nullable: Boolean = true + + @transient + lazy val writer = new CharArrayWriter() + + @transient + lazy val gen = + new JacksonGenerator(child.dataType.asInstanceOf[StructType], writer) + + override def dataType: DataType = StringType + override def children: Seq[Expression] = child :: Nil + + override def checkInputDataTypes(): TypeCheckResult = { + if (StructType.acceptsType(child.dataType)) { + try { + JacksonUtils.verifySchema(child.dataType.asInstanceOf[StructType]) + TypeCheckResult.TypeCheckSuccess + } catch { + case e: UnsupportedOperationException => + TypeCheckResult.TypeCheckFailure(e.getMessage) + } + } else { + TypeCheckResult.TypeCheckFailure( + s"$prettyName requires that the expression is a struct expression.") + } + } + + override def eval(input: InternalRow): Any = { + gen.write(child.eval(input).asInstanceOf[InternalRow]) + gen.flush() + val json = writer.toString + writer.reset() + UTF8String.fromString(json) + } + + override def inputTypes: Seq[AbstractDataType] = StructType :: Nil +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index 5b55b701862b7..4b548e0e7f978 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -15,15 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.json +package org.apache.spark.sql.catalyst.json import java.io.Writer import com.fasterxml.jackson.core._ -import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.json.JSONOptions +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, MapData} import org.apache.spark.sql.types._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala index c4d9abb2c07e8..3b23c6cd2816f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.json import com.fasterxml.jackson.core.{JsonParser, JsonToken} +import org.apache.spark.sql.types._ + object JacksonUtils { /** * Advance the parser until a null or a specific token is found @@ -29,4 +31,28 @@ object JacksonUtils { case x => x != stopOn } } + + /** + * Verify if the schema is supported in JSON parsing. + */ + def verifySchema(schema: StructType): Unit = { + def verifyType(name: String, dataType: DataType): Unit = dataType match { + case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | + DoubleType | StringType | TimestampType | DateType | BinaryType | _: DecimalType => + + case st: StructType => st.foreach(field => verifyType(field.name, field.dataType)) + + case at: ArrayType => verifyType(name, at.elementType) + + case mt: MapType => verifyType(name, mt.keyType) + + case udt: UserDefinedType[_] => verifyType(name, udt.sqlType) + + case _ => + throw new UnsupportedOperationException( + s"Unable to convert column $name of type ${dataType.simpleString} to JSON.") + } + + schema.foreach(field => verifyType(field.name, field.dataType)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 84623934d95d2..f9db649bc2404 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -343,4 +343,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { null ) } + + test("to_json") { + val schema = StructType(StructField("a", IntegerType) :: Nil) + val struct = Literal.create(create_row(1), schema) + checkEvaluation( + StructToJson(Map.empty, struct), + """{"a":1}""" + ) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 6e0a2471e0fb5..eb2b20afc37cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.json.JacksonGenerator import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -45,7 +46,6 @@ import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, QueryExecution, SQLExecution} import org.apache.spark.sql.execution.command.{CreateViewCommand, ExplainCommand, GlobalTempView, LocalTempView} import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 5a409c04c929d..0e38aefecb673 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -32,7 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions} import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextOutputWriter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5f1efd22d8204..944a476114faf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2883,10 +2883,10 @@ object functions { * (Scala-specific) Parses a column containing a JSON string into a [[StructType]] with the * specified schema. Returns `null`, in the case of an unparseable string. * + * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string * @param options options to control how the json is parsed. accepts the same options and the * json data source. - * @param e a string column containing JSON data. * * @group collection_funcs * @since 2.1.0 @@ -2936,6 +2936,48 @@ object functions { def from_json(e: Column, schema: String, options: java.util.Map[String, String]): Column = from_json(e, DataType.fromJson(schema).asInstanceOf[StructType], options) + + /** + * (Scala-specific) Converts a column containing a [[StructType]] into a JSON string with the + * specified schema. Throws an exception, in the case of an unsupported type. + * + * @param e a struct column. + * @param options options to control how the struct column is converted into a json string. + * accepts the same options and the json data source. + * + * @group collection_funcs + * @since 2.1.0 + */ + def to_json(e: Column, options: Map[String, String]): Column = withExpr { + StructToJson(options, e.expr) + } + + /** + * (Java-specific) Converts a column containing a [[StructType]] into a JSON string with the + * specified schema. Throws an exception, in the case of an unsupported type. + * + * @param e a struct column. + * @param options options to control how the struct column is converted into a json string. + * accepts the same options and the json data source. + * + * @group collection_funcs + * @since 2.1.0 + */ + def to_json(e: Column, options: java.util.Map[String, String]): Column = + to_json(e, options.asScala.toMap) + + /** + * Converts a column containing a [[StructType]] into a JSON string with the + * specified schema. Throws an exception, in the case of an unsupported type. + * + * @param e a struct column. + * + * @group collection_funcs + * @since 2.1.0 + */ + def to_json(e: Column): Column = + to_json(e, Map.empty[String, String]) + /** * Returns length of array or map. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 518d6e92b2ff7..59ae889cf3b92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql -import org.apache.spark.sql.functions.from_json +import org.apache.spark.sql.functions.{from_json, struct, to_json} import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.types.{CalendarIntervalType, IntegerType, StructType} class JsonFunctionsSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -31,7 +31,6 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Row("alice", "5")) } - val tuples: Seq[(String, String)] = ("1", """{"f1": "value1", "f2": "value2", "f3": 3, "f5": 5.23}""") :: ("2", """{"f1": "value12", "f3": "value3", "f2": 2, "f4": 4.01}""") :: @@ -97,7 +96,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(expr, expected) } - test("json_parser") { + test("from_json") { val df = Seq("""{"a": 1}""").toDS() val schema = new StructType().add("a", IntegerType) @@ -106,7 +105,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Row(Row(1)) :: Nil) } - test("json_parser missing columns") { + test("from_json missing columns") { val df = Seq("""{"a": 1}""").toDS() val schema = new StructType().add("b", IntegerType) @@ -115,7 +114,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Row(Row(null)) :: Nil) } - test("json_parser invalid json") { + test("from_json invalid json") { val df = Seq("""{"a" 1}""").toDS() val schema = new StructType().add("a", IntegerType) @@ -123,4 +122,23 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { df.select(from_json($"value", schema)), Row(null) :: Nil) } + + test("to_json") { + val df = Seq(Tuple1(Tuple1(1))).toDF("a") + + checkAnswer( + df.select(to_json($"a")), + Row("""{"_1":1}""") :: Nil) + } + + test("to_json unsupported type") { + val df = Seq(Tuple1(Tuple1("interval -3 month 7 hours"))).toDF("a") + .select(struct($"a._1".cast(CalendarIntervalType).as("a")).as("c")) + val e = intercept[AnalysisException]{ + // Unsupported type throws an exception + df.select(to_json($"c")).collect() + } + assert(e.getMessage.contains( + "Unable to convert column a of type calendarinterval to JSON.")) + } } From 6e6298154aba63831a292117797798131a646869 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 1 Nov 2016 16:23:47 -0700 Subject: [PATCH 007/198] [SPARK-17350][SQL] Disable default use of KryoSerializer in Thrift Server In SPARK-4761 / #3621 (December 2014) we enabled Kryo serialization by default in the Spark Thrift Server. However, I don't think that the original rationale for doing this still holds now that most Spark SQL serialization is now performed via encoders and our UnsafeRow format. In addition, the use of Kryo as the default serializer can introduce performance problems because the creation of new KryoSerializer instances is expensive and we haven't performed instance-reuse optimizations in several code paths (including DirectTaskResult deserialization). Given all of this, I propose to revert back to using JavaSerializer as the default serializer in the Thrift Server. /cc liancheng Author: Josh Rosen Closes #14906 from JoshRosen/disable-kryo-in-thriftserver. --- docs/configuration.md | 5 ++--- .../spark/sql/hive/thriftserver/SparkSQLEnv.scala | 10 ---------- 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 780fc94908d38..0017219e07261 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -767,7 +767,7 @@ Apart from these, the following properties are also available, and may be useful spark.kryo.referenceTracking - true (false when using Spark SQL Thrift Server) + true Whether to track references to the same object when serializing data with Kryo, which is necessary if your object graphs have loops and useful for efficiency if they contain multiple @@ -838,8 +838,7 @@ Apart from these, the following properties are also available, and may be useful spark.serializer - org.apache.spark.serializer.
JavaSerializer (org.apache.spark.serializer.
- KryoSerializer when using Spark SQL Thrift Server) + org.apache.spark.serializer.
JavaSerializer Class to use for serializing objects that will be sent over the network or need to be cached diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 638911599aad3..78a309497ab57 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.hive.thriftserver import java.io.PrintStream -import scala.collection.JavaConverters._ - import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SQLContext} @@ -37,8 +35,6 @@ private[hive] object SparkSQLEnv extends Logging { def init() { if (sqlContext == null) { val sparkConf = new SparkConf(loadDefaults = true) - val maybeSerializer = sparkConf.getOption("spark.serializer") - val maybeKryoReferenceTracking = sparkConf.getOption("spark.kryo.referenceTracking") // If user doesn't specify the appName, we want to get [SparkSQL::localHostName] instead of // the default appName [SparkSQLCLIDriver] in cli or beeline. val maybeAppName = sparkConf @@ -47,12 +43,6 @@ private[hive] object SparkSQLEnv extends Logging { sparkConf .setAppName(maybeAppName.getOrElse(s"SparkSQL::${Utils.localHostName()}")) - .set( - "spark.serializer", - maybeSerializer.getOrElse("org.apache.spark.serializer.KryoSerializer")) - .set( - "spark.kryo.referenceTracking", - maybeKryoReferenceTracking.getOrElse("false")) val sparkSession = SparkSession.builder.config(sparkConf).enableHiveSupport().getOrCreate() sparkContext = sparkSession.sparkContext From b929537b6eb0f8f34497c3dbceea8045bf5dffdb Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 1 Nov 2016 16:49:41 -0700 Subject: [PATCH 008/198] [SPARK-18182] Expose ReplayListenerBus.read() overload which takes string iterator The `ReplayListenerBus.read()` method is used when implementing a custom `ApplicationHistoryProvider`. The current interface only exposes a `read()` method which takes an `InputStream` and performs stream-to-lines conversion itself, but it would also be useful to expose an overloaded method which accepts an iterator of strings, thereby enabling events to be provided from non-`InputStream` sources. Author: Josh Rosen Closes #15698 from JoshRosen/replay-listener-bus-interface. --- .../spark/scheduler/ReplayListenerBus.scala | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala index 2424586431aa0..0bd5a6bc59a9e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala @@ -53,13 +53,24 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { sourceName: String, maybeTruncated: Boolean = false, eventsFilter: ReplayEventsFilter = SELECT_ALL_FILTER): Unit = { + val lines = Source.fromInputStream(logData).getLines() + replay(lines, sourceName, maybeTruncated, eventsFilter) + } + /** + * Overloaded variant of [[replay()]] which accepts an iterator of lines instead of an + * [[InputStream]]. Exposed for use by custom ApplicationHistoryProvider implementations. + */ + def replay( + lines: Iterator[String], + sourceName: String, + maybeTruncated: Boolean, + eventsFilter: ReplayEventsFilter): Unit = { var currentLine: String = null var lineNumber: Int = 0 try { - val lineEntries = Source.fromInputStream(logData) - .getLines() + val lineEntries = lines .zipWithIndex .filter { case (line, _) => eventsFilter(line) } From 91c33a0ca5c8287f710076ed7681e5aa13ca068f Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 1 Nov 2016 17:00:00 -0700 Subject: [PATCH 009/198] [SPARK-18088][ML] Various ChiSqSelector cleanups ## What changes were proposed in this pull request? - Renamed kbest to numTopFeatures - Renamed alpha to fpr - Added missing Since annotations - Doc cleanups ## How was this patch tested? Added new standardized unit tests for spark.ml. Improved existing unit test coverage a bit. Author: Joseph K. Bradley Closes #15647 from jkbradley/chisqselector-follow-ups. --- docs/ml-features.md | 12 +- docs/mllib-feature-extraction.md | 15 +- .../spark/ml/feature/ChiSqSelector.scala | 59 ++++---- .../mllib/api/python/PythonMLLibAPI.scala | 4 +- .../spark/mllib/feature/ChiSqSelector.scala | 45 +++--- .../spark/ml/feature/ChiSqSelectorSuite.scala | 135 ++++++++++-------- .../mllib/feature/ChiSqSelectorSuite.scala | 17 +-- python/pyspark/ml/feature.py | 37 ++--- python/pyspark/mllib/feature.py | 58 ++++---- 9 files changed, 197 insertions(+), 185 deletions(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index 64c6a160239cc..352887d3ba6e3 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1338,14 +1338,14 @@ for more details on the API. `ChiSqSelector` stands for Chi-Squared feature selection. It operates on labeled data with categorical features. ChiSqSelector uses the [Chi-Squared test of independence](https://en.wikipedia.org/wiki/Chi-squared_test) to decide which -features to choose. It supports three selection methods: `KBest`, `Percentile` and `FPR`: +features to choose. It supports three selection methods: `numTopFeatures`, `percentile`, `fpr`: -* `KBest` chooses the `k` top features according to a chi-squared test. This is akin to yielding the features with the most predictive power. -* `Percentile` is similar to `KBest` but chooses a fraction of all features instead of a fixed number. -* `FPR` chooses all features whose false positive rate meets some threshold. +* `numTopFeatures` chooses a fixed number of top features according to a chi-squared test. This is akin to yielding the features with the most predictive power. +* `percentile` is similar to `numTopFeatures` but chooses a fraction of all features instead of a fixed number. +* `fpr` chooses all features whose p-value is below a threshold, thus controlling the false positive rate of selection. -By default, the selection method is `KBest`, the default number of top features is 50. User can use -`setNumTopFeatures`, `setPercentile` and `setAlpha` to set different selection methods. +By default, the selection method is `numTopFeatures`, with the default number of top features set to 50. +The user can choose a selection method using `setSelectorType`. **Examples** diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 87e1e027e945b..42568c312e70e 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -227,22 +227,19 @@ both speed and statistical learning behavior. [`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) implements Chi-Squared feature selection. It operates on labeled data with categorical features. ChiSqSelector uses the [Chi-Squared test of independence](https://en.wikipedia.org/wiki/Chi-squared_test) to decide which -features to choose. It supports three selection methods: `KBest`, `Percentile` and `FPR`: +features to choose. It supports three selection methods: `numTopFeatures`, `percentile`, `fpr`: -* `KBest` chooses the `k` top features according to a chi-squared test. This is akin to yielding the features with the most predictive power. -* `Percentile` is similar to `KBest` but chooses a fraction of all features instead of a fixed number. -* `FPR` chooses all features whose false positive rate meets some threshold. +* `numTopFeatures` chooses a fixed number of top features according to a chi-squared test. This is akin to yielding the features with the most predictive power. +* `percentile` is similar to `numTopFeatures` but chooses a fraction of all features instead of a fixed number. +* `fpr` chooses all features whose p-value is below a threshold, thus controlling the false positive rate of selection. -By default, the selection method is `KBest`, the default number of top features is 50. User can use -`setNumTopFeatures`, `setPercentile` and `setAlpha` to set different selection methods. +By default, the selection method is `numTopFeatures`, with the default number of top features set to 50. +The user can choose a selection method using `setSelectorType`. The number of features to select can be tuned using a held-out validation set. ### Model Fitting -`ChiSqSelector` takes a `numTopFeatures` parameter specifying the number of top features that -the selector will select. - The [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) method takes an input of `RDD[LabeledPoint]` with categorical features, learns the summary statistics, and then returns a `ChiSqSelectorModel` which can transform an input dataset into the reduced feature space. diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index d0385e220e1e2..653fa41124f88 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -42,69 +42,80 @@ private[feature] trait ChiSqSelectorParams extends Params with HasFeaturesCol with HasOutputCol with HasLabelCol { /** - * Number of features that selector will select (ordered by statistic value descending). If the + * Number of features that selector will select, ordered by ascending p-value. If the * number of features is less than numTopFeatures, then this will select all features. - * Only applicable when selectorType = "kbest". + * Only applicable when selectorType = "numTopFeatures". * The default value of numTopFeatures is 50. * * @group param */ + @Since("1.6.0") final val numTopFeatures = new IntParam(this, "numTopFeatures", - "Number of features that selector will select, ordered by statistics value descending. If the" + + "Number of features that selector will select, ordered by ascending p-value. If the" + " number of features is < numTopFeatures, then this will select all features.", ParamValidators.gtEq(1)) setDefault(numTopFeatures -> 50) /** @group getParam */ + @Since("1.6.0") def getNumTopFeatures: Int = $(numTopFeatures) /** * Percentile of features that selector will select, ordered by statistics value descending. * Only applicable when selectorType = "percentile". * Default value is 0.1. + * @group param */ + @Since("2.1.0") final val percentile = new DoubleParam(this, "percentile", - "Percentile of features that selector will select, ordered by statistics value descending.", + "Percentile of features that selector will select, ordered by ascending p-value.", ParamValidators.inRange(0, 1)) setDefault(percentile -> 0.1) /** @group getParam */ + @Since("2.1.0") def getPercentile: Double = $(percentile) /** * The highest p-value for features to be kept. * Only applicable when selectorType = "fpr". * Default value is 0.05. + * @group param */ - final val alpha = new DoubleParam(this, "alpha", "The highest p-value for features to be kept.", + final val fpr = new DoubleParam(this, "fpr", "The highest p-value for features to be kept.", ParamValidators.inRange(0, 1)) - setDefault(alpha -> 0.05) + setDefault(fpr -> 0.05) /** @group getParam */ - def getAlpha: Double = $(alpha) + def getFpr: Double = $(fpr) /** * The selector type of the ChisqSelector. - * Supported options: "kbest" (default), "percentile" and "fpr". + * Supported options: "numTopFeatures" (default), "percentile", "fpr". + * @group param */ + @Since("2.1.0") final val selectorType = new Param[String](this, "selectorType", "The selector type of the ChisqSelector. " + - "Supported options: kbest (default), percentile and fpr.", - ParamValidators.inArray[String](OldChiSqSelector.supportedSelectorTypes.toArray)) - setDefault(selectorType -> OldChiSqSelector.KBest) + "Supported options: " + OldChiSqSelector.supportedSelectorTypes.mkString(", "), + ParamValidators.inArray[String](OldChiSqSelector.supportedSelectorTypes)) + setDefault(selectorType -> OldChiSqSelector.NumTopFeatures) /** @group getParam */ + @Since("2.1.0") def getSelectorType: String = $(selectorType) } /** * Chi-Squared feature selection, which selects categorical features to use for predicting a * categorical label. - * The selector supports three selection methods: `kbest`, `percentile` and `fpr`. - * `kbest` chooses the `k` top features according to a chi-squared test. - * `percentile` is similar but chooses a fraction of all features instead of a fixed number. - * `fpr` chooses all features whose false positive rate meets some threshold. - * By default, the selection method is `kbest`, the default number of top features is 50. + * The selector supports different selection methods: `numTopFeatures`, `percentile`, `fpr`. + * - `numTopFeatures` chooses a fixed number of top features according to a chi-squared test. + * - `percentile` is similar but chooses a fraction of all features instead of a fixed number. + * - `fpr` chooses all features whose p-value is below a threshold, thus controlling the false + * positive rate of selection. + * By default, the selection method is `numTopFeatures`, with the default number of top features + * set to 50. */ @Since("1.6.0") final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: String) @@ -113,10 +124,6 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str @Since("1.6.0") def this() = this(Identifiable.randomUID("chiSqSelector")) - /** @group setParam */ - @Since("2.1.0") - def setSelectorType(value: String): this.type = set(selectorType, value) - /** @group setParam */ @Since("1.6.0") def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value) @@ -127,7 +134,11 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str /** @group setParam */ @Since("2.1.0") - def setAlpha(value: Double): this.type = set(alpha, value) + def setFpr(value: Double): this.type = set(fpr, value) + + /** @group setParam */ + @Since("2.1.0") + def setSelectorType(value: String): this.type = set(selectorType, value) /** @group setParam */ @Since("1.6.0") @@ -153,15 +164,15 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str .setSelectorType($(selectorType)) .setNumTopFeatures($(numTopFeatures)) .setPercentile($(percentile)) - .setAlpha($(alpha)) + .setFpr($(fpr)) val model = selector.fit(input) copyValues(new ChiSqSelectorModel(uid, model).setParent(this)) } @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { - val otherPairs = OldChiSqSelector.supportedTypeAndParamPairs.filter(_._1 != $(selectorType)) - otherPairs.foreach { case (_, paramName: String) => + val otherPairs = OldChiSqSelector.supportedSelectorTypes.filter(_ != $(selectorType)) + otherPairs.foreach { paramName: String => if (isSet(getParam(paramName))) { logWarning(s"Param $paramName will take no effect when selector type = ${$(selectorType)}.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 904000f50d0a2..034e3625e8c01 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -638,13 +638,13 @@ private[python] class PythonMLLibAPI extends Serializable { selectorType: String, numTopFeatures: Int, percentile: Double, - alpha: Double, + fpr: Double, data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = { new ChiSqSelector() .setSelectorType(selectorType) .setNumTopFeatures(numTopFeatures) .setPercentile(percentile) - .setAlpha(alpha) + .setFpr(fpr) .fit(data.rdd) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index f8276de4f23d4..f9156b642785f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -161,7 +161,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { Loader.checkSchema[Data](dataFrame.schema) val features = dataArray.rdd.map { - case Row(feature: Int) => (feature) + case Row(feature: Int) => feature }.collect() new ChiSqSelectorModel(features) @@ -171,18 +171,20 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { /** * Creates a ChiSquared feature selector. - * The selector supports three selection methods: `kbest`, `percentile` and `fpr`. - * `kbest` chooses the `k` top features according to a chi-squared test. - * `percentile` is similar but chooses a fraction of all features instead of a fixed number. - * `fpr` chooses all features whose false positive rate meets some threshold. - * By default, the selection method is `kbest`, the default number of top features is 50. + * The selector supports different selection methods: `numTopFeatures`, `percentile`, `fpr`. + * - `numTopFeatures` chooses a fixed number of top features according to a chi-squared test. + * - `percentile` is similar but chooses a fraction of all features instead of a fixed number. + * - `fpr` chooses all features whose p-value is below a threshold, thus controlling the false + * positive rate of selection. + * By default, the selection method is `numTopFeatures`, with the default number of top features + * set to 50. */ @Since("1.3.0") class ChiSqSelector @Since("2.1.0") () extends Serializable { var numTopFeatures: Int = 50 var percentile: Double = 0.1 - var alpha: Double = 0.05 - var selectorType = ChiSqSelector.KBest + var fpr: Double = 0.05 + var selectorType = ChiSqSelector.NumTopFeatures /** * The is the same to call this() and setNumTopFeatures(numTopFeatures) @@ -207,15 +209,15 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { } @Since("2.1.0") - def setAlpha(value: Double): this.type = { - require(0.0 <= value && value <= 1.0, "Alpha must be in [0,1]") - alpha = value + def setFpr(value: Double): this.type = { + require(0.0 <= value && value <= 1.0, "FPR must be in [0,1]") + fpr = value this } @Since("2.1.0") def setSelectorType(value: String): this.type = { - require(ChiSqSelector.supportedSelectorTypes.toSeq.contains(value), + require(ChiSqSelector.supportedSelectorTypes.contains(value), s"ChiSqSelector Type: $value was not supported.") selectorType = value this @@ -232,7 +234,7 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = { val chiSqTestResult = Statistics.chiSqTest(data).zipWithIndex val features = selectorType match { - case ChiSqSelector.KBest => + case ChiSqSelector.NumTopFeatures => chiSqTestResult .sortBy { case (res, _) => res.pValue } .take(numTopFeatures) @@ -242,7 +244,7 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { .take((chiSqTestResult.length * percentile).toInt) case ChiSqSelector.FPR => chiSqTestResult - .filter { case (res, _) => res.pValue < alpha } + .filter { case (res, _) => res.pValue < fpr } case errorType => throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType") } @@ -251,22 +253,17 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { } } -@Since("2.1.0") -object ChiSqSelector { +private[spark] object ChiSqSelector { - /** String name for `kbest` selector type. */ - private[spark] val KBest: String = "kbest" + /** String name for `numTopFeatures` selector type. */ + val NumTopFeatures: String = "numTopFeatures" /** String name for `percentile` selector type. */ - private[spark] val Percentile: String = "percentile" + val Percentile: String = "percentile" /** String name for `fpr` selector type. */ private[spark] val FPR: String = "fpr" - /** Set of selector type and param pairs that ChiSqSelector supports. */ - private[spark] val supportedTypeAndParamPairs = Set(KBest -> "numTopFeatures", - Percentile -> "percentile", FPR -> "alpha") - /** Set of selector types that ChiSqSelector supports. */ - private[spark] val supportedSelectorTypes = supportedTypeAndParamPairs.map(_._1) + val supportedSelectorTypes: Array[String] = Array(NumTopFeatures, Percentile, FPR) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index 6af06d82d671a..80970fd744881 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -19,85 +19,72 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.feature import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Row +import org.apache.spark.sql.{Dataset, Row} class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - test("Test Chi-Square selector") { - import testImplicits._ - val data = Seq( - LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))), - LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))), - LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))), - LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0))) - ) + @transient var dataset: Dataset[_] = _ - val preFilteredData = Seq( - Vectors.dense(8.0), - Vectors.dense(0.0), - Vectors.dense(0.0), - Vectors.dense(8.0) - ) + override def beforeAll(): Unit = { + super.beforeAll() - val df = sc.parallelize(data.zip(preFilteredData)) - .map(x => (x._1.label, x._1.features, x._2)) - .toDF("label", "data", "preFilteredData") - - val selector = new ChiSqSelector() - .setSelectorType("kbest") - .setNumTopFeatures(1) - .setFeaturesCol("data") - .setLabelCol("label") - .setOutputCol("filtered") - - selector.fit(df).transform(df).select("filtered", "preFilteredData").collect().foreach { - case Row(vec1: Vector, vec2: Vector) => - assert(vec1 ~== vec2 absTol 1e-1) - } - - selector.setSelectorType("percentile").setPercentile(0.34).fit(df).transform(df) - .select("filtered", "preFilteredData").collect().foreach { - case Row(vec1: Vector, vec2: Vector) => - assert(vec1 ~== vec2 absTol 1e-1) - } + // Toy dataset, including the top feature for a chi-squared test. + // These data are chosen such that each feature's test has a distinct p-value. + /* To verify the results with R, run: + library(stats) + x1 <- c(8.0, 0.0, 0.0, 7.0, 8.0) + x2 <- c(7.0, 9.0, 9.0, 9.0, 7.0) + x3 <- c(0.0, 6.0, 8.0, 5.0, 3.0) + y <- c(0.0, 1.0, 1.0, 2.0, 2.0) + chisq.test(x1,y) + chisq.test(x2,y) + chisq.test(x3,y) + */ + dataset = spark.createDataFrame(Seq( + (0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0))), Vectors.dense(8.0)), + (1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0))), Vectors.dense(0.0)), + (1.0, Vectors.dense(Array(0.0, 9.0, 8.0)), Vectors.dense(0.0)), + (2.0, Vectors.dense(Array(7.0, 9.0, 5.0)), Vectors.dense(7.0)), + (2.0, Vectors.dense(Array(8.0, 7.0, 3.0)), Vectors.dense(8.0)) + )).toDF("label", "features", "topFeature") + } - val preFilteredData2 = Seq( - Vectors.dense(8.0, 7.0), - Vectors.dense(0.0, 9.0), - Vectors.dense(0.0, 9.0), - Vectors.dense(8.0, 9.0) - ) + test("params") { + ParamsSuite.checkParams(new ChiSqSelector) + val model = new ChiSqSelectorModel("myModel", + new org.apache.spark.mllib.feature.ChiSqSelectorModel(Array(1, 3, 4))) + ParamsSuite.checkParams(model) + } - val df2 = sc.parallelize(data.zip(preFilteredData2)) - .map(x => (x._1.label, x._1.features, x._2)) - .toDF("label", "data", "preFilteredData") + test("Test Chi-Square selector: numTopFeatures") { + val selector = new ChiSqSelector() + .setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(1) + ChiSqSelectorSuite.testSelector(selector, dataset) + } - selector.setSelectorType("fpr").setAlpha(0.2).fit(df2).transform(df2) - .select("filtered", "preFilteredData").collect().foreach { - case Row(vec1: Vector, vec2: Vector) => - assert(vec1 ~== vec2 absTol 1e-1) - } + test("Test Chi-Square selector: percentile") { + val selector = new ChiSqSelector() + .setOutputCol("filtered").setSelectorType("percentile").setPercentile(0.34) + ChiSqSelectorSuite.testSelector(selector, dataset) } - test("ChiSqSelector read/write") { - val t = new ChiSqSelector() - .setFeaturesCol("myFeaturesCol") - .setLabelCol("myLabelCol") - .setOutputCol("myOutputCol") - .setNumTopFeatures(2) - testDefaultReadWrite(t) + test("Test Chi-Square selector: fpr") { + val selector = new ChiSqSelector() + .setOutputCol("filtered").setSelectorType("fpr").setFpr(0.2) + ChiSqSelectorSuite.testSelector(selector, dataset) } - test("ChiSqSelectorModel read/write") { - val oldModel = new feature.ChiSqSelectorModel(Array(1, 3)) - val instance = new ChiSqSelectorModel("myChiSqSelectorModel", oldModel) - val newInstance = testDefaultReadWrite(instance) - assert(newInstance.selectedFeatures === instance.selectedFeatures) + test("read/write") { + def checkModelData(model: ChiSqSelectorModel, model2: ChiSqSelectorModel): Unit = { + assert(model.selectedFeatures === model2.selectedFeatures) + } + val nb = new ChiSqSelector + testEstimatorAndModelReadWrite(nb, dataset, ChiSqSelectorSuite.allParamSettings, checkModelData) } test("should support all NumericType labels and not support other types") { @@ -108,3 +95,25 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext } } } + +object ChiSqSelectorSuite { + + private def testSelector(selector: ChiSqSelector, dataset: Dataset[_]): Unit = { + selector.fit(dataset).transform(dataset).select("filtered", "topFeature").collect() + .foreach { case Row(vec1: Vector, vec2: Vector) => + assert(vec1 ~== vec2 absTol 1e-1) + } + } + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "selectorType" -> "percentile", + "numTopFeatures" -> 1, + "percentile" -> 0.12, + "outputCol" -> "myOutput" + ) +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala index ac702b4b7c69e..77219e500617d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala @@ -54,33 +54,34 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))), LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0)))), 2) val preFilteredData = - Set(LabeledPoint(0.0, Vectors.dense(Array(8.0))), + Seq(LabeledPoint(0.0, Vectors.dense(Array(8.0))), LabeledPoint(1.0, Vectors.dense(Array(0.0))), LabeledPoint(1.0, Vectors.dense(Array(0.0))), LabeledPoint(2.0, Vectors.dense(Array(8.0)))) val model = new ChiSqSelector(1).fit(labeledDiscreteData) val filteredData = labeledDiscreteData.map { lp => LabeledPoint(lp.label, model.transform(lp.features)) - }.collect().toSet - assert(filteredData == preFilteredData) + }.collect().toSeq + assert(filteredData === preFilteredData) } - test("ChiSqSelector by FPR transform test (sparse & dense vector)") { + test("ChiSqSelector by fpr transform test (sparse & dense vector)") { val labeledDiscreteData = sc.parallelize( Seq(LabeledPoint(0.0, Vectors.sparse(4, Array((0, 8.0), (1, 7.0)))), LabeledPoint(1.0, Vectors.sparse(4, Array((1, 9.0), (2, 6.0), (3, 4.0)))), LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0, 4.0))), LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0, 9.0)))), 2) val preFilteredData = - Set(LabeledPoint(0.0, Vectors.dense(Array(0.0))), + Seq(LabeledPoint(0.0, Vectors.dense(Array(0.0))), LabeledPoint(1.0, Vectors.dense(Array(4.0))), LabeledPoint(1.0, Vectors.dense(Array(4.0))), LabeledPoint(2.0, Vectors.dense(Array(9.0)))) - val model = new ChiSqSelector().setSelectorType("fpr").setAlpha(0.1).fit(labeledDiscreteData) + val model: ChiSqSelectorModel = new ChiSqSelector().setSelectorType("fpr") + .setFpr(0.1).fit(labeledDiscreteData) val filteredData = labeledDiscreteData.map { lp => LabeledPoint(lp.label, model.transform(lp.features)) - }.collect().toSet - assert(filteredData == preFilteredData) + }.collect().toSeq + assert(filteredData === preFilteredData) } test("model load / save") { diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 94afe82a36472..635cf1304588e 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2606,42 +2606,43 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja selectorType = Param(Params._dummy(), "selectorType", "The selector type of the ChisqSelector. " + - "Supported options: kbest (default), percentile and fpr.", + "Supported options: numTopFeatures (default), percentile and fpr.", typeConverter=TypeConverters.toString) numTopFeatures = \ Param(Params._dummy(), "numTopFeatures", - "Number of features that selector will select, ordered by statistics value " + - "descending. If the number of features is < numTopFeatures, then this will select " + + "Number of features that selector will select, ordered by ascending p-value. " + + "If the number of features is < numTopFeatures, then this will select " + "all features.", typeConverter=TypeConverters.toInt) percentile = Param(Params._dummy(), "percentile", "Percentile of features that selector " + - "will select, ordered by statistics value descending.", + "will select, ordered by ascending p-value.", typeConverter=TypeConverters.toFloat) - alpha = Param(Params._dummy(), "alpha", "The highest p-value for features to be kept.", - typeConverter=TypeConverters.toFloat) + fpr = Param(Params._dummy(), "fpr", "The highest p-value for features to be kept.", + typeConverter=TypeConverters.toFloat) @keyword_only def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, - labelCol="label", selectorType="kbest", percentile=0.1, alpha=0.05): + labelCol="label", selectorType="numTopFeatures", percentile=0.1, fpr=0.05): """ __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, \ - labelCol="label", selectorType="kbest", percentile=0.1, alpha=0.05) + labelCol="label", selectorType="numTopFeatures", percentile=0.1, fpr=0.05) """ super(ChiSqSelector, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ChiSqSelector", self.uid) - self._setDefault(numTopFeatures=50, selectorType="kbest", percentile=0.1, alpha=0.05) + self._setDefault(numTopFeatures=50, selectorType="numTopFeatures", percentile=0.1, + fpr=0.05) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only @since("2.0.0") def setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None, - labelCol="labels", selectorType="kbest", percentile=0.1, alpha=0.05): + labelCol="labels", selectorType="numTopFeatures", percentile=0.1, fpr=0.05): """ setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None, \ - labelCol="labels", selectorType="kbest", percentile=0.1, alpha=0.05) + labelCol="labels", selectorType="numTopFeatures", percentile=0.1, fpr=0.05) Sets params for this ChiSqSelector. """ kwargs = self.setParams._input_kwargs @@ -2665,7 +2666,7 @@ def getSelectorType(self): def setNumTopFeatures(self, value): """ Sets the value of :py:attr:`numTopFeatures`. - Only applicable when selectorType = "kbest". + Only applicable when selectorType = "numTopFeatures". """ return self._set(numTopFeatures=value) @@ -2692,19 +2693,19 @@ def getPercentile(self): return self.getOrDefault(self.percentile) @since("2.1.0") - def setAlpha(self, value): + def setFpr(self, value): """ - Sets the value of :py:attr:`alpha`. + Sets the value of :py:attr:`fpr`. Only applicable when selectorType = "fpr". """ - return self._set(alpha=value) + return self._set(fpr=value) @since("2.1.0") - def getAlpha(self): + def getFpr(self): """ - Gets the value of alpha or its default value. + Gets the value of fpr or its default value. """ - return self.getOrDefault(self.alpha) + return self.getOrDefault(self.fpr) def _create_model(self, java_model): return ChiSqSelectorModel(java_model) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 50ef7c7901c2c..7eaa2282cb8bb 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -274,52 +274,48 @@ def transform(self, vector): class ChiSqSelector(object): """ Creates a ChiSquared feature selector. - The selector supports three selection methods: `KBest`, `Percentile` and `FPR`. - `kbest` chooses the `k` top features according to a chi-squared test. + The selector supports different selection methods: `numTopFeatures`, `percentile`, `fpr`. + `numTopFeatures` chooses a fixed number of top features according to a chi-squared test. `percentile` is similar but chooses a fraction of all features instead of a fixed number. - `fpr` chooses all features whose false positive rate meets some threshold. - By default, the selection method is `kbest`, the default number of top features is 50. + `fpr` chooses all features whose p-value is below a threshold, thus controlling the false + positive rate of selection. + By default, the selection method is `numTopFeatures`, with the default number of top features + set to 50. - >>> data = [ + >>> data = sc.parallelize([ ... LabeledPoint(0.0, SparseVector(3, {0: 8.0, 1: 7.0})), ... LabeledPoint(1.0, SparseVector(3, {1: 9.0, 2: 6.0})), ... LabeledPoint(1.0, [0.0, 9.0, 8.0]), - ... LabeledPoint(2.0, [8.0, 9.0, 5.0]) - ... ] - >>> model = ChiSqSelector().setNumTopFeatures(1).fit(sc.parallelize(data)) + ... LabeledPoint(2.0, [7.0, 9.0, 5.0]), + ... LabeledPoint(2.0, [8.0, 7.0, 3.0]) + ... ]) + >>> model = ChiSqSelector(numTopFeatures=1).fit(data) >>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0})) SparseVector(1, {}) - >>> model.transform(DenseVector([8.0, 9.0, 5.0])) - DenseVector([8.0]) - >>> model = ChiSqSelector().setSelectorType("percentile").setPercentile(0.34).fit( - ... sc.parallelize(data)) + >>> model.transform(DenseVector([7.0, 9.0, 5.0])) + DenseVector([7.0]) + >>> model = ChiSqSelector(selectorType="fpr", fpr=0.2).fit(data) >>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0})) SparseVector(1, {}) - >>> model.transform(DenseVector([8.0, 9.0, 5.0])) - DenseVector([8.0]) - >>> data = [ - ... LabeledPoint(0.0, SparseVector(4, {0: 8.0, 1: 7.0})), - ... LabeledPoint(1.0, SparseVector(4, {1: 9.0, 2: 6.0, 3: 4.0})), - ... LabeledPoint(1.0, [0.0, 9.0, 8.0, 4.0]), - ... LabeledPoint(2.0, [8.0, 9.0, 5.0, 9.0]) - ... ] - >>> model = ChiSqSelector().setSelectorType("fpr").setAlpha(0.1).fit(sc.parallelize(data)) - >>> model.transform(DenseVector([1.0,2.0,3.0,4.0])) - DenseVector([4.0]) + >>> model.transform(DenseVector([7.0, 9.0, 5.0])) + DenseVector([7.0]) + >>> model = ChiSqSelector(selectorType="percentile", percentile=0.34).fit(data) + >>> model.transform(DenseVector([7.0, 9.0, 5.0])) + DenseVector([7.0]) .. versionadded:: 1.4.0 """ - def __init__(self, numTopFeatures=50, selectorType="kbest", percentile=0.1, alpha=0.05): + def __init__(self, numTopFeatures=50, selectorType="numTopFeatures", percentile=0.1, fpr=0.05): self.numTopFeatures = numTopFeatures self.selectorType = selectorType self.percentile = percentile - self.alpha = alpha + self.fpr = fpr @since('2.1.0') def setNumTopFeatures(self, numTopFeatures): """ set numTopFeature for feature selection by number of top features. - Only applicable when selectorType = "kbest". + Only applicable when selectorType = "numTopFeatures". """ self.numTopFeatures = int(numTopFeatures) return self @@ -334,19 +330,19 @@ def setPercentile(self, percentile): return self @since('2.1.0') - def setAlpha(self, alpha): + def setFpr(self, fpr): """ - set alpha [0.0, 1.0] for feature selection by FPR. + set FPR [0.0, 1.0] for feature selection by FPR. Only applicable when selectorType = "fpr". """ - self.alpha = float(alpha) + self.fpr = float(fpr) return self @since('2.1.0') def setSelectorType(self, selectorType): """ set the selector type of the ChisqSelector. - Supported options: "kbest" (default), "percentile" and "fpr". + Supported options: "numTopFeatures" (default), "percentile", "fpr". """ self.selectorType = str(selectorType) return self @@ -362,7 +358,7 @@ def fit(self, data): Apply feature discretizer before using this function. """ jmodel = callMLlibFunc("fitChiSqSelector", self.selectorType, self.numTopFeatures, - self.percentile, self.alpha, data) + self.percentile, self.fpr, data) return ChiSqSelectorModel(jmodel) From 77a98162d1ec28247053b8b3ad4af28baa950797 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 1 Nov 2016 18:06:57 -0700 Subject: [PATCH 010/198] [SPARK-18025] Use commit protocol API in structured streaming ## What changes were proposed in this pull request? This patch adds a new commit protocol implementation ManifestFileCommitProtocol that follows the existing streaming flow, and uses it in FileStreamSink to consolidate the write path in structured streaming with the batch mode write path. This deletes a lot of code, and would make it trivial to support other functionalities that are currently available in batch but not in streaming, including all file formats and bucketing. ## How was this patch tested? Should be covered by existing tests. Author: Reynold Xin Closes #15710 from rxin/SPARK-18025. --- .../datasources/FileCommitProtocol.scala | 11 +- .../execution/datasources/FileFormat.scala | 14 -- ...iteOutput.scala => FileFormatWriter.scala} | 20 +- .../InsertIntoHadoopFsRelationCommand.scala | 25 +- .../parquet/ParquetFileFormat.scala | 11 - .../parquet/ParquetOutputWriter.scala | 116 +-------- .../execution/streaming/FileStreamSink.scala | 229 ++---------------- .../ManifestFileCommitProtocol.scala | 114 +++++++++ .../apache/spark/sql/internal/SQLConf.scala | 3 +- .../sql/streaming/FileStreamSinkSuite.scala | 106 +------- 10 files changed, 174 insertions(+), 475 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/{WriteOutput.scala => FileFormatWriter.scala} (97%) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileCommitProtocol.scala index 1ce9ae4266c1a..f5dd5ce22919d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileCommitProtocol.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileCommitProtocol.scala @@ -32,9 +32,9 @@ import org.apache.spark.util.Utils object FileCommitProtocol { - class TaskCommitMessage(obj: Any) extends Serializable + class TaskCommitMessage(val obj: Any) extends Serializable - object EmptyTaskCommitMessage extends TaskCommitMessage(Unit) + object EmptyTaskCommitMessage extends TaskCommitMessage(null) /** * Instantiates a FileCommitProtocol using the given className. @@ -62,8 +62,11 @@ object FileCommitProtocol { /** - * An interface to define how a Spark job commits its outputs. Implementations must be serializable, - * as the committer instance instantiated on the driver will be used for tasks on executors. + * An interface to define how a single Spark job commits its outputs. Two notes: + * + * 1. Implementations must be serializable, as the committer instance instantiated on the driver + * will be used for tasks on executors. + * 2. A committer should not be reused across multiple Spark jobs. * * The proper call sequence is: * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index 9d153cec731a8..4f4aaaa5026fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -55,20 +55,6 @@ trait FileFormat { options: Map[String, String], dataSchema: StructType): OutputWriterFactory - /** - * Returns a [[OutputWriterFactory]] for generating output writers that can write data. - * This method is current used only by FileStreamSinkWriter to generate output writers that - * does not use output committers to write data. The OutputWriter generated by the returned - * [[OutputWriterFactory]] must implement the method `newWriter(path)`.. - */ - def buildWriter( - sqlContext: SQLContext, - dataSchema: StructType, - options: Map[String, String]): OutputWriterFactory = { - // TODO: Remove this default implementation when the other formats have been ported - throw new UnsupportedOperationException(s"buildWriter is not supported for $this") - } - /** * Returns whether this format support returning columnar batch or not. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala similarity index 97% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index a07855111b401..bc00a0a749c09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -43,8 +43,8 @@ import org.apache.spark.util.{SerializableConfiguration, Utils} import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter -/** A helper object for writing data out to a location. */ -object WriteOutput extends Logging { +/** A helper object for writing FileFormat data out to a location. */ +object FileFormatWriter extends Logging { /** A shared job description for all the write tasks. */ private class WriteJobDescription( @@ -55,7 +55,6 @@ object WriteOutput extends Logging { val partitionColumns: Seq[Attribute], val nonPartitionColumns: Seq[Attribute], val bucketSpec: Option[BucketSpec], - val isAppend: Boolean, val path: String) extends Serializable { @@ -82,18 +81,18 @@ object WriteOutput extends Logging { sparkSession: SparkSession, plan: LogicalPlan, fileFormat: FileFormat, - outputPath: Path, + committer: FileCommitProtocol, + outputPath: String, hadoopConf: Configuration, partitionColumns: Seq[Attribute], bucketSpec: Option[BucketSpec], refreshFunction: (Seq[TablePartitionSpec]) => Unit, - options: Map[String, String], - isAppend: Boolean): Unit = { + options: Map[String, String]): Unit = { val job = Job.getInstance(hadoopConf) job.setOutputKeyClass(classOf[Void]) job.setOutputValueClass(classOf[InternalRow]) - FileOutputFormat.setOutputPath(job, outputPath) + FileOutputFormat.setOutputPath(job, new Path(outputPath)) val partitionSet = AttributeSet(partitionColumns) val dataColumns = plan.output.filterNot(partitionSet.contains) @@ -111,16 +110,11 @@ object WriteOutput extends Logging { partitionColumns = partitionColumns, nonPartitionColumns = dataColumns, bucketSpec = bucketSpec, - isAppend = isAppend, - path = outputPath.toString) + path = outputPath) SQLExecution.withNewExecutionId(sparkSession, queryExecution) { // This call shouldn't be put into the `try` block below because it only initializes and // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. - val committer = FileCommitProtocol.instantiate( - sparkSession.sessionState.conf.fileCommitProtocolClass, - outputPath.toString, - isAppend) committer.setupJob(job) try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index a1221d0ae6d27..230c74a47ba2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -84,17 +84,22 @@ case class InsertIntoHadoopFsRelationCommand( val isAppend = pathExists && (mode == SaveMode.Append) if (doInsertion) { - WriteOutput.write( - sparkSession, - query, - fileFormat, - qualifiedOutputPath, - hadoopConf, - partitionColumns, - bucketSpec, - refreshFunction, - options, + val committer = FileCommitProtocol.instantiate( + sparkSession.sessionState.conf.fileCommitProtocolClass, + outputPath.toString, isAppend) + + FileFormatWriter.write( + sparkSession = sparkSession, + plan = query, + fileFormat = fileFormat, + committer = committer, + outputPath = qualifiedOutputPath.toString, + hadoopConf = hadoopConf, + partitionColumns = partitionColumns, + bucketSpec = bucketSpec, + refreshFunction = refreshFunction, + options = options) } else { logInfo("Skipping insertion into a relation that already exists.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 77c83ba38efee..b8ea7f40c4ab3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -415,17 +415,6 @@ class ParquetFileFormat } } } - - override def buildWriter( - sqlContext: SQLContext, - dataSchema: StructType, - options: Map[String, String]): OutputWriterFactory = { - new ParquetOutputWriterFactory( - sqlContext.conf, - dataSchema, - sqlContext.sessionState.newHadoopConf(), - options) - } } object ParquetFileFormat extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala index 92d4f27be3fd5..5c0f8af17a232 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala @@ -17,125 +17,13 @@ package org.apache.spark.sql.execution.datasources.parquet -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl -import org.apache.parquet.hadoop.{ParquetOutputFormat, ParquetRecordWriter} -import org.apache.parquet.hadoop.codec.CodecConfig -import org.apache.parquet.hadoop.util.ContextUtil +import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory} -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.StructType -import org.apache.spark.util.SerializableConfiguration - - -/** - * A factory for generating OutputWriters for writing parquet files. This implemented is different - * from the [[ParquetOutputWriter]] as this does not use any [[OutputCommitter]]. It simply - * writes the data to the path used to generate the output writer. Callers of this factory - * has to ensure which files are to be considered as committed. - */ -private[parquet] class ParquetOutputWriterFactory( - sqlConf: SQLConf, - dataSchema: StructType, - hadoopConf: Configuration, - options: Map[String, String]) - extends OutputWriterFactory { - - private val serializableConf: SerializableConfiguration = { - val job = Job.getInstance(hadoopConf) - val conf = ContextUtil.getConfiguration(job) - val parquetOptions = new ParquetOptions(options, sqlConf) - - // We're not really using `ParquetOutputFormat[Row]` for writing data here, because we override - // it in `ParquetOutputWriter` to support appending and dynamic partitioning. The reason why - // we set it here is to setup the output committer class to `ParquetOutputCommitter`, which is - // bundled with `ParquetOutputFormat[Row]`. - job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]]) - - ParquetOutputFormat.setWriteSupportClass(job, classOf[ParquetWriteSupport]) - - // We want to clear this temporary metadata from saving into Parquet file. - // This metadata is only useful for detecting optional columns when pushing down filters. - val dataSchemaToWrite = StructType.removeMetadata( - StructType.metadataKeyForOptionalField, - dataSchema).asInstanceOf[StructType] - ParquetWriteSupport.setSchema(dataSchemaToWrite, conf) - - // Sets flags for `CatalystSchemaConverter` (which converts Catalyst schema to Parquet schema) - // and `CatalystWriteSupport` (writing actual rows to Parquet files). - conf.set( - SQLConf.PARQUET_BINARY_AS_STRING.key, - sqlConf.isParquetBinaryAsString.toString) - - conf.set( - SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, - sqlConf.isParquetINT96AsTimestamp.toString) - - conf.set( - SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key, - sqlConf.writeLegacyParquetFormat.toString) - - // Sets compression scheme - conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName) - new SerializableConfiguration(conf) - } - - /** - * Returns a [[OutputWriter]] that writes data to the give path without using - * [[OutputCommitter]]. - */ - override def newWriter(path: String): OutputWriter = new OutputWriter { - - // Create TaskAttemptContext that is used to pass on Configuration to the ParquetRecordWriter - private val hadoopTaskAttemptId = new TaskAttemptID(new TaskID(new JobID, TaskType.MAP, 0), 0) - private val hadoopAttemptContext = new TaskAttemptContextImpl( - serializableConf.value, hadoopTaskAttemptId) - - // Instance of ParquetRecordWriter that does not use OutputCommitter - private val recordWriter = createNoCommitterRecordWriter(path, hadoopAttemptContext) - - override def write(row: Row): Unit = { - throw new UnsupportedOperationException("call writeInternal") - } - - protected[sql] override def writeInternal(row: InternalRow): Unit = { - recordWriter.write(null, row) - } - - override def close(): Unit = recordWriter.close(hadoopAttemptContext) - } - - /** Create a [[ParquetRecordWriter]] that writes the given path without using OutputCommitter */ - private def createNoCommitterRecordWriter( - path: String, - hadoopAttemptContext: TaskAttemptContext): RecordWriter[Void, InternalRow] = { - // Custom ParquetOutputFormat that disable use of committer and writes to the given path - val outputFormat = new ParquetOutputFormat[InternalRow]() { - override def getOutputCommitter(c: TaskAttemptContext): OutputCommitter = { null } - override def getDefaultWorkFile(c: TaskAttemptContext, ext: String): Path = { new Path(path) } - } - outputFormat.getRecordWriter(hadoopAttemptContext) - } - - /** Disable the use of the older API. */ - override def newInstance( - path: String, - dataSchema: StructType, - context: TaskAttemptContext): OutputWriter = { - throw new UnsupportedOperationException("this version of newInstance not supported for " + - "ParquetOutputWriterFactory") - } - - override def getFileExtension(context: TaskAttemptContext): String = { - CodecConfig.from(context).getCodec.getExtension + ".parquet" - } -} - +import org.apache.spark.sql.execution.datasources.OutputWriter // NOTE: This class is instantiated and used on executor side only, no need to be serializable. private[parquet] class ParquetOutputWriter(path: String, context: TaskAttemptContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 02c5b857ee7fe..daec2b5450971 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -17,23 +17,12 @@ package org.apache.spark.sql.execution.streaming -import java.util.UUID - -import scala.collection.mutable.ArrayBuffer - -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkEnv, SparkException, TaskContext, TaskContextImpl} import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, SparkSession} -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.UnsafeKVExternalSorter -import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriter, PartitioningUtils} -import org.apache.spark.sql.types.{StringType, StructType} -import org.apache.spark.util.SerializableConfiguration -import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter +import org.apache.spark.sql.execution.datasources.{FileCommitProtocol, FileFormat, FileFormatWriter} object FileStreamSink { // The name of the subdirectory that is used to store metadata about which files are valid. @@ -59,207 +48,41 @@ class FileStreamSink( private val fileLog = new FileStreamSinkLog(FileStreamSinkLog.VERSION, sparkSession, logPath.toUri.toString) private val hadoopConf = sparkSession.sessionState.newHadoopConf() - private val fs = basePath.getFileSystem(hadoopConf) override def addBatch(batchId: Long, data: DataFrame): Unit = { if (batchId <= fileLog.getLatest().map(_._1).getOrElse(-1L)) { logInfo(s"Skipping already committed batch $batchId") } else { - val writer = new FileStreamSinkWriter( - data, fileFormat, path, partitionColumnNames, hadoopConf, options) - val fileStatuses = writer.write() - if (fileLog.add(batchId, fileStatuses)) { - logInfo(s"Committed batch $batchId") - } else { - throw new IllegalStateException(s"Race while writing batch $batchId") + val committer = FileCommitProtocol.instantiate( + sparkSession.sessionState.conf.streamingFileCommitProtocolClass, path, isAppend = false) + committer match { + case manifestCommitter: ManifestFileCommitProtocol => + manifestCommitter.setupManifestOptions(fileLog, batchId) + case _ => // Do nothing } - } - } - - override def toString: String = s"FileSink[$path]" -} - - -/** - * Writes data given to a [[FileStreamSink]] to the given `basePath` in the given `fileFormat`, - * partitioned by the given `partitionColumnNames`. This writer always appends data to the - * directory if it already has data. - */ -class FileStreamSinkWriter( - data: DataFrame, - fileFormat: FileFormat, - basePath: String, - partitionColumnNames: Seq[String], - hadoopConf: Configuration, - options: Map[String, String]) extends Serializable with Logging { - - PartitioningUtils.validatePartitionColumn( - data.schema, partitionColumnNames, data.sqlContext.conf.caseSensitiveAnalysis) - - private val serializableConf = new SerializableConfiguration(hadoopConf) - private val dataSchema = data.schema - private val dataColumns = data.logicalPlan.output - - // Get the actual partition columns as attributes after matching them by name with - // the given columns names. - private val partitionColumns = partitionColumnNames.map { col => - val nameEquality = data.sparkSession.sessionState.conf.resolver - data.logicalPlan.output.find(f => nameEquality(f.name, col)).getOrElse { - throw new RuntimeException(s"Partition column $col not found in schema $dataSchema") - } - } - - // Columns that are to be written to the files. If there are partitioning columns, then - // those will not be written to the files. - private val writeColumns = { - val partitionSet = AttributeSet(partitionColumns) - dataColumns.filterNot(partitionSet.contains) - } - - // An OutputWriterFactory for generating writers in the executors for writing the files. - private val outputWriterFactory = - fileFormat.buildWriter(data.sqlContext, writeColumns.toStructType, options) - - /** Expressions that given a partition key build a string like: col1=val/col2=val/... */ - private def partitionStringExpression: Seq[Expression] = { - partitionColumns.zipWithIndex.flatMap { case (c, i) => - val escaped = - ScalaUDF( - PartitioningUtils.escapePathName _, - StringType, - Seq(Cast(c, StringType)), - Seq(StringType)) - val str = If(IsNull(c), Literal(PartitioningUtils.DEFAULT_PARTITION_NAME), escaped) - val partitionName = Literal(c.name + "=") :: str :: Nil - if (i == 0) partitionName else Literal(Path.SEPARATOR) :: partitionName - } - } - - /** Generate a new output writer from the writer factory */ - private def newOutputWriter(path: Path): OutputWriter = { - val newWriter = outputWriterFactory.newWriter(path.toString) - newWriter.initConverter(dataSchema) - newWriter - } - /** Write the dataframe to files. This gets called in the driver by the [[FileStreamSink]]. */ - def write(): Array[SinkFileStatus] = { - data.sqlContext.sparkContext.runJob( - data.queryExecution.toRdd, - (taskContext: TaskContext, iterator: Iterator[InternalRow]) => { - if (partitionColumns.isEmpty) { - Seq(writePartitionToSingleFile(iterator)) - } else { - writePartitionToPartitionedFiles(iterator) + // Get the actual partition columns as attributes after matching them by name with + // the given columns names. + val partitionColumns: Seq[Attribute] = partitionColumnNames.map { col => + val nameEquality = data.sparkSession.sessionState.conf.resolver + data.logicalPlan.output.find(f => nameEquality(f.name, col)).getOrElse { + throw new RuntimeException(s"Partition column $col not found in schema ${data.schema}") } - }).flatten - } - - /** - * Writes a RDD partition to a single file without dynamic partitioning. - * This gets called in the executor, and it uses a [[OutputWriter]] to write the data. - */ - def writePartitionToSingleFile(iterator: Iterator[InternalRow]): SinkFileStatus = { - var writer: OutputWriter = null - try { - val path = new Path(basePath, UUID.randomUUID.toString) - val fs = path.getFileSystem(serializableConf.value) - writer = newOutputWriter(path) - while (iterator.hasNext) { - writer.writeInternal(iterator.next) - } - writer.close() - writer = null - SinkFileStatus(fs.getFileStatus(path)) - } catch { - case cause: Throwable => - logError("Aborting task.", cause) - // call failure callbacks first, so we could have a chance to cleanup the writer. - TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(cause) - throw new SparkException("Task failed while writing rows.", cause) - } finally { - if (writer != null) { - writer.close() } - } - } - - /** - * Writes a RDD partition to multiple dynamically partitioned files. - * This gets called in the executor. It first sorts the data based on the partitioning columns - * and then writes the data of each key to separate files using [[OutputWriter]]s. - */ - def writePartitionToPartitionedFiles(iterator: Iterator[InternalRow]): Seq[SinkFileStatus] = { - - // Returns the partitioning columns for sorting - val getSortingKey = UnsafeProjection.create(partitionColumns, dataColumns) - - // Returns the data columns to be written given an input row - val getOutputRow = UnsafeProjection.create(writeColumns, dataColumns) - - // Returns the partition path given a partition key - val getPartitionString = - UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns) - // Sort the data before write, so that we only need one writer at the same time. - val sorter = new UnsafeKVExternalSorter( - partitionColumns.toStructType, - StructType.fromAttributes(writeColumns), - SparkEnv.get.blockManager, - SparkEnv.get.serializerManager, - TaskContext.get().taskMemoryManager().pageSizeBytes, - SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD)) - - while (iterator.hasNext) { - val currentRow = iterator.next() - sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) - } - logDebug(s"Sorting complete. Writing out partition files one at a time.") - - val sortedIterator = sorter.sortedIterator() - val paths = new ArrayBuffer[Path] - - // Write the sorted data to partitioned files, one for each unique key - var currentWriter: OutputWriter = null - try { - var currentKey: UnsafeRow = null - while (sortedIterator.next()) { - val nextKey = sortedIterator.getKey - - // If key changes, close current writer, and open a new writer to a new partitioned file - if (currentKey != nextKey) { - if (currentWriter != null) { - currentWriter.close() - currentWriter = null - } - currentKey = nextKey.copy() - val partitionPath = getPartitionString(currentKey).getString(0) - val path = new Path(new Path(basePath, partitionPath), UUID.randomUUID.toString) - paths += path - currentWriter = newOutputWriter(path) - logInfo(s"Writing partition $currentKey to $path") - } - currentWriter.writeInternal(sortedIterator.getValue) - } - if (currentWriter != null) { - currentWriter.close() - currentWriter = null - } - if (paths.nonEmpty) { - val fs = paths.head.getFileSystem(serializableConf.value) - paths.map(p => SinkFileStatus(fs.getFileStatus(p))) - } else Seq.empty - } catch { - case cause: Throwable => - logError("Aborting task.", cause) - // call failure callbacks first, so we could have a chance to cleanup the writer. - TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(cause) - throw new SparkException("Task failed while writing rows.", cause) - } finally { - if (currentWriter != null) { - currentWriter.close() - } + FileFormatWriter.write( + sparkSession = sparkSession, + plan = data.logicalPlan, + fileFormat = fileFormat, + committer = committer, + outputPath = path, + hadoopConf = hadoopConf, + partitionColumns = partitionColumns, + bucketSpec = None, + refreshFunction = _ => (), + options = options) } } + + override def toString: String = s"FileSink[$path]" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala new file mode 100644 index 0000000000000..510312267a98d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.UUID + +import scala.collection.mutable.ArrayBuffer + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.datasources.FileCommitProtocol +import org.apache.spark.sql.execution.datasources.FileCommitProtocol.TaskCommitMessage + +/** + * A [[FileCommitProtocol]] that tracks the list of valid files in a manifest file, used in + * structured streaming. + * + * @param path path to write the final output to. + */ +class ManifestFileCommitProtocol(path: String) + extends FileCommitProtocol with Serializable with Logging { + + // Track the list of files added by a task, only used on the executors. + @transient private var addedFiles: ArrayBuffer[String] = _ + + @transient private var fileLog: FileStreamSinkLog = _ + private var batchId: Long = _ + + /** + * Sets up the manifest log output and the batch id for this job. + * Must be called before any other function. + */ + def setupManifestOptions(fileLog: FileStreamSinkLog, batchId: Long): Unit = { + this.fileLog = fileLog + this.batchId = batchId + } + + override def setupJob(jobContext: JobContext): Unit = { + require(fileLog != null, "setupManifestOptions must be called before this function") + // Do nothing + } + + override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = { + require(fileLog != null, "setupManifestOptions must be called before this function") + val fileStatuses = taskCommits.flatMap(_.obj.asInstanceOf[Seq[SinkFileStatus]]).toArray + + if (fileLog.add(batchId, fileStatuses)) { + logInfo(s"Committed batch $batchId") + } else { + throw new IllegalStateException(s"Race while writing batch $batchId") + } + } + + override def abortJob(jobContext: JobContext): Unit = { + require(fileLog != null, "setupManifestOptions must be called before this function") + // Do nothing + } + + override def setupTask(taskContext: TaskAttemptContext): Unit = { + addedFiles = new ArrayBuffer[String] + } + + override def newTaskTempFile( + taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = { + // The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet + // Note that %05d does not truncate the split number, so if we have more than 100000 tasks, + // the file name is fine and won't overflow. + val split = taskContext.getTaskAttemptID.getTaskID.getId + val uuid = UUID.randomUUID.toString + val filename = f"part-$split%05d-$uuid$ext" + + val file = dir.map { d => + new Path(new Path(path, d), filename).toString + }.getOrElse { + new Path(path, filename).toString + } + + addedFiles += file + file + } + + override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = { + if (addedFiles.nonEmpty) { + val fs = new Path(addedFiles.head).getFileSystem(taskContext.getConfiguration) + val statuses: Seq[SinkFileStatus] = + addedFiles.map(f => SinkFileStatus(fs.getFileStatus(new Path(f)))) + new TaskCommitMessage(statuses) + } else { + new TaskCommitMessage(Seq.empty[SinkFileStatus]) + } + } + + override def abortTask(taskContext: TaskAttemptContext): Unit = { + // Do nothing + // TODO: we can also try delete the addedFiles as a best-effort cleanup. + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 29e79847aa38b..7bb3ac02fa5d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -31,6 +31,7 @@ import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.execution.datasources.HadoopCommitProtocolWrapper +import org.apache.spark.sql.execution.streaming.ManifestFileCommitProtocol import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -523,7 +524,7 @@ object SQLConf { SQLConfigBuilder("spark.sql.streaming.commitProtocolClass") .internal() .stringConf - .createWithDefault(classOf[HadoopCommitProtocolWrapper].getName) + .createWithDefault(classOf[ManifestFileCommitProtocol].getName) val FILE_SINK_LOG_DELETION = SQLConfigBuilder("spark.sql.streaming.fileSink.log.deletion") .internal() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 18b42a81a098c..902cf05344716 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -17,106 +17,16 @@ package org.apache.spark.sql.streaming -import java.io.File - -import org.apache.commons.io.FileUtils -import org.apache.commons.io.filefilter.{DirectoryFileFilter, RegexFileFilter} - import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.streaming.{FileStreamSinkWriter, MemoryStream, MetadataLogFileIndex} -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.execution.streaming.{MemoryStream, MetadataLogFileIndex} import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils class FileStreamSinkSuite extends StreamTest { import testImplicits._ - - test("FileStreamSinkWriter - unpartitioned data") { - val path = Utils.createTempDir() - path.delete() - - val hadoopConf = spark.sparkContext.hadoopConfiguration - val fileFormat = new parquet.ParquetFileFormat() - - def writeRange(start: Int, end: Int, numPartitions: Int): Seq[String] = { - val df = spark - .range(start, end, 1, numPartitions) - .select($"id", lit(100).as("data")) - val writer = new FileStreamSinkWriter( - df, fileFormat, path.toString, partitionColumnNames = Nil, hadoopConf, Map.empty) - writer.write().map(_.path.stripPrefix("file://")) - } - - // Write and check whether new files are written correctly - val files1 = writeRange(0, 10, 2) - assert(files1.size === 2, s"unexpected number of files: $files1") - checkFilesExist(path, files1, "file not written") - checkAnswer(spark.read.load(path.getCanonicalPath), (0 until 10).map(Row(_, 100))) - - // Append and check whether new files are written correctly and old files still exist - val files2 = writeRange(10, 20, 3) - assert(files2.size === 3, s"unexpected number of files: $files2") - assert(files2.intersect(files1).isEmpty, "old files returned") - checkFilesExist(path, files2, s"New file not written") - checkFilesExist(path, files1, s"Old file not found") - checkAnswer(spark.read.load(path.getCanonicalPath), (0 until 20).map(Row(_, 100))) - } - - test("FileStreamSinkWriter - partitioned data") { - implicit val e = ExpressionEncoder[java.lang.Long] - val path = Utils.createTempDir() - path.delete() - - val hadoopConf = spark.sparkContext.hadoopConfiguration - val fileFormat = new parquet.ParquetFileFormat() - - def writeRange(start: Int, end: Int, numPartitions: Int): Seq[String] = { - val df = spark - .range(start, end, 1, numPartitions) - .flatMap(x => Iterator(x, x, x)).toDF("id") - .select($"id", lit(100).as("data1"), lit(1000).as("data2")) - - require(df.rdd.partitions.size === numPartitions) - val writer = new FileStreamSinkWriter( - df, fileFormat, path.toString, partitionColumnNames = Seq("id"), hadoopConf, Map.empty) - writer.write().map(_.path.stripPrefix("file://")) - } - - def checkOneFileWrittenPerKey(keys: Seq[Int], filesWritten: Seq[String]): Unit = { - keys.foreach { id => - assert( - filesWritten.count(_.contains(s"/id=$id/")) == 1, - s"no file for id=$id. all files: \n\t${filesWritten.mkString("\n\t")}" - ) - } - } - - // Write and check whether new files are written correctly - val files1 = writeRange(0, 10, 2) - assert(files1.size === 10, s"unexpected number of files:\n${files1.mkString("\n")}") - checkFilesExist(path, files1, "file not written") - checkOneFileWrittenPerKey(0 until 10, files1) - - val answer1 = (0 until 10).flatMap(x => Iterator(x, x, x)).map(Row(100, 1000, _)) - checkAnswer(spark.read.load(path.getCanonicalPath), answer1) - - // Append and check whether new files are written correctly and old files still exist - val files2 = writeRange(0, 20, 3) - assert(files2.size === 20, s"unexpected number of files:\n${files2.mkString("\n")}") - assert(files2.intersect(files1).isEmpty, "old files returned") - checkFilesExist(path, files2, s"New file not written") - checkFilesExist(path, files1, s"Old file not found") - checkOneFileWrittenPerKey(0 until 20, files2) - - val answer2 = (0 until 20).flatMap(x => Iterator(x, x, x)).map(Row(100, 1000, _)) - checkAnswer(spark.read.load(path.getCanonicalPath), answer1 ++ answer2) - } - test("FileStreamSink - unpartitioned writing and batch reading") { val inputData = MemoryStream[Int] val df = inputData.toDF() @@ -270,18 +180,4 @@ class FileStreamSinkSuite extends StreamTest { } } - private def checkFilesExist(dir: File, expectedFiles: Seq[String], msg: String): Unit = { - import scala.collection.JavaConverters._ - val files = - FileUtils.listFiles(dir, new RegexFileFilter("[^.]+"), DirectoryFileFilter.DIRECTORY) - .asScala - .map(_.getCanonicalPath) - .toSet - - expectedFiles.foreach { f => - assert(files.contains(f), - s"\n$msg\nexpected file:\n\t$f\nfound files:\n${files.mkString("\n\t")}") - } - } - } From ad4832a9faf2c0c869bbcad9d71afe1cecbd3ec8 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 1 Nov 2016 21:20:53 -0700 Subject: [PATCH 011/198] [SPARK-18216][SQL] Make Column.expr public ## What changes were proposed in this pull request? Column.expr is private[sql], but it's an actually really useful field to have for debugging. We should open it up, similar to how we use QueryExecution. ## How was this patch tested? N/A - this is a simple visibility change. Author: Reynold Xin Closes #15724 from rxin/SPARK-18216. --- sql/core/src/main/scala/org/apache/spark/sql/Column.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 05e867bf5be96..249408e0fbce4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -118,6 +118,9 @@ class TypedColumn[-T, U]( * $"a" === $"b" * }}} * + * Note that the internal Catalyst expression can be accessed via "expr", but this method is for + * debugging purposes only and can change in any future Spark releases. + * * @groupname java_expr_ops Java-specific expression operators * @groupname expr_ops Expression operators * @groupname df_ops DataFrame functions @@ -126,7 +129,7 @@ class TypedColumn[-T, U]( * @since 1.3.0 */ @InterfaceStability.Stable -class Column(protected[sql] val expr: Expression) extends Logging { +class Column(val expr: Expression) extends Logging { def this(name: String) = this(name match { case "*" => UnresolvedStar(None) From 1ecfafa0869cb3a3e367bda8be252a69874dc4de Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 1 Nov 2016 22:14:53 -0700 Subject: [PATCH 012/198] [SPARK-17838][SPARKR] Check named arguments for options and use formatted R friendly message from JVM exception message ## What changes were proposed in this pull request? This PR proposes to - improve the R-friendly error messages rather than raw JVM exception one. As `read.json`, `read.text`, `read.orc`, `read.parquet` and `read.jdbc` are executed in the same path with `read.df`, and `write.json`, `write.text`, `write.orc`, `write.parquet` and `write.jdbc` shares the same path with `write.df`, it seems it is safe to call `handledCallJMethod` to handle JVM messages. - prevent `zero-length variable name` and prints the ignored options as an warning message. **Before** ``` r > read.json("path", a = 1, 2, 3, "a") Error in env[[name]] <- value : zero-length variable name ``` ``` r > read.json("arbitrary_path") Error in invokeJava(isStatic = FALSE, objId$id, methodName, ...) : org.apache.spark.sql.AnalysisException: Path does not exist: file:/...; at org.apache.spark.sql.execution.datasources.DataSource$$anonfun$12.apply(DataSource.scala:398) ... > read.orc("arbitrary_path") Error in invokeJava(isStatic = FALSE, objId$id, methodName, ...) : org.apache.spark.sql.AnalysisException: Path does not exist: file:/...; at org.apache.spark.sql.execution.datasources.DataSource$$anonfun$12.apply(DataSource.scala:398) ... > read.text("arbitrary_path") Error in invokeJava(isStatic = FALSE, objId$id, methodName, ...) : org.apache.spark.sql.AnalysisException: Path does not exist: file:/...; at org.apache.spark.sql.execution.datasources.DataSource$$anonfun$12.apply(DataSource.scala:398) ... > read.parquet("arbitrary_path") Error in invokeJava(isStatic = FALSE, objId$id, methodName, ...) : org.apache.spark.sql.AnalysisException: Path does not exist: file:/...; at org.apache.spark.sql.execution.datasources.DataSource$$anonfun$12.apply(DataSource.scala:398) ... ``` ``` r > write.json(df, "existing_path") Error in invokeJava(isStatic = FALSE, objId$id, methodName, ...) : org.apache.spark.sql.AnalysisException: path file:/... already exists.; at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand.run(InsertIntoHadoopFsRelationCommand.scala:68) > write.orc(df, "existing_path") Error in invokeJava(isStatic = FALSE, objId$id, methodName, ...) : org.apache.spark.sql.AnalysisException: path file:/... already exists.; at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand.run(InsertIntoHadoopFsRelationCommand.scala:68) > write.text(df, "existing_path") Error in invokeJava(isStatic = FALSE, objId$id, methodName, ...) : org.apache.spark.sql.AnalysisException: path file:/... already exists.; at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand.run(InsertIntoHadoopFsRelationCommand.scala:68) > write.parquet(df, "existing_path") Error in invokeJava(isStatic = FALSE, objId$id, methodName, ...) : org.apache.spark.sql.AnalysisException: path file:/... already exists.; at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand.run(InsertIntoHadoopFsRelationCommand.scala:68) ``` **After** ``` r read.json("arbitrary_path", a = 1, 2, 3, "a") Unnamed arguments ignored: 2, 3, a. ``` ``` r > read.json("arbitrary_path") Error in json : analysis error - Path does not exist: file:/... > read.orc("arbitrary_path") Error in orc : analysis error - Path does not exist: file:/... > read.text("arbitrary_path") Error in text : analysis error - Path does not exist: file:/... > read.parquet("arbitrary_path") Error in parquet : analysis error - Path does not exist: file:/... ``` ``` r > write.json(df, "existing_path") Error in json : analysis error - path file:/... already exists.; > write.orc(df, "existing_path") Error in orc : analysis error - path file:/... already exists.; > write.text(df, "existing_path") Error in text : analysis error - path file:/... already exists.; > write.parquet(df, "existing_path") Error in parquet : analysis error - path file:/... already exists.; ``` ## How was this patch tested? Unit tests in `test_utils.R` and `test_sparkSQL.R`. Author: hyukjinkwon Closes #15608 from HyukjinKwon/SPARK-17838. --- R/pkg/R/DataFrame.R | 10 +++--- R/pkg/R/SQLContext.R | 17 ++++----- R/pkg/R/utils.R | 44 ++++++++++++++++------- R/pkg/inst/tests/testthat/test_sparkSQL.R | 16 +++++++++ R/pkg/inst/tests/testthat/test_utils.R | 2 ++ 5 files changed, 64 insertions(+), 25 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 1df8bbf9fe604..1cf9b38ea6483 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -788,7 +788,7 @@ setMethod("write.json", function(x, path, mode = "error", ...) { write <- callJMethod(x@sdf, "write") write <- setWriteOptions(write, mode = mode, ...) - invisible(callJMethod(write, "json", path)) + invisible(handledCallJMethod(write, "json", path)) }) #' Save the contents of SparkDataFrame as an ORC file, preserving the schema. @@ -819,7 +819,7 @@ setMethod("write.orc", function(x, path, mode = "error", ...) { write <- callJMethod(x@sdf, "write") write <- setWriteOptions(write, mode = mode, ...) - invisible(callJMethod(write, "orc", path)) + invisible(handledCallJMethod(write, "orc", path)) }) #' Save the contents of SparkDataFrame as a Parquet file, preserving the schema. @@ -851,7 +851,7 @@ setMethod("write.parquet", function(x, path, mode = "error", ...) { write <- callJMethod(x@sdf, "write") write <- setWriteOptions(write, mode = mode, ...) - invisible(callJMethod(write, "parquet", path)) + invisible(handledCallJMethod(write, "parquet", path)) }) #' @rdname write.parquet @@ -895,7 +895,7 @@ setMethod("write.text", function(x, path, mode = "error", ...) { write <- callJMethod(x@sdf, "write") write <- setWriteOptions(write, mode = mode, ...) - invisible(callJMethod(write, "text", path)) + invisible(handledCallJMethod(write, "text", path)) }) #' Distinct @@ -3342,7 +3342,7 @@ setMethod("write.jdbc", jprops <- varargsToJProperties(...) write <- callJMethod(x@sdf, "write") write <- callJMethod(write, "mode", jmode) - invisible(callJMethod(write, "jdbc", url, tableName, jprops)) + invisible(handledCallJMethod(write, "jdbc", url, tableName, jprops)) }) #' randomSplit diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 216ca51666ba8..38d83c6e5c52b 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -350,7 +350,7 @@ read.json.default <- function(path, ...) { paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) - sdf <- callJMethod(read, "json", paths) + sdf <- handledCallJMethod(read, "json", paths) dataFrame(sdf) } @@ -422,7 +422,7 @@ read.orc <- function(path, ...) { path <- suppressWarnings(normalizePath(path)) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) - sdf <- callJMethod(read, "orc", path) + sdf <- handledCallJMethod(read, "orc", path) dataFrame(sdf) } @@ -444,7 +444,7 @@ read.parquet.default <- function(path, ...) { paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) - sdf <- callJMethod(read, "parquet", paths) + sdf <- handledCallJMethod(read, "parquet", paths) dataFrame(sdf) } @@ -496,7 +496,7 @@ read.text.default <- function(path, ...) { paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) - sdf <- callJMethod(read, "text", paths) + sdf <- handledCallJMethod(read, "text", paths) dataFrame(sdf) } @@ -914,12 +914,13 @@ read.jdbc <- function(url, tableName, } else { numPartitions <- numToInt(numPartitions) } - sdf <- callJMethod(read, "jdbc", url, tableName, as.character(partitionColumn), - numToInt(lowerBound), numToInt(upperBound), numPartitions, jprops) + sdf <- handledCallJMethod(read, "jdbc", url, tableName, as.character(partitionColumn), + numToInt(lowerBound), numToInt(upperBound), numPartitions, jprops) } else if (length(predicates) > 0) { - sdf <- callJMethod(read, "jdbc", url, tableName, as.list(as.character(predicates)), jprops) + sdf <- handledCallJMethod(read, "jdbc", url, tableName, as.list(as.character(predicates)), + jprops) } else { - sdf <- callJMethod(read, "jdbc", url, tableName, jprops) + sdf <- handledCallJMethod(read, "jdbc", url, tableName, jprops) } dataFrame(sdf) } diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index c4e78cbb804d9..20004549cc037 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -338,21 +338,41 @@ varargsToEnv <- function(...) { # into string. varargsToStrEnv <- function(...) { pairs <- list(...) + nameList <- names(pairs) env <- new.env() - for (name in names(pairs)) { - value <- pairs[[name]] - if (!(is.logical(value) || is.numeric(value) || is.character(value) || is.null(value))) { - stop(paste0("Unsupported type for ", name, " : ", class(value), - ". Supported types are logical, numeric, character and NULL.")) - } - if (is.logical(value)) { - env[[name]] <- tolower(as.character(value)) - } else if (is.null(value)) { - env[[name]] <- value - } else { - env[[name]] <- as.character(value) + ignoredNames <- list() + + if (is.null(nameList)) { + # When all arguments are not named, names(..) returns NULL. + ignoredNames <- pairs + } else { + for (i in seq_along(pairs)) { + name <- nameList[i] + value <- pairs[i] + if (identical(name, "")) { + # When some of arguments are not named, name is "". + ignoredNames <- append(ignoredNames, value) + } else { + value <- pairs[[name]] + if (!(is.logical(value) || is.numeric(value) || is.character(value) || is.null(value))) { + stop(paste0("Unsupported type for ", name, " : ", class(value), + ". Supported types are logical, numeric, character and NULL."), call. = FALSE) + } + if (is.logical(value)) { + env[[name]] <- tolower(as.character(value)) + } else if (is.null(value)) { + env[[name]] <- value + } else { + env[[name]] <- as.character(value) + } + } } } + + if (length(ignoredNames) != 0) { + warning(paste0("Unnamed arguments ignored: ", paste(ignoredNames, collapse = ", "), "."), + call. = FALSE) + } env } diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 9289db57b6d63..806019d7524ff 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -2660,6 +2660,14 @@ test_that("Call DataFrameWriter.save() API in Java without path and check argume # DataFrameWriter.save() without path. expect_error(write.df(df, source = "csv"), "Error in save : illegal argument - 'path' is not specified") + expect_error(write.json(df, jsonPath), + "Error in json : analysis error - path file:.*already exists") + expect_error(write.text(df, jsonPath), + "Error in text : analysis error - path file:.*already exists") + expect_error(write.orc(df, jsonPath), + "Error in orc : analysis error - path file:.*already exists") + expect_error(write.parquet(df, jsonPath), + "Error in parquet : analysis error - path file:.*already exists") # Arguments checking in R side. expect_error(write.df(df, "data.tmp", source = c(1, 2)), @@ -2679,6 +2687,11 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume paste("Error in loadDF : analysis error - Unable to infer schema for JSON at .", "It must be specified manually")) expect_error(read.df("arbitrary_path"), "Error in loadDF : analysis error - Path does not exist") + expect_error(read.json("arbitrary_path"), "Error in json : analysis error - Path does not exist") + expect_error(read.text("arbitrary_path"), "Error in text : analysis error - Path does not exist") + expect_error(read.orc("arbitrary_path"), "Error in orc : analysis error - Path does not exist") + expect_error(read.parquet("arbitrary_path"), + "Error in parquet : analysis error - Path does not exist") # Arguments checking in R side. expect_error(read.df(path = c(3)), @@ -2686,6 +2699,9 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume expect_error(read.df(jsonPath, source = c(1, 2)), paste("source should be character, NULL or omitted. It is the datasource specified", "in 'spark.sql.sources.default' configuration by default.")) + + expect_warning(read.json(jsonPath, a = 1, 2, 3, "a"), + "Unnamed arguments ignored: 2, 3, a.") }) unlink(parquetPath) diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index a20254e9b3fa9..607c407f04f97 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -224,6 +224,8 @@ test_that("varargsToStrEnv", { expect_error(varargsToStrEnv(a = list(1, "a")), paste0("Unsupported type for a : list. Supported types are logical, ", "numeric, character and NULL.")) + expect_warning(varargsToStrEnv(a = 1, 2, 3, 4), "Unnamed arguments ignored: 2, 3, 4.") + expect_warning(varargsToStrEnv(1, 2, 3, 4), "Unnamed arguments ignored: 1, 2, 3, 4.") }) sparkR.session.stop() From 1bbf9ff634745148e782370009aa31d3a042638c Mon Sep 17 00:00:00 2001 From: Michael Allman Date: Tue, 1 Nov 2016 22:20:19 -0700 Subject: [PATCH 013/198] [SPARK-17992][SQL] Return all partitions from HiveShim when Hive throws a metastore exception when attempting to fetch partitions by filter (Link to Jira issue: https://issues.apache.org/jira/browse/SPARK-17992) ## What changes were proposed in this pull request? We recently added table partition pruning for partitioned Hive tables converted to using `TableFileCatalog`. When the Hive configuration option `hive.metastore.try.direct.sql` is set to `false`, Hive will throw an exception for unsupported filter expressions. For example, attempting to filter on an integer partition column will throw a `org.apache.hadoop.hive.metastore.api.MetaException`. I discovered this behavior because VideoAmp uses the CDH version of Hive with a Postgresql metastore DB. In this configuration, CDH sets `hive.metastore.try.direct.sql` to `false` by default, and queries that filter on a non-string partition column will fail. Rather than throw an exception in query planning, this patch catches this exception, logs a warning and returns all table partitions instead. Clients of this method are already expected to handle the possibility that the filters will not be honored. ## How was this patch tested? A unit test was added. Author: Michael Allman Closes #15673 from mallman/spark-17992-catch_hive_partition_filter_exception. --- .../spark/sql/hive/client/HiveShim.scala | 31 ++++++-- .../sql/hive/client/HiveClientBuilder.scala | 56 ++++++++++++++ .../sql/hive/client/HiveClientSuite.scala | 61 +++++++++++++++ .../spark/sql/hive/client/VersionsSuite.scala | 77 +++++-------------- 4 files changed, 160 insertions(+), 65 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 85edaf63db886..3d9642dd1463d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -29,7 +29,7 @@ import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.metastore.api.{Function => HiveFunction, FunctionType, NoSuchObjectException, PrincipalType, ResourceType, ResourceUri} +import org.apache.hadoop.hive.metastore.api.{Function => HiveFunction, FunctionType, MetaException, PrincipalType, ResourceType, ResourceUri} import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.{Hive, HiveException, Partition, Table} import org.apache.hadoop.hive.ql.plan.AddPartitionDesc @@ -43,6 +43,7 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchPermanentFunctionException import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, CatalogTablePartition, FunctionResource, FunctionResourceType} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegralType, StringType} import org.apache.spark.util.Utils @@ -586,17 +587,31 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]] } else { logDebug(s"Hive metastore filter is '$filter'.") + val tryDirectSqlConfVar = HiveConf.ConfVars.METASTORE_TRY_DIRECT_SQL + val tryDirectSql = + hive.getConf.getBoolean(tryDirectSqlConfVar.varname, tryDirectSqlConfVar.defaultBoolVal) try { + // Hive may throw an exception when calling this method in some circumstances, such as + // when filtering on a non-string partition column when the hive config key + // hive.metastore.try.direct.sql is false getPartitionsByFilterMethod.invoke(hive, table, filter) .asInstanceOf[JArrayList[Partition]] } catch { - case e: InvocationTargetException => - // SPARK-18167 retry to investigate the flaky test. This should be reverted before - // the release is cut. - val retry = Try(getPartitionsByFilterMethod.invoke(hive, table, filter)) - logError("getPartitionsByFilter failed, retry success = " + retry.isSuccess) - logError("all partitions: " + getAllPartitions(hive, table)) - throw e + case ex: InvocationTargetException if ex.getCause.isInstanceOf[MetaException] && + !tryDirectSql => + logWarning("Caught Hive MetaException attempting to get partition metadata by " + + "filter from Hive. Falling back to fetching all partition metadata, which will " + + "degrade performance. Modifying your Hive metastore configuration to set " + + s"${tryDirectSqlConfVar.varname} to true may resolve this problem.", ex) + // HiveShim clients are expected to handle a superset of the requested partitions + getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]] + case ex: InvocationTargetException if ex.getCause.isInstanceOf[MetaException] && + tryDirectSql => + throw new RuntimeException("Caught Hive MetaException attempting to get partition " + + "metadata by filter from Hive. You can set the Spark configuration setting " + + s"${SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key} to false to work around this " + + "problem, however this will result in degraded performance. Please report a bug: " + + "https://issues.apache.org/jira/browse/SPARK", ex) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala new file mode 100644 index 0000000000000..591a968c82847 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import java.io.File + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.util.VersionInfo + +import org.apache.spark.SparkConf +import org.apache.spark.util.Utils + +private[client] class HiveClientBuilder { + private val sparkConf = new SparkConf() + + // In order to speed up test execution during development or in Jenkins, you can specify the path + // of an existing Ivy cache: + private val ivyPath: Option[String] = { + sys.env.get("SPARK_VERSIONS_SUITE_IVY_PATH").orElse( + Some(new File(sys.props("java.io.tmpdir"), "hive-ivy-cache").getAbsolutePath)) + } + + private def buildConf() = { + lazy val warehousePath = Utils.createTempDir() + lazy val metastorePath = Utils.createTempDir() + metastorePath.delete() + Map( + "javax.jdo.option.ConnectionURL" -> s"jdbc:derby:;databaseName=$metastorePath;create=true", + "hive.metastore.warehouse.dir" -> warehousePath.toString) + } + + def buildClient(version: String, hadoopConf: Configuration): HiveClient = { + IsolatedClientLoader.forVersion( + hiveMetastoreVersion = version, + hadoopVersion = VersionInfo.getVersion, + sparkConf = sparkConf, + hadoopConf = hadoopConf, + config = buildConf(), + ivyPath = ivyPath).createClient() + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala new file mode 100644 index 0000000000000..4790331168bd2 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.conf.HiveConf + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal} +import org.apache.spark.sql.hive.HiveUtils +import org.apache.spark.sql.types.IntegerType + +class HiveClientSuite extends SparkFunSuite { + private val clientBuilder = new HiveClientBuilder + + private val tryDirectSqlKey = HiveConf.ConfVars.METASTORE_TRY_DIRECT_SQL.varname + + test(s"getPartitionsByFilter returns all partitions when $tryDirectSqlKey=false") { + val testPartitionCount = 5 + + val storageFormat = CatalogStorageFormat( + locationUri = None, + inputFormat = None, + outputFormat = None, + serde = None, + compressed = false, + properties = Map.empty) + + val hadoopConf = new Configuration() + hadoopConf.setBoolean(tryDirectSqlKey, false) + val client = clientBuilder.buildClient(HiveUtils.hiveExecutionVersion, hadoopConf) + client.runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (part INT)") + + val partitions = (1 to testPartitionCount).map { part => + CatalogTablePartition(Map("part" -> part.toString), storageFormat) + } + client.createPartitions( + "default", "test", partitions, ignoreIfExists = false) + + val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"), + Seq(EqualTo(AttributeReference("part", IntegerType)(), Literal(3)))) + + assert(filteredPartitions.size == testPartitionCount) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 9a10957c8efa5..081b0ed9bd688 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -23,9 +23,8 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.mapred.TextInputFormat -import org.apache.hadoop.util.VersionInfo -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} @@ -48,46 +47,19 @@ import org.apache.spark.util.{MutableURLClassLoader, Utils} @ExtendedHiveTest class VersionsSuite extends SparkFunSuite with Logging { - private val sparkConf = new SparkConf() - - // In order to speed up test execution during development or in Jenkins, you can specify the path - // of an existing Ivy cache: - private val ivyPath: Option[String] = { - sys.env.get("SPARK_VERSIONS_SUITE_IVY_PATH").orElse( - Some(new File(sys.props("java.io.tmpdir"), "hive-ivy-cache").getAbsolutePath)) - } - - private def buildConf() = { - lazy val warehousePath = Utils.createTempDir() - lazy val metastorePath = Utils.createTempDir() - metastorePath.delete() - Map( - "javax.jdo.option.ConnectionURL" -> s"jdbc:derby:;databaseName=$metastorePath;create=true", - "hive.metastore.warehouse.dir" -> warehousePath.toString) - } + private val clientBuilder = new HiveClientBuilder + import clientBuilder.buildClient test("success sanity check") { - val badClient = IsolatedClientLoader.forVersion( - hiveMetastoreVersion = HiveUtils.hiveExecutionVersion, - hadoopVersion = VersionInfo.getVersion, - sparkConf = sparkConf, - hadoopConf = new Configuration(), - config = buildConf(), - ivyPath = ivyPath).createClient() + val badClient = buildClient(HiveUtils.hiveExecutionVersion, new Configuration()) val db = new CatalogDatabase("default", "desc", "loc", Map()) badClient.createDatabase(db, ignoreIfExists = true) } test("hadoop configuration preserved") { - val hadoopConf = new Configuration(); + val hadoopConf = new Configuration() hadoopConf.set("test", "success") - val client = IsolatedClientLoader.forVersion( - hiveMetastoreVersion = HiveUtils.hiveExecutionVersion, - hadoopVersion = VersionInfo.getVersion, - sparkConf = sparkConf, - hadoopConf = hadoopConf, - config = buildConf(), - ivyPath = ivyPath).createClient() + val client = buildClient(HiveUtils.hiveExecutionVersion, hadoopConf) assert("success" === client.getConf("test", null)) } @@ -109,15 +81,7 @@ class VersionsSuite extends SparkFunSuite with Logging { // TODO: currently only works on mysql where we manually create the schema... ignore("failure sanity check") { val e = intercept[Throwable] { - val badClient = quietly { - IsolatedClientLoader.forVersion( - hiveMetastoreVersion = "13", - hadoopVersion = VersionInfo.getVersion, - sparkConf = sparkConf, - hadoopConf = new Configuration(), - config = buildConf(), - ivyPath = ivyPath).createClient() - } + val badClient = quietly { buildClient("13", new Configuration()) } } assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'") } @@ -130,16 +94,9 @@ class VersionsSuite extends SparkFunSuite with Logging { test(s"$version: create client") { client = null System.gc() // Hack to avoid SEGV on some JVM versions. - val hadoopConf = new Configuration(); + val hadoopConf = new Configuration() hadoopConf.set("test", "success") - client = - IsolatedClientLoader.forVersion( - hiveMetastoreVersion = version, - hadoopVersion = VersionInfo.getVersion, - sparkConf = sparkConf, - hadoopConf = hadoopConf, - config = buildConf(), - ivyPath = ivyPath).createClient() + client = buildClient(version, hadoopConf) } def table(database: String, tableName: String): CatalogTable = { @@ -287,15 +244,19 @@ class VersionsSuite extends SparkFunSuite with Logging { client.runSqlHive("CREATE TABLE src_part (value INT) PARTITIONED BY (key1 INT, key2 INT)") } + val testPartitionCount = 2 + test(s"$version: createPartitions") { - val partition1 = CatalogTablePartition(Map("key1" -> "1", "key2" -> "1"), storageFormat) - val partition2 = CatalogTablePartition(Map("key1" -> "1", "key2" -> "2"), storageFormat) + val partitions = (1 to testPartitionCount).map { key2 => + CatalogTablePartition(Map("key1" -> "1", "key2" -> key2.toString), storageFormat) + } client.createPartitions( - "default", "src_part", Seq(partition1, partition2), ignoreIfExists = true) + "default", "src_part", partitions, ignoreIfExists = true) } test(s"$version: getPartitions(catalogTable)") { - assert(2 == client.getPartitions(client.getTable("default", "src_part")).size) + assert(testPartitionCount == + client.getPartitions(client.getTable("default", "src_part")).size) } test(s"$version: getPartitionsByFilter") { @@ -306,6 +267,8 @@ class VersionsSuite extends SparkFunSuite with Logging { // Hive 0.12 doesn't support getPartitionsByFilter, it ignores the filter condition. if (version != "0.12") { assert(result.size == 1) + } else { + assert(result.size == testPartitionCount) } } @@ -327,7 +290,7 @@ class VersionsSuite extends SparkFunSuite with Logging { } test(s"$version: getPartitions(db: String, table: String)") { - assert(2 == client.getPartitions("default", "src_part", None).size) + assert(testPartitionCount == client.getPartitions("default", "src_part", None).size) } test(s"$version: loadPartition") { From 620da3b4828b3580c7ed7339b2a07938e6be1bb1 Mon Sep 17 00:00:00 2001 From: frreiss Date: Tue, 1 Nov 2016 23:00:17 -0700 Subject: [PATCH 014/198] [SPARK-17475][STREAMING] Delete CRC files if the filesystem doesn't use checksum files ## What changes were proposed in this pull request? When the metadata logs for various parts of Structured Streaming are stored on non-HDFS filesystems such as NFS or ext4, the HDFSMetadataLog class leaves hidden HDFS-style checksum (CRC) files in the log directory, one file per batch. This PR modifies HDFSMetadataLog so that it detects the use of a filesystem that doesn't use CRC files and removes the CRC files. ## How was this patch tested? Modified an existing test case in HDFSMetadataLogSuite to check whether HDFSMetadataLog correctly removes CRC files on the local POSIX filesystem. Ran the entire regression suite. Author: frreiss Closes #15027 from frreiss/fred-17475. --- .../spark/sql/execution/streaming/HDFSMetadataLog.scala | 5 +++++ .../sql/execution/streaming/HDFSMetadataLogSuite.scala | 6 ++++++ 2 files changed, 11 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index c7235320fd6bd..9a0f87cf0498c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -148,6 +148,11 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String) // It will fail if there is an existing file (someone has committed the batch) logDebug(s"Attempting to write log #${batchIdToPath(batchId)}") fileManager.rename(tempPath, batchIdToPath(batchId)) + + // SPARK-17475: HDFSMetadataLog should not leak CRC files + // If the underlying filesystem didn't rename the CRC file, delete it. + val crcPath = new Path(tempPath.getParent(), s".${tempPath.getName()}.crc") + if (fileManager.exists(crcPath)) fileManager.delete(crcPath) return } catch { case e: IOException if isFileAlreadyExistsException(e) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index 9c1d26dcb2241..d03e08d9a576c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -119,6 +119,12 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { assert(metadataLog.get(1).isEmpty) assert(metadataLog.get(2).isDefined) assert(metadataLog.getLatest().get._1 == 2) + + // There should be exactly one file, called "2", in the metadata directory. + // This check also tests for regressions of SPARK-17475 + val allFiles = new File(metadataLog.metadataPath.toString).listFiles().toSeq + assert(allFiles.size == 1) + assert(allFiles(0).getName() == "2") } } From abefe2ec428dc24a4112c623fb6fbe4b2ca60a2b Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 2 Nov 2016 14:15:10 +0800 Subject: [PATCH 015/198] [SPARK-18183][SPARK-18184] Fix INSERT [INTO|OVERWRITE] TABLE ... PARTITION for Datasource tables ## What changes were proposed in this pull request? There are a couple issues with the current 2.1 behavior when inserting into Datasource tables with partitions managed by Hive. (1) OVERWRITE TABLE ... PARTITION will actually overwrite the entire table instead of just the specified partition. (2) INSERT|OVERWRITE does not work with partitions that have custom locations. This PR fixes both of these issues for Datasource tables managed by Hive. The behavior for legacy tables or when `manageFilesourcePartitions = false` is unchanged. There is one other issue in that INSERT OVERWRITE with dynamic partitions will overwrite the entire table instead of just the updated partitions, but this behavior is pretty complicated to implement for Datasource tables. We should address that in a future release. ## How was this patch tested? Unit tests. Author: Eric Liang Closes #15705 from ericl/sc-4942. --- .../spark/sql/catalyst/dsl/package.scala | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 9 +++- .../plans/logical/basicLogicalOperators.scala | 19 ++++++- .../sql/catalyst/parser/PlanParserSuite.scala | 15 ++++-- .../apache/spark/sql/DataFrameWriter.scala | 4 +- .../datasources/CatalogFileIndex.scala | 5 +- .../datasources/DataSourceStrategy.scala | 30 +++++++++-- .../InsertIntoDataSourceCommand.scala | 6 +-- .../spark/sql/hive/HiveStrategies.scala | 3 +- .../CreateHiveTableAsSelectCommand.scala | 5 +- .../PartitionProviderCompatibilitySuite.scala | 52 +++++++++++++++++++ 11 files changed, 129 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 66e52ca68af19..e901683be6854 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -367,7 +367,7 @@ package object dsl { def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = InsertIntoTable( analysis.UnresolvedRelation(TableIdentifier(tableName)), - Map.empty, logicalPlan, overwrite, false) + Map.empty, logicalPlan, OverwriteOptions(overwrite), false) def as(alias: String): LogicalPlan = logicalPlan match { case UnresolvedRelation(tbl, _) => UnresolvedRelation(tbl, Option(alias)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 38e9bb6c162ad..ac1577b3abb4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -177,12 +177,19 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { throw new ParseException(s"Dynamic partitions do not support IF NOT EXISTS. Specified " + "partitions with value: " + dynamicPartitionKeys.keys.mkString("[", ",", "]"), ctx) } + val overwrite = ctx.OVERWRITE != null + val overwritePartition = + if (overwrite && partitionKeys.nonEmpty && dynamicPartitionKeys.isEmpty) { + Some(partitionKeys.map(t => (t._1, t._2.get))) + } else { + None + } InsertIntoTable( UnresolvedRelation(tableIdent, None), partitionKeys, query, - ctx.OVERWRITE != null, + OverwriteOptions(overwrite, overwritePartition), ctx.EXISTS != null) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index a48974c6322ad..7a15c2285d584 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.catalyst.catalog.CatalogTypes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ @@ -345,18 +346,32 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { override lazy val statistics: Statistics = super.statistics.copy(isBroadcastable = true) } +/** + * Options for writing new data into a table. + * + * @param enabled whether to overwrite existing data in the table. + * @param specificPartition only data in the specified partition will be overwritten. + */ +case class OverwriteOptions( + enabled: Boolean, + specificPartition: Option[CatalogTypes.TablePartitionSpec] = None) { + if (specificPartition.isDefined) { + assert(enabled, "Overwrite must be enabled when specifying a partition to overwrite.") + } +} + case class InsertIntoTable( table: LogicalPlan, partition: Map[String, Option[String]], child: LogicalPlan, - overwrite: Boolean, + overwrite: OverwriteOptions, ifNotExists: Boolean) extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil override def output: Seq[Attribute] = Seq.empty - assert(overwrite || !ifNotExists) + assert(overwrite.enabled || !ifNotExists) assert(partition.values.forall(_.nonEmpty) || !ifNotExists) override lazy val resolved: Boolean = childrenResolved && table.resolved diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index ca86304d4d400..7400f3430e99c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -180,7 +180,16 @@ class PlanParserSuite extends PlanTest { partition: Map[String, Option[String]], overwrite: Boolean = false, ifNotExists: Boolean = false): LogicalPlan = - InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists) + InsertIntoTable( + table("s"), partition, plan, + OverwriteOptions( + overwrite, + if (overwrite && partition.nonEmpty) { + Some(partition.map(kv => (kv._1, kv._2.get))) + } else { + None + }), + ifNotExists) // Single inserts assertEqual(s"insert overwrite table s $sql", @@ -196,9 +205,9 @@ class PlanParserSuite extends PlanTest { val plan2 = table("t").where('x > 5).select(star()) assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5", InsertIntoTable( - table("s"), Map.empty, plan.limit(1), overwrite = false, ifNotExists = false).union( + table("s"), Map.empty, plan.limit(1), OverwriteOptions(false), ifNotExists = false).union( InsertIntoTable( - table("u"), Map.empty, plan2, overwrite = false, ifNotExists = false))) + table("u"), Map.empty, plan2, OverwriteOptions(false), ifNotExists = false))) } test ("insert with if not exists") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 11dd1df909938..700f4835ac89a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -25,7 +25,7 @@ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType} -import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Union} +import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, OverwriteOptions, Union} import org.apache.spark.sql.execution.command.AlterTableRecoverPartitionsCommand import org.apache.spark.sql.execution.datasources.{CaseInsensitiveMap, CreateTable, DataSource, HadoopFsRelation} import org.apache.spark.sql.types.StructType @@ -259,7 +259,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { table = UnresolvedRelation(tableIdent), partition = Map.empty[String, Option[String]], child = df.logicalPlan, - overwrite = mode == SaveMode.Overwrite, + overwrite = OverwriteOptions(mode == SaveMode.Overwrite), ifNotExists = false)).toRdd } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala index 092aabc89a36c..443a2ec033a98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala @@ -67,7 +67,10 @@ class CatalogFileIndex( val selectedPartitions = sparkSession.sessionState.catalog.listPartitionsByFilter( table.identifier, filters) val partitions = selectedPartitions.map { p => - PartitionPath(p.toRow(partitionSchema), p.storage.locationUri.get) + val path = new Path(p.storage.locationUri.get) + val fs = path.getFileSystem(hadoopConf) + PartitionPath( + p.toRow(partitionSchema), path.makeQualified(fs.getUri, fs.getWorkingDirectory)) } val partitionSpec = PartitionSpec(partitionSchema, partitions) new PrunedInMemoryFileIndex( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 34b77cab65def..47c1f9d3fac1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources import scala.collection.mutable.ArrayBuffer +import org.apache.hadoop.fs.Path + import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ @@ -174,14 +176,32 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { case LogicalRelation(r: HadoopFsRelation, _, _) => r.location.rootPaths }.flatten - val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append - if (overwrite && inputPaths.contains(outputPath)) { + val mode = if (overwrite.enabled) SaveMode.Overwrite else SaveMode.Append + if (overwrite.enabled && inputPaths.contains(outputPath)) { throw new AnalysisException( "Cannot overwrite a path that is also being read from.") } + val overwritingSinglePartition = (overwrite.specificPartition.isDefined && + t.sparkSession.sessionState.conf.manageFilesourcePartitions && + l.catalogTable.get.partitionProviderIsHive) + + val effectiveOutputPath = if (overwritingSinglePartition) { + val partition = t.sparkSession.sessionState.catalog.getPartition( + l.catalogTable.get.identifier, overwrite.specificPartition.get) + new Path(partition.storage.locationUri.get) + } else { + outputPath + } + + val effectivePartitionSchema = if (overwritingSinglePartition) { + Nil + } else { + query.resolve(t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver) + } + def refreshPartitionsCallback(updatedPartitions: Seq[TablePartitionSpec]): Unit = { - if (l.catalogTable.isDefined && + if (l.catalogTable.isDefined && updatedPartitions.nonEmpty && l.catalogTable.get.partitionColumnNames.nonEmpty && l.catalogTable.get.partitionProviderIsHive) { val metastoreUpdater = AlterTableAddPartitionCommand( @@ -194,8 +214,8 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { } val insertCmd = InsertIntoHadoopFsRelationCommand( - outputPath, - query.resolve(t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver), + effectiveOutputPath, + effectivePartitionSchema, t.bucketSpec, t.fileFormat, refreshPartitionsCallback, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala index b2ff68a833fea..2eba1e9986acd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OverwriteOptions} import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.sources.InsertableRelation @@ -30,7 +30,7 @@ import org.apache.spark.sql.sources.InsertableRelation case class InsertIntoDataSourceCommand( logicalRelation: LogicalRelation, query: LogicalPlan, - overwrite: Boolean) + overwrite: OverwriteOptions) extends RunnableCommand { override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query) @@ -40,7 +40,7 @@ case class InsertIntoDataSourceCommand( val data = Dataset.ofRows(sparkSession, query) // Apply the schema of the existing table to the new data. val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema) - relation.insert(df, overwrite) + relation.insert(df, overwrite.enabled) // Invalidate the cache. sparkSession.sharedState.cacheManager.invalidateCache(logicalRelation) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 9d2930948d6ba..ce1e3eb1a5bc9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -46,7 +46,8 @@ private[hive] trait HiveStrategies { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.InsertIntoTable( table: MetastoreRelation, partition, child, overwrite, ifNotExists) => - InsertIntoHiveTable(table, partition, planLater(child), overwrite, ifNotExists) :: Nil + InsertIntoHiveTable( + table, partition, planLater(child), overwrite.enabled, ifNotExists) :: Nil case CreateTable(tableDesc, mode, Some(query)) if tableDesc.provider.get == "hive" => val newTableDesc = if (tableDesc.storage.serde.isEmpty) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index ef5a5a001fb6f..cac43597aef21 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -21,7 +21,7 @@ import scala.util.control.NonFatal import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.catalog.CatalogTable -import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan, OverwriteOptions} import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.hive.MetastoreRelation @@ -88,7 +88,8 @@ case class CreateHiveTableAsSelectCommand( } else { try { sparkSession.sessionState.executePlan(InsertIntoTable( - metastoreRelation, Map(), query, overwrite = true, ifNotExists = false)).toRdd + metastoreRelation, Map(), query, overwrite = OverwriteOptions(true), + ifNotExists = false)).toRdd } catch { case NonFatal(e) => // drop the created table. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala index 5f16960fb1496..ac435bf6195b0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala @@ -134,4 +134,56 @@ class PartitionProviderCompatibilitySuite } } } + + test("insert overwrite partition of legacy datasource table overwrites entire table") { + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "false") { + withTable("test") { + withTempDir { dir => + setupPartitionedDatasourceTable("test", dir) + spark.sql( + """insert overwrite table test + |partition (partCol=1) + |select * from range(100)""".stripMargin) + assert(spark.sql("select * from test").count() == 100) + + // Dynamic partitions case + spark.sql("insert overwrite table test select id, id from range(10)".stripMargin) + assert(spark.sql("select * from test").count() == 10) + } + } + } + } + + test("insert overwrite partition of new datasource table overwrites just partition") { + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "true") { + withTable("test") { + withTempDir { dir => + setupPartitionedDatasourceTable("test", dir) + sql("msck repair table test") + spark.sql( + """insert overwrite table test + |partition (partCol=1) + |select * from range(100)""".stripMargin) + assert(spark.sql("select * from test").count() == 104) + + // Test overwriting a partition that has a custom location + withTempDir { dir2 => + sql( + s"""alter table test partition (partCol=1) + |set location '${dir2.getAbsolutePath}'""".stripMargin) + assert(sql("select * from test").count() == 4) + sql( + """insert overwrite table test + |partition (partCol=1) + |select * from range(30)""".stripMargin) + sql( + """insert overwrite table test + |partition (partCol=1) + |select * from range(20)""".stripMargin) + assert(sql("select * from test").count() == 24) + } + } + } + } + } } From a36653c5b7b2719f8bfddf4ddfc6e1b828ac9af1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 1 Nov 2016 23:37:03 -0700 Subject: [PATCH 016/198] [SPARK-18192] Support all file formats in structured streaming ## What changes were proposed in this pull request? This patch adds support for all file formats in structured streaming sinks. This is actually a very small change thanks to all the previous refactoring done using the new internal commit protocol API. ## How was this patch tested? Updated FileStreamSinkSuite to add test cases for json, text, and parquet. Author: Reynold Xin Closes #15711 from rxin/SPARK-18192. --- .../execution/datasources/DataSource.scala | 8 +-- .../sql/streaming/FileStreamSinkSuite.scala | 62 +++++++++---------- 2 files changed, 32 insertions(+), 38 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index d980e6a15aabe..3f956c427655e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -29,7 +29,6 @@ import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat @@ -37,7 +36,6 @@ import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{CalendarIntervalType, StructType} @@ -292,7 +290,7 @@ case class DataSource( case s: StreamSinkProvider => s.createSink(sparkSession.sqlContext, options, partitionColumns, outputMode) - case parquet: parquet.ParquetFileFormat => + case fileFormat: FileFormat => val caseInsensitiveOptions = new CaseInsensitiveMap(options) val path = caseInsensitiveOptions.getOrElse("path", { throw new IllegalArgumentException("'path' is not specified") @@ -301,7 +299,7 @@ case class DataSource( throw new IllegalArgumentException( s"Data source $className does not support $outputMode output mode") } - new FileStreamSink(sparkSession, path, parquet, partitionColumns, options) + new FileStreamSink(sparkSession, path, fileFormat, partitionColumns, options) case _ => throw new UnsupportedOperationException( @@ -516,7 +514,7 @@ case class DataSource( val plan = data.logicalPlan plan.resolve(name :: Nil, data.sparkSession.sessionState.analyzer.resolver).getOrElse { throw new AnalysisException( - s"Unable to resolve ${name} given [${plan.output.map(_.name).mkString(", ")}]") + s"Unable to resolve $name given [${plan.output.map(_.name).mkString(", ")}]") }.asInstanceOf[Attribute] } // For partitioned relation r, r.schema's column ordering can be different from the column diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 902cf05344716..0f140f94f630e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming -import org.apache.spark.sql._ +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.streaming.{MemoryStream, MetadataLogFileIndex} @@ -142,42 +142,38 @@ class FileStreamSinkSuite extends StreamTest { } } - test("FileStreamSink - supported formats") { - def testFormat(format: Option[String]): Unit = { - val inputData = MemoryStream[Int] - val ds = inputData.toDS() + test("FileStreamSink - parquet") { + testFormat(None) // should not throw error as default format parquet when not specified + testFormat(Some("parquet")) + } - val outputDir = Utils.createTempDir(namePrefix = "stream.output").getCanonicalPath - val checkpointDir = Utils.createTempDir(namePrefix = "stream.checkpoint").getCanonicalPath + test("FileStreamSink - text") { + testFormat(Some("text")) + } - var query: StreamingQuery = null + test("FileStreamSink - json") { + testFormat(Some("text")) + } - try { - val writer = - ds.map(i => (i, i * 1000)) - .toDF("id", "value") - .writeStream - if (format.nonEmpty) { - writer.format(format.get) - } - query = writer - .option("checkpointLocation", checkpointDir) - .start(outputDir) - } finally { - if (query != null) { - query.stop() - } - } - } + def testFormat(format: Option[String]): Unit = { + val inputData = MemoryStream[Int] + val ds = inputData.toDS() - testFormat(None) // should not throw error as default format parquet when not specified - testFormat(Some("parquet")) - val e = intercept[UnsupportedOperationException] { - testFormat(Some("text")) - } - Seq("text", "not support", "stream").foreach { s => - assert(e.getMessage.contains(s)) + val outputDir = Utils.createTempDir(namePrefix = "stream.output").getCanonicalPath + val checkpointDir = Utils.createTempDir(namePrefix = "stream.checkpoint").getCanonicalPath + + var query: StreamingQuery = null + + try { + val writer = ds.map(i => (i, i * 1000)).toDF("id", "value").writeStream + if (format.nonEmpty) { + writer.format(format.get) + } + query = writer.option("checkpointLocation", checkpointDir).start(outputDir) + } finally { + if (query != null) { + query.stop() + } } } - } From 85c5424d466f4a5765c825e0e2ab30da97611285 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Tue, 1 Nov 2016 23:39:53 -0700 Subject: [PATCH 017/198] [SPARK-18144][SQL] logging StreamingQueryListener$QueryStartedEvent ## What changes were proposed in this pull request? The PR fixes the bug that the QueryStartedEvent is not logged the postToAll() in the original code is actually calling StreamingQueryListenerBus.postToAll() which has no listener at all....we shall post by sparkListenerBus.postToAll(s) and this.postToAll() to trigger local listeners as well as the listeners registered in LiveListenerBus zsxwing ## How was this patch tested? The following snapshot shows that QueryStartedEvent has been logged correctly ![image](https://cloud.githubusercontent.com/assets/678008/19821553/007a7d28-9d2d-11e6-9f13-49851559cdaa.png) Author: CodingCat Closes #15675 from CodingCat/SPARK-18144. --- .../streaming/StreamingQueryListenerBus.scala | 10 +++++++++- .../spark/sql/streaming/StreamingQuerySuite.scala | 7 ++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala index fc2190d39da4f..22e4c6380fcd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala @@ -41,6 +41,8 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus) def post(event: StreamingQueryListener.Event) { event match { case s: QueryStartedEvent => + sparkListenerBus.post(s) + // post to local listeners to trigger callbacks postToAll(s) case _ => sparkListenerBus.post(event) @@ -50,7 +52,13 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus) override def onOtherEvent(event: SparkListenerEvent): Unit = { event match { case e: StreamingQueryListener.Event => - postToAll(e) + // SPARK-18144: we broadcast QueryStartedEvent to all listeners attached to this bus + // synchronously and the ones attached to LiveListenerBus asynchronously. Therefore, + // we need to ignore QueryStartedEvent if this method is called within SparkListenerBus + // thread + if (!LiveListenerBus.withinListenerThread.value || !e.isInstanceOf[QueryStartedEvent]) { + postToAll(e) + } case _ => } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 464c443beb6e7..31b7fe0b04da9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -290,7 +290,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { // A StreamingQueryListener that gets the query status after the first completed trigger val listener = new StreamingQueryListener { @volatile var firstStatus: StreamingQueryStatus = null - override def onQueryStarted(queryStarted: QueryStartedEvent): Unit = { } + @volatile var queryStartedEvent = 0 + override def onQueryStarted(queryStarted: QueryStartedEvent): Unit = { + queryStartedEvent += 1 + } override def onQueryProgress(queryProgress: QueryProgressEvent): Unit = { if (firstStatus == null) firstStatus = queryProgress.queryStatus } @@ -303,6 +306,8 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { q.processAllAvailable() eventually(timeout(streamingTimeout)) { assert(listener.firstStatus != null) + // test if QueryStartedEvent callback is called for only once + assert(listener.queryStartedEvent === 1) } listener.firstStatus } finally { From 2dc048081668665f85623839d5f663b402e42555 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 2 Nov 2016 00:08:30 -0700 Subject: [PATCH 018/198] [SPARK-17532] Add lock debugging info to thread dumps. ## What changes were proposed in this pull request? This adds information to the web UI thread dump page about the JVM locks held by threads and the locks that threads are blocked waiting to acquire. This should help find cases where lock contention is causing Spark applications to run slowly. ## How was this patch tested? Tested by applying this patch and viewing the change in the web UI. ![thread-lock-info](https://cloud.githubusercontent.com/assets/87915/18493057/6e5da870-79c3-11e6-8c20-f54c18a37544.png) Additions: - A "Thread Locking" column with the locks held by the thread or that are blocking the thread - Links from the a blocked thread to the thread holding the lock - Stack frames show where threads are inside `synchronized` blocks, "holding Monitor(...)" Author: Ryan Blue Closes #15088 from rdblue/SPARK-17532-add-thread-lock-info. --- .../org/apache/spark/ui/static/table.js | 3 +- .../ui/exec/ExecutorThreadDumpPage.scala | 12 +++++++ .../apache/spark/util/ThreadStackTrace.scala | 6 +++- .../scala/org/apache/spark/util/Utils.scala | 34 ++++++++++++++++--- 4 files changed, 49 insertions(+), 6 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/table.js b/core/src/main/resources/org/apache/spark/ui/static/table.js index 14b06bfe860ed..0315ebf5c48a9 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/table.js +++ b/core/src/main/resources/org/apache/spark/ui/static/table.js @@ -36,7 +36,7 @@ function toggleThreadStackTrace(threadId, forceAdd) { if (stackTrace.length == 0) { var stackTraceText = $('#' + threadId + "_td_stacktrace").html() var threadCell = $("#thread_" + threadId + "_tr") - threadCell.after("
" +
+        threadCell.after("
" +
             stackTraceText +  "
") } else { if (!forceAdd) { @@ -73,6 +73,7 @@ function onMouseOverAndOut(threadId) { $("#" + threadId + "_td_id").toggleClass("threaddump-td-mouseover"); $("#" + threadId + "_td_name").toggleClass("threaddump-td-mouseover"); $("#" + threadId + "_td_state").toggleClass("threaddump-td-mouseover"); + $("#" + threadId + "_td_locking").toggleClass("threaddump-td-mouseover"); } function onSearchStringChange() { diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index a0ef80d9bdae0..c6a07445f2a35 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -48,6 +48,16 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage } }.map { thread => val threadId = thread.threadId + val blockedBy = thread.blockedByThreadId match { + case Some(blockedByThreadId) => + + case None => Text("") + } + val heldLocks = thread.holdingLocks.mkString(", ") + {threadId} {thread.threadName} {thread.threadState} + {blockedBy}{heldLocks} {thread.stackTrace} } @@ -86,6 +97,7 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage Thread ID Thread Name Thread State + Thread Locks {dumpRows} diff --git a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala index d4e0ad93b966a..b1217980faf1f 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala @@ -24,4 +24,8 @@ private[spark] case class ThreadStackTrace( threadId: Long, threadName: String, threadState: Thread.State, - stackTrace: String) + stackTrace: String, + blockedByThreadId: Option[Long], + blockedByLock: String, + holdingLocks: Seq[String]) + diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 6027b07c0fee8..22c28fba2087e 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -18,7 +18,7 @@ package org.apache.spark.util import java.io._ -import java.lang.management.ManagementFactory +import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo} import java.net._ import java.nio.ByteBuffer import java.nio.channels.Channels @@ -2096,15 +2096,41 @@ private[spark] object Utils extends Logging { } } + private implicit class Lock(lock: LockInfo) { + def lockString: String = { + lock match { + case monitor: MonitorInfo => + s"Monitor(${lock.getClassName}@${lock.getIdentityHashCode}})" + case _ => + s"Lock(${lock.getClassName}@${lock.getIdentityHashCode}})" + } + } + } + /** Return a thread dump of all threads' stacktraces. Used to capture dumps for the web UI */ def getThreadDump(): Array[ThreadStackTrace] = { // We need to filter out null values here because dumpAllThreads() may return null array // elements for threads that are dead / don't exist. val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_ != null) threadInfos.sortBy(_.getThreadId).map { case threadInfo => - val stackTrace = threadInfo.getStackTrace.map(_.toString).mkString("\n") - ThreadStackTrace(threadInfo.getThreadId, threadInfo.getThreadName, - threadInfo.getThreadState, stackTrace) + val monitors = threadInfo.getLockedMonitors.map(m => m.getLockedStackFrame -> m).toMap + val stackTrace = threadInfo.getStackTrace.map { frame => + monitors.get(frame) match { + case Some(monitor) => + monitor.getLockedStackFrame.toString + s" => holding ${monitor.lockString}" + case None => + frame.toString + } + }.mkString("\n") + + // use a set to dedup re-entrant locks that are held at multiple places + val heldLocks = (threadInfo.getLockedSynchronizers.map(_.lockString) + ++ threadInfo.getLockedMonitors.map(_.lockString) + ).toSet + + ThreadStackTrace(threadInfo.getThreadId, threadInfo.getThreadName, threadInfo.getThreadState, + stackTrace, if (threadInfo.getLockOwnerId < 0) None else Some(threadInfo.getLockOwnerId), + Option(threadInfo.getLockInfo).map(_.lockString).getOrElse(""), heldLocks.toSeq) } } From bcbe44440e6c871e217f06d2a4696fd41f1d2606 Mon Sep 17 00:00:00 2001 From: Maria Rydzy Date: Wed, 2 Nov 2016 09:09:16 +0000 Subject: [PATCH 019/198] [MINOR] Use <= for clarity in Pi examples' Monte Carlo process ## What changes were proposed in this pull request? If my understanding is correct we should be rather looking at closed disk than the opened one. ## How was this patch tested? Run simple comparison, of the mean squared error of approaches with closed and opened disk. https://gist.github.com/mrydzy/1cf0e5c316ef9d6fbd91426b91f1969f The closed one performed slightly better, but the tested sample wasn't too big, so I rely mostly on the algorithm understanding. Author: Maria Rydzy Closes #15687 from mrydzy/master. --- .../src/main/java/org/apache/spark/examples/JavaSparkPi.java | 2 +- examples/src/main/python/pi.py | 2 +- examples/src/main/scala/org/apache/spark/examples/LocalPi.scala | 2 +- examples/src/main/scala/org/apache/spark/examples/SparkPi.scala | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java index 7df145e3117b8..89855e81f1f7a 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java @@ -54,7 +54,7 @@ public static void main(String[] args) throws Exception { public Integer call(Integer integer) { double x = Math.random() * 2 - 1; double y = Math.random() * 2 - 1; - return (x * x + y * y < 1) ? 1 : 0; + return (x * x + y * y <= 1) ? 1 : 0; } }).reduce(new Function2() { @Override diff --git a/examples/src/main/python/pi.py b/examples/src/main/python/pi.py index e3f0c4aeef1b7..37029b76798f6 100755 --- a/examples/src/main/python/pi.py +++ b/examples/src/main/python/pi.py @@ -38,7 +38,7 @@ def f(_): x = random() * 2 - 1 y = random() * 2 - 1 - return 1 if x ** 2 + y ** 2 < 1 else 0 + return 1 if x ** 2 + y ** 2 <= 1 else 0 count = spark.sparkContext.parallelize(range(1, n + 1), partitions).map(f).reduce(add) print("Pi is roughly %f" % (4.0 * count / n)) diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala index 720d92fb9d029..121b768e4198e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala @@ -26,7 +26,7 @@ object LocalPi { for (i <- 1 to 100000) { val x = random * 2 - 1 val y = random * 2 - 1 - if (x*x + y*y < 1) count += 1 + if (x*x + y*y <= 1) count += 1 } println("Pi is roughly " + 4 * count / 100000.0) } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala index 272c1a4fc2f47..a5cacf17a5cca 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala @@ -34,7 +34,7 @@ object SparkPi { val count = spark.sparkContext.parallelize(1 until n, slices).map { i => val x = random * 2 - 1 val y = random * 2 - 1 - if (x*x + y*y < 1) 1 else 0 + if (x*x + y*y <= 1) 1 else 0 }.reduce(_ + _) println("Pi is roughly " + 4.0 * count / (n - 1)) spark.stop() From 98ede49496d0d7b4724085083d4f24436b92a7bf Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Wed, 2 Nov 2016 09:10:34 +0000 Subject: [PATCH 020/198] [SPARK-18198][DOC][STREAMING] Highlight code snippets ## What changes were proposed in this pull request? This patch uses `{% highlight lang %}...{% endhighlight %}` to highlight code snippets in the `Structured Streaming Kafka010 integration doc` and the `Spark Streaming Kafka010 integration doc`. This patch consists of two commits: - the first commit fixes only the leading spaces -- this is large - the second commit adds the highlight instructions -- this is much simpler and easier to review ## How was this patch tested? SKIP_API=1 jekyll build ## Screenshots **Before** ![snip20161101_3](https://cloud.githubusercontent.com/assets/15843379/19894258/47746524-a087-11e6-9a2a-7bff2d428d44.png) **After** ![snip20161101_1](https://cloud.githubusercontent.com/assets/15843379/19894324/8bebcd1e-a087-11e6-835b-88c4d2979cfa.png) Author: Liwei Lin Closes #15715 from lw-lin/doc-highlight-code-snippet. --- docs/streaming-kafka-0-10-integration.md | 391 +++++++++--------- .../structured-streaming-kafka-integration.md | 156 +++---- 2 files changed, 287 insertions(+), 260 deletions(-) diff --git a/docs/streaming-kafka-0-10-integration.md b/docs/streaming-kafka-0-10-integration.md index c1ef396907db7..b645d3c3a4b53 100644 --- a/docs/streaming-kafka-0-10-integration.md +++ b/docs/streaming-kafka-0-10-integration.md @@ -17,69 +17,72 @@ For Scala/Java applications using SBT/Maven project definitions, link your strea
- import org.apache.kafka.clients.consumer.ConsumerRecord - import org.apache.kafka.common.serialization.StringDeserializer - import org.apache.spark.streaming.kafka010._ - import org.apache.spark.streaming.kafka010.LocationStrategies.PreferConsistent - import org.apache.spark.streaming.kafka010.ConsumerStrategies.Subscribe - - val kafkaParams = Map[String, Object]( - "bootstrap.servers" -> "localhost:9092,anotherhost:9092", - "key.deserializer" -> classOf[StringDeserializer], - "value.deserializer" -> classOf[StringDeserializer], - "group.id" -> "use_a_separate_group_id_for_each_stream", - "auto.offset.reset" -> "latest", - "enable.auto.commit" -> (false: java.lang.Boolean) - ) - - val topics = Array("topicA", "topicB") - val stream = KafkaUtils.createDirectStream[String, String]( - streamingContext, - PreferConsistent, - Subscribe[String, String](topics, kafkaParams) - ) - - stream.map(record => (record.key, record.value)) - +{% highlight scala %} +import org.apache.kafka.clients.consumer.ConsumerRecord +import org.apache.kafka.common.serialization.StringDeserializer +import org.apache.spark.streaming.kafka010._ +import org.apache.spark.streaming.kafka010.LocationStrategies.PreferConsistent +import org.apache.spark.streaming.kafka010.ConsumerStrategies.Subscribe + +val kafkaParams = Map[String, Object]( + "bootstrap.servers" -> "localhost:9092,anotherhost:9092", + "key.deserializer" -> classOf[StringDeserializer], + "value.deserializer" -> classOf[StringDeserializer], + "group.id" -> "use_a_separate_group_id_for_each_stream", + "auto.offset.reset" -> "latest", + "enable.auto.commit" -> (false: java.lang.Boolean) +) + +val topics = Array("topicA", "topicB") +val stream = KafkaUtils.createDirectStream[String, String]( + streamingContext, + PreferConsistent, + Subscribe[String, String](topics, kafkaParams) +) + +stream.map(record => (record.key, record.value)) +{% endhighlight %} Each item in the stream is a [ConsumerRecord](http://kafka.apache.org/0100/javadoc/org/apache/kafka/clients/consumer/ConsumerRecord.html)
- import java.util.*; - import org.apache.spark.SparkConf; - import org.apache.spark.TaskContext; - import org.apache.spark.api.java.*; - import org.apache.spark.api.java.function.*; - import org.apache.spark.streaming.api.java.*; - import org.apache.spark.streaming.kafka010.*; - import org.apache.kafka.clients.consumer.ConsumerRecord; - import org.apache.kafka.common.TopicPartition; - import org.apache.kafka.common.serialization.StringDeserializer; - import scala.Tuple2; - - Map kafkaParams = new HashMap<>(); - kafkaParams.put("bootstrap.servers", "localhost:9092,anotherhost:9092"); - kafkaParams.put("key.deserializer", StringDeserializer.class); - kafkaParams.put("value.deserializer", StringDeserializer.class); - kafkaParams.put("group.id", "use_a_separate_group_id_for_each_stream"); - kafkaParams.put("auto.offset.reset", "latest"); - kafkaParams.put("enable.auto.commit", false); - - Collection topics = Arrays.asList("topicA", "topicB"); - - final JavaInputDStream> stream = - KafkaUtils.createDirectStream( - streamingContext, - LocationStrategies.PreferConsistent(), - ConsumerStrategies.Subscribe(topics, kafkaParams) - ); - - stream.mapToPair( - new PairFunction, String, String>() { - @Override - public Tuple2 call(ConsumerRecord record) { - return new Tuple2<>(record.key(), record.value()); - } - }) +{% highlight java %} +import java.util.*; +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.*; +import org.apache.spark.streaming.api.java.*; +import org.apache.spark.streaming.kafka010.*; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.StringDeserializer; +import scala.Tuple2; + +Map kafkaParams = new HashMap<>(); +kafkaParams.put("bootstrap.servers", "localhost:9092,anotherhost:9092"); +kafkaParams.put("key.deserializer", StringDeserializer.class); +kafkaParams.put("value.deserializer", StringDeserializer.class); +kafkaParams.put("group.id", "use_a_separate_group_id_for_each_stream"); +kafkaParams.put("auto.offset.reset", "latest"); +kafkaParams.put("enable.auto.commit", false); + +Collection topics = Arrays.asList("topicA", "topicB"); + +final JavaInputDStream> stream = + KafkaUtils.createDirectStream( + streamingContext, + LocationStrategies.PreferConsistent(), + ConsumerStrategies.Subscribe(topics, kafkaParams) + ); + +stream.mapToPair( + new PairFunction, String, String>() { + @Override + public Tuple2 call(ConsumerRecord record) { + return new Tuple2<>(record.key(), record.value()); + } + }) +{% endhighlight %}
@@ -109,32 +112,35 @@ If you have a use case that is better suited to batch processing, you can create
- // Import dependencies and create kafka params as in Create Direct Stream above - - val offsetRanges = Array( - // topic, partition, inclusive starting offset, exclusive ending offset - OffsetRange("test", 0, 0, 100), - OffsetRange("test", 1, 0, 100) - ) +{% highlight scala %} +// Import dependencies and create kafka params as in Create Direct Stream above - val rdd = KafkaUtils.createRDD[String, String](sparkContext, kafkaParams, offsetRanges, PreferConsistent) +val offsetRanges = Array( + // topic, partition, inclusive starting offset, exclusive ending offset + OffsetRange("test", 0, 0, 100), + OffsetRange("test", 1, 0, 100) +) +val rdd = KafkaUtils.createRDD[String, String](sparkContext, kafkaParams, offsetRanges, PreferConsistent) +{% endhighlight %}
- // Import dependencies and create kafka params as in Create Direct Stream above - - OffsetRange[] offsetRanges = { - // topic, partition, inclusive starting offset, exclusive ending offset - OffsetRange.create("test", 0, 0, 100), - OffsetRange.create("test", 1, 0, 100) - }; - - JavaRDD> rdd = KafkaUtils.createRDD( - sparkContext, - kafkaParams, - offsetRanges, - LocationStrategies.PreferConsistent() - ); +{% highlight java %} +// Import dependencies and create kafka params as in Create Direct Stream above + +OffsetRange[] offsetRanges = { + // topic, partition, inclusive starting offset, exclusive ending offset + OffsetRange.create("test", 0, 0, 100), + OffsetRange.create("test", 1, 0, 100) +}; + +JavaRDD> rdd = KafkaUtils.createRDD( + sparkContext, + kafkaParams, + offsetRanges, + LocationStrategies.PreferConsistent() +); +{% endhighlight %}
@@ -144,29 +150,33 @@ Note that you cannot use `PreferBrokers`, because without the stream there is no
- stream.foreachRDD { rdd => - val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges - rdd.foreachPartition { iter => - val o: OffsetRange = offsetRanges(TaskContext.get.partitionId) - println(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") - } - } +{% highlight scala %} +stream.foreachRDD { rdd => + val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + rdd.foreachPartition { iter => + val o: OffsetRange = offsetRanges(TaskContext.get.partitionId) + println(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") + } +} +{% endhighlight %}
- stream.foreachRDD(new VoidFunction>>() { - @Override - public void call(JavaRDD> rdd) { - final OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); - rdd.foreachPartition(new VoidFunction>>() { - @Override - public void call(Iterator> consumerRecords) { - OffsetRange o = offsetRanges[TaskContext.get().partitionId()]; - System.out.println( - o.topic() + " " + o.partition() + " " + o.fromOffset() + " " + o.untilOffset()); - } - }); - } - }); +{% highlight java %} +stream.foreachRDD(new VoidFunction>>() { + @Override + public void call(JavaRDD> rdd) { + final OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + rdd.foreachPartition(new VoidFunction>>() { + @Override + public void call(Iterator> consumerRecords) { + OffsetRange o = offsetRanges[TaskContext.get().partitionId()]; + System.out.println( + o.topic() + " " + o.partition() + " " + o.fromOffset() + " " + o.untilOffset()); + } + }); + } +}); +{% endhighlight %}
@@ -183,25 +193,28 @@ Kafka has an offset commit API that stores offsets in a special Kafka topic. By
- stream.foreachRDD { rdd => - val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges - - // some time later, after outputs have completed - stream.asInstanceOf[CanCommitOffsets].commitAsync(offsetRanges) - } - +{% highlight scala %} +stream.foreachRDD { rdd => + val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + + // some time later, after outputs have completed + stream.asInstanceOf[CanCommitOffsets].commitAsync(offsetRanges) +} +{% endhighlight %} As with HasOffsetRanges, the cast to CanCommitOffsets will only succeed if called on the result of createDirectStream, not after transformations. The commitAsync call is threadsafe, but must occur after outputs if you want meaningful semantics.
- stream.foreachRDD(new VoidFunction>>() { - @Override - public void call(JavaRDD> rdd) { - OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); - - // some time later, after outputs have completed - ((CanCommitOffsets) stream.inputDStream()).commitAsync(offsetRanges); - } - }); +{% highlight java %} +stream.foreachRDD(new VoidFunction>>() { + @Override + public void call(JavaRDD> rdd) { + OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + + // some time later, after outputs have completed + ((CanCommitOffsets) stream.inputDStream()).commitAsync(offsetRanges); + } +}); +{% endhighlight %}
@@ -210,64 +223,68 @@ For data stores that support transactions, saving offsets in the same transactio
- // The details depend on your data store, but the general idea looks like this +{% highlight scala %} +// The details depend on your data store, but the general idea looks like this - // begin from the the offsets committed to the database - val fromOffsets = selectOffsetsFromYourDatabase.map { resultSet => - new TopicPartition(resultSet.string("topic"), resultSet.int("partition")) -> resultSet.long("offset") - }.toMap +// begin from the the offsets committed to the database +val fromOffsets = selectOffsetsFromYourDatabase.map { resultSet => + new TopicPartition(resultSet.string("topic"), resultSet.int("partition")) -> resultSet.long("offset") +}.toMap - val stream = KafkaUtils.createDirectStream[String, String]( - streamingContext, - PreferConsistent, - Assign[String, String](fromOffsets.keys.toList, kafkaParams, fromOffsets) - ) +val stream = KafkaUtils.createDirectStream[String, String]( + streamingContext, + PreferConsistent, + Assign[String, String](fromOffsets.keys.toList, kafkaParams, fromOffsets) +) - stream.foreachRDD { rdd => - val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges +stream.foreachRDD { rdd => + val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges - val results = yourCalculation(rdd) + val results = yourCalculation(rdd) - // begin your transaction + // begin your transaction - // update results - // update offsets where the end of existing offsets matches the beginning of this batch of offsets - // assert that offsets were updated correctly + // update results + // update offsets where the end of existing offsets matches the beginning of this batch of offsets + // assert that offsets were updated correctly - // end your transaction - } + // end your transaction +} +{% endhighlight %}
- // The details depend on your data store, but the general idea looks like this - - // begin from the the offsets committed to the database - Map fromOffsets = new HashMap<>(); - for (resultSet : selectOffsetsFromYourDatabase) - fromOffsets.put(new TopicPartition(resultSet.string("topic"), resultSet.int("partition")), resultSet.long("offset")); - } - - JavaInputDStream> stream = KafkaUtils.createDirectStream( - streamingContext, - LocationStrategies.PreferConsistent(), - ConsumerStrategies.Assign(fromOffsets.keySet(), kafkaParams, fromOffsets) - ); - - stream.foreachRDD(new VoidFunction>>() { - @Override - public void call(JavaRDD> rdd) { - OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); - - Object results = yourCalculation(rdd); - - // begin your transaction - - // update results - // update offsets where the end of existing offsets matches the beginning of this batch of offsets - // assert that offsets were updated correctly - - // end your transaction - } - }); +{% highlight java %} +// The details depend on your data store, but the general idea looks like this + +// begin from the the offsets committed to the database +Map fromOffsets = new HashMap<>(); +for (resultSet : selectOffsetsFromYourDatabase) + fromOffsets.put(new TopicPartition(resultSet.string("topic"), resultSet.int("partition")), resultSet.long("offset")); +} + +JavaInputDStream> stream = KafkaUtils.createDirectStream( + streamingContext, + LocationStrategies.PreferConsistent(), + ConsumerStrategies.Assign(fromOffsets.keySet(), kafkaParams, fromOffsets) +); + +stream.foreachRDD(new VoidFunction>>() { + @Override + public void call(JavaRDD> rdd) { + OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + + Object results = yourCalculation(rdd); + + // begin your transaction + + // update results + // update offsets where the end of existing offsets matches the beginning of this batch of offsets + // assert that offsets were updated correctly + + // end your transaction + } +}); +{% endhighlight %}
@@ -277,25 +294,29 @@ The new Kafka consumer [supports SSL](http://kafka.apache.org/documentation.html
- val kafkaParams = Map[String, Object]( - // the usual params, make sure to change the port in bootstrap.servers if 9092 is not TLS - "security.protocol" -> "SSL", - "ssl.truststore.location" -> "/some-directory/kafka.client.truststore.jks", - "ssl.truststore.password" -> "test1234", - "ssl.keystore.location" -> "/some-directory/kafka.client.keystore.jks", - "ssl.keystore.password" -> "test1234", - "ssl.key.password" -> "test1234" - ) +{% highlight scala %} +val kafkaParams = Map[String, Object]( + // the usual params, make sure to change the port in bootstrap.servers if 9092 is not TLS + "security.protocol" -> "SSL", + "ssl.truststore.location" -> "/some-directory/kafka.client.truststore.jks", + "ssl.truststore.password" -> "test1234", + "ssl.keystore.location" -> "/some-directory/kafka.client.keystore.jks", + "ssl.keystore.password" -> "test1234", + "ssl.key.password" -> "test1234" +) +{% endhighlight %}
- Map kafkaParams = new HashMap(); - // the usual params, make sure to change the port in bootstrap.servers if 9092 is not TLS - kafkaParams.put("security.protocol", "SSL"); - kafkaParams.put("ssl.truststore.location", "/some-directory/kafka.client.truststore.jks"); - kafkaParams.put("ssl.truststore.password", "test1234"); - kafkaParams.put("ssl.keystore.location", "/some-directory/kafka.client.keystore.jks"); - kafkaParams.put("ssl.keystore.password", "test1234"); - kafkaParams.put("ssl.key.password", "test1234"); +{% highlight java %} +Map kafkaParams = new HashMap(); +// the usual params, make sure to change the port in bootstrap.servers if 9092 is not TLS +kafkaParams.put("security.protocol", "SSL"); +kafkaParams.put("ssl.truststore.location", "/some-directory/kafka.client.truststore.jks"); +kafkaParams.put("ssl.truststore.password", "test1234"); +kafkaParams.put("ssl.keystore.location", "/some-directory/kafka.client.keystore.jks"); +kafkaParams.put("ssl.keystore.password", "test1234"); +kafkaParams.put("ssl.key.password", "test1234"); +{% endhighlight %}
diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index a6c3b3a9024d8..c4c9fb3f7d3db 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -19,97 +19,103 @@ application. See the [Deploying](#deploying) subsection below.
+{% highlight scala %} - // Subscribe to 1 topic - val ds1 = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1") - .load() - ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] +// Subscribe to 1 topic +val ds1 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() +ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] - // Subscribe to multiple topics - val ds2 = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1,topic2") - .load() - ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] +// Subscribe to multiple topics +val ds2 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1,topic2") + .load() +ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] - // Subscribe to a pattern - val ds3 = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribePattern", "topic.*") - .load() - ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] +// Subscribe to a pattern +val ds3 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribePattern", "topic.*") + .load() +ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] +{% endhighlight %}
+{% highlight java %} - // Subscribe to 1 topic - Dataset ds1 = spark - .readStream() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1") - .load() - ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +// Subscribe to 1 topic +Dataset ds1 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() +ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - // Subscribe to multiple topics - Dataset ds2 = spark - .readStream() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1,topic2") - .load() - ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +// Subscribe to multiple topics +Dataset ds2 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1,topic2") + .load() +ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - // Subscribe to a pattern - Dataset ds3 = spark - .readStream() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribePattern", "topic.*") - .load() - ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +// Subscribe to a pattern +Dataset ds3 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribePattern", "topic.*") + .load() +ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +{% endhighlight %}
+{% highlight python %} - # Subscribe to 1 topic - ds1 = spark - .readStream() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1") - .load() - ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +# Subscribe to 1 topic +ds1 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() +ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - # Subscribe to multiple topics - ds2 = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1,topic2") - .load() - ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +# Subscribe to multiple topics +ds2 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1,topic2") + .load() +ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - # Subscribe to a pattern - ds3 = spark - .readStream() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribePattern", "topic.*") - .load() - ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +# Subscribe to a pattern +ds3 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribePattern", "topic.*") + .load() +ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +{% endhighlight %}
From 70a5db7bbd192a4bc68bcfdc475ab221adf2fcdd Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Wed, 2 Nov 2016 09:21:26 +0000 Subject: [PATCH 021/198] [SPARK-18204][WEBUI] Remove SparkUI.appUIAddress ## What changes were proposed in this pull request? Removing `appUIAddress` attribute since it is no longer in use. ## How was this patch tested? Local build Author: Jacek Laskowski Closes #15603 from jaceklaskowski/sparkui-fixes. --- .../cluster/StandaloneSchedulerBackend.scala | 6 +++--- .../main/scala/org/apache/spark/ui/SparkUI.scala | 13 +++---------- .../main/scala/org/apache/spark/ui/WebUI.scala | 8 ++++---- .../org/apache/spark/ui/jobs/AllJobsPage.scala | 4 ++-- .../org/apache/spark/ui/UISeleniumSuite.scala | 16 ++++++++-------- .../test/scala/org/apache/spark/ui/UISuite.scala | 13 ++++++------- .../MesosCoarseGrainedSchedulerBackend.scala | 2 +- .../mesos/MesosFineGrainedSchedulerBackend.scala | 2 +- .../apache/spark/streaming/UISeleniumSuite.scala | 12 ++++++------ .../spark/deploy/yarn/ApplicationMaster.scala | 2 +- .../cluster/YarnClientSchedulerBackend.scala | 2 +- 11 files changed, 36 insertions(+), 44 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 04d40e2907cff..368cd30a2e11a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -93,7 +93,7 @@ private[spark] class StandaloneSchedulerBackend( val javaOpts = sparkJavaOpts ++ extraJavaOpts val command = Command("org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs, classPathEntries ++ testingClassPath, libraryPathEntries, javaOpts) - val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("") + val webUrl = sc.ui.map(_.webUrl).getOrElse("") val coresPerExecutor = conf.getOption("spark.executor.cores").map(_.toInt) // If we're using dynamic allocation, set our initial executor limit to 0 for now. // ExecutorAllocationManager will send the real initial limit to the Master later. @@ -103,8 +103,8 @@ private[spark] class StandaloneSchedulerBackend( } else { None } - val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, - appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor, initialExecutorLimit) + val appDesc = ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, + webUrl, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor, initialExecutorLimit) client = new StandaloneAppClient(sc.env.rpcEnv, masters, appDesc, this, conf) client.start() launcherBackend.setState(SparkAppHandle.State.SUBMITTED) diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index f631a047a707d..b828532aba7a3 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -82,7 +82,7 @@ private[spark] class SparkUI private ( initialize() def getSparkUser: String = { - environmentListener.systemProperties.toMap.get("user.name").getOrElse("") + environmentListener.systemProperties.toMap.getOrElse("user.name", "") } def getAppName: String = appName @@ -94,16 +94,9 @@ private[spark] class SparkUI private ( /** Stop the server behind this web interface. Only valid after bind(). */ override def stop() { super.stop() - logInfo("Stopped Spark web UI at %s".format(appUIAddress)) + logInfo(s"Stopped Spark web UI at $webUrl") } - /** - * Return the application UI host:port. This does not include the scheme (http://). - */ - private[spark] def appUIHostPort = publicHostName + ":" + boundPort - - private[spark] def appUIAddress = s"http://$appUIHostPort" - def getSparkUI(appId: String): Option[SparkUI] = { if (appId == this.appId) Some(this) else None } @@ -136,7 +129,7 @@ private[spark] class SparkUI private ( private[spark] abstract class SparkUITab(parent: SparkUI, prefix: String) extends WebUITab(parent, prefix) { - def appName: String = parent.getAppName + def appName: String = parent.appName } diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index a05e0efb7a3e3..8c801558672fa 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -56,8 +56,8 @@ private[spark] abstract class WebUI( private val className = Utils.getFormattedClassName(this) def getBasePath: String = basePath - def getTabs: Seq[WebUITab] = tabs.toSeq - def getHandlers: Seq[ServletContextHandler] = handlers.toSeq + def getTabs: Seq[WebUITab] = tabs + def getHandlers: Seq[ServletContextHandler] = handlers def getSecurityManager: SecurityManager = securityManager /** Attach a tab to this UI, along with all of its attached pages. */ @@ -133,7 +133,7 @@ private[spark] abstract class WebUI( def initialize(): Unit /** Bind to the HTTP server behind this web interface. */ - def bind() { + def bind(): Unit = { assert(!serverInfo.isDefined, s"Attempted to bind $className more than once!") try { val host = Option(conf.getenv("SPARK_LOCAL_IP")).getOrElse("0.0.0.0") @@ -156,7 +156,7 @@ private[spark] abstract class WebUI( def boundPort: Int = serverInfo.map(_.boundPort).getOrElse(-1) /** Stop the server behind this web interface. Only valid after bind(). */ - def stop() { + def stop(): Unit = { assert(serverInfo.isDefined, s"Attempted to stop $className before binding to a server!") serverInfo.get.stop() diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 173fc3cf31ce8..50e8e2d19e155 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -289,8 +289,8 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { val startTime = listener.startTime val endTime = listener.endTime val activeJobs = listener.activeJobs.values.toSeq - val completedJobs = listener.completedJobs.reverse.toSeq - val failedJobs = listener.failedJobs.reverse.toSeq + val completedJobs = listener.completedJobs.reverse + val failedJobs = listener.failedJobs.reverse val activeJobsTable = jobsTable(request, "active", "activeJob", activeJobs, killEnabled = parent.killEnabled) diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index e5d408a167361..f4786e3931c94 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -473,7 +473,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() eventually(timeout(5 seconds), interval(50 milliseconds)) { val url = new URL( - sc.ui.get.appUIAddress.stripSuffix("/") + "/stages/stage/kill/?id=0") + sc.ui.get.webUrl.stripSuffix("/") + "/stages/stage/kill/?id=0") // SPARK-6846: should be POST only but YARN AM doesn't proxy POST getResponseCode(url, "GET") should be (200) getResponseCode(url, "POST") should be (200) @@ -486,7 +486,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() eventually(timeout(5 seconds), interval(50 milliseconds)) { val url = new URL( - sc.ui.get.appUIAddress.stripSuffix("/") + "/jobs/job/kill/?id=0") + sc.ui.get.webUrl.stripSuffix("/") + "/jobs/job/kill/?id=0") // SPARK-6846: should be POST only but YARN AM doesn't proxy POST getResponseCode(url, "GET") should be (200) getResponseCode(url, "POST") should be (200) @@ -620,7 +620,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B test("live UI json application list") { withSpark(newSparkContext()) { sc => val appListRawJson = HistoryServerSuite.getUrl(new URL( - sc.ui.get.appUIAddress + "/api/v1/applications")) + sc.ui.get.webUrl + "/api/v1/applications")) val appListJsonAst = JsonMethods.parse(appListRawJson) appListJsonAst.children.length should be (1) val attempts = (appListJsonAst \ "attempts").children @@ -640,7 +640,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B sc.parallelize(Seq(1, 2, 3)).map(identity).groupBy(identity).map(identity).groupBy(identity) rdd.count() - val stage0 = Source.fromURL(sc.ui.get.appUIAddress + + val stage0 = Source.fromURL(sc.ui.get.webUrl + "/stages/stage/?id=0&attempt=0&expandDagViz=true").mkString assert(stage0.contains("digraph G {\n subgraph clusterstage_0 {\n " + "label="Stage 0";\n subgraph ")) @@ -651,7 +651,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B assert(stage0.contains("{\n label="groupBy";\n " + "2 [label="MapPartitionsRDD [2]")) - val stage1 = Source.fromURL(sc.ui.get.appUIAddress + + val stage1 = Source.fromURL(sc.ui.get.webUrl + "/stages/stage/?id=1&attempt=0&expandDagViz=true").mkString assert(stage1.contains("digraph G {\n subgraph clusterstage_1 {\n " + "label="Stage 1";\n subgraph ")) @@ -662,7 +662,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B assert(stage1.contains("{\n label="groupBy";\n " + "5 [label="MapPartitionsRDD [5]")) - val stage2 = Source.fromURL(sc.ui.get.appUIAddress + + val stage2 = Source.fromURL(sc.ui.get.webUrl + "/stages/stage/?id=2&attempt=0&expandDagViz=true").mkString assert(stage2.contains("digraph G {\n subgraph clusterstage_2 {\n " + "label="Stage 2";\n subgraph ")) @@ -687,7 +687,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } def goToUi(ui: SparkUI, path: String): Unit = { - go to (ui.appUIAddress.stripSuffix("/") + path) + go to (ui.webUrl.stripSuffix("/") + path) } def parseDate(json: JValue): Long = { @@ -699,6 +699,6 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } def apiUrl(ui: SparkUI, path: String): URL = { - new URL(ui.appUIAddress + "/api/v1/applications/" + ui.sc.get.applicationId + "/" + path) + new URL(ui.webUrl + "/api/v1/applications/" + ui.sc.get.applicationId + "/" + path) } } diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 4abcfb7e51914..68c7657cb315b 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -66,7 +66,7 @@ class UISuite extends SparkFunSuite { withSpark(newSparkContext()) { sc => // test if the ui is visible, and all the expected tabs are visible eventually(timeout(10 seconds), interval(50 milliseconds)) { - val html = Source.fromURL(sc.ui.get.appUIAddress).mkString + val html = Source.fromURL(sc.ui.get.webUrl).mkString assert(!html.contains("random data that should not be present")) assert(html.toLowerCase.contains("stages")) assert(html.toLowerCase.contains("storage")) @@ -176,19 +176,18 @@ class UISuite extends SparkFunSuite { } } - test("verify appUIAddress contains the scheme") { + test("verify webUrl contains the scheme") { withSpark(newSparkContext()) { sc => val ui = sc.ui.get - val uiAddress = ui.appUIAddress - val uiHostPort = ui.appUIHostPort - assert(uiAddress.equals("http://" + uiHostPort)) + val uiAddress = ui.webUrl + assert(uiAddress.startsWith("http://") || uiAddress.startsWith("https://")) } } - test("verify appUIAddress contains the port") { + test("verify webUrl contains the port") { withSpark(newSparkContext()) { sc => val ui = sc.ui.get - val splitUIAddress = ui.appUIAddress.split(':') + val splitUIAddress = ui.webUrl.split(':') val boundPort = ui.boundPort assert(splitUIAddress(2).toInt == boundPort) } diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 5063c1fe988bc..842c05e7bf732 100644 --- a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -158,7 +158,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( sc.sparkUser, sc.appName, sc.conf, - sc.conf.getOption("spark.mesos.driver.webui.url").orElse(sc.ui.map(_.appUIAddress)), + sc.conf.getOption("spark.mesos.driver.webui.url").orElse(sc.ui.map(_.webUrl)), None, None, sc.conf.getOption("spark.mesos.driver.frameworkId") diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index 09a252f3c74ac..c1aa00151e69b 100644 --- a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -77,7 +77,7 @@ private[spark] class MesosFineGrainedSchedulerBackend( sc.sparkUser, sc.appName, sc.conf, - sc.conf.getOption("spark.mesos.driver.webui.url").orElse(sc.ui.map(_.appUIAddress)), + sc.conf.getOption("spark.mesos.driver.webui.url").orElse(sc.ui.map(_.webUrl)), Option.empty, Option.empty, sc.conf.getOption("spark.mesos.driver.frameworkId") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index 454c3dffa3db1..e7cec999c219e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -92,13 +92,13 @@ class UISeleniumSuite val sparkUI = ssc.sparkContext.ui.get eventually(timeout(10 seconds), interval(50 milliseconds)) { - go to (sparkUI.appUIAddress.stripSuffix("/")) + go to (sparkUI.webUrl.stripSuffix("/")) find(cssSelector( """ul li a[href*="streaming"]""")) should not be (None) } eventually(timeout(10 seconds), interval(50 milliseconds)) { // check whether streaming page exists - go to (sparkUI.appUIAddress.stripSuffix("/") + "/streaming") + go to (sparkUI.webUrl.stripSuffix("/") + "/streaming") val h3Text = findAll(cssSelector("h3")).map(_.text).toSeq h3Text should contain("Streaming Statistics") @@ -180,23 +180,23 @@ class UISeleniumSuite jobDetails should contain("Completed Stages:") // Check a batch page without id - go to (sparkUI.appUIAddress.stripSuffix("/") + "/streaming/batch/") + go to (sparkUI.webUrl.stripSuffix("/") + "/streaming/batch/") webDriver.getPageSource should include ("Missing id parameter") // Check a non-exist batch - go to (sparkUI.appUIAddress.stripSuffix("/") + "/streaming/batch/?id=12345") + go to (sparkUI.webUrl.stripSuffix("/") + "/streaming/batch/?id=12345") webDriver.getPageSource should include ("does not exist") } ssc.stop(false) eventually(timeout(10 seconds), interval(50 milliseconds)) { - go to (sparkUI.appUIAddress.stripSuffix("/")) + go to (sparkUI.webUrl.stripSuffix("/")) find(cssSelector( """ul li a[href*="streaming"]""")) should be(None) } eventually(timeout(10 seconds), interval(50 milliseconds)) { - go to (sparkUI.appUIAddress.stripSuffix("/") + "/streaming") + go to (sparkUI.webUrl.stripSuffix("/") + "/streaming") val h3Text = findAll(cssSelector("h3")).map(_.text).toSeq h3Text should not contain("Streaming Statistics") } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index aabae140af8b1..f2b9dfb4d184d 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -406,7 +406,7 @@ private[spark] class ApplicationMaster( sc.getConf.get("spark.driver.host"), sc.getConf.get("spark.driver.port"), isClusterMode = true) - registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.appUIAddress).getOrElse(""), + registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.webUrl).getOrElse(""), securityMgr) } else { // Sanity check; should never happen in normal operation, since sc should only be null diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index d8b36c5feaf52..60da356ad14aa 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -44,7 +44,7 @@ private[spark] class YarnClientSchedulerBackend( val driverHost = conf.get("spark.driver.host") val driverPort = conf.get("spark.driver.port") val hostport = driverHost + ":" + driverPort - sc.ui.foreach { ui => conf.set("spark.driver.appUIAddress", ui.appUIAddress) } + sc.ui.foreach { ui => conf.set("spark.driver.appUIAddress", ui.webUrl) } val argsArrayBuf = new ArrayBuffer[String]() argsArrayBuf += ("--arg", hostport) From 9c8deef64efee20a0ddc9b612f90e77c80aede60 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 2 Nov 2016 09:39:15 +0000 Subject: [PATCH 022/198] [SPARK-18076][CORE][SQL] Fix default Locale used in DateFormat, NumberFormat to Locale.US ## What changes were proposed in this pull request? Fix `Locale.US` for all usages of `DateFormat`, `NumberFormat` ## How was this patch tested? Existing tests. Author: Sean Owen Closes #15610 from srowen/SPARK-18076. --- .../org/apache/spark/SparkHadoopWriter.scala | 8 +++---- .../apache/spark/deploy/SparkHadoopUtil.scala | 4 ++-- .../apache/spark/deploy/master/Master.scala | 5 ++-- .../apache/spark/deploy/worker/Worker.scala | 4 ++-- .../org/apache/spark/rdd/HadoopRDD.scala | 5 ++-- .../org/apache/spark/rdd/NewHadoopRDD.scala | 4 ++-- .../apache/spark/rdd/PairRDDFunctions.scala | 4 ++-- .../status/api/v1/JacksonMessageWriter.scala | 4 ++-- .../spark/status/api/v1/SimpleDateParam.scala | 6 ++--- .../scala/org/apache/spark/ui/UIUtils.scala | 3 ++- .../spark/util/logging/RollingPolicy.scala | 6 ++--- .../org/apache/spark/util/UtilsSuite.scala | 2 +- .../deploy/rest/mesos/MesosRestServer.scala | 11 ++++----- .../mllib/pmml/export/PMMLModelExport.scala | 4 ++-- .../expressions/datetimeExpressions.scala | 17 ++++++------- .../expressions/stringExpressions.scala | 2 +- .../spark/sql/catalyst/json/JSONOptions.scala | 6 +++-- .../sql/catalyst/util/DateTimeUtils.scala | 6 ++--- .../expressions/DateExpressionsSuite.scala | 24 +++++++++---------- .../catalyst/util/DateTimeUtilsSuite.scala | 6 ++--- .../datasources/csv/CSVInferSchema.scala | 4 ++-- .../datasources/csv/CSVOptions.scala | 5 ++-- .../sql/execution/metric/SQLMetrics.scala | 2 +- .../sql/execution/streaming/socket.scala | 4 ++-- .../apache/spark/sql/DateFunctionsSuite.scala | 11 +++++---- .../execution/datasources/csv/CSVSuite.scala | 9 +++---- .../datasources/csv/CSVTypeCastSuite.scala | 9 ++++--- .../hive/execution/InsertIntoHiveTable.scala | 9 +++---- .../spark/sql/hive/hiveWriterContainers.scala | 4 ++-- .../sql/sources/SimpleTextRelation.scala | 3 ++- .../apache/spark/streaming/ui/UIUtils.scala | 8 ++++--- 31 files changed, 103 insertions(+), 96 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 6550d703bc860..7f75a393bf8ff 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -20,7 +20,7 @@ package org.apache.spark import java.io.IOException import java.text.NumberFormat import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.fs.Path @@ -67,12 +67,12 @@ class SparkHadoopWriter(jobConf: JobConf) extends Logging with Serializable { def setup(jobid: Int, splitid: Int, attemptid: Int) { setIDs(jobid, splitid, attemptid) - HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss").format(now), + HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(now), jobid, splitID, attemptID, conf.value) } def open() { - val numfmt = NumberFormat.getInstance() + val numfmt = NumberFormat.getInstance(Locale.US) numfmt.setMinimumIntegerDigits(5) numfmt.setGroupingUsed(false) @@ -162,7 +162,7 @@ class SparkHadoopWriter(jobConf: JobConf) extends Logging with Serializable { private[spark] object SparkHadoopWriter { def createJobID(time: Date, id: Int): JobID = { - val formatter = new SimpleDateFormat("yyyyMMddHHmmss") + val formatter = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) val jobtrackerID = formatter.format(time) new JobID(jobtrackerID, id) } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 3f54ecc17ac33..23156072c3ebe 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -21,7 +21,7 @@ import java.io.IOException import java.lang.reflect.Method import java.security.PrivilegedExceptionAction import java.text.DateFormat -import java.util.{Arrays, Comparator, Date} +import java.util.{Arrays, Comparator, Date, Locale} import scala.collection.JavaConverters._ import scala.util.control.NonFatal @@ -357,7 +357,7 @@ class SparkHadoopUtil extends Logging { * @return a printable string value. */ private[spark] def tokenToString(token: Token[_ <: TokenIdentifier]): String = { - val df = DateFormat.getDateTimeInstance(DateFormat.SHORT, DateFormat.SHORT) + val df = DateFormat.getDateTimeInstance(DateFormat.SHORT, DateFormat.SHORT, Locale.US) val buffer = new StringBuilder(128) buffer.append(token.toString) try { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 8c91aa15167c4..4618e6117a4fb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.master import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import java.util.concurrent.{ScheduledFuture, TimeUnit} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -51,7 +51,8 @@ private[deploy] class Master( private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs + // For application IDs + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) private val WORKER_TIMEOUT_MS = conf.getLong("spark.worker.timeout", 60) * 1000 private val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 0bedd9a20a969..8b1c6bf2e5fd5 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -20,7 +20,7 @@ package org.apache.spark.deploy.worker import java.io.File import java.io.IOException import java.text.SimpleDateFormat -import java.util.{Date, UUID} +import java.util.{Date, Locale, UUID} import java.util.concurrent._ import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} @@ -68,7 +68,7 @@ private[deploy] class Worker( ThreadUtils.newDaemonSingleThreadExecutor("worker-cleanup-thread")) // For worker and executor IDs - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) // Send a heartbeat every (heartbeat timeout) / 4 milliseconds private val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4 diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index e1cf3938de098..36a2f5c87e372 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.io.IOException import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import scala.collection.immutable.Map import scala.reflect.ClassTag @@ -243,7 +243,8 @@ class HadoopRDD[K, V]( var reader: RecordReader[K, V] = null val inputFormat = getInputFormat(jobConf) - HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss").format(createTime), + HadoopRDD.addLocalConfiguration( + new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(createTime), context.stageId, theSplit.index, context.attemptNumber, jobConf) reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index baf31fb658870..488e777fea371 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.io.IOException import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import scala.reflect.ClassTag @@ -79,7 +79,7 @@ class NewHadoopRDD[K, V]( // private val serializableConf = new SerializableWritable(_conf) private val jobTrackerId: String = { - val formatter = new SimpleDateFormat("yyyyMMddHHmmss") + val formatter = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) formatter.format(new Date()) } diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 068f4ed8ad745..67baad1c51bca 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.nio.ByteBuffer import java.text.SimpleDateFormat -import java.util.{Date, HashMap => JHashMap} +import java.util.{Date, HashMap => JHashMap, Locale} import scala.collection.{mutable, Map} import scala.collection.JavaConverters._ @@ -1079,7 +1079,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf val job = NewAPIHadoopJob.getInstance(hadoopConf) - val formatter = new SimpleDateFormat("yyyyMMddHHmmss") + val formatter = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) val jobtrackerID = formatter.format(new Date()) val stageId = self.id val jobConfiguration = job.getConfiguration diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala index f6a9f9c5573db..76af33c1a18db 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala @@ -21,7 +21,7 @@ import java.lang.annotation.Annotation import java.lang.reflect.Type import java.nio.charset.StandardCharsets import java.text.SimpleDateFormat -import java.util.{Calendar, SimpleTimeZone} +import java.util.{Calendar, Locale, SimpleTimeZone} import javax.ws.rs.Produces import javax.ws.rs.core.{MediaType, MultivaluedMap} import javax.ws.rs.ext.{MessageBodyWriter, Provider} @@ -86,7 +86,7 @@ private[v1] class JacksonMessageWriter extends MessageBodyWriter[Object]{ private[spark] object JacksonMessageWriter { def makeISODateFormat: SimpleDateFormat = { - val iso8601 = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'GMT'") + val iso8601 = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'GMT'", Locale.US) val cal = Calendar.getInstance(new SimpleTimeZone(0, "GMT")) iso8601.setCalendar(cal) iso8601 diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala index 0c71cd2382225..d8d5e8958b23c 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala @@ -17,7 +17,7 @@ package org.apache.spark.status.api.v1 import java.text.{ParseException, SimpleDateFormat} -import java.util.TimeZone +import java.util.{Locale, TimeZone} import javax.ws.rs.WebApplicationException import javax.ws.rs.core.Response import javax.ws.rs.core.Response.Status @@ -25,12 +25,12 @@ import javax.ws.rs.core.Response.Status private[v1] class SimpleDateParam(val originalValue: String) { val timestamp: Long = { - val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz") + val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz", Locale.US) try { format.parse(originalValue).getTime() } catch { case _: ParseException => - val gmtDay = new SimpleDateFormat("yyyy-MM-dd") + val gmtDay = new SimpleDateFormat("yyyy-MM-dd", Locale.US) gmtDay.setTimeZone(TimeZone.getTimeZone("GMT")) try { gmtDay.parse(originalValue).getTime() diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index c0d1a2220f62a..66b097aa8166d 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -36,7 +36,8 @@ private[spark] object UIUtils extends Logging { // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. private val dateFormat = new ThreadLocal[SimpleDateFormat]() { - override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") + override def initialValue(): SimpleDateFormat = + new SimpleDateFormat("yyyy/MM/dd HH:mm:ss", Locale.US) } def formatDate(date: Date): String = dateFormat.get.format(date) diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala index 5c4238c0381a1..1f263df57c857 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala @@ -18,7 +18,7 @@ package org.apache.spark.util.logging import java.text.SimpleDateFormat -import java.util.Calendar +import java.util.{Calendar, Locale} import org.apache.spark.internal.Logging @@ -59,7 +59,7 @@ private[spark] class TimeBasedRollingPolicy( } @volatile private var nextRolloverTime = calculateNextRolloverTime() - private val formatter = new SimpleDateFormat(rollingFileSuffixPattern) + private val formatter = new SimpleDateFormat(rollingFileSuffixPattern, Locale.US) /** Should rollover if current time has exceeded next rollover time */ def shouldRollover(bytesToBeWritten: Long): Boolean = { @@ -109,7 +109,7 @@ private[spark] class SizeBasedRollingPolicy( } @volatile private var bytesWrittenSinceRollover = 0L - val formatter = new SimpleDateFormat("--yyyy-MM-dd--HH-mm-ss--SSSS") + val formatter = new SimpleDateFormat("--yyyy-MM-dd--HH-mm-ss--SSSS", Locale.US) /** Should rollover if the next set of bytes is going to exceed the size limit */ def shouldRollover(bytesToBeWritten: Long): Boolean = { diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 15ef32f21d90c..feacfb7642f27 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -264,7 +264,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { val hour = minute * 60 def str: (Long) => String = Utils.msDurationToString(_) - val sep = new DecimalFormatSymbols(Locale.getDefault()).getDecimalSeparator() + val sep = new DecimalFormatSymbols(Locale.US).getDecimalSeparator assert(str(123) === "123 ms") assert(str(second) === "1" + sep + "0 s") diff --git a/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala index 3b96488a129a9..ff60b88c6d533 100644 --- a/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala +++ b/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.rest.mesos import java.io.File import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import java.util.concurrent.atomic.AtomicLong import javax.servlet.http.HttpServletResponse @@ -62,11 +62,10 @@ private[mesos] class MesosSubmitRequestServlet( private val DEFAULT_CORES = 1.0 private val nextDriverNumber = new AtomicLong(0) - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs - private def newDriverId(submitDate: Date): String = { - "driver-%s-%04d".format( - createDateFormat.format(submitDate), nextDriverNumber.incrementAndGet()) - } + // For application IDs + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) + private def newDriverId(submitDate: Date): String = + f"driver-${createDateFormat.format(submitDate)}-${nextDriverNumber.incrementAndGet()}%04d" /** * Build a driver description from the fields specified in the submit request. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala index 426bb818c9266..f5ca1c221d66b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.pmml.export import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import scala.beans.BeanProperty @@ -34,7 +34,7 @@ private[mllib] trait PMMLModelExport { val version = getClass.getPackage.getImplementationVersion val app = new Application("Apache Spark MLlib").setVersion(version) val timestamp = new Timestamp() - .addContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss").format(new Date())) + .addContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss", Locale.US).format(new Date())) val header = new Header() .setApplication(app) .setTimestamp(timestamp) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 7ab68a13e09cf..67c078ae5e264 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.text.SimpleDateFormat -import java.util.{Calendar, TimeZone} +import java.util.{Calendar, Locale, TimeZone} import scala.util.Try @@ -331,7 +331,7 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType) override protected def nullSafeEval(timestamp: Any, format: Any): Any = { - val sdf = new SimpleDateFormat(format.toString) + val sdf = new SimpleDateFormat(format.toString, Locale.US) UTF8String.fromString(sdf.format(new java.util.Date(timestamp.asInstanceOf[Long] / 1000))) } @@ -400,7 +400,7 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] private lazy val formatter: SimpleDateFormat = - Try(new SimpleDateFormat(constFormat.toString)).getOrElse(null) + Try(new SimpleDateFormat(constFormat.toString, Locale.US)).getOrElse(null) override def eval(input: InternalRow): Any = { val t = left.eval(input) @@ -425,7 +425,7 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { null } else { val formatString = f.asInstanceOf[UTF8String].toString - Try(new SimpleDateFormat(formatString).parse( + Try(new SimpleDateFormat(formatString, Locale.US).parse( t.asInstanceOf[UTF8String].toString).getTime / 1000L).getOrElse(null) } } @@ -520,7 +520,7 @@ case class FromUnixTime(sec: Expression, format: Expression) private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] private lazy val formatter: SimpleDateFormat = - Try(new SimpleDateFormat(constFormat.toString)).getOrElse(null) + Try(new SimpleDateFormat(constFormat.toString, Locale.US)).getOrElse(null) override def eval(input: InternalRow): Any = { val time = left.eval(input) @@ -539,9 +539,10 @@ case class FromUnixTime(sec: Expression, format: Expression) if (f == null) { null } else { - Try(UTF8String.fromString(new SimpleDateFormat( - f.asInstanceOf[UTF8String].toString).format(new java.util.Date( - time.asInstanceOf[Long] * 1000L)))).getOrElse(null) + Try( + UTF8String.fromString(new SimpleDateFormat(f.toString, Locale.US). + format(new java.util.Date(time.asInstanceOf[Long] * 1000L))) + ).getOrElse(null) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 1bcbb6cfc9246..25a5e3fd7da73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1415,7 +1415,7 @@ case class Sentences( val locale = if (languageStr != null && countryStr != null) { new Locale(languageStr.toString, countryStr.toString) } else { - Locale.getDefault + Locale.US } getSentences(string.asInstanceOf[UTF8String].toString, locale) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index aec18922ea6c8..c45970658cf07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.json +import java.util.Locale + import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import org.apache.commons.lang3.time.FastDateFormat @@ -56,11 +58,11 @@ private[sql] class JSONOptions( // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. val dateFormat: FastDateFormat = - FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd")) + FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), Locale.US) val timestampFormat: FastDateFormat = FastDateFormat.getInstance( - parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ")) + parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), Locale.US) // Parse mode flags if (!ParseModes.isValidMode(parseMode)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 0b643a5b84268..235ca8d2633a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} -import java.util.{Calendar, TimeZone} +import java.util.{Calendar, Locale, TimeZone} import javax.xml.bind.DatatypeConverter import scala.annotation.tailrec @@ -79,14 +79,14 @@ object DateTimeUtils { // `SimpleDateFormat` is not thread-safe. val threadLocalTimestampFormat = new ThreadLocal[DateFormat] { override def initialValue(): SimpleDateFormat = { - new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) } } // `SimpleDateFormat` is not thread-safe. private val threadLocalDateFormat = new ThreadLocal[DateFormat] { override def initialValue(): SimpleDateFormat = { - new SimpleDateFormat("yyyy-MM-dd") + new SimpleDateFormat("yyyy-MM-dd", Locale.US) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 6118a34d29eaa..35cea25ba0b7d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat -import java.util.Calendar +import java.util.{Calendar, Locale} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -30,8 +30,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { import IntegralLiteralTestUtils._ - val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") - val sdfDate = new SimpleDateFormat("yyyy-MM-dd") + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) + val sdfDate = new SimpleDateFormat("yyyy-MM-dd", Locale.US) val d = new Date(sdf.parse("2015-04-08 13:10:15").getTime) val ts = new Timestamp(sdf.parse("2013-11-08 13:10:15").getTime) @@ -49,7 +49,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("DayOfYear") { - val sdfDay = new SimpleDateFormat("D") + val sdfDay = new SimpleDateFormat("D", Locale.US) (0 to 3).foreach { m => (0 to 5).foreach { i => val c = Calendar.getInstance() @@ -411,9 +411,9 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("from_unixtime") { - val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" - val sdf2 = new SimpleDateFormat(fmt2) + val sdf2 = new SimpleDateFormat(fmt2, Locale.US) checkEvaluation( FromUnixTime(Literal(0L), Literal("yyyy-MM-dd HH:mm:ss")), sdf1.format(new Timestamp(0))) checkEvaluation(FromUnixTime( @@ -430,11 +430,11 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("unix_timestamp") { - val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" - val sdf2 = new SimpleDateFormat(fmt2) + val sdf2 = new SimpleDateFormat(fmt2, Locale.US) val fmt3 = "yy-MM-dd" - val sdf3 = new SimpleDateFormat(fmt3) + val sdf3 = new SimpleDateFormat(fmt3, Locale.US) val date1 = Date.valueOf("2015-07-24") checkEvaluation( UnixTimestamp(Literal(sdf1.format(new Timestamp(0))), Literal("yyyy-MM-dd HH:mm:ss")), 0L) @@ -466,11 +466,11 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("to_unix_timestamp") { - val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" - val sdf2 = new SimpleDateFormat(fmt2) + val sdf2 = new SimpleDateFormat(fmt2, Locale.US) val fmt3 = "yy-MM-dd" - val sdf3 = new SimpleDateFormat(fmt3) + val sdf3 = new SimpleDateFormat(fmt3, Locale.US) val date1 = Date.valueOf("2015-07-24") checkEvaluation( ToUnixTimestamp(Literal(sdf1.format(new Timestamp(0))), Literal("yyyy-MM-dd HH:mm:ss")), 0L) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 4f516d006458e..e0a9a0c3d5c00 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat -import java.util.{Calendar, TimeZone} +import java.util.{Calendar, Locale, TimeZone} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.DateTimeUtils._ @@ -68,8 +68,8 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(d2.toString === d1.toString) } - val df1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") - val df2 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss z") + val df1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) + val df2 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss z", Locale.US) checkFromToJavaDate(new Date(100)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index 3ab775c909238..1981d8607c0c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -247,7 +247,7 @@ private[csv] object CSVTypeCast { case options.positiveInf => Float.PositiveInfinity case _ => Try(datum.toFloat) - .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue()) + .getOrElse(NumberFormat.getInstance(Locale.US).parse(datum).floatValue()) } case _: DoubleType => datum match { @@ -256,7 +256,7 @@ private[csv] object CSVTypeCast { case options.positiveInf => Double.PositiveInfinity case _ => Try(datum.toDouble) - .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue()) + .getOrElse(NumberFormat.getInstance(Locale.US).parse(datum).doubleValue()) } case _: BooleanType => datum.toBoolean case dt: DecimalType => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 014614eb997a5..5903729c11fc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.csv import java.nio.charset.StandardCharsets +import java.util.Locale import org.apache.commons.lang3.time.FastDateFormat @@ -104,11 +105,11 @@ private[csv] class CSVOptions(@transient private val parameters: Map[String, Str // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. val dateFormat: FastDateFormat = - FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd")) + FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), Locale.US) val timestampFormat: FastDateFormat = FastDateFormat.getInstance( - parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ")) + parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), Locale.US) val maxColumns = getInt("maxColumns", 20480) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 0cc1edd196bc8..dbc27d8b237f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -102,7 +102,7 @@ object SQLMetrics { */ def stringValue(metricsType: String, values: Seq[Long]): String = { if (metricsType == SUM_METRIC) { - val numberFormat = NumberFormat.getIntegerInstance(Locale.ENGLISH) + val numberFormat = NumberFormat.getIntegerInstance(Locale.US) numberFormat.format(values.sum) } else { val strFormat: Long => String = if (metricsType == SIZE_METRIC) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala index c662e7c6bc775..042977f870b8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala @@ -21,7 +21,7 @@ import java.io.{BufferedReader, InputStreamReader, IOException} import java.net.Socket import java.sql.Timestamp import java.text.SimpleDateFormat -import java.util.Calendar +import java.util.{Calendar, Locale} import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.ListBuffer @@ -37,7 +37,7 @@ object TextSocketSource { val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil) val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) :: StructField("timestamp", TimestampType) :: Nil) - val DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index f7aa3b747ae5d..e05b2252ee346 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat +import java.util.Locale import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions._ @@ -55,8 +56,8 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(sql("""SELECT CURRENT_TIMESTAMP() = NOW()"""), Row(true)) } - val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") - val sdfDate = new SimpleDateFormat("yyyy-MM-dd") + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) + val sdfDate = new SimpleDateFormat("yyyy-MM-dd", Locale.US) val d = new Date(sdf.parse("2015-04-08 13:10:15").getTime) val ts = new Timestamp(sdf.parse("2013-04-08 13:10:15").getTime) @@ -395,11 +396,11 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { } test("from_unixtime") { - val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" - val sdf2 = new SimpleDateFormat(fmt2) + val sdf2 = new SimpleDateFormat(fmt2, Locale.US) val fmt3 = "yy-MM-dd HH-mm-ss" - val sdf3 = new SimpleDateFormat(fmt3) + val sdf3 = new SimpleDateFormat(fmt3, Locale.US) val df = Seq((1000, "yyyy-MM-dd HH:mm:ss.SSS"), (-1000, "yy-MM-dd HH-mm-ss")).toDF("a", "b") checkAnswer( df.select(from_unixtime(col("a"))), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index f7c22c6c93f7a..8209b5bd7f9de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -21,6 +21,7 @@ import java.io.File import java.nio.charset.UnsupportedCharsetException import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat +import java.util.Locale import org.apache.commons.lang3.time.FastDateFormat import org.apache.hadoop.io.SequenceFile.CompressionType @@ -487,7 +488,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .select("date") .collect() - val dateFormat = new SimpleDateFormat("dd/MM/yyyy HH:mm") + val dateFormat = new SimpleDateFormat("dd/MM/yyyy HH:mm", Locale.US) val expected = Seq(Seq(new Timestamp(dateFormat.parse("26/08/2015 18:00").getTime)), Seq(new Timestamp(dateFormat.parse("27/10/2014 18:30").getTime)), @@ -509,7 +510,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .select("date") .collect() - val dateFormat = new SimpleDateFormat("dd/MM/yyyy hh:mm") + val dateFormat = new SimpleDateFormat("dd/MM/yyyy hh:mm", Locale.US) val expected = Seq( new Date(dateFormat.parse("26/08/2015 18:00").getTime), new Date(dateFormat.parse("27/10/2014 18:30").getTime), @@ -728,7 +729,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .option("inferSchema", "false") .load(iso8601timestampsPath) - val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSZZ") + val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSZZ", Locale.US) val expectedTimestamps = timestamps.collect().map { r => // This should be ISO8601 formatted string. Row(iso8501.format(r.toSeq.head)) @@ -761,7 +762,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .option("inferSchema", "false") .load(iso8601datesPath) - val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd") + val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd", Locale.US) val expectedDates = dates.collect().map { r => // This should be ISO8601 formatted string. Row(iso8501.format(r.toSeq.head)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala index 51832a13cfe0b..c74406b9cbfbb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala @@ -144,13 +144,12 @@ class CSVTypeCastSuite extends SparkFunSuite { DateTimeUtils.millisToDays(DateTimeUtils.stringToTime("2015-01-01").getTime)) } - test("Float and Double Types are cast correctly with Locale") { + test("Float and Double Types are cast without respect to platform default Locale") { val originalLocale = Locale.getDefault try { - val locale : Locale = new Locale("fr", "FR") - Locale.setDefault(locale) - assert(CSVTypeCast.castTo("1,00", FloatType) == 1.0) - assert(CSVTypeCast.castTo("1,00", DoubleType) == 1.0) + Locale.setDefault(new Locale("fr", "FR")) + assert(CSVTypeCast.castTo("1,00", FloatType) == 100.0) // Would parse as 1.0 in fr-FR + assert(CSVTypeCast.castTo("1,00", DoubleType) == 100.0) } finally { Locale.setDefault(originalLocale) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 2843100fb3b36..05164d774ccaf 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -20,9 +20,7 @@ package org.apache.spark.sql.hive.execution import java.io.IOException import java.net.URI import java.text.SimpleDateFormat -import java.util.{Date, Random} - -import scala.collection.JavaConverters._ +import java.util.{Date, Locale, Random} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} @@ -60,9 +58,8 @@ case class InsertIntoHiveTable( private def executionId: String = { val rand: Random = new Random - val format: SimpleDateFormat = new SimpleDateFormat("yyyy-MM-dd_HH-mm-ss_SSS") - val executionId: String = "hive_" + format.format(new Date) + "_" + Math.abs(rand.nextLong) - return executionId + val format = new SimpleDateFormat("yyyy-MM-dd_HH-mm-ss_SSS", Locale.US) + "hive_" + format.format(new Date) + "_" + Math.abs(rand.nextLong) } private def getStagingDir(inputPath: Path, hadoopConf: Configuration): Path = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index ea88276bb96c0..e53c3e4d4833b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive import java.text.NumberFormat -import java.util.Date +import java.util.{Date, Locale} import scala.collection.JavaConverters._ @@ -95,7 +95,7 @@ private[hive] class SparkHiveWriterContainer( } protected def getOutputName: String = { - val numberFormat = NumberFormat.getInstance() + val numberFormat = NumberFormat.getInstance(Locale.US) numberFormat.setMinimumIntegerDigits(5) numberFormat.setGroupingUsed(false) val extension = Utilities.getFileExtension(conf.value, fileSinkConf.getCompressed, outputFormat) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 64d0ecbeefc98..cecfd99098659 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources import java.text.NumberFormat +import java.util.Locale import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} @@ -141,7 +142,7 @@ class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) class AppendingTextOutputFormat(path: String) extends TextOutputFormat[NullWritable, Text] { - val numberFormat = NumberFormat.getInstance() + val numberFormat = NumberFormat.getInstance(Locale.US) numberFormat.setMinimumIntegerDigits(5) numberFormat.setGroupingUsed(false) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala index 9b1c939e9329f..84ecf81abfbf1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.ui import java.text.SimpleDateFormat -import java.util.TimeZone +import java.util.{Locale, TimeZone} import java.util.concurrent.TimeUnit import scala.xml.Node @@ -80,11 +80,13 @@ private[streaming] object UIUtils { // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. private val batchTimeFormat = new ThreadLocal[SimpleDateFormat]() { - override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") + override def initialValue(): SimpleDateFormat = + new SimpleDateFormat("yyyy/MM/dd HH:mm:ss", Locale.US) } private val batchTimeFormatWithMilliseconds = new ThreadLocal[SimpleDateFormat]() { - override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss.SSS") + override def initialValue(): SimpleDateFormat = + new SimpleDateFormat("yyyy/MM/dd HH:mm:ss.SSS", Locale.US) } /** From f151bd1af8a05d4b6c901ebe6ac0b51a4a1a20df Mon Sep 17 00:00:00 2001 From: eyal farago Date: Wed, 2 Nov 2016 11:12:20 +0100 Subject: [PATCH 023/198] [SPARK-16839][SQL] Simplify Struct creation code path ## What changes were proposed in this pull request? Simplify struct creation, especially the aspect of `CleanupAliases` which missed some aliases when handling trees created by `CreateStruct`. This PR includes: 1. A failing test (create struct with nested aliases, some of the aliases survive `CleanupAliases`). 2. A fix that transforms `CreateStruct` into a `CreateNamedStruct` constructor, effectively eliminating `CreateStruct` from all expression trees. 3. A `NamePlaceHolder` used by `CreateStruct` when column names cannot be extracted from unresolved `NamedExpression`. 4. A new Analyzer rule that resolves `NamePlaceHolder` into a string literal once the `NamedExpression` is resolved. 5. `CleanupAliases` code was simplified as it no longer has to deal with `CreateStruct`'s top level columns. ## How was this patch tested? Running all tests-suits in package org.apache.spark.sql, especially including the analysis suite, making sure added test initially fails, after applying suggested fix rerun the entire analysis package successfully. Modified few tests that expected `CreateStruct` which is now transformed into `CreateNamedStruct`. Author: eyal farago Author: Herman van Hovell Author: eyal farago Author: Eyal Farago Author: Hyukjin Kwon Author: eyalfa Closes #15718 from hvanhovell/SPARK-16839-2. --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 12 +- .../sql/catalyst/analysis/Analyzer.scala | 53 ++--- .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../sql/catalyst/expressions/Projection.scala | 2 - .../expressions/complexTypeCreator.scala | 212 ++++++------------ .../sql/catalyst/parser/AstBuilder.scala | 4 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 38 +++- .../expressions/ComplexTypeSuite.scala | 1 - .../scala/org/apache/spark/sql/Column.scala | 3 + .../command/AnalyzeColumnCommand.scala | 4 +- .../sql-tests/results/group-by.sql.out | 2 +- .../apache/spark/sql/hive/test/TestHive.scala | 20 +- .../resources/sqlgen/subquery_in_having_2.sql | 2 +- .../sql/catalyst/LogicalPlanToSQLSuite.scala | 12 +- 14 files changed, 169 insertions(+), 198 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 806019d7524ff..d7fe6b32822a7 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1222,16 +1222,16 @@ test_that("column functions", { # Test struct() df <- createDataFrame(list(list(1L, 2L, 3L), list(4L, 5L, 6L)), schema = c("a", "b", "c")) - result <- collect(select(df, struct("a", "c"))) + result <- collect(select(df, alias(struct("a", "c"), "d"))) expected <- data.frame(row.names = 1:2) - expected$"struct(a, c)" <- list(listToStruct(list(a = 1L, c = 3L)), - listToStruct(list(a = 4L, c = 6L))) + expected$"d" <- list(listToStruct(list(a = 1L, c = 3L)), + listToStruct(list(a = 4L, c = 6L))) expect_equal(result, expected) - result <- collect(select(df, struct(df$a, df$b))) + result <- collect(select(df, alias(struct(df$a, df$b), "d"))) expected <- data.frame(row.names = 1:2) - expected$"struct(a, b)" <- list(listToStruct(list(a = 1L, b = 2L)), - listToStruct(list(a = 4L, b = 5L))) + expected$"d" <- list(listToStruct(list(a = 1L, b = 2L)), + listToStruct(list(a = 4L, b = 5L))) expect_equal(result, expected) # Test encode(), decode() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f8f4799322b3b..5011f2fdbf9b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.trees.{TreeNodeRef} +import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.types._ @@ -83,6 +83,7 @@ class Analyzer( ResolveTableValuedFunctions :: ResolveRelations :: ResolveReferences :: + ResolveCreateNamedStruct :: ResolveDeserializer :: ResolveNewInstance :: ResolveUpCast :: @@ -653,11 +654,12 @@ class Analyzer( case s: Star => s.expand(child, resolver) case o => o :: Nil }) - case c: CreateStruct if containsStar(c.children) => - c.copy(children = c.children.flatMap { - case s: Star => s.expand(child, resolver) - case o => o :: Nil - }) + case c: CreateNamedStruct if containsStar(c.valExprs) => + val newChildren = c.children.grouped(2).flatMap { + case Seq(k, s : Star) => CreateStruct(s.expand(child, resolver)).children + case kv => kv + } + c.copy(children = newChildren.toList ) case c: CreateArray if containsStar(c.children) => c.copy(children = c.children.flatMap { case s: Star => s.expand(child, resolver) @@ -1141,7 +1143,7 @@ class Analyzer( case In(e, Seq(l @ ListQuery(_, exprId))) if e.resolved => // Get the left hand side expressions. val expressions = e match { - case CreateStruct(exprs) => exprs + case cns : CreateNamedStruct => cns.valExprs case expr => Seq(expr) } resolveSubQuery(l, plans, expressions.size) { (rewrite, conditions) => @@ -2072,18 +2074,8 @@ object EliminateUnions extends Rule[LogicalPlan] { */ object CleanupAliases extends Rule[LogicalPlan] { private def trimAliases(e: Expression): Expression = { - var stop = false e.transformDown { - // CreateStruct is a special case, we need to retain its top level Aliases as they decide the - // name of StructField. We also need to stop transform down this expression, or the Aliases - // under CreateStruct will be mistakenly trimmed. - case c: CreateStruct if !stop => - stop = true - c.copy(children = c.children.map(trimNonTopLevelAliases)) - case c: CreateStructUnsafe if !stop => - stop = true - c.copy(children = c.children.map(trimNonTopLevelAliases)) - case Alias(child, _) if !stop => child + case Alias(child, _) => child } } @@ -2116,15 +2108,8 @@ object CleanupAliases extends Rule[LogicalPlan] { case a: AppendColumns => a case other => - var stop = false other transformExpressionsDown { - case c: CreateStruct if !stop => - stop = true - c.copy(children = c.children.map(trimNonTopLevelAliases)) - case c: CreateStructUnsafe if !stop => - stop = true - c.copy(children = c.children.map(trimNonTopLevelAliases)) - case Alias(child, _) if !stop => child + case Alias(child, _) => child } } } @@ -2217,3 +2202,19 @@ object TimeWindowing extends Rule[LogicalPlan] { } } } + +/** + * Resolve a [[CreateNamedStruct]] if it contains [[NamePlaceholder]]s. + */ +object ResolveCreateNamedStruct extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { + case e: CreateNamedStruct if !e.resolved => + val children = e.children.grouped(2).flatMap { + case Seq(NamePlaceholder, e: NamedExpression) if e.resolved => + Seq(Literal(e.name), e) + case kv => + kv + } + CreateNamedStruct(children.toList) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 3e836ca375e2e..b028d07fb8d0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -357,7 +357,7 @@ object FunctionRegistry { expression[MapValues]("map_values"), expression[Size]("size"), expression[SortArray]("sort_array"), - expression[CreateStruct]("struct"), + CreateStruct.registryEntry, // misc functions expression[AssertTrue]("assert_true"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index a81fa1ce3adcc..03e054d098511 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -119,7 +119,6 @@ object UnsafeProjection { */ def create(exprs: Seq[Expression]): UnsafeProjection = { val unsafeExprs = exprs.map(_ transform { - case CreateStruct(children) => CreateStructUnsafe(children) case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) }) GenerateUnsafeProjection.generate(unsafeExprs) @@ -145,7 +144,6 @@ object UnsafeProjection { subexpressionEliminationEnabled: Boolean): UnsafeProjection = { val e = exprs.map(BindReferences.bindReference(_, inputSchema)) .map(_ transform { - case CreateStruct(children) => CreateStructUnsafe(children) case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) }) GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 917aa0873130b..dbfb2996ec9d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -18,9 +18,11 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.analysis.Star import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData, TypeUtils} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -172,101 +174,71 @@ case class CreateMap(children: Seq[Expression]) extends Expression { } /** - * Returns a Row containing the evaluation of all children expressions. + * An expression representing a not yet available attribute name. This expression is unevaluable + * and as its name suggests it is a temporary place holder until we're able to determine the + * actual attribute name. */ -@ExpressionDescription( - usage = "_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.") -case class CreateStruct(children: Seq[Expression]) extends Expression { - - override def foldable: Boolean = children.forall(_.foldable) - - override lazy val dataType: StructType = { - val fields = children.zipWithIndex.map { case (child, idx) => - child match { - case ne: NamedExpression => - StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) - case _ => - StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) - } - } - StructType(fields) - } - +case object NamePlaceholder extends LeafExpression with Unevaluable { + override lazy val resolved: Boolean = false + override def foldable: Boolean = false override def nullable: Boolean = false + override def dataType: DataType = StringType + override def prettyName: String = "NamePlaceholder" + override def toString: String = prettyName +} - override def eval(input: InternalRow): Any = { - InternalRow(children.map(_.eval(input)): _*) +/** + * Returns a Row containing the evaluation of all children expressions. + */ +object CreateStruct extends FunctionBuilder { + def apply(children: Seq[Expression]): CreateNamedStruct = { + CreateNamedStruct(children.zipWithIndex.flatMap { + case (e: NamedExpression, _) if e.resolved => Seq(Literal(e.name), e) + case (e: NamedExpression, _) => Seq(NamePlaceholder, e) + case (e, index) => Seq(Literal(s"col${index + 1}"), e) + }) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val rowClass = classOf[GenericInternalRow].getName - val values = ctx.freshName("values") - ctx.addMutableState("Object[]", values, s"this.$values = null;") - - ev.copy(code = s""" - boolean ${ev.isNull} = false; - this.$values = new Object[${children.size}];""" + - ctx.splitExpressions( - ctx.INPUT_ROW, - children.zipWithIndex.map { case (e, i) => - val eval = e.genCode(ctx) - eval.code + s""" - if (${eval.isNull}) { - $values[$i] = null; - } else { - $values[$i] = ${eval.value}; - }""" - }) + - s""" - final InternalRow ${ev.value} = new $rowClass($values); - this.$values = null; - """) + /** + * Entry to use in the function registry. + */ + val registryEntry: (String, (ExpressionInfo, FunctionBuilder)) = { + val info: ExpressionInfo = new ExpressionInfo( + "org.apache.spark.sql.catalyst.expressions.NamedStruct", + null, + "struct", + "_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.", + "") + ("struct", (info, this)) } - - override def prettyName: String = "struct" } - /** - * Creates a struct with the given field names and values - * - * @param children Seq(name1, val1, name2, val2, ...) + * Common base class for both [[CreateNamedStruct]] and [[CreateNamedStructUnsafe]]. */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.") -// scalastyle:on line.size.limit -case class CreateNamedStruct(children: Seq[Expression]) extends Expression { +trait CreateNamedStructLike extends Expression { + lazy val (nameExprs, valExprs) = children.grouped(2).map { + case Seq(name, value) => (name, value) + }.toList.unzip - /** - * Returns Aliased [[Expression]]s that could be used to construct a flattened version of this - * StructType. - */ - def flatten: Seq[NamedExpression] = valExprs.zip(names).map { - case (v, n) => Alias(v, n.toString)() - } + lazy val names = nameExprs.map(_.eval(EmptyRow)) - private lazy val (nameExprs, valExprs) = - children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip + override def nullable: Boolean = false - private lazy val names = nameExprs.map(_.eval(EmptyRow)) + override def foldable: Boolean = valExprs.forall(_.foldable) override lazy val dataType: StructType = { val fields = names.zip(valExprs).map { - case (name, valExpr: NamedExpression) => - StructField(name.asInstanceOf[UTF8String].toString, - valExpr.dataType, valExpr.nullable, valExpr.metadata) - case (name, valExpr) => - StructField(name.asInstanceOf[UTF8String].toString, - valExpr.dataType, valExpr.nullable, Metadata.empty) + case (name, expr) => + val metadata = expr match { + case ne: NamedExpression => ne.metadata + case _ => Metadata.empty + } + StructField(name.toString, expr.dataType, expr.nullable, metadata) } StructType(fields) } - override def foldable: Boolean = valExprs.forall(_.foldable) - - override def nullable: Boolean = false - override def checkInputDataTypes(): TypeCheckResult = { if (children.size % 2 != 0) { TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.") @@ -274,8 +246,8 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) if (invalidNames.nonEmpty) { TypeCheckResult.TypeCheckFailure( - s"Only foldable StringType expressions are allowed to appear at odd position , got :" + - s" ${invalidNames.mkString(",")}") + "Only foldable StringType expressions are allowed to appear at odd position, got:" + + s" ${invalidNames.mkString(",")}") } else if (!names.contains(null)) { TypeCheckResult.TypeCheckSuccess } else { @@ -284,9 +256,29 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { } } + /** + * Returns Aliased [[Expression]]s that could be used to construct a flattened version of this + * StructType. + */ + def flatten: Seq[NamedExpression] = valExprs.zip(names).map { + case (v, n) => Alias(v, n.toString)() + } + override def eval(input: InternalRow): Any = { InternalRow(valExprs.map(_.eval(input)): _*) } +} + +/** + * Creates a struct with the given field names and values + * + * @param children Seq(name1, val1, name2, val2, ...) + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.") +// scalastyle:on line.size.limit +case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStructLike { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericInternalRow].getName @@ -316,44 +308,6 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { override def prettyName: String = "named_struct" } -/** - * Returns a Row containing the evaluation of all children expressions. This is a variant that - * returns UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with - * this expression automatically at runtime. - */ -case class CreateStructUnsafe(children: Seq[Expression]) extends Expression { - - override def foldable: Boolean = children.forall(_.foldable) - - override lazy val resolved: Boolean = childrenResolved - - override lazy val dataType: StructType = { - val fields = children.zipWithIndex.map { case (child, idx) => - child match { - case ne: NamedExpression => - StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) - case _ => - StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) - } - } - StructType(fields) - } - - override def nullable: Boolean = false - - override def eval(input: InternalRow): Any = { - InternalRow(children.map(_.eval(input)): _*) - } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val eval = GenerateUnsafeProjection.createCode(ctx, children) - ExprCode(code = eval.code, isNull = eval.isNull, value = eval.value) - } - - override def prettyName: String = "struct_unsafe" -} - - /** * Creates a struct with the given field names and values. This is a variant that returns * UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with @@ -361,31 +315,7 @@ case class CreateStructUnsafe(children: Seq[Expression]) extends Expression { * * @param children Seq(name1, val1, name2, val2, ...) */ -case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression { - - private lazy val (nameExprs, valExprs) = - children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip - - private lazy val names = nameExprs.map(_.eval(EmptyRow).toString) - - override lazy val dataType: StructType = { - val fields = names.zip(valExprs).map { - case (name, valExpr: NamedExpression) => - StructField(name, valExpr.dataType, valExpr.nullable, valExpr.metadata) - case (name, valExpr) => - StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty) - } - StructType(fields) - } - - override def foldable: Boolean = valExprs.forall(_.foldable) - - override def nullable: Boolean = false - - override def eval(input: InternalRow): Any = { - InternalRow(valExprs.map(_.eval(input)): _*) - } - +case class CreateNamedStructUnsafe(children: Seq[Expression]) extends CreateNamedStructLike { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = GenerateUnsafeProjection.createCode(ctx, valExprs) ExprCode(code = eval.code, isNull = eval.isNull, value = eval.value) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index ac1577b3abb4d..4b151c81d8f8b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -688,8 +688,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { // inline table comes in two styles: // style 1: values (1), (2), (3) -- multiple columns are supported // style 2: values 1, 2, 3 -- only a single column is supported here - case CreateStruct(children) => children // style 1 - case child => Seq(child) // style 2 + case struct: CreateNamedStruct => struct.valExprs // style 1 + case child => Seq(child) // style 2 } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 590774c043040..817de48de2798 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.analysis +import org.scalatest.ShouldMatchers + import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -25,7 +27,8 @@ import org.apache.spark.sql.catalyst.plans.{Cross, Inner} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ -class AnalysisSuite extends AnalysisTest { + +class AnalysisSuite extends AnalysisTest with ShouldMatchers { import org.apache.spark.sql.catalyst.analysis.TestRelations._ test("union project *") { @@ -218,9 +221,36 @@ class AnalysisSuite extends AnalysisTest { // CreateStruct is a special case that we should not trim Alias for it. plan = testRelation.select(CreateStruct(Seq(a, (a + 1).as("a+1"))).as("col")) - checkAnalysis(plan, plan) - plan = testRelation.select(CreateStructUnsafe(Seq(a, (a + 1).as("a+1"))).as("col")) - checkAnalysis(plan, plan) + expected = testRelation.select(CreateNamedStruct(Seq( + Literal(a.name), a, + Literal("a+1"), (a + 1))).as("col")) + checkAnalysis(plan, expected) + } + + test("Analysis may leave unnecassary aliases") { + val att1 = testRelation.output.head + var plan = testRelation.select( + CreateStruct(Seq(att1, ((att1.as("aa")) + 1).as("a_plus_1"))).as("col"), + att1 + ) + val prevPlan = getAnalyzer(true).execute(plan) + plan = prevPlan.select(CreateArray(Seq( + CreateStruct(Seq(att1, (att1 + 1).as("a_plus_1"))).as("col1"), + /** alias should be eliminated by [[CleanupAliases]] */ + "col".attr.as("col2") + )).as("arr")) + plan = getAnalyzer(true).execute(plan) + + val expectedPlan = prevPlan.select( + CreateArray(Seq( + CreateNamedStruct(Seq( + Literal(att1.name), att1, + Literal("a_plus_1"), (att1 + 1))), + 'col.struct(prevPlan.output(0).dataType.asInstanceOf[StructType]).notNull + )).as("arr") + ) + + checkAnalysis(plan, expectedPlan) } test("SPARK-10534: resolve attribute references in order by clause") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 0c307b2b8576b..c21c6de32c0ba 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -243,7 +243,6 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { val b = AttributeReference("b", IntegerType)() checkMetadata(CreateStruct(Seq(a, b))) checkMetadata(CreateNamedStruct(Seq("a", a, "b", b))) - checkMetadata(CreateStructUnsafe(Seq(a, b))) checkMetadata(CreateNamedStructUnsafe(Seq("a", a, "b", b))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 249408e0fbce4..7a131b30eafd7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -186,6 +186,9 @@ class Column(val expr: Expression) extends Logging { case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => UnresolvedAlias(a, Some(Column.generateAlias)) + // Wait until the struct is resolved. This will generate a nicer looking alias. + case struct: CreateNamedStructLike => UnresolvedAlias(struct) + case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index f873f34a845ef..6141fab4aff0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -137,7 +137,7 @@ object ColumnStatStruct { private def numTrues(e: Expression): Expression = Sum(If(e, one, zero)) private def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero)) - private def getStruct(exprs: Seq[Expression]): CreateStruct = { + private def getStruct(exprs: Seq[Expression]): CreateNamedStruct = { CreateStruct(exprs.map { expr: Expression => expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() @@ -168,7 +168,7 @@ object ColumnStatStruct { } } - def apply(attr: Attribute, relativeSD: Double): CreateStruct = attr.dataType match { + def apply(attr: Attribute, relativeSD: Double): CreateNamedStruct = attr.dataType match { // Use aggregate functions to compute statistics we need. case _: NumericType | TimestampType | DateType => getStruct(numericColumnStat(attr, relativeSD)) case StringType => getStruct(stringColumnStat(attr, relativeSD)) diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index a91f04e098b18..af6c930d64b76 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -87,7 +87,7 @@ struct -- !query 9 SELECT 'foo', MAX(STRUCT(a)) FROM testData WHERE a = 0 GROUP BY 1 -- !query 9 schema -struct> +struct> -- !query 9 output diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 6eb571b91ffab..90000445dffb2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -190,6 +190,12 @@ private[hive] class TestHiveSparkSession( new File(Thread.currentThread().getContextClassLoader.getResource(path).getFile) } + private def quoteHiveFile(path : String) = if (Utils.isWindows) { + getHiveFile(path).getPath.replace('\\', '/') + } else { + getHiveFile(path).getPath + } + def getWarehousePath(): String = { val tempConf = new SQLConf sc.conf.getAll.foreach { case (k, v) => tempConf.setConfString(k, v) } @@ -225,16 +231,16 @@ private[hive] class TestHiveSparkSession( val hiveQTestUtilTables: Seq[TestTable] = Seq( TestTable("src", "CREATE TABLE src (key INT, value STRING)".cmd, - s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' INTO TABLE src".cmd), + s"LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' INTO TABLE src".cmd), TestTable("src1", "CREATE TABLE src1 (key INT, value STRING)".cmd, - s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd), + s"LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd), TestTable("srcpart", () => { sql( "CREATE TABLE srcpart (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING)") for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) { sql( - s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' + s"""LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' |OVERWRITE INTO TABLE srcpart PARTITION (ds='$ds',hr='$hr') """.stripMargin) } @@ -244,7 +250,7 @@ private[hive] class TestHiveSparkSession( "CREATE TABLE srcpart1 (key INT, value STRING) PARTITIONED BY (ds STRING, hr INT)") for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- 11 to 12) { sql( - s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' + s"""LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' |OVERWRITE INTO TABLE srcpart1 PARTITION (ds='$ds',hr='$hr') """.stripMargin) } @@ -269,7 +275,7 @@ private[hive] class TestHiveSparkSession( sql( s""" - |LOAD DATA LOCAL INPATH '${getHiveFile("data/files/complex.seq")}' + |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/complex.seq")}' |INTO TABLE src_thrift """.stripMargin) }), @@ -308,7 +314,7 @@ private[hive] class TestHiveSparkSession( |) """.stripMargin.cmd, s""" - |LOAD DATA LOCAL INPATH '${getHiveFile("data/files/episodes.avro")}' + |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/episodes.avro")}' |INTO TABLE episodes """.stripMargin.cmd ), @@ -379,7 +385,7 @@ private[hive] class TestHiveSparkSession( TestTable("src_json", s"""CREATE TABLE src_json (json STRING) STORED AS TEXTFILE """.stripMargin.cmd, - s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/json.txt")}' INTO TABLE src_json".cmd) + s"LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/json.txt")}' INTO TABLE src_json".cmd) ) hiveQTestUtilTables.foreach(registerTestTable) diff --git a/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql b/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql index de0116a4dcbaf..cdda29af50e37 100644 --- a/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql +++ b/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql @@ -7,4 +7,4 @@ having b.key in (select a.key where a.value > 'val_9' and a.value = min(b.value)) order by b.key -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `min(value)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, min(`gen_attr_5`) AS `gen_attr_1`, min(`gen_attr_5`) AS `gen_attr_4` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING (struct(`gen_attr_0`, `gen_attr_4`) IN (SELECT `gen_attr_6` AS `_c0`, `gen_attr_7` AS `_c1` FROM (SELECT `gen_attr_2` AS `gen_attr_6`, `gen_attr_3` AS `gen_attr_7` FROM (SELECT `gen_attr_2`, `gen_attr_3` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_3 WHERE (`gen_attr_3` > 'val_9')) AS gen_subquery_2) AS gen_subquery_4))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC NULLS FIRST) AS b +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `min(value)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, min(`gen_attr_5`) AS `gen_attr_1`, min(`gen_attr_5`) AS `gen_attr_4` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING (named_struct('gen_attr_0', `gen_attr_0`, 'gen_attr_4', `gen_attr_4`) IN (SELECT `gen_attr_6` AS `_c0`, `gen_attr_7` AS `_c1` FROM (SELECT `gen_attr_2` AS `gen_attr_6`, `gen_attr_3` AS `gen_attr_7` FROM (SELECT `gen_attr_2`, `gen_attr_3` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_3 WHERE (`gen_attr_3` > 'val_9')) AS gen_subquery_2) AS gen_subquery_4))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC NULLS FIRST) AS b diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala index c7f10e569fa4d..12d18dc87ceb4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst import java.nio.charset.StandardCharsets import java.nio.file.{Files, NoSuchFileException, Paths} +import scala.io.Source import scala.util.control.NonFatal import org.apache.spark.sql.Column @@ -109,12 +110,15 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { Files.write(path, answerText.getBytes(StandardCharsets.UTF_8)) } else { val goldenFileName = s"sqlgen/$answerFile.sql" - val resourceFile = getClass.getClassLoader.getResource(goldenFileName) - if (resourceFile == null) { + val resourceStream = getClass.getClassLoader.getResourceAsStream(goldenFileName) + if (resourceStream == null) { throw new NoSuchFileException(goldenFileName) } - val path = resourceFile.getPath - val answerText = new String(Files.readAllBytes(Paths.get(path)), StandardCharsets.UTF_8) + val answerText = try { + Source.fromInputStream(resourceStream).mkString + } finally { + resourceStream.close + } val sqls = answerText.split(separator) assert(sqls.length == 2, "Golden sql files should have a separator.") val expectedSQL = sqls(1).trim() From 4af0ce2d96de3397c9bc05684cad290a52486577 Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Wed, 2 Nov 2016 11:29:26 -0700 Subject: [PATCH 024/198] [SPARK-17683][SQL] Support ArrayType in Literal.apply ## What changes were proposed in this pull request? This pr is to add pattern-matching entries for array data in `Literal.apply`. ## How was this patch tested? Added tests in `LiteralExpressionSuite`. Author: Takeshi YAMAMURO Closes #15257 from maropu/SPARK-17683. --- .../sql/catalyst/expressions/literals.scala | 57 ++++++++++++++++++- .../expressions/LiteralExpressionSuite.scala | 27 ++++++++- 2 files changed, 82 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index a597a17aadd99..1985e68c94e2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -17,14 +17,25 @@ package org.apache.spark.sql.catalyst.expressions +import java.lang.{Boolean => JavaBoolean} +import java.lang.{Byte => JavaByte} +import java.lang.{Double => JavaDouble} +import java.lang.{Float => JavaFloat} +import java.lang.{Integer => JavaInteger} +import java.lang.{Long => JavaLong} +import java.lang.{Short => JavaShort} +import java.math.{BigDecimal => JavaBigDecimal} import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import java.util import java.util.Objects import javax.xml.bind.DatatypeConverter +import scala.math.{BigDecimal, BigInt} + import org.json4s.JsonAST._ +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -46,12 +57,17 @@ object Literal { case s: String => Literal(UTF8String.fromString(s), StringType) case b: Boolean => Literal(b, BooleanType) case d: BigDecimal => Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale)) - case d: java.math.BigDecimal => + case d: JavaBigDecimal => Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale())) case d: Decimal => Literal(d, DecimalType(Math.max(d.precision, d.scale), d.scale)) case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType) case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType) case a: Array[Byte] => Literal(a, BinaryType) + case a: Array[_] => + val elementType = componentTypeToDataType(a.getClass.getComponentType()) + val dataType = ArrayType(elementType) + val convert = CatalystTypeConverters.createToCatalystConverter(dataType) + Literal(convert(a), dataType) case i: CalendarInterval => Literal(i, CalendarIntervalType) case null => Literal(null, NullType) case v: Literal => v @@ -59,6 +75,45 @@ object Literal { throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) } + /** + * Returns the Spark SQL DataType for a given class object. Since this type needs to be resolved + * in runtime, we use match-case idioms for class objects here. However, there are similar + * functions in other files (e.g., HiveInspectors), so these functions need to merged into one. + */ + private[this] def componentTypeToDataType(clz: Class[_]): DataType = clz match { + // primitive types + case JavaShort.TYPE => ShortType + case JavaInteger.TYPE => IntegerType + case JavaLong.TYPE => LongType + case JavaDouble.TYPE => DoubleType + case JavaByte.TYPE => ByteType + case JavaFloat.TYPE => FloatType + case JavaBoolean.TYPE => BooleanType + + // java classes + case _ if clz == classOf[Date] => DateType + case _ if clz == classOf[Timestamp] => TimestampType + case _ if clz == classOf[JavaBigDecimal] => DecimalType.SYSTEM_DEFAULT + case _ if clz == classOf[Array[Byte]] => BinaryType + case _ if clz == classOf[JavaShort] => ShortType + case _ if clz == classOf[JavaInteger] => IntegerType + case _ if clz == classOf[JavaLong] => LongType + case _ if clz == classOf[JavaDouble] => DoubleType + case _ if clz == classOf[JavaByte] => ByteType + case _ if clz == classOf[JavaFloat] => FloatType + case _ if clz == classOf[JavaBoolean] => BooleanType + + // other scala classes + case _ if clz == classOf[String] => StringType + case _ if clz == classOf[BigInt] => DecimalType.SYSTEM_DEFAULT + case _ if clz == classOf[BigDecimal] => DecimalType.SYSTEM_DEFAULT + case _ if clz == classOf[CalendarInterval] => CalendarIntervalType + + case _ if clz.isArray => ArrayType(componentTypeToDataType(clz.getComponentType)) + + case _ => throw new AnalysisException(s"Unsupported component type $clz in arrays") + } + /** * Constructs a [[Literal]] of [[ObjectType]], for example when you need to pass an object * into code generation. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index 450222d8cbba3..4af4da8a9f0c2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -21,6 +21,7 @@ import java.nio.charset.StandardCharsets import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -43,6 +44,7 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, TimestampType), null) checkEvaluation(Literal.create(null, CalendarIntervalType), null) checkEvaluation(Literal.create(null, ArrayType(ByteType, true)), null) + checkEvaluation(Literal.create(null, ArrayType(StringType, true)), null) checkEvaluation(Literal.create(null, MapType(StringType, IntegerType)), null) checkEvaluation(Literal.create(null, StructType(Seq.empty)), null) } @@ -122,5 +124,28 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { } } - // TODO(davies): add tests for ArrayType, MapType and StructType + test("array") { + def checkArrayLiteral(a: Array[_], elementType: DataType): Unit = { + val toCatalyst = (a: Array[_], elementType: DataType) => { + CatalystTypeConverters.createToCatalystConverter(ArrayType(elementType))(a) + } + checkEvaluation(Literal(a), toCatalyst(a, elementType)) + } + checkArrayLiteral(Array(1, 2, 3), IntegerType) + checkArrayLiteral(Array("a", "b", "c"), StringType) + checkArrayLiteral(Array(1.0, 4.0), DoubleType) + checkArrayLiteral(Array(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR), + CalendarIntervalType) + } + + test("unsupported types (map and struct) in literals") { + def checkUnsupportedTypeInLiteral(v: Any): Unit = { + val errMsgMap = intercept[RuntimeException] { + Literal(v) + } + assert(errMsgMap.getMessage.startsWith("Unsupported literal type")) + } + checkUnsupportedTypeInLiteral(Map("key1" -> 1, "key2" -> 2)) + checkUnsupportedTypeInLiteral(("mike", 29, 1.0)) + } } From 742e0fea5391857964e90d396641ecf95cac4248 Mon Sep 17 00:00:00 2001 From: buzhihuojie Date: Wed, 2 Nov 2016 11:36:20 -0700 Subject: [PATCH 025/198] [SPARK-17895] Improve doc for rangeBetween and rowsBetween ## What changes were proposed in this pull request? Copied description for row and range based frame boundary from https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala#L56 Added examples to show different behavior of rangeBetween and rowsBetween when involving duplicate values. Please review https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark before opening a pull request. Author: buzhihuojie Closes #15727 from david-weiluo-ren/improveDocForRangeAndRowsBetween. --- .../apache/spark/sql/expressions/Window.scala | 55 +++++++++++++++++++ .../spark/sql/expressions/WindowSpec.scala | 55 +++++++++++++++++++ 2 files changed, 110 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index 0b26d863cac5d..327bc379d4132 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -121,6 +121,32 @@ object Window { * and [[Window.currentRow]] to specify special boundary values, rather than using integral * values directly. * + * A row based boundary is based on the position of the row within the partition. + * An offset indicates the number of rows above or below the current row, the frame for the + * current row starts or ends. For instance, given a row based sliding frame with a lower bound + * offset of -1 and a upper bound offset of +2. The frame for row with index 5 would range from + * index 4 to index 6. + * + * {{{ + * import org.apache.spark.sql.expressions.Window + * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) + * .toDF("id", "category") + * df.withColumn("sum", + * sum('id) over Window.partitionBy('category).orderBy('id).rowsBetween(0,1)) + * .show() + * + * +---+--------+---+ + * | id|category|sum| + * +---+--------+---+ + * | 1| b| 3| + * | 2| b| 5| + * | 3| b| 3| + * | 1| a| 2| + * | 1| a| 3| + * | 2| a| 2| + * +---+--------+---+ + * }}} + * * @param start boundary start, inclusive. The frame is unbounded if this is * the minimum long value ([[Window.unboundedPreceding]]). * @param end boundary end, inclusive. The frame is unbounded if this is the @@ -144,6 +170,35 @@ object Window { * and [[Window.currentRow]] to specify special boundary values, rather than using integral * values directly. * + * A range based boundary is based on the actual value of the ORDER BY + * expression(s). An offset is used to alter the value of the ORDER BY expression, for + * instance if the current order by expression has a value of 10 and the lower bound offset + * is -3, the resulting lower bound for the current row will be 10 - 3 = 7. This however puts a + * number of constraints on the ORDER BY expressions: there can be only one expression and this + * expression must have a numerical data type. An exception can be made when the offset is 0, + * because no value modification is needed, in this case multiple and non-numeric ORDER BY + * expression are allowed. + * + * {{{ + * import org.apache.spark.sql.expressions.Window + * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) + * .toDF("id", "category") + * df.withColumn("sum", + * sum('id) over Window.partitionBy('category).orderBy('id).rangeBetween(0,1)) + * .show() + * + * +---+--------+---+ + * | id|category|sum| + * +---+--------+---+ + * | 1| b| 3| + * | 2| b| 5| + * | 3| b| 3| + * | 1| a| 4| + * | 1| a| 4| + * | 2| a| 2| + * +---+--------+---+ + * }}} + * * @param start boundary start, inclusive. The frame is unbounded if this is * the minimum long value ([[Window.unboundedPreceding]]). * @param end boundary end, inclusive. The frame is unbounded if this is the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 1e85b6e7881ad..4a8ce695bd4da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -89,6 +89,32 @@ class WindowSpec private[sql]( * and [[Window.currentRow]] to specify special boundary values, rather than using integral * values directly. * + * A row based boundary is based on the position of the row within the partition. + * An offset indicates the number of rows above or below the current row, the frame for the + * current row starts or ends. For instance, given a row based sliding frame with a lower bound + * offset of -1 and a upper bound offset of +2. The frame for row with index 5 would range from + * index 4 to index 6. + * + * {{{ + * import org.apache.spark.sql.expressions.Window + * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) + * .toDF("id", "category") + * df.withColumn("sum", + * sum('id) over Window.partitionBy('category).orderBy('id).rowsBetween(0,1)) + * .show() + * + * +---+--------+---+ + * | id|category|sum| + * +---+--------+---+ + * | 1| b| 3| + * | 2| b| 5| + * | 3| b| 3| + * | 1| a| 2| + * | 1| a| 3| + * | 2| a| 2| + * +---+--------+---+ + * }}} + * * @param start boundary start, inclusive. The frame is unbounded if this is * the minimum long value ([[Window.unboundedPreceding]]). * @param end boundary end, inclusive. The frame is unbounded if this is the @@ -111,6 +137,35 @@ class WindowSpec private[sql]( * and [[Window.currentRow]] to specify special boundary values, rather than using integral * values directly. * + * A range based boundary is based on the actual value of the ORDER BY + * expression(s). An offset is used to alter the value of the ORDER BY expression, for + * instance if the current order by expression has a value of 10 and the lower bound offset + * is -3, the resulting lower bound for the current row will be 10 - 3 = 7. This however puts a + * number of constraints on the ORDER BY expressions: there can be only one expression and this + * expression must have a numerical data type. An exception can be made when the offset is 0, + * because no value modification is needed, in this case multiple and non-numeric ORDER BY + * expression are allowed. + * + * {{{ + * import org.apache.spark.sql.expressions.Window + * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) + * .toDF("id", "category") + * df.withColumn("sum", + * sum('id) over Window.partitionBy('category).orderBy('id).rangeBetween(0,1)) + * .show() + * + * +---+--------+---+ + * | id|category|sum| + * +---+--------+---+ + * | 1| b| 3| + * | 2| b| 5| + * | 3| b| 3| + * | 1| a| 4| + * | 1| a| 4| + * | 2| a| 2| + * +---+--------+---+ + * }}} + * * @param start boundary start, inclusive. The frame is unbounded if this is * the minimum long value ([[Window.unboundedPreceding]]). * @param end boundary end, inclusive. The frame is unbounded if this is the From 02f203107b8eda1f1576e36c4f12b0e3bc5e910e Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 2 Nov 2016 11:41:49 -0700 Subject: [PATCH 026/198] [SPARK-14393][SQL] values generated by non-deterministic functions shouldn't change after coalesce or union ## What changes were proposed in this pull request? When a user appended a column using a "nondeterministic" function to a DataFrame, e.g., `rand`, `randn`, and `monotonically_increasing_id`, the expected semantic is the following: - The value in each row should remain unchanged, as if we materialize the column immediately, regardless of later DataFrame operations. However, since we use `TaskContext.getPartitionId` to get the partition index from the current thread, the values from nondeterministic columns might change if we call `union` or `coalesce` after. `TaskContext.getPartitionId` returns the partition index of the current Spark task, which might not be the corresponding partition index of the DataFrame where we defined the column. See the unit tests below or JIRA for examples. This PR uses the partition index from `RDD.mapPartitionWithIndex` instead of `TaskContext` and fixes the partition initialization logic in whole-stage codegen, normal codegen, and codegen fallback. `initializeStatesForPartition(partitionIndex: Int)` was added to `Projection`, `Nondeterministic`, and `Predicate` (codegen) and initialized right after object creation in `mapPartitionWithIndex`. `newPredicate` now returns a `Predicate` instance rather than a function for proper initialization. ## How was this patch tested? Unit tests. (Actually I'm not very confident that this PR fixed all issues without introducing new ones ...) cc: rxin davies Author: Xiangrui Meng Closes #15567 from mengxr/SPARK-14393. --- .../main/scala/org/apache/spark/rdd/RDD.scala | 16 +++++- .../sql/catalyst/expressions/Expression.scala | 19 +++++-- .../catalyst/expressions/InputFileName.scala | 2 +- .../MonotonicallyIncreasingID.scala | 11 ++-- .../sql/catalyst/expressions/Projection.scala | 22 +++++--- .../expressions/SparkPartitionID.scala | 13 +++-- .../expressions/codegen/CodeGenerator.scala | 14 +++++ .../expressions/codegen/CodegenFallback.scala | 18 +++++-- .../codegen/GenerateMutableProjection.scala | 4 ++ .../codegen/GeneratePredicate.scala | 18 +++++-- .../codegen/GenerateSafeProjection.scala | 4 ++ .../codegen/GenerateUnsafeProjection.scala | 4 ++ .../sql/catalyst/expressions/package.scala | 10 +++- .../sql/catalyst/expressions/predicates.scala | 4 -- .../expressions/randomExpressions.scala | 14 ++--- .../sql/catalyst/optimizer/Optimizer.scala | 1 + .../expressions/ExpressionEvalHelper.scala | 5 +- .../CodegenExpressionCachingSuite.scala | 13 +++-- .../sql/execution/DataSourceScanExec.scala | 6 ++- .../spark/sql/execution/ExistingRDD.scala | 3 +- .../spark/sql/execution/GenerateExec.scala | 3 +- .../spark/sql/execution/SparkPlan.scala | 4 +- .../sql/execution/WholeStageCodegenExec.scala | 8 ++- .../execution/basicPhysicalOperators.scala | 8 +-- .../columnar/InMemoryTableScanExec.scala | 5 +- .../joins/BroadcastNestedLoopJoinExec.scala | 7 +-- .../joins/CartesianProductExec.scala | 8 +-- .../spark/sql/execution/joins/HashJoin.scala | 2 +- .../execution/joins/SortMergeJoinExec.scala | 2 +- .../apache/spark/sql/execution/objects.scala | 6 ++- .../spark/sql/DataFrameFunctionsSuite.scala | 52 +++++++++++++++++++ .../hive/execution/HiveTableScanExec.scala | 3 +- 32 files changed, 231 insertions(+), 78 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index db535de9e9bb3..e018af35cb18d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -788,14 +788,26 @@ abstract class RDD[T: ClassTag]( } /** - * [performance] Spark's internal mapPartitions method which skips closure cleaning. It is a - * performance API to be used carefully only if we are sure that the RDD elements are + * [performance] Spark's internal mapPartitionsWithIndex method that skips closure cleaning. + * It is a performance API to be used carefully only if we are sure that the RDD elements are * serializable and don't require closure cleaning. * * @param preservesPartitioning indicates whether the input function preserves the partitioner, * which should be `false` unless this is a pair RDD and the input function doesn't modify * the keys. */ + private[spark] def mapPartitionsWithIndexInternal[U: ClassTag]( + f: (Int, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = withScope { + new MapPartitionsRDD( + this, + (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter), + preservesPartitioning) + } + + /** + * [performance] Spark's internal mapPartitions method that skips closure cleaning. + */ private[spark] def mapPartitionsInternal[U: ClassTag]( f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = withScope { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 9edc1ceff26a7..726a231fd814e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -272,17 +272,28 @@ trait Nondeterministic extends Expression { final override def deterministic: Boolean = false final override def foldable: Boolean = false + @transient private[this] var initialized = false - final def setInitialValues(): Unit = { - initInternal() + /** + * Initializes internal states given the current partition index and mark this as initialized. + * Subclasses should override [[initializeInternal()]]. + */ + final def initialize(partitionIndex: Int): Unit = { + initializeInternal(partitionIndex) initialized = true } - protected def initInternal(): Unit + protected def initializeInternal(partitionIndex: Int): Unit + /** + * @inheritdoc + * Throws an exception if [[initialize()]] is not called yet. + * Subclasses should override [[evalInternal()]]. + */ final override def eval(input: InternalRow = null): Any = { - require(initialized, "nondeterministic expression should be initialized before evaluate") + require(initialized, + s"Nondeterministic expression ${this.getClass.getName} should be initialized before eval.") evalInternal(input) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala index 96929ecf56375..b6c12c5351119 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -37,7 +37,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override def prettyName: String = "input_file_name" - override protected def initInternal(): Unit = {} + override protected def initializeInternal(partitionIndex: Int): Unit = {} override protected def evalInternal(input: InternalRow): UTF8String = { InputFileNameHolder.getInputFileName() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 5b4922e0cf2b7..72b8dcca26e2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -50,9 +50,9 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis @transient private[this] var partitionMask: Long = _ - override protected def initInternal(): Unit = { + override protected def initializeInternal(partitionIndex: Int): Unit = { count = 0L - partitionMask = TaskContext.getPartitionId().toLong << 33 + partitionMask = partitionIndex.toLong << 33 } override def nullable: Boolean = false @@ -68,9 +68,10 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val countTerm = ctx.freshName("count") val partitionMaskTerm = ctx.freshName("partitionMask") - ctx.addMutableState(ctx.JAVA_LONG, countTerm, s"$countTerm = 0L;") - ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, - s"$partitionMaskTerm = ((long) org.apache.spark.TaskContext.getPartitionId()) << 33;") + ctx.addMutableState(ctx.JAVA_LONG, countTerm, "") + ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, "") + ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") + ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") ev.copy(code = s""" final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 03e054d098511..476e37e6a9bac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.types.{DataType, StructType} /** * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. + * * @param expressions a sequence of expressions that determine the value of each column of the * output row. */ @@ -30,10 +31,12 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) - expressions.foreach(_.foreach { - case n: Nondeterministic => n.setInitialValues() - case _ => - }) + override def initialize(partitionIndex: Int): Unit = { + expressions.foreach(_.foreach { + case n: Nondeterministic => n.initialize(partitionIndex) + case _ => + }) + } // null check is required for when Kryo invokes the no-arg constructor. protected val exprArray = if (expressions != null) expressions.toArray else null @@ -54,6 +57,7 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { /** * A [[MutableProjection]] that is calculated by calling `eval` on each of the specified * expressions. + * * @param expressions a sequence of expressions that determine the value of each column of the * output row. */ @@ -63,10 +67,12 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu private[this] val buffer = new Array[Any](expressions.size) - expressions.foreach(_.foreach { - case n: Nondeterministic => n.setInitialValues() - case _ => - }) + override def initialize(partitionIndex: Int): Unit = { + expressions.foreach(_.foreach { + case n: Nondeterministic => n.initialize(partitionIndex) + case _ => + }) + } private[this] val exprArray = expressions.toArray private[this] var mutableRow: InternalRow = new GenericInternalRow(exprArray.length) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 1f675d5b07270..6bef473cac060 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -17,16 +17,15 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{DataType, IntegerType} /** - * Expression that returns the current partition id of the Spark task. + * Expression that returns the current partition id. */ @ExpressionDescription( - usage = "_FUNC_() - Returns the current partition id of the Spark task", + usage = "_FUNC_() - Returns the current partition id", extended = "> SELECT _FUNC_();\n 0") case class SparkPartitionID() extends LeafExpression with Nondeterministic { @@ -38,16 +37,16 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { override val prettyName = "SPARK_PARTITION_ID" - override protected def initInternal(): Unit = { - partitionId = TaskContext.getPartitionId() + override protected def initializeInternal(partitionIndex: Int): Unit = { + partitionId = partitionIndex } override protected def evalInternal(input: InternalRow): Int = partitionId override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val idTerm = ctx.freshName("partitionId") - ctx.addMutableState(ctx.JAVA_INT, idTerm, - s"$idTerm = org.apache.spark.TaskContext.getPartitionId();") + ctx.addMutableState(ctx.JAVA_INT, idTerm, "") + ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 6cab50ae1bf8d..9c3c6d3b2a7f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -184,6 +184,20 @@ class CodegenContext { splitExpressions(initCodes, "init", Nil) } + /** + * Code statements to initialize states that depend on the partition index. + * An integer `partitionIndex` will be made available within the scope. + */ + val partitionInitializationStatements: mutable.ArrayBuffer[String] = mutable.ArrayBuffer.empty + + def addPartitionInitializationStatement(statement: String): Unit = { + partitionInitializationStatements += statement + } + + def initPartition(): String = { + partitionInitializationStatements.mkString("\n") + } + /** * Holding all the functions those will be added into generated class. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 6a5a3e7933eea..0322d1dd6a9ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -25,15 +25,23 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, No trait CodegenFallback extends Expression { protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - foreach { - case n: Nondeterministic => n.setInitialValues() - case _ => - } - // LeafNode does not need `input` val input = if (this.isInstanceOf[LeafExpression]) "null" else ctx.INPUT_ROW val idx = ctx.references.length ctx.references += this + var childIndex = idx + this.foreach { + case n: Nondeterministic => + // This might add the current expression twice, but it won't hurt. + ctx.references += n + childIndex += 1 + ctx.addPartitionInitializationStatement( + s""" + |((Nondeterministic) references[$childIndex]) + | .initialize(partitionIndex); + """.stripMargin) + case _ => + } val objectTerm = ctx.freshName("obj") val placeHolder = ctx.registerComment(this.toString) if (nullable) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 5c4b56b0b224c..4d732445544a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -111,6 +111,10 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP ${ctx.initMutableStates()} } + public void initialize(int partitionIndex) { + ${ctx.initPartition()} + } + ${ctx.declareAddedFunctions()} public ${classOf[BaseMutableProjection].getName} target(InternalRow row) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 39aa7b17de6c9..dcd1ed96a298e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -25,19 +25,26 @@ import org.apache.spark.sql.catalyst.expressions._ */ abstract class Predicate { def eval(r: InternalRow): Boolean + + /** + * Initializes internal states given the current partition index. + * This is used by nondeterministic expressions to set initial states. + * The default implementation does nothing. + */ + def initialize(partitionIndex: Int): Unit = {} } /** * Generates bytecode that evaluates a boolean [[Expression]] on a given input [[InternalRow]]. */ -object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Boolean] { +object GeneratePredicate extends CodeGenerator[Expression, Predicate] { protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer.execute(in) protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression = BindReferences.bindReference(in, inputSchema) - protected def create(predicate: Expression): ((InternalRow) => Boolean) = { + protected def create(predicate: Expression): Predicate = { val ctx = newCodeGenContext() val eval = predicate.genCode(ctx) @@ -55,6 +62,10 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool ${ctx.initMutableStates()} } + public void initialize(int partitionIndex) { + ${ctx.initPartition()} + } + ${ctx.declareAddedFunctions()} public boolean eval(InternalRow ${ctx.INPUT_ROW}) { @@ -67,7 +78,6 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}") - val p = CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate] - (r: InternalRow) => p.eval(r) + CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 2773e1a666212..b1cb6edefb852 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -173,6 +173,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ${ctx.initMutableStates()} } + public void initialize(int partitionIndex) { + ${ctx.initPartition()} + } + ${ctx.declareAddedFunctions()} public java.lang.Object apply(java.lang.Object _i) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 7cc45372daa5a..7e4c9089a2cb9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -380,6 +380,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ${ctx.initMutableStates()} } + public void initialize(int partitionIndex) { + ${ctx.initPartition()} + } + ${ctx.declareAddedFunctions()} // Scala.Function1 need this diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 1510a4796683c..1b00c9e79da22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -64,7 +64,15 @@ package object expressions { * column of the new row. If the schema of the input row is specified, then the given expression * will be bound to that schema. */ - abstract class Projection extends (InternalRow => InternalRow) + abstract class Projection extends (InternalRow => InternalRow) { + + /** + * Initializes internal states given the current partition index. + * This is used by nondeterministic expressions to set initial states. + * The default implementation does nothing. + */ + def initialize(partitionIndex: Int): Unit = {} + } /** * Converts a [[InternalRow]] to another Row given a sequence of expression that define each diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 9394e39aadd9d..c941a576d00d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -31,10 +31,6 @@ object InterpretedPredicate { create(BindReferences.bindReference(expression, inputSchema)) def create(expression: Expression): (InternalRow => Boolean) = { - expression.foreach { - case n: Nondeterministic => n.setInitialValues() - case _ => - } (r: InternalRow) => expression.eval(r).asInstanceOf[Boolean] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index ca200768b2286..e09029f5aab9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -42,8 +42,8 @@ abstract class RDG extends LeafExpression with Nondeterministic { */ @transient protected var rng: XORShiftRandom = _ - override protected def initInternal(): Unit = { - rng = new XORShiftRandom(seed + TaskContext.getPartitionId) + override protected def initializeInternal(partitionIndex: Int): Unit = { + rng = new XORShiftRandom(seed + partitionIndex) } override def nullable: Boolean = false @@ -70,8 +70,9 @@ case class Rand(seed: Long) extends RDG { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName - ctx.addMutableState(className, rngTerm, - s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());") + ctx.addMutableState(className, rngTerm, "") + ctx.addPartitionInitializationStatement( + s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = "false") } @@ -93,8 +94,9 @@ case class Randn(seed: Long) extends RDG { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName - ctx.addMutableState(className, rngTerm, - s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());") + ctx.addMutableState(className, rngTerm, "") + ctx.addPartitionInitializationStatement( + s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e5e2cd7d27d15..b6ad5db74e3c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1060,6 +1060,7 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] { case Project(projectList, LocalRelation(output, data)) if !projectList.exists(hasUnevaluableExpr) => val projection = new InterpretedProjection(projectList, output) + projection.initialize(0) LocalRelation(projectList.map(_.toAttribute), data.map(projection)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index f0c149c02b9aa..9ceb709185417 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -75,7 +75,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = { expression.foreach { - case n: Nondeterministic => n.setInitialValues() + case n: Nondeterministic => n.initialize(0) case _ => } expression.eval(inputRow) @@ -121,6 +121,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { val plan = generateProject( GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) + plan.initialize(0) val actual = plan(inputRow).get(0, expression.dataType) if (!checkResult(actual, expected)) { @@ -182,12 +183,14 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { var plan = generateProject( GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) + plan.initialize(0) var actual = plan(inputRow).get(0, expression.dataType) assert(checkResult(actual, expected)) plan = generateProject( GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) + plan.initialize(0) actual = FromUnsafeProjection(expression.dataType :: Nil)( plan(inputRow)).get(0, expression.dataType) assert(checkResult(actual, expected)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala index 06dc3bd33b90e..fe5cb8eda824f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala @@ -31,19 +31,22 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { // Use an Add to wrap two of them together in case we only initialize the top level expressions. val expr = And(NondeterministicExpression(), NondeterministicExpression()) val instance = UnsafeProjection.create(Seq(expr)) + instance.initialize(0) assert(instance.apply(null).getBoolean(0) === false) } test("GenerateMutableProjection should initialize expressions") { val expr = And(NondeterministicExpression(), NondeterministicExpression()) val instance = GenerateMutableProjection.generate(Seq(expr)) + instance.initialize(0) assert(instance.apply(null).getBoolean(0) === false) } test("GeneratePredicate should initialize expressions") { val expr = And(NondeterministicExpression(), NondeterministicExpression()) val instance = GeneratePredicate.generate(expr) - assert(instance.apply(null) === false) + instance.initialize(0) + assert(instance.eval(null) === false) } test("GenerateUnsafeProjection should not share expression instances") { @@ -73,13 +76,13 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { test("GeneratePredicate should not share expression instances") { val expr1 = MutableExpression() val instance1 = GeneratePredicate.generate(expr1) - assert(instance1.apply(null) === false) + assert(instance1.eval(null) === false) val expr2 = MutableExpression() expr2.mutableState = true val instance2 = GeneratePredicate.generate(expr2) - assert(instance1.apply(null) === false) - assert(instance2.apply(null) === true) + assert(instance1.eval(null) === false) + assert(instance2.eval(null) === true) } } @@ -89,7 +92,7 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { */ case class NondeterministicExpression() extends LeafExpression with Nondeterministic with CodegenFallback { - override protected def initInternal(): Unit = { } + override protected def initializeInternal(partitionIndex: Int): Unit = {} override protected def evalInternal(input: InternalRow): Any = false override def nullable: Boolean = false override def dataType: DataType = BooleanType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index fdd1fa3648251..e485b52b43f76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -71,8 +71,9 @@ case class RowDataSourceScanExec( val unsafeRow = if (outputUnsafeRows) { rdd } else { - rdd.mapPartitionsInternal { iter => + rdd.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(schema) + proj.initialize(index) iter.map(proj) } } @@ -284,8 +285,9 @@ case class FileSourceScanExec( val unsafeRows = { val scan = inputRDD if (needsUnsafeRowConversion) { - scan.mapPartitionsInternal { iter => + scan.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(schema) + proj.initialize(index) iter.map(proj) } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 455fb5bfbb6f7..aab087cd98716 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -190,8 +190,9 @@ case class RDDScanExec( protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - rdd.mapPartitionsInternal { iter => + rdd.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(schema) + proj.initialize(index) iter.map { r => numOutputRows += 1 proj(r) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 2663129562660..19fbf0c162048 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -94,8 +94,9 @@ case class GenerateExec( } val numOutputRows = longMetric("numOutputRows") - rows.mapPartitionsInternal { iter => + rows.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(output, output) + proj.initialize(index) iter.map { r => numOutputRows += 1 proj(r) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 24d0cffef82a2..cadab37a449aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -29,7 +29,7 @@ import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => GenPredicate, _} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetric @@ -354,7 +354,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } protected def newPredicate( - expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = { + expression: Expression, inputSchema: Seq[Attribute]): GenPredicate = { GeneratePredicate.generate(expression, inputSchema) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 6303483f22fd3..516b9d5444d31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -331,6 +331,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co partitionIndex = index; this.inputs = inputs; ${ctx.initMutableStates()} + ${ctx.initPartition()} } ${ctx.declareAddedFunctions()} @@ -383,10 +384,13 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co } else { // Right now, we support up to two input RDDs. rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) => - val partitionIndex = TaskContext.getPartitionId() + Iterator((leftIter, rightIter)) + // a small hack to obtain the correct partition index + }.mapPartitionsWithIndex { (index, zippedIter) => + val (leftIter, rightIter) = zippedIter.next() val clazz = CodeGenerator.compile(cleanedSource) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] - buffer.init(partitionIndex, Array(leftIter, rightIter)) + buffer.init(index, Array(leftIter, rightIter)) new Iterator[InternalRow] { override def hasNext: Boolean = { val v = buffer.hasNext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index a5291e0c12f88..32133f52630cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -70,9 +70,10 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) } protected override def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitionsInternal { iter => + child.execute().mapPartitionsWithIndexInternal { (index, iter) => val project = UnsafeProjection.create(projectList, child.output, subexpressionEliminationEnabled) + project.initialize(index) iter.map(project) } } @@ -205,10 +206,11 @@ case class FilterExec(condition: Expression, child: SparkPlan) protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - child.execute().mapPartitionsInternal { iter => + child.execute().mapPartitionsWithIndexInternal { (index, iter) => val predicate = newPredicate(condition, child.output) + predicate.initialize(0) iter.filter { row => - val r = predicate(row) + val r = predicate.eval(row) if (r) numOutputRows += 1 r } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index b87016d5a5696..9028caa446e8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -132,10 +132,11 @@ case class InMemoryTableScanExec( val relOutput: AttributeSeq = relation.output val buffers = relation.cachedColumnBuffers - buffers.mapPartitionsInternal { cachedBatchIterator => + buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) => val partitionFilter = newPredicate( partitionFilters.reduceOption(And).getOrElse(Literal(true)), schema) + partitionFilter.initialize(index) // Find the ordinals and data types of the requested columns. val (requestedColumnIndices, requestedColumnDataTypes) = @@ -147,7 +148,7 @@ case class InMemoryTableScanExec( val cachedBatchesToScan = if (inMemoryPartitionPruningEnabled) { cachedBatchIterator.filter { cachedBatch => - if (!partitionFilter(cachedBatch.stats)) { + if (!partitionFilter.eval(cachedBatch.stats)) { def statsString: String = schemaIndex.map { case (a, i) => val value = cachedBatch.stats.get(i, a.dataType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index bfe7e3dea45df..f526a19876670 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -52,7 +52,7 @@ case class BroadcastNestedLoopJoinExec( UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil } - private[this] def genResultProjection: InternalRow => InternalRow = joinType match { + private[this] def genResultProjection: UnsafeProjection = joinType match { case LeftExistence(j) => UnsafeProjection.create(output, output) case other => @@ -84,7 +84,7 @@ case class BroadcastNestedLoopJoinExec( @transient private lazy val boundCondition = { if (condition.isDefined) { - newPredicate(condition.get, streamed.output ++ broadcast.output) + newPredicate(condition.get, streamed.output ++ broadcast.output).eval _ } else { (r: InternalRow) => true } @@ -366,8 +366,9 @@ case class BroadcastNestedLoopJoinExec( } val numOutputRows = longMetric("numOutputRows") - resultRdd.mapPartitionsInternal { iter => + resultRdd.mapPartitionsWithIndexInternal { (index, iter) => val resultProj = genResultProjection + resultProj.initialize(index) iter.map { r => numOutputRows += 1 resultProj(r) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala index 15dc9b40662e2..8341fe2ffd078 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala @@ -98,15 +98,15 @@ case class CartesianProductExec( val rightResults = right.execute().asInstanceOf[RDD[UnsafeRow]] val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size) - pair.mapPartitionsInternal { iter => + pair.mapPartitionsWithIndexInternal { (index, iter) => val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema) val filtered = if (condition.isDefined) { - val boundCondition: (InternalRow) => Boolean = - newPredicate(condition.get, left.output ++ right.output) + val boundCondition = newPredicate(condition.get, left.output ++ right.output) + boundCondition.initialize(index) val joined = new JoinedRow iter.filter { r => - boundCondition(joined(r._1, r._2)) + boundCondition.eval(joined(r._1, r._2)) } } else { iter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 05c5e2f4cd77b..1aef5f6864263 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -81,7 +81,7 @@ trait HashJoin { UnsafeProjection.create(streamedKeys) @transient private[this] lazy val boundCondition = if (condition.isDefined) { - newPredicate(condition.get, streamedPlan.output ++ buildPlan.output) + newPredicate(condition.get, streamedPlan.output ++ buildPlan.output).eval _ } else { (r: InternalRow) => true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index ecf7cf289f034..ca9c0ed8cec32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -101,7 +101,7 @@ case class SortMergeJoinExec( left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => val boundCondition: (InternalRow) => Boolean = { condition.map { cond => - newPredicate(cond, left.output ++ right.output) + newPredicate(cond, left.output ++ right.output).eval _ }.getOrElse { (r: InternalRow) => true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 9df56bbf1ef87..fde3b2a528994 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -87,8 +87,9 @@ case class DeserializeToObjectExec( } override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitionsInternal { iter => + child.execute().mapPartitionsWithIndexInternal { (index, iter) => val projection = GenerateSafeProjection.generate(deserializer :: Nil, child.output) + projection.initialize(index) iter.map(projection) } } @@ -124,8 +125,9 @@ case class SerializeFromObjectExec( } override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitionsInternal { iter => + child.execute().mapPartitionsWithIndexInternal { (index, iter) => val projection = UnsafeProjection.create(serializer) + projection.initialize(index) iter.map(projection) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 586a0fffeb7a1..0e9a2c6cf7dec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -19,7 +19,13 @@ package org.apache.spark.sql import java.nio.charset.StandardCharsets +import scala.util.Random + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -406,4 +412,50 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Seq(Row(true), Row(true)) ) } + + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { + import DataFrameFunctionsSuite.CodegenFallbackExpr + for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { + val c = if (codegenFallback) { + Column(CodegenFallbackExpr(v.expr)) + } else { + v + } + withSQLConf( + (SQLConf.WHOLESTAGE_FALLBACK.key, codegenFallback.toString), + (SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString)) { + val df = spark.range(0, 4, 1, 4).withColumn("c", c) + val rows = df.collect() + val rowsAfterCoalesce = df.coalesce(2).collect() + assert(rows === rowsAfterCoalesce, "Values changed after coalesce when " + + s"codegenFallback=$codegenFallback and wholeStage=$wholeStage.") + + val df1 = spark.range(0, 2, 1, 2).withColumn("c", c) + val rows1 = df1.collect() + val df2 = spark.range(2, 4, 1, 2).withColumn("c", c) + val rows2 = df2.collect() + val rowsAfterUnion = df1.union(df2).collect() + assert(rowsAfterUnion === rows1 ++ rows2, "Values changed after union when " + + s"codegenFallback=$codegenFallback and wholeStage=$wholeStage.") + } + } + } + + test("SPARK-14393: values generated by non-deterministic functions shouldn't change after " + + "coalesce or union") { + Seq( + monotonically_increasing_id(), spark_partition_id(), + rand(Random.nextLong()), randn(Random.nextLong()) + ).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_)) + } +} + +object DataFrameFunctionsSuite { + case class CodegenFallbackExpr(child: Expression) extends Expression with CodegenFallback { + override def children: Seq[Expression] = Seq(child) + override def nullable: Boolean = child.nullable + override def dataType: DataType = child.dataType + override lazy val resolved = true + override def eval(input: InternalRow): Any = child.eval(input) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 231f204b12b47..c80695bd3e0fe 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -154,8 +154,9 @@ case class HiveTableScanExec( val numOutputRows = longMetric("numOutputRows") // Avoid to serialize MetastoreRelation because schema is lazy. (see SPARK-15649) val outputSchema = schema - rdd.mapPartitionsInternal { iter => + rdd.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(outputSchema) + proj.initialize(index) iter.map { r => numOutputRows += 1 proj(r) From 3c24299b71e23e159edbb972347b13430f92a465 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 2 Nov 2016 11:47:45 -0700 Subject: [PATCH 027/198] [SPARK-18160][CORE][YARN] spark.files & spark.jars should not be passed to driver in yarn mode ## What changes were proposed in this pull request? spark.files is still passed to driver in yarn mode, so SparkContext will still handle it which cause the error in the jira desc. ## How was this patch tested? Tested manually in a 5 node cluster. As this issue only happens in multiple node cluster, so I didn't write test for it. Author: Jeff Zhang Closes #15669 from zjffdu/SPARK-18160. --- .../scala/org/apache/spark/SparkContext.scala | 29 ++++--------------- .../org/apache/spark/deploy/yarn/Client.scala | 5 +++- 2 files changed, 10 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 4694790c72cd8..63478c88b057b 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1716,29 +1716,12 @@ class SparkContext(config: SparkConf) extends Logging { key = uri.getScheme match { // A JAR file which exists only on the driver node case null | "file" => - if (master == "yarn" && deployMode == "cluster") { - // In order for this to work in yarn cluster mode the user must specify the - // --addJars option to the client to upload the file into the distributed cache - // of the AM to make it show up in the current working directory. - val fileName = new Path(uri.getPath).getName() - try { - env.rpcEnv.fileServer.addJar(new File(fileName)) - } catch { - case e: Exception => - // For now just log an error but allow to go through so spark examples work. - // The spark examples don't really need the jar distributed since its also - // the app jar. - logError("Error adding jar (" + e + "), was the --addJars option used?") - null - } - } else { - try { - env.rpcEnv.fileServer.addJar(new File(uri.getPath)) - } catch { - case exc: FileNotFoundException => - logError(s"Jar not found at $path") - null - } + try { + env.rpcEnv.fileServer.addJar(new File(uri.getPath)) + } catch { + case exc: FileNotFoundException => + logError(s"Jar not found at $path") + null } // A JAR file which exists locally on every worker node case "local" => diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 55e4a833b6707..053a78617d4e0 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1202,7 +1202,10 @@ private object Client extends Logging { // Note that any env variable with the SPARK_ prefix gets propagated to all (remote) processes System.setProperty("SPARK_YARN_MODE", "true") val sparkConf = new SparkConf - + // SparkSubmit would use yarn cache to distribute files & jars in yarn mode, + // so remove them from sparkConf here for yarn mode. + sparkConf.remove("spark.jars") + sparkConf.remove("spark.files") val args = new ClientArguments(argStrings) new Client(args, sparkConf).run() } From 37d95227a21de602b939dae84943ba007f434513 Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Wed, 2 Nov 2016 11:52:29 -0700 Subject: [PATCH 028/198] [SPARK-17058][BUILD] Add maven snapshots-and-staging profile to build/test against staging artifacts ## What changes were proposed in this pull request? Adds a `snapshots-and-staging profile` so that RCs of projects like Hadoop and HBase can be used in developer-only build and test runs. There's a comment above the profile telling people not to use this in production. There's no attempt to do the same for SBT, as Ivy is different. ## How was this patch tested? Tested by building against the Hadoop 2.7.3 RC 1 JARs without the profile (and without any local copy of the 2.7.3 artifacts), the build failed ``` mvn install -DskipTests -Pyarn,hadoop-2.7,hive -Dhadoop.version=2.7.3 ... [INFO] ------------------------------------------------------------------------ [INFO] Building Spark Project Launcher 2.1.0-SNAPSHOT [INFO] ------------------------------------------------------------------------ Downloading: https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-client/2.7.3/hadoop-client-2.7.3.pom [WARNING] The POM for org.apache.hadoop:hadoop-client:jar:2.7.3 is missing, no dependency information available Downloading: https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-client/2.7.3/hadoop-client-2.7.3.jar [INFO] ------------------------------------------------------------------------ [INFO] Reactor Summary: [INFO] [INFO] Spark Project Parent POM ........................... SUCCESS [ 4.482 s] [INFO] Spark Project Tags ................................. SUCCESS [ 17.402 s] [INFO] Spark Project Sketch ............................... SUCCESS [ 11.252 s] [INFO] Spark Project Networking ........................... SUCCESS [ 13.458 s] [INFO] Spark Project Shuffle Streaming Service ............ SUCCESS [ 9.043 s] [INFO] Spark Project Unsafe ............................... SUCCESS [ 16.027 s] [INFO] Spark Project Launcher ............................. FAILURE [ 1.653 s] [INFO] Spark Project Core ................................. SKIPPED ... ``` With the profile, the build completed ``` mvn install -DskipTests -Pyarn,hadoop-2.7,hive,snapshots-and-staging -Dhadoop.version=2.7.3 ``` Author: Steve Loughran Closes #14646 from steveloughran/stevel/SPARK-17058-support-asf-snapshots. --- pom.xml | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/pom.xml b/pom.xml index aaf7cfa7eb2ad..04d2eaa1d3bac 100644 --- a/pom.xml +++ b/pom.xml @@ -2693,6 +2693,54 @@ + + + snapshots-and-staging + + + https://repository.apache.org/content/groups/staging/ + https://repository.apache.org/content/repositories/snapshots/ + + + + + ASF Staging + ${asf.staging} + + + ASF Snapshots + ${asf.snapshots} + + true + + + false + + + + + + + ASF Staging + ${asf.staging} + + + ASF Snapshots + ${asf.snapshots} + + true + + + false + + + + + + + org.json + json + From b533fa2b205544b42dcebe0a6fee9d8275f6da7d Mon Sep 17 00:00:00 2001 From: Michael Allman Date: Thu, 10 Nov 2016 13:41:13 -0800 Subject: [PATCH 112/198] [SPARK-17993][SQL] Fix Parquet log output redirection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (Link to Jira issue: https://issues.apache.org/jira/browse/SPARK-17993) ## What changes were proposed in this pull request? PR #14690 broke parquet log output redirection for converted partitioned Hive tables. For example, when querying parquet files written by Parquet-mr 1.6.0 Spark prints a torrent of (harmless) warning messages from the Parquet reader: ``` Oct 18, 2016 7:42:18 PM WARNING: org.apache.parquet.CorruptStatistics: Ignoring statistics because created_by could not be parsed (see PARQUET-251): parquet-mr version 1.6.0 org.apache.parquet.VersionParser$VersionParseException: Could not parse created_by: parquet-mr version 1.6.0 using format: (.+) version ((.*) )?\(build ?(.*)\) at org.apache.parquet.VersionParser.parse(VersionParser.java:112) at org.apache.parquet.CorruptStatistics.shouldIgnoreStatistics(CorruptStatistics.java:60) at org.apache.parquet.format.converter.ParquetMetadataConverter.fromParquetStatistics(ParquetMetadataConverter.java:263) at org.apache.parquet.hadoop.ParquetFileReader$Chunk.readAllPages(ParquetFileReader.java:583) at org.apache.parquet.hadoop.ParquetFileReader.readNextRowGroup(ParquetFileReader.java:513) at org.apache.spark.sql.execution.datasources.parquet.VectorizedParquetRecordReader.checkEndOfRowGroup(VectorizedParquetRecordReader.java:270) at org.apache.spark.sql.execution.datasources.parquet.VectorizedParquetRecordReader.nextBatch(VectorizedParquetRecordReader.java:225) at org.apache.spark.sql.execution.datasources.parquet.VectorizedParquetRecordReader.nextKeyValue(VectorizedParquetRecordReader.java:137) at org.apache.spark.sql.execution.datasources.RecordReaderIterator.hasNext(RecordReaderIterator.scala:39) at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.hasNext(FileScanRDD.scala:102) at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.nextIterator(FileScanRDD.scala:162) at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.hasNext(FileScanRDD.scala:102) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.scan_nextBatch$(Unknown Source) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:372) at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:231) at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:225) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:803) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:803) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:319) at org.apache.spark.rdd.RDD.iterator(RDD.scala:283) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87) at org.apache.spark.scheduler.Task.run(Task.scala:99) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) ``` This only happens during execution, not planning, and it doesn't matter what log level the `SparkContext` is set to. That's because Parquet (versions < 1.9) doesn't use slf4j for logging. Note, you can tell that log redirection is not working here because the log message format does not conform to the default Spark log message format. This is a regression I noted as something we needed to fix as a follow up. It appears that the problem arose because we removed the call to `inferSchema` during Hive table conversion. That call is what triggered the output redirection. ## How was this patch tested? I tested this manually in four ways: 1. Executing `spark.sqlContext.range(10).selectExpr("id as a").write.mode("overwrite").parquet("test")`. 2. Executing `spark.read.format("parquet").load(legacyParquetFile).show` for a Parquet file `legacyParquetFile` written using Parquet-mr 1.6.0. 3. Executing `select * from legacy_parquet_table limit 1` for some unpartitioned Parquet-based Hive table written using Parquet-mr 1.6.0. 4. Executing `select * from legacy_partitioned_parquet_table where partcol=x limit 1` for some partitioned Parquet-based Hive table written using Parquet-mr 1.6.0. I ran each test with a new instance of `spark-shell` or `spark-sql`. Incidentally, I found that test case 3 was not a regression—redirection was not occurring in the master codebase prior to #14690. I spent some time working on a unit test, but based on my experience working on this ticket I feel that automated testing here is far from feasible. cc ericl dongjoon-hyun Author: Michael Allman Closes #15538 from mallman/spark-17993-fix_parquet_log_redirection. --- .../parquet/ParquetLogRedirector.java | 72 +++++++++++++++++++ .../parquet/ParquetFileFormat.scala | 58 ++++----------- sql/core/src/test/resources/log4j.properties | 4 +- sql/hive/src/test/resources/log4j.properties | 4 ++ 4 files changed, 90 insertions(+), 48 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetLogRedirector.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetLogRedirector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetLogRedirector.java new file mode 100644 index 0000000000000..7a7f32ee1e87b --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetLogRedirector.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet; + +import java.io.Serializable; +import java.util.logging.Handler; +import java.util.logging.Logger; + +import org.apache.parquet.Log; +import org.slf4j.bridge.SLF4JBridgeHandler; + +// Redirects the JUL logging for parquet-mr versions <= 1.8 to SLF4J logging using +// SLF4JBridgeHandler. Parquet-mr versions >= 1.9 use SLF4J directly +final class ParquetLogRedirector implements Serializable { + // Client classes should hold a reference to INSTANCE to ensure redirection occurs. This is + // especially important for Serializable classes where fields are set but constructors are + // ignored + static final ParquetLogRedirector INSTANCE = new ParquetLogRedirector(); + + // JUL loggers must be held by a strong reference, otherwise they may get destroyed by GC. + // However, the root JUL logger used by Parquet isn't properly referenced. Here we keep + // references to loggers in both parquet-mr <= 1.6 and 1.7/1.8 + private static final Logger apacheParquetLogger = + Logger.getLogger(Log.class.getPackage().getName()); + private static final Logger parquetLogger = Logger.getLogger("parquet"); + + static { + // For parquet-mr 1.7 and 1.8, which are under `org.apache.parquet` namespace. + try { + Class.forName(Log.class.getName()); + redirect(Logger.getLogger(Log.class.getPackage().getName())); + } catch (ClassNotFoundException ex) { + throw new RuntimeException(ex); + } + + // For parquet-mr 1.6.0 and lower versions bundled with Hive, which are under `parquet` + // namespace. + try { + Class.forName("parquet.Log"); + redirect(Logger.getLogger("parquet")); + } catch (Throwable t) { + // SPARK-9974: com.twitter:parquet-hadoop-bundle:1.6.0 is not packaged into the assembly + // when Spark is built with SBT. So `parquet.Log` may not be found. This try/catch block + // should be removed after this issue is fixed. + } + } + + private ParquetLogRedirector() { + } + + private static void redirect(Logger logger) { + for (Handler handler : logger.getHandlers()) { + logger.removeHandler(handler); + } + logger.setUseParentHandlers(false); + logger.addHandler(new SLF4JBridgeHandler()); + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index b8ea7f40c4ab3..031a0fe57893f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.datasources.parquet import java.net.URI -import java.util.logging.{Logger => JLogger} import scala.collection.JavaConverters._ import scala.collection.mutable @@ -29,14 +28,12 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.FileSplit import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl -import org.apache.parquet.{Log => ApacheParquetLog} import org.apache.parquet.filter2.compat.FilterCompat import org.apache.parquet.filter2.predicate.FilterApi import org.apache.parquet.hadoop._ import org.apache.parquet.hadoop.codec.CodecConfig import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.schema.MessageType -import org.slf4j.bridge.SLF4JBridgeHandler import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging @@ -56,6 +53,11 @@ class ParquetFileFormat with DataSourceRegister with Logging with Serializable { + // Hold a reference to the (serializable) singleton instance of ParquetLogRedirector. This + // ensures the ParquetLogRedirector class is initialized whether an instance of ParquetFileFormat + // is constructed or deserialized. Do not heed the Scala compiler's warning about an unused field + // here. + private val parquetLogRedirector = ParquetLogRedirector.INSTANCE override def shortName(): String = "parquet" @@ -129,10 +131,14 @@ class ParquetFileFormat conf.setBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, false) } - ParquetFileFormat.redirectParquetLogs() - new OutputWriterFactory { - override def newInstance( + // This OutputWriterFactory instance is deserialized when writing Parquet files on the + // executor side without constructing or deserializing ParquetFileFormat. Therefore, we hold + // another reference to ParquetLogRedirector.INSTANCE here to ensure the latter class is + // initialized. + private val parquetLogRedirector = ParquetLogRedirector.INSTANCE + + override def newInstance( path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { @@ -673,44 +679,4 @@ object ParquetFileFormat extends Logging { Failure(cause) }.toOption } - - // JUL loggers must be held by a strong reference, otherwise they may get destroyed by GC. - // However, the root JUL logger used by Parquet isn't properly referenced. Here we keep - // references to loggers in both parquet-mr <= 1.6 and >= 1.7 - val apacheParquetLogger: JLogger = JLogger.getLogger(classOf[ApacheParquetLog].getPackage.getName) - val parquetLogger: JLogger = JLogger.getLogger("parquet") - - // Parquet initializes its own JUL logger in a static block which always prints to stdout. Here - // we redirect the JUL logger via SLF4J JUL bridge handler. - val redirectParquetLogsViaSLF4J: Unit = { - def redirect(logger: JLogger): Unit = { - logger.getHandlers.foreach(logger.removeHandler) - logger.setUseParentHandlers(false) - logger.addHandler(new SLF4JBridgeHandler) - } - - // For parquet-mr 1.7.0 and above versions, which are under `org.apache.parquet` namespace. - // scalastyle:off classforname - Class.forName(classOf[ApacheParquetLog].getName) - // scalastyle:on classforname - redirect(JLogger.getLogger(classOf[ApacheParquetLog].getPackage.getName)) - - // For parquet-mr 1.6.0 and lower versions bundled with Hive, which are under `parquet` - // namespace. - try { - // scalastyle:off classforname - Class.forName("parquet.Log") - // scalastyle:on classforname - redirect(JLogger.getLogger("parquet")) - } catch { case _: Throwable => - // SPARK-9974: com.twitter:parquet-hadoop-bundle:1.6.0 is not packaged into the assembly - // when Spark is built with SBT. So `parquet.Log` may not be found. This try/catch block - // should be removed after this issue is fixed. - } - } - - /** - * ParquetFileFormat.prepareWrite calls this function to initialize `redirectParquetLogsViaSLF4J`. - */ - def redirectParquetLogs(): Unit = {} } diff --git a/sql/core/src/test/resources/log4j.properties b/sql/core/src/test/resources/log4j.properties index 33b9ecf1e2826..25b817382195a 100644 --- a/sql/core/src/test/resources/log4j.properties +++ b/sql/core/src/test/resources/log4j.properties @@ -53,5 +53,5 @@ log4j.additivity.hive.ql.metadata.Hive=false log4j.logger.hive.ql.metadata.Hive=OFF # Parquet related logging -log4j.logger.org.apache.parquet.hadoop=WARN -log4j.logger.org.apache.spark.sql.parquet=INFO +log4j.logger.org.apache.parquet=ERROR +log4j.logger.parquet=ERROR diff --git a/sql/hive/src/test/resources/log4j.properties b/sql/hive/src/test/resources/log4j.properties index fea3404769d9d..072bb25d30a87 100644 --- a/sql/hive/src/test/resources/log4j.properties +++ b/sql/hive/src/test/resources/log4j.properties @@ -59,3 +59,7 @@ log4j.logger.hive.ql.metadata.Hive=OFF log4j.additivity.org.apache.hadoop.hive.ql.io.RCFile=false log4j.logger.org.apache.hadoop.hive.ql.io.RCFile=ERROR + +# Parquet related logging +log4j.logger.org.apache.parquet=ERROR +log4j.logger.parquet=ERROR From 2f7461f31331cfc37f6cfa3586b7bbefb3af5547 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 10 Nov 2016 13:42:48 -0800 Subject: [PATCH 113/198] [SPARK-17990][SPARK-18302][SQL] correct several partition related behaviours of ExternalCatalog ## What changes were proposed in this pull request? This PR corrects several partition related behaviors of `ExternalCatalog`: 1. default partition location should not always lower case the partition column names in path string(fix `HiveExternalCatalog`) 2. rename partition should not always lower case the partition column names in updated partition path string(fix `HiveExternalCatalog`) 3. rename partition should update the partition location only for managed table(fix `InMemoryCatalog`) 4. create partition with existing directory should be fine(fix `InMemoryCatalog`) 5. create partition with non-existing directory should create that directory(fix `InMemoryCatalog`) 6. drop partition from external table should not delete the directory(fix `InMemoryCatalog`) ## How was this patch tested? new tests in `ExternalCatalogSuite` Author: Wenchen Fan Closes #15797 from cloud-fan/partition. --- .../catalog/ExternalCatalogUtils.scala | 121 ++++++++++++++ .../catalyst/catalog/InMemoryCatalog.scala | 92 +++++------ .../sql/catalyst/catalog/interface.scala | 11 ++ .../catalog/ExternalCatalogSuite.scala | 150 ++++++++++++++---- .../catalog/SessionCatalogSuite.scala | 24 ++- .../spark/sql/execution/command/ddl.scala | 8 +- .../spark/sql/execution/command/tables.scala | 3 +- .../datasources/CatalogFileIndex.scala | 2 +- .../datasources/DataSourceStrategy.scala | 2 +- .../datasources/FileFormatWriter.scala | 6 +- .../PartitioningAwareFileIndex.scala | 2 - .../datasources/PartitioningUtils.scala | 94 +---------- .../sql/execution/command/DDLSuite.scala | 8 +- .../ParquetPartitionDiscoverySuite.scala | 21 +-- .../spark/sql/hive/HiveExternalCatalog.scala | 51 +++++- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 4 +- .../spark/sql/hive/MultiDatabaseSuite.scala | 2 +- .../sql/hive/execution/HiveDDLSuite.scala | 2 +- .../sql/hive/execution/SQLQuerySuite.scala | 2 +- 19 files changed, 397 insertions(+), 208 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala new file mode 100644 index 0000000000000..b1442eec164d8 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.util.Shell + +import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec + +object ExternalCatalogUtils { + // This duplicates default value of Hive `ConfVars.DEFAULTPARTITIONNAME`, since catalyst doesn't + // depend on Hive. + val DEFAULT_PARTITION_NAME = "__HIVE_DEFAULT_PARTITION__" + + ////////////////////////////////////////////////////////////////////////////////////////////////// + // The following string escaping code is mainly copied from Hive (o.a.h.h.common.FileUtils). + ////////////////////////////////////////////////////////////////////////////////////////////////// + + val charToEscape = { + val bitSet = new java.util.BitSet(128) + + /** + * ASCII 01-1F are HTTP control characters that need to be escaped. + * \u000A and \u000D are \n and \r, respectively. + */ + val clist = Array( + '\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u0009', + '\n', '\u000B', '\u000C', '\r', '\u000E', '\u000F', '\u0010', '\u0011', '\u0012', '\u0013', + '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001A', '\u001B', '\u001C', + '\u001D', '\u001E', '\u001F', '"', '#', '%', '\'', '*', '/', ':', '=', '?', '\\', '\u007F', + '{', '[', ']', '^') + + clist.foreach(bitSet.set(_)) + + if (Shell.WINDOWS) { + Array(' ', '<', '>', '|').foreach(bitSet.set(_)) + } + + bitSet + } + + def needsEscaping(c: Char): Boolean = { + c >= 0 && c < charToEscape.size() && charToEscape.get(c) + } + + def escapePathName(path: String): String = { + val builder = new StringBuilder() + path.foreach { c => + if (needsEscaping(c)) { + builder.append('%') + builder.append(f"${c.asInstanceOf[Int]}%02X") + } else { + builder.append(c) + } + } + + builder.toString() + } + + + def unescapePathName(path: String): String = { + val sb = new StringBuilder + var i = 0 + + while (i < path.length) { + val c = path.charAt(i) + if (c == '%' && i + 2 < path.length) { + val code: Int = try { + Integer.parseInt(path.substring(i + 1, i + 3), 16) + } catch { + case _: Exception => -1 + } + if (code >= 0) { + sb.append(code.asInstanceOf[Char]) + i += 3 + } else { + sb.append(c) + i += 1 + } + } else { + sb.append(c) + i += 1 + } + } + + sb.toString() + } + + def generatePartitionPath( + spec: TablePartitionSpec, + partitionColumnNames: Seq[String], + tablePath: Path): Path = { + val partitionPathStrings = partitionColumnNames.map { col => + val partitionValue = spec(col) + val partitionString = if (partitionValue == null) { + DEFAULT_PARTITION_NAME + } else { + escapePathName(partitionValue) + } + escapePathName(col) + "=" + partitionString + } + partitionPathStrings.foldLeft(tablePath) { (totalPath, nextPartPath) => + new Path(totalPath, nextPartPath) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 20db81e6f9060..a3ffeaa63f690 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -231,7 +231,7 @@ class InMemoryCatalog( assert(tableMeta.storage.locationUri.isDefined, "Managed table should always have table location, as we will assign a default location " + "to it if it doesn't have one.") - val dir = new Path(tableMeta.storage.locationUri.get) + val dir = new Path(tableMeta.location) try { val fs = dir.getFileSystem(hadoopConfig) fs.delete(dir, true) @@ -259,7 +259,7 @@ class InMemoryCatalog( assert(oldDesc.table.storage.locationUri.isDefined, "Managed table should always have table location, as we will assign a default location " + "to it if it doesn't have one.") - val oldDir = new Path(oldDesc.table.storage.locationUri.get) + val oldDir = new Path(oldDesc.table.location) val newDir = new Path(catalog(db).db.locationUri, newName) try { val fs = oldDir.getFileSystem(hadoopConfig) @@ -355,25 +355,28 @@ class InMemoryCatalog( } } - val tableDir = new Path(catalog(db).db.locationUri, table) - val partitionColumnNames = getTable(db, table).partitionColumnNames + val tableMeta = getTable(db, table) + val partitionColumnNames = tableMeta.partitionColumnNames + val tablePath = new Path(tableMeta.location) // TODO: we should follow hive to roll back if one partition path failed to create. parts.foreach { p => - // If location is set, the partition is using an external partition location and we don't - // need to handle its directory. - if (p.storage.locationUri.isEmpty) { - val partitionPath = partitionColumnNames.flatMap { col => - p.spec.get(col).map(col + "=" + _) - }.mkString("/") - try { - val fs = tableDir.getFileSystem(hadoopConfig) - fs.mkdirs(new Path(tableDir, partitionPath)) - } catch { - case e: IOException => - throw new SparkException(s"Unable to create partition path $partitionPath", e) + val partitionPath = p.storage.locationUri.map(new Path(_)).getOrElse { + ExternalCatalogUtils.generatePartitionPath(p.spec, partitionColumnNames, tablePath) + } + + try { + val fs = tablePath.getFileSystem(hadoopConfig) + if (!fs.exists(partitionPath)) { + fs.mkdirs(partitionPath) } + } catch { + case e: IOException => + throw new SparkException(s"Unable to create partition path $partitionPath", e) } - existingParts.put(p.spec, p) + + existingParts.put( + p.spec, + p.copy(storage = p.storage.copy(locationUri = Some(partitionPath.toString)))) } } @@ -392,19 +395,15 @@ class InMemoryCatalog( } } - val tableDir = new Path(catalog(db).db.locationUri, table) - val partitionColumnNames = getTable(db, table).partitionColumnNames - // TODO: we should follow hive to roll back if one partition path failed to delete. + val shouldRemovePartitionLocation = getTable(db, table).tableType == CatalogTableType.MANAGED + // TODO: we should follow hive to roll back if one partition path failed to delete, and support + // partial partition spec. partSpecs.foreach { p => - // If location is set, the partition is using an external partition location and we don't - // need to handle its directory. - if (existingParts.contains(p) && existingParts(p).storage.locationUri.isEmpty) { - val partitionPath = partitionColumnNames.flatMap { col => - p.get(col).map(col + "=" + _) - }.mkString("/") + if (existingParts.contains(p) && shouldRemovePartitionLocation) { + val partitionPath = new Path(existingParts(p).location) try { - val fs = tableDir.getFileSystem(hadoopConfig) - fs.delete(new Path(tableDir, partitionPath), true) + val fs = partitionPath.getFileSystem(hadoopConfig) + fs.delete(partitionPath, true) } catch { case e: IOException => throw new SparkException(s"Unable to delete partition path $partitionPath", e) @@ -423,33 +422,34 @@ class InMemoryCatalog( requirePartitionsExist(db, table, specs) requirePartitionsNotExist(db, table, newSpecs) - val tableDir = new Path(catalog(db).db.locationUri, table) - val partitionColumnNames = getTable(db, table).partitionColumnNames + val tableMeta = getTable(db, table) + val partitionColumnNames = tableMeta.partitionColumnNames + val tablePath = new Path(tableMeta.location) + val shouldUpdatePartitionLocation = getTable(db, table).tableType == CatalogTableType.MANAGED + val existingParts = catalog(db).tables(table).partitions // TODO: we should follow hive to roll back if one partition path failed to rename. specs.zip(newSpecs).foreach { case (oldSpec, newSpec) => - val newPart = getPartition(db, table, oldSpec).copy(spec = newSpec) - val existingParts = catalog(db).tables(table).partitions - - // If location is set, the partition is using an external partition location and we don't - // need to handle its directory. - if (newPart.storage.locationUri.isEmpty) { - val oldPath = partitionColumnNames.flatMap { col => - oldSpec.get(col).map(col + "=" + _) - }.mkString("/") - val newPath = partitionColumnNames.flatMap { col => - newSpec.get(col).map(col + "=" + _) - }.mkString("/") + val oldPartition = getPartition(db, table, oldSpec) + val newPartition = if (shouldUpdatePartitionLocation) { + val oldPartPath = new Path(oldPartition.location) + val newPartPath = ExternalCatalogUtils.generatePartitionPath( + newSpec, partitionColumnNames, tablePath) try { - val fs = tableDir.getFileSystem(hadoopConfig) - fs.rename(new Path(tableDir, oldPath), new Path(tableDir, newPath)) + val fs = tablePath.getFileSystem(hadoopConfig) + fs.rename(oldPartPath, newPartPath) } catch { case e: IOException => - throw new SparkException(s"Unable to rename partition path $oldPath", e) + throw new SparkException(s"Unable to rename partition path $oldPartPath", e) } + oldPartition.copy( + spec = newSpec, + storage = oldPartition.storage.copy(locationUri = Some(newPartPath.toString))) + } else { + oldPartition.copy(spec = newSpec) } existingParts.remove(oldSpec) - existingParts.put(newSpec, newPart) + existingParts.put(newSpec, newPartition) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 34748a04859ad..93c70de18ae7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -99,6 +99,12 @@ case class CatalogTablePartition( output.filter(_.nonEmpty).mkString("CatalogPartition(\n\t", "\n\t", ")") } + /** Return the partition location, assuming it is specified. */ + def location: String = storage.locationUri.getOrElse { + val specString = spec.map { case (k, v) => s"$k=$v" }.mkString(", ") + throw new AnalysisException(s"Partition [$specString] did not specify locationUri") + } + /** * Given the partition schema, returns a row with that schema holding the partition values. */ @@ -171,6 +177,11 @@ case class CatalogTable( throw new AnalysisException(s"table $identifier did not specify database") } + /** Return the table location, assuming it is specified. */ + def location: String = storage.locationUri.getOrElse { + throw new AnalysisException(s"table $identifier did not specify locationUri") + } + /** Return the fully qualified name of this table, assuming the database was specified. */ def qualifiedName: String = identifier.unquotedString diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index 34bdfc8a98710..303a8662d3f4d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -17,9 +17,8 @@ package org.apache.spark.sql.catalyst.catalog -import java.io.File -import java.net.URI - +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite @@ -320,6 +319,33 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac catalog.createPartitions("db2", "tbl2", Seq(part1), ignoreIfExists = true) } + test("create partitions without location") { + val catalog = newBasicCatalog() + val table = CatalogTable( + identifier = TableIdentifier("tbl", Some("db1")), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat(None, None, None, None, false, Map.empty), + schema = new StructType() + .add("col1", "int") + .add("col2", "string") + .add("partCol1", "int") + .add("partCol2", "string"), + provider = Some("hive"), + partitionColumnNames = Seq("partCol1", "partCol2")) + catalog.createTable(table, ignoreIfExists = false) + + val partition = CatalogTablePartition(Map("partCol1" -> "1", "partCol2" -> "2"), storageFormat) + catalog.createPartitions("db1", "tbl", Seq(partition), ignoreIfExists = false) + + val partitionLocation = catalog.getPartition( + "db1", + "tbl", + Map("partCol1" -> "1", "partCol2" -> "2")).location + val tableLocation = catalog.getTable("db1", "tbl").location + val defaultPartitionLocation = new Path(new Path(tableLocation, "partCol1=1"), "partCol2=2") + assert(new Path(partitionLocation) == defaultPartitionLocation) + } + test("list partitions with partial partition spec") { val catalog = newBasicCatalog() val parts = catalog.listPartitions("db2", "tbl2", Some(Map("a" -> "1"))) @@ -399,6 +425,46 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac intercept[AnalysisException] { catalog.getPartition("db2", "tbl2", part2.spec) } } + test("rename partitions should update the location for managed table") { + val catalog = newBasicCatalog() + val table = CatalogTable( + identifier = TableIdentifier("tbl", Some("db1")), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat(None, None, None, None, false, Map.empty), + schema = new StructType() + .add("col1", "int") + .add("col2", "string") + .add("partCol1", "int") + .add("partCol2", "string"), + provider = Some("hive"), + partitionColumnNames = Seq("partCol1", "partCol2")) + catalog.createTable(table, ignoreIfExists = false) + + val tableLocation = catalog.getTable("db1", "tbl").location + + val mixedCasePart1 = CatalogTablePartition( + Map("partCol1" -> "1", "partCol2" -> "2"), storageFormat) + val mixedCasePart2 = CatalogTablePartition( + Map("partCol1" -> "3", "partCol2" -> "4"), storageFormat) + + catalog.createPartitions("db1", "tbl", Seq(mixedCasePart1), ignoreIfExists = false) + assert( + new Path(catalog.getPartition("db1", "tbl", mixedCasePart1.spec).location) == + new Path(new Path(tableLocation, "partCol1=1"), "partCol2=2")) + + catalog.renamePartitions("db1", "tbl", Seq(mixedCasePart1.spec), Seq(mixedCasePart2.spec)) + assert( + new Path(catalog.getPartition("db1", "tbl", mixedCasePart2.spec).location) == + new Path(new Path(tableLocation, "partCol1=3"), "partCol2=4")) + + // For external tables, RENAME PARTITION should not update the partition location. + val existingPartLoc = catalog.getPartition("db2", "tbl2", part1.spec).location + catalog.renamePartitions("db2", "tbl2", Seq(part1.spec), Seq(part3.spec)) + assert( + new Path(catalog.getPartition("db2", "tbl2", part3.spec).location) == + new Path(existingPartLoc)) + } + test("rename partitions when database/table does not exist") { val catalog = newBasicCatalog() intercept[AnalysisException] { @@ -419,11 +485,6 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac test("alter partitions") { val catalog = newBasicCatalog() try { - // Note: Before altering table partitions in Hive, you *must* set the current database - // to the one that contains the table of interest. Otherwise you will end up with the - // most helpful error message ever: "Unable to alter partition. alter is not possible." - // See HIVE-2742 for more detail. - catalog.setCurrentDatabase("db2") val newLocation = newUriForDatabase() val newSerde = "com.sparkbricks.text.EasySerde" val newSerdeProps = Map("spark" -> "bricks", "compressed" -> "false") @@ -571,10 +632,11 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac // -------------------------------------------------------------------------- private def exists(uri: String, children: String*): Boolean = { - val base = new File(new URI(uri)) - children.foldLeft(base) { - case (parent, child) => new File(parent, child) - }.exists() + val base = new Path(uri) + val finalPath = children.foldLeft(base) { + case (parent, child) => new Path(parent, child) + } + base.getFileSystem(new Configuration()).exists(finalPath) } test("create/drop database should create/delete the directory") { @@ -623,7 +685,6 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac test("create/drop/rename partitions should create/delete/rename the directory") { val catalog = newBasicCatalog() - val databaseDir = catalog.getDatabase("db1").locationUri val table = CatalogTable( identifier = TableIdentifier("tbl", Some("db1")), tableType = CatalogTableType.MANAGED, @@ -631,34 +692,61 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac schema = new StructType() .add("col1", "int") .add("col2", "string") - .add("a", "int") - .add("b", "string"), + .add("partCol1", "int") + .add("partCol2", "string"), provider = Some("hive"), - partitionColumnNames = Seq("a", "b") - ) + partitionColumnNames = Seq("partCol1", "partCol2")) catalog.createTable(table, ignoreIfExists = false) + val tableLocation = catalog.getTable("db1", "tbl").location + + val part1 = CatalogTablePartition(Map("partCol1" -> "1", "partCol2" -> "2"), storageFormat) + val part2 = CatalogTablePartition(Map("partCol1" -> "3", "partCol2" -> "4"), storageFormat) + val part3 = CatalogTablePartition(Map("partCol1" -> "5", "partCol2" -> "6"), storageFormat) + catalog.createPartitions("db1", "tbl", Seq(part1, part2), ignoreIfExists = false) - assert(exists(databaseDir, "tbl", "a=1", "b=2")) - assert(exists(databaseDir, "tbl", "a=3", "b=4")) + assert(exists(tableLocation, "partCol1=1", "partCol2=2")) + assert(exists(tableLocation, "partCol1=3", "partCol2=4")) catalog.renamePartitions("db1", "tbl", Seq(part1.spec), Seq(part3.spec)) - assert(!exists(databaseDir, "tbl", "a=1", "b=2")) - assert(exists(databaseDir, "tbl", "a=5", "b=6")) + assert(!exists(tableLocation, "partCol1=1", "partCol2=2")) + assert(exists(tableLocation, "partCol1=5", "partCol2=6")) catalog.dropPartitions("db1", "tbl", Seq(part2.spec, part3.spec), ignoreIfNotExists = false, purge = false) - assert(!exists(databaseDir, "tbl", "a=3", "b=4")) - assert(!exists(databaseDir, "tbl", "a=5", "b=6")) + assert(!exists(tableLocation, "partCol1=3", "partCol2=4")) + assert(!exists(tableLocation, "partCol1=5", "partCol2=6")) - val externalPartition = CatalogTablePartition( - Map("a" -> "7", "b" -> "8"), + val tempPath = Utils.createTempDir() + // create partition with existing directory is OK. + val partWithExistingDir = CatalogTablePartition( + Map("partCol1" -> "7", "partCol2" -> "8"), CatalogStorageFormat( - Some(Utils.createTempDir().getAbsolutePath), - None, None, None, false, Map.empty) - ) - catalog.createPartitions("db1", "tbl", Seq(externalPartition), ignoreIfExists = false) - assert(!exists(databaseDir, "tbl", "a=7", "b=8")) + Some(tempPath.getAbsolutePath), + None, None, None, false, Map.empty)) + catalog.createPartitions("db1", "tbl", Seq(partWithExistingDir), ignoreIfExists = false) + + tempPath.delete() + // create partition with non-existing directory will create that directory. + val partWithNonExistingDir = CatalogTablePartition( + Map("partCol1" -> "9", "partCol2" -> "10"), + CatalogStorageFormat( + Some(tempPath.getAbsolutePath), + None, None, None, false, Map.empty)) + catalog.createPartitions("db1", "tbl", Seq(partWithNonExistingDir), ignoreIfExists = false) + assert(tempPath.exists()) + } + + test("drop partition from external table should not delete the directory") { + val catalog = newBasicCatalog() + catalog.createPartitions("db2", "tbl1", Seq(part1), ignoreIfExists = false) + + val partPath = new Path(catalog.getPartition("db2", "tbl1", part1.spec).location) + val fs = partPath.getFileSystem(new Configuration) + assert(fs.exists(partPath)) + + catalog.dropPartitions("db2", "tbl1", Seq(part1.spec), ignoreIfNotExists = false, purge = false) + assert(fs.exists(partPath)) } } @@ -731,7 +819,7 @@ abstract class CatalogTestUtils { CatalogTable( identifier = TableIdentifier(name, database), tableType = CatalogTableType.EXTERNAL, - storage = storageFormat, + storage = storageFormat.copy(locationUri = Some(Utils.createTempDir().getAbsolutePath)), schema = new StructType() .add("col1", "int") .add("col2", "string") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 001d9c47785d2..52385de50db6b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -527,13 +527,13 @@ class SessionCatalogSuite extends SparkFunSuite { sessionCatalog.createTable(newTable("tbl", "mydb"), ignoreIfExists = false) sessionCatalog.createPartitions( TableIdentifier("tbl", Some("mydb")), Seq(part1, part2), ignoreIfExists = false) - assert(catalogPartitionsEqual(externalCatalog, "mydb", "tbl", Seq(part1, part2))) + assert(catalogPartitionsEqual(externalCatalog.listPartitions("mydb", "tbl"), part1, part2)) // Create partitions without explicitly specifying database sessionCatalog.setCurrentDatabase("mydb") sessionCatalog.createPartitions( TableIdentifier("tbl"), Seq(partWithMixedOrder), ignoreIfExists = false) assert(catalogPartitionsEqual( - externalCatalog, "mydb", "tbl", Seq(part1, part2, partWithMixedOrder))) + externalCatalog.listPartitions("mydb", "tbl"), part1, part2, partWithMixedOrder)) } test("create partitions when database/table does not exist") { @@ -586,13 +586,13 @@ class SessionCatalogSuite extends SparkFunSuite { test("drop partitions") { val externalCatalog = newBasicCatalog() val sessionCatalog = new SessionCatalog(externalCatalog) - assert(catalogPartitionsEqual(externalCatalog, "db2", "tbl2", Seq(part1, part2))) + assert(catalogPartitionsEqual(externalCatalog.listPartitions("db2", "tbl2"), part1, part2)) sessionCatalog.dropPartitions( TableIdentifier("tbl2", Some("db2")), Seq(part1.spec), ignoreIfNotExists = false, purge = false) - assert(catalogPartitionsEqual(externalCatalog, "db2", "tbl2", Seq(part2))) + assert(catalogPartitionsEqual(externalCatalog.listPartitions("db2", "tbl2"), part2)) // Drop partitions without explicitly specifying database sessionCatalog.setCurrentDatabase("db2") sessionCatalog.dropPartitions( @@ -604,7 +604,7 @@ class SessionCatalogSuite extends SparkFunSuite { // Drop multiple partitions at once sessionCatalog.createPartitions( TableIdentifier("tbl2", Some("db2")), Seq(part1, part2), ignoreIfExists = false) - assert(catalogPartitionsEqual(externalCatalog, "db2", "tbl2", Seq(part1, part2))) + assert(catalogPartitionsEqual(externalCatalog.listPartitions("db2", "tbl2"), part1, part2)) sessionCatalog.dropPartitions( TableIdentifier("tbl2", Some("db2")), Seq(part1.spec, part2.spec), @@ -844,10 +844,11 @@ class SessionCatalogSuite extends SparkFunSuite { test("list partitions") { val catalog = new SessionCatalog(newBasicCatalog()) - assert(catalog.listPartitions(TableIdentifier("tbl2", Some("db2"))).toSet == Set(part1, part2)) + assert(catalogPartitionsEqual( + catalog.listPartitions(TableIdentifier("tbl2", Some("db2"))), part1, part2)) // List partitions without explicitly specifying database catalog.setCurrentDatabase("db2") - assert(catalog.listPartitions(TableIdentifier("tbl2")).toSet == Set(part1, part2)) + assert(catalogPartitionsEqual(catalog.listPartitions(TableIdentifier("tbl2")), part1, part2)) } test("list partitions when database/table does not exist") { @@ -860,6 +861,15 @@ class SessionCatalogSuite extends SparkFunSuite { } } + private def catalogPartitionsEqual( + actualParts: Seq[CatalogTablePartition], + expectedParts: CatalogTablePartition*): Boolean = { + // ExternalCatalog may set a default location for partitions, here we ignore the partition + // location when comparing them. + actualParts.map(p => p.copy(storage = p.storage.copy(locationUri = None))).toSet == + expectedParts.map(p => p.copy(storage = p.storage.copy(locationUri = None))).toSet + } + // -------------------------------------------------------------------------- // Functions // -------------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 8500ab460a1b6..84a63fdb9f36f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.Resolver -import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTablePartition, CatalogTableType, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.execution.datasources.{CaseInsensitiveMap, PartitioningUtils} @@ -500,7 +500,7 @@ case class AlterTableRecoverPartitionsCommand( s"location provided: $tableIdentWithDB") } - val root = new Path(table.storage.locationUri.get) + val root = new Path(table.location) logInfo(s"Recover all the partitions in $root") val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration) @@ -558,9 +558,9 @@ case class AlterTableRecoverPartitionsCommand( val name = st.getPath.getName if (st.isDirectory && name.contains("=")) { val ps = name.split("=", 2) - val columnName = PartitioningUtils.unescapePathName(ps(0)) + val columnName = ExternalCatalogUtils.unescapePathName(ps(0)) // TODO: Validate the value - val value = PartitioningUtils.unescapePathName(ps(1)) + val value = ExternalCatalogUtils.unescapePathName(ps(1)) if (resolver(columnName, partitionNames.head)) { scanPartitions(spark, fs, filter, st.getPath, spec ++ Map(partitionNames.head -> value), partitionNames.drop(1), threshold, resolver) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index e49a1f5acd0c9..119e732d0202c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -710,7 +710,8 @@ case class ShowPartitionsCommand( private def getPartName(spec: TablePartitionSpec, partColNames: Seq[String]): String = { partColNames.map { name => - PartitioningUtils.escapePathName(name) + "=" + PartitioningUtils.escapePathName(spec(name)) + ExternalCatalogUtils.escapePathName(name) + "=" + + ExternalCatalogUtils.escapePathName(spec(name)) }.mkString(File.separator) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala index 443a2ec033a98..4ad91dcceb432 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala @@ -67,7 +67,7 @@ class CatalogFileIndex( val selectedPartitions = sparkSession.sessionState.catalog.listPartitionsByFilter( table.identifier, filters) val partitions = selectedPartitions.map { p => - val path = new Path(p.storage.locationUri.get) + val path = new Path(p.location) val fs = path.getFileSystem(hadoopConf) PartitionPath( p.toRow(partitionSchema), path.makeQualified(fs.getUri, fs.getWorkingDirectory)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 2d43a6ad098ed..739aeac877b99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -190,7 +190,7 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { val effectiveOutputPath = if (overwritingSinglePartition) { val partition = t.sparkSession.sessionState.catalog.getPartition( l.catalogTable.get.identifier, overwrite.specificPartition.get) - new Path(partition.storage.locationUri.get) + new Path(partition.location) } else { outputPath } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index fa7fe143daeba..69b3fa667ef54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -32,7 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils} import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage import org.apache.spark.sql.{Dataset, SparkSession} -import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning @@ -281,11 +281,11 @@ object FileFormatWriter extends Logging { private def partitionStringExpression: Seq[Expression] = { description.partitionColumns.zipWithIndex.flatMap { case (c, i) => val escaped = ScalaUDF( - PartitioningUtils.escapePathName _, + ExternalCatalogUtils.escapePathName _, StringType, Seq(Cast(c, StringType)), Seq(StringType)) - val str = If(IsNull(c), Literal(PartitioningUtils.DEFAULT_PARTITION_NAME), escaped) + val str = If(IsNull(c), Literal(ExternalCatalogUtils.DEFAULT_PARTITION_NAME), escaped) val partitionName = Literal(c.name + "=") :: str :: Nil if (i == 0) partitionName else Literal(Path.SEPARATOR) :: partitionName } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index a8a722dd3c620..3740caa22c37e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -128,7 +128,6 @@ abstract class PartitioningAwareFileIndex( case Some(userProvidedSchema) if userProvidedSchema.nonEmpty => val spec = PartitioningUtils.parsePartitions( leafDirs, - PartitioningUtils.DEFAULT_PARTITION_NAME, typeInference = false, basePaths = basePaths) @@ -148,7 +147,6 @@ abstract class PartitioningAwareFileIndex( case _ => PartitioningUtils.parsePartitions( leafDirs, - PartitioningUtils.DEFAULT_PARTITION_NAME, typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled, basePaths = basePaths) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index b51b41869bf06..a28b04ca3fb5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -25,7 +25,6 @@ import scala.collection.mutable.ArrayBuffer import scala.util.Try import org.apache.hadoop.fs.Path -import org.apache.hadoop.util.Shell import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow @@ -56,15 +55,15 @@ object PartitionSpec { } object PartitioningUtils { - // This duplicates default value of Hive `ConfVars.DEFAULTPARTITIONNAME`, since sql/core doesn't - // depend on Hive. - val DEFAULT_PARTITION_NAME = "__HIVE_DEFAULT_PARTITION__" private[datasources] case class PartitionValues(columnNames: Seq[String], literals: Seq[Literal]) { require(columnNames.size == literals.size) } + import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.DEFAULT_PARTITION_NAME + import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.unescapePathName + /** * Given a group of qualified paths, tries to parse them and returns a partition specification. * For example, given: @@ -90,12 +89,11 @@ object PartitioningUtils { */ private[datasources] def parsePartitions( paths: Seq[Path], - defaultPartitionName: String, typeInference: Boolean, basePaths: Set[Path]): PartitionSpec = { // First, we need to parse every partition's path and see if we can find partition values. val (partitionValues, optDiscoveredBasePaths) = paths.map { path => - parsePartition(path, defaultPartitionName, typeInference, basePaths) + parsePartition(path, typeInference, basePaths) }.unzip // We create pairs of (path -> path's partition value) here @@ -173,7 +171,6 @@ object PartitioningUtils { */ private[datasources] def parsePartition( path: Path, - defaultPartitionName: String, typeInference: Boolean, basePaths: Set[Path]): (Option[PartitionValues], Option[Path]) = { val columns = ArrayBuffer.empty[(String, Literal)] @@ -196,7 +193,7 @@ object PartitioningUtils { // Let's say currentPath is a path of "/table/a=1/", currentPath.getName will give us a=1. // Once we get the string, we try to parse it and find the partition column and value. val maybeColumn = - parsePartitionColumn(currentPath.getName, defaultPartitionName, typeInference) + parsePartitionColumn(currentPath.getName, typeInference) maybeColumn.foreach(columns += _) // Now, we determine if we should stop. @@ -228,7 +225,6 @@ object PartitioningUtils { private def parsePartitionColumn( columnSpec: String, - defaultPartitionName: String, typeInference: Boolean): Option[(String, Literal)] = { val equalSignIndex = columnSpec.indexOf('=') if (equalSignIndex == -1) { @@ -240,7 +236,7 @@ object PartitioningUtils { val rawColumnValue = columnSpec.drop(equalSignIndex + 1) assert(rawColumnValue.nonEmpty, s"Empty partition column value in '$columnSpec'") - val literal = inferPartitionColumnValue(rawColumnValue, defaultPartitionName, typeInference) + val literal = inferPartitionColumnValue(rawColumnValue, typeInference) Some(columnName -> literal) } } @@ -355,7 +351,6 @@ object PartitioningUtils { */ private[datasources] def inferPartitionColumnValue( raw: String, - defaultPartitionName: String, typeInference: Boolean): Literal = { val decimalTry = Try { // `BigDecimal` conversion can fail when the `field` is not a form of number. @@ -380,14 +375,14 @@ object PartitioningUtils { .orElse(Try(Literal(JTimestamp.valueOf(unescapePathName(raw))))) // Then falls back to string .getOrElse { - if (raw == defaultPartitionName) { + if (raw == DEFAULT_PARTITION_NAME) { Literal.create(null, NullType) } else { Literal.create(unescapePathName(raw), StringType) } } } else { - if (raw == defaultPartitionName) { + if (raw == DEFAULT_PARTITION_NAME) { Literal.create(null, NullType) } else { Literal.create(unescapePathName(raw), StringType) @@ -450,77 +445,4 @@ object PartitioningUtils { Literal.create(Cast(l, desiredType).eval(), desiredType) } } - - ////////////////////////////////////////////////////////////////////////////////////////////////// - // The following string escaping code is mainly copied from Hive (o.a.h.h.common.FileUtils). - ////////////////////////////////////////////////////////////////////////////////////////////////// - - val charToEscape = { - val bitSet = new java.util.BitSet(128) - - /** - * ASCII 01-1F are HTTP control characters that need to be escaped. - * \u000A and \u000D are \n and \r, respectively. - */ - val clist = Array( - '\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u0009', - '\n', '\u000B', '\u000C', '\r', '\u000E', '\u000F', '\u0010', '\u0011', '\u0012', '\u0013', - '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001A', '\u001B', '\u001C', - '\u001D', '\u001E', '\u001F', '"', '#', '%', '\'', '*', '/', ':', '=', '?', '\\', '\u007F', - '{', '[', ']', '^') - - clist.foreach(bitSet.set(_)) - - if (Shell.WINDOWS) { - Array(' ', '<', '>', '|').foreach(bitSet.set(_)) - } - - bitSet - } - - def needsEscaping(c: Char): Boolean = { - c >= 0 && c < charToEscape.size() && charToEscape.get(c) - } - - def escapePathName(path: String): String = { - val builder = new StringBuilder() - path.foreach { c => - if (needsEscaping(c)) { - builder.append('%') - builder.append(f"${c.asInstanceOf[Int]}%02X") - } else { - builder.append(c) - } - } - - builder.toString() - } - - def unescapePathName(path: String): String = { - val sb = new StringBuilder - var i = 0 - - while (i < path.length) { - val c = path.charAt(i) - if (c == '%' && i + 2 < path.length) { - val code: Int = try { - Integer.parseInt(path.substring(i + 1, i + 3), 16) - } catch { - case _: Exception => -1 - } - if (code >= 0) { - sb.append(code.asInstanceOf[Char]) - i += 3 - } else { - sb.append(c) - i += 1 - } - } else { - sb.append(c) - i += 1 - } - } - - sb.toString() - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index df3a3c34c39a0..363715c6d2249 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -875,7 +875,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) val part2 = Map("a" -> "2", "b" -> "6") - val root = new Path(catalog.getTableMetadata(tableIdent).storage.locationUri.get) + val root = new Path(catalog.getTableMetadata(tableIdent).location) val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration) // valid fs.mkdirs(new Path(new Path(root, "a=1"), "b=5")) @@ -1133,7 +1133,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } assert(catalog.getTableMetadata(tableIdent).storage.locationUri.isDefined) assert(catalog.getTableMetadata(tableIdent).storage.properties.isEmpty) - assert(catalog.getPartition(tableIdent, partSpec).storage.locationUri.isEmpty) + assert(catalog.getPartition(tableIdent, partSpec).storage.locationUri.isDefined) assert(catalog.getPartition(tableIdent, partSpec).storage.properties.isEmpty) // Verify that the location is set to the expected string def verifyLocation(expected: String, spec: Option[TablePartitionSpec] = None): Unit = { @@ -1296,9 +1296,9 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { sql("ALTER TABLE dbx.tab1 ADD IF NOT EXISTS " + "PARTITION (a='2', b='6') LOCATION 'paris' PARTITION (a='3', b='7')") assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3)) - assert(catalog.getPartition(tableIdent, part1).storage.locationUri.isEmpty) + assert(catalog.getPartition(tableIdent, part1).storage.locationUri.isDefined) assert(catalog.getPartition(tableIdent, part2).storage.locationUri == Option("paris")) - assert(catalog.getPartition(tableIdent, part3).storage.locationUri.isEmpty) + assert(catalog.getPartition(tableIdent, part3).storage.locationUri.isDefined) // add partitions without explicitly specifying database catalog.setCurrentDatabase("dbx") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 120a3a2ef33aa..22e35a1bc0b1d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -29,6 +29,7 @@ import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.{PartitionPath => Partition} @@ -48,11 +49,11 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha import PartitioningUtils._ import testImplicits._ - val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__" + val defaultPartitionName = ExternalCatalogUtils.DEFAULT_PARTITION_NAME test("column type inference") { def check(raw: String, literal: Literal): Unit = { - assert(inferPartitionColumnValue(raw, defaultPartitionName, true) === literal) + assert(inferPartitionColumnValue(raw, true) === literal) } check("10", Literal.create(10, IntegerType)) @@ -76,7 +77,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha "hdfs://host:9000/path/a=10.5/b=hello") var exception = intercept[AssertionError] { - parsePartitions(paths.map(new Path(_)), defaultPartitionName, true, Set.empty[Path]) + parsePartitions(paths.map(new Path(_)), true, Set.empty[Path]) } assert(exception.getMessage().contains("Conflicting directory structures detected")) @@ -88,7 +89,6 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), - defaultPartitionName, true, Set(new Path("hdfs://host:9000/path/"))) @@ -101,7 +101,6 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), - defaultPartitionName, true, Set(new Path("hdfs://host:9000/path/something=true/table"))) @@ -114,7 +113,6 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), - defaultPartitionName, true, Set(new Path("hdfs://host:9000/path/table=true"))) @@ -127,7 +125,6 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha exception = intercept[AssertionError] { parsePartitions( paths.map(new Path(_)), - defaultPartitionName, true, Set(new Path("hdfs://host:9000/path/"))) } @@ -147,7 +144,6 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha exception = intercept[AssertionError] { parsePartitions( paths.map(new Path(_)), - defaultPartitionName, true, Set(new Path("hdfs://host:9000/tmp/tables/"))) } @@ -156,13 +152,13 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha test("parse partition") { def check(path: String, expected: Option[PartitionValues]): Unit = { - val actual = parsePartition(new Path(path), defaultPartitionName, true, Set.empty[Path])._1 + val actual = parsePartition(new Path(path), true, Set.empty[Path])._1 assert(expected === actual) } def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = { val message = intercept[T] { - parsePartition(new Path(path), defaultPartitionName, true, Set.empty[Path]) + parsePartition(new Path(path), true, Set.empty[Path]) }.getMessage assert(message.contains(expected)) @@ -204,7 +200,6 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha // when the basePaths is the same as the path to a leaf directory val partitionSpec1: Option[PartitionValues] = parsePartition( path = new Path("file://path/a=10"), - defaultPartitionName = defaultPartitionName, typeInference = true, basePaths = Set(new Path("file://path/a=10")))._1 @@ -213,7 +208,6 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha // when the basePaths is the path to a base directory of leaf directories val partitionSpec2: Option[PartitionValues] = parsePartition( path = new Path("file://path/a=10"), - defaultPartitionName = defaultPartitionName, typeInference = true, basePaths = Set(new Path("file://path")))._1 @@ -231,7 +225,6 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val actualSpec = parsePartitions( paths.map(new Path(_)), - defaultPartitionName, true, rootPaths) assert(actualSpec === spec) @@ -314,7 +307,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha test("parse partitions with type inference disabled") { def check(paths: Seq[String], spec: PartitionSpec): Unit = { val actualSpec = - parsePartitions(paths.map(new Path(_)), defaultPartitionName, false, Set.empty[Path]) + parsePartitions(paths.map(new Path(_)), false, Set.empty[Path]) assert(actualSpec === spec) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index b537061d0d221..42ce1a88a2b67 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.hive +import java.io.IOException import java.util import scala.util.control.NonFatal @@ -26,7 +27,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.ql.metadata.HiveException import org.apache.thrift.TException -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier @@ -255,7 +256,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // compatible format, which means the data source is file-based and must have a `path`. require(tableDefinition.storage.locationUri.isDefined, "External file-based data source table must have a `path` entry in storage properties.") - Some(new Path(tableDefinition.storage.locationUri.get).toUri.toString) + Some(new Path(tableDefinition.location).toUri.toString) } else { None } @@ -789,7 +790,21 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat parts: Seq[CatalogTablePartition], ignoreIfExists: Boolean): Unit = withClient { requireTableExists(db, table) - val lowerCasedParts = parts.map(p => p.copy(spec = lowerCasePartitionSpec(p.spec))) + + val tableMeta = getTable(db, table) + val partitionColumnNames = tableMeta.partitionColumnNames + val tablePath = new Path(tableMeta.location) + val partsWithLocation = parts.map { p => + // Ideally we can leave the partition location empty and let Hive metastore to set it. + // However, Hive metastore is not case preserving and will generate wrong partition location + // with lower cased partition column names. Here we set the default partition location + // manually to avoid this problem. + val partitionPath = p.storage.locationUri.map(new Path(_)).getOrElse { + ExternalCatalogUtils.generatePartitionPath(p.spec, partitionColumnNames, tablePath) + } + p.copy(storage = p.storage.copy(locationUri = Some(partitionPath.toString))) + } + val lowerCasedParts = partsWithLocation.map(p => p.copy(spec = lowerCasePartitionSpec(p.spec))) client.createPartitions(db, table, lowerCasedParts, ignoreIfExists) } @@ -810,6 +825,31 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat newSpecs: Seq[TablePartitionSpec]): Unit = withClient { client.renamePartitions( db, table, specs.map(lowerCasePartitionSpec), newSpecs.map(lowerCasePartitionSpec)) + + val tableMeta = getTable(db, table) + val partitionColumnNames = tableMeta.partitionColumnNames + // Hive metastore is not case preserving and keeps partition columns with lower cased names. + // When Hive rename partition for managed tables, it will create the partition location with + // a default path generate by the new spec with lower cased partition column names. This is + // unexpected and we need to rename them manually and alter the partition location. + val hasUpperCasePartitionColumn = partitionColumnNames.exists(col => col.toLowerCase != col) + if (tableMeta.tableType == MANAGED && hasUpperCasePartitionColumn) { + val tablePath = new Path(tableMeta.location) + val newParts = newSpecs.map { spec => + val partition = client.getPartition(db, table, lowerCasePartitionSpec(spec)) + val wrongPath = new Path(partition.location) + val rightPath = ExternalCatalogUtils.generatePartitionPath( + spec, partitionColumnNames, tablePath) + try { + tablePath.getFileSystem(hadoopConf).rename(wrongPath, rightPath) + } catch { + case e: IOException => throw new SparkException( + s"Unable to rename partition path from $wrongPath to $rightPath", e) + } + partition.copy(storage = partition.storage.copy(locationUri = Some(rightPath.toString))) + } + alterPartitions(db, table, newParts) + } } override def alterPartitions( @@ -817,6 +857,11 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat table: String, newParts: Seq[CatalogTablePartition]): Unit = withClient { val lowerCasedParts = newParts.map(p => p.copy(spec = lowerCasePartitionSpec(p.spec))) + // Note: Before altering table partitions in Hive, you *must* set the current database + // to the one that contains the table of interest. Otherwise you will end up with the + // most helpful error message ever: "Unable to alter partition. alter is not possible." + // See HIVE-2742 for more detail. + client.setCurrentDatabase(db) client.alterPartitions(db, table, lowerCasedParts) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index d3873cf6c8231..fbd705172cae6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -445,7 +445,7 @@ object SetWarehouseLocationTest extends Logging { catalog.getTableMetadata(TableIdentifier("testLocation", Some("default"))) val expectedLocation = "file:" + expectedWarehouseLocation.toString + "/testlocation" - val actualLocation = tableMetadata.storage.locationUri.get + val actualLocation = tableMetadata.location if (actualLocation != expectedLocation) { throw new Exception( s"Expected table location is $expectedLocation. But, it is actually $actualLocation") @@ -461,7 +461,7 @@ object SetWarehouseLocationTest extends Logging { catalog.getTableMetadata(TableIdentifier("testLocation", Some("testLocationDB"))) val expectedLocation = "file:" + expectedWarehouseLocation.toString + "/testlocationdb.db/testlocation" - val actualLocation = tableMetadata.storage.locationUri.get + val actualLocation = tableMetadata.location if (actualLocation != expectedLocation) { throw new Exception( s"Expected table location is $expectedLocation. But, it is actually $actualLocation") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index cfc1d81d544eb..9f4401ae22560 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -29,7 +29,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle val expectedPath = spark.sharedState.externalCatalog.getDatabase(dbName).locationUri + "/" + tableName - assert(metastoreTable.storage.locationUri.get === expectedPath) + assert(metastoreTable.location === expectedPath) } private def getTableNames(dbName: Option[String] = None): Array[String] = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 0076a778683ca..6efae13ddf69d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -425,7 +425,7 @@ class HiveDDLSuite sql("CREATE TABLE tab1 (height INT, length INT) PARTITIONED BY (a INT, b INT)") val part1 = Map("a" -> "1", "b" -> "5") val part2 = Map("a" -> "2", "b" -> "6") - val root = new Path(catalog.getTableMetadata(tableIdent).storage.locationUri.get) + val root = new Path(catalog.getTableMetadata(tableIdent).location) val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration) // valid fs.mkdirs(new Path(new Path(root, "a=1"), "b=5")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index c21db3595fa19..e607af67f93e5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -542,7 +542,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } userSpecifiedLocation match { case Some(location) => - assert(r.catalogTable.storage.locationUri.get === location) + assert(r.catalogTable.location === location) case None => // OK. } // Also make sure that the format and serde are as desired. From e0deee1f7df31177cfc14bbb296f0baa372f473d Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 10 Nov 2016 13:44:54 -0800 Subject: [PATCH 114/198] [SPARK-18403][SQL] Temporarily disable flaky ObjectHashAggregateSuite ## What changes were proposed in this pull request? Randomized tests in `ObjectHashAggregateSuite` is being flaky and breaks PR builds. This PR disables them temporarily to bring back the PR build. ## How was this patch tested? N/A Author: Cheng Lian Closes #15845 from liancheng/ignore-flaky-object-hash-agg-suite. --- .../spark/sql/hive/execution/ObjectHashAggregateSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala index 93fc5e8a5e376..b7f91d8c3a797 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala @@ -326,7 +326,8 @@ class ObjectHashAggregateSuite // Currently Spark SQL doesn't support evaluating distinct aggregate function together // with aggregate functions without partial aggregation support. if (!(aggs.contains(withoutPartial) && aggs.contains(withDistinct))) { - test( + // TODO Re-enables them after fixing SPARK-18403 + ignore( s"randomized aggregation test - " + s"${names.mkString("[", ", ", "]")} - " + s"${if (withGroupingKeys) "with" else "without"} grouping keys - " + From a3356343cbf58b930326f45721fb4ecade6f8029 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 10 Nov 2016 17:00:43 -0800 Subject: [PATCH 115/198] [SPARK-18185] Fix all forms of INSERT / OVERWRITE TABLE for Datasource tables ## What changes were proposed in this pull request? As of current 2.1, INSERT OVERWRITE with dynamic partitions against a Datasource table will overwrite the entire table instead of only the partitions matching the static keys, as in Hive. It also doesn't respect custom partition locations. This PR adds support for all these operations to Datasource tables managed by the Hive metastore. It is implemented as follows - During planning time, the full set of partitions affected by an INSERT or OVERWRITE command is read from the Hive metastore. - The planner identifies any partitions with custom locations and includes this in the write task metadata. - FileFormatWriter tasks refer to this custom locations map when determining where to write for dynamic partition output. - When the write job finishes, the set of written partitions is compared against the initial set of matched partitions, and the Hive metastore is updated to reflect the newly added / removed partitions. It was necessary to introduce a method for staging files with absolute output paths to `FileCommitProtocol`. These files are not handled by the Hadoop output committer but are moved to their final locations when the job commits. The overwrite behavior of legacy Datasource tables is also changed: no longer will the entire table be overwritten if a partial partition spec is present. cc cloud-fan yhuai ## How was this patch tested? Unit tests, existing tests. Author: Eric Liang Author: Wenchen Fan Closes #15814 from ericl/sc-5027. --- .../internal/io/FileCommitProtocol.scala | 15 ++ .../io/HadoopMapReduceCommitProtocol.scala | 63 ++++++- .../sql/catalyst/parser/AstBuilder.scala | 12 +- .../plans/logical/basicLogicalOperators.scala | 10 +- .../sql/catalyst/parser/PlanParserSuite.scala | 4 +- .../execution/datasources/DataSource.scala | 20 ++- .../datasources/DataSourceStrategy.scala | 94 +++++++--- .../datasources/FileFormatWriter.scala | 26 ++- .../InsertIntoHadoopFsRelationCommand.scala | 61 ++++++- .../datasources/PartitioningUtils.scala | 10 ++ .../execution/streaming/FileStreamSink.scala | 2 +- .../ManifestFileCommitProtocol.scala | 6 + .../PartitionProviderCompatibilitySuite.scala | 161 +++++++++++++++++- 13 files changed, 411 insertions(+), 73 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala index fb8020585cf89..afd2250c93a8a 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -82,9 +82,24 @@ abstract class FileCommitProtocol { * * The "dir" parameter specifies 2, and "ext" parameter specifies both 4 and 5, and the rest * are left to the commit protocol implementation to decide. + * + * Important: it is the caller's responsibility to add uniquely identifying content to "ext" + * if a task is going to write out multiple files to the same dir. The file commit protocol only + * guarantees that files written by different tasks will not conflict. */ def newTaskTempFile(taskContext: TaskAttemptContext, dir: Option[String], ext: String): String + /** + * Similar to newTaskTempFile(), but allows files to committed to an absolute output location. + * Depending on the implementation, there may be weaker guarantees around adding files this way. + * + * Important: it is the caller's responsibility to add uniquely identifying content to "ext" + * if a task is going to write out multiple files to the same dir. The file commit protocol only + * guarantees that files written by different tasks will not conflict. + */ + def newTaskTempFileAbsPath( + taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String + /** * Commits a task after the writes succeed. Must be called on the executors when running tasks. */ diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 6b0bcb8f908b8..b2d9b8d2a012f 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -17,7 +17,9 @@ package org.apache.spark.internal.io -import java.util.Date +import java.util.{Date, UUID} + +import scala.collection.mutable import org.apache.hadoop.conf.Configurable import org.apache.hadoop.fs.Path @@ -42,6 +44,19 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) /** OutputCommitter from Hadoop is not serializable so marking it transient. */ @transient private var committer: OutputCommitter = _ + /** + * Tracks files staged by this task for absolute output paths. These outputs are not managed by + * the Hadoop OutputCommitter, so we must move these to their final locations on job commit. + * + * The mapping is from the temp output path to the final desired output path of the file. + */ + @transient private var addedAbsPathFiles: mutable.Map[String, String] = null + + /** + * The staging directory for all files committed with absolute output paths. + */ + private def absPathStagingDir: Path = new Path(path, "_temporary-" + jobId) + protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { val format = context.getOutputFormatClass.newInstance() // If OutputFormat is Configurable, we should set conf to it. @@ -54,11 +69,7 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) override def newTaskTempFile( taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = { - // The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet - // Note that %05d does not truncate the split number, so if we have more than 100000 tasks, - // the file name is fine and won't overflow. - val split = taskContext.getTaskAttemptID.getTaskID.getId - val filename = f"part-$split%05d-$jobId$ext" + val filename = getFilename(taskContext, ext) val stagingDir: String = committer match { // For FileOutputCommitter it has its own staging path called "work path". @@ -73,6 +84,28 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) } } + override def newTaskTempFileAbsPath( + taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = { + val filename = getFilename(taskContext, ext) + val absOutputPath = new Path(absoluteDir, filename).toString + + // Include a UUID here to prevent file collisions for one task writing to different dirs. + // In principle we could include hash(absoluteDir) instead but this is simpler. + val tmpOutputPath = new Path( + absPathStagingDir, UUID.randomUUID().toString() + "-" + filename).toString + + addedAbsPathFiles(tmpOutputPath) = absOutputPath + tmpOutputPath + } + + private def getFilename(taskContext: TaskAttemptContext, ext: String): String = { + // The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet + // Note that %05d does not truncate the split number, so if we have more than 100000 tasks, + // the file name is fine and won't overflow. + val split = taskContext.getTaskAttemptID.getTaskID.getId + f"part-$split%05d-$jobId$ext" + } + override def setupJob(jobContext: JobContext): Unit = { // Setup IDs val jobId = SparkHadoopWriterUtils.createJobID(new Date, 0) @@ -93,26 +126,42 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = { committer.commitJob(jobContext) + val filesToMove = taskCommits.map(_.obj.asInstanceOf[Map[String, String]]) + .foldLeft(Map[String, String]())(_ ++ _) + logDebug(s"Committing files staged for absolute locations $filesToMove") + val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) + for ((src, dst) <- filesToMove) { + fs.rename(new Path(src), new Path(dst)) + } + fs.delete(absPathStagingDir, true) } override def abortJob(jobContext: JobContext): Unit = { committer.abortJob(jobContext, JobStatus.State.FAILED) + val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) + fs.delete(absPathStagingDir, true) } override def setupTask(taskContext: TaskAttemptContext): Unit = { committer = setupCommitter(taskContext) committer.setupTask(taskContext) + addedAbsPathFiles = mutable.Map[String, String]() } override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = { val attemptId = taskContext.getTaskAttemptID SparkHadoopMapRedUtil.commitTask( committer, taskContext, attemptId.getJobID.getId, attemptId.getTaskID.getId) - EmptyTaskCommitMessage + new TaskCommitMessage(addedAbsPathFiles.toMap) } override def abortTask(taskContext: TaskAttemptContext): Unit = { committer.abortTask(taskContext) + // best effort cleanup of other staged files + for ((src, _) <- addedAbsPathFiles) { + val tmp = new Path(src) + tmp.getFileSystem(taskContext.getConfiguration).delete(tmp, false) + } } /** Whether we are using a direct output committer */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 2c4db0d2c3425..3fa7bf1cdbf16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -172,24 +172,20 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { val tableIdent = visitTableIdentifier(ctx.tableIdentifier) val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) - val dynamicPartitionKeys = partitionKeys.filter(_._2.isEmpty) + val dynamicPartitionKeys: Map[String, Option[String]] = partitionKeys.filter(_._2.isEmpty) if (ctx.EXISTS != null && dynamicPartitionKeys.nonEmpty) { throw new ParseException(s"Dynamic partitions do not support IF NOT EXISTS. Specified " + "partitions with value: " + dynamicPartitionKeys.keys.mkString("[", ",", "]"), ctx) } val overwrite = ctx.OVERWRITE != null - val overwritePartition = - if (overwrite && partitionKeys.nonEmpty && dynamicPartitionKeys.isEmpty) { - Some(partitionKeys.map(t => (t._1, t._2.get))) - } else { - None - } + val staticPartitionKeys: Map[String, String] = + partitionKeys.filter(_._2.nonEmpty).map(t => (t._1, t._2.get)) InsertIntoTable( UnresolvedRelation(tableIdent, None), partitionKeys, query, - OverwriteOptions(overwrite, overwritePartition), + OverwriteOptions(overwrite, if (overwrite) staticPartitionKeys else Map.empty), ctx.EXISTS != null) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index dcae7b026f58c..4dcc2885536eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -349,13 +349,15 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { * Options for writing new data into a table. * * @param enabled whether to overwrite existing data in the table. - * @param specificPartition only data in the specified partition will be overwritten. + * @param staticPartitionKeys if non-empty, specifies that we only want to overwrite partitions + * that match this partial partition spec. If empty, all partitions + * will be overwritten. */ case class OverwriteOptions( enabled: Boolean, - specificPartition: Option[CatalogTypes.TablePartitionSpec] = None) { - if (specificPartition.isDefined) { - assert(enabled, "Overwrite must be enabled when specifying a partition to overwrite.") + staticPartitionKeys: CatalogTypes.TablePartitionSpec = Map.empty) { + if (staticPartitionKeys.nonEmpty) { + assert(enabled, "Overwrite must be enabled when specifying specific partitions.") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 5f0f6ee479c69..9aae520ae664a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -185,9 +185,9 @@ class PlanParserSuite extends PlanTest { OverwriteOptions( overwrite, if (overwrite && partition.nonEmpty) { - Some(partition.map(kv => (kv._1, kv._2.get))) + partition.map(kv => (kv._1, kv._2.get)) } else { - None + Map.empty }), ifNotExists) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 5d663949df6b5..65422f1495f03 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -417,15 +417,17 @@ case class DataSource( // will be adjusted within InsertIntoHadoopFsRelation. val plan = InsertIntoHadoopFsRelationCommand( - outputPath, - columns, - bucketSpec, - format, - _ => Unit, // No existing table needs to be refreshed. - options, - data.logicalPlan, - mode, - catalogTable) + outputPath = outputPath, + staticPartitionKeys = Map.empty, + customPartitionLocations = Map.empty, + partitionColumns = columns, + bucketSpec = bucketSpec, + fileFormat = format, + refreshFunction = _ => Unit, // No existing table needs to be refreshed. + options = options, + query = data.logicalPlan, + mode = mode, + catalogTable = catalogTable) sparkSession.sessionState.executePlan(plan).toRdd // Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring it. copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 739aeac877b99..4f19a2d00b0e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -24,10 +24,10 @@ import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SimpleCatalogRelation} +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTablePartition, SimpleCatalogRelation} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} -import org.apache.spark.sql.execution.command.{AlterTableAddPartitionCommand, DDLUtils, ExecutedCommandExec} +import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -182,41 +182,53 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { "Cannot overwrite a path that is also being read from.") } - val overwritingSinglePartition = - overwrite.specificPartition.isDefined && + val partitionSchema = query.resolve( + t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver) + val partitionsTrackedByCatalog = t.sparkSession.sessionState.conf.manageFilesourcePartitions && + l.catalogTable.isDefined && l.catalogTable.get.partitionColumnNames.nonEmpty && l.catalogTable.get.tracksPartitionsInCatalog - val effectiveOutputPath = if (overwritingSinglePartition) { - val partition = t.sparkSession.sessionState.catalog.getPartition( - l.catalogTable.get.identifier, overwrite.specificPartition.get) - new Path(partition.location) - } else { - outputPath - } - - val effectivePartitionSchema = if (overwritingSinglePartition) { - Nil - } else { - query.resolve(t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver) + var initialMatchingPartitions: Seq[TablePartitionSpec] = Nil + var customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty + + // When partitions are tracked by the catalog, compute all custom partition locations that + // may be relevant to the insertion job. + if (partitionsTrackedByCatalog) { + val matchingPartitions = t.sparkSession.sessionState.catalog.listPartitions( + l.catalogTable.get.identifier, Some(overwrite.staticPartitionKeys)) + initialMatchingPartitions = matchingPartitions.map(_.spec) + customPartitionLocations = getCustomPartitionLocations( + t.sparkSession, l.catalogTable.get, outputPath, matchingPartitions) } + // Callback for updating metastore partition metadata after the insertion job completes. + // TODO(ekl) consider moving this into InsertIntoHadoopFsRelationCommand def refreshPartitionsCallback(updatedPartitions: Seq[TablePartitionSpec]): Unit = { - if (l.catalogTable.isDefined && updatedPartitions.nonEmpty && - l.catalogTable.get.partitionColumnNames.nonEmpty && - l.catalogTable.get.tracksPartitionsInCatalog) { - val metastoreUpdater = AlterTableAddPartitionCommand( - l.catalogTable.get.identifier, - updatedPartitions.map(p => (p, None)), - ifNotExists = true) - metastoreUpdater.run(t.sparkSession) + if (partitionsTrackedByCatalog) { + val newPartitions = updatedPartitions.toSet -- initialMatchingPartitions + if (newPartitions.nonEmpty) { + AlterTableAddPartitionCommand( + l.catalogTable.get.identifier, newPartitions.toSeq.map(p => (p, None)), + ifNotExists = true).run(t.sparkSession) + } + if (overwrite.enabled) { + val deletedPartitions = initialMatchingPartitions.toSet -- updatedPartitions + if (deletedPartitions.nonEmpty) { + AlterTableDropPartitionCommand( + l.catalogTable.get.identifier, deletedPartitions.toSeq, + ifExists = true, purge = true).run(t.sparkSession) + } + } } t.location.refresh() } val insertCmd = InsertIntoHadoopFsRelationCommand( - effectiveOutputPath, - effectivePartitionSchema, + outputPath, + if (overwrite.enabled) overwrite.staticPartitionKeys else Map.empty, + customPartitionLocations, + partitionSchema, t.bucketSpec, t.fileFormat, refreshPartitionsCallback, @@ -227,6 +239,34 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { insertCmd } + + /** + * Given a set of input partitions, returns those that have locations that differ from the + * Hive default (e.g. /k1=v1/k2=v2). These partitions were manually assigned locations by + * the user. + * + * @return a mapping from partition specs to their custom locations + */ + private def getCustomPartitionLocations( + spark: SparkSession, + table: CatalogTable, + basePath: Path, + partitions: Seq[CatalogTablePartition]): Map[TablePartitionSpec, String] = { + val hadoopConf = spark.sessionState.newHadoopConf + val fs = basePath.getFileSystem(hadoopConf) + val qualifiedBasePath = basePath.makeQualified(fs.getUri, fs.getWorkingDirectory) + partitions.flatMap { p => + val defaultLocation = qualifiedBasePath.suffix( + "/" + PartitioningUtils.getPathFragment(p.spec, table.partitionSchema)).toString + val catalogLocation = new Path(p.location).makeQualified( + fs.getUri, fs.getWorkingDirectory).toString + if (catalogLocation != defaultLocation) { + Some(p.spec -> catalogLocation) + } else { + None + } + }.toMap + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 69b3fa667ef54..4e4b0e48cd7bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -47,6 +47,10 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter /** A helper object for writing FileFormat data out to a location. */ object FileFormatWriter extends Logging { + /** Describes how output files should be placed in the filesystem. */ + case class OutputSpec( + outputPath: String, customPartitionLocations: Map[TablePartitionSpec, String]) + /** A shared job description for all the write tasks. */ private class WriteJobDescription( val uuid: String, // prevent collision between different (appending) write jobs @@ -56,7 +60,8 @@ object FileFormatWriter extends Logging { val partitionColumns: Seq[Attribute], val nonPartitionColumns: Seq[Attribute], val bucketSpec: Option[BucketSpec], - val path: String) + val path: String, + val customPartitionLocations: Map[TablePartitionSpec, String]) extends Serializable { assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ nonPartitionColumns), @@ -83,7 +88,7 @@ object FileFormatWriter extends Logging { plan: LogicalPlan, fileFormat: FileFormat, committer: FileCommitProtocol, - outputPath: String, + outputSpec: OutputSpec, hadoopConf: Configuration, partitionColumns: Seq[Attribute], bucketSpec: Option[BucketSpec], @@ -93,7 +98,7 @@ object FileFormatWriter extends Logging { val job = Job.getInstance(hadoopConf) job.setOutputKeyClass(classOf[Void]) job.setOutputValueClass(classOf[InternalRow]) - FileOutputFormat.setOutputPath(job, new Path(outputPath)) + FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath)) val partitionSet = AttributeSet(partitionColumns) val dataColumns = plan.output.filterNot(partitionSet.contains) @@ -111,7 +116,8 @@ object FileFormatWriter extends Logging { partitionColumns = partitionColumns, nonPartitionColumns = dataColumns, bucketSpec = bucketSpec, - path = outputPath) + path = outputSpec.outputPath, + customPartitionLocations = outputSpec.customPartitionLocations) SQLExecution.withNewExecutionId(sparkSession, queryExecution) { // This call shouldn't be put into the `try` block below because it only initializes and @@ -308,7 +314,17 @@ object FileFormatWriter extends Logging { } val ext = bucketId + description.outputWriterFactory.getFileExtension(taskAttemptContext) - val path = committer.newTaskTempFile(taskAttemptContext, partDir, ext) + val customPath = partDir match { + case Some(dir) => + description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) + case _ => + None + } + val path = if (customPath.isDefined) { + committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext) + } else { + committer.newTaskTempFile(taskAttemptContext, partDir, ext) + } val newWriter = description.outputWriterFactory.newInstance( path = path, dataSchema = description.nonPartitionColumns.toStructType, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index a0a8cb5024c33..28975e1546e79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import java.io.IOException -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql._ @@ -32,19 +32,32 @@ import org.apache.spark.sql.execution.command.RunnableCommand /** * A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending. * Writing to dynamic partitions is also supported. + * + * @param staticPartitionKeys partial partitioning spec for write. This defines the scope of + * partition overwrites: when the spec is empty, all partitions are + * overwritten. When it covers a prefix of the partition keys, only + * partitions matching the prefix are overwritten. + * @param customPartitionLocations mapping of partition specs to their custom locations. The + * caller should guarantee that exactly those table partitions + * falling under the specified static partition keys are contained + * in this map, and that no other partitions are. */ case class InsertIntoHadoopFsRelationCommand( outputPath: Path, + staticPartitionKeys: TablePartitionSpec, + customPartitionLocations: Map[TablePartitionSpec, String], partitionColumns: Seq[Attribute], bucketSpec: Option[BucketSpec], fileFormat: FileFormat, - refreshFunction: (Seq[TablePartitionSpec]) => Unit, + refreshFunction: Seq[TablePartitionSpec] => Unit, options: Map[String, String], @transient query: LogicalPlan, mode: SaveMode, catalogTable: Option[CatalogTable]) extends RunnableCommand { + import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName + override protected def innerChildren: Seq[LogicalPlan] = query :: Nil override def run(sparkSession: SparkSession): Seq[Row] = { @@ -66,10 +79,7 @@ case class InsertIntoHadoopFsRelationCommand( case (SaveMode.ErrorIfExists, true) => throw new AnalysisException(s"path $qualifiedOutputPath already exists.") case (SaveMode.Overwrite, true) => - if (!fs.delete(qualifiedOutputPath, true /* recursively */)) { - throw new IOException(s"Unable to clear output " + - s"directory $qualifiedOutputPath prior to writing to it") - } + deleteMatchingPartitions(fs, qualifiedOutputPath) true case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) => true @@ -93,7 +103,8 @@ case class InsertIntoHadoopFsRelationCommand( plan = query, fileFormat = fileFormat, committer = committer, - outputPath = qualifiedOutputPath.toString, + outputSpec = FileFormatWriter.OutputSpec( + qualifiedOutputPath.toString, customPartitionLocations), hadoopConf = hadoopConf, partitionColumns = partitionColumns, bucketSpec = bucketSpec, @@ -105,4 +116,40 @@ case class InsertIntoHadoopFsRelationCommand( Seq.empty[Row] } + + /** + * Deletes all partition files that match the specified static prefix. Partitions with custom + * locations are also cleared based on the custom locations map given to this class. + */ + private def deleteMatchingPartitions(fs: FileSystem, qualifiedOutputPath: Path): Unit = { + val staticPartitionPrefix = if (staticPartitionKeys.nonEmpty) { + "/" + partitionColumns.flatMap { p => + staticPartitionKeys.get(p.name) match { + case Some(value) => + Some(escapePathName(p.name) + "=" + escapePathName(value)) + case None => + None + } + }.mkString("/") + } else { + "" + } + // first clear the path determined by the static partition keys (e.g. /table/foo=1) + val staticPrefixPath = qualifiedOutputPath.suffix(staticPartitionPrefix) + if (fs.exists(staticPrefixPath) && !fs.delete(staticPrefixPath, true /* recursively */)) { + throw new IOException(s"Unable to clear output " + + s"directory $staticPrefixPath prior to writing to it") + } + // now clear all custom partition locations (e.g. /custom/dir/where/foo=2/bar=4) + for ((spec, customLoc) <- customPartitionLocations) { + assert( + (staticPartitionKeys.toSet -- spec).isEmpty, + "Custom partition location did not match static partitioning keys") + val path = new Path(customLoc) + if (fs.exists(path) && !fs.delete(path, true)) { + throw new IOException(s"Unable to clear partition " + + s"directory $path prior to writing to it") + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index a28b04ca3fb5a..bf9f318780ec2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -62,6 +62,7 @@ object PartitioningUtils { } import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.DEFAULT_PARTITION_NAME + import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.unescapePathName /** @@ -252,6 +253,15 @@ object PartitioningUtils { }.toMap } + /** + * This is the inverse of parsePathFragment(). + */ + def getPathFragment(spec: TablePartitionSpec, partitionSchema: StructType): String = { + partitionSchema.map { field => + escapePathName(field.name) + "=" + escapePathName(spec(field.name)) + }.mkString("/") + } + /** * Normalize the column names in partition specification, w.r.t. the real partition column names * and case sensitivity. e.g., if the partition spec has a column named `monTh`, and there is a diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index e849cafef4184..f1c5f9ab5067d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -80,7 +80,7 @@ class FileStreamSink( plan = data.logicalPlan, fileFormat = fileFormat, committer = committer, - outputPath = path, + outputSpec = FileFormatWriter.OutputSpec(path, Map.empty), hadoopConf = hadoopConf, partitionColumns = partitionColumns, bucketSpec = None, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala index 1fe13fa1623fc..92191c8b64b72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala @@ -96,6 +96,12 @@ class ManifestFileCommitProtocol(jobId: String, path: String) file } + override def newTaskTempFileAbsPath( + taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = { + throw new UnsupportedOperationException( + s"$this does not support adding files with an absolute path") + } + override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = { if (addedFiles.nonEmpty) { val fs = new Path(addedFiles.head).getFileSystem(taskContext.getConfiguration) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala index ac435bf6195b0..a1aa07456fd36 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.util.Utils class PartitionProviderCompatibilitySuite extends QueryTest with TestHiveSingleton with SQLTestUtils { @@ -135,7 +136,7 @@ class PartitionProviderCompatibilitySuite } } - test("insert overwrite partition of legacy datasource table overwrites entire table") { + test("insert overwrite partition of legacy datasource table") { withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "false") { withTable("test") { withTempDir { dir => @@ -144,9 +145,9 @@ class PartitionProviderCompatibilitySuite """insert overwrite table test |partition (partCol=1) |select * from range(100)""".stripMargin) - assert(spark.sql("select * from test").count() == 100) + assert(spark.sql("select * from test").count() == 104) - // Dynamic partitions case + // Overwriting entire table spark.sql("insert overwrite table test select id, id from range(10)".stripMargin) assert(spark.sql("select * from test").count() == 10) } @@ -186,4 +187,158 @@ class PartitionProviderCompatibilitySuite } } } + + /** + * Runs a test against a multi-level partitioned table, then validates that the custom locations + * were respected by the output writer. + * + * The initial partitioning structure is: + * /P1=0/P2=0 -- custom location a + * /P1=0/P2=1 -- custom location b + * /P1=1/P2=0 -- custom location c + * /P1=1/P2=1 -- default location + */ + private def testCustomLocations(testFn: => Unit): Unit = { + val base = Utils.createTempDir(namePrefix = "base") + val a = Utils.createTempDir(namePrefix = "a") + val b = Utils.createTempDir(namePrefix = "b") + val c = Utils.createTempDir(namePrefix = "c") + try { + spark.sql(s""" + |create table test (id long, P1 int, P2 int) + |using parquet + |options (path "${base.getAbsolutePath}") + |partitioned by (P1, P2)""".stripMargin) + spark.sql(s"alter table test add partition (P1=0, P2=0) location '${a.getAbsolutePath}'") + spark.sql(s"alter table test add partition (P1=0, P2=1) location '${b.getAbsolutePath}'") + spark.sql(s"alter table test add partition (P1=1, P2=0) location '${c.getAbsolutePath}'") + spark.sql(s"alter table test add partition (P1=1, P2=1)") + + testFn + + // Now validate the partition custom locations were respected + val initialCount = spark.sql("select * from test").count() + val numA = spark.sql("select * from test where P1=0 and P2=0").count() + val numB = spark.sql("select * from test where P1=0 and P2=1").count() + val numC = spark.sql("select * from test where P1=1 and P2=0").count() + Utils.deleteRecursively(a) + spark.sql("refresh table test") + assert(spark.sql("select * from test where P1=0 and P2=0").count() == 0) + assert(spark.sql("select * from test").count() == initialCount - numA) + Utils.deleteRecursively(b) + spark.sql("refresh table test") + assert(spark.sql("select * from test where P1=0 and P2=1").count() == 0) + assert(spark.sql("select * from test").count() == initialCount - numA - numB) + Utils.deleteRecursively(c) + spark.sql("refresh table test") + assert(spark.sql("select * from test where P1=1 and P2=0").count() == 0) + assert(spark.sql("select * from test").count() == initialCount - numA - numB - numC) + } finally { + Utils.deleteRecursively(base) + Utils.deleteRecursively(a) + Utils.deleteRecursively(b) + Utils.deleteRecursively(c) + spark.sql("drop table test") + } + } + + test("sanity check table setup") { + testCustomLocations { + assert(spark.sql("select * from test").count() == 0) + assert(spark.sql("show partitions test").count() == 4) + } + } + + test("insert into partial dynamic partitions") { + testCustomLocations { + spark.sql("insert into test partition (P1=0, P2) select id, id from range(10)") + assert(spark.sql("select * from test").count() == 10) + assert(spark.sql("show partitions test").count() == 12) + spark.sql("insert into test partition (P1=0, P2) select id, id from range(10)") + assert(spark.sql("select * from test").count() == 20) + assert(spark.sql("show partitions test").count() == 12) + spark.sql("insert into test partition (P1=1, P2) select id, id from range(10)") + assert(spark.sql("select * from test").count() == 30) + assert(spark.sql("show partitions test").count() == 20) + spark.sql("insert into test partition (P1=2, P2) select id, id from range(10)") + assert(spark.sql("select * from test").count() == 40) + assert(spark.sql("show partitions test").count() == 30) + } + } + + test("insert into fully dynamic partitions") { + testCustomLocations { + spark.sql("insert into test partition (P1, P2) select id, id, id from range(10)") + assert(spark.sql("select * from test").count() == 10) + assert(spark.sql("show partitions test").count() == 12) + spark.sql("insert into test partition (P1, P2) select id, id, id from range(10)") + assert(spark.sql("select * from test").count() == 20) + assert(spark.sql("show partitions test").count() == 12) + } + } + + test("insert into static partition") { + testCustomLocations { + spark.sql("insert into test partition (P1=0, P2=0) select id from range(10)") + assert(spark.sql("select * from test").count() == 10) + assert(spark.sql("show partitions test").count() == 4) + spark.sql("insert into test partition (P1=0, P2=0) select id from range(10)") + assert(spark.sql("select * from test").count() == 20) + assert(spark.sql("show partitions test").count() == 4) + spark.sql("insert into test partition (P1=1, P2=1) select id from range(10)") + assert(spark.sql("select * from test").count() == 30) + assert(spark.sql("show partitions test").count() == 4) + } + } + + test("overwrite partial dynamic partitions") { + testCustomLocations { + spark.sql("insert overwrite table test partition (P1=0, P2) select id, id from range(10)") + assert(spark.sql("select * from test").count() == 10) + assert(spark.sql("show partitions test").count() == 12) + spark.sql("insert overwrite table test partition (P1=0, P2) select id, id from range(5)") + assert(spark.sql("select * from test").count() == 5) + assert(spark.sql("show partitions test").count() == 7) + spark.sql("insert overwrite table test partition (P1=0, P2) select id, id from range(1)") + assert(spark.sql("select * from test").count() == 1) + assert(spark.sql("show partitions test").count() == 3) + spark.sql("insert overwrite table test partition (P1=1, P2) select id, id from range(10)") + assert(spark.sql("select * from test").count() == 11) + assert(spark.sql("show partitions test").count() == 11) + spark.sql("insert overwrite table test partition (P1=1, P2) select id, id from range(1)") + assert(spark.sql("select * from test").count() == 2) + assert(spark.sql("show partitions test").count() == 2) + spark.sql("insert overwrite table test partition (P1=3, P2) select id, id from range(100)") + assert(spark.sql("select * from test").count() == 102) + assert(spark.sql("show partitions test").count() == 102) + } + } + + test("overwrite fully dynamic partitions") { + testCustomLocations { + spark.sql("insert overwrite table test partition (P1, P2) select id, id, id from range(10)") + assert(spark.sql("select * from test").count() == 10) + assert(spark.sql("show partitions test").count() == 10) + spark.sql("insert overwrite table test partition (P1, P2) select id, id, id from range(5)") + assert(spark.sql("select * from test").count() == 5) + assert(spark.sql("show partitions test").count() == 5) + } + } + + test("overwrite static partition") { + testCustomLocations { + spark.sql("insert overwrite table test partition (P1=0, P2=0) select id from range(10)") + assert(spark.sql("select * from test").count() == 10) + assert(spark.sql("show partitions test").count() == 4) + spark.sql("insert overwrite table test partition (P1=0, P2=0) select id from range(5)") + assert(spark.sql("select * from test").count() == 5) + assert(spark.sql("show partitions test").count() == 4) + spark.sql("insert overwrite table test partition (P1=1, P2=1) select id from range(5)") + assert(spark.sql("select * from test").count() == 10) + assert(spark.sql("show partitions test").count() == 4) + spark.sql("insert overwrite table test partition (P1=1, P2=2) select id from range(5)") + assert(spark.sql("select * from test").count() == 15) + assert(spark.sql("show partitions test").count() == 5) + } + } } From 5ddf69470b93c0b8a28bb4ac905e7670d9c50a95 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 10 Nov 2016 17:13:10 -0800 Subject: [PATCH 116/198] [SPARK-18401][SPARKR][ML] SparkR random forest should support output original label. ## What changes were proposed in this pull request? SparkR ```spark.randomForest``` classification prediction should output original label rather than the indexed label. This issue is very similar with [SPARK-18291](https://issues.apache.org/jira/browse/SPARK-18291). ## How was this patch tested? Add unit tests. Author: Yanbo Liang Closes #15842 from yanboliang/spark-18401. --- R/pkg/inst/tests/testthat/test_mllib.R | 24 ++++++++++++++++ .../r/RandomForestClassificationWrapper.scala | 28 ++++++++++++++++--- 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 33e9d0d267ac5..b76f75dbdc682 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -935,6 +935,10 @@ test_that("spark.randomForest Classification", { expect_equal(stats$numTrees, 20) expect_error(capture.output(stats), NA) expect_true(length(capture.output(stats)) > 6) + # Test string prediction values + predictions <- collect(predict(model, data))$prediction + expect_equal(length(grep("setosa", predictions)), 50) + expect_equal(length(grep("versicolor", predictions)), 50) modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp") write.ml(model, modelPath) @@ -947,6 +951,26 @@ test_that("spark.randomForest Classification", { expect_equal(stats$numClasses, stats2$numClasses) unlink(modelPath) + + # Test numeric response variable + labelToIndex <- function(species) { + switch(as.character(species), + setosa = 0.0, + versicolor = 1.0, + virginica = 2.0 + ) + } + iris$NumericSpecies <- lapply(iris$Species, labelToIndex) + data <- suppressWarnings(createDataFrame(iris[-5])) + model <- spark.randomForest(data, NumericSpecies ~ Petal_Length + Petal_Width, "classification", + maxDepth = 5, maxBins = 16) + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$numTrees, 20) + # Test numeric prediction values + predictions <- collect(predict(model, data))$prediction + expect_equal(length(grep("1.0", predictions)), 50) + expect_equal(length(grep("2.0", predictions)), 50) }) test_that("spark.gbt", { diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala index 6947ba7e7597a..31f846dc6cfec 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala @@ -23,9 +23,9 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} -import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} -import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.feature.{IndexToString, RFormula} import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -35,6 +35,8 @@ private[r] class RandomForestClassifierWrapper private ( val formula: String, val features: Array[String]) extends MLWritable { + import RandomForestClassifierWrapper._ + private val rfcModel: RandomForestClassificationModel = pipeline.stages(1).asInstanceOf[RandomForestClassificationModel] @@ -46,7 +48,9 @@ private[r] class RandomForestClassifierWrapper private ( def summary: String = rfcModel.toDebugString def transform(dataset: Dataset[_]): DataFrame = { - pipeline.transform(dataset).drop(rfcModel.getFeaturesCol) + pipeline.transform(dataset) + .drop(PREDICTED_LABEL_INDEX_COL) + .drop(rfcModel.getFeaturesCol) } override def write: MLWriter = new @@ -54,6 +58,10 @@ private[r] class RandomForestClassifierWrapper private ( } private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestClassifierWrapper] { + + val PREDICTED_LABEL_INDEX_COL = "pred_label_idx" + val PREDICTED_LABEL_COL = "prediction" + def fit( // scalastyle:ignore data: DataFrame, formula: String, @@ -73,6 +81,7 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC val rFormula = new RFormula() .setFormula(formula) + .setForceIndexLabel(true) RWrapperUtils.checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) @@ -82,6 +91,11 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC .attributes.get val features = featureAttrs.map(_.name.get) + // get label names from output schema + val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol)) + .asInstanceOf[NominalAttribute] + val labels = labelAttr.values.get + // assemble and fit the pipeline val rfc = new RandomForestClassifier() .setMaxDepth(maxDepth) @@ -97,10 +111,16 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC .setCacheNodeIds(cacheNodeIds) .setProbabilityCol(probabilityCol) .setFeaturesCol(rFormula.getFeaturesCol) + .setPredictionCol(PREDICTED_LABEL_INDEX_COL) if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong) + val idxToStr = new IndexToString() + .setInputCol(PREDICTED_LABEL_INDEX_COL) + .setOutputCol(PREDICTED_LABEL_COL) + .setLabels(labels) + val pipeline = new Pipeline() - .setStages(Array(rFormulaModel, rfc)) + .setStages(Array(rFormulaModel, rfc, idxToStr)) .fit(data) new RandomForestClassifierWrapper(pipeline, formula, features) From 4f15d94cfec86130f8dab28ae2e228ded8124020 Mon Sep 17 00:00:00 2001 From: Junjie Chen Date: Fri, 11 Nov 2016 10:37:58 -0800 Subject: [PATCH 117/198] [SPARK-13331] AES support for over-the-wire encryption ## What changes were proposed in this pull request? DIGEST-MD5 mechanism is used for SASL authentication and secure communication. DIGEST-MD5 mechanism supports 3DES, DES, and RC4 ciphers. However, 3DES, DES and RC4 are slow relatively. AES provide better performance and security by design and is a replacement for 3DES according to NIST. Apache Common Crypto is a cryptographic library optimized with AES-NI, this patch employ Apache Common Crypto as enc/dec backend for SASL authentication and secure channel to improve spark RPC. ## How was this patch tested? Unit tests and Integration test. Author: Junjie Chen Closes #15172 from cjjnjust/shuffle_rpc_encrypt. --- common/network-common/pom.xml | 4 + .../network/sasl/SaslClientBootstrap.java | 23 +- .../spark/network/sasl/SaslRpcHandler.java | 101 ++++-- .../spark/network/sasl/aes/AesCipher.java | 294 ++++++++++++++++++ .../network/sasl/aes/AesConfigMessage.java | 101 ++++++ .../util/ByteArrayReadableChannel.java | 62 ++++ .../spark/network/util/TransportConf.java | 22 ++ .../spark/network/sasl/SparkSaslSuite.java | 93 +++++- docs/configuration.md | 26 ++ 9 files changed, 689 insertions(+), 37 deletions(-) create mode 100644 common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java create mode 100644 common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java create mode 100644 common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayReadableChannel.java diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index fcefe64d59c91..ca99fa89ebe1b 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -76,6 +76,10 @@ guava compile + + org.apache.commons + commons-crypto + diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index 9e5c616ee5a1f..a1bb453657460 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -30,6 +30,8 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.sasl.aes.AesCipher; +import org.apache.spark.network.sasl.aes.AesConfigMessage; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.TransportConf; @@ -88,9 +90,26 @@ public void doBootstrap(TransportClient client, Channel channel) { throw new RuntimeException( new SaslException("Encryption requests by negotiated non-encrypted connection.")); } - SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize()); + + if (conf.aesEncryptionEnabled()) { + // Generate a request config message to send to server. + AesConfigMessage configMessage = AesCipher.createConfigMessage(conf); + ByteBuffer buf = configMessage.encodeMessage(); + + // Encrypted the config message. + byte[] toEncrypt = JavaUtils.bufferToArray(buf); + ByteBuffer encrypted = ByteBuffer.wrap(saslClient.wrap(toEncrypt, 0, toEncrypt.length)); + + client.sendRpcSync(encrypted, conf.saslRTTimeoutMs()); + AesCipher cipher = new AesCipher(configMessage, conf); + logger.info("Enabling AES cipher for client channel {}", client); + cipher.addToChannel(channel); + saslClient.dispose(); + } else { + SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize()); + } saslClient = null; - logger.debug("Channel {} configured for SASL encryption.", client); + logger.debug("Channel {} configured for encryption.", client); } } catch (IOException ioe) { throw new RuntimeException(ioe); diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index c41f5b6873f6c..b2f3ef214b7ac 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -29,6 +29,8 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.sasl.aes.AesCipher; +import org.apache.spark.network.sasl.aes.AesConfigMessage; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.util.JavaUtils; @@ -59,6 +61,7 @@ class SaslRpcHandler extends RpcHandler { private SparkSaslServer saslServer; private boolean isComplete; + private boolean isAuthenticated; SaslRpcHandler( TransportConf conf, @@ -71,6 +74,7 @@ class SaslRpcHandler extends RpcHandler { this.secretKeyHolder = secretKeyHolder; this.saslServer = null; this.isComplete = false; + this.isAuthenticated = false; } @Override @@ -80,30 +84,31 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb delegate.receive(client, message, callback); return; } + if (saslServer == null || !saslServer.isComplete()) { + ByteBuf nettyBuf = Unpooled.wrappedBuffer(message); + SaslMessage saslMessage; + try { + saslMessage = SaslMessage.decode(nettyBuf); + } finally { + nettyBuf.release(); + } - ByteBuf nettyBuf = Unpooled.wrappedBuffer(message); - SaslMessage saslMessage; - try { - saslMessage = SaslMessage.decode(nettyBuf); - } finally { - nettyBuf.release(); - } - - if (saslServer == null) { - // First message in the handshake, setup the necessary state. - client.setClientId(saslMessage.appId); - saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder, - conf.saslServerAlwaysEncrypt()); - } + if (saslServer == null) { + // First message in the handshake, setup the necessary state. + client.setClientId(saslMessage.appId); + saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder, + conf.saslServerAlwaysEncrypt()); + } - byte[] response; - try { - response = saslServer.response(JavaUtils.bufferToArray( - saslMessage.body().nioByteBuffer())); - } catch (IOException ioe) { - throw new RuntimeException(ioe); + byte[] response; + try { + response = saslServer.response(JavaUtils.bufferToArray( + saslMessage.body().nioByteBuffer())); + } catch (IOException ioe) { + throw new RuntimeException(ioe); + } + callback.onSuccess(ByteBuffer.wrap(response)); } - callback.onSuccess(ByteBuffer.wrap(response)); // Setup encryption after the SASL response is sent, otherwise the client can't parse the // response. It's ok to change the channel pipeline here since we are processing an incoming @@ -111,15 +116,42 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb // method returns. This assumes that the code ensures, through other means, that no outbound // messages are being written to the channel while negotiation is still going on. if (saslServer.isComplete()) { - logger.debug("SASL authentication successful for channel {}", client); - isComplete = true; - if (SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) { + if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) { + logger.debug("SASL authentication successful for channel {}", client); + complete(true); + return; + } + + if (!conf.aesEncryptionEnabled()) { logger.debug("Enabling encryption for channel {}", client); SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize()); - saslServer = null; - } else { - saslServer.dispose(); - saslServer = null; + complete(false); + return; + } + + // Extra negotiation should happen after authentication, so return directly while + // processing authenticate. + if (!isAuthenticated) { + logger.debug("SASL authentication successful for channel {}", client); + isAuthenticated = true; + return; + } + + // Create AES cipher when it is authenticated + try { + byte[] encrypted = JavaUtils.bufferToArray(message); + ByteBuffer decrypted = ByteBuffer.wrap(saslServer.unwrap(encrypted, 0 , encrypted.length)); + + AesConfigMessage configMessage = AesConfigMessage.decodeMessage(decrypted); + AesCipher cipher = new AesCipher(configMessage, conf); + + // Send response back to client to confirm that server accept config. + callback.onSuccess(JavaUtils.stringToBytes(AesCipher.TRANSFORM)); + logger.info("Enabling AES cipher for Server channel {}", client); + cipher.addToChannel(channel); + complete(true); + } catch (IOException ioe) { + throw new RuntimeException(ioe); } } } @@ -155,4 +187,17 @@ public void exceptionCaught(Throwable cause, TransportClient client) { delegate.exceptionCaught(cause, client); } + private void complete(boolean dispose) { + if (dispose) { + try { + saslServer.dispose(); + } catch (RuntimeException e) { + logger.error("Error while disposing SASL server", e); + } + } + + saslServer = null; + isComplete = true; + } + } diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java new file mode 100644 index 0000000000000..78034a69f734d --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl.aes; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.WritableByteChannel; +import java.util.Properties; +import javax.crypto.spec.SecretKeySpec; +import javax.crypto.spec.IvParameterSpec; + +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.*; +import io.netty.util.AbstractReferenceCounted; +import org.apache.commons.crypto.cipher.CryptoCipherFactory; +import org.apache.commons.crypto.random.CryptoRandom; +import org.apache.commons.crypto.random.CryptoRandomFactory; +import org.apache.commons.crypto.stream.CryptoInputStream; +import org.apache.commons.crypto.stream.CryptoOutputStream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.util.ByteArrayReadableChannel; +import org.apache.spark.network.util.ByteArrayWritableChannel; +import org.apache.spark.network.util.TransportConf; + +/** + * AES cipher for encryption and decryption. + */ +public class AesCipher { + private static final Logger logger = LoggerFactory.getLogger(AesCipher.class); + public static final String ENCRYPTION_HANDLER_NAME = "AesEncryption"; + public static final String DECRYPTION_HANDLER_NAME = "AesDecryption"; + public static final int STREAM_BUFFER_SIZE = 1024 * 32; + public static final String TRANSFORM = "AES/CTR/NoPadding"; + + private final SecretKeySpec inKeySpec; + private final IvParameterSpec inIvSpec; + private final SecretKeySpec outKeySpec; + private final IvParameterSpec outIvSpec; + private final Properties properties; + + public AesCipher(AesConfigMessage configMessage, TransportConf conf) throws IOException { + this.properties = CryptoStreamUtils.toCryptoConf(conf); + this.inKeySpec = new SecretKeySpec(configMessage.inKey, "AES"); + this.inIvSpec = new IvParameterSpec(configMessage.inIv); + this.outKeySpec = new SecretKeySpec(configMessage.outKey, "AES"); + this.outIvSpec = new IvParameterSpec(configMessage.outIv); + } + + /** + * Create AES crypto output stream + * @param ch The underlying channel to write out. + * @return Return output crypto stream for encryption. + * @throws IOException + */ + private CryptoOutputStream createOutputStream(WritableByteChannel ch) throws IOException { + return new CryptoOutputStream(TRANSFORM, properties, ch, outKeySpec, outIvSpec); + } + + /** + * Create AES crypto input stream + * @param ch The underlying channel used to read data. + * @return Return input crypto stream for decryption. + * @throws IOException + */ + private CryptoInputStream createInputStream(ReadableByteChannel ch) throws IOException { + return new CryptoInputStream(TRANSFORM, properties, ch, inKeySpec, inIvSpec); + } + + /** + * Add handlers to channel + * @param ch the channel for adding handlers + * @throws IOException + */ + public void addToChannel(Channel ch) throws IOException { + ch.pipeline() + .addFirst(ENCRYPTION_HANDLER_NAME, new AesEncryptHandler(this)) + .addFirst(DECRYPTION_HANDLER_NAME, new AesDecryptHandler(this)); + } + + /** + * Create the configuration message + * @param conf is the local transport configuration. + * @return Config message for sending. + */ + public static AesConfigMessage createConfigMessage(TransportConf conf) { + int keySize = conf.aesCipherKeySize(); + Properties properties = CryptoStreamUtils.toCryptoConf(conf); + + try { + int paramLen = CryptoCipherFactory.getCryptoCipher(AesCipher.TRANSFORM, properties) + .getBlockSize(); + byte[] inKey = new byte[keySize]; + byte[] outKey = new byte[keySize]; + byte[] inIv = new byte[paramLen]; + byte[] outIv = new byte[paramLen]; + + CryptoRandom random = CryptoRandomFactory.getCryptoRandom(properties); + random.nextBytes(inKey); + random.nextBytes(outKey); + random.nextBytes(inIv); + random.nextBytes(outIv); + + return new AesConfigMessage(inKey, inIv, outKey, outIv); + } catch (Exception e) { + logger.error("AES config error", e); + throw Throwables.propagate(e); + } + } + + /** + * CryptoStreamUtils is used to convert config from TransportConf to AES Crypto config. + */ + private static class CryptoStreamUtils { + public static Properties toCryptoConf(TransportConf conf) { + Properties props = new Properties(); + if (conf.aesCipherClass() != null) { + props.setProperty(CryptoCipherFactory.CLASSES_KEY, conf.aesCipherClass()); + } + return props; + } + } + + private static class AesEncryptHandler extends ChannelOutboundHandlerAdapter { + private final ByteArrayWritableChannel byteChannel; + private final CryptoOutputStream cos; + + AesEncryptHandler(AesCipher cipher) throws IOException { + byteChannel = new ByteArrayWritableChannel(AesCipher.STREAM_BUFFER_SIZE); + cos = cipher.createOutputStream(byteChannel); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) + throws Exception { + ctx.write(new EncryptedMessage(cos, msg, byteChannel), promise); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + try { + cos.close(); + } finally { + super.close(ctx, promise); + } + } + } + + private static class AesDecryptHandler extends ChannelInboundHandlerAdapter { + private final CryptoInputStream cis; + private final ByteArrayReadableChannel byteChannel; + + AesDecryptHandler(AesCipher cipher) throws IOException { + byteChannel = new ByteArrayReadableChannel(); + cis = cipher.createInputStream(byteChannel); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception { + byteChannel.feedData((ByteBuf) data); + + byte[] decryptedData = new byte[byteChannel.readableBytes()]; + int offset = 0; + while (offset < decryptedData.length) { + offset += cis.read(decryptedData, offset, decryptedData.length - offset); + } + + ctx.fireChannelRead(Unpooled.wrappedBuffer(decryptedData, 0, decryptedData.length)); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + try { + cis.close(); + } finally { + super.channelInactive(ctx); + } + } + } + + private static class EncryptedMessage extends AbstractReferenceCounted implements FileRegion { + private final boolean isByteBuf; + private final ByteBuf buf; + private final FileRegion region; + private long transferred; + private CryptoOutputStream cos; + + // Due to streaming issue CRYPTO-125: https://issues.apache.org/jira/browse/CRYPTO-125, it has + // to utilize two helper ByteArrayWritableChannel for streaming. One is used to receive raw data + // from upper handler, another is used to store encrypted data. + private ByteArrayWritableChannel byteEncChannel; + private ByteArrayWritableChannel byteRawChannel; + + private ByteBuffer currentEncrypted; + + EncryptedMessage(CryptoOutputStream cos, Object msg, ByteArrayWritableChannel ch) { + Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof FileRegion, + "Unrecognized message type: %s", msg.getClass().getName()); + this.isByteBuf = msg instanceof ByteBuf; + this.buf = isByteBuf ? (ByteBuf) msg : null; + this.region = isByteBuf ? null : (FileRegion) msg; + this.transferred = 0; + this.byteRawChannel = new ByteArrayWritableChannel(AesCipher.STREAM_BUFFER_SIZE); + this.cos = cos; + this.byteEncChannel = ch; + } + + @Override + public long count() { + return isByteBuf ? buf.readableBytes() : region.count(); + } + + @Override + public long position() { + return 0; + } + + @Override + public long transfered() { + return transferred; + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + Preconditions.checkArgument(position == transfered(), "Invalid position."); + + do { + if (currentEncrypted == null) { + encryptMore(); + } + + int bytesWritten = currentEncrypted.remaining(); + target.write(currentEncrypted); + bytesWritten -= currentEncrypted.remaining(); + transferred += bytesWritten; + if (!currentEncrypted.hasRemaining()) { + currentEncrypted = null; + byteEncChannel.reset(); + } + } while (transferred < count()); + + return transferred; + } + + private void encryptMore() throws IOException { + byteRawChannel.reset(); + + if (isByteBuf) { + int copied = byteRawChannel.write(buf.nioBuffer()); + buf.skipBytes(copied); + } else { + region.transferTo(byteRawChannel, region.transfered()); + } + cos.write(byteRawChannel.getData(), 0, byteRawChannel.length()); + cos.flush(); + + currentEncrypted = ByteBuffer.wrap(byteEncChannel.getData(), + 0, byteEncChannel.length()); + } + + @Override + protected void deallocate() { + byteRawChannel.reset(); + byteEncChannel.reset(); + if (region != null) { + region.release(); + } + if (buf != null) { + buf.release(); + } + } + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java new file mode 100644 index 0000000000000..3ef6f74a1f89f --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl.aes; + +import java.nio.ByteBuffer; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.Encoders; + +/** + * The AES cipher options for encryption negotiation. + */ +public class AesConfigMessage implements Encodable { + /** Serialization tag used to catch incorrect payloads. */ + private static final byte TAG_BYTE = (byte) 0xEB; + + public byte[] inKey; + public byte[] outKey; + public byte[] inIv; + public byte[] outIv; + + public AesConfigMessage(byte[] inKey, byte[] inIv, byte[] outKey, byte[] outIv) { + if (inKey == null || inIv == null || outKey == null || outIv == null) { + throw new IllegalArgumentException("Cipher Key or IV must not be null!"); + } + + this.inKey = inKey; + this.inIv = inIv; + this.outKey = outKey; + this.outIv = outIv; + } + + @Override + public int encodedLength() { + return 1 + + Encoders.ByteArrays.encodedLength(inKey) + Encoders.ByteArrays.encodedLength(outKey) + + Encoders.ByteArrays.encodedLength(inIv) + Encoders.ByteArrays.encodedLength(outIv); + } + + @Override + public void encode(ByteBuf buf) { + buf.writeByte(TAG_BYTE); + Encoders.ByteArrays.encode(buf, inKey); + Encoders.ByteArrays.encode(buf, inIv); + Encoders.ByteArrays.encode(buf, outKey); + Encoders.ByteArrays.encode(buf, outIv); + } + + /** + * Encode the config message. + * @return ByteBuffer which contains encoded config message. + */ + public ByteBuffer encodeMessage(){ + ByteBuffer buf = ByteBuffer.allocate(encodedLength()); + + ByteBuf wrappedBuf = Unpooled.wrappedBuffer(buf); + wrappedBuf.clear(); + encode(wrappedBuf); + + return buf; + } + + /** + * Decode the config message from buffer + * @param buffer the buffer contain encoded config message + * @return config message + */ + public static AesConfigMessage decodeMessage(ByteBuffer buffer) { + ByteBuf buf = Unpooled.wrappedBuffer(buffer); + + if (buf.readByte() != TAG_BYTE) { + throw new IllegalStateException("Expected AesConfigMessage, received something else" + + " (maybe your client does not have AES enabled?)"); + } + + byte[] outKey = Encoders.ByteArrays.decode(buf); + byte[] outIv = Encoders.ByteArrays.decode(buf); + byte[] inKey = Encoders.ByteArrays.decode(buf); + byte[] inIv = Encoders.ByteArrays.decode(buf); + return new AesConfigMessage(inKey, inIv, outKey, outIv); + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayReadableChannel.java b/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayReadableChannel.java new file mode 100644 index 0000000000000..25d103d0e316f --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayReadableChannel.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; + +import io.netty.buffer.ByteBuf; + +public class ByteArrayReadableChannel implements ReadableByteChannel { + private ByteBuf data; + + public int readableBytes() { + return data.readableBytes(); + } + + public void feedData(ByteBuf buf) { + data = buf; + } + + @Override + public int read(ByteBuffer dst) throws IOException { + int totalRead = 0; + while (data.readableBytes() > 0 && dst.remaining() > 0) { + int bytesToRead = Math.min(data.readableBytes(), dst.remaining()); + dst.put(data.readSlice(bytesToRead).nioBuffer()); + totalRead += bytesToRead; + } + + if (data.readableBytes() == 0) { + data.release(); + } + + return totalRead; + } + + @Override + public void close() throws IOException { + } + + @Override + public boolean isOpen() { + return true; + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 64eaba103cccb..d0d072849d384 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -18,6 +18,7 @@ package org.apache.spark.network.util; import com.google.common.primitives.Ints; +import org.apache.commons.crypto.cipher.CryptoCipherFactory; /** * A central location that tracks all the settings we expose to users. @@ -175,4 +176,25 @@ public boolean saslServerAlwaysEncrypt() { return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false); } + /** + * The trigger for enabling AES encryption. + */ + public boolean aesEncryptionEnabled() { + return conf.getBoolean("spark.authenticate.encryption.aes.enabled", false); + } + + /** + * The implementation class for crypto cipher + */ + public String aesCipherClass() { + return conf.get("spark.authenticate.encryption.aes.cipher.class", null); + } + + /** + * The bytes of AES cipher key which is effective when AES cipher is enabled. Notice that + * the length should be 16, 24 or 32 bytes. + */ + public int aesCipherKeySize() { + return conf.getInt("spark.authenticate.encryption.aes.cipher.keySize", 16); + } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 45cc03df435ac..4e6146cf070d0 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -53,6 +53,7 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.sasl.aes.AesCipher; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; @@ -149,7 +150,7 @@ public Void answer(InvocationOnMock invocation) { .when(rpcHandler) .receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class)); - SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false); + SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false, false); try { ByteBuffer response = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), TimeUnit.SECONDS.toMillis(10)); @@ -275,7 +276,7 @@ public ManagedBuffer answer(InvocationOnMock invocation) { new Random().nextBytes(data); Files.write(data, file); - ctx = new SaslTestCtx(rpcHandler, true, false); + ctx = new SaslTestCtx(rpcHandler, true, false, false); final CountDownLatch lock = new CountDownLatch(1); @@ -317,7 +318,7 @@ public void testServerAlwaysEncrypt() throws Exception { SaslTestCtx ctx = null; try { - ctx = new SaslTestCtx(mock(RpcHandler.class), false, false); + ctx = new SaslTestCtx(mock(RpcHandler.class), false, false, false); fail("Should have failed to connect without encryption."); } catch (Exception e) { assertTrue(e.getCause() instanceof SaslException); @@ -336,7 +337,7 @@ public void testDataEncryptionIsActuallyEnabled() throws Exception { // able to understand RPCs sent to it and thus close the connection. SaslTestCtx ctx = null; try { - ctx = new SaslTestCtx(mock(RpcHandler.class), true, true); + ctx = new SaslTestCtx(mock(RpcHandler.class), true, true, false); ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), TimeUnit.SECONDS.toMillis(10)); fail("Should have failed to send RPC to server."); @@ -374,6 +375,69 @@ public void testDelegates() throws Exception { } } + @Test + public void testAesEncryption() throws Exception { + final AtomicReference response = new AtomicReference<>(); + final File file = File.createTempFile("sasltest", ".txt"); + SaslTestCtx ctx = null; + try { + final TransportConf conf = new TransportConf("rpc", new SystemPropertyConfigProvider()); + final TransportConf spyConf = spy(conf); + doReturn(true).when(spyConf).aesEncryptionEnabled(); + + StreamManager sm = mock(StreamManager.class); + when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new Answer() { + @Override + public ManagedBuffer answer(InvocationOnMock invocation) { + return new FileSegmentManagedBuffer(spyConf, file, 0, file.length()); + } + }); + + RpcHandler rpcHandler = mock(RpcHandler.class); + when(rpcHandler.getStreamManager()).thenReturn(sm); + + byte[] data = new byte[256 * 1024 * 1024]; + new Random().nextBytes(data); + Files.write(data, file); + + ctx = new SaslTestCtx(rpcHandler, true, false, true); + + final Object lock = new Object(); + + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) { + response.set((ManagedBuffer) invocation.getArguments()[1]); + response.get().retain(); + synchronized (lock) { + lock.notifyAll(); + } + return null; + } + }).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class)); + + synchronized (lock) { + ctx.client.fetchChunk(0, 0, callback); + lock.wait(10 * 1000); + } + + verify(callback, times(1)).onSuccess(anyInt(), any(ManagedBuffer.class)); + verify(callback, never()).onFailure(anyInt(), any(Throwable.class)); + + byte[] received = ByteStreams.toByteArray(response.get().createInputStream()); + assertTrue(Arrays.equals(data, received)); + } finally { + file.delete(); + if (ctx != null) { + ctx.close(); + } + if (response.get() != null) { + response.get().release(); + } + } + } + private static class SaslTestCtx { final TransportClient client; @@ -386,18 +450,28 @@ private static class SaslTestCtx { SaslTestCtx( RpcHandler rpcHandler, boolean encrypt, - boolean disableClientEncryption) + boolean disableClientEncryption, + boolean aesEnable) throws Exception { TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + if (aesEnable) { + conf = spy(conf); + doReturn(true).when(conf).aesEncryptionEnabled(); + } + SecretKeyHolder keyHolder = mock(SecretKeyHolder.class); when(keyHolder.getSaslUser(anyString())).thenReturn("user"); when(keyHolder.getSecretKey(anyString())).thenReturn("secret"); TransportContext ctx = new TransportContext(conf, rpcHandler); - this.checker = new EncryptionCheckerBootstrap(); + String encryptHandlerName = aesEnable ? AesCipher.ENCRYPTION_HANDLER_NAME : + SaslEncryption.ENCRYPTION_HANDLER_NAME; + + this.checker = new EncryptionCheckerBootstrap(encryptHandlerName); + this.server = ctx.createServer(Arrays.asList(new SaslServerBootstrap(conf, keyHolder), checker)); @@ -437,13 +511,18 @@ private static class EncryptionCheckerBootstrap extends ChannelOutboundHandlerAd implements TransportServerBootstrap { boolean foundEncryptionHandler; + String encryptHandlerName; + + public EncryptionCheckerBootstrap(String encryptHandlerName) { + this.encryptHandlerName = encryptHandlerName; + } @Override public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { if (!foundEncryptionHandler) { foundEncryptionHandler = - ctx.channel().pipeline().get(SaslEncryption.ENCRYPTION_HANDLER_NAME) != null; + ctx.channel().pipeline().get(encryptHandlerName) != null; } ctx.write(msg, promise); } diff --git a/docs/configuration.md b/docs/configuration.md index d0acd944dd6b9..41c1778ee7fcf 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1529,6 +1529,32 @@ Apart from these, the following properties are also available, and may be useful currently supported by the external shuffle service. + + spark.authenticate.encryption.aes.enabled + false + + Enable AES for over-the-wire encryption + + + + spark.authenticate.encryption.aes.cipher.keySize + 16 + + The bytes of AES cipher key which is effective when AES cipher is enabled. AES + works with 16, 24 and 32 bytes keys. + + + + spark.authenticate.encryption.aes.cipher.class + null + + Specify the underlying implementation class of crypto cipher. Set null here to use default. + In order to use OpenSslCipher users should install openssl. Currently, there are two cipher + classes available in Commons Crypto library: + org.apache.commons.crypto.cipher.OpenSslCipher + org.apache.commons.crypto.cipher.JceCipher + + spark.core.connection.ack.wait.timeout 60s From a531fe1a82ec515314f2db2e2305283fef24067f Mon Sep 17 00:00:00 2001 From: Vinayak Date: Fri, 11 Nov 2016 12:54:16 -0600 Subject: [PATCH 118/198] [SPARK-17843][WEB UI] Indicate event logs pending for processing on history server UI ## What changes were proposed in this pull request? History Server UI's application listing to display information on currently under process event logs so a user knows that pending this processing an application may not list on the UI. When there are no event logs under process, the application list page has a "Last Updated" date-time at the top indicating the date-time of the last _completed_ scan of the event logs. The value is displayed to the user in his/her local time zone. ## How was this patch tested? All unit tests pass. Particularly all the suites under org.apache.spark.deploy.history.\* were run to test changes. - Very first startup - Pending logs - no logs processed yet: screen shot 2016-10-24 at 3 07 04 pm - Very first startup - Pending logs - some logs processed: screen shot 2016-10-24 at 3 18 42 pm - Last updated - No currently pending logs: screen shot 2016-10-17 at 8 34 37 pm - Last updated - With some currently pending logs: screen shot 2016-10-24 at 3 09 31 pm - No applications found and No currently pending logs: screen shot 2016-10-24 at 3 24 26 pm Author: Vinayak Closes #15410 from vijoshi/SAAS-608_master. --- .../spark/ui/static/historypage-common.js | 24 ++++++++ .../history/ApplicationHistoryProvider.scala | 24 ++++++++ .../deploy/history/FsHistoryProvider.scala | 59 +++++++++++++------ .../spark/deploy/history/HistoryPage.scala | 19 ++++++ .../spark/deploy/history/HistoryServer.scala | 8 +++ 5 files changed, 116 insertions(+), 18 deletions(-) create mode 100644 core/src/main/resources/org/apache/spark/ui/static/historypage-common.js diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-common.js b/core/src/main/resources/org/apache/spark/ui/static/historypage-common.js new file mode 100644 index 0000000000000..55d540d8317a0 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-common.js @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +$(document).ready(function() { + if ($('#last-updated').length) { + var lastUpdatedMillis = Number($('#last-updated').text()); + var updatedDate = new Date(lastUpdatedMillis); + $('#last-updated').text(updatedDate.toLocaleDateString()+", "+updatedDate.toLocaleTimeString()) + } +}); diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index 06530ff836466..d7d82800b8b55 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -74,6 +74,30 @@ private[history] case class LoadedAppUI( private[history] abstract class ApplicationHistoryProvider { + /** + * Returns the count of application event logs that the provider is currently still processing. + * History Server UI can use this to indicate to a user that the application listing on the UI + * can be expected to list additional known applications once the processing of these + * application event logs completes. + * + * A History Provider that does not have a notion of count of event logs that may be pending + * for processing need not override this method. + * + * @return Count of application event logs that are currently under process + */ + def getEventLogsUnderProcess(): Int = { + return 0; + } + + /** + * Returns the time the history provider last updated the application history information + * + * @return 0 if this is undefined or unsupported, otherwise the last updated time in millis + */ + def getLastUpdatedTime(): Long = { + return 0; + } + /** * Returns a list of applications available for the history server to show. * diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index dfc1aad64c818..ca38a47639422 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.history import java.io.{FileNotFoundException, IOException, OutputStream} import java.util.UUID -import java.util.concurrent.{Executors, ExecutorService, TimeUnit} +import java.util.concurrent.{Executors, ExecutorService, Future, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.mutable @@ -108,7 +108,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // The modification time of the newest log detected during the last scan. Currently only // used for logging msgs (logs are re-scanned based on file size, rather than modtime) - private var lastScanTime = -1L + private val lastScanTime = new java.util.concurrent.atomic.AtomicLong(-1) // Mapping of application IDs to their metadata, in descending end time order. Apps are inserted // into the map in order, so the LinkedHashMap maintains the correct ordering. @@ -120,6 +120,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // List of application logs to be deleted by event log cleaner. private var attemptsToClean = new mutable.ListBuffer[FsApplicationAttemptInfo] + private val pendingReplayTasksCount = new java.util.concurrent.atomic.AtomicInteger(0) + /** * Return a runnable that performs the given operation on the event logs. * This operation is expected to be executed periodically. @@ -226,6 +228,10 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) applications.get(appId) } + override def getEventLogsUnderProcess(): Int = pendingReplayTasksCount.get() + + override def getLastUpdatedTime(): Long = lastScanTime.get() + override def getAppUI(appId: String, attemptId: Option[String]): Option[LoadedAppUI] = { try { applications.get(appId).flatMap { appInfo => @@ -329,26 +335,43 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) if (logInfos.nonEmpty) { logDebug(s"New/updated attempts found: ${logInfos.size} ${logInfos.map(_.getPath)}") } - logInfos.map { file => - replayExecutor.submit(new Runnable { + + var tasks = mutable.ListBuffer[Future[_]]() + + try { + for (file <- logInfos) { + tasks += replayExecutor.submit(new Runnable { override def run(): Unit = mergeApplicationListing(file) }) } - .foreach { task => - try { - // Wait for all tasks to finish. This makes sure that checkForLogs - // is not scheduled again while some tasks are already running in - // the replayExecutor. - task.get() - } catch { - case e: InterruptedException => - throw e - case e: Exception => - logError("Exception while merging application listings", e) - } + } catch { + // let the iteration over logInfos break, since an exception on + // replayExecutor.submit (..) indicates the ExecutorService is unable + // to take any more submissions at this time + + case e: Exception => + logError(s"Exception while submitting event log for replay", e) + } + + pendingReplayTasksCount.addAndGet(tasks.size) + + tasks.foreach { task => + try { + // Wait for all tasks to finish. This makes sure that checkForLogs + // is not scheduled again while some tasks are already running in + // the replayExecutor. + task.get() + } catch { + case e: InterruptedException => + throw e + case e: Exception => + logError("Exception while merging application listings", e) + } finally { + pendingReplayTasksCount.decrementAndGet() } + } - lastScanTime = newLastScanTime + lastScanTime.set(newLastScanTime) } catch { case e: Exception => logError("Exception in checking for event log updates", e) } @@ -365,7 +388,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } catch { case e: Exception => logError("Exception encountered when attempting to update last scan time", e) - lastScanTime + lastScanTime.get() } finally { if (!fs.delete(path, true)) { logWarning(s"Error deleting ${path}") diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 96b9ecf43b14c..0e7a6c24d4fa5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -30,13 +30,30 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") Option(request.getParameter("showIncomplete")).getOrElse("false").toBoolean val allAppsSize = parent.getApplicationList().count(_.completed != requestedIncomplete) + val eventLogsUnderProcessCount = parent.getEventLogsUnderProcess() + val lastUpdatedTime = parent.getLastUpdatedTime() val providerConfig = parent.getProviderConfig() val content = +
    {providerConfig.map { case (k, v) =>
  • {k}: {v}
  • }}
+ { + if (eventLogsUnderProcessCount > 0) { +

There are {eventLogsUnderProcessCount} event log(s) currently being + processed which may result in additional applications getting listed on this page. + Refresh the page to view updates.

+ } + } + + { + if (lastUpdatedTime > 0) { +

Last updated: {lastUpdatedTime}

+ } + } + { if (allAppsSize > 0) { ++ @@ -46,6 +63,8 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") } else if (requestedIncomplete) {

No incomplete applications found!

+ } else if (eventLogsUnderProcessCount > 0) { +

No completed applications found!

} else {

No completed applications found!

++ parent.emptyListingHtml } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 3175b36b3e56f..7e21fa681aa1e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -179,6 +179,14 @@ class HistoryServer( provider.getListing() } + def getEventLogsUnderProcess(): Int = { + provider.getEventLogsUnderProcess() + } + + def getLastUpdatedTime(): Long = { + provider.getLastUpdatedTime() + } + def getApplicationInfoList: Iterator[ApplicationInfo] = { getApplicationList().map(ApplicationsListResource.appHistoryInfoToPublicAppInfo) } From d42bb7cc4e32c173769bd7da5b9b5eafb510860c Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 11 Nov 2016 13:28:18 -0800 Subject: [PATCH 119/198] [SPARK-17982][SQL] SQLBuilder should wrap the generated SQL with parenthesis for LIMIT ## What changes were proposed in this pull request? Currently, `SQLBuilder` handles `LIMIT` by always adding `LIMIT` at the end of the generated subSQL. It makes `RuntimeException`s like the following. This PR adds a parenthesis always except `SubqueryAlias` is used together with `LIMIT`. **Before** ``` scala scala> sql("CREATE TABLE tbl(id INT)") scala> sql("CREATE VIEW v1(id2) AS SELECT id FROM tbl LIMIT 2") java.lang.RuntimeException: Failed to analyze the canonicalized SQL: ... ``` **After** ``` scala scala> sql("CREATE TABLE tbl(id INT)") scala> sql("CREATE VIEW v1(id2) AS SELECT id FROM tbl LIMIT 2") scala> sql("SELECT id2 FROM v1") res4: org.apache.spark.sql.DataFrame = [id2: int] ``` **Fixed cases in this PR** The following two cases are the detail query plans having problematic SQL generations. 1. `SELECT * FROM (SELECT id FROM tbl LIMIT 2)` Please note that **FROM SELECT** part of the generated SQL in the below. When we don't use '()' for limit, this fails. ```scala # Original logical plan: Project [id#1] +- GlobalLimit 2 +- LocalLimit 2 +- Project [id#1] +- MetastoreRelation default, tbl # Canonicalized logical plan: Project [gen_attr_0#1 AS id#4] +- SubqueryAlias tbl +- Project [gen_attr_0#1] +- GlobalLimit 2 +- LocalLimit 2 +- Project [gen_attr_0#1] +- SubqueryAlias gen_subquery_0 +- Project [id#1 AS gen_attr_0#1] +- SQLTable default, tbl, [id#1] # Generated SQL: SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`tbl`) AS gen_subquery_0 LIMIT 2) AS tbl ``` 2. `SELECT * FROM (SELECT id FROM tbl TABLESAMPLE (2 ROWS))` Please note that **((~~~) AS gen_subquery_0 LIMIT 2)** in the below. When we use '()' for limit on `SubqueryAlias`, this fails. ```scala # Original logical plan: Project [id#1] +- Project [id#1] +- GlobalLimit 2 +- LocalLimit 2 +- MetastoreRelation default, tbl # Canonicalized logical plan: Project [gen_attr_0#1 AS id#4] +- SubqueryAlias tbl +- Project [gen_attr_0#1] +- GlobalLimit 2 +- LocalLimit 2 +- SubqueryAlias gen_subquery_0 +- Project [id#1 AS gen_attr_0#1] +- SQLTable default, tbl, [id#1] # Generated SQL: SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM ((SELECT `id` AS `gen_attr_0` FROM `default`.`tbl`) AS gen_subquery_0 LIMIT 2)) AS tbl ``` ## How was this patch tested? Pass the Jenkins test with a newly added test case. Author: Dongjoon Hyun Closes #15546 from dongjoon-hyun/SPARK-17982. --- .../org/apache/spark/sql/catalyst/SQLBuilder.scala | 7 ++++++- .../test/resources/sqlgen/generate_with_other_1.sql | 2 +- .../test/resources/sqlgen/generate_with_other_2.sql | 2 +- sql/hive/src/test/resources/sqlgen/limit.sql | 4 ++++ .../spark/sql/catalyst/LogicalPlanToSQLSuite.scala | 10 ++++++++++ 5 files changed, 22 insertions(+), 3 deletions(-) create mode 100644 sql/hive/src/test/resources/sqlgen/limit.sql diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala index 6f821f80cc4c5..380454267eaf4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala @@ -138,9 +138,14 @@ class SQLBuilder private ( case g: Generate => generateToSQL(g) - case Limit(limitExpr, child) => + // This prevents a pattern of `((...) AS gen_subquery_0 LIMIT 1)` which does not work. + // For example, `SELECT * FROM (SELECT id FROM tbl TABLESAMPLE (2 ROWS))` makes this plan. + case Limit(limitExpr, child: SubqueryAlias) => s"${toSQL(child)} LIMIT ${limitExpr.sql}" + case Limit(limitExpr, child) => + s"(${toSQL(child)} LIMIT ${limitExpr.sql})" + case Filter(condition, child) => val whereOrHaving = child match { case _: Aggregate => "HAVING" diff --git a/sql/hive/src/test/resources/sqlgen/generate_with_other_1.sql b/sql/hive/src/test/resources/sqlgen/generate_with_other_1.sql index ab444d0c70936..0739f8fff5467 100644 --- a/sql/hive/src/test/resources/sqlgen/generate_with_other_1.sql +++ b/sql/hive/src/test/resources/sqlgen/generate_with_other_1.sql @@ -5,4 +5,4 @@ WHERE id > 2 ORDER BY val, id LIMIT 5 -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `val`, `gen_attr_1` AS `id` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT gen_subquery_0.`gen_attr_2`, gen_subquery_0.`gen_attr_3`, gen_subquery_0.`gen_attr_4`, gen_subquery_0.`gen_attr_1` FROM (SELECT `arr` AS `gen_attr_2`, `arr2` AS `gen_attr_3`, `json` AS `gen_attr_4`, `id` AS `gen_attr_1` FROM `default`.`parquet_t3`) AS gen_subquery_0 WHERE (`gen_attr_1` > CAST(2 AS BIGINT))) AS gen_subquery_1 LATERAL VIEW explode(`gen_attr_2`) gen_subquery_2 AS `gen_attr_0` ORDER BY `gen_attr_0` ASC NULLS FIRST, `gen_attr_1` ASC NULLS FIRST LIMIT 5) AS parquet_t3 +SELECT `gen_attr_0` AS `val`, `gen_attr_1` AS `id` FROM ((SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT gen_subquery_0.`gen_attr_2`, gen_subquery_0.`gen_attr_3`, gen_subquery_0.`gen_attr_4`, gen_subquery_0.`gen_attr_1` FROM (SELECT `arr` AS `gen_attr_2`, `arr2` AS `gen_attr_3`, `json` AS `gen_attr_4`, `id` AS `gen_attr_1` FROM `default`.`parquet_t3`) AS gen_subquery_0 WHERE (`gen_attr_1` > CAST(2 AS BIGINT))) AS gen_subquery_1 LATERAL VIEW explode(`gen_attr_2`) gen_subquery_2 AS `gen_attr_0` ORDER BY `gen_attr_0` ASC NULLS FIRST, `gen_attr_1` ASC NULLS FIRST LIMIT 5)) AS parquet_t3 diff --git a/sql/hive/src/test/resources/sqlgen/generate_with_other_2.sql b/sql/hive/src/test/resources/sqlgen/generate_with_other_2.sql index 42a2369f34d1c..c4b344ee238a5 100644 --- a/sql/hive/src/test/resources/sqlgen/generate_with_other_2.sql +++ b/sql/hive/src/test/resources/sqlgen/generate_with_other_2.sql @@ -7,4 +7,4 @@ WHERE val > 2 ORDER BY val, id LIMIT 5 -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `val`, `gen_attr_1` AS `id` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `arr` AS `gen_attr_4`, `arr2` AS `gen_attr_3`, `json` AS `gen_attr_5`, `id` AS `gen_attr_1` FROM `default`.`parquet_t3`) AS gen_subquery_0 LATERAL VIEW explode(`gen_attr_3`) gen_subquery_2 AS `gen_attr_2` LATERAL VIEW explode(`gen_attr_2`) gen_subquery_3 AS `gen_attr_0` WHERE (`gen_attr_0` > CAST(2 AS BIGINT)) ORDER BY `gen_attr_0` ASC NULLS FIRST, `gen_attr_1` ASC NULLS FIRST LIMIT 5) AS gen_subquery_1 +SELECT `gen_attr_0` AS `val`, `gen_attr_1` AS `id` FROM ((SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `arr` AS `gen_attr_4`, `arr2` AS `gen_attr_3`, `json` AS `gen_attr_5`, `id` AS `gen_attr_1` FROM `default`.`parquet_t3`) AS gen_subquery_0 LATERAL VIEW explode(`gen_attr_3`) gen_subquery_2 AS `gen_attr_2` LATERAL VIEW explode(`gen_attr_2`) gen_subquery_3 AS `gen_attr_0` WHERE (`gen_attr_0` > CAST(2 AS BIGINT)) ORDER BY `gen_attr_0` ASC NULLS FIRST, `gen_attr_1` ASC NULLS FIRST LIMIT 5)) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/limit.sql b/sql/hive/src/test/resources/sqlgen/limit.sql new file mode 100644 index 0000000000000..7a6b060fbf505 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/limit.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT * FROM (SELECT id FROM tbl LIMIT 2) +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0`, `name` AS `gen_attr_1` FROM `default`.`tbl`) AS gen_subquery_0 LIMIT 2)) AS tbl diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala index 8696337b9dc8a..557ea44d1c80b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala @@ -1173,4 +1173,14 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { ) } } + + test("SPARK-17982 - limit") { + withTable("tbl") { + sql("CREATE TABLE tbl(id INT, name STRING)") + checkSQL( + "SELECT * FROM (SELECT id FROM tbl LIMIT 2)", + "limit" + ) + } + } } From 6e95325fc3726d260054bd6e7c0717b3c139917e Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Fri, 11 Nov 2016 13:52:10 -0800 Subject: [PATCH 120/198] [SPARK-18387][SQL] Add serialization to checkEvaluation. ## What changes were proposed in this pull request? This removes the serialization test from RegexpExpressionsSuite and replaces it by serializing all expressions in checkEvaluation. This also fixes math constant expressions by making LeafMathExpression Serializable and fixes NumberFormat values that are null or invalid after serialization. ## How was this patch tested? This patch is to tests. Author: Ryan Blue Closes #15847 from rdblue/SPARK-18387-fix-serializable-expressions. --- .../expressions/mathExpressions.scala | 2 +- .../expressions/stringExpressions.scala | 44 +++++++++++-------- .../expressions/ExpressionEvalHelper.scala | 15 ++++--- .../expressions/RegexpExpressionsSuite.scala | 16 +------ 4 files changed, 36 insertions(+), 41 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index a60494a5bb69d..65273a77b1054 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -36,7 +36,7 @@ import org.apache.spark.unsafe.types.UTF8String * @param name The short name of the function */ abstract class LeafMathExpression(c: Double, name: String) - extends LeafExpression with CodegenFallback { + extends LeafExpression with CodegenFallback with Serializable { override def dataType: DataType = DoubleType override def foldable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 5f533fecf8d07..e74ef9a08750e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1431,18 +1431,20 @@ case class FormatNumber(x: Expression, d: Expression) // Associated with the pattern, for the last d value, and we will update the // pattern (DecimalFormat) once the new coming d value differ with the last one. + // This is an Option to distinguish between 0 (numberFormat is valid) and uninitialized after + // serialization (numberFormat has not been updated for dValue = 0). @transient - private var lastDValue: Int = -100 + private var lastDValue: Option[Int] = None // A cached DecimalFormat, for performance concern, we will change it // only if the d value changed. @transient - private val pattern: StringBuffer = new StringBuffer() + private lazy val pattern: StringBuffer = new StringBuffer() // SPARK-13515: US Locale configures the DecimalFormat object to use a dot ('.') // as a decimal separator. @transient - private val numberFormat = new DecimalFormat("", new DecimalFormatSymbols(Locale.US)) + private lazy val numberFormat = new DecimalFormat("", new DecimalFormatSymbols(Locale.US)) override protected def nullSafeEval(xObject: Any, dObject: Any): Any = { val dValue = dObject.asInstanceOf[Int] @@ -1450,24 +1452,28 @@ case class FormatNumber(x: Expression, d: Expression) return null } - if (dValue != lastDValue) { - // construct a new DecimalFormat only if a new dValue - pattern.delete(0, pattern.length) - pattern.append("#,###,###,###,###,###,##0") - - // decimal place - if (dValue > 0) { - pattern.append(".") - - var i = 0 - while (i < dValue) { - i += 1 - pattern.append("0") + lastDValue match { + case Some(last) if last == dValue => + // use the current pattern + case _ => + // construct a new DecimalFormat only if a new dValue + pattern.delete(0, pattern.length) + pattern.append("#,###,###,###,###,###,##0") + + // decimal place + if (dValue > 0) { + pattern.append(".") + + var i = 0 + while (i < dValue) { + i += 1 + pattern.append("0") + } } - } - lastDValue = dValue - numberFormat.applyLocalizedPattern(pattern.toString) + lastDValue = Some(dValue) + + numberFormat.applyLocalizedPattern(pattern.toString) } x.dataType match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 9ceb709185417..f83650424a964 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -22,7 +22,8 @@ import org.scalactic.TripleEqualsSupport.Spread import org.scalatest.exceptions.TestFailedException import org.scalatest.prop.GeneratorDrivenPropertyChecks -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.JavaSerializer import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer @@ -43,13 +44,15 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { protected def checkEvaluation( expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { + val serializer = new JavaSerializer(new SparkConf()).newInstance + val expr: Expression = serializer.deserialize(serializer.serialize(expression)) val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) - checkEvaluationWithoutCodegen(expression, catalystValue, inputRow) - checkEvaluationWithGeneratedMutableProjection(expression, catalystValue, inputRow) - if (GenerateUnsafeProjection.canSupport(expression.dataType)) { - checkEvalutionWithUnsafeProjection(expression, catalystValue, inputRow) + checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) + checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) + if (GenerateUnsafeProjection.canSupport(expr.dataType)) { + checkEvalutionWithUnsafeProjection(expr, catalystValue, inputRow) } - checkEvaluationWithOptimization(expression, catalystValue, inputRow) + checkEvaluationWithOptimization(expr, catalystValue, inputRow) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index d0d1aaa9d299d..5299549e7b4da 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types.StringType @@ -192,17 +191,4 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringSplit(s1, s2), null, row3) } - test("RegExpReplace serialization") { - val serializer = new JavaSerializer(new SparkConf()).newInstance - - val row = create_row("abc", "b", "") - - val s = 's.string.at(0) - val p = 'p.string.at(1) - val r = 'r.string.at(2) - - val expr: RegExpReplace = serializer.deserialize(serializer.serialize(RegExpReplace(s, p, r))) - checkEvaluation(expr, "ac", row) - } - } From ba23f768f7419039df85530b84258ec31f0c22b4 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Fri, 11 Nov 2016 15:49:55 -0800 Subject: [PATCH 121/198] [SPARK-18264][SPARKR] build vignettes with package, update vignettes for CRAN release build and add info on release ## What changes were proposed in this pull request? Changes to DESCRIPTION to build vignettes. Changes the metadata for vignettes to generate the recommended format (which is about <10% of size before). Unfortunately it does not look as nice (before - left, after - right) ![image](https://cloud.githubusercontent.com/assets/8969467/20040492/b75883e6-a40d-11e6-9534-25cdd5d59a8b.png) ![image](https://cloud.githubusercontent.com/assets/8969467/20040490/a40f4d42-a40d-11e6-8c91-af00ddcbdad9.png) Also add information on how to run build/release to CRAN later. ## How was this patch tested? manually, unit tests shivaram We need this for branch-2.1 Author: Felix Cheung Closes #15790 from felixcheung/rpkgvignettes. --- R/CRAN_RELEASE.md | 91 ++++++++++++++++++++++++++++ R/README.md | 8 +-- R/check-cran.sh | 33 ++++++++-- R/create-docs.sh | 19 +----- R/pkg/DESCRIPTION | 9 ++- R/pkg/vignettes/sparkr-vignettes.Rmd | 9 +-- 6 files changed, 134 insertions(+), 35 deletions(-) create mode 100644 R/CRAN_RELEASE.md diff --git a/R/CRAN_RELEASE.md b/R/CRAN_RELEASE.md new file mode 100644 index 0000000000000..bea8f9fbe4eec --- /dev/null +++ b/R/CRAN_RELEASE.md @@ -0,0 +1,91 @@ +# SparkR CRAN Release + +To release SparkR as a package to CRAN, we would use the `devtools` package. Please work with the +`dev@spark.apache.org` community and R package maintainer on this. + +### Release + +First, check that the `Version:` field in the `pkg/DESCRIPTION` file is updated. Also, check for stale files not under source control. + +Note that while `check-cran.sh` is running `R CMD check`, it is doing so with `--no-manual --no-vignettes`, which skips a few vignettes or PDF checks - therefore it will be preferred to run `R CMD check` on the source package built manually before uploading a release. + +To upload a release, we would need to update the `cran-comments.md`. This should generally contain the results from running the `check-cran.sh` script along with comments on status of all `WARNING` (should not be any) or `NOTE`. As a part of `check-cran.sh` and the release process, the vignettes is build - make sure `SPARK_HOME` is set and Spark jars are accessible. + +Once everything is in place, run in R under the `SPARK_HOME/R` directory: + +```R +paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); devtools::release(); .libPaths(paths) +``` + +For more information please refer to http://r-pkgs.had.co.nz/release.html#release-check + +### Testing: build package manually + +To build package manually such as to inspect the resulting `.tar.gz` file content, we would also use the `devtools` package. + +Source package is what get released to CRAN. CRAN would then build platform-specific binary packages from the source package. + +#### Build source package + +To build source package locally without releasing to CRAN, run in R under the `SPARK_HOME/R` directory: + +```R +paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); devtools::build("pkg"); .libPaths(paths) +``` + +(http://r-pkgs.had.co.nz/vignettes.html#vignette-workflow-2) + +Similarly, the source package is also created by `check-cran.sh` with `R CMD build pkg`. + +For example, this should be the content of the source package: + +```sh +DESCRIPTION R inst tests +NAMESPACE build man vignettes + +inst/doc/ +sparkr-vignettes.html +sparkr-vignettes.Rmd +sparkr-vignettes.Rman + +build/ +vignette.rds + +man/ + *.Rd files... + +vignettes/ +sparkr-vignettes.Rmd +``` + +#### Test source package + +To install, run this: + +```sh +R CMD INSTALL SparkR_2.1.0.tar.gz +``` + +With "2.1.0" replaced with the version of SparkR. + +This command installs SparkR to the default libPaths. Once that is done, you should be able to start R and run: + +```R +library(SparkR) +vignette("sparkr-vignettes", package="SparkR") +``` + +#### Build binary package + +To build binary package locally, run in R under the `SPARK_HOME/R` directory: + +```R +paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); devtools::build("pkg", binary = TRUE); .libPaths(paths) +``` + +For example, this should be the content of the binary package: + +```sh +DESCRIPTION Meta R html tests +INDEX NAMESPACE help profile worker +``` diff --git a/R/README.md b/R/README.md index 932d5272d0b4f..47f9a86dfde11 100644 --- a/R/README.md +++ b/R/README.md @@ -6,7 +6,7 @@ SparkR is an R package that provides a light-weight frontend to use Spark from R Libraries of sparkR need to be created in `$SPARK_HOME/R/lib`. This can be done by running the script `$SPARK_HOME/R/install-dev.sh`. By default the above script uses the system wide installation of R. However, this can be changed to any user installed location of R by setting the environment variable `R_HOME` the full path of the base directory where R is installed, before running install-dev.sh script. -Example: +Example: ```bash # where /home/username/R is where R is installed and /home/username/R/bin contains the files R and RScript export R_HOME=/home/username/R @@ -46,7 +46,7 @@ Sys.setenv(SPARK_HOME="/Users/username/spark") # This line loads SparkR from the installed directory .libPaths(c(file.path(Sys.getenv("SPARK_HOME"), "R", "lib"), .libPaths())) library(SparkR) -sc <- sparkR.init(master="local") +sparkR.session() ``` #### Making changes to SparkR @@ -54,11 +54,11 @@ sc <- sparkR.init(master="local") The [instructions](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) for making contributions to Spark also apply to SparkR. If you only make R file changes (i.e. no Scala changes) then you can just re-install the R package using `R/install-dev.sh` and test your changes. Once you have made your changes, please include unit tests for them and run existing unit tests using the `R/run-tests.sh` script as described below. - + #### Generating documentation The SparkR documentation (Rd files and HTML files) are not a part of the source repository. To generate them you can run the script `R/create-docs.sh`. This script uses `devtools` and `knitr` to generate the docs and these packages need to be installed on the machine before using the script. Also, you may need to install these [prerequisites](https://github.com/apache/spark/tree/master/docs#prerequisites). See also, `R/DOCUMENTATION.md` - + ### Examples, Unit tests SparkR comes with several sample programs in the `examples/src/main/r` directory. diff --git a/R/check-cran.sh b/R/check-cran.sh index bb331466ae931..c5f042848c90c 100755 --- a/R/check-cran.sh +++ b/R/check-cran.sh @@ -36,11 +36,27 @@ if [ ! -z "$R_HOME" ] fi echo "USING R_HOME = $R_HOME" -# Build the latest docs +# Build the latest docs, but not vignettes, which is built with the package next $FWDIR/create-docs.sh -# Build a zip file containing the source package -"$R_SCRIPT_PATH/"R CMD build $FWDIR/pkg +# Build source package with vignettes +SPARK_HOME="$(cd "${FWDIR}"/..; pwd)" +. "${SPARK_HOME}"/bin/load-spark-env.sh +if [ -f "${SPARK_HOME}/RELEASE" ]; then + SPARK_JARS_DIR="${SPARK_HOME}/jars" +else + SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION/jars" +fi + +if [ -d "$SPARK_JARS_DIR" ]; then + # Build a zip file containing the source package with vignettes + SPARK_HOME="${SPARK_HOME}" "$R_SCRIPT_PATH/"R CMD build $FWDIR/pkg + + find pkg/vignettes/. -not -name '.' -not -name '*.Rmd' -not -name '*.md' -not -name '*.pdf' -not -name '*.html' -delete +else + echo "Error Spark JARs not found in $SPARK_HOME" + exit 1 +fi # Run check as-cran. VERSION=`grep Version $FWDIR/pkg/DESCRIPTION | awk '{print $NF}'` @@ -54,11 +70,16 @@ fi if [ -n "$NO_MANUAL" ] then - CRAN_CHECK_OPTIONS=$CRAN_CHECK_OPTIONS" --no-manual" + CRAN_CHECK_OPTIONS=$CRAN_CHECK_OPTIONS" --no-manual --no-vignettes" fi echo "Running CRAN check with $CRAN_CHECK_OPTIONS options" -"$R_SCRIPT_PATH/"R CMD check $CRAN_CHECK_OPTIONS SparkR_"$VERSION".tar.gz - +if [ -n "$NO_TESTS" ] && [ -n "$NO_MANUAL" ] +then + "$R_SCRIPT_PATH/"R CMD check $CRAN_CHECK_OPTIONS SparkR_"$VERSION".tar.gz +else + # This will run tests and/or build vignettes, and require SPARK_HOME + SPARK_HOME="${SPARK_HOME}" "$R_SCRIPT_PATH/"R CMD check $CRAN_CHECK_OPTIONS SparkR_"$VERSION".tar.gz +fi popd > /dev/null diff --git a/R/create-docs.sh b/R/create-docs.sh index 69ffc5f678c36..84e6aa928cb0f 100755 --- a/R/create-docs.sh +++ b/R/create-docs.sh @@ -20,7 +20,7 @@ # Script to create API docs and vignettes for SparkR # This requires `devtools`, `knitr` and `rmarkdown` to be installed on the machine. -# After running this script the html docs can be found in +# After running this script the html docs can be found in # $SPARK_HOME/R/pkg/html # The vignettes can be found in # $SPARK_HOME/R/pkg/vignettes/sparkr_vignettes.html @@ -52,21 +52,4 @@ Rscript -e 'libDir <- "../../lib"; library(SparkR, lib.loc=libDir); library(knit popd -# Find Spark jars. -if [ -f "${SPARK_HOME}/RELEASE" ]; then - SPARK_JARS_DIR="${SPARK_HOME}/jars" -else - SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION/jars" -fi - -# Only create vignettes if Spark JARs exist -if [ -d "$SPARK_JARS_DIR" ]; then - # render creates SparkR vignettes - Rscript -e 'library(rmarkdown); paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); render("pkg/vignettes/sparkr-vignettes.Rmd"); .libPaths(paths)' - - find pkg/vignettes/. -not -name '.' -not -name '*.Rmd' -not -name '*.md' -not -name '*.pdf' -not -name '*.html' -delete -else - echo "Skipping R vignettes as Spark JARs not found in $SPARK_HOME" -fi - popd diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 5a83883089e0e..fe41a9e7dabbd 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,8 +1,8 @@ Package: SparkR Type: Package Title: R Frontend for Apache Spark -Version: 2.0.0 -Date: 2016-08-27 +Version: 2.1.0 +Date: 2016-11-06 Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), email = "shivaram@cs.berkeley.edu"), person("Xiangrui", "Meng", role = "aut", @@ -18,7 +18,9 @@ Depends: Suggests: testthat, e1071, - survival + survival, + knitr, + rmarkdown Description: The SparkR package provides an R frontend for Apache Spark. License: Apache License (== 2.0) Collate: @@ -48,3 +50,4 @@ Collate: 'utils.R' 'window.R' RoxygenNote: 5.0.1 +VignetteBuilder: knitr diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 80e876027bddb..73a5e26a3ba9c 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -1,12 +1,13 @@ --- title: "SparkR - Practical Guide" output: - html_document: - theme: united + rmarkdown::html_vignette: toc: true toc_depth: 4 - toc_float: true - highlight: textmate +vignette: > + %\VignetteIndexEntry{SparkR - Practical Guide} + %\VignetteEngine{knitr::rmarkdown} + \usepackage[utf8]{inputenc} --- ## Overview From 46b2550bcd3690a260b995fd4d024a73b92a0299 Mon Sep 17 00:00:00 2001 From: sethah Date: Sat, 12 Nov 2016 01:38:26 +0000 Subject: [PATCH 122/198] [SPARK-18060][ML] Avoid unnecessary computation for MLOR ## What changes were proposed in this pull request? Before this patch, the gradient updates for multinomial logistic regression were computed by an outer loop over the number of classes and an inner loop over the number of features. Inside the inner loop, we standardized the feature value (`value / featuresStd(index)`), which means we performed the computation `numFeatures * numClasses` times. We only need to perform that computation `numFeatures` times, however. If we re-order the inner and outer loop, we can avoid this, but then we lose sequential memory access. In this patch, we instead lay out the coefficients in column major order while we train, so that we can avoid the extra computation and retain sequential memory access. We convert back to row-major order when we create the model. ## How was this patch tested? This is an implementation detail only, so the original behavior should be maintained. All tests pass. I ran some performance tests to verify speedups. The results are below, and show significant speedups. ## Performance Tests **Setup** 3 node bare-metal cluster 120 cores total 384 gb RAM total **Results** NOTE: The `currentMasterTime` and `thisPatchTime` are times in seconds for a single iteration of L-BFGS or OWL-QN. | | numPoints | numFeatures | numClasses | regParam | elasticNetParam | currentMasterTime (sec) | thisPatchTime (sec) | pctSpeedup | |----|-------------|---------------|--------------|------------|-------------------|---------------------------|-----------------------|--------------| | 0 | 1e+07 | 100 | 500 | 0.5 | 0 | 90 | 18 | 80 | | 1 | 1e+08 | 100 | 50 | 0.5 | 0 | 90 | 19 | 78 | | 2 | 1e+08 | 100 | 50 | 0.05 | 1 | 72 | 19 | 73 | | 3 | 1e+06 | 100 | 5000 | 0.5 | 0 | 93 | 53 | 43 | | 4 | 1e+07 | 100 | 5000 | 0.5 | 0 | 900 | 390 | 56 | | 5 | 1e+08 | 100 | 500 | 0.5 | 0 | 840 | 174 | 79 | | 6 | 1e+08 | 100 | 200 | 0.5 | 0 | 360 | 72 | 80 | | 7 | 1e+08 | 1000 | 5 | 0.5 | 0 | 9 | 3 | 66 | Author: sethah Closes #15593 from sethah/MLOR_PERF_COL_MAJOR_COEF. --- .../classification/LogisticRegression.scala | 125 +++++++++++------- 1 file changed, 74 insertions(+), 51 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index c4651054fd765..18b9b3043db8a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -438,18 +438,14 @@ class LogisticRegression @Since("1.2.0") ( val standardizationParam = $(standardization) def regParamL1Fun = (index: Int) => { // Remove the L1 penalization on the intercept - val isIntercept = $(fitIntercept) && ((index + 1) % numFeaturesPlusIntercept == 0) + val isIntercept = $(fitIntercept) && index >= numFeatures * numCoefficientSets if (isIntercept) { 0.0 } else { if (standardizationParam) { regParamL1 } else { - val featureIndex = if ($(fitIntercept)) { - index % numFeaturesPlusIntercept - } else { - index % numFeatures - } + val featureIndex = index / numCoefficientSets // If `standardization` is false, we still standardize the data // to improve the rate of convergence; as a result, we have to // perform this reverse standardization by penalizing each component @@ -466,6 +462,15 @@ class LogisticRegression @Since("1.2.0") ( new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol)) } + /* + The coefficients are laid out in column major order during training. e.g. for + `numClasses = 3` and `numFeatures = 2` and `fitIntercept = true` the layout is: + + Array(beta_11, beta_21, beta_31, beta_12, beta_22, beta_32, intercept_1, intercept_2, + intercept_3) + + where beta_jk corresponds to the coefficient for class `j` and feature `k`. + */ val initialCoefficientsWithIntercept = Vectors.zeros(numCoefficientSets * numFeaturesPlusIntercept) @@ -489,13 +494,14 @@ class LogisticRegression @Since("1.2.0") ( val initialCoefWithInterceptArray = initialCoefficientsWithIntercept.toArray val providedCoef = optInitialModel.get.coefficientMatrix providedCoef.foreachActive { (row, col, value) => - val flatIndex = row * numFeaturesPlusIntercept + col + // convert matrix to column major for training + val flatIndex = col * numCoefficientSets + row // We need to scale the coefficients since they will be trained in the scaled space initialCoefWithInterceptArray(flatIndex) = value * featuresStd(col) } if ($(fitIntercept)) { optInitialModel.get.interceptVector.foreachActive { (index, value) => - val coefIndex = (index + 1) * numFeaturesPlusIntercept - 1 + val coefIndex = numCoefficientSets * numFeatures + index initialCoefWithInterceptArray(coefIndex) = value } } @@ -526,7 +532,7 @@ class LogisticRegression @Since("1.2.0") ( val rawIntercepts = histogram.map(c => math.log(c + 1)) // add 1 for smoothing val rawMean = rawIntercepts.sum / rawIntercepts.length rawIntercepts.indices.foreach { i => - initialCoefficientsWithIntercept.toArray(i * numFeaturesPlusIntercept + numFeatures) = + initialCoefficientsWithIntercept.toArray(numClasses * numFeatures + i) = rawIntercepts(i) - rawMean } } else if ($(fitIntercept)) { @@ -572,16 +578,20 @@ class LogisticRegression @Since("1.2.0") ( /* The coefficients are trained in the scaled space; we're converting them back to the original space. + + Additionally, since the coefficients were laid out in column major order during training + to avoid extra computation, we convert them back to row major before passing them to the + model. + Note that the intercept in scaled space and original space is the same; as a result, no scaling is needed. */ val rawCoefficients = state.x.toArray.clone() val coefficientArray = Array.tabulate(numCoefficientSets * numFeatures) { i => - // flatIndex will loop though rawCoefficients, and skip the intercept terms. - val flatIndex = if ($(fitIntercept)) i + i / numFeatures else i + val colMajorIndex = (i % numFeatures) * numCoefficientSets + i / numFeatures val featureIndex = i % numFeatures if (featuresStd(featureIndex) != 0.0) { - rawCoefficients(flatIndex) / featuresStd(featureIndex) + rawCoefficients(colMajorIndex) / featuresStd(featureIndex) } else { 0.0 } @@ -618,7 +628,7 @@ class LogisticRegression @Since("1.2.0") ( val interceptsArray: Array[Double] = if ($(fitIntercept)) { Array.tabulate(numCoefficientSets) { i => - val coefIndex = (i + 1) * numFeaturesPlusIntercept - 1 + val coefIndex = numFeatures * numCoefficientSets + i rawCoefficients(coefIndex) } } else { @@ -697,6 +707,7 @@ class LogisticRegressionModel private[spark] ( /** * A vector of model coefficients for "binomial" logistic regression. If this model was trained * using the "multinomial" family then an exception is thrown. + * * @return Vector */ @Since("2.0.0") @@ -720,6 +731,7 @@ class LogisticRegressionModel private[spark] ( /** * The model intercept for "binomial" logistic regression. If this model was fit with the * "multinomial" family then an exception is thrown. + * * @return Double */ @Since("1.3.0") @@ -1389,6 +1401,12 @@ class BinaryLogisticRegressionSummary private[classification] ( * $$ *

* + * @note In order to avoid unnecessary computation during calculation of the gradient updates + * we lay out the coefficients in column major order during training. This allows us to + * perform feature standardization once, while still retaining sequential memory access + * for speed. We convert back to row major order when we create the model, + * since this form is optimal for the matrix operations used for prediction. + * * @param bcCoefficients The broadcast coefficients corresponding to the features. * @param bcFeaturesStd The broadcast standard deviation values of the features. * @param numClasses the number of possible outcomes for k classes classification problem in @@ -1486,23 +1504,25 @@ private class LogisticAggregator( var marginOfLabel = 0.0 var maxMargin = Double.NegativeInfinity - val margins = Array.tabulate(numClasses) { i => - var margin = 0.0 - features.foreachActive { (index, value) => - if (localFeaturesStd(index) != 0.0 && value != 0.0) { - margin += localCoefficients(i * numFeaturesPlusIntercept + index) * - value / localFeaturesStd(index) - } + val margins = new Array[Double](numClasses) + features.foreachActive { (index, value) => + val stdValue = value / localFeaturesStd(index) + var j = 0 + while (j < numClasses) { + margins(j) += localCoefficients(index * numClasses + j) * stdValue + j += 1 } - + } + var i = 0 + while (i < numClasses) { if (fitIntercept) { - margin += localCoefficients(i * numFeaturesPlusIntercept + numFeatures) + margins(i) += localCoefficients(numClasses * numFeatures + i) } - if (i == label.toInt) marginOfLabel = margin - if (margin > maxMargin) { - maxMargin = margin + if (i == label.toInt) marginOfLabel = margins(i) + if (margins(i) > maxMargin) { + maxMargin = margins(i) } - margin + i += 1 } /** @@ -1510,33 +1530,39 @@ private class LogisticAggregator( * We address this by subtracting maxMargin from all the margins, so it's guaranteed * that all of the new margins will be smaller than zero to prevent arithmetic overflow. */ + val multipliers = new Array[Double](numClasses) val sum = { var temp = 0.0 - if (maxMargin > 0) { - for (i <- 0 until numClasses) { - margins(i) -= maxMargin - temp += math.exp(margins(i)) - } - } else { - for (i <- 0 until numClasses) { - temp += math.exp(margins(i)) - } + var i = 0 + while (i < numClasses) { + if (maxMargin > 0) margins(i) -= maxMargin + val exp = math.exp(margins(i)) + temp += exp + multipliers(i) = exp + i += 1 } temp } - for (i <- 0 until numClasses) { - val multiplier = math.exp(margins(i)) / sum - { - if (label == i) 1.0 else 0.0 - } - features.foreachActive { (index, value) => - if (localFeaturesStd(index) != 0.0 && value != 0.0) { - localGradientArray(i * numFeaturesPlusIntercept + index) += - weight * multiplier * value / localFeaturesStd(index) + margins.indices.foreach { i => + multipliers(i) = multipliers(i) / sum - (if (label == i) 1.0 else 0.0) + } + features.foreachActive { (index, value) => + if (localFeaturesStd(index) != 0.0 && value != 0.0) { + val stdValue = value / localFeaturesStd(index) + var j = 0 + while (j < numClasses) { + localGradientArray(index * numClasses + j) += + weight * multipliers(j) * stdValue + j += 1 } } - if (fitIntercept) { - localGradientArray(i * numFeaturesPlusIntercept + numFeatures) += weight * multiplier + } + if (fitIntercept) { + var i = 0 + while (i < numClasses) { + localGradientArray(numFeatures * numClasses + i) += weight * multipliers(i) + i += 1 } } @@ -1637,6 +1663,7 @@ private class LogisticCostFun( val bcCoeffs = instances.context.broadcast(coeffs) val featuresStd = bcFeaturesStd.value val numFeatures = featuresStd.length + val numCoefficientSets = if (multinomial) numClasses else 1 val logisticAggregator = { val seqOp = (c: LogisticAggregator, instance: Instance) => c.add(instance) @@ -1656,7 +1683,7 @@ private class LogisticCostFun( var sum = 0.0 coeffs.foreachActive { case (index, value) => // We do not apply regularization to the intercepts - val isIntercept = fitIntercept && ((index + 1) % (numFeatures + 1) == 0) + val isIntercept = fitIntercept && index >= numCoefficientSets * numFeatures if (!isIntercept) { // The following code will compute the loss of the regularization; also // the gradient of the regularization, and add back to totalGradientArray. @@ -1665,11 +1692,7 @@ private class LogisticCostFun( totalGradientArray(index) += regParamL2 * value value * value } else { - val featureIndex = if (fitIntercept) { - index % (numFeatures + 1) - } else { - index % numFeatures - } + val featureIndex = index / numCoefficientSets if (featuresStd(featureIndex) != 0.0) { // If `standardization` is false, we still standardize the data // to improve the rate of convergence; as a result, we have to From 3af894511be6fcc17731e28b284dba432fe911f5 Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Fri, 11 Nov 2016 18:36:23 -0800 Subject: [PATCH 123/198] [SPARK-16759][CORE] Add a configuration property to pass caller contexts of upstream applications into Spark ## What changes were proposed in this pull request? Many applications take Spark as a computing engine and run on it. This PR adds a configuration property `spark.log.callerContext` that can be used by Spark's upstream applications (e.g. Oozie) to set up their caller contexts into Spark. In the end, Spark will combine its own caller context with the caller contexts of its upstream applications, and write them into Yarn RM log and HDFS audit log. The audit log has a config to truncate the caller contexts passed in (default 128). The caller contexts will be sent over rpc, so it should be concise. The call context written into HDFS log and Yarn log consists of two parts: the information `A` specified by Spark itself and the value `B` of `spark.log.callerContext` property. Currently `A` typically takes 64 to 74 characters, so `B` can have up to 50 characters (mentioned in the doc `running-on-yarn.md`) ## How was this patch tested? Manual tests. I have run some Spark applications with `spark.log.callerContext` configuration in Yarn client/cluster mode, and verified that the caller contexts were written into Yarn RM log and HDFS audit log correctly. The ways to configure `spark.log.callerContext` property: - In spark-defaults.conf: ``` spark.log.callerContext infoSpecifiedByUpstreamApp ``` - In app's source code: ``` val spark = SparkSession .builder .appName("SparkKMeans") .config("spark.log.callerContext", "infoSpecifiedByUpstreamApp") .getOrCreate() ``` When running on Spark Yarn cluster mode, the driver is unable to pass 'spark.log.callerContext' to Yarn client and AM since Yarn client and AM have already started before the driver performs `.config("spark.log.callerContext", "infoSpecifiedByUpstreamApp")`. The following example shows the command line used to submit a SparkKMeans application and the corresponding records in Yarn RM log and HDFS audit log. Command: ``` ./bin/spark-submit --verbose --executor-cores 3 --num-executors 1 --master yarn --deploy-mode client --class org.apache.spark.examples.SparkKMeans examples/target/original-spark-examples_2.11-2.1.0-SNAPSHOT.jar hdfs://localhost:9000/lr_big.txt 2 5 ``` Yarn RM log: screen shot 2016-10-19 at 9 12 03 pm HDFS audit log: screen shot 2016-10-19 at 10 18 14 pm Author: Weiqing Yang Closes #15563 from weiqingy/SPARK-16759. --- .../spark/internal/config/package.scala | 4 ++ .../org/apache/spark/scheduler/Task.scala | 13 ++++- .../scala/org/apache/spark/util/Utils.scala | 53 ++++++++++++------- docs/configuration.md | 9 ++++ .../spark/deploy/yarn/ApplicationMaster.scala | 3 +- .../org/apache/spark/deploy/yarn/Client.scala | 3 +- 6 files changed, 61 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 4a3e3d5c79eff..2951bdc18bc57 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -207,6 +207,10 @@ package object config { .booleanConf .createWithDefault(false) + private[spark] val APP_CALLER_CONTEXT = ConfigBuilder("spark.log.callerContext") + .stringConf + .createOptional + private[spark] val FILES_MAX_PARTITION_BYTES = ConfigBuilder("spark.files.maxPartitionBytes") .doc("The maximum number of bytes to pack into a single partition when reading files.") .longConf diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 9385e3c31e1e4..d39651a722325 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -26,6 +26,7 @@ import scala.collection.mutable.HashMap import org.apache.spark._ import org.apache.spark.executor.TaskMetrics +import org.apache.spark.internal.config.APP_CALLER_CONTEXT import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} import org.apache.spark.metrics.MetricsSystem import org.apache.spark.serializer.SerializerInstance @@ -92,8 +93,16 @@ private[spark] abstract class Task[T]( kill(interruptThread = false) } - new CallerContext("TASK", appId, appAttemptId, jobId, Option(stageId), Option(stageAttemptId), - Option(taskAttemptId), Option(attemptNumber)).setCurrentContext() + new CallerContext( + "TASK", + SparkEnv.get.conf.get(APP_CALLER_CONTEXT), + appId, + appAttemptId, + jobId, + Option(stageId), + Option(stageAttemptId), + Option(taskAttemptId), + Option(attemptNumber)).setCurrentContext() try { runTask(context) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 1de66af632a8a..c27cbe3192846 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2569,6 +2569,7 @@ private[util] object CallerContext extends Logging { * @param from who sets up the caller context (TASK, CLIENT, APPMASTER) * * The parameters below are optional: + * @param upstreamCallerContext caller context the upstream application passes in * @param appId id of the app this task belongs to * @param appAttemptId attempt id of the app this task belongs to * @param jobId id of the job this task belongs to @@ -2578,26 +2579,38 @@ private[util] object CallerContext extends Logging { * @param taskAttemptNumber task attempt id */ private[spark] class CallerContext( - from: String, - appId: Option[String] = None, - appAttemptId: Option[String] = None, - jobId: Option[Int] = None, - stageId: Option[Int] = None, - stageAttemptId: Option[Int] = None, - taskId: Option[Long] = None, - taskAttemptNumber: Option[Int] = None) extends Logging { - - val appIdStr = if (appId.isDefined) s"_${appId.get}" else "" - val appAttemptIdStr = if (appAttemptId.isDefined) s"_${appAttemptId.get}" else "" - val jobIdStr = if (jobId.isDefined) s"_JId_${jobId.get}" else "" - val stageIdStr = if (stageId.isDefined) s"_SId_${stageId.get}" else "" - val stageAttemptIdStr = if (stageAttemptId.isDefined) s"_${stageAttemptId.get}" else "" - val taskIdStr = if (taskId.isDefined) s"_TId_${taskId.get}" else "" - val taskAttemptNumberStr = - if (taskAttemptNumber.isDefined) s"_${taskAttemptNumber.get}" else "" - - val context = "SPARK_" + from + appIdStr + appAttemptIdStr + - jobIdStr + stageIdStr + stageAttemptIdStr + taskIdStr + taskAttemptNumberStr + from: String, + upstreamCallerContext: Option[String] = None, + appId: Option[String] = None, + appAttemptId: Option[String] = None, + jobId: Option[Int] = None, + stageId: Option[Int] = None, + stageAttemptId: Option[Int] = None, + taskId: Option[Long] = None, + taskAttemptNumber: Option[Int] = None) extends Logging { + + private val context = prepareContext("SPARK_" + + from + + appId.map("_" + _).getOrElse("") + + appAttemptId.map("_" + _).getOrElse("") + + jobId.map("_JId_" + _).getOrElse("") + + stageId.map("_SId_" + _).getOrElse("") + + stageAttemptId.map("_" + _).getOrElse("") + + taskId.map("_TId_" + _).getOrElse("") + + taskAttemptNumber.map("_" + _).getOrElse("") + + upstreamCallerContext.map("_" + _).getOrElse("")) + + private def prepareContext(context: String): String = { + // The default max size of Hadoop caller context is 128 + lazy val len = SparkHadoopUtil.get.conf.getInt("hadoop.caller.context.max.size", 128) + if (context == null || context.length <= len) { + context + } else { + val finalContext = context.substring(0, len) + logWarning(s"Truncated Spark caller context from $context to $finalContext") + finalContext + } + } /** * Set up the caller context [[context]] by invoking Hadoop CallerContext API of diff --git a/docs/configuration.md b/docs/configuration.md index 41c1778ee7fcf..ea99592408bac 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -202,6 +202,15 @@ of the most common options to set are: or remotely ("cluster") on one of the nodes inside the cluster. + + spark.log.callerContext + (none) + + Application information that will be written into Yarn RM log/HDFS audit log when running on Yarn/HDFS. + Its length depends on the Hadoop configuration hadoop.caller.context.max.size. It should be concise, + and typically can have up to 50 characters. + + Apart from these, the following properties are also available, and may be useful in some situations: diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index f2b9dfb4d184d..918cc2dd04ab6 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -202,7 +202,8 @@ private[spark] class ApplicationMaster( attemptID = Option(appAttemptId.getAttemptId.toString) } - new CallerContext("APPMASTER", + new CallerContext( + "APPMASTER", sparkConf.get(APP_CALLER_CONTEXT), Option(appAttemptId.getApplicationId.toString), attemptID).setCurrentContext() logInfo("ApplicationAttemptId: " + appAttemptId) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index e77fa386dc933..1b75688b280e6 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -161,7 +161,8 @@ private[spark] class Client( reportLauncherState(SparkAppHandle.State.SUBMITTED) launcherBackend.setAppId(appId.toString) - new CallerContext("CLIENT", Option(appId.toString)).setCurrentContext() + new CallerContext("CLIENT", sparkConf.get(APP_CALLER_CONTEXT), + Option(appId.toString)).setCurrentContext() // Verify whether the cluster has enough resources for our AM verifyClusterResources(newAppResponse) From bc41d997ea287080f549219722b6d9049adef4e2 Mon Sep 17 00:00:00 2001 From: Guoqiang Li Date: Sat, 12 Nov 2016 09:49:14 +0000 Subject: [PATCH 124/198] [SPARK-18375][SPARK-18383][BUILD][CORE] Upgrade netty to 4.0.42.Final ## What changes were proposed in this pull request? One of the important changes for 4.0.42.Final is "Support any FileRegion implementation when using epoll transport netty/netty#5825". In 4.0.42.Final, `MessageWithHeader` can work properly when `spark.[shuffle|rpc].io.mode` is set to epoll ## How was this patch tested? Existing tests Author: Guoqiang Li Closes #15830 from witgo/SPARK-18375_netty-4.0.42. --- core/src/main/scala/org/apache/spark/util/Utils.scala | 4 ++++ dev/deps/spark-deps-hadoop-2.2 | 2 +- dev/deps/spark-deps-hadoop-2.3 | 2 +- dev/deps/spark-deps-hadoop-2.4 | 2 +- dev/deps/spark-deps-hadoop-2.6 | 2 +- dev/deps/spark-deps-hadoop-2.7 | 2 +- pom.xml | 2 +- 7 files changed, 10 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index c27cbe3192846..d341982ae9e8c 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -39,6 +39,7 @@ import scala.reflect.ClassTag import scala.util.Try import scala.util.control.{ControlThrowable, NonFatal} +import _root_.io.netty.channel.unix.Errors.NativeIoException import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} import com.google.common.io.{ByteStreams, Files => GFiles} import com.google.common.net.InetAddresses @@ -2222,6 +2223,9 @@ private[spark] object Utils extends Logging { isBindCollision(e.getCause) case e: MultiException => e.getThrowables.asScala.exists(isBindCollision) + case e: NativeIoException => + (e.getMessage != null && e.getMessage.startsWith("bind() failed: ")) || + isBindCollision(e.getCause) case e: Exception => isBindCollision(e.getCause) case _ => false } diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index 6e749ac16cac0..bbdea069f9496 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -123,7 +123,7 @@ metrics-json-3.1.2.jar metrics-jvm-3.1.2.jar minlog-1.3.0.jar netty-3.8.0.Final.jar -netty-all-4.0.41.Final.jar +netty-all-4.0.42.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index 515995a0a46bd..a2dec41d64519 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -130,7 +130,7 @@ metrics-jvm-3.1.2.jar minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.8.0.Final.jar -netty-all-4.0.41.Final.jar +netty-all-4.0.42.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index d2139fd952406..c1f02b93d751c 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -130,7 +130,7 @@ metrics-jvm-3.1.2.jar minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.8.0.Final.jar -netty-all-4.0.41.Final.jar +netty-all-4.0.42.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index b5cecf72ec35f..4f04636be712b 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -138,7 +138,7 @@ metrics-jvm-3.1.2.jar minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.8.0.Final.jar -netty-all-4.0.41.Final.jar +netty-all-4.0.42.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index a5e03a78e7ea8..da3af9ffa155b 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -139,7 +139,7 @@ metrics-jvm-3.1.2.jar minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.8.0.Final.jar -netty-all-4.0.41.Final.jar +netty-all-4.0.42.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar diff --git a/pom.xml b/pom.xml index 8aa0a6c3caab9..650b4cd965b66 100644 --- a/pom.xml +++ b/pom.xml @@ -552,7 +552,7 @@ io.netty netty-all - 4.0.41.Final + 4.0.42.Final io.netty From 22cb3a060a440205281b71686637679645454ca6 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sat, 12 Nov 2016 06:13:22 -0800 Subject: [PATCH 125/198] [SPARK-14077][ML][FOLLOW-UP] Minor refactor and cleanup for NaiveBayes ## What changes were proposed in this pull request? * Refactor out ```trainWithLabelCheck``` and make ```mllib.NaiveBayes``` call into it. * Avoid capturing the outer object for ```modelType```. * Move ```requireNonnegativeValues``` and ```requireZeroOneBernoulliValues``` to companion object. ## How was this patch tested? Existing tests. Author: Yanbo Liang Closes #15826 from yanboliang/spark-14077-2. --- .../spark/ml/classification/NaiveBayes.scala | 72 +++++++++---------- .../mllib/classification/NaiveBayes.scala | 6 +- 2 files changed, 39 insertions(+), 39 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index b03a07a6bc1e7..f1a7676c74b0e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -76,7 +76,7 @@ class NaiveBayes @Since("1.5.0") ( extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel] with NaiveBayesParams with DefaultParamsWritable { - import NaiveBayes.{Bernoulli, Multinomial} + import NaiveBayes._ @Since("1.5.0") def this() = this(Identifiable.randomUID("nb")) @@ -110,21 +110,20 @@ class NaiveBayes @Since("1.5.0") ( @Since("2.1.0") def setWeightCol(value: String): this.type = set(weightCol, value) + override protected def train(dataset: Dataset[_]): NaiveBayesModel = { + trainWithLabelCheck(dataset, positiveLabel = true) + } + /** * ml assumes input labels in range [0, numClasses). But this implementation * is also called by mllib NaiveBayes which allows other kinds of input labels - * such as {-1, +1}. Here we use this parameter to switch between different processing logic. - * It should be removed when we remove mllib NaiveBayes. + * such as {-1, +1}. `positiveLabel` is used to determine whether the label + * should be checked and it should be removed when we remove mllib NaiveBayes. */ - private[spark] var isML: Boolean = true - - private[spark] def setIsML(isML: Boolean): this.type = { - this.isML = isML - this - } - - override protected def train(dataset: Dataset[_]): NaiveBayesModel = { - if (isML) { + private[spark] def trainWithLabelCheck( + dataset: Dataset[_], + positiveLabel: Boolean): NaiveBayesModel = { + if (positiveLabel) { val numClasses = getNumClasses(dataset) if (isDefined(thresholds)) { require($(thresholds).length == numClasses, this.getClass.getSimpleName + @@ -133,28 +132,9 @@ class NaiveBayes @Since("1.5.0") ( } } - val requireNonnegativeValues: Vector => Unit = (v: Vector) => { - val values = v match { - case sv: SparseVector => sv.values - case dv: DenseVector => dv.values - } - - require(values.forall(_ >= 0.0), - s"Naive Bayes requires nonnegative feature values but found $v.") - } - - val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => { - val values = v match { - case sv: SparseVector => sv.values - case dv: DenseVector => dv.values - } - - require(values.forall(v => v == 0.0 || v == 1.0), - s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.") - } - + val modelTypeValue = $(modelType) val requireValues: Vector => Unit = { - $(modelType) match { + modelTypeValue match { case Multinomial => requireNonnegativeValues case Bernoulli => @@ -226,13 +206,33 @@ class NaiveBayes @Since("1.5.0") ( @Since("1.6.0") object NaiveBayes extends DefaultParamsReadable[NaiveBayes] { /** String name for multinomial model type. */ - private[spark] val Multinomial: String = "multinomial" + private[classification] val Multinomial: String = "multinomial" /** String name for Bernoulli model type. */ - private[spark] val Bernoulli: String = "bernoulli" + private[classification] val Bernoulli: String = "bernoulli" /* Set of modelTypes that NaiveBayes supports */ - private[spark] val supportedModelTypes = Set(Multinomial, Bernoulli) + private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli) + + private[NaiveBayes] def requireNonnegativeValues(v: Vector): Unit = { + val values = v match { + case sv: SparseVector => sv.values + case dv: DenseVector => dv.values + } + + require(values.forall(_ >= 0.0), + s"Naive Bayes requires nonnegative feature values but found $v.") + } + + private[NaiveBayes] def requireZeroOneBernoulliValues(v: Vector): Unit = { + val values = v match { + case sv: SparseVector => sv.values + case dv: DenseVector => dv.values + } + + require(values.forall(v => v == 0.0 || v == 1.0), + s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.") + } @Since("1.6.0") override def load(path: String): NaiveBayes = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 33561be4b5bc1..767d056861a8b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -364,12 +364,12 @@ class NaiveBayes private ( val nb = new NewNaiveBayes() .setModelType(modelType) .setSmoothing(lambda) - .setIsML(false) val dataset = data.map { case LabeledPoint(label, features) => (label, features.asML) } .toDF("label", "features") - val newModel = nb.fit(dataset) + // mllib NaiveBayes allows input labels like {-1, +1}, so set `positiveLabel` as false. + val newModel = nb.trainWithLabelCheck(dataset, positiveLabel = false) val pi = newModel.pi.toArray val theta = Array.fill[Double](newModel.numClasses, newModel.numFeatures)(0.0) @@ -378,7 +378,7 @@ class NaiveBayes private ( theta(i)(j) = v } - require(newModel.oldLabels != null, + assert(newModel.oldLabels != null, "The underlying ML NaiveBayes training does not produce labels.") new NaiveBayesModel(newModel.oldLabels, pi, theta, modelType) } From 1386fd28daf798bf152606f4da30a36223d75d18 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sat, 12 Nov 2016 14:50:37 -0800 Subject: [PATCH 126/198] [SPARK-18418] Fix flags for make_binary_release for hadoop profile ## What changes were proposed in this pull request? Fix the flags used to specify the hadoop version ## How was this patch tested? Manually tested as part of https://github.com/apache/spark/pull/15659 by having the build succeed. cc joshrosen Author: Holden Karau Closes #15860 from holdenk/minor-fix-release-build-script. --- dev/create-release/release-build.sh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 96f9b5714ebb8..81f0d63054e29 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -187,10 +187,10 @@ if [[ "$1" == "package" ]]; then # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds # share the same Zinc server. FLAGS="-Psparkr -Phive -Phive-thriftserver -Pyarn -Pmesos" - make_binary_release "hadoop2.3" "-Phadoop2.3 $FLAGS" "3033" & - make_binary_release "hadoop2.4" "-Phadoop2.4 $FLAGS" "3034" & - make_binary_release "hadoop2.6" "-Phadoop2.6 $FLAGS" "3035" & - make_binary_release "hadoop2.7" "-Phadoop2.7 $FLAGS" "3036" & + make_binary_release "hadoop2.3" "-Phadoop-2.3 $FLAGS" "3033" & + make_binary_release "hadoop2.4" "-Phadoop-2.4 $FLAGS" "3034" & + make_binary_release "hadoop2.6" "-Phadoop-2.6 $FLAGS" "3035" & + make_binary_release "hadoop2.7" "-Phadoop-2.7 $FLAGS" "3036" & make_binary_release "hadoop2.4-without-hive" "-Psparkr -Phadoop-2.4 -Pyarn -Pmesos" "3037" & make_binary_release "without-hadoop" "-Psparkr -Phadoop-provided -Pyarn -Pmesos" "3038" & wait From b91a51bb231af321860415075a7f404bc46e0a74 Mon Sep 17 00:00:00 2001 From: Denny Lee Date: Sun, 13 Nov 2016 18:10:06 -0800 Subject: [PATCH 127/198] [SPARK-18426][STRUCTURED STREAMING] Python Documentation Fix for Structured Streaming Programming Guide ## What changes were proposed in this pull request? Update the python section of the Structured Streaming Guide from .builder() to .builder ## How was this patch tested? Validated documentation and successfully running the test example. Please review https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark before opening a pull request. 'Builder' object is not callable object hence changed .builder() to .builder Author: Denny Lee Closes #15872 from dennyglee/master. --- docs/structured-streaming-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index d838ed35a14fd..d2545584ae3b0 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -58,7 +58,7 @@ from pyspark.sql.functions import explode from pyspark.sql.functions import split spark = SparkSession \ - .builder() \ + .builder \ .appName("StructuredNetworkWordCount") \ .getOrCreate() {% endhighlight %} From 07be232ea12dfc8dc3701ca948814be7dbebf4ee Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sun, 13 Nov 2016 20:25:12 -0800 Subject: [PATCH 128/198] [SPARK-18412][SPARKR][ML] Fix exception for some SparkR ML algorithms training on libsvm data ## What changes were proposed in this pull request? * Fix the following exceptions which throws when ```spark.randomForest```(classification), ```spark.gbt```(classification), ```spark.naiveBayes``` and ```spark.glm```(binomial family) were fitted on libsvm data. ``` java.lang.IllegalArgumentException: requirement failed: If label column already exists, forceIndexLabel can not be set with true. ``` See [SPARK-18412](https://issues.apache.org/jira/browse/SPARK-18412) for more detail about how to reproduce this bug. * Refactor out ```getFeaturesAndLabels``` to RWrapperUtils, since lots of ML algorithm wrappers use this function. * Drop some unwanted columns when making prediction. ## How was this patch tested? Add unit test. Author: Yanbo Liang Closes #15851 from yanboliang/spark-18412. --- R/pkg/inst/tests/testthat/test_mllib.R | 18 ++++++++-- .../spark/ml/r/GBTClassificationWrapper.scala | 18 ++++------ .../GeneralizedLinearRegressionWrapper.scala | 5 ++- .../apache/spark/ml/r/NaiveBayesWrapper.scala | 14 +++----- .../org/apache/spark/ml/r/RWrapperUtils.scala | 36 ++++++++++++++++--- .../r/RandomForestClassificationWrapper.scala | 18 ++++------ 6 files changed, 68 insertions(+), 41 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index b76f75dbdc682..07df4b6d6f844 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -881,7 +881,8 @@ test_that("spark.kstest", { expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:") }) -test_that("spark.randomForest Regression", { +test_that("spark.randomForest", { + # regression data <- suppressWarnings(createDataFrame(longley)) model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, numTrees = 1) @@ -923,9 +924,8 @@ test_that("spark.randomForest Regression", { expect_equal(stats$treeWeights, stats2$treeWeights) unlink(modelPath) -}) -test_that("spark.randomForest Classification", { + # classification data <- suppressWarnings(createDataFrame(iris)) model <- spark.randomForest(data, Species ~ Petal_Length + Petal_Width, "classification", maxDepth = 5, maxBins = 16) @@ -971,6 +971,12 @@ test_that("spark.randomForest Classification", { predictions <- collect(predict(model, data))$prediction expect_equal(length(grep("1.0", predictions)), 50) expect_equal(length(grep("2.0", predictions)), 50) + + # spark.randomForest classification can work on libsvm data + data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), + source = "libsvm") + model <- spark.randomForest(data, label ~ features, "classification") + expect_equal(summary(model)$numFeatures, 4) }) test_that("spark.gbt", { @@ -1039,6 +1045,12 @@ test_that("spark.gbt", { expect_equal(iris2$NumericSpecies, as.double(collect(predict(m, df))$prediction)) expect_equal(s$numFeatures, 5) expect_equal(s$numTrees, 20) + + # spark.gbt classification can work on libsvm data + data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"), + source = "libsvm") + model <- spark.gbt(data, label ~ features, "classification") + expect_equal(summary(model)$numFeatures, 692) }) sparkR.session.stop() diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala index 8946025032200..aacb41ee2659b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala @@ -23,10 +23,10 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} -import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier} import org.apache.spark.ml.feature.{IndexToString, RFormula} import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.r.RWrapperUtils._ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -51,6 +51,7 @@ private[r] class GBTClassifierWrapper private ( pipeline.transform(dataset) .drop(PREDICTED_LABEL_INDEX_COL) .drop(gbtcModel.getFeaturesCol) + .drop(gbtcModel.getLabelCol) } override def write: MLWriter = new @@ -81,19 +82,11 @@ private[r] object GBTClassifierWrapper extends MLReadable[GBTClassifierWrapper] val rFormula = new RFormula() .setFormula(formula) .setForceIndexLabel(true) - RWrapperUtils.checkDataColumns(rFormula, data) + checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) - // get feature names from output schema - val schema = rFormulaModel.transform(data).schema - val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) - .attributes.get - val features = featureAttrs.map(_.name.get) - - // get label names from output schema - val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol)) - .asInstanceOf[NominalAttribute] - val labels = labelAttr.values.get + // get labels and feature names from output schema + val (features, labels) = getFeaturesAndLabels(rFormulaModel, data) // assemble and fit the pipeline val rfc = new GBTClassifier() @@ -109,6 +102,7 @@ private[r] object GBTClassifierWrapper extends MLReadable[GBTClassifierWrapper] .setMaxMemoryInMB(maxMemoryInMB) .setCacheNodeIds(cacheNodeIds) .setFeaturesCol(rFormula.getFeaturesCol) + .setLabelCol(rFormula.getLabelCol) .setPredictionCol(PREDICTED_LABEL_INDEX_COL) if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index 995b1ef03bcec..add4d49110d16 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -29,6 +29,7 @@ import org.apache.spark.ml.regression._ import org.apache.spark.ml.Transformer import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.r.RWrapperUtils._ import org.apache.spark.ml.util._ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ @@ -64,6 +65,7 @@ private[r] class GeneralizedLinearRegressionWrapper private ( .drop(PREDICTED_LABEL_PROB_COL) .drop(PREDICTED_LABEL_INDEX_COL) .drop(glm.getFeaturesCol) + .drop(glm.getLabelCol) } else { pipeline.transform(dataset) .drop(glm.getFeaturesCol) @@ -92,7 +94,7 @@ private[r] object GeneralizedLinearRegressionWrapper regParam: Double): GeneralizedLinearRegressionWrapper = { val rFormula = new RFormula().setFormula(formula) if (family == "binomial") rFormula.setForceIndexLabel(true) - RWrapperUtils.checkDataColumns(rFormula, data) + checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) // get labels and feature names from output schema val schema = rFormulaModel.transform(data).schema @@ -109,6 +111,7 @@ private[r] object GeneralizedLinearRegressionWrapper .setWeightCol(weightCol) .setRegParam(regParam) .setFeaturesCol(rFormula.getFeaturesCol) + .setLabelCol(rFormula.getLabelCol) val pipeline = if (family == "binomial") { // Convert prediction from probability to label index. val probToPred = new ProbabilityToPrediction() diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala index 4fdab2dd94655..0afea4be3d1dd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala @@ -23,9 +23,9 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} -import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel} import org.apache.spark.ml.feature.{IndexToString, RFormula} +import org.apache.spark.ml.r.RWrapperUtils._ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -46,6 +46,7 @@ private[r] class NaiveBayesWrapper private ( pipeline.transform(dataset) .drop(PREDICTED_LABEL_INDEX_COL) .drop(naiveBayesModel.getFeaturesCol) + .drop(naiveBayesModel.getLabelCol) } override def write: MLWriter = new NaiveBayesWrapper.NaiveBayesWrapperWriter(this) @@ -60,21 +61,16 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] { val rFormula = new RFormula() .setFormula(formula) .setForceIndexLabel(true) - RWrapperUtils.checkDataColumns(rFormula, data) + checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) // get labels and feature names from output schema - val schema = rFormulaModel.transform(data).schema - val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol)) - .asInstanceOf[NominalAttribute] - val labels = labelAttr.values.get - val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) - .attributes.get - val features = featureAttrs.map(_.name.get) + val (features, labels) = getFeaturesAndLabels(rFormulaModel, data) // assemble and fit the pipeline val naiveBayes = new NaiveBayes() .setSmoothing(smoothing) .setModelType("bernoulli") .setFeaturesCol(rFormula.getFeaturesCol) + .setLabelCol(rFormula.getLabelCol) .setPredictionCol(PREDICTED_LABEL_INDEX_COL) val idxToStr = new IndexToString() .setInputCol(PREDICTED_LABEL_INDEX_COL) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala index 379007c4d948d..665e50af67d46 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala @@ -18,11 +18,12 @@ package org.apache.spark.ml.r import org.apache.spark.internal.Logging -import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} +import org.apache.spark.ml.feature.{RFormula, RFormulaModel} import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.Dataset -object RWrapperUtils extends Logging { +private[r] object RWrapperUtils extends Logging { /** * DataFrame column check. @@ -32,14 +33,41 @@ object RWrapperUtils extends Logging { * * @param rFormula RFormula instance * @param data Input dataset - * @return Unit */ def checkDataColumns(rFormula: RFormula, data: Dataset[_]): Unit = { if (data.schema.fieldNames.contains(rFormula.getFeaturesCol)) { val newFeaturesName = s"${Identifiable.randomUID(rFormula.getFeaturesCol)}" - logWarning(s"data containing ${rFormula.getFeaturesCol} column, " + + logInfo(s"data containing ${rFormula.getFeaturesCol} column, " + s"using new name $newFeaturesName instead") rFormula.setFeaturesCol(newFeaturesName) } + + if (rFormula.getForceIndexLabel && data.schema.fieldNames.contains(rFormula.getLabelCol)) { + val newLabelName = s"${Identifiable.randomUID(rFormula.getLabelCol)}" + logInfo(s"data containing ${rFormula.getLabelCol} column and we force to index label, " + + s"using new name $newLabelName instead") + rFormula.setLabelCol(newLabelName) + } + } + + /** + * Get the feature names and original labels from the schema + * of DataFrame transformed by RFormulaModel. + * + * @param rFormulaModel The RFormulaModel instance. + * @param data Input dataset. + * @return The feature names and original labels. + */ + def getFeaturesAndLabels( + rFormulaModel: RFormulaModel, + data: Dataset[_]): (Array[String], Array[String]) = { + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol)) + .asInstanceOf[NominalAttribute] + val labels = labelAttr.values.get + (features, labels) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala index 31f846dc6cfec..0b860e5af96e3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala @@ -23,10 +23,10 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} -import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} import org.apache.spark.ml.feature.{IndexToString, RFormula} import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.r.RWrapperUtils._ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -51,6 +51,7 @@ private[r] class RandomForestClassifierWrapper private ( pipeline.transform(dataset) .drop(PREDICTED_LABEL_INDEX_COL) .drop(rfcModel.getFeaturesCol) + .drop(rfcModel.getLabelCol) } override def write: MLWriter = new @@ -82,19 +83,11 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC val rFormula = new RFormula() .setFormula(formula) .setForceIndexLabel(true) - RWrapperUtils.checkDataColumns(rFormula, data) + checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) - // get feature names from output schema - val schema = rFormulaModel.transform(data).schema - val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) - .attributes.get - val features = featureAttrs.map(_.name.get) - - // get label names from output schema - val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol)) - .asInstanceOf[NominalAttribute] - val labels = labelAttr.values.get + // get labels and feature names from output schema + val (features, labels) = getFeaturesAndLabels(rFormulaModel, data) // assemble and fit the pipeline val rfc = new RandomForestClassifier() @@ -111,6 +104,7 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC .setCacheNodeIds(cacheNodeIds) .setProbabilityCol(probabilityCol) .setFeaturesCol(rFormula.getFeaturesCol) + .setLabelCol(rFormula.getLabelCol) .setPredictionCol(PREDICTED_LABEL_INDEX_COL) if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong) From f95b124c68ccc2e318f6ac30685aa47770eea8f3 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 14 Nov 2016 16:52:07 +0900 Subject: [PATCH 129/198] [SPARK-18382][WEBUI] "run at null:-1" in UI when no file/line info in call site info ## What changes were proposed in this pull request? Avoid reporting null/-1 file / line number in call sites if encountering StackTraceElement without this info ## How was this patch tested? Existing tests Author: Sean Owen Closes #15862 from srowen/SPARK-18382. --- core/src/main/scala/org/apache/spark/util/Utils.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index d341982ae9e8c..23b95b9f649fe 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1419,8 +1419,12 @@ private[spark] object Utils extends Logging { } callStack(0) = ste.toString // Put last Spark method on top of the stack trace. } else { - firstUserLine = ste.getLineNumber - firstUserFile = ste.getFileName + if (ste.getFileName != null) { + firstUserFile = ste.getFileName + if (ste.getLineNumber >= 0) { + firstUserLine = ste.getLineNumber + } + } callStack += ste.toString insideSpark = false } From ae6cddb78742be94aa0851ce719f293e0a64ce4f Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Mon, 14 Nov 2016 12:08:06 +0100 Subject: [PATCH 130/198] [SPARK-18166][MLLIB] Fix Poisson GLM bug due to wrong requirement of response values ## What changes were proposed in this pull request? The current implementation of Poisson GLM seems to allow only positive values. This is incorrect since the support of Poisson includes the origin. The bug is easily fixed by changing the test of the Poisson variable from 'require(y **>** 0.0' to 'require(y **>=** 0.0'. mengxr srowen Author: actuaryzhang Author: actuaryzhang Closes #15683 from actuaryzhang/master. --- .../GeneralizedLinearRegression.scala | 4 +- .../GeneralizedLinearRegressionSuite.scala | 45 +++++++++++++++++++ 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 1938e8ecc513d..1d2961e0277f5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -501,8 +501,8 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine val defaultLink: Link = Log override def initialize(y: Double, weight: Double): Double = { - require(y > 0.0, "The response variable of Poisson family " + - s"should be positive, but got $y") + require(y >= 0.0, "The response variable of Poisson family " + + s"should be non-negative, but got $y") y } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 111bc974642d9..6a4ac1735b2cb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -44,6 +44,7 @@ class GeneralizedLinearRegressionSuite @transient var datasetGaussianInverse: DataFrame = _ @transient var datasetBinomial: DataFrame = _ @transient var datasetPoissonLog: DataFrame = _ + @transient var datasetPoissonLogWithZero: DataFrame = _ @transient var datasetPoissonIdentity: DataFrame = _ @transient var datasetPoissonSqrt: DataFrame = _ @transient var datasetGammaInverse: DataFrame = _ @@ -88,6 +89,12 @@ class GeneralizedLinearRegressionSuite xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, family = "poisson", link = "log").toDF() + datasetPoissonLogWithZero = generateGeneralizedLinearRegressionInput( + intercept = -1.5, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 100, seed, noiseLevel = 0.01, + family = "poisson", link = "log") + .map{x => LabeledPoint(if (x.label < 0.7) 0.0 else x.label, x.features)}.toDF() + datasetPoissonIdentity = generateGeneralizedLinearRegressionInput( intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, @@ -139,6 +146,10 @@ class GeneralizedLinearRegressionSuite label + "," + features.toArray.mkString(",") }.repartition(1).saveAsTextFile( "target/tmp/GeneralizedLinearRegressionSuite/datasetPoissonLog") + datasetPoissonLogWithZero.rdd.map { case Row(label: Double, features: Vector) => + label + "," + features.toArray.mkString(",") + }.repartition(1).saveAsTextFile( + "target/tmp/GeneralizedLinearRegressionSuite/datasetPoissonLogWithZero") datasetPoissonIdentity.rdd.map { case Row(label: Double, features: Vector) => label + "," + features.toArray.mkString(",") }.repartition(1).saveAsTextFile( @@ -456,6 +467,40 @@ class GeneralizedLinearRegressionSuite } } + test("generalized linear regression: poisson family against glm (with zero values)") { + /* + R code: + f1 <- data$V1 ~ data$V2 + data$V3 - 1 + f2 <- data$V1 ~ data$V2 + data$V3 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family="poisson", data=data) + print(as.vector(coef(model))) + } + [1] 0.4272661 -0.1565423 + [1] -3.6911354 0.6214301 0.1295814 + */ + val expected = Seq( + Vectors.dense(0.0, 0.4272661, -0.1565423), + Vectors.dense(-3.6911354, 0.6214301, 0.1295814)) + + import GeneralizedLinearRegression._ + + var idx = 0 + val link = "log" + val dataset = datasetPoissonLogWithZero + for (fitIntercept <- Seq(false, true)) { + val trainer = new GeneralizedLinearRegression().setFamily("poisson").setLink(link) + .setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction") + val model = trainer.fit(dataset) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with poisson family, " + + s"$link link and fitIntercept = $fitIntercept (with zero values).") + idx += 1 + } + } + test("generalized linear regression: gamma family against glm") { /* R code: From 637a0bb88f74712001f32a53ff66fd0b8cb67e4a Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Mon, 14 Nov 2016 12:22:36 +0100 Subject: [PATCH 131/198] [SPARK-18396][HISTORYSERVER] Duration" column makes search result confused, maybe we should make it unsearchable MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? When we search data in History Server, it will check if any columns contains the search string. Duration is represented as long value in table, so if we search simple string like "003", "111", the duration containing "003", ‘111“ will be showed, which make not much sense to users. We cannot simply transfer the long value to meaning format like "1 h", "3.2 min" because they are also used for sorting. Better way to handle it is ban "Duration" columns from searching. ## How was this patch tested manually tests. Before("local-1478225166651" pass the filter because its duration in long value, which is "257244245" contains search string "244"): ![before](https://cloud.githubusercontent.com/assets/5276001/20203166/f851ffc6-a7ff-11e6-8fe6-91a90ca92b23.jpg) After: ![after](https://cloud.githubusercontent.com/assets/5276001/20178646/2129fbb0-a78d-11e6-9edb-39f885ce3ed0.jpg) Author: WangTaoTheTonic Closes #15838 from WangTaoTheTonic/duration. --- .../main/resources/org/apache/spark/ui/static/historypage.js | 3 +++ 1 file changed, 3 insertions(+) diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index 6c0ec8d5fce54..8fd91865b0429 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -139,6 +139,9 @@ $(document).ready(function() { {name: 'eighth'}, {name: 'ninth'}, ], + "columnDefs": [ + {"searchable": false, "targets": [5]} + ], "autoWidth": false, "order": [[ 4, "desc" ]] }; From 9d07ceee7860921eafb55b47852f1b51089c98da Mon Sep 17 00:00:00 2001 From: Noritaka Sekiyama Date: Mon, 14 Nov 2016 21:07:59 +0900 Subject: [PATCH 132/198] [SPARK-18432][DOC] Changed HDFS default block size from 64MB to 128MB Changed HDFS default block size from 64MB to 128MB. https://issues.apache.org/jira/browse/SPARK-18432 Author: Noritaka Sekiyama Closes #15879 from moomindani/SPARK-18432. --- docs/programming-guide.md | 6 +++--- docs/tuning.md | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/programming-guide.md b/docs/programming-guide.md index b9a2110b602a0..58bf17b4a84ef 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -343,7 +343,7 @@ Some notes on reading files with Spark: * All of Spark's file-based input methods, including `textFile`, support running on directories, compressed files, and wildcards as well. For example, you can use `textFile("/my/directory")`, `textFile("/my/directory/*.txt")`, and `textFile("/my/directory/*.gz")`. -* The `textFile` method also takes an optional second argument for controlling the number of partitions of the file. By default, Spark creates one partition for each block of the file (blocks being 64MB by default in HDFS), but you can also ask for a higher number of partitions by passing a larger value. Note that you cannot have fewer partitions than blocks. +* The `textFile` method also takes an optional second argument for controlling the number of partitions of the file. By default, Spark creates one partition for each block of the file (blocks being 128MB by default in HDFS), but you can also ask for a higher number of partitions by passing a larger value. Note that you cannot have fewer partitions than blocks. Apart from text files, Spark's Scala API also supports several other data formats: @@ -375,7 +375,7 @@ Some notes on reading files with Spark: * All of Spark's file-based input methods, including `textFile`, support running on directories, compressed files, and wildcards as well. For example, you can use `textFile("/my/directory")`, `textFile("/my/directory/*.txt")`, and `textFile("/my/directory/*.gz")`. -* The `textFile` method also takes an optional second argument for controlling the number of partitions of the file. By default, Spark creates one partition for each block of the file (blocks being 64MB by default in HDFS), but you can also ask for a higher number of partitions by passing a larger value. Note that you cannot have fewer partitions than blocks. +* The `textFile` method also takes an optional second argument for controlling the number of partitions of the file. By default, Spark creates one partition for each block of the file (blocks being 128MB by default in HDFS), but you can also ask for a higher number of partitions by passing a larger value. Note that you cannot have fewer partitions than blocks. Apart from text files, Spark's Java API also supports several other data formats: @@ -407,7 +407,7 @@ Some notes on reading files with Spark: * All of Spark's file-based input methods, including `textFile`, support running on directories, compressed files, and wildcards as well. For example, you can use `textFile("/my/directory")`, `textFile("/my/directory/*.txt")`, and `textFile("/my/directory/*.gz")`. -* The `textFile` method also takes an optional second argument for controlling the number of partitions of the file. By default, Spark creates one partition for each block of the file (blocks being 64MB by default in HDFS), but you can also ask for a higher number of partitions by passing a larger value. Note that you cannot have fewer partitions than blocks. +* The `textFile` method also takes an optional second argument for controlling the number of partitions of the file. By default, Spark creates one partition for each block of the file (blocks being 128MB by default in HDFS), but you can also ask for a higher number of partitions by passing a larger value. Note that you cannot have fewer partitions than blocks. Apart from text files, Spark's Python API also supports several other data formats: diff --git a/docs/tuning.md b/docs/tuning.md index 9c43b315bbb9e..0de303a3bd9bf 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -224,8 +224,8 @@ temporary objects created during task execution. Some steps which may be useful * As an example, if your task is reading data from HDFS, the amount of memory used by the task can be estimated using the size of the data block read from HDFS. Note that the size of a decompressed block is often 2 or 3 times the - size of the block. So if we wish to have 3 or 4 tasks' worth of working space, and the HDFS block size is 64 MB, - we can estimate size of Eden to be `4*3*64MB`. + size of the block. So if we wish to have 3 or 4 tasks' worth of working space, and the HDFS block size is 128 MB, + we can estimate size of Eden to be `4*3*128MB`. * Monitor how the frequency and time taken by garbage collection changes with the new settings. From bdfe60ac921172be0fb77de2f075cc7904a3b238 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 14 Nov 2016 10:03:01 -0800 Subject: [PATCH 133/198] [SPARK-18416][STRUCTURED STREAMING] Fixed temp file leak in state store ## What changes were proposed in this pull request? StateStore.get() causes temporary files to be created immediately, even if the store is not used to make updates for new version. The temp file is not closed as store.commit() is not called in those cases, thus keeping the output stream to temp file open forever. This PR fixes it by opening the temp file only when there are updates being made. ## How was this patch tested? New unit test Author: Tathagata Das Closes #15859 from tdas/SPARK-18416. --- .../state/HDFSBackedStateStoreProvider.scala | 10 +-- .../streaming/state/StateStoreSuite.scala | 63 +++++++++++++++++++ 2 files changed, 68 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 808713161c316..f07feaad5dc71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -87,8 +87,7 @@ private[state] class HDFSBackedStateStoreProvider( private val newVersion = version + 1 private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}") - private val tempDeltaFileStream = compressStream(fs.create(tempDeltaFile, true)) - + private lazy val tempDeltaFileStream = compressStream(fs.create(tempDeltaFile, true)) private val allUpdates = new java.util.HashMap[UnsafeRow, StoreUpdate]() @volatile private var state: STATE = UPDATING @@ -101,7 +100,7 @@ private[state] class HDFSBackedStateStoreProvider( } override def put(key: UnsafeRow, value: UnsafeRow): Unit = { - verify(state == UPDATING, "Cannot remove after already committed or aborted") + verify(state == UPDATING, "Cannot put after already committed or aborted") val isNewKey = !mapToUpdate.containsKey(key) mapToUpdate.put(key, value) @@ -125,6 +124,7 @@ private[state] class HDFSBackedStateStoreProvider( /** Remove keys that match the following condition */ override def remove(condition: UnsafeRow => Boolean): Unit = { verify(state == UPDATING, "Cannot remove after already committed or aborted") + val keyIter = mapToUpdate.keySet().iterator() while (keyIter.hasNext) { val key = keyIter.next @@ -154,7 +154,7 @@ private[state] class HDFSBackedStateStoreProvider( finalizeDeltaFile(tempDeltaFileStream) finalDeltaFile = commitUpdates(newVersion, mapToUpdate, tempDeltaFile) state = COMMITTED - logInfo(s"Committed version $newVersion for $this") + logInfo(s"Committed version $newVersion for $this to file $finalDeltaFile") newVersion } catch { case NonFatal(e) => @@ -174,7 +174,7 @@ private[state] class HDFSBackedStateStoreProvider( if (tempDeltaFile != null) { fs.delete(tempDeltaFile, true) } - logInfo("Aborted") + logInfo(s"Aborted version $newVersion for $this") } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 504a26516107f..533cd0cd2a2ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -468,6 +468,69 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth assert(e.getCause.getMessage.contains("Failed to rename")) } + test("SPARK-18416: do not create temp delta file until the store is updated") { + val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString + val storeId = StateStoreId(dir, 0, 0) + val storeConf = StateStoreConf.empty + val hadoopConf = new Configuration() + val deltaFileDir = new File(s"$dir/0/0/") + + def numTempFiles: Int = { + if (deltaFileDir.exists) { + deltaFileDir.listFiles.map(_.getName).count(n => n.contains("temp") && !n.startsWith(".")) + } else 0 + } + + def numDeltaFiles: Int = { + if (deltaFileDir.exists) { + deltaFileDir.listFiles.map(_.getName).count(n => n.contains(".delta") && !n.startsWith(".")) + } else 0 + } + + def shouldNotCreateTempFile[T](body: => T): T = { + val before = numTempFiles + val result = body + assert(numTempFiles === before) + result + } + + // Getting the store should not create temp file + val store0 = shouldNotCreateTempFile { + StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf) + } + + // Put should create a temp file + put(store0, "a", 1) + assert(numTempFiles === 1) + assert(numDeltaFiles === 0) + + // Commit should remove temp file and create a delta file + store0.commit() + assert(numTempFiles === 0) + assert(numDeltaFiles === 1) + + // Remove should create a temp file + val store1 = shouldNotCreateTempFile { + StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf) + } + remove(store1, _ == "a") + assert(numTempFiles === 1) + assert(numDeltaFiles === 1) + + // Commit should remove temp file and create a delta file + store1.commit() + assert(numTempFiles === 0) + assert(numDeltaFiles === 2) + + // Commit without any updates should create a delta file + val store2 = shouldNotCreateTempFile { + StateStore.get(storeId, keySchema, valueSchema, 2, storeConf, hadoopConf) + } + store2.commit() + assert(numTempFiles === 0) + assert(numDeltaFiles === 3) + } + def getDataFromFiles( provider: HDFSBackedStateStoreProvider, version: Int = -1): Set[(String, Int)] = { From 89d1fa58dbe88560b1f2b0362fcc3035ccc888be Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Mon, 14 Nov 2016 11:10:37 -0800 Subject: [PATCH 134/198] [SPARK-17510][STREAMING][KAFKA] config max rate on a per-partition basis ## What changes were proposed in this pull request? Allow configuration of max rate on a per-topicpartition basis. ## How was this patch tested? Unit tests. The reporter (Jeff Nadler) said he could test on his workload, so let's wait on that report. Author: cody koeninger Closes #15132 from koeninger/SPARK-17510. --- .../kafka010/DirectKafkaInputDStream.scala | 11 ++-- .../spark/streaming/kafka010/KafkaUtils.scala | 53 ++++++++++++++++++- .../kafka010/PerPartitionConfig.scala | 47 ++++++++++++++++ .../kafka010/DirectKafkaStreamSuite.scala | 34 ++++++++---- .../kafka/DirectKafkaInputDStream.scala | 4 +- 5 files changed, 131 insertions(+), 18 deletions(-) create mode 100644 external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/PerPartitionConfig.scala diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index 7e57bb18cbd50..794f53c5abfd0 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -57,7 +57,8 @@ import org.apache.spark.streaming.scheduler.rate.RateEstimator private[spark] class DirectKafkaInputDStream[K, V]( _ssc: StreamingContext, locationStrategy: LocationStrategy, - consumerStrategy: ConsumerStrategy[K, V] + consumerStrategy: ConsumerStrategy[K, V], + ppc: PerPartitionConfig ) extends InputDStream[ConsumerRecord[K, V]](_ssc) with Logging with CanCommitOffsets { val executorKafkaParams = { @@ -128,12 +129,9 @@ private[spark] class DirectKafkaInputDStream[K, V]( } } - private val maxRateLimitPerPartition: Int = context.sparkContext.getConf.getInt( - "spark.streaming.kafka.maxRatePerPartition", 0) - protected[streaming] def maxMessagesPerPartition( offsets: Map[TopicPartition, Long]): Option[Map[TopicPartition, Long]] = { - val estimatedRateLimit = rateController.map(_.getLatestRate().toInt) + val estimatedRateLimit = rateController.map(_.getLatestRate()) // calculate a per-partition rate limit based on current lag val effectiveRateLimitPerPartition = estimatedRateLimit.filter(_ > 0) match { @@ -144,11 +142,12 @@ private[spark] class DirectKafkaInputDStream[K, V]( val totalLag = lagPerPartition.values.sum lagPerPartition.map { case (tp, lag) => + val maxRateLimitPerPartition = ppc.maxRatePerPartition(tp) val backpressureRate = Math.round(lag / totalLag.toFloat * rate) tp -> (if (maxRateLimitPerPartition > 0) { Math.min(backpressureRate, maxRateLimitPerPartition)} else backpressureRate) } - case None => offsets.map { case (tp, offset) => tp -> maxRateLimitPerPartition } + case None => offsets.map { case (tp, offset) => tp -> ppc.maxRatePerPartition(tp) } } if (effectiveRateLimitPerPartition.values.sum > 0) { diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala index b2190bfa05a3a..c11917f59d5b8 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala @@ -123,7 +123,31 @@ object KafkaUtils extends Logging { locationStrategy: LocationStrategy, consumerStrategy: ConsumerStrategy[K, V] ): InputDStream[ConsumerRecord[K, V]] = { - new DirectKafkaInputDStream[K, V](ssc, locationStrategy, consumerStrategy) + val ppc = new DefaultPerPartitionConfig(ssc.sparkContext.getConf) + createDirectStream[K, V](ssc, locationStrategy, consumerStrategy, ppc) + } + + /** + * :: Experimental :: + * Scala constructor for a DStream where + * each given Kafka topic/partition corresponds to an RDD partition. + * @param locationStrategy In most cases, pass in LocationStrategies.preferConsistent, + * see [[LocationStrategies]] for more details. + * @param consumerStrategy In most cases, pass in ConsumerStrategies.subscribe, + * see [[ConsumerStrategies]] for more details. + * @param perPartitionConfig configuration of settings such as max rate on a per-partition basis. + * see [[PerPartitionConfig]] for more details. + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + */ + @Experimental + def createDirectStream[K, V]( + ssc: StreamingContext, + locationStrategy: LocationStrategy, + consumerStrategy: ConsumerStrategy[K, V], + perPartitionConfig: PerPartitionConfig + ): InputDStream[ConsumerRecord[K, V]] = { + new DirectKafkaInputDStream[K, V](ssc, locationStrategy, consumerStrategy, perPartitionConfig) } /** @@ -150,6 +174,33 @@ object KafkaUtils extends Logging { jssc.ssc, locationStrategy, consumerStrategy)) } + /** + * :: Experimental :: + * Java constructor for a DStream where + * each given Kafka topic/partition corresponds to an RDD partition. + * @param keyClass Class of the keys in the Kafka records + * @param valueClass Class of the values in the Kafka records + * @param locationStrategy In most cases, pass in LocationStrategies.preferConsistent, + * see [[LocationStrategies]] for more details. + * @param consumerStrategy In most cases, pass in ConsumerStrategies.subscribe, + * see [[ConsumerStrategies]] for more details + * @param perPartitionConfig configuration of settings such as max rate on a per-partition basis. + * see [[PerPartitionConfig]] for more details. + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + */ + @Experimental + def createDirectStream[K, V]( + jssc: JavaStreamingContext, + locationStrategy: LocationStrategy, + consumerStrategy: ConsumerStrategy[K, V], + perPartitionConfig: PerPartitionConfig + ): JavaInputDStream[ConsumerRecord[K, V]] = { + new JavaInputDStream( + createDirectStream[K, V]( + jssc.ssc, locationStrategy, consumerStrategy, perPartitionConfig)) + } + /** * Tweak kafka params to prevent issues on executors */ diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/PerPartitionConfig.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/PerPartitionConfig.scala new file mode 100644 index 0000000000000..4792f2a955110 --- /dev/null +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/PerPartitionConfig.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010 + +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.SparkConf +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * Interface for user-supplied configurations that can't otherwise be set via Spark properties, + * because they need tweaking on a per-partition basis, + */ +@Experimental +abstract class PerPartitionConfig extends Serializable { + /** + * Maximum rate (number of records per second) at which data will be read + * from each Kafka partition. + */ + def maxRatePerPartition(topicPartition: TopicPartition): Long +} + +/** + * Default per-partition configuration + */ +private class DefaultPerPartitionConfig(conf: SparkConf) + extends PerPartitionConfig { + val maxRate = conf.getLong("spark.streaming.kafka.maxRatePerPartition", 0) + + def maxRatePerPartition(topicPartition: TopicPartition): Long = maxRate +} diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala index c81836da3cbbf..fde3714d3d02e 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala @@ -252,7 +252,8 @@ class DirectKafkaStreamSuite val s = new DirectKafkaInputDStream[String, String]( ssc, preferredHosts, - ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala)) + ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala), + new DefaultPerPartitionConfig(sparkConf)) s.consumer.poll(0) assert( s.consumer.position(topicPartition) >= offsetBeforeStart, @@ -307,7 +308,8 @@ class DirectKafkaStreamSuite ConsumerStrategies.Assign[String, String]( List(topicPartition), kafkaParams.asScala, - Map(topicPartition -> 11L))) + Map(topicPartition -> 11L)), + new DefaultPerPartitionConfig(sparkConf)) s.consumer.poll(0) assert( s.consumer.position(topicPartition) >= offsetBeforeStart, @@ -520,7 +522,7 @@ class DirectKafkaStreamSuite test("maxMessagesPerPartition with backpressure disabled") { val topic = "maxMessagesPerPartition" - val kafkaStream = getDirectKafkaStream(topic, None) + val kafkaStream = getDirectKafkaStream(topic, None, None) val input = Map(new TopicPartition(topic, 0) -> 50L, new TopicPartition(topic, 1) -> 50L) assert(kafkaStream.maxMessagesPerPartition(input).get == @@ -530,7 +532,7 @@ class DirectKafkaStreamSuite test("maxMessagesPerPartition with no lag") { val topic = "maxMessagesPerPartition" val rateController = Some(new ConstantRateController(0, new ConstantEstimator(100), 100)) - val kafkaStream = getDirectKafkaStream(topic, rateController) + val kafkaStream = getDirectKafkaStream(topic, rateController, None) val input = Map(new TopicPartition(topic, 0) -> 0L, new TopicPartition(topic, 1) -> 0L) assert(kafkaStream.maxMessagesPerPartition(input).isEmpty) @@ -539,11 +541,19 @@ class DirectKafkaStreamSuite test("maxMessagesPerPartition respects max rate") { val topic = "maxMessagesPerPartition" val rateController = Some(new ConstantRateController(0, new ConstantEstimator(100), 1000)) - val kafkaStream = getDirectKafkaStream(topic, rateController) + val ppc = Some(new PerPartitionConfig { + def maxRatePerPartition(tp: TopicPartition) = + if (tp.topic == topic && tp.partition == 0) { + 50 + } else { + 100 + } + }) + val kafkaStream = getDirectKafkaStream(topic, rateController, ppc) val input = Map(new TopicPartition(topic, 0) -> 1000L, new TopicPartition(topic, 1) -> 1000L) assert(kafkaStream.maxMessagesPerPartition(input).get == - Map(new TopicPartition(topic, 0) -> 10L, new TopicPartition(topic, 1) -> 10L)) + Map(new TopicPartition(topic, 0) -> 5L, new TopicPartition(topic, 1) -> 10L)) } test("using rate controller") { @@ -572,7 +582,9 @@ class DirectKafkaStreamSuite new DirectKafkaInputDStream[String, String]( ssc, preferredHosts, - ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala)) { + ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala), + new DefaultPerPartitionConfig(sparkConf) + ) { override protected[streaming] val rateController = Some(new DirectKafkaRateController(id, estimator)) }.map(r => (r.key, r.value)) @@ -618,7 +630,10 @@ class DirectKafkaStreamSuite }.toSeq.sortBy { _._1 } } - private def getDirectKafkaStream(topic: String, mockRateController: Option[RateController]) = { + private def getDirectKafkaStream( + topic: String, + mockRateController: Option[RateController], + ppc: Option[PerPartitionConfig]) = { val batchIntervalMilliseconds = 100 val sparkConf = new SparkConf() @@ -645,7 +660,8 @@ class DirectKafkaStreamSuite tps.foreach(tp => consumer.seek(tp, 0)) consumer } - } + }, + ppc.getOrElse(new DefaultPerPartitionConfig(sparkConf)) ) { override protected[streaming] val rateController = mockRateController } diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index c3c799375bbeb..d52c230eb7849 100644 --- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -88,12 +88,12 @@ class DirectKafkaInputDStream[ protected val kc = new KafkaCluster(kafkaParams) - private val maxRateLimitPerPartition: Int = context.sparkContext.getConf.getInt( + private val maxRateLimitPerPartition: Long = context.sparkContext.getConf.getLong( "spark.streaming.kafka.maxRatePerPartition", 0) protected[streaming] def maxMessagesPerPartition( offsets: Map[TopicAndPartition, Long]): Option[Map[TopicAndPartition, Long]] = { - val estimatedRateLimit = rateController.map(_.getLatestRate().toInt) + val estimatedRateLimit = rateController.map(_.getLatestRate()) // calculate a per-partition rate limit based on current lag val effectiveRateLimitPerPartition = estimatedRateLimit.filter(_ > 0) match { From 75934457d75996be71ffd0d4b448497d656c0d40 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Mon, 14 Nov 2016 19:42:00 +0000 Subject: [PATCH 135/198] [SPARK-11496][GRAPHX][FOLLOWUP] Add param checking for runParallelPersonalizedPageRank ## What changes were proposed in this pull request? add the param checking to keep in line with other algos ## How was this patch tested? existing tests Author: Zheng RuiFeng Closes #15876 from zhengruifeng/param_check_runParallelPersonalizedPageRank. --- .../main/scala/org/apache/spark/graphx/lib/PageRank.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index f4b00757a8b54..c0c3c73463aab 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -185,6 +185,13 @@ object PageRank extends Logging { def runParallelPersonalizedPageRank[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], numIter: Int, resetProb: Double = 0.15, sources: Array[VertexId]): Graph[Vector, Double] = { + require(numIter > 0, s"Number of iterations must be greater than 0," + + s" but got ${numIter}") + require(resetProb >= 0 && resetProb <= 1, s"Random reset probability must belong" + + s" to [0, 1], but got ${resetProb}") + require(sources.nonEmpty, s"The list of sources must be non-empty," + + s" but got ${sources.mkString("[", ",", "]")}") + // TODO if one sources vertex id is outside of the int range // we won't be able to store its activations in a sparse vector val zero = Vectors.sparse(sources.size, List()).asBreeze From bd85603ba5f9e61e1aa8326d3e4d5703b5977a4c Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Mon, 14 Nov 2016 20:59:15 +0100 Subject: [PATCH 136/198] [SPARK-17348][SQL] Incorrect results from subquery transformation ## What changes were proposed in this pull request? Return an Analysis exception when there is a correlated non-equality predicate in a subquery and the correlated column from the outer reference is not from the immediate parent operator of the subquery. This PR prevents incorrect results from subquery transformation in such case. Test cases, both positive and negative tests, are added. ## How was this patch tested? sql/test, catalyst/test, hive/test, and scenarios that will produce incorrect results without this PR and product correct results when subquery transformation does happen. Author: Nattavut Sutyanyong Closes #15763 from nsyca/spark-17348. --- .../sql/catalyst/analysis/Analyzer.scala | 44 +++++++++ .../sql/catalyst/analysis/CheckAnalysis.scala | 7 -- .../org/apache/spark/sql/SubquerySuite.scala | 95 ++++++++++++++++++- 3 files changed, 137 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index dd68d60d3e839..c14f353517088 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1031,6 +1031,37 @@ class Analyzer( } } + // SPARK-17348: A potential incorrect result case. + // When a correlated predicate is a non-equality predicate, + // certain operators are not permitted from the operator + // hosting the correlated predicate up to the operator on the outer table. + // Otherwise, the pull up of the correlated predicate + // will generate a plan with a different semantics + // which could return incorrect result. + // Currently we check for Aggregate and Window operators + // + // Below shows an example of a Logical Plan during Analyzer phase that + // show this problem. Pulling the correlated predicate [outer(c2#77) >= ..] + // through the Aggregate (or Window) operator could alter the result of + // the Aggregate. + // + // Project [c1#76] + // +- Project [c1#87, c2#88] + // : (Aggregate or Window operator) + // : +- Filter [outer(c2#77) >= c2#88)] + // : +- SubqueryAlias t2, `t2` + // : +- Project [_1#84 AS c1#87, _2#85 AS c2#88] + // : +- LocalRelation [_1#84, _2#85] + // +- SubqueryAlias t1, `t1` + // +- Project [_1#73 AS c1#76, _2#74 AS c2#77] + // +- LocalRelation [_1#73, _2#74] + def failOnNonEqualCorrelatedPredicate(found: Boolean, p: LogicalPlan): Unit = { + if (found) { + // Report a non-supported case as an exception + failAnalysis(s"Correlated column is not allowed in a non-equality predicate:\n$p") + } + } + /** Determine which correlated predicate references are missing from this plan. */ def missingReferences(p: LogicalPlan): AttributeSet = { val localPredicateReferences = p.collect(predicateMap) @@ -1041,12 +1072,20 @@ class Analyzer( localPredicateReferences -- p.outputSet } + var foundNonEqualCorrelatedPred : Boolean = false + // Simplify the predicates before pulling them out. val transformed = BooleanSimplification(sub) transformUp { case f @ Filter(cond, child) => // Find all predicates with an outer reference. val (correlated, local) = splitConjunctivePredicates(cond).partition(containsOuter) + // Find any non-equality correlated predicates + foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists { + case _: EqualTo | _: EqualNullSafe => false + case _ => true + } + // Rewrite the filter without the correlated predicates if any. correlated match { case Nil => f @@ -1068,12 +1107,17 @@ class Analyzer( } case a @ Aggregate(grouping, expressions, child) => failOnOuterReference(a) + failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a) + val referencesToAdd = missingReferences(a) if (referencesToAdd.nonEmpty) { Aggregate(grouping ++ referencesToAdd, expressions ++ referencesToAdd, child) } else { a } + case w : Window => + failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, w) + w case j @ Join(left, _, RightOuter, _) => failOnOuterReference(j) failOnOuterReferenceInSubTree(left, "a RIGHT OUTER JOIN") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 3455a567b7786..7b75c1f70974b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -119,13 +119,6 @@ trait CheckAnalysis extends PredicateHelper { } case s @ ScalarSubquery(query, conditions, _) if conditions.nonEmpty => - // Make sure we are using equi-joins. - conditions.foreach { - case _: EqualTo | _: EqualNullSafe => // ok - case e => failAnalysis( - s"The correlated scalar subquery can only contain equality predicates: $e") - } - // Make sure correlated scalar subqueries contain one row for every outer row by // enforcing that they are aggregates which contain exactly one aggregate expressions. // The analyzer has already checked that subquery contained only one output column, and diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 89348668340be..c84a6f161893c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -498,10 +498,10 @@ class SubquerySuite extends QueryTest with SharedSQLContext { test("non-equal correlated scalar subquery") { val msg1 = intercept[AnalysisException] { - sql("select a, (select b from l l2 where l2.a < l1.a) sum_b from l l1") + sql("select a, (select sum(b) from l l2 where l2.a < l1.a) sum_b from l l1") } assert(msg1.getMessage.contains( - "The correlated scalar subquery can only contain equality predicates")) + "Correlated column is not allowed in a non-equality predicate:")) } test("disjunctive correlated scalar subquery") { @@ -639,6 +639,97 @@ class SubquerySuite extends QueryTest with SharedSQLContext { | from t1 left join t2 on t1.c1=t2.c2) t3 | where c3 not in (select c2 from t2)""".stripMargin), Row(2) :: Nil) + } + } + + test("SPARK-17348: Correlated subqueries with non-equality predicate (good case)") { + withTempView("t1", "t2") { + Seq((1, 1)).toDF("c1", "c2").createOrReplaceTempView("t1") + Seq((1, 1), (2, 0)).toDF("c1", "c2").createOrReplaceTempView("t2") + + // Simple case + checkAnswer( + sql( + """ + | select c1 + | from t1 + | where c1 in (select t2.c1 + | from t2 + | where t1.c2 >= t2.c2)""".stripMargin), + Row(1) :: Nil) + + // More complex case with OR predicate + checkAnswer( + sql( + """ + | select t1.c1 + | from t1, t1 as t3 + | where t1.c1 = t3.c1 + | and (t1.c1 in (select t2.c1 + | from t2 + | where t1.c2 >= t2.c2 + | or t3.c2 < t2.c2) + | or t1.c2 >= 0)""".stripMargin), + Row(1) :: Nil) + } + } + + test("SPARK-17348: Correlated subqueries with non-equality predicate (error case)") { + withTempView("t1", "t2", "t3", "t4") { + Seq((1, 1)).toDF("c1", "c2").createOrReplaceTempView("t1") + Seq((1, 1), (2, 0)).toDF("c1", "c2").createOrReplaceTempView("t2") + Seq((2, 1)).toDF("c1", "c2").createOrReplaceTempView("t3") + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t4") + + // Simplest case + intercept[AnalysisException] { + sql( + """ + | select t1.c1 + | from t1 + | where t1.c1 in (select max(t2.c1) + | from t2 + | where t1.c2 >= t2.c2)""".stripMargin).collect() + } + + // Add a HAVING on top and augmented within an OR predicate + intercept[AnalysisException] { + sql( + """ + | select t1.c1 + | from t1 + | where t1.c1 in (select max(t2.c1) + | from t2 + | where t1.c2 >= t2.c2 + | having count(*) > 0 ) + | or t1.c2 >= 0""".stripMargin).collect() + } + + // Add a HAVING on top and augmented within an OR predicate + intercept[AnalysisException] { + sql( + """ + | select t1.c1 + | from t1, t1 as t3 + | where t1.c1 = t3.c1 + | and (t1.c1 in (select max(t2.c1) + | from t2 + | where t1.c2 = t2.c2 + | or t3.c2 = t2.c2) + | )""".stripMargin).collect() + } + + // In Window expression: changing the data set to + // demonstrate if this query ran, it would return incorrect result. + intercept[AnalysisException] { + sql( + """ + | select c1 + | from t3 + | where c1 in (select max(t4.c1) over () + | from t4 + | where t3.c2 >= t4.c2)""".stripMargin).collect() + } } } } From c07187823a98f0d1a0f58c06e28a27e1abed157a Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 14 Nov 2016 16:46:26 -0800 Subject: [PATCH 137/198] [SPARK-18124] Observed delay based Event Time Watermarks This PR adds a new method `withWatermark` to the `Dataset` API, which can be used specify an _event time watermark_. An event time watermark allows the streaming engine to reason about the point in time after which we no longer expect to see late data. This PR also has augmented `StreamExecution` to use this watermark for several purposes: - To know when a given time window aggregation is finalized and thus results can be emitted when using output modes that do not allow updates (e.g. `Append` mode). - To minimize the amount of state that we need to keep for on-going aggregations, by evicting state for groups that are no longer expected to change. Although, we do still maintain all state if the query requires (i.e. if the event time is not present in the `groupBy` or when running in `Complete` mode). An example that emits windowed counts of records, waiting up to 5 minutes for late data to arrive. ```scala df.withWatermark("eventTime", "5 minutes") .groupBy(window($"eventTime", "1 minute") as 'window) .count() .writeStream .format("console") .mode("append") // In append mode, we only output finalized aggregations. .start() ``` ### Calculating the watermark. The current event time is computed by looking at the `MAX(eventTime)` seen this epoch across all of the partitions in the query minus some user defined _delayThreshold_. An additional constraint is that the watermark must increase monotonically. Note that since we must coordinate this value across partitions occasionally, the actual watermark used is only guaranteed to be at least `delay` behind the actual event time. In some cases we may still process records that arrive more than delay late. This mechanism was chosen for the initial implementation over processing time for two reasons: - it is robust to downtime that could affect processing delay - it does not require syncing of time or timezones between the producer and the processing engine. ### Other notable implementation details - A new trigger metric `eventTimeWatermark` outputs the current value of the watermark. - We mark the event time column in the `Attribute` metadata using the key `spark.watermarkDelay`. This allows downstream operations to know which column holds the event time. Operations like `window` propagate this metadata. - `explain()` marks the watermark with a suffix of `-T${delayMs}` to ease debugging of how this information is propagated. - Currently, we don't filter out late records, but instead rely on the state store to avoid emitting records that are both added and filtered in the same epoch. ### Remaining in this PR - [ ] The test for recovery is currently failing as we don't record the watermark used in the offset log. We will need to do so to ensure determinism, but this is deferred until #15626 is merged. ### Other follow-ups There are some natural additional features that we should consider for future work: - Ability to write records that arrive too late to some external store in case any out-of-band remediation is required. - `Update` mode so you can get partial results before a group is evicted. - Other mechanisms for calculating the watermark. In particular a watermark based on quantiles would be more robust to outliers. Author: Michael Armbrust Closes #15702 from marmbrus/watermarks. --- .../spark/unsafe/types/CalendarInterval.java | 4 + .../apache/spark/sql/AnalysisException.scala | 3 +- .../sql/catalyst/analysis/Analyzer.scala | 8 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 10 + .../UnsupportedOperationChecker.scala | 18 +- .../sql/catalyst/analysis/unresolved.scala | 3 +- .../expressions/namedExpressions.scala | 17 +- .../plans/logical/EventTimeWatermark.scala | 51 +++++ .../scala/org/apache/spark/sql/Dataset.scala | 40 +++- .../spark/sql/execution/SparkStrategies.scala | 12 +- .../sql/execution/aggregate/AggUtils.scala | 9 +- .../sql/execution/command/commands.scala | 2 +- .../streaming/EventTimeWatermarkExec.scala | 93 +++++++++ .../sql/execution/streaming/ForeachSink.scala | 3 +- .../streaming/IncrementalExecution.scala | 12 +- .../streaming/StatefulAggregate.scala | 170 +++++++++------- .../execution/streaming/StreamExecution.scala | 25 ++- .../execution/streaming/StreamMetrics.scala | 1 + .../state/HDFSBackedStateStoreProvider.scala | 23 ++- .../streaming/state/StateStore.scala | 7 +- .../streaming/state/StateStoreSuite.scala | 6 +- .../spark/sql/streaming/WatermarkSuite.scala | 191 ++++++++++++++++++ 22 files changed, 597 insertions(+), 111 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/WatermarkSuite.scala diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java index 518ed6470a753..a7b0e6f80c2b6 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java @@ -252,6 +252,10 @@ public static long parseSecondNano(String secondNano) throws IllegalArgumentExce public final int months; public final long microseconds; + public final long milliseconds() { + return this.microseconds / MICROS_PER_MILLI; + } + public CalendarInterval(int months, long microseconds) { this.months = months; this.microseconds = microseconds; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala index 7defb9df862c0..ff8576157305b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -31,7 +31,8 @@ class AnalysisException protected[sql] ( val message: String, val line: Option[Int] = None, val startPosition: Option[Int] = None, - val plan: Option[LogicalPlan] = None, + // Some plans fail to serialize due to bugs in scala collections. + @transient val plan: Option[LogicalPlan] = None, val cause: Option[Throwable] = None) extends Exception(message, cause.orNull) with Serializable { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c14f353517088..ec5f710fd9872 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2272,7 +2272,13 @@ object TimeWindowing extends Rule[LogicalPlan] { windowExpressions.head.timeColumn.resolved && windowExpressions.head.checkInputDataTypes().isSuccess) { val window = windowExpressions.head - val windowAttr = AttributeReference("window", window.dataType)() + + val metadata = window.timeColumn match { + case a: Attribute => a.metadata + case _ => Metadata.empty + } + val windowAttr = + AttributeReference("window", window.dataType, metadata = metadata)() val maxNumOverlapping = math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt val windows = Seq.tabulate(maxNumOverlapping + 1) { i => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 7b75c1f70974b..98e50d0d3c674 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -148,6 +148,16 @@ trait CheckAnalysis extends PredicateHelper { } operator match { + case etw: EventTimeWatermark => + etw.eventTime.dataType match { + case s: StructType + if s.find(_.name == "end").map(_.dataType) == Some(TimestampType) => + case _: TimestampType => + case _ => + failAnalysis( + s"Event time must be defined on a window or a timestamp, but " + + s"${etw.eventTime.name} is of type ${etw.eventTime.dataType.simpleString}") + } case f: Filter if f.condition.dataType != BooleanType => failAnalysis( s"filter expression '${f.condition.sql}' " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index e81370c504abb..c054fcbef36f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.{AnalysisException, InternalOutputModes} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.streaming.OutputMode @@ -55,9 +56,20 @@ object UnsupportedOperationChecker { // Disallow some output mode outputMode match { case InternalOutputModes.Append if aggregates.nonEmpty => - throwError( - s"$outputMode output mode not supported when there are streaming aggregations on " + - s"streaming DataFrames/DataSets")(plan) + val aggregate = aggregates.head + + // Find any attributes that are associated with an eventTime watermark. + val watermarkAttributes = aggregate.groupingExpressions.collect { + case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => a + } + + // We can append rows to the sink once the group is under the watermark. Without this + // watermark a group is never "finished" so we would never output anything. + if (watermarkAttributes.isEmpty) { + throwError( + s"$outputMode output mode not supported when there are streaming aggregations on " + + s"streaming DataFrames/DataSets")(plan) + } case InternalOutputModes.Complete | InternalOutputModes.Update if aggregates.isEmpty => throwError( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 235ae04782455..36ed9ba50372b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, Codege import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan} import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.quoteIdentifier -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.{DataType, Metadata, StructType} /** * Thrown when an invalid attempt is made to access a property of a tree that has yet to be fully @@ -98,6 +98,7 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un override def withNullability(newNullability: Boolean): UnresolvedAttribute = this override def withQualifier(newQualifier: Option[String]): UnresolvedAttribute = this override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName) + override def withMetadata(newMetadata: Metadata): Attribute = this override def toString: String = s"'$name" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 306a99d5a37bf..1274757136051 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -22,6 +22,7 @@ import java.util.{Objects, UUID} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.sql.types._ @@ -104,6 +105,7 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn def withNullability(newNullability: Boolean): Attribute def withQualifier(newQualifier: Option[String]): Attribute def withName(newName: String): Attribute + def withMetadata(newMetadata: Metadata): Attribute override def toAttribute: Attribute = this def newInstance(): Attribute @@ -292,11 +294,22 @@ case class AttributeReference( } } + override def withMetadata(newMetadata: Metadata): Attribute = { + AttributeReference(name, dataType, nullable, newMetadata)(exprId, qualifier, isGenerated) + } + override protected final def otherCopyArgs: Seq[AnyRef] = { exprId :: qualifier :: isGenerated :: Nil } - override def toString: String = s"$name#${exprId.id}$typeSuffix" + /** Used to signal the column used to calculate an eventTime watermark (e.g. a#1-T{delayMs}) */ + private def delaySuffix = if (metadata.contains(EventTimeWatermark.delayKey)) { + s"-T${metadata.getLong(EventTimeWatermark.delayKey)}ms" + } else { + "" + } + + override def toString: String = s"$name#${exprId.id}$typeSuffix$delaySuffix" // Since the expression id is not in the first constructor it is missing from the default // tree string. @@ -332,6 +345,8 @@ case class PrettyAttribute( override def withQualifier(newQualifier: Option[String]): Attribute = throw new UnsupportedOperationException override def withName(newName: String): Attribute = throw new UnsupportedOperationException + override def withMetadata(newMetadata: Metadata): Attribute = + throw new UnsupportedOperationException override def qualifier: Option[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala new file mode 100644 index 0000000000000..4224a7997c410 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.types.MetadataBuilder +import org.apache.spark.unsafe.types.CalendarInterval + +object EventTimeWatermark { + /** The [[org.apache.spark.sql.types.Metadata]] key used to hold the eventTime watermark delay. */ + val delayKey = "spark.watermarkDelayMs" +} + +/** + * Used to mark a user specified column as holding the event time for a row. + */ +case class EventTimeWatermark( + eventTime: Attribute, + delay: CalendarInterval, + child: LogicalPlan) extends LogicalPlan { + + // Update the metadata on the eventTime column to include the desired delay. + override val output: Seq[Attribute] = child.output.map { a => + if (a semanticEquals eventTime) { + val updatedMetadata = new MetadataBuilder() + .withMetadata(a.metadata) + .putLong(EventTimeWatermark.delayKey, delay.milliseconds) + .build() + a.withMetadata(updatedMetadata) + } else { + a + } + } + + override val children: Seq[LogicalPlan] = child :: Nil +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index eb2b20afc37cf..af30683cc01c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -50,6 +50,7 @@ import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel +import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.Utils private[sql] object Dataset { @@ -476,7 +477,7 @@ class Dataset[T] private[sql]( * `collect()`, will throw an [[AnalysisException]] when there is a streaming * source present. * - * @group basic + * @group streaming * @since 2.0.0 */ @Experimental @@ -496,8 +497,6 @@ class Dataset[T] private[sql]( /** * Returns a checkpointed version of this Dataset. * - * @param eager When true, materializes the underlying checkpointed RDD eagerly. - * * @group basic * @since 2.1.0 */ @@ -535,6 +534,41 @@ class Dataset[T] private[sql]( )(sparkSession)).as[T] } + /** + * :: Experimental :: + * Defines an event time watermark for this [[Dataset]]. A watermark tracks a point in time + * before which we assume no more late data is going to arrive. + * + * Spark will use this watermark for several purposes: + * - To know when a given time window aggregation can be finalized and thus can be emitted when + * using output modes that do not allow updates. + * - To minimize the amount of state that we need to keep for on-going aggregations. + * + * The current watermark is computed by looking at the `MAX(eventTime)` seen across + * all of the partitions in the query minus a user specified `delayThreshold`. Due to the cost + * of coordinating this value across partitions, the actual watermark used is only guaranteed + * to be at least `delayThreshold` behind the actual event time. In some cases we may still + * process records that arrive more than `delayThreshold` late. + * + * @param eventTime the name of the column that contains the event time of the row. + * @param delayThreshold the minimum delay to wait to data to arrive late, relative to the latest + * record that has been processed in the form of an interval + * (e.g. "1 minute" or "5 hours"). + * + * @group streaming + * @since 2.1.0 + */ + @Experimental + @InterfaceStability.Evolving + // We only accept an existing column name, not a derived column here as a watermark that is + // defined on a derived column cannot referenced elsewhere in the plan. + def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = withTypedPlan { + val parsedDelay = + Option(CalendarInterval.fromString("interval " + delayThreshold)) + .getOrElse(throw new AnalysisException(s"Unable to parse time delay '$delayThreshold'")) + EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan) + } + /** * Displays the Dataset in a tabular form. Strings more than 20 characters will be truncated, * and all cells will be aligned right. For example: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 190fdd84343ee..2308ae8a6c611 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -18,20 +18,23 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{execution, SaveMode, Strategy} +import org.apache.spark.sql.{SaveMode, Strategy} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, EventTimeWatermark, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} -import org.apache.spark.sql.execution.streaming.{MemoryPlan, StreamingExecutionRelation, StreamingRelation, StreamingRelationExec} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.StreamingQuery /** * Converts a logical plan into zero or more SparkPlans. This API is exposed for experimenting @@ -224,6 +227,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ object StatefulAggregationStrategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case EventTimeWatermark(columnName, delay, child) => + EventTimeWatermarkExec(columnName, delay, planLater(child)) :: Nil + case PhysicalAggregation( namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 3c8ef1ad84c0a..8b8ccf4239b13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -328,8 +328,13 @@ object AggUtils { } // Note: stateId and returnAllStates are filled in later with preparation rules // in IncrementalExecution. - val saved = StateStoreSaveExec( - groupingAttributes, stateId = None, returnAllStates = None, partialMerged2) + val saved = + StateStoreSaveExec( + groupingAttributes, + stateId = None, + outputMode = None, + eventTimeWatermark = None, + partialMerged2) val finalAndCompleteAggregate: SparkPlan = { val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index d82e54e57564c..52d8dc22a2d4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -104,7 +104,7 @@ case class ExplainCommand( if (logicalPlan.isStreaming) { // This is used only by explaining `Dataset/DataFrame` created by `spark.readStream`, so the // output mode does not matter since there is no `Sink`. - new IncrementalExecution(sparkSession, logicalPlan, OutputMode.Append(), "", 0) + new IncrementalExecution(sparkSession, logicalPlan, OutputMode.Append(), "", 0, 0) } else { sparkSession.sessionState.executePlan(logicalPlan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala new file mode 100644 index 0000000000000..4c8cb069d23a0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import scala.math.max + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} +import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.types.MetadataBuilder +import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.util.AccumulatorV2 + +/** Tracks the maximum positive long seen. */ +class MaxLong(protected var currentValue: Long = 0) + extends AccumulatorV2[Long, Long] { + + override def isZero: Boolean = value == 0 + override def value: Long = currentValue + override def copy(): AccumulatorV2[Long, Long] = new MaxLong(currentValue) + + override def reset(): Unit = { + currentValue = 0 + } + + override def add(v: Long): Unit = { + currentValue = max(v, value) + } + + override def merge(other: AccumulatorV2[Long, Long]): Unit = { + currentValue = max(value, other.value) + } +} + +/** + * Used to mark a column as the containing the event time for a given record. In addition to + * adding appropriate metadata to this column, this operator also tracks the maximum observed event + * time. Based on the maximum observed time and a user specified delay, we can calculate the + * `watermark` after which we assume we will no longer see late records for a particular time + * period. + */ +case class EventTimeWatermarkExec( + eventTime: Attribute, + delay: CalendarInterval, + child: SparkPlan) extends SparkPlan { + + // TODO: Use Spark SQL Metrics? + val maxEventTime = new MaxLong + sparkContext.register(maxEventTime) + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { iter => + val getEventTime = UnsafeProjection.create(eventTime :: Nil, child.output) + iter.map { row => + maxEventTime.add(getEventTime(row).getLong(0)) + row + } + } + } + + // Update the metadata on the eventTime column to include the desired delay. + override val output: Seq[Attribute] = child.output.map { a => + if (a semanticEquals eventTime) { + val updatedMetadata = new MetadataBuilder() + .withMetadata(a.metadata) + .putLong(EventTimeWatermark.delayKey, delay.milliseconds) + .build() + + a.withMetadata(updatedMetadata) + } else { + a + } + } + + override def children: Seq[SparkPlan] = child :: Nil +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala index 24f98b9211f12..f5c550dd6ac3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala @@ -60,7 +60,8 @@ class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Seria deserialized, data.queryExecution.asInstanceOf[IncrementalExecution].outputMode, data.queryExecution.asInstanceOf[IncrementalExecution].checkpointLocation, - data.queryExecution.asInstanceOf[IncrementalExecution].currentBatchId) + data.queryExecution.asInstanceOf[IncrementalExecution].currentBatchId, + data.queryExecution.asInstanceOf[IncrementalExecution].currentEventTimeWatermark) incrementalExecution.toRdd.mapPartitions { rows => rows.map(_.get(0, objectType)) }.asInstanceOf[RDD[T]] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 05294df2673dc..e9d072f8a98b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -32,11 +32,13 @@ class IncrementalExecution( logicalPlan: LogicalPlan, val outputMode: OutputMode, val checkpointLocation: String, - val currentBatchId: Long) + val currentBatchId: Long, + val currentEventTimeWatermark: Long) extends QueryExecution(sparkSession, logicalPlan) { // TODO: make this always part of planning. - val stateStrategy = sparkSession.sessionState.planner.StatefulAggregationStrategy +: + val stateStrategy = + sparkSession.sessionState.planner.StatefulAggregationStrategy +: sparkSession.sessionState.planner.StreamingRelationStrategy +: sparkSession.sessionState.experimentalMethods.extraStrategies @@ -57,17 +59,17 @@ class IncrementalExecution( val state = new Rule[SparkPlan] { override def apply(plan: SparkPlan): SparkPlan = plan transform { - case StateStoreSaveExec(keys, None, None, + case StateStoreSaveExec(keys, None, None, None, UnaryExecNode(agg, StateStoreRestoreExec(keys2, None, child))) => val stateId = OperatorStateId(checkpointLocation, operatorId, currentBatchId) - val returnAllStates = if (outputMode == InternalOutputModes.Complete) true else false operatorId += 1 StateStoreSaveExec( keys, Some(stateId), - Some(returnAllStates), + Some(outputMode), + Some(currentEventTimeWatermark), agg.withNewChildren( StateStoreRestoreExec( keys, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala index ad8238f189c64..7af978a9c4aa2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala @@ -21,12 +21,17 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratePredicate, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution +import org.apache.spark.sql.InternalOutputModes._ +import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.StructType + /** Used to identify the state store for a given operator. */ case class OperatorStateId( @@ -92,8 +97,9 @@ case class StateStoreRestoreExec( */ case class StateStoreSaveExec( keyExpressions: Seq[Attribute], - stateId: Option[OperatorStateId], - returnAllStates: Option[Boolean], + stateId: Option[OperatorStateId] = None, + outputMode: Option[OutputMode] = None, + eventTimeWatermark: Option[Long] = None, child: SparkPlan) extends execution.UnaryExecNode with StatefulOperator { @@ -104,9 +110,9 @@ case class StateStoreSaveExec( override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver - assert(returnAllStates.nonEmpty, - "Incorrect planning in IncrementalExecution, returnAllStates have not been set") - val saveAndReturnFunc = if (returnAllStates.get) saveAndReturnAll _ else saveAndReturnUpdated _ + assert(outputMode.nonEmpty, + "Incorrect planning in IncrementalExecution, outputMode has not been set") + child.execute().mapPartitionsWithStateStore( getStateId.checkpointLocation, operatorId = getStateId.operatorId, @@ -114,75 +120,95 @@ case class StateStoreSaveExec( keyExpressions.toStructType, child.output.toStructType, sqlContext.sessionState, - Some(sqlContext.streams.stateStoreCoordinator) - )(saveAndReturnFunc) + Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => + val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) + val numOutputRows = longMetric("numOutputRows") + val numTotalStateRows = longMetric("numTotalStateRows") + val numUpdatedStateRows = longMetric("numUpdatedStateRows") + + outputMode match { + // Update and output all rows in the StateStore. + case Some(Complete) => + while (iter.hasNext) { + val row = iter.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key.copy(), row.copy()) + numUpdatedStateRows += 1 + } + store.commit() + numTotalStateRows += store.numKeys() + store.iterator().map { case (k, v) => + numOutputRows += 1 + v.asInstanceOf[InternalRow] + } + + // Update and output only rows being evicted from the StateStore + case Some(Append) => + while (iter.hasNext) { + val row = iter.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key.copy(), row.copy()) + numUpdatedStateRows += 1 + } + + val watermarkAttribute = + keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey)).get + // If we are evicting based on a window, use the end of the window. Otherwise just + // use the attribute itself. + val evictionExpression = + if (watermarkAttribute.dataType.isInstanceOf[StructType]) { + LessThanOrEqual( + GetStructField(watermarkAttribute, 1), + Literal(eventTimeWatermark.get * 1000)) + } else { + LessThanOrEqual( + watermarkAttribute, + Literal(eventTimeWatermark.get * 1000)) + } + + logInfo(s"Filtering state store on: $evictionExpression") + val predicate = newPredicate(evictionExpression, keyExpressions) + store.remove(predicate.eval) + + store.commit() + + numTotalStateRows += store.numKeys() + store.updates().filter(_.isInstanceOf[ValueRemoved]).map { removed => + numOutputRows += 1 + removed.value.asInstanceOf[InternalRow] + } + + // Update and output modified rows from the StateStore. + case Some(Update) => + new Iterator[InternalRow] { + private[this] val baseIterator = iter + + override def hasNext: Boolean = { + if (!baseIterator.hasNext) { + store.commit() + numTotalStateRows += store.numKeys() + false + } else { + true + } + } + + override def next(): InternalRow = { + val row = baseIterator.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key.copy(), row.copy()) + numOutputRows += 1 + numUpdatedStateRows += 1 + row + } + } + + case _ => throw new UnsupportedOperationException(s"Invalid output mode: $outputMode") + } + } } override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning - - /** - * Save all the rows to the state store, and return all the rows in the state store. - * Note that this returns an iterator that pipelines the saving to store with downstream - * processing. - */ - private def saveAndReturnUpdated( - store: StateStore, - iter: Iterator[InternalRow]): Iterator[InternalRow] = { - val numOutputRows = longMetric("numOutputRows") - val numTotalStateRows = longMetric("numTotalStateRows") - val numUpdatedStateRows = longMetric("numUpdatedStateRows") - - new Iterator[InternalRow] { - private[this] val baseIterator = iter - private[this] val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) - - override def hasNext: Boolean = { - if (!baseIterator.hasNext) { - store.commit() - numTotalStateRows += store.numKeys() - false - } else { - true - } - } - - override def next(): InternalRow = { - val row = baseIterator.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key.copy(), row.copy()) - numOutputRows += 1 - numUpdatedStateRows += 1 - row - } - } - } - - /** - * Save all the rows to the state store, and return all the rows in the state store. - * Note that the saving to store is blocking; only after all the rows have been saved - * is the iterator on the update store data is generated. - */ - private def saveAndReturnAll( - store: StateStore, - iter: Iterator[InternalRow]): Iterator[InternalRow] = { - val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) - val numOutputRows = longMetric("numOutputRows") - val numTotalStateRows = longMetric("numTotalStateRows") - val numUpdatedStateRows = longMetric("numUpdatedStateRows") - - while (iter.hasNext) { - val row = iter.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key.copy(), row.copy()) - numUpdatedStateRows += 1 - } - store.commit() - numTotalStateRows += store.numKeys() - store.iterator().map { case (k, v) => - numOutputRows += 1 - v.asInstanceOf[InternalRow] - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 57e89f85361e4..3ca6feac05cef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -92,6 +92,9 @@ class StreamExecution( /** The current batchId or -1 if execution has not yet been initialized. */ private var currentBatchId: Long = -1 + /** The current eventTime watermark, used to bound the lateness of data that will processed. */ + private var currentEventTimeWatermark: Long = 0 + /** All stream sources present in the query plan. */ private val sources = logicalPlan.collect { case s: StreamingExecutionRelation => s.source } @@ -427,7 +430,8 @@ class StreamExecution( triggerLogicalPlan, outputMode, checkpointFile("state"), - currentBatchId) + currentBatchId, + currentEventTimeWatermark) lastExecution.executedPlan // Force the lazy generation of execution plan } @@ -436,6 +440,25 @@ class StreamExecution( sink.addBatch(currentBatchId, nextBatch) reportNumRows(executedPlan, triggerLogicalPlan, newData) + // Update the eventTime watermark if we find one in the plan. + // TODO: Does this need to be an AttributeMap? + lastExecution.executedPlan.collect { + case e: EventTimeWatermarkExec => + logTrace(s"Maximum observed eventTime: ${e.maxEventTime.value}") + (e.maxEventTime.value / 1000) - e.delay.milliseconds() + }.headOption.foreach { newWatermark => + if (newWatermark > currentEventTimeWatermark) { + logInfo(s"Updating eventTime watermark to: $newWatermark ms") + currentEventTimeWatermark = newWatermark + } else { + logTrace(s"Event time didn't move: $newWatermark < $currentEventTimeWatermark") + } + + if (newWatermark != 0) { + streamMetrics.reportTriggerDetail(EVENT_TIME_WATERMARK, newWatermark) + } + } + awaitBatchLock.lock() try { // Wake up any threads that are waiting for the stream to progress. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala index e98d1883e4596..5645554a58f6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala @@ -221,6 +221,7 @@ object StreamMetrics extends Logging { val IS_TRIGGER_ACTIVE = "isTriggerActive" val IS_DATA_PRESENT_IN_TRIGGER = "isDataPresentInTrigger" val STATUS_MESSAGE = "statusMessage" + val EVENT_TIME_WATERMARK = "eventTimeWatermark" val START_TIMESTAMP = "timestamp.triggerStart" val GET_OFFSET_TIMESTAMP = "timestamp.afterGetOffset" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index f07feaad5dc71..493fdaaec5069 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -109,7 +109,7 @@ private[state] class HDFSBackedStateStoreProvider( case Some(ValueAdded(_, _)) => // Value did not exist in previous version and was added already, keep it marked as added allUpdates.put(key, ValueAdded(key, value)) - case Some(ValueUpdated(_, _)) | Some(KeyRemoved(_)) => + case Some(ValueUpdated(_, _)) | Some(ValueRemoved(_, _)) => // Value existed in previous version and updated/removed, mark it as updated allUpdates.put(key, ValueUpdated(key, value)) case None => @@ -124,24 +124,25 @@ private[state] class HDFSBackedStateStoreProvider( /** Remove keys that match the following condition */ override def remove(condition: UnsafeRow => Boolean): Unit = { verify(state == UPDATING, "Cannot remove after already committed or aborted") - - val keyIter = mapToUpdate.keySet().iterator() - while (keyIter.hasNext) { - val key = keyIter.next - if (condition(key)) { - keyIter.remove() + val entryIter = mapToUpdate.entrySet().iterator() + while (entryIter.hasNext) { + val entry = entryIter.next + if (condition(entry.getKey)) { + val value = entry.getValue + val key = entry.getKey + entryIter.remove() Option(allUpdates.get(key)) match { case Some(ValueUpdated(_, _)) | None => // Value existed in previous version and maybe was updated, mark removed - allUpdates.put(key, KeyRemoved(key)) + allUpdates.put(key, ValueRemoved(key, value)) case Some(ValueAdded(_, _)) => // Value did not exist in previous version and was added, should not appear in updates allUpdates.remove(key) - case Some(KeyRemoved(_)) => + case Some(ValueRemoved(_, _)) => // Remove already in update map, no need to change } - writeToDeltaFile(tempDeltaFileStream, KeyRemoved(key)) + writeToDeltaFile(tempDeltaFileStream, ValueRemoved(key, value)) } } } @@ -334,7 +335,7 @@ private[state] class HDFSBackedStateStoreProvider( writeUpdate(key, value) case ValueUpdated(key, value) => writeUpdate(key, value) - case KeyRemoved(key) => + case ValueRemoved(key, value) => writeRemove(key) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 7132e284c28f4..9bc6c0e2b9334 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -99,13 +99,16 @@ trait StateStoreProvider { /** Trait representing updates made to a [[StateStore]]. */ -sealed trait StoreUpdate +sealed trait StoreUpdate { + def key: UnsafeRow + def value: UnsafeRow +} case class ValueAdded(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate case class ValueUpdated(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate -case class KeyRemoved(key: UnsafeRow) extends StoreUpdate +case class ValueRemoved(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 533cd0cd2a2ea..05fc7345a7daf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -668,11 +668,11 @@ private[state] object StateStoreSuite { } def updatesToSet(iterator: Iterator[StoreUpdate]): Set[TestUpdate] = { - iterator.map { _ match { + iterator.map { case ValueAdded(key, value) => Added(rowToString(key), rowToInt(value)) case ValueUpdated(key, value) => Updated(rowToString(key), rowToInt(value)) - case KeyRemoved(key) => Removed(rowToString(key)) - }}.toSet + case ValueRemoved(key, _) => Removed(rowToString(key)) + }.toSet } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/WatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/WatermarkSuite.scala new file mode 100644 index 0000000000000..3617ec0f564c1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/WatermarkSuite.scala @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.functions.{count, window} + +class WatermarkSuite extends StreamTest with BeforeAndAfter with Logging { + + import testImplicits._ + + after { + sqlContext.streams.active.foreach(_.stop()) + } + + test("error on bad column") { + val inputData = MemoryStream[Int].toDF() + val e = intercept[AnalysisException] { + inputData.withWatermark("badColumn", "1 minute") + } + assert(e.getMessage contains "badColumn") + } + + test("error on wrong type") { + val inputData = MemoryStream[Int].toDF() + val e = intercept[AnalysisException] { + inputData.withWatermark("value", "1 minute") + } + assert(e.getMessage contains "value") + assert(e.getMessage contains "int") + } + + + test("watermark metric") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(windowedAggregation)( + AddData(inputData, 15), + AssertOnLastQueryStatus { status => + status.triggerDetails.get(StreamMetrics.EVENT_TIME_WATERMARK) === "5000" + }, + AddData(inputData, 15), + AssertOnLastQueryStatus { status => + status.triggerDetails.get(StreamMetrics.EVENT_TIME_WATERMARK) === "5000" + }, + AddData(inputData, 25), + AssertOnLastQueryStatus { status => + status.triggerDetails.get(StreamMetrics.EVENT_TIME_WATERMARK) === "15000" + } + ) + } + + test("append-mode watermark aggregation") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(windowedAggregation)( + AddData(inputData, 10, 11, 12, 13, 14, 15), + CheckAnswer(), + AddData(inputData, 25), // Advance watermark to 15 seconds + CheckAnswer(), + AddData(inputData, 25), // Evict items less than previous watermark. + CheckAnswer((10, 5)) + ) + } + + ignore("recovery") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(windowedAggregation)( + AddData(inputData, 10, 11, 12, 13, 14, 15), + CheckAnswer(), + AddData(inputData, 25), // Advance watermark to 15 seconds + StopStream, + StartStream(), + CheckAnswer(), + AddData(inputData, 25), // Evict items less than previous watermark. + StopStream, + StartStream(), + CheckAnswer((10, 5)) + ) + } + + test("dropping old data") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(windowedAggregation)( + AddData(inputData, 10, 11, 12), + CheckAnswer(), + AddData(inputData, 25), // Advance watermark to 15 seconds + CheckAnswer(), + AddData(inputData, 25), // Evict items less than previous watermark. + CheckAnswer((10, 3)), + AddData(inputData, 10), // 10 is later than 15 second watermark + CheckAnswer((10, 3)), + AddData(inputData, 25), + CheckAnswer((10, 3)) // Should not emit an incorrect partial result. + ) + } + + test("complete mode") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + // No eviction when asked to compute complete results. + testStream(windowedAggregation, OutputMode.Complete)( + AddData(inputData, 10, 11, 12), + CheckAnswer((10, 3)), + AddData(inputData, 25), + CheckAnswer((10, 3), (25, 1)), + AddData(inputData, 25), + CheckAnswer((10, 3), (25, 2)), + AddData(inputData, 10), + CheckAnswer((10, 4), (25, 2)), + AddData(inputData, 25), + CheckAnswer((10, 4), (25, 3)) + ) + } + + test("group by on raw timestamp") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy($"eventTime") + .agg(count("*") as 'count) + .select($"eventTime".cast("long").as[Long], $"count".as[Long]) + + testStream(windowedAggregation)( + AddData(inputData, 10), + CheckAnswer(), + AddData(inputData, 25), // Advance watermark to 15 seconds + CheckAnswer(), + AddData(inputData, 25), // Evict items less than previous watermark. + CheckAnswer((10, 1)) + ) + } +} From c31def1ddcbed340bfc071d54fb3dc7945cb525a Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Mon, 14 Nov 2016 21:15:39 -0800 Subject: [PATCH 138/198] [SPARK-18428][DOC] Update docs for GraphX ## What changes were proposed in this pull request? 1, Add link of `VertexRDD` and `EdgeRDD` 2, Notify in `Vertex and Edge RDDs` that not all methods are listed 3, `VertexID` -> `VertexId` ## How was this patch tested? No tests, only docs is modified Author: Zheng RuiFeng Closes #15875 from zhengruifeng/update_graphop_doc. --- docs/graphx-programming-guide.md | 68 ++++++++++++++++---------------- 1 file changed, 35 insertions(+), 33 deletions(-) diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 58671e6f146d8..1097cf1211c1f 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -11,6 +11,7 @@ description: GraphX graph processing library guide for Spark SPARK_VERSION_SHORT [EdgeRDD]: api/scala/index.html#org.apache.spark.graphx.EdgeRDD +[VertexRDD]: api/scala/index.html#org.apache.spark.graphx.VertexRDD [Edge]: api/scala/index.html#org.apache.spark.graphx.Edge [EdgeTriplet]: api/scala/index.html#org.apache.spark.graphx.EdgeTriplet [Graph]: api/scala/index.html#org.apache.spark.graphx.Graph @@ -89,7 +90,7 @@ with user defined objects attached to each vertex and edge. A directed multigra graph with potentially multiple parallel edges sharing the same source and destination vertex. The ability to support parallel edges simplifies modeling scenarios where there can be multiple relationships (e.g., co-worker and friend) between the same vertices. Each vertex is keyed by a -*unique* 64-bit long identifier (`VertexID`). GraphX does not impose any ordering constraints on +*unique* 64-bit long identifier (`VertexId`). GraphX does not impose any ordering constraints on the vertex identifiers. Similarly, edges have corresponding source and destination vertex identifiers. @@ -130,12 +131,12 @@ class Graph[VD, ED] { } {% endhighlight %} -The classes `VertexRDD[VD]` and `EdgeRDD[ED]` extend and are optimized versions of `RDD[(VertexID, +The classes `VertexRDD[VD]` and `EdgeRDD[ED]` extend and are optimized versions of `RDD[(VertexId, VD)]` and `RDD[Edge[ED]]` respectively. Both `VertexRDD[VD]` and `EdgeRDD[ED]` provide additional functionality built around graph computation and leverage internal optimizations. We discuss the -`VertexRDD` and `EdgeRDD` API in greater detail in the section on [vertex and edge +`VertexRDD`[VertexRDD] and `EdgeRDD`[EdgeRDD] API in greater detail in the section on [vertex and edge RDDs](#vertex_and_edge_rdds) but for now they can be thought of as simply RDDs of the form: -`RDD[(VertexID, VD)]` and `RDD[Edge[ED]]`. +`RDD[(VertexId, VD)]` and `RDD[Edge[ED]]`. ### Example Property Graph @@ -197,7 +198,7 @@ graph.edges.filter(e => e.srcId > e.dstId).count {% endhighlight %} > Note that `graph.vertices` returns an `VertexRDD[(String, String)]` which extends -> `RDD[(VertexID, (String, String))]` and so we use the scala `case` expression to deconstruct the +> `RDD[(VertexId, (String, String))]` and so we use the scala `case` expression to deconstruct the > tuple. On the other hand, `graph.edges` returns an `EdgeRDD` containing `Edge[String]` objects. > We could have also used the case class type constructor as in the following: > {% highlight scala %} @@ -287,7 +288,7 @@ class Graph[VD, ED] { // Change the partitioning heuristic ============================================================ def partitionBy(partitionStrategy: PartitionStrategy): Graph[VD, ED] // Transform vertex and edge attributes ========================================================== - def mapVertices[VD2](map: (VertexID, VD) => VD2): Graph[VD2, ED] + def mapVertices[VD2](map: (VertexId, VD) => VD2): Graph[VD2, ED] def mapEdges[ED2](map: Edge[ED] => ED2): Graph[VD, ED2] def mapEdges[ED2](map: (PartitionID, Iterator[Edge[ED]]) => Iterator[ED2]): Graph[VD, ED2] def mapTriplets[ED2](map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] @@ -297,18 +298,18 @@ class Graph[VD, ED] { def reverse: Graph[VD, ED] def subgraph( epred: EdgeTriplet[VD,ED] => Boolean = (x => true), - vpred: (VertexID, VD) => Boolean = ((v, d) => true)) + vpred: (VertexId, VD) => Boolean = ((v, d) => true)) : Graph[VD, ED] def mask[VD2, ED2](other: Graph[VD2, ED2]): Graph[VD, ED] def groupEdges(merge: (ED, ED) => ED): Graph[VD, ED] // Join RDDs with the graph ====================================================================== - def joinVertices[U](table: RDD[(VertexID, U)])(mapFunc: (VertexID, VD, U) => VD): Graph[VD, ED] - def outerJoinVertices[U, VD2](other: RDD[(VertexID, U)]) - (mapFunc: (VertexID, VD, Option[U]) => VD2) + def joinVertices[U](table: RDD[(VertexId, U)])(mapFunc: (VertexId, VD, U) => VD): Graph[VD, ED] + def outerJoinVertices[U, VD2](other: RDD[(VertexId, U)]) + (mapFunc: (VertexId, VD, Option[U]) => VD2) : Graph[VD2, ED] // Aggregate information about adjacent triplets ================================================= - def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexID]] - def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexID, VD)]] + def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexId]] + def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexId, VD)]] def aggregateMessages[Msg: ClassTag]( sendMsg: EdgeContext[VD, ED, Msg] => Unit, mergeMsg: (Msg, Msg) => Msg, @@ -316,15 +317,15 @@ class Graph[VD, ED] { : VertexRDD[A] // Iterative graph-parallel computation ========================================================== def pregel[A](initialMsg: A, maxIterations: Int, activeDirection: EdgeDirection)( - vprog: (VertexID, VD, A) => VD, - sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexID,A)], + vprog: (VertexId, VD, A) => VD, + sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId,A)], mergeMsg: (A, A) => A) : Graph[VD, ED] // Basic graph algorithms ======================================================================== def pageRank(tol: Double, resetProb: Double = 0.15): Graph[Double, Double] - def connectedComponents(): Graph[VertexID, ED] + def connectedComponents(): Graph[VertexId, ED] def triangleCount(): Graph[Int, ED] - def stronglyConnectedComponents(numIter: Int): Graph[VertexID, ED] + def stronglyConnectedComponents(numIter: Int): Graph[VertexId, ED] } {% endhighlight %} @@ -481,7 +482,7 @@ original value. > is therefore recommended that the input RDD be made unique using the following which will > also *pre-index* the resulting values to substantially accelerate the subsequent join. > {% highlight scala %} -val nonUniqueCosts: RDD[(VertexID, Double)] +val nonUniqueCosts: RDD[(VertexId, Double)] val uniqueCosts: VertexRDD[Double] = graph.vertices.aggregateUsingIndex(nonUnique, (a,b) => a + b) val joinedGraph = graph.joinVertices(uniqueCosts)( @@ -511,7 +512,7 @@ val degreeGraph = graph.outerJoinVertices(outDegrees) { (id, oldAttr, outDegOpt) > provide type annotation for the user defined function: > {% highlight scala %} val joinedGraph = graph.joinVertices(uniqueCosts, - (id: VertexID, oldCost: Double, extraCost: Double) => oldCost + extraCost) + (id: VertexId, oldCost: Double, extraCost: Double) => oldCost + extraCost) {% endhighlight %} > @@ -558,7 +559,7 @@ The user defined `mergeMsg` function takes two messages destined to the same ver yields a single message. Think of `mergeMsg` as the reduce function in map-reduce. The [`aggregateMessages`][Graph.aggregateMessages] operator returns a `VertexRDD[Msg]` containing the aggregate message (of type `Msg`) destined to each vertex. Vertices that did not -receive a message are not included in the returned `VertexRDD`. +receive a message are not included in the returned `VertexRDD`[VertexRDD]. + +More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.regression.LinearRegression). + {% include_example python/ml/linear_regression_with_elastic_net.py %}
@@ -519,18 +546,21 @@ function and extracting model summary statistics.
+ Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.regression.GeneralizedLinearRegression) for more details. {% include_example scala/org/apache/spark/examples/ml/GeneralizedLinearRegressionExample.scala %}
+ Refer to the [Java API docs](api/java/org/apache/spark/ml/regression/GeneralizedLinearRegression.html) for more details. {% include_example java/org/apache/spark/examples/ml/JavaGeneralizedLinearRegressionExample.java %}
+ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.GeneralizedLinearRegression) for more details. {% include_example python/ml/generalized_linear_regression_example.py %} @@ -705,14 +735,23 @@ The implementation matches the result from R's survival function
+ +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.regression.AFTSurvivalRegression) for more details. + {% include_example scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala %}
+ +Refer to the [Java API docs](api/java/org/apache/spark/ml/regression/AFTSurvivalRegression.html) for more details. + {% include_example java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java %}
+ +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.AFTSurvivalRegression) for more details. + {% include_example python/ml/aft_survival_regression.py %}
diff --git a/docs/ml-pipeline.md b/docs/ml-pipeline.md index adb057ba7e250..b4d6be94f5eb0 100644 --- a/docs/ml-pipeline.md +++ b/docs/ml-pipeline.md @@ -207,14 +207,29 @@ This example covers the concepts of `Estimator`, `Transformer`, and `Param`.
+ +Refer to the [`Estimator` Scala docs](api/scala/index.html#org.apache.spark.ml.Estimator), +the [`Transformer` Scala docs](api/scala/index.html#org.apache.spark.ml.Transformer) and +the [`Params` Scala docs](api/scala/index.html#org.apache.spark.ml.param.Params) for details on the API. + {% include_example scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala %}
+ +Refer to the [`Estimator` Java docs](api/java/org/apache/spark/ml/Estimator.html), +the [`Transformer` Java docs](api/java/org/apache/spark/ml/Transformer.html) and +the [`Params` Java docs](api/java/org/apache/spark/ml/param/Params.html) for details on the API. + {% include_example java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java %}
+ +Refer to the [`Estimator` Python docs](api/python/pyspark.ml.html#pyspark.ml.Estimator), +the [`Transformer` Python docs](api/python/pyspark.ml.html#pyspark.ml.Transformer) and +the [`Params` Python docs](api/python/pyspark.ml.html#pyspark.ml.param.Params) for more details on the API. + {% include_example python/ml/estimator_transformer_param_example.py %}
@@ -227,14 +242,24 @@ This example follows the simple text document `Pipeline` illustrated in the figu
+ +Refer to the [`Pipeline` Scala docs](api/scala/index.html#org.apache.spark.ml.Pipeline) for details on the API. + {% include_example scala/org/apache/spark/examples/ml/PipelineExample.scala %}
+ + +Refer to the [`Pipeline` Java docs](api/java/org/apache/spark/ml/Pipeline.html) for details on the API. + {% include_example java/org/apache/spark/examples/ml/JavaPipelineExample.java %}
+ +Refer to the [`Pipeline` Python docs](api/python/pyspark.ml.html#pyspark.ml.Pipeline) for more details on the API. + {% include_example python/ml/pipeline_example.py %}
diff --git a/docs/ml-tuning.md b/docs/ml-tuning.md index e4b070331db4b..a135adc4334cc 100644 --- a/docs/ml-tuning.md +++ b/docs/ml-tuning.md @@ -75,15 +75,23 @@ However, it is also a well-established method for choosing parameters which is m
+ +Refer to the [`CrossValidator` Scala docs](api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator) for details on the API. + {% include_example scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala %}
+ +Refer to the [`CrossValidator` Java docs](api/java/org/apache/spark/ml/tuning/CrossValidator.html) for details on the API. + {% include_example java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java %}
+Refer to the [`CrossValidator` Python docs](api/python/pyspark.ml.html#pyspark.ml.tuning.CrossValidator) for more details on the API. + {% include_example python/ml/cross_validator.py %}
@@ -107,14 +115,23 @@ Like `CrossValidator`, `TrainValidationSplit` finally fits the `Estimator` using
+ +Refer to the [`TrainValidationSplit` Scala docs](api/scala/index.html#org.apache.spark.ml.tuning.TrainValidationSplit) for details on the API. + {% include_example scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala %}
+ +Refer to the [`TrainValidationSplit` Java docs](api/java/org/apache/spark/ml/tuning/TrainValidationSplit.html) for details on the API. + {% include_example java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java %}
+ +Refer to the [`TrainValidationSplit` Python docs](api/python/pyspark.ml.html#pyspark.ml.tuning.TrainValidationSplit) for more details on the API. + {% include_example python/ml/train_validation_split.py %}
From 7569cf6cb85bda7d0e76d3e75e286d4796e77e08 Mon Sep 17 00:00:00 2001 From: Xianyang Liu Date: Wed, 16 Nov 2016 11:59:00 +0000 Subject: [PATCH 160/198] [SPARK-18420][BUILD] Fix the errors caused by lint check in Java ## What changes were proposed in this pull request? Small fix, fix the errors caused by lint check in Java - Clear unused objects and `UnusedImports`. - Add comments around the method `finalize` of `NioBufferedFileInputStream`to turn off checkstyle. - Cut the line which is longer than 100 characters into two lines. ## How was this patch tested? Travis CI. ``` $ build/mvn -T 4 -q -DskipTests -Pyarn -Phadoop-2.3 -Pkinesis-asl -Phive -Phive-thriftserver install $ dev/lint-java ``` Before: ``` Checkstyle checks failed at following occurrences: [ERROR] src/main/java/org/apache/spark/network/util/TransportConf.java:[21,8] (imports) UnusedImports: Unused import - org.apache.commons.crypto.cipher.CryptoCipherFactory. [ERROR] src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java:[516,5] (modifier) RedundantModifier: Redundant 'public' modifier. [ERROR] src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java:[133] (coding) NoFinalizer: Avoid using finalizer method. [ERROR] src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java:[71] (sizes) LineLength: Line is longer than 100 characters (found 113). [ERROR] src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java:[112] (sizes) LineLength: Line is longer than 100 characters (found 110). [ERROR] src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java:[31,17] (modifier) ModifierOrder: 'static' modifier out of order with the JLS suggestions. [ERROR]src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java:[64] (sizes) LineLength: Line is longer than 100 characters (found 103). [ERROR] src/main/java/org/apache/spark/examples/ml/JavaInteractionExample.java:[22,8] (imports) UnusedImports: Unused import - org.apache.spark.ml.linalg.Vectors. [ERROR] src/main/java/org/apache/spark/examples/ml/JavaInteractionExample.java:[51] (regexp) RegexpSingleline: No trailing whitespace allowed. ``` After: ``` $ build/mvn -T 4 -q -DskipTests -Pyarn -Phadoop-2.3 -Pkinesis-asl -Phive -Phive-thriftserver install $ dev/lint-java Using `mvn` from path: /home/travis/build/ConeyLiu/spark/build/apache-maven-3.3.9/bin/mvn Checkstyle checks passed. ``` Author: Xianyang Liu Closes #15865 from ConeyLiu/master. --- .../apache/spark/network/util/TransportConf.java | 1 - .../apache/spark/network/sasl/SparkSaslSuite.java | 2 +- .../spark/io/NioBufferedFileInputStream.java | 2 ++ dev/checkstyle.xml | 15 +++++++++++++++ .../spark/examples/ml/JavaInteractionExample.java | 3 +-- ...vaLogisticRegressionWithElasticNetExample.java | 4 ++-- .../sql/catalyst/expressions/UnsafeArrayData.java | 3 ++- .../sql/catalyst/expressions/UnsafeMapData.java | 3 ++- .../sql/catalyst/expressions/HiveHasherSuite.java | 1 - 9 files changed, 25 insertions(+), 9 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index d0d072849d384..012bb098f6fc4 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -18,7 +18,6 @@ package org.apache.spark.network.util; import com.google.common.primitives.Ints; -import org.apache.commons.crypto.cipher.CryptoCipherFactory; /** * A central location that tracks all the settings we expose to users. diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 4e6146cf070d0..ef2ab34b2277c 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -513,7 +513,7 @@ private static class EncryptionCheckerBootstrap extends ChannelOutboundHandlerAd boolean foundEncryptionHandler; String encryptHandlerName; - public EncryptionCheckerBootstrap(String encryptHandlerName) { + EncryptionCheckerBootstrap(String encryptHandlerName) { this.encryptHandlerName = encryptHandlerName; } diff --git a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java index f6d1288cb263d..ea5f1a9abf69b 100644 --- a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java +++ b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java @@ -130,8 +130,10 @@ public synchronized void close() throws IOException { StorageUtils.dispose(byteBuffer); } + //checkstyle.off: NoFinalizer @Override protected void finalize() throws IOException { close(); } + //checkstyle.on: NoFinalizer } diff --git a/dev/checkstyle.xml b/dev/checkstyle.xml index 3de6aa91dcd51..92c5251c85037 100644 --- a/dev/checkstyle.xml +++ b/dev/checkstyle.xml @@ -52,6 +52,20 @@ + + + + + + + @@ -168,5 +182,6 @@ + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaInteractionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaInteractionExample.java index 4213c05703cc6..3684a87e22e7b 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaInteractionExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaInteractionExample.java @@ -19,7 +19,6 @@ import org.apache.spark.ml.feature.Interaction; import org.apache.spark.ml.feature.VectorAssembler; -import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.*; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; @@ -48,7 +47,7 @@ public static void main(String[] args) { RowFactory.create(5, 9, 2, 7, 10, 7, 3), RowFactory.create(6, 1, 1, 4, 2, 8, 4) ); - + StructType schema = new StructType(new StructField[]{ new StructField("id1", DataTypes.IntegerType, false, Metadata.empty()), new StructField("id2", DataTypes.IntegerType, false, Metadata.empty()), diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java index b8fb5972ea418..4cdec21d23023 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java @@ -60,8 +60,8 @@ public static void main(String[] args) { LogisticRegressionModel mlrModel = mlr.fit(training); // Print the coefficients and intercepts for logistic regression with multinomial family - System.out.println("Multinomial coefficients: " - + lrModel.coefficientMatrix() + "\nMultinomial intercepts: " + mlrModel.interceptVector()); + System.out.println("Multinomial coefficients: " + lrModel.coefficientMatrix() + + "\nMultinomial intercepts: " + mlrModel.interceptVector()); // $example off$ spark.stop(); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 86523c1474015..e8c33871f97bc 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -109,7 +109,8 @@ public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) { // Read the number of elements from the first 8 bytes. final long numElements = Platform.getLong(baseObject, baseOffset); assert numElements >= 0 : "numElements (" + numElements + ") should >= 0"; - assert numElements <= Integer.MAX_VALUE : "numElements (" + numElements + ") should <= Integer.MAX_VALUE"; + assert numElements <= Integer.MAX_VALUE : + "numElements (" + numElements + ") should <= Integer.MAX_VALUE"; this.numElements = (int)numElements; this.baseObject = baseObject; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java index 35029f5a50e3e..f17441dfccb6d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java @@ -68,7 +68,8 @@ public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) { // Read the numBytes of key array from the first 8 bytes. final long keyArraySize = Platform.getLong(baseObject, baseOffset); assert keyArraySize >= 0 : "keyArraySize (" + keyArraySize + ") should >= 0"; - assert keyArraySize <= Integer.MAX_VALUE : "keyArraySize (" + keyArraySize + ") should <= Integer.MAX_VALUE"; + assert keyArraySize <= Integer.MAX_VALUE : + "keyArraySize (" + keyArraySize + ") should <= Integer.MAX_VALUE"; final int valueArraySize = sizeInBytes - (int)keyArraySize - 8; assert valueArraySize >= 0 : "valueArraySize (" + valueArraySize + ") should >= 0"; diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java index 67a5eb0c7fe8f..b67c6f3e6e85e 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java @@ -28,7 +28,6 @@ import java.util.Set; public class HiveHasherSuite { - private final static HiveHasher hasher = new HiveHasher(); @Test public void testKnownIntegerInputs() { From 608ecc512b759514c75a1b475582f237ed569f10 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 16 Nov 2016 08:25:15 -0800 Subject: [PATCH 161/198] [SPARK-18415][SQL] Weird Plan Output when CTE used in RunnableCommand ### What changes were proposed in this pull request? Currently, when CTE is used in RunnableCommand, the Analyzer does not replace the logical node `With`. The child plan of RunnableCommand is not resolved. Thus, the output of the `With` plan node looks very confusing. For example, ``` sql( """ |CREATE VIEW cte_view AS |WITH w AS (SELECT 1 AS n), cte1 (select 2), cte2 as (select 3) |SELECT n FROM w """.stripMargin).explain() ``` The output is like ``` ExecutedCommand +- CreateViewCommand `cte_view`, WITH w AS (SELECT 1 AS n), cte1 (select 2), cte2 as (select 3) SELECT n FROM w, false, false, PersistedView +- 'With [(w,SubqueryAlias w +- Project [1 AS n#16] +- OneRowRelation$ ), (cte1,'SubqueryAlias cte1 +- 'Project [unresolvedalias(2, None)] +- OneRowRelation$ ), (cte2,'SubqueryAlias cte2 +- 'Project [unresolvedalias(3, None)] +- OneRowRelation$ )] +- 'Project ['n] +- 'UnresolvedRelation `w` ``` After the fix, the output is as shown below. ``` ExecutedCommand +- CreateViewCommand `cte_view`, WITH w AS (SELECT 1 AS n), cte1 (select 2), cte2 as (select 3) SELECT n FROM w, false, false, PersistedView +- CTE [w, cte1, cte2] : :- SubqueryAlias w : : +- Project [1 AS n#16] : : +- OneRowRelation$ : :- 'SubqueryAlias cte1 : : +- 'Project [unresolvedalias(2, None)] : : +- OneRowRelation$ : +- 'SubqueryAlias cte2 : +- 'Project [unresolvedalias(3, None)] : +- OneRowRelation$ +- 'Project ['n] +- 'UnresolvedRelation `w` ``` BTW, this PR also fixes the output of the view type. ### How was this patch tested? Manual Author: gatorsmile Closes #15854 from gatorsmile/cteName. --- .../catalyst/plans/logical/basicLogicalOperators.scala | 8 ++++++++ .../org/apache/spark/sql/execution/command/views.scala | 4 +++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 4dcc2885536eb..4e333d57f3623 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /** * When planning take() or collect() operations, this special node that is inserted at the top of @@ -404,6 +405,13 @@ case class InsertIntoTable( */ case class With(child: LogicalPlan, cteRelations: Seq[(String, SubqueryAlias)]) extends UnaryNode { override def output: Seq[Attribute] = child.output + + override def simpleString: String = { + val cteAliases = Utils.truncatedString(cteRelations.map(_._1), "[", ", ", "]") + s"CTE $cteAliases" + } + + override def innerChildren: Seq[QueryPlan[_]] = cteRelations.map(_._2) } case class WithWindowDefinition( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index 30472ec45ce44..154141bf83c7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -33,7 +33,9 @@ import org.apache.spark.sql.types.MetadataBuilder * ViewType is used to specify the expected view type when we want to create or replace a view in * [[CreateViewCommand]]. */ -sealed trait ViewType +sealed trait ViewType { + override def toString: String = getClass.getSimpleName.stripSuffix("$") +} /** * LocalTempView means session-scoped local temporary views. Its lifetime is the lifetime of the From 0048ce7ce64b02cbb6a1c4a2963a0b1b9541047e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 16 Nov 2016 10:00:59 -0800 Subject: [PATCH 162/198] [SPARK-18459][SPARK-18460][STRUCTUREDSTREAMING] Rename triggerId to batchId and add triggerDetails to json in StreamingQueryStatus ## What changes were proposed in this pull request? SPARK-18459: triggerId seems like a number that should be increasing with each trigger, whether or not there is data in it. However, actually, triggerId increases only where there is a batch of data in a trigger. So its better to rename it to batchId. SPARK-18460: triggerDetails was missing from json representation. Fixed it. ## How was this patch tested? Updated existing unit tests. Author: Tathagata Das Closes #15895 from tdas/SPARK-18459. --- python/pyspark/sql/streaming.py | 6 ++--- .../execution/streaming/StreamMetrics.scala | 8 +++---- .../sql/streaming/StreamingQueryStatus.scala | 4 ++-- .../streaming/StreamMetricsSuite.scala | 8 +++---- .../StreamingQueryListenerSuite.scala | 4 ++-- .../streaming/StreamingQueryStatusSuite.scala | 22 +++++++++++++++++-- 6 files changed, 35 insertions(+), 17 deletions(-) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index f326f16232690..0e4589be976ea 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -212,12 +212,12 @@ def __str__(self): Processing rate 23.5 rows/sec Latency: 345.0 ms Trigger details: + batchId: 5 isDataPresentInTrigger: true isTriggerActive: true latency.getBatch.total: 20 latency.getOffset.total: 10 numRows.input.total: 100 - triggerId: 5 Source statuses [1 source]: Source 1 - MySource1 Available offset: 0 @@ -341,8 +341,8 @@ def triggerDetails(self): If no trigger is currently active, then it will have details of the last completed trigger. >>> sqs.triggerDetails - {u'triggerId': u'5', u'latency.getBatch.total': u'20', u'numRows.input.total': u'100', - u'isTriggerActive': u'true', u'latency.getOffset.total': u'10', + {u'latency.getBatch.total': u'20', u'numRows.input.total': u'100', + u'isTriggerActive': u'true', u'batchId': u'5', u'latency.getOffset.total': u'10', u'isDataPresentInTrigger': u'true'} """ return self._jsqs.triggerDetails() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala index 5645554a58f6e..942e6ed8944be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala @@ -78,13 +78,13 @@ class StreamMetrics(sources: Set[Source], triggerClock: Clock, codahaleSourceNam // =========== Setter methods =========== - def reportTriggerStarted(triggerId: Long): Unit = synchronized { + def reportTriggerStarted(batchId: Long): Unit = synchronized { numInputRows.clear() triggerDetails.clear() sourceTriggerDetails.values.foreach(_.clear()) - reportTriggerDetail(TRIGGER_ID, triggerId) - sources.foreach(s => reportSourceTriggerDetail(s, TRIGGER_ID, triggerId)) + reportTriggerDetail(BATCH_ID, batchId) + sources.foreach(s => reportSourceTriggerDetail(s, BATCH_ID, batchId)) reportTriggerDetail(IS_TRIGGER_ACTIVE, true) currentTriggerStartTimestamp = triggerClock.getTimeMillis() reportTriggerDetail(START_TIMESTAMP, currentTriggerStartTimestamp) @@ -217,7 +217,7 @@ object StreamMetrics extends Logging { } - val TRIGGER_ID = "triggerId" + val BATCH_ID = "batchId" val IS_TRIGGER_ACTIVE = "isTriggerActive" val IS_DATA_PRESENT_IN_TRIGGER = "isDataPresentInTrigger" val STATUS_MESSAGE = "statusMessage" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala index 99c7729d02351..ba732ff7fc2ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala @@ -102,7 +102,7 @@ class StreamingQueryStatus private( ("inputRate" -> JDouble(inputRate)) ~ ("processingRate" -> JDouble(processingRate)) ~ ("latency" -> latency.map(JDouble).getOrElse(JNothing)) ~ - ("triggerDetails" -> JsonProtocol.mapToJson(triggerDetails.asScala)) + ("triggerDetails" -> JsonProtocol.mapToJson(triggerDetails.asScala)) ~ ("sourceStatuses" -> JArray(sourceStatuses.map(_.jsonValue).toList)) ~ ("sinkStatus" -> sinkStatus.jsonValue) } @@ -151,7 +151,7 @@ private[sql] object StreamingQueryStatus { desc = "MySink", offsetDesc = OffsetSeq(Some(LongOffset(1)) :: None :: Nil).toString), triggerDetails = Map( - TRIGGER_ID -> "5", + BATCH_ID -> "5", IS_TRIGGER_ACTIVE -> "true", IS_DATA_PRESENT_IN_TRIGGER -> "true", GET_OFFSET_LATENCY -> "10", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/StreamMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/StreamMetricsSuite.scala index 938423db64745..38c4ece439770 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/StreamMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/StreamMetricsSuite.scala @@ -50,10 +50,10 @@ class StreamMetricsSuite extends SparkFunSuite { assert(sm.currentSourceProcessingRate(source) === 0.0) assert(sm.currentLatency() === None) assert(sm.currentTriggerDetails() === - Map(TRIGGER_ID -> "1", IS_TRIGGER_ACTIVE -> "true", + Map(BATCH_ID -> "1", IS_TRIGGER_ACTIVE -> "true", START_TIMESTAMP -> "0", "key" -> "value")) assert(sm.currentSourceTriggerDetails(source) === - Map(TRIGGER_ID -> "1", "key2" -> "value2")) + Map(BATCH_ID -> "1", "key2" -> "value2")) // Finishing the trigger should calculate the rates, except input rate which needs // to have another trigger interval @@ -66,11 +66,11 @@ class StreamMetricsSuite extends SparkFunSuite { assert(sm.currentSourceProcessingRate(source) === 100.0) assert(sm.currentLatency() === None) assert(sm.currentTriggerDetails() === - Map(TRIGGER_ID -> "1", IS_TRIGGER_ACTIVE -> "false", + Map(BATCH_ID -> "1", IS_TRIGGER_ACTIVE -> "false", START_TIMESTAMP -> "0", FINISH_TIMESTAMP -> "1000", NUM_INPUT_ROWS -> "100", "key" -> "value")) assert(sm.currentSourceTriggerDetails(source) === - Map(TRIGGER_ID -> "1", NUM_SOURCE_INPUT_ROWS -> "100", "key2" -> "value2")) + Map(BATCH_ID -> "1", NUM_SOURCE_INPUT_ROWS -> "100", "key2" -> "value2")) // After another trigger starts, the rates and latencies should not change until // new rows are reported diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index cebb32a0a56cc..98f3bec7080af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -84,7 +84,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { AssertOnLastQueryStatus { status: StreamingQueryStatus => // Check the correctness of the trigger info of the last completed batch reported by // onQueryProgress - assert(status.triggerDetails.containsKey("triggerId")) + assert(status.triggerDetails.containsKey("batchId")) assert(status.triggerDetails.get("isTriggerActive") === "false") assert(status.triggerDetails.get("isDataPresentInTrigger") === "true") @@ -104,7 +104,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { assert(status.triggerDetails.get("numRows.state.aggregation1.updated") === "1") assert(status.sourceStatuses.length === 1) - assert(status.sourceStatuses(0).triggerDetails.containsKey("triggerId")) + assert(status.sourceStatuses(0).triggerDetails.containsKey("batchId")) assert(status.sourceStatuses(0).triggerDetails.get("latency.getOffset.source") === "100") assert(status.sourceStatuses(0).triggerDetails.get("latency.getBatch.source") === "200") assert(status.sourceStatuses(0).triggerDetails.get("numRows.input.source") === "2") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusSuite.scala index 6af19fb0c2327..50a7d92ede9a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusSuite.scala @@ -48,12 +48,12 @@ class StreamingQueryStatusSuite extends SparkFunSuite { | Processing rate 23.5 rows/sec | Latency: 345.0 ms | Trigger details: + | batchId: 5 | isDataPresentInTrigger: true | isTriggerActive: true | latency.getBatch.total: 20 | latency.getOffset.total: 10 | numRows.input.total: 100 - | triggerId: 5 | Source statuses [1 source]: | Source 1 - MySource1 | Available offset: 0 @@ -72,7 +72,11 @@ class StreamingQueryStatusSuite extends SparkFunSuite { test("json") { assert(StreamingQueryStatus.testStatus.json === """ - |{"sourceStatuses":[{"description":"MySource1","offsetDesc":"0","inputRate":15.5, + |{"name":"query","id":1,"timestamp":123,"inputRate":15.5,"processingRate":23.5, + |"latency":345.0,"triggerDetails":{"latency.getBatch.total":"20", + |"numRows.input.total":"100","isTriggerActive":"true","batchId":"5", + |"latency.getOffset.total":"10","isDataPresentInTrigger":"true"}, + |"sourceStatuses":[{"description":"MySource1","offsetDesc":"0","inputRate":15.5, |"processingRate":23.5,"triggerDetails":{"numRows.input.source":"100", |"latency.getOffset.source":"10","latency.getBatch.source":"20"}}], |"sinkStatus":{"description":"MySink","offsetDesc":"[1, -]"}} @@ -84,6 +88,20 @@ class StreamingQueryStatusSuite extends SparkFunSuite { StreamingQueryStatus.testStatus.prettyJson === """ |{ + | "name" : "query", + | "id" : 1, + | "timestamp" : 123, + | "inputRate" : 15.5, + | "processingRate" : 23.5, + | "latency" : 345.0, + | "triggerDetails" : { + | "latency.getBatch.total" : "20", + | "numRows.input.total" : "100", + | "isTriggerActive" : "true", + | "batchId" : "5", + | "latency.getOffset.total" : "10", + | "isDataPresentInTrigger" : "true" + | }, | "sourceStatuses" : [ { | "description" : "MySource1", | "offsetDesc" : "0", From bb6cdfd9a6a6b6c91aada7c3174436146045ed1e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 16 Nov 2016 11:03:10 -0800 Subject: [PATCH 163/198] [SPARK-18461][DOCS][STRUCTUREDSTREAMING] Added more information about monitoring streaming queries ## What changes were proposed in this pull request? screen shot 2016-11-15 at 6 27 32 pm screen shot 2016-11-15 at 6 27 45 pm Author: Tathagata Das Closes #15897 from tdas/SPARK-18461. --- .../structured-streaming-programming-guide.md | 182 +++++++++++++++++- 1 file changed, 179 insertions(+), 3 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index d2545584ae3b0..77b66b3b3a497 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -1087,9 +1087,185 @@ spark.streams().awaitAnyTermination() # block until any one of them terminates
-Finally, for asynchronous monitoring of streaming queries, you can create and attach a `StreamingQueryListener` -([Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryListener)/[Java](api/java/org/apache/spark/sql/streaming/StreamingQueryListener.html) docs), -which will give you regular callback-based updates when queries are started and terminated. + +## Monitoring Streaming Queries +There are two ways you can monitor queries. You can directly get the current status +of an active query using `streamingQuery.status`, which will return a `StreamingQueryStatus` object +([Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryStatus)/[Java](api/java/org/apache/spark/sql/streaming/StreamingQueryStatus.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.StreamingQueryStatus) docs) +that has all the details like current ingestion rates, processing rates, average latency, +details of the currently active trigger, etc. + +
+
+ +{% highlight scala %} +val query: StreamingQuery = ... + +println(query.status) + +/* Will print the current status of the query + +Status of query 'queryName' + Query id: 1 + Status timestamp: 123 + Input rate: 15.5 rows/sec + Processing rate 23.5 rows/sec + Latency: 345.0 ms + Trigger details: + batchId: 5 + isDataPresentInTrigger: true + isTriggerActive: true + latency.getBatch.total: 20 + latency.getOffset.total: 10 + numRows.input.total: 100 + Source statuses [1 source]: + Source 1 - MySource1 + Available offset: 0 + Input rate: 15.5 rows/sec + Processing rate: 23.5 rows/sec + Trigger details: + numRows.input.source: 100 + latency.getOffset.source: 10 + latency.getBatch.source: 20 + Sink status - MySink + Committed offsets: [1, -] +*/ +{% endhighlight %} + +
+
+ +{% highlight java %} +StreamingQuery query = ... + +System.out.println(query.status); + +/* Will print the current status of the query + +Status of query 'queryName' + Query id: 1 + Status timestamp: 123 + Input rate: 15.5 rows/sec + Processing rate 23.5 rows/sec + Latency: 345.0 ms + Trigger details: + batchId: 5 + isDataPresentInTrigger: true + isTriggerActive: true + latency.getBatch.total: 20 + latency.getOffset.total: 10 + numRows.input.total: 100 + Source statuses [1 source]: + Source 1 - MySource1 + Available offset: 0 + Input rate: 15.5 rows/sec + Processing rate: 23.5 rows/sec + Trigger details: + numRows.input.source: 100 + latency.getOffset.source: 10 + latency.getBatch.source: 20 + Sink status - MySink + Committed offsets: [1, -] +*/ +{% endhighlight %} + +
+
+ +{% highlight python %} +query = ... // a StreamingQuery + +print(query.status) + +''' +Will print the current status of the query + +Status of query 'queryName' + Query id: 1 + Status timestamp: 123 + Input rate: 15.5 rows/sec + Processing rate 23.5 rows/sec + Latency: 345.0 ms + Trigger details: + batchId: 5 + isDataPresentInTrigger: true + isTriggerActive: true + latency.getBatch.total: 20 + latency.getOffset.total: 10 + numRows.input.total: 100 + Source statuses [1 source]: + Source 1 - MySource1 + Available offset: 0 + Input rate: 15.5 rows/sec + Processing rate: 23.5 rows/sec + Trigger details: + numRows.input.source: 100 + latency.getOffset.source: 10 + latency.getBatch.source: 20 + Sink status - MySink + Committed offsets: [1, -] +''' +{% endhighlight %} + +
+
+ + +You can also asynchronously monitor all queries associated with a +`SparkSession` by attaching a `StreamingQueryListener` +([Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryListener)/[Java](api/java/org/apache/spark/sql/streaming/StreamingQueryListener.html) docs). +Once you attach your custom `StreamingQueryListener` object with +`sparkSession.streams.attachListener()`, you will get callbacks when a query is started and +stopped and when there is progress made in an active query. Here is an example, + +
+
+ +{% highlight scala %} +val spark: SparkSession = ... + +spark.streams.addListener(new StreamingQueryListener() { + + override def onQueryStarted(queryStarted: QueryStartedEvent): Unit = { + println("Query started: " + queryTerminated.queryStatus.name) + } + override def onQueryTerminated(queryTerminated: QueryTerminatedEvent): Unit = { + println("Query terminated: " + queryTerminated.queryStatus.name) + } + override def onQueryProgress(queryProgress: QueryProgressEvent): Unit = { + println("Query made progress: " + queryProgress.queryStatus) + } +}) +{% endhighlight %} + +
+
+ +{% highlight java %} +SparkSession spark = ... + +spark.streams.addListener(new StreamingQueryListener() { + + @Overrides void onQueryStarted(QueryStartedEvent queryStarted) { + System.out.println("Query started: " + queryTerminated.queryStatus.name); + } + @Overrides void onQueryTerminated(QueryTerminatedEvent queryTerminated) { + System.out.println("Query terminated: " + queryTerminated.queryStatus.name); + } + @Overrides void onQueryProgress(QueryProgressEvent queryProgress) { + System.out.println("Query made progress: " + queryProgress.queryStatus); + } +}); +{% endhighlight %} + +
+
+{% highlight bash %} +Not available in Python. +{% endhighlight %} + +
+
## Recovering from Failures with Checkpointing In case of a failure or intentional shutdown, you can recover the previous progress and state of a previous query, and continue where it left off. This is done using checkpointing and write ahead logs. You can configure a query with a checkpoint location, and the query will save all the progress information (i.e. range of offsets processed in each trigger) and the running aggregates (e.g. word counts in the [quick example](#quick-example)) to the checkpoint location. As of Spark 2.0, this checkpoint location has to be a path in an HDFS compatible file system, and can be set as an option in the DataStreamWriter when [starting a query](#starting-streaming-queries). From a36a76ac43c36a3b897a748bd9f138b629dbc684 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 16 Nov 2016 14:22:15 -0800 Subject: [PATCH 164/198] [SPARK-1267][SPARK-18129] Allow PySpark to be pip installed ## What changes were proposed in this pull request? This PR aims to provide a pip installable PySpark package. This does a bunch of work to copy the jars over and package them with the Python code (to prevent challenges from trying to use different versions of the Python code with different versions of the JAR). It does not currently publish to PyPI but that is the natural follow up (SPARK-18129). Done: - pip installable on conda [manual tested] - setup.py installed on a non-pip managed system (RHEL) with YARN [manual tested] - Automated testing of this (virtualenv) - packaging and signing with release-build* Possible follow up work: - release-build update to publish to PyPI (SPARK-18128) - figure out who owns the pyspark package name on prod PyPI (is it someone with in the project or should we ask PyPI or should we choose a different name to publish with like ApachePySpark?) - Windows support and or testing ( SPARK-18136 ) - investigate details of wheel caching and see if we can avoid cleaning the wheel cache during our test - consider how we want to number our dev/snapshot versions Explicitly out of scope: - Using pip installed PySpark to start a standalone cluster - Using pip installed PySpark for non-Python Spark programs *I've done some work to test release-build locally but as a non-committer I've just done local testing. ## How was this patch tested? Automated testing with virtualenv, manual testing with conda, a system wide install, and YARN integration. release-build changes tested locally as a non-committer (no testing of upload artifacts to Apache staging websites) Author: Holden Karau Author: Juliet Hougland Author: Juliet Hougland Closes #15659 from holdenk/SPARK-1267-pip-install-pyspark. --- .gitignore | 2 + bin/beeline | 2 +- bin/find-spark-home | 41 ++++ bin/load-spark-env.sh | 2 +- bin/pyspark | 6 +- bin/run-example | 2 +- bin/spark-class | 6 +- bin/spark-shell | 4 +- bin/spark-sql | 2 +- bin/spark-submit | 2 +- bin/sparkR | 2 +- dev/create-release/release-build.sh | 26 ++- dev/create-release/release-tag.sh | 11 +- dev/lint-python | 4 +- dev/make-distribution.sh | 16 +- dev/pip-sanity-check.py | 36 +++ dev/run-pip-tests | 115 ++++++++++ dev/run-tests-jenkins.py | 1 + dev/run-tests.py | 7 + dev/sparktestsupport/__init__.py | 1 + docs/building-spark.md | 8 + docs/index.md | 4 +- .../spark/launcher/CommandBuilderUtils.java | 2 +- python/MANIFEST.in | 22 ++ python/README.md | 32 +++ python/pyspark/__init__.py | 1 + python/pyspark/find_spark_home.py | 74 +++++++ python/pyspark/java_gateway.py | 3 +- python/pyspark/version.py | 19 ++ python/setup.cfg | 22 ++ python/setup.py | 209 ++++++++++++++++++ 31 files changed, 660 insertions(+), 24 deletions(-) create mode 100755 bin/find-spark-home create mode 100644 dev/pip-sanity-check.py create mode 100755 dev/run-pip-tests create mode 100644 python/MANIFEST.in create mode 100644 python/README.md create mode 100755 python/pyspark/find_spark_home.py create mode 100644 python/pyspark/version.py create mode 100644 python/setup.cfg create mode 100644 python/setup.py diff --git a/.gitignore b/.gitignore index 39d17e1793f77..5634a434db0c0 100644 --- a/.gitignore +++ b/.gitignore @@ -57,6 +57,8 @@ project/plugins/project/build.properties project/plugins/src_managed/ project/plugins/target/ python/lib/pyspark.zip +python/deps +python/pyspark/python reports/ scalastyle-on-compile.generated.xml scalastyle-output.xml diff --git a/bin/beeline b/bin/beeline index 1627626941a73..058534699e44b 100755 --- a/bin/beeline +++ b/bin/beeline @@ -25,7 +25,7 @@ set -o posix # Figure out if SPARK_HOME is set if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi CLASS="org.apache.hive.beeline.BeeLine" diff --git a/bin/find-spark-home b/bin/find-spark-home new file mode 100755 index 0000000000000..fa78407d4175a --- /dev/null +++ b/bin/find-spark-home @@ -0,0 +1,41 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Attempts to find a proper value for SPARK_HOME. Should be included using "source" directive. + +FIND_SPARK_HOME_PYTHON_SCRIPT="$(cd "$(dirname "$0")"; pwd)/find_spark_home.py" + +# Short cirtuit if the user already has this set. +if [ ! -z "${SPARK_HOME}" ]; then + exit 0 +elif [ ! -f "$FIND_SPARK_HOME_PYTHON_SCRIPT" ]; then + # If we are not in the same directory as find_spark_home.py we are not pip installed so we don't + # need to search the different Python directories for a Spark installation. + # Note only that, if the user has pip installed PySpark but is directly calling pyspark-shell or + # spark-submit in another directory we want to use that version of PySpark rather than the + # pip installed version of PySpark. + export SPARK_HOME="$(cd "$(dirname "$0")"/..; pwd)" +else + # We are pip installed, use the Python script to resolve a reasonable SPARK_HOME + # Default to standard python interpreter unless told otherwise + if [[ -z "$PYSPARK_DRIVER_PYTHON" ]]; then + PYSPARK_DRIVER_PYTHON="${PYSPARK_PYTHON:-"python"}" + fi + export SPARK_HOME=$($PYSPARK_DRIVER_PYTHON "$FIND_SPARK_HOME_PYTHON_SCRIPT") +fi diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh index eaea964ed5b3d..8a2f709960a25 100644 --- a/bin/load-spark-env.sh +++ b/bin/load-spark-env.sh @@ -23,7 +23,7 @@ # Figure out where Spark is installed if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi if [ -z "$SPARK_ENV_LOADED" ]; then diff --git a/bin/pyspark b/bin/pyspark index d6b3ab0a44321..98387c2ec5b8a 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi source "${SPARK_HOME}"/bin/load-spark-env.sh @@ -46,7 +46,7 @@ WORKS_WITH_IPYTHON=$(python -c 'import sys; print(sys.version_info >= (2, 7, 0)) # Determine the Python executable to use for the executors: if [[ -z "$PYSPARK_PYTHON" ]]; then - if [[ $PYSPARK_DRIVER_PYTHON == *ipython* && ! WORKS_WITH_IPYTHON ]]; then + if [[ $PYSPARK_DRIVER_PYTHON == *ipython* && ! $WORKS_WITH_IPYTHON ]]; then echo "IPython requires Python 2.7+; please install python2.7 or set PYSPARK_PYTHON" 1>&2 exit 1 else @@ -68,7 +68,7 @@ if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR export PYTHONHASHSEED=0 - exec "$PYSPARK_DRIVER_PYTHON" -m $1 + exec "$PYSPARK_DRIVER_PYTHON" -m "$1" exit fi diff --git a/bin/run-example b/bin/run-example index dd0e3c4120260..4ba5399311d33 100755 --- a/bin/run-example +++ b/bin/run-example @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi export _SPARK_CMD_USAGE="Usage: ./bin/run-example [options] example-class [example args]" diff --git a/bin/spark-class b/bin/spark-class index 377c8d1add3f6..77ea40cc37946 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi . "${SPARK_HOME}"/bin/load-spark-env.sh @@ -27,7 +27,7 @@ fi if [ -n "${JAVA_HOME}" ]; then RUNNER="${JAVA_HOME}/bin/java" else - if [ `command -v java` ]; then + if [ "$(command -v java)" ]; then RUNNER="java" else echo "JAVA_HOME is not set" >&2 @@ -36,7 +36,7 @@ else fi # Find Spark jars. -if [ -f "${SPARK_HOME}/RELEASE" ]; then +if [ -d "${SPARK_HOME}/jars" ]; then SPARK_JARS_DIR="${SPARK_HOME}/jars" else SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION/jars" diff --git a/bin/spark-shell b/bin/spark-shell index 6583b5bd880ee..421f36cac3d47 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -21,7 +21,7 @@ # Shell script for starting the Spark Shell REPL cygwin=false -case "`uname`" in +case "$(uname)" in CYGWIN*) cygwin=true;; esac @@ -29,7 +29,7 @@ esac set -o posix if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi export _SPARK_CMD_USAGE="Usage: ./bin/spark-shell [options]" diff --git a/bin/spark-sql b/bin/spark-sql index 970d12cbf51dd..b08b944ebd319 100755 --- a/bin/spark-sql +++ b/bin/spark-sql @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi export _SPARK_CMD_USAGE="Usage: ./bin/spark-sql [options] [cli option]" diff --git a/bin/spark-submit b/bin/spark-submit index 023f9c162f4b8..4e9d3614e6370 100755 --- a/bin/spark-submit +++ b/bin/spark-submit @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi # disable randomized hash for string in Python 3.3+ diff --git a/bin/sparkR b/bin/sparkR index 2c07a82e2173b..29ab10df8ab6d 100755 --- a/bin/sparkR +++ b/bin/sparkR @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi source "${SPARK_HOME}"/bin/load-spark-env.sh diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 81f0d63054e29..1dbfa3b6e361b 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -162,14 +162,35 @@ if [[ "$1" == "package" ]]; then export ZINC_PORT=$ZINC_PORT echo "Creating distribution: $NAME ($FLAGS)" + # Write out the NAME and VERSION to PySpark version info we rewrite the - into a . and SNAPSHOT + # to dev0 to be closer to PEP440. We use the NAME as a "local version". + PYSPARK_VERSION=`echo "$SPARK_VERSION+$NAME" | sed -r "s/-/./" | sed -r "s/SNAPSHOT/dev0/"` + echo "__version__='$PYSPARK_VERSION'" > python/pyspark/version.py + # Get maven home set by MVN MVN_HOME=`$MVN -version 2>&1 | grep 'Maven home' | awk '{print $NF}'` - ./dev/make-distribution.sh --name $NAME --mvn $MVN_HOME/bin/mvn --tgz $FLAGS \ + echo "Creating distribution" + ./dev/make-distribution.sh --name $NAME --mvn $MVN_HOME/bin/mvn --tgz --pip $FLAGS \ -DzincPort=$ZINC_PORT 2>&1 > ../binary-release-$NAME.log cd .. - cp spark-$SPARK_VERSION-bin-$NAME/spark-$SPARK_VERSION-bin-$NAME.tgz . + echo "Copying and signing python distribution" + PYTHON_DIST_NAME=pyspark-$PYSPARK_VERSION.tar.gz + cp spark-$SPARK_VERSION-bin-$NAME/python/dist/$PYTHON_DIST_NAME . + + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \ + --output $PYTHON_DIST_NAME.asc \ + --detach-sig $PYTHON_DIST_NAME + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ + MD5 $PYTHON_DIST_NAME > \ + $PYTHON_DIST_NAME.md5 + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ + SHA512 $PYTHON_DIST_NAME > \ + $PYTHON_DIST_NAME.sha + + echo "Copying and signing regular binary distribution" + cp spark-$SPARK_VERSION-bin-$NAME/spark-$SPARK_VERSION-bin-$NAME.tgz . echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \ --output spark-$SPARK_VERSION-bin-$NAME.tgz.asc \ --detach-sig spark-$SPARK_VERSION-bin-$NAME.tgz @@ -208,6 +229,7 @@ if [[ "$1" == "package" ]]; then # Re-upload a second time and leave the files in the timestamped upload directory: LFTP mkdir -p $dest_dir LFTP mput -O $dest_dir 'spark-*' + LFTP mput -O $dest_dir 'pyspark-*' exit 0 fi diff --git a/dev/create-release/release-tag.sh b/dev/create-release/release-tag.sh index b7e5100ca7408..370a62ce15bc4 100755 --- a/dev/create-release/release-tag.sh +++ b/dev/create-release/release-tag.sh @@ -65,6 +65,7 @@ sed -i".tmp1" 's/Version.*$/Version: '"$RELEASE_VERSION"'/g' R/pkg/DESCRIPTION # Set the release version in docs sed -i".tmp1" 's/SPARK_VERSION:.*$/SPARK_VERSION: '"$RELEASE_VERSION"'/g' docs/_config.yml sed -i".tmp2" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$RELEASE_VERSION"'/g' docs/_config.yml +sed -i".tmp3" 's/__version__ = .*$/__version__ = "'"$RELEASE_VERSION"'"/' python/pyspark/version.py git commit -a -m "Preparing Spark release $RELEASE_TAG" echo "Creating tag $RELEASE_TAG at the head of $GIT_BRANCH" @@ -74,12 +75,16 @@ git tag $RELEASE_TAG $MVN versions:set -DnewVersion=$NEXT_VERSION | grep -v "no value" # silence logs # Remove -SNAPSHOT before setting the R version as R expects version strings to only have numbers R_NEXT_VERSION=`echo $NEXT_VERSION | sed 's/-SNAPSHOT//g'` -sed -i".tmp2" 's/Version.*$/Version: '"$R_NEXT_VERSION"'/g' R/pkg/DESCRIPTION +sed -i".tmp4" 's/Version.*$/Version: '"$R_NEXT_VERSION"'/g' R/pkg/DESCRIPTION +# Write out the R_NEXT_VERSION to PySpark version info we use dev0 instead of SNAPSHOT to be closer +# to PEP440. +sed -i".tmp5" 's/__version__ = .*$/__version__ = "'"$R_NEXT_VERSION.dev0"'"/' python/pyspark/version.py + # Update docs with next version -sed -i".tmp3" 's/SPARK_VERSION:.*$/SPARK_VERSION: '"$NEXT_VERSION"'/g' docs/_config.yml +sed -i".tmp6" 's/SPARK_VERSION:.*$/SPARK_VERSION: '"$NEXT_VERSION"'/g' docs/_config.yml # Use R version for short version -sed -i".tmp4" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$R_NEXT_VERSION"'/g' docs/_config.yml +sed -i".tmp7" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$R_NEXT_VERSION"'/g' docs/_config.yml git commit -a -m "Preparing development version $NEXT_VERSION" diff --git a/dev/lint-python b/dev/lint-python index 63487043a50b6..3f878c2dad6b1 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -20,7 +20,9 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" PATHS_TO_CHECK="./python/pyspark/ ./examples/src/main/python/ ./dev/sparktestsupport" -PATHS_TO_CHECK="$PATHS_TO_CHECK ./dev/run-tests.py ./python/run-tests.py ./dev/run-tests-jenkins.py" +# TODO: fix pep8 errors with the rest of the Python scripts under dev +PATHS_TO_CHECK="$PATHS_TO_CHECK ./dev/run-tests.py ./python/*.py ./dev/run-tests-jenkins.py" +PATHS_TO_CHECK="$PATHS_TO_CHECK ./dev/pip-sanity-check.py" PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt" PYLINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/pylint-report.txt" PYLINT_INSTALL_INFO="$SPARK_ROOT_DIR/dev/pylint-info.txt" diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index 9be4fdfa51c93..49b46fbc3fb27 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -33,6 +33,7 @@ SPARK_HOME="$(cd "`dirname "$0"`/.."; pwd)" DISTDIR="$SPARK_HOME/dist" MAKE_TGZ=false +MAKE_PIP=false NAME=none MVN="$SPARK_HOME/build/mvn" @@ -40,7 +41,7 @@ function exit_with_usage { echo "make-distribution.sh - tool for making binary distributions of Spark" echo "" echo "usage:" - cl_options="[--name] [--tgz] [--mvn ]" + cl_options="[--name] [--tgz] [--pip] [--mvn ]" echo "make-distribution.sh $cl_options " echo "See Spark's \"Building Spark\" doc for correct Maven options." echo "" @@ -67,6 +68,9 @@ while (( "$#" )); do --tgz) MAKE_TGZ=true ;; + --pip) + MAKE_PIP=true + ;; --mvn) MVN="$2" shift @@ -201,6 +205,16 @@ fi # Copy data files cp -r "$SPARK_HOME/data" "$DISTDIR" +# Make pip package +if [ "$MAKE_PIP" == "true" ]; then + echo "Building python distribution package" + cd $SPARK_HOME/python + python setup.py sdist + cd .. +else + echo "Skipping creating pip installable PySpark" +fi + # Copy other things mkdir "$DISTDIR"/conf cp "$SPARK_HOME"/conf/*.template "$DISTDIR"/conf diff --git a/dev/pip-sanity-check.py b/dev/pip-sanity-check.py new file mode 100644 index 0000000000000..430c2ab52766a --- /dev/null +++ b/dev/pip-sanity-check.py @@ -0,0 +1,36 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark.sql import SparkSession +import sys + +if __name__ == "__main__": + spark = SparkSession\ + .builder\ + .appName("PipSanityCheck")\ + .getOrCreate() + sc = spark.sparkContext + rdd = sc.parallelize(range(100), 10) + value = rdd.reduce(lambda x, y: x + y) + if (value != 4950): + print("Value {0} did not match expected value.".format(value), file=sys.stderr) + sys.exit(-1) + print("Successfully ran pip sanity check") + + spark.stop() diff --git a/dev/run-pip-tests b/dev/run-pip-tests new file mode 100755 index 0000000000000..e1da18e60bb3d --- /dev/null +++ b/dev/run-pip-tests @@ -0,0 +1,115 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Stop on error +set -e +# Set nullglob for when we are checking existence based on globs +shopt -s nullglob + +FWDIR="$(cd "$(dirname "$0")"/..; pwd)" +cd "$FWDIR" + +echo "Constucting virtual env for testing" +VIRTUALENV_BASE=$(mktemp -d) + +# Clean up the virtual env enviroment used if we created one. +function delete_virtualenv() { + echo "Cleaning up temporary directory - $VIRTUALENV_BASE" + rm -rf "$VIRTUALENV_BASE" +} +trap delete_virtualenv EXIT + +# Some systems don't have pip or virtualenv - in those cases our tests won't work. +if ! hash virtualenv 2>/dev/null; then + echo "Missing virtualenv skipping pip installability tests." + exit 0 +fi +if ! hash pip 2>/dev/null; then + echo "Missing pip, skipping pip installability tests." + exit 0 +fi + +# Figure out which Python execs we should test pip installation with +PYTHON_EXECS=() +if hash python2 2>/dev/null; then + # We do this since we are testing with virtualenv and the default virtual env python + # is in /usr/bin/python + PYTHON_EXECS+=('python2') +elif hash python 2>/dev/null; then + # If python2 isn't installed fallback to python if available + PYTHON_EXECS+=('python') +fi +if hash python3 2>/dev/null; then + PYTHON_EXECS+=('python3') +fi + +# Determine which version of PySpark we are building for archive name +PYSPARK_VERSION=$(python -c "exec(open('python/pyspark/version.py').read());print __version__") +PYSPARK_DIST="$FWDIR/python/dist/pyspark-$PYSPARK_VERSION.tar.gz" +# The pip install options we use for all the pip commands +PIP_OPTIONS="--upgrade --no-cache-dir --force-reinstall " +# Test both regular user and edit/dev install modes. +PIP_COMMANDS=("pip install $PIP_OPTIONS $PYSPARK_DIST" + "pip install $PIP_OPTIONS -e python/") + +for python in "${PYTHON_EXECS[@]}"; do + for install_command in "${PIP_COMMANDS[@]}"; do + echo "Testing pip installation with python $python" + # Create a temp directory for us to work in and save its name to a file for cleanup + echo "Using $VIRTUALENV_BASE for virtualenv" + VIRTUALENV_PATH="$VIRTUALENV_BASE"/$python + rm -rf "$VIRTUALENV_PATH" + mkdir -p "$VIRTUALENV_PATH" + virtualenv --python=$python "$VIRTUALENV_PATH" + source "$VIRTUALENV_PATH"/bin/activate + # Upgrade pip + pip install --upgrade pip + + echo "Creating pip installable source dist" + cd "$FWDIR"/python + $python setup.py sdist + + + echo "Installing dist into virtual env" + cd dist + # Verify that the dist directory only contains one thing to install + sdists=(*.tar.gz) + if [ ${#sdists[@]} -ne 1 ]; then + echo "Unexpected number of targets found in dist directory - please cleanup existing sdists first." + exit -1 + fi + # Do the actual installation + cd "$FWDIR" + $install_command + + cd / + + echo "Run basic sanity check on pip installed version with spark-submit" + spark-submit "$FWDIR"/dev/pip-sanity-check.py + echo "Run basic sanity check with import based" + python "$FWDIR"/dev/pip-sanity-check.py + echo "Run the tests for context.py" + python "$FWDIR"/python/pyspark/context.py + + cd "$FWDIR" + + done +done + +exit 0 diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index a48d918f9dc1f..1d1e72faccf2a 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -128,6 +128,7 @@ def run_tests(tests_timeout): ERROR_CODES["BLOCK_MIMA"]: 'MiMa tests', ERROR_CODES["BLOCK_SPARK_UNIT_TESTS"]: 'Spark unit tests', ERROR_CODES["BLOCK_PYSPARK_UNIT_TESTS"]: 'PySpark unit tests', + ERROR_CODES["BLOCK_PYSPARK_PIP_TESTS"]: 'PySpark pip packaging tests', ERROR_CODES["BLOCK_SPARKR_UNIT_TESTS"]: 'SparkR unit tests', ERROR_CODES["BLOCK_TIMEOUT"]: 'from timeout after a configured wait of \`%s\`' % ( tests_timeout) diff --git a/dev/run-tests.py b/dev/run-tests.py index 5d661f5f1a1c5..ab285ac96af7e 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -432,6 +432,12 @@ def run_python_tests(test_modules, parallelism): run_cmd(command) +def run_python_packaging_tests(): + set_title_and_block("Running PySpark packaging tests", "BLOCK_PYSPARK_PIP_TESTS") + command = [os.path.join(SPARK_HOME, "dev", "run-pip-tests")] + run_cmd(command) + + def run_build_tests(): set_title_and_block("Running build tests", "BLOCK_BUILD_TESTS") run_cmd([os.path.join(SPARK_HOME, "dev", "test-dependencies.sh")]) @@ -583,6 +589,7 @@ def main(): modules_with_python_tests = [m for m in test_modules if m.python_test_goals] if modules_with_python_tests: run_python_tests(modules_with_python_tests, opts.parallelism) + run_python_packaging_tests() if any(m.should_run_r_tests for m in test_modules): run_sparkr_tests() diff --git a/dev/sparktestsupport/__init__.py b/dev/sparktestsupport/__init__.py index 89015f8c4fb9c..38f25da41f775 100644 --- a/dev/sparktestsupport/__init__.py +++ b/dev/sparktestsupport/__init__.py @@ -33,5 +33,6 @@ "BLOCK_SPARKR_UNIT_TESTS": 20, "BLOCK_JAVA_STYLE": 21, "BLOCK_BUILD_TESTS": 22, + "BLOCK_PYSPARK_PIP_TESTS": 23, "BLOCK_TIMEOUT": 124 } diff --git a/docs/building-spark.md b/docs/building-spark.md index 2b404bd3e116c..88da0cc9c3bbf 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -265,6 +265,14 @@ or Java 8 tests are automatically enabled when a Java 8 JDK is detected. If you have JDK 8 installed but it is not the system default, you can set JAVA_HOME to point to JDK 8 before running the tests. +## PySpark pip installable + +If you are building Spark for use in a Python environment and you wish to pip install it, you will first need to build the Spark JARs as described above. Then you can construct an sdist package suitable for setup.py and pip installable package. + + cd python; python setup.py sdist + +**Note:** Due to packaging requirements you can not directly pip install from the Python directory, rather you must first build the sdist package as described above. + ## PySpark Tests with Maven If you are building PySpark and wish to run the PySpark tests you will need to build Spark with Hive support. diff --git a/docs/index.md b/docs/index.md index fe51439ae08d7..39de11de854a7 100644 --- a/docs/index.md +++ b/docs/index.md @@ -14,7 +14,9 @@ It also supports a rich set of higher-level tools including [Spark SQL](sql-prog Get Spark from the [downloads page](http://spark.apache.org/downloads.html) of the project website. This documentation is for Spark version {{site.SPARK_VERSION}}. Spark uses Hadoop's client libraries for HDFS and YARN. Downloads are pre-packaged for a handful of popular Hadoop versions. Users can also download a "Hadoop free" binary and run Spark with any Hadoop version -[by augmenting Spark's classpath](hadoop-provided.html). +[by augmenting Spark's classpath](hadoop-provided.html). +Scala and Java users can include Spark in their projects using its maven cooridnates and in the future Python users can also install Spark from PyPI. + If you'd like to build Spark from source, visit [Building Spark](building-spark.html). diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index 62a22008d0d5d..250b2a882feb5 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -357,7 +357,7 @@ static int javaMajorVersion(String javaVersion) { static String findJarsDir(String sparkHome, String scalaVersion, boolean failIfNotFound) { // TODO: change to the correct directory once the assembly build is changed. File libdir; - if (new File(sparkHome, "RELEASE").isFile()) { + if (new File(sparkHome, "jars").isDirectory()) { libdir = new File(sparkHome, "jars"); checkState(!failIfNotFound || libdir.isDirectory(), "Library directory '%s' does not exist.", diff --git a/python/MANIFEST.in b/python/MANIFEST.in new file mode 100644 index 0000000000000..bbcce1baa439d --- /dev/null +++ b/python/MANIFEST.in @@ -0,0 +1,22 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +global-exclude *.py[cod] __pycache__ .DS_Store +recursive-include deps/jars *.jar +graft deps/bin +recursive-include deps/examples *.py +recursive-include lib *.zip +include README.md diff --git a/python/README.md b/python/README.md new file mode 100644 index 0000000000000..0a5c8010b8486 --- /dev/null +++ b/python/README.md @@ -0,0 +1,32 @@ +# Apache Spark + +Spark is a fast and general cluster computing system for Big Data. It provides +high-level APIs in Scala, Java, Python, and R, and an optimized engine that +supports general computation graphs for data analysis. It also supports a +rich set of higher-level tools including Spark SQL for SQL and DataFrames, +MLlib for machine learning, GraphX for graph processing, +and Spark Streaming for stream processing. + + + +## Online Documentation + +You can find the latest Spark documentation, including a programming +guide, on the [project web page](http://spark.apache.org/documentation.html) + + +## Python Packaging + +This README file only contains basic information related to pip installed PySpark. +This packaging is currently experimental and may change in future versions (although we will do our best to keep compatibility). +Using PySpark requires the Spark JARs, and if you are building this from source please see the builder instructions at +["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html). + +The Python packaging for Spark is not intended to replace all of the other use cases. This Python packaged version of Spark is suitable for interacting with an existing cluster (be it Spark standalone, YARN, or Mesos) - but does not contain the tools required to setup your own standalone Spark cluster. You can download the full version of Spark from the [Apache Spark downloads page](http://spark.apache.org/downloads.html). + + +**NOTE:** If you are using this with a Spark standalone cluster you must ensure that the version (including minor version) matches or you may experience odd errors. + +## Python Requirements + +At its core PySpark depends on Py4J (currently version 0.10.4), but additional sub-packages have their own requirements (including numpy and pandas). \ No newline at end of file diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index ec1687415a7f6..5f93586a48a5a 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -50,6 +50,7 @@ from pyspark.serializers import MarshalSerializer, PickleSerializer from pyspark.status import * from pyspark.profiler import Profiler, BasicProfiler +from pyspark.version import __version__ def since(version): diff --git a/python/pyspark/find_spark_home.py b/python/pyspark/find_spark_home.py new file mode 100755 index 0000000000000..212a618b767ab --- /dev/null +++ b/python/pyspark/find_spark_home.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This script attempt to determine the correct setting for SPARK_HOME given +# that Spark may have been installed on the system with pip. + +from __future__ import print_function +import os +import sys + + +def _find_spark_home(): + """Find the SPARK_HOME.""" + # If the enviroment has SPARK_HOME set trust it. + if "SPARK_HOME" in os.environ: + return os.environ["SPARK_HOME"] + + def is_spark_home(path): + """Takes a path and returns true if the provided path could be a reasonable SPARK_HOME""" + return (os.path.isfile(os.path.join(path, "bin/spark-submit")) and + (os.path.isdir(os.path.join(path, "jars")) or + os.path.isdir(os.path.join(path, "assembly")))) + + paths = ["../", os.path.dirname(os.path.realpath(__file__))] + + # Add the path of the PySpark module if it exists + if sys.version < "3": + import imp + try: + module_home = imp.find_module("pyspark")[1] + paths.append(module_home) + # If we are installed in edit mode also look two dirs up + paths.append(os.path.join(module_home, "../../")) + except ImportError: + # Not pip installed no worries + pass + else: + from importlib.util import find_spec + try: + module_home = os.path.dirname(find_spec("pyspark").origin) + paths.append(module_home) + # If we are installed in edit mode also look two dirs up + paths.append(os.path.join(module_home, "../../")) + except ImportError: + # Not pip installed no worries + pass + + # Normalize the paths + paths = [os.path.abspath(p) for p in paths] + + try: + return next(path for path in paths if is_spark_home(path)) + except StopIteration: + print("Could not find valid SPARK_HOME while searching {0}".format(paths), file=sys.stderr) + exit(-1) + +if __name__ == "__main__": + print(_find_spark_home()) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index c1cf843d84388..3c783ae541a1f 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -29,6 +29,7 @@ xrange = range from py4j.java_gateway import java_import, JavaGateway, GatewayClient +from pyspark.find_spark_home import _find_spark_home from pyspark.serializers import read_int @@ -41,7 +42,7 @@ def launch_gateway(conf=None): if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) else: - SPARK_HOME = os.environ["SPARK_HOME"] + SPARK_HOME = _find_spark_home() # Launch the Py4j gateway using Spark's run command so that we pick up the # proper classpath and settings from spark-env.sh on_windows = platform.system() == "Windows" diff --git a/python/pyspark/version.py b/python/pyspark/version.py new file mode 100644 index 0000000000000..08a301695fda7 --- /dev/null +++ b/python/pyspark/version.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__version__ = "2.1.0.dev0" diff --git a/python/setup.cfg b/python/setup.cfg new file mode 100644 index 0000000000000..d100b932bbafc --- /dev/null +++ b/python/setup.cfg @@ -0,0 +1,22 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +[bdist_wheel] +universal = 1 + +[metadata] +description-file = README.md diff --git a/python/setup.py b/python/setup.py new file mode 100644 index 0000000000000..625aea04073f5 --- /dev/null +++ b/python/setup.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import glob +import os +import sys +from setuptools import setup, find_packages +from shutil import copyfile, copytree, rmtree + +if sys.version_info < (2, 7): + print("Python versions prior to 2.7 are not supported for pip installed PySpark.", + file=sys.stderr) + exit(-1) + +try: + exec(open('pyspark/version.py').read()) +except IOError: + print("Failed to load PySpark version file for packaging. You must be in Spark's python dir.", + file=sys.stderr) + sys.exit(-1) +VERSION = __version__ +# A temporary path so we can access above the Python project root and fetch scripts and jars we need +TEMP_PATH = "deps" +SPARK_HOME = os.path.abspath("../") + +# Provide guidance about how to use setup.py +incorrect_invocation_message = """ +If you are installing pyspark from spark source, you must first build Spark and +run sdist. + + To build Spark with maven you can run: + ./build/mvn -DskipTests clean package + Building the source dist is done in the Python directory: + cd python + python setup.py sdist + pip install dist/*.tar.gz""" + +# Figure out where the jars are we need to package with PySpark. +JARS_PATH = glob.glob(os.path.join(SPARK_HOME, "assembly/target/scala-*/jars/")) + +if len(JARS_PATH) == 1: + JARS_PATH = JARS_PATH[0] +elif (os.path.isfile("../RELEASE") and len(glob.glob("../jars/spark*core*.jar")) == 1): + # Release mode puts the jars in a jars directory + JARS_PATH = os.path.join(SPARK_HOME, "jars") +elif len(JARS_PATH) > 1: + print("Assembly jars exist for multiple scalas ({0}), please cleanup assembly/target".format( + JARS_PATH), file=sys.stderr) + sys.exit(-1) +elif len(JARS_PATH) == 0 and not os.path.exists(TEMP_PATH): + print(incorrect_invocation_message, file=sys.stderr) + sys.exit(-1) + +EXAMPLES_PATH = os.path.join(SPARK_HOME, "examples/src/main/python") +SCRIPTS_PATH = os.path.join(SPARK_HOME, "bin") +SCRIPTS_TARGET = os.path.join(TEMP_PATH, "bin") +JARS_TARGET = os.path.join(TEMP_PATH, "jars") +EXAMPLES_TARGET = os.path.join(TEMP_PATH, "examples") + + +# Check and see if we are under the spark path in which case we need to build the symlink farm. +# This is important because we only want to build the symlink farm while under Spark otherwise we +# want to use the symlink farm. And if the symlink farm exists under while under Spark (e.g. a +# partially built sdist) we should error and have the user sort it out. +in_spark = (os.path.isfile("../core/src/main/scala/org/apache/spark/SparkContext.scala") or + (os.path.isfile("../RELEASE") and len(glob.glob("../jars/spark*core*.jar")) == 1)) + + +def _supports_symlinks(): + """Check if the system supports symlinks (e.g. *nix) or not.""" + return getattr(os, "symlink", None) is not None + + +if (in_spark): + # Construct links for setup + try: + os.mkdir(TEMP_PATH) + except: + print("Temp path for symlink to parent already exists {0}".format(TEMP_PATH), + file=sys.stderr) + exit(-1) + +try: + # We copy the shell script to be under pyspark/python/pyspark so that the launcher scripts + # find it where expected. The rest of the files aren't copied because they are accessed + # using Python imports instead which will be resolved correctly. + try: + os.makedirs("pyspark/python/pyspark") + except OSError: + # Don't worry if the directory already exists. + pass + copyfile("pyspark/shell.py", "pyspark/python/pyspark/shell.py") + + if (in_spark): + # Construct the symlink farm - this is necessary since we can't refer to the path above the + # package root and we need to copy the jars and scripts which are up above the python root. + if _supports_symlinks(): + os.symlink(JARS_PATH, JARS_TARGET) + os.symlink(SCRIPTS_PATH, SCRIPTS_TARGET) + os.symlink(EXAMPLES_PATH, EXAMPLES_TARGET) + else: + # For windows fall back to the slower copytree + copytree(JARS_PATH, JARS_TARGET) + copytree(SCRIPTS_PATH, SCRIPTS_TARGET) + copytree(EXAMPLES_PATH, EXAMPLES_TARGET) + else: + # If we are not inside of SPARK_HOME verify we have the required symlink farm + if not os.path.exists(JARS_TARGET): + print("To build packaging must be in the python directory under the SPARK_HOME.", + file=sys.stderr) + + if not os.path.isdir(SCRIPTS_TARGET): + print(incorrect_invocation_message, file=sys.stderr) + exit(-1) + + # Scripts directive requires a list of each script path and does not take wild cards. + script_names = os.listdir(SCRIPTS_TARGET) + scripts = list(map(lambda script: os.path.join(SCRIPTS_TARGET, script), script_names)) + # We add find_spark_home.py to the bin directory we install so that pip installed PySpark + # will search for SPARK_HOME with Python. + scripts.append("pyspark/find_spark_home.py") + + # Parse the README markdown file into rst for PyPI + long_description = "!!!!! missing pandoc do not upload to PyPI !!!!" + try: + import pypandoc + long_description = pypandoc.convert('README.md', 'rst') + except ImportError: + print("Could not import pypandoc - required to package PySpark", file=sys.stderr) + + setup( + name='pyspark', + version=VERSION, + description='Apache Spark Python API', + long_description=long_description, + author='Spark Developers', + author_email='dev@spark.apache.org', + url='https://github.com/apache/spark/tree/master/python', + packages=['pyspark', + 'pyspark.mllib', + 'pyspark.ml', + 'pyspark.sql', + 'pyspark.streaming', + 'pyspark.bin', + 'pyspark.jars', + 'pyspark.python.pyspark', + 'pyspark.python.lib', + 'pyspark.examples.src.main.python'], + include_package_data=True, + package_dir={ + 'pyspark.jars': 'deps/jars', + 'pyspark.bin': 'deps/bin', + 'pyspark.python.lib': 'lib', + 'pyspark.examples.src.main.python': 'deps/examples', + }, + package_data={ + 'pyspark.jars': ['*.jar'], + 'pyspark.bin': ['*'], + 'pyspark.python.lib': ['*.zip'], + 'pyspark.examples.src.main.python': ['*.py', '*/*.py']}, + scripts=scripts, + license='http://www.apache.org/licenses/LICENSE-2.0', + install_requires=['py4j==0.10.4'], + setup_requires=['pypandoc'], + extras_require={ + 'ml': ['numpy>=1.7'], + 'mllib': ['numpy>=1.7'], + 'sql': ['pandas'] + }, + classifiers=[ + 'Development Status :: 5 - Production/Stable', + 'License :: OSI Approved :: Apache Software License', + 'Programming Language :: Python :: 2.7', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.4', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: Implementation :: CPython', + 'Programming Language :: Python :: Implementation :: PyPy'] + ) +finally: + # We only cleanup the symlink farm if we were in Spark, otherwise we are installing rather than + # packaging. + if (in_spark): + # Depending on cleaning up the symlink farm or copied version + if _supports_symlinks(): + os.remove(os.path.join(TEMP_PATH, "jars")) + os.remove(os.path.join(TEMP_PATH, "bin")) + os.remove(os.path.join(TEMP_PATH, "examples")) + else: + rmtree(os.path.join(TEMP_PATH, "jars")) + rmtree(os.path.join(TEMP_PATH, "bin")) + rmtree(os.path.join(TEMP_PATH, "examples")) + os.rmdir(TEMP_PATH) From 2ca8ae9aa1b29bf1f46d0b656d9885e438e67f53 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 16 Nov 2016 14:32:36 -0800 Subject: [PATCH 165/198] [SPARK-18186] Migrate HiveUDAFFunction to TypedImperativeAggregate for partial aggregation support ## What changes were proposed in this pull request? While being evaluated in Spark SQL, Hive UDAFs don't support partial aggregation. This PR migrates `HiveUDAFFunction`s to `TypedImperativeAggregate`, which already provides partial aggregation support for aggregate functions that may use arbitrary Java objects as aggregation states. The following snippet shows the effect of this PR: ```scala import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax sql(s"CREATE FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'") spark.range(100).createOrReplaceTempView("t") // A query using both Spark SQL native `max` and Hive `max` sql(s"SELECT max(id), hive_max(id) FROM t").explain() ``` Before this PR: ``` == Physical Plan == SortAggregate(key=[], functions=[max(id#1L), default.hive_max(default.hive_max, HiveFunctionWrapper(org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax,org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax7475f57e), id#1L, false, 0, 0)]) +- Exchange SinglePartition +- *Range (0, 100, step=1, splits=Some(1)) ``` After this PR: ``` == Physical Plan == SortAggregate(key=[], functions=[max(id#1L), default.hive_max(default.hive_max, HiveFunctionWrapper(org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax,org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax5e18a6a7), id#1L, false, 0, 0)]) +- Exchange SinglePartition +- SortAggregate(key=[], functions=[partial_max(id#1L), partial_default.hive_max(default.hive_max, HiveFunctionWrapper(org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax,org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax5e18a6a7), id#1L, false, 0, 0)]) +- *Range (0, 100, step=1, splits=Some(1)) ``` The tricky part of the PR is mostly about updating and passing around aggregation states of `HiveUDAFFunction`s since the aggregation state of a Hive UDAF may appear in three different forms. Let's take a look at the testing `MockUDAF` added in this PR as an example. This UDAF computes the count of non-null values together with the count of nulls of a given column. Its aggregation state may appear as the following forms at different time: 1. A `MockUDAFBuffer`, which is a concrete subclass of `GenericUDAFEvaluator.AggregationBuffer` The form used by Hive UDAF API. This form is required by the following scenarios: - Calling `GenericUDAFEvaluator.iterate()` to update an existing aggregation state with new input values. - Calling `GenericUDAFEvaluator.terminate()` to get the final aggregated value from an existing aggregation state. - Calling `GenericUDAFEvaluator.merge()` to merge other aggregation states into an existing aggregation state. The existing aggregation state to be updated must be in this form. Conversions: - To form 2: `GenericUDAFEvaluator.terminatePartial()` - To form 3: Convert to form 2 first, and then to 3. 2. An `Object[]` array containing two `java.lang.Long` values. The form used to interact with Hive's `ObjectInspector`s. This form is required by the following scenarios: - Calling `GenericUDAFEvaluator.terminatePartial()` to convert an existing aggregation state in form 1 to form 2. - Calling `GenericUDAFEvaluator.merge()` to merge other aggregation states into an existing aggregation state. The input aggregation state must be in this form. Conversions: - To form 1: No direct method. Have to create an empty `AggregationBuffer` and merge it into the empty buffer. - To form 3: `unwrapperFor()`/`unwrap()` method of `HiveInspectors` 3. The byte array that holds data of an `UnsafeRow` with two `LongType` fields. The form used by Spark SQL to shuffle partial aggregation results. This form is required because `TypedImperativeAggregate` always asks its subclasses to serialize their aggregation states into a byte array. Conversions: - To form 1: Convert to form 2 first, and then to 1. - To form 2: `wrapperFor()`/`wrap()` method of `HiveInspectors` Here're some micro-benchmark results produced by the most recent master and this PR branch. Master: ``` Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5 Intel(R) Core(TM) i7-4960HQ CPU 2.60GHz hive udaf vs spark af: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ w/o groupBy 339 / 372 3.1 323.2 1.0X w/ groupBy 503 / 529 2.1 479.7 0.7X ``` This PR: ``` Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5 Intel(R) Core(TM) i7-4960HQ CPU 2.60GHz hive udaf vs spark af: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ w/o groupBy 116 / 126 9.0 110.8 1.0X w/ groupBy 151 / 159 6.9 144.0 0.8X ``` Benchmark code snippet: ```scala test("Hive UDAF benchmark") { val N = 1 << 20 sparkSession.sql(s"CREATE TEMPORARY FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'") val benchmark = new Benchmark( name = "hive udaf vs spark af", valuesPerIteration = N, minNumIters = 5, warmupTime = 5.seconds, minTime = 5.seconds, outputPerIteration = true ) benchmark.addCase("w/o groupBy") { _ => sparkSession.range(N).agg("id" -> "hive_max").collect() } benchmark.addCase("w/ groupBy") { _ => sparkSession.range(N).groupBy($"id" % 10).agg("id" -> "hive_max").collect() } benchmark.run() sparkSession.sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max") } ``` ## How was this patch tested? New test suite `HiveUDAFSuite` is added. Author: Cheng Lian Closes #15703 from liancheng/partial-agg-hive-udaf. --- .../org/apache/spark/sql/hive/hiveUDFs.scala | 199 +++++++++++++----- .../sql/hive/execution/HiveUDAFSuite.scala | 152 +++++++++++++ 2 files changed, 301 insertions(+), 50 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 42033080dc34b..32edd4aec2865 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -17,16 +17,18 @@ package org.apache.spark.sql.hive +import java.nio.ByteBuffer + import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.hive.ql.exec._ import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic._ +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper -import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, - ObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, ObjectInspectorFactory} import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions import org.apache.spark.internal.Logging @@ -58,7 +60,7 @@ private[hive] case class HiveSimpleUDF( @transient private lazy val isUDFDeterministic = { - val udfType = function.getClass().getAnnotation(classOf[HiveUDFType]) + val udfType = function.getClass.getAnnotation(classOf[HiveUDFType]) udfType != null && udfType.deterministic() } @@ -75,7 +77,7 @@ private[hive] case class HiveSimpleUDF( @transient lazy val unwrapper = unwrapperFor(ObjectInspectorFactory.getReflectionObjectInspector( - method.getGenericReturnType(), ObjectInspectorOptions.JAVA)) + method.getGenericReturnType, ObjectInspectorOptions.JAVA)) @transient private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) @@ -263,8 +265,35 @@ private[hive] case class HiveGenericUDTF( } /** - * Currently we don't support partial aggregation for queries using Hive UDAF, which may hurt - * performance a lot. + * While being evaluated by Spark SQL, the aggregation state of a Hive UDAF may be in the following + * three formats: + * + * 1. An instance of some concrete `GenericUDAFEvaluator.AggregationBuffer` class + * + * This is the native Hive representation of an aggregation state. Hive `GenericUDAFEvaluator` + * methods like `iterate()`, `merge()`, `terminatePartial()`, and `terminate()` use this format. + * We call these methods to evaluate Hive UDAFs. + * + * 2. A Java object that can be inspected using the `ObjectInspector` returned by the + * `GenericUDAFEvaluator.init()` method. + * + * Hive uses this format to produce a serializable aggregation state so that it can shuffle + * partial aggregation results. Whenever we need to convert a Hive `AggregationBuffer` instance + * into a Spark SQL value, we have to convert it to this format first and then do the conversion + * with the help of `ObjectInspector`s. + * + * 3. A Spark SQL value + * + * We use this format for serializing Hive UDAF aggregation states on Spark side. To be more + * specific, we convert `AggregationBuffer`s into equivalent Spark SQL values, write them into + * `UnsafeRow`s, and then retrieve the byte array behind those `UnsafeRow`s as serialization + * results. + * + * We may use the following methods to convert the aggregation state back and forth: + * + * - `wrap()`/`wrapperFor()`: from 3 to 1 + * - `unwrap()`/`unwrapperFor()`: from 1 to 3 + * - `GenericUDAFEvaluator.terminatePartial()`: from 2 to 3 */ private[hive] case class HiveUDAFFunction( name: String, @@ -273,7 +302,7 @@ private[hive] case class HiveUDAFFunction( isUDAFBridgeRequired: Boolean = false, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends ImperativeAggregate with HiveInspectors { + extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer] with HiveInspectors { override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) @@ -281,73 +310,73 @@ private[hive] case class HiveUDAFFunction( override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = copy(inputAggBufferOffset = newInputAggBufferOffset) + // Hive `ObjectInspector`s for all child expressions (input parameters of the function). @transient - private lazy val resolver = - if (isUDAFBridgeRequired) { + private lazy val inputInspectors = children.map(toInspector).toArray + + // Spark SQL data types of input parameters. + @transient + private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray + + private def newEvaluator(): GenericUDAFEvaluator = { + val resolver = if (isUDAFBridgeRequired) { new GenericUDAFBridge(funcWrapper.createFunction[UDAF]()) } else { funcWrapper.createFunction[AbstractGenericUDAFResolver]() } - @transient - private lazy val inspectors = children.map(toInspector).toArray - - @transient - private lazy val functionAndInspector = { - val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false) - val f = resolver.getEvaluator(parameterInfo) - f -> f.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) + val parameterInfo = new SimpleGenericUDAFParameterInfo(inputInspectors, false, false) + resolver.getEvaluator(parameterInfo) } + // The UDAF evaluator used to consume raw input rows and produce partial aggregation results. @transient - private lazy val function = functionAndInspector._1 + private lazy val partial1ModeEvaluator = newEvaluator() + // Hive `ObjectInspector` used to inspect partial aggregation results. @transient - private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray + private val partialResultInspector = partial1ModeEvaluator.init( + GenericUDAFEvaluator.Mode.PARTIAL1, + inputInspectors + ) + // The UDAF evaluator used to merge partial aggregation results. @transient - private lazy val returnInspector = functionAndInspector._2 + private lazy val partial2ModeEvaluator = { + val evaluator = newEvaluator() + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, Array(partialResultInspector)) + evaluator + } + // Spark SQL data type of partial aggregation results @transient - private lazy val unwrapper = unwrapperFor(returnInspector) + private lazy val partialResultDataType = inspectorToDataType(partialResultInspector) + // The UDAF evaluator used to compute the final result from a partial aggregation result objects. @transient - private[this] var buffer: GenericUDAFEvaluator.AggregationBuffer = _ - - override def eval(input: InternalRow): Any = unwrapper(function.evaluate(buffer)) + private lazy val finalModeEvaluator = newEvaluator() + // Hive `ObjectInspector` used to inspect the final aggregation result object. @transient - private lazy val inputProjection = new InterpretedProjection(children) + private val returnInspector = finalModeEvaluator.init( + GenericUDAFEvaluator.Mode.FINAL, + Array(partialResultInspector) + ) + // Wrapper functions used to wrap Spark SQL input arguments into Hive specific format. @transient - private lazy val cached = new Array[AnyRef](children.length) + private lazy val inputWrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray + // Unwrapper function used to unwrap final aggregation result objects returned by Hive UDAFs into + // Spark SQL specific format. @transient - private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray - - // Hive UDAF has its own buffer, so we don't need to occupy a slot in the aggregation - // buffer for it. - override def aggBufferSchema: StructType = StructType(Nil) - - override def update(_buffer: InternalRow, input: InternalRow): Unit = { - val inputs = inputProjection(input) - function.iterate(buffer, wrap(inputs, wrappers, cached, inputDataTypes)) - } - - override def merge(buffer1: InternalRow, buffer2: InternalRow): Unit = { - throw new UnsupportedOperationException( - "Hive UDAF doesn't support partial aggregate") - } + private lazy val resultUnwrapper = unwrapperFor(returnInspector) - override def initialize(_buffer: InternalRow): Unit = { - buffer = function.getNewAggregationBuffer - } - - override val aggBufferAttributes: Seq[AttributeReference] = Nil + @transient + private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) - // Note: although this simply copies aggBufferAttributes, this common code can not be placed - // in the superclass because that will lead to initialization ordering issues. - override val inputAggBufferAttributes: Seq[AttributeReference] = Nil + @transient + private lazy val aggBufferSerDe: AggregationBufferSerDe = new AggregationBufferSerDe // We rely on Hive to check the input data types, so use `AnyDataType` here to bypass our // catalyst type checking framework. @@ -355,7 +384,7 @@ private[hive] case class HiveUDAFFunction( override def nullable: Boolean = true - override def supportsPartial: Boolean = false + override def supportsPartial: Boolean = true override lazy val dataType: DataType = inspectorToDataType(returnInspector) @@ -365,4 +394,74 @@ private[hive] case class HiveUDAFFunction( val distinct = if (isDistinct) "DISTINCT " else " " s"$name($distinct${children.map(_.sql).mkString(", ")})" } + + override def createAggregationBuffer(): AggregationBuffer = + partial1ModeEvaluator.getNewAggregationBuffer + + @transient + private lazy val inputProjection = UnsafeProjection.create(children) + + override def update(buffer: AggregationBuffer, input: InternalRow): Unit = { + partial1ModeEvaluator.iterate( + buffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes)) + } + + override def merge(buffer: AggregationBuffer, input: AggregationBuffer): Unit = { + // The 2nd argument of the Hive `GenericUDAFEvaluator.merge()` method is an input aggregation + // buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts + // this `AggregationBuffer`s into this format before shuffling partial aggregation results, and + // calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion. + partial2ModeEvaluator.merge(buffer, partial1ModeEvaluator.terminatePartial(input)) + } + + override def eval(buffer: AggregationBuffer): Any = { + resultUnwrapper(finalModeEvaluator.terminate(buffer)) + } + + override def serialize(buffer: AggregationBuffer): Array[Byte] = { + // Serializes an `AggregationBuffer` that holds partial aggregation results so that we can + // shuffle it for global aggregation later. + aggBufferSerDe.serialize(buffer) + } + + override def deserialize(bytes: Array[Byte]): AggregationBuffer = { + // Deserializes an `AggregationBuffer` from the shuffled partial aggregation phase to prepare + // for global aggregation by merging multiple partial aggregation results within a single group. + aggBufferSerDe.deserialize(bytes) + } + + // Helper class used to de/serialize Hive UDAF `AggregationBuffer` objects + private class AggregationBufferSerDe { + private val partialResultUnwrapper = unwrapperFor(partialResultInspector) + + private val partialResultWrapper = wrapperFor(partialResultInspector, partialResultDataType) + + private val projection = UnsafeProjection.create(Array(partialResultDataType)) + + private val mutableRow = new GenericInternalRow(1) + + def serialize(buffer: AggregationBuffer): Array[Byte] = { + // `GenericUDAFEvaluator.terminatePartial()` converts an `AggregationBuffer` into an object + // that can be inspected by the `ObjectInspector` returned by `GenericUDAFEvaluator.init()`. + // Then we can unwrap it to a Spark SQL value. + mutableRow.update(0, partialResultUnwrapper(partial1ModeEvaluator.terminatePartial(buffer))) + val unsafeRow = projection(mutableRow) + val bytes = ByteBuffer.allocate(unsafeRow.getSizeInBytes) + unsafeRow.writeTo(bytes) + bytes.array() + } + + def deserialize(bytes: Array[Byte]): AggregationBuffer = { + // `GenericUDAFEvaluator` doesn't provide any method that is capable to convert an object + // returned by `GenericUDAFEvaluator.terminatePartial()` back to an `AggregationBuffer`. The + // workaround here is creating an initial `AggregationBuffer` first and then merge the + // deserialized object into the buffer. + val buffer = partial2ModeEvaluator.getNewAggregationBuffer + val unsafeRow = new UnsafeRow(1) + unsafeRow.pointTo(bytes, bytes.length) + val partialResult = unsafeRow.get(0, partialResultDataType) + partial2ModeEvaluator.merge(buffer, partialResultWrapper(partialResult)) + buffer + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala new file mode 100644 index 0000000000000..c9ef72ee112cf --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDAFEvaluator, GenericUDAFMax} +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.{AggregationBuffer, Mode} +import org.apache.hadoop.hive.ql.util.JavaDataModel +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils + +class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { + import testImplicits._ + + protected override def beforeAll(): Unit = { + sql(s"CREATE TEMPORARY FUNCTION mock AS '${classOf[MockUDAF].getName}'") + sql(s"CREATE TEMPORARY FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'") + + Seq( + (0: Integer) -> "val_0", + (1: Integer) -> "val_1", + (2: Integer) -> null, + (3: Integer) -> null + ).toDF("key", "value").repartition(2).createOrReplaceTempView("t") + } + + protected override def afterAll(): Unit = { + sql(s"DROP TEMPORARY FUNCTION IF EXISTS mock") + sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max") + } + + test("built-in Hive UDAF") { + val df = sql("SELECT key % 2, hive_max(key) FROM t GROUP BY key % 2") + + val aggs = df.queryExecution.executedPlan.collect { + case agg: ObjectHashAggregateExec => agg + } + + // There should be two aggregate operators, one for partial aggregation, and the other for + // global aggregation. + assert(aggs.length == 2) + + checkAnswer(df, Seq( + Row(0, 2), + Row(1, 3) + )) + } + + test("customized Hive UDAF") { + val df = sql("SELECT key % 2, mock(value) FROM t GROUP BY key % 2") + + val aggs = df.queryExecution.executedPlan.collect { + case agg: ObjectHashAggregateExec => agg + } + + // There should be two aggregate operators, one for partial aggregation, and the other for + // global aggregation. + assert(aggs.length == 2) + + checkAnswer(df, Seq( + Row(0, Row(1, 1)), + Row(1, Row(1, 1)) + )) + } +} + +/** + * A testing Hive UDAF that computes the counts of both non-null values and nulls of a given column. + */ +class MockUDAF extends AbstractGenericUDAFResolver { + override def getEvaluator(info: Array[TypeInfo]): GenericUDAFEvaluator = new MockUDAFEvaluator +} + +class MockUDAFBuffer(var nonNullCount: Long, var nullCount: Long) + extends GenericUDAFEvaluator.AbstractAggregationBuffer { + + override def estimate(): Int = JavaDataModel.PRIMITIVES2 * 2 +} + +class MockUDAFEvaluator extends GenericUDAFEvaluator { + private val nonNullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector + + private val nullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector + + private val bufferOI = { + val fieldNames = Seq("nonNullCount", "nullCount").asJava + val fieldOIs = Seq(nonNullCountOI: ObjectInspector, nullCountOI: ObjectInspector).asJava + ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs) + } + + private val nonNullCountField = bufferOI.getStructFieldRef("nonNullCount") + + private val nullCountField = bufferOI.getStructFieldRef("nullCount") + + override def getNewAggregationBuffer: AggregationBuffer = new MockUDAFBuffer(0L, 0L) + + override def reset(agg: AggregationBuffer): Unit = { + val buffer = agg.asInstanceOf[MockUDAFBuffer] + buffer.nonNullCount = 0L + buffer.nullCount = 0L + } + + override def init(mode: Mode, parameters: Array[ObjectInspector]): ObjectInspector = bufferOI + + override def iterate(agg: AggregationBuffer, parameters: Array[AnyRef]): Unit = { + val buffer = agg.asInstanceOf[MockUDAFBuffer] + if (parameters.head eq null) { + buffer.nullCount += 1L + } else { + buffer.nonNullCount += 1L + } + } + + override def merge(agg: AggregationBuffer, partial: Object): Unit = { + if (partial ne null) { + val nonNullCount = nonNullCountOI.get(bufferOI.getStructFieldData(partial, nonNullCountField)) + val nullCount = nullCountOI.get(bufferOI.getStructFieldData(partial, nullCountField)) + val buffer = agg.asInstanceOf[MockUDAFBuffer] + buffer.nonNullCount += nonNullCount + buffer.nullCount += nullCount + } + } + + override def terminatePartial(agg: AggregationBuffer): AnyRef = { + val buffer = agg.asInstanceOf[MockUDAFBuffer] + Array[Object](buffer.nonNullCount: java.lang.Long, buffer.nullCount: java.lang.Long) + } + + override def terminate(agg: AggregationBuffer): AnyRef = terminatePartial(agg) +} From 55589987be89ff78dadf44498352fbbd811a206e Mon Sep 17 00:00:00 2001 From: Artur Sukhenko Date: Wed, 16 Nov 2016 15:08:01 -0800 Subject: [PATCH 166/198] [YARN][DOC] Increasing NodeManager's heap size with External Shuffle Service ## What changes were proposed in this pull request? Suggest users to increase `NodeManager's` heap size if `External Shuffle Service` is enabled as `NM` can spend a lot of time doing GC resulting in shuffle operations being a bottleneck due to `Shuffle Read blocked time` bumped up. Also because of GC `NodeManager` can use an enormous amount of CPU and cluster performance will suffer. I have seen NodeManager using 5-13G RAM and up to 2700% CPU with `spark_shuffle` service on. ## How was this patch tested? #### Added step 5: ![shuffle_service](https://cloud.githubusercontent.com/assets/15244468/20355499/2fec0fde-ac2a-11e6-8f8b-1c80daf71be1.png) Author: Artur Sukhenko Closes #15906 from Devian-ua/nmHeapSize. --- docs/running-on-yarn.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index cd18808681ece..fe0221ce7c5b6 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -559,6 +559,8 @@ pre-packaged distribution. 1. In the `yarn-site.xml` on each node, add `spark_shuffle` to `yarn.nodemanager.aux-services`, then set `yarn.nodemanager.aux-services.spark_shuffle.class` to `org.apache.spark.network.yarn.YarnShuffleService`. +1. Increase `NodeManager's` heap size by setting `YARN_HEAPSIZE` (1000 by default) in `etc/hadoop/yarn-env.sh` +to avoid garbage collection issues during shuffle. 1. Restart all `NodeManager`s in your cluster. The following extra configuration options are available when the shuffle service is running on YARN: From 170eeb345f951de89a39fe565697b3e913011768 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 17 Nov 2016 11:21:08 +0800 Subject: [PATCH 167/198] [SPARK-18442][SQL] Fix nullability of WrapOption. ## What changes were proposed in this pull request? The nullability of `WrapOption` should be `false`. ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #15887 from ueshin/issues/SPARK-18442. --- .../apache/spark/sql/catalyst/expressions/objects/objects.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 50e2ac3c36d93..0e3d99127ed56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -341,7 +341,7 @@ case class WrapOption(child: Expression, optType: DataType) override def dataType: DataType = ObjectType(classOf[Option[_]]) - override def nullable: Boolean = true + override def nullable: Boolean = false override def inputTypes: Seq[AbstractDataType] = optType :: Nil From 07b3f045cd6f79b92bc86b3b1b51d3d5e6bd37ce Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 17 Nov 2016 00:00:38 -0800 Subject: [PATCH 168/198] [SPARK-18464][SQL] support old table which doesn't store schema in metastore ## What changes were proposed in this pull request? Before Spark 2.1, users can create an external data source table without schema, and we will infer the table schema at runtime. In Spark 2.1, we decided to infer the schema when the table was created, so that we don't need to infer it again and again at runtime. This is a good improvement, but we should still respect and support old tables which doesn't store table schema in metastore. ## How was this patch tested? regression test. Author: Wenchen Fan Closes #15900 from cloud-fan/hive-catalog. --- .../spark/sql/execution/command/tables.scala | 8 ++++++- .../spark/sql/hive/HiveExternalCatalog.scala | 5 +++++ .../spark/sql/hive/HiveMetastoreCatalog.scala | 4 +++- .../sql/hive/MetastoreDataSourcesSuite.scala | 22 +++++++++++++++++++ 4 files changed, 37 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 119e732d0202c..7049e53a78684 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -431,7 +431,13 @@ case class DescribeTableCommand( describeSchema(catalog.lookupRelation(table).schema, result) } else { val metadata = catalog.getTableMetadata(table) - describeSchema(metadata.schema, result) + if (metadata.schema.isEmpty) { + // In older version(prior to 2.1) of Spark, the table schema can be empty and should be + // inferred at runtime. We should still support it. + describeSchema(catalog.lookupRelation(metadata.identifier).schema, result) + } else { + describeSchema(metadata.schema, result) + } describePartitionInfo(metadata, result) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index cbd00da81cfcd..843305883abc8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -1023,6 +1023,11 @@ object HiveExternalCatalog { // After SPARK-6024, we removed this flag. // Although we are not using `spark.sql.sources.schema` any more, we need to still support. DataType.fromJson(schema.get).asInstanceOf[StructType] + } else if (props.filterKeys(_.startsWith(DATASOURCE_SCHEMA_PREFIX)).isEmpty) { + // If there is no schema information in table properties, it means the schema of this table + // was empty when saving into metastore, which is possible in older version(prior to 2.1) of + // Spark. We should respect it. + new StructType() } else { val numSchemaParts = props.get(DATASOURCE_SCHEMA_NUMPARTS) if (numSchemaParts.isDefined) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 8e5fc88aad448..edbde5d10b47c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -64,7 +64,9 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log val dataSource = DataSource( sparkSession, - userSpecifiedSchema = Some(table.schema), + // In older version(prior to 2.1) of Spark, the table schema can be empty and should be + // inferred at runtime. We should still support it. + userSpecifiedSchema = if (table.schema.isEmpty) None else Some(table.schema), partitionColumns = table.partitionColumnNames, bucketSpec = table.bucketSpec, className = table.provider.get, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index c50f92e783c88..4ab1a54edc46d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -1371,4 +1371,26 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } } } + + test("SPARK-18464: support old table which doesn't store schema in table properties") { + withTable("old") { + withTempPath { path => + Seq(1 -> "a").toDF("i", "j").write.parquet(path.getAbsolutePath) + val tableDesc = CatalogTable( + identifier = TableIdentifier("old", Some("default")), + tableType = CatalogTableType.EXTERNAL, + storage = CatalogStorageFormat.empty.copy( + properties = Map("path" -> path.getAbsolutePath) + ), + schema = new StructType(), + properties = Map( + HiveExternalCatalog.DATASOURCE_PROVIDER -> "parquet")) + hiveClient.createTable(tableDesc, ignoreIfExists = false) + + checkAnswer(spark.table("old"), Row(1, "a")) + + checkAnswer(sql("DESC old"), Row("i", "int", null) :: Row("j", "string", null) :: Nil) + } + } + } } From a3cac7bd86a6fe8e9b42da1bf580aaeb59378304 Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Thu, 17 Nov 2016 11:13:22 +0000 Subject: [PATCH 169/198] [YARN][DOC] Remove non-Yarn specific configurations from running-on-yarn.md ## What changes were proposed in this pull request? Remove `spark.driver.memory`, `spark.executor.memory`, `spark.driver.cores`, and `spark.executor.cores` from `running-on-yarn.md` as they are not Yarn-specific, and they are also defined in`configuration.md`. ## How was this patch tested? Build passed & Manually check. Author: Weiqing Yang Closes #15869 from weiqingy/yarnDoc. --- docs/running-on-yarn.md | 36 ------------------------------------ 1 file changed, 36 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index fe0221ce7c5b6..4d1fafc07b8fc 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -117,28 +117,6 @@ To use a custom metrics.properties for the application master and executors, upd Use lower-case suffixes, e.g. k, m, g, t, and p, for kibi-, mebi-, gibi-, tebi-, and pebibytes, respectively. - - spark.driver.memory - 1g - - Amount of memory to use for the driver process, i.e. where SparkContext is initialized. - (e.g. 1g, 2g). - -
Note: In client mode, this config must not be set through the SparkConf - directly in your application, because the driver JVM has already started at that point. - Instead, please set this through the --driver-memory command line option - or in your default properties file. - - - - spark.driver.cores - 1 - - Number of cores used by the driver in YARN cluster mode. - Since the driver is run in the same JVM as the YARN Application Master in cluster mode, this also controls the cores used by the YARN Application Master. - In client mode, use spark.yarn.am.cores to control the number of cores used by the YARN Application Master instead. - - spark.yarn.am.cores 1 @@ -233,13 +211,6 @@ To use a custom metrics.properties for the application master and executors, upd Comma-separated list of jars to be placed in the working directory of each executor. - - spark.executor.cores - 1 in YARN mode, all the available cores on the worker in standalone mode. - - The number of cores to use on each executor. For YARN and standalone mode only. - - spark.executor.instances 2 @@ -247,13 +218,6 @@ To use a custom metrics.properties for the application master and executors, upd The number of executors for static allocation. With spark.dynamicAllocation.enabled, the initial set of executors will be at least this large. - - spark.executor.memory - 1g - - Amount of memory to use per executor process (e.g. 2g, 8g). - - spark.yarn.executor.memoryOverhead executorMemory * 0.10, with minimum of 384 From 49b6f456aca350e9e2c170782aa5cc75e7822680 Mon Sep 17 00:00:00 2001 From: anabranch Date: Thu, 17 Nov 2016 11:34:55 +0000 Subject: [PATCH 170/198] [SPARK-18365][DOCS] Improve Sample Method Documentation ## What changes were proposed in this pull request? I found the documentation for the sample method to be confusing, this adds more clarification across all languages. - [x] Scala - [x] Python - [x] R - [x] RDD Scala - [ ] RDD Python with SEED - [X] RDD Java - [x] RDD Java with SEED - [x] RDD Python ## How was this patch tested? NA Please review https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark before opening a pull request. Author: anabranch Author: Bill Chambers Closes #15815 from anabranch/SPARK-18365. --- R/pkg/R/DataFrame.R | 4 +++- .../main/scala/org/apache/spark/api/java/JavaRDD.scala | 8 ++++++-- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 3 +++ python/pyspark/rdd.py | 5 +++++ python/pyspark/sql/dataframe.py | 5 +++++ .../src/main/scala/org/apache/spark/sql/Dataset.scala | 10 ++++++++-- 6 files changed, 30 insertions(+), 5 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 1cf9b38ea6483..4e3d97bb3ad07 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -936,7 +936,9 @@ setMethod("unique", #' Sample #' -#' Return a sampled subset of this SparkDataFrame using a random seed. +#' Return a sampled subset of this SparkDataFrame using a random seed. +#' Note: this is not guaranteed to provide exactly the fraction specified +#' of the total count of of the given SparkDataFrame. #' #' @param x A SparkDataFrame #' @param withReplacement Sampling with replacement or not diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 20d6c9341bf7a..d67cff64e6e46 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -98,7 +98,9 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) def repartition(numPartitions: Int): JavaRDD[T] = rdd.repartition(numPartitions) /** - * Return a sampled subset of this RDD. + * Return a sampled subset of this RDD with a random seed. + * Note: this is NOT guaranteed to provide exactly the fraction of the count + * of the given [[RDD]]. * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size @@ -109,7 +111,9 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) sample(withReplacement, fraction, Utils.random.nextLong) /** - * Return a sampled subset of this RDD. + * Return a sampled subset of this RDD, with a user-supplied seed. + * Note: this is NOT guaranteed to provide exactly the fraction of the count + * of the given [[RDD]]. * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e018af35cb18d..cded899db1f5c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -466,6 +466,9 @@ abstract class RDD[T: ClassTag]( /** * Return a sampled subset of this RDD. * + * Note: this is NOT guaranteed to provide exactly the fraction of the count + * of the given [[RDD]]. + * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 2de2c2fd1a60b..a163ceafe9d3b 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -386,6 +386,11 @@ def sample(self, withReplacement, fraction, seed=None): with replacement: expected number of times each element is chosen; fraction must be >= 0 :param seed: seed for the random number generator + .. note:: + + This is not guaranteed to provide exactly the fraction specified of the total count + of the given :class:`DataFrame`. + >>> rdd = sc.parallelize(range(100), 4) >>> 6 <= rdd.sample(False, 0.1, 81).count() <= 14 True diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 29710acf54c4f..38998900837cf 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -549,6 +549,11 @@ def distinct(self): def sample(self, withReplacement, fraction, seed=None): """Returns a sampled subset of this :class:`DataFrame`. + .. note:: + + This is not guaranteed to provide exactly the fraction specified of the total count + of the given :class:`DataFrame`. + >>> df.sample(False, 0.5, 42).count() 2 """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index af30683cc01c4..3761773698df3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1646,7 +1646,10 @@ class Dataset[T] private[sql]( } /** - * Returns a new Dataset by sampling a fraction of rows. + * Returns a new [[Dataset]] by sampling a fraction of rows, using a user-supplied seed. + * + * Note: this is NOT guaranteed to provide exactly the fraction of the count + * of the given [[Dataset]]. * * @param withReplacement Sample with replacement or not. * @param fraction Fraction of rows to generate. @@ -1665,7 +1668,10 @@ class Dataset[T] private[sql]( } /** - * Returns a new Dataset by sampling a fraction of rows, using a random seed. + * Returns a new [[Dataset]] by sampling a fraction of rows, using a random seed. + * + * Note: this is NOT guaranteed to provide exactly the fraction of the total count + * of the given [[Dataset]]. * * @param withReplacement Sample with replacement or not. * @param fraction Fraction of rows to generate. From de77c67750dc868d75d6af173c3820b75a9fe4b7 Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Thu, 17 Nov 2016 13:37:42 +0000 Subject: [PATCH 171/198] [SPARK-17462][MLLIB]use VersionUtils to parse Spark version strings ## What changes were proposed in this pull request? Several places in MLlib use custom regexes or other approaches to parse Spark versions. Those should be fixed to use the VersionUtils. This PR replaces custom regexes with VersionUtils to get Spark version numbers. ## How was this patch tested? Existing tests. Signed-off-by: VinceShieh vincent.xieintel.com Author: VinceShieh Closes #15055 from VinceShieh/SPARK-17462. --- .../main/scala/org/apache/spark/ml/clustering/KMeans.scala | 6 ++---- mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index a0d481b294ac7..26505b4cc1501 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -33,6 +33,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.util.VersionUtils.majorVersion /** * Common params for KMeans and KMeansModel @@ -232,10 +233,7 @@ object KMeansModel extends MLReadable[KMeansModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val versionRegex = "([0-9]+)\\.(.+)".r - val versionRegex(major, _) = metadata.sparkVersion - - val clusterCenters = if (major.toInt >= 2) { + val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) { val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data] data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML) } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 444006fe1edb6..1e49352b8517e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.util.VersionUtils.majorVersion /** * Params for [[PCA]] and [[PCAModel]]. @@ -204,11 +205,8 @@ object PCAModel extends MLReadable[PCAModel] { override def load(path: String): PCAModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) - val versionRegex = "([0-9]+)\\.(.+)".r - val versionRegex(major, _) = metadata.sparkVersion - val dataPath = new Path(path, "data").toString - val model = if (major.toInt >= 2) { + val model = if (majorVersion(metadata.sparkVersion) >= 2) { val Row(pc: DenseMatrix, explainedVariance: DenseVector) = sparkSession.read.parquet(dataPath) .select("pc", "explainedVariance") From cdaf4ce9fe58c4606be8aa2a5c3756d30545c850 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Thu, 17 Nov 2016 13:40:16 +0000 Subject: [PATCH 172/198] [SPARK-18480][DOCS] Fix wrong links for ML guide docs ## What changes were proposed in this pull request? 1, There are two `[Graph.partitionBy]` in `graphx-programming-guide.md`, the first one had no effert. 2, `DataFrame`, `Transformer`, `Pipeline` and `Parameter` in `ml-pipeline.md` were linked to `ml-guide.html` by mistake. 3, `PythonMLLibAPI` in `mllib-linear-methods.md` was not accessable, because class `PythonMLLibAPI` is private. 4, Other link updates. ## How was this patch tested? manual tests Author: Zheng RuiFeng Closes #15912 from zhengruifeng/md_fix. --- docs/graphx-programming-guide.md | 1 - docs/ml-classification-regression.md | 4 ++-- docs/ml-features.md | 2 +- docs/ml-pipeline.md | 12 ++++++------ docs/mllib-linear-methods.md | 4 +--- .../main/scala/org/apache/spark/ml/feature/LSH.scala | 2 +- .../spark/ml/tree/impl/GradientBoostedTrees.scala | 8 ++++---- .../org/apache/spark/ml/tree/impl/RandomForest.scala | 8 ++++---- 8 files changed, 19 insertions(+), 22 deletions(-) diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 1097cf1211c1f..e271b28fb4f28 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -36,7 +36,6 @@ description: GraphX graph processing library guide for Spark SPARK_VERSION_SHORT [Graph.fromEdgeTuples]: api/scala/index.html#org.apache.spark.graphx.Graph$@fromEdgeTuples[VD](RDD[(VertexId,VertexId)],VD,Option[PartitionStrategy])(ClassTag[VD]):Graph[VD,Int] [Graph.fromEdges]: api/scala/index.html#org.apache.spark.graphx.Graph$@fromEdges[VD,ED](RDD[Edge[ED]],VD)(ClassTag[VD],ClassTag[ED]):Graph[VD,ED] [PartitionStrategy]: api/scala/index.html#org.apache.spark.graphx.PartitionStrategy -[Graph.partitionBy]: api/scala/index.html#org.apache.spark.graphx.Graph$@partitionBy(partitionStrategy:org.apache.spark.graphx.PartitionStrategy):org.apache.spark.graphx.Graph[VD,ED] [PageRank]: api/scala/index.html#org.apache.spark.graphx.lib.PageRank$ [ConnectedComponents]: api/scala/index.html#org.apache.spark.graphx.lib.ConnectedComponents$ [TriangleCount]: api/scala/index.html#org.apache.spark.graphx.lib.TriangleCount$ diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index 1aacc3e054b52..43cc79b9c0811 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -984,7 +984,7 @@ Random forests combine many decision trees in order to reduce the risk of overfi The `spark.ml` implementation supports random forests for binary and multiclass classification and for regression, using both continuous and categorical features. -For more information on the algorithm itself, please see the [`spark.mllib` documentation on random forests](mllib-ensembles.html). +For more information on the algorithm itself, please see the [`spark.mllib` documentation on random forests](mllib-ensembles.html#random-forests). ### Inputs and Outputs @@ -1065,7 +1065,7 @@ GBTs iteratively train decision trees in order to minimize a loss function. The `spark.ml` implementation supports GBTs for binary classification and for regression, using both continuous and categorical features. -For more information on the algorithm itself, please see the [`spark.mllib` documentation on GBTs](mllib-ensembles.html). +For more information on the algorithm itself, please see the [`spark.mllib` documentation on GBTs](mllib-ensembles.html#gradient-boosted-trees-gbts). ### Inputs and Outputs diff --git a/docs/ml-features.md b/docs/ml-features.md index 19ec5746978ab..d2f036fb083da 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -710,7 +710,7 @@ for more details on the API. `VectorIndexer` helps index categorical features in datasets of `Vector`s. It can both automatically decide which features are categorical and convert original values to category indices. Specifically, it does the following: -1. Take an input column of type [Vector](api/scala/index.html#org.apache.spark.mllib.linalg.Vector) and a parameter `maxCategories`. +1. Take an input column of type [Vector](api/scala/index.html#org.apache.spark.ml.linalg.Vector) and a parameter `maxCategories`. 2. Decide which features should be categorical based on the number of distinct values, where features with at most `maxCategories` are declared categorical. 3. Compute 0-based category indices for each categorical feature. 4. Index categorical features and transform original feature values to indices. diff --git a/docs/ml-pipeline.md b/docs/ml-pipeline.md index b4d6be94f5eb0..0384513ab7014 100644 --- a/docs/ml-pipeline.md +++ b/docs/ml-pipeline.md @@ -38,26 +38,26 @@ algorithms into a single pipeline, or workflow. This section covers the key concepts introduced by the Pipelines API, where the pipeline concept is mostly inspired by the [scikit-learn](http://scikit-learn.org/) project. -* **[`DataFrame`](ml-guide.html#dataframe)**: This ML API uses `DataFrame` from Spark SQL as an ML +* **[`DataFrame`](ml-pipeline.html#dataframe)**: This ML API uses `DataFrame` from Spark SQL as an ML dataset, which can hold a variety of data types. E.g., a `DataFrame` could have different columns storing text, feature vectors, true labels, and predictions. -* **[`Transformer`](ml-guide.html#transformers)**: A `Transformer` is an algorithm which can transform one `DataFrame` into another `DataFrame`. +* **[`Transformer`](ml-pipeline.html#transformers)**: A `Transformer` is an algorithm which can transform one `DataFrame` into another `DataFrame`. E.g., an ML model is a `Transformer` which transforms a `DataFrame` with features into a `DataFrame` with predictions. -* **[`Estimator`](ml-guide.html#estimators)**: An `Estimator` is an algorithm which can be fit on a `DataFrame` to produce a `Transformer`. +* **[`Estimator`](ml-pipeline.html#estimators)**: An `Estimator` is an algorithm which can be fit on a `DataFrame` to produce a `Transformer`. E.g., a learning algorithm is an `Estimator` which trains on a `DataFrame` and produces a model. -* **[`Pipeline`](ml-guide.html#pipeline)**: A `Pipeline` chains multiple `Transformer`s and `Estimator`s together to specify an ML workflow. +* **[`Pipeline`](ml-pipeline.html#pipeline)**: A `Pipeline` chains multiple `Transformer`s and `Estimator`s together to specify an ML workflow. -* **[`Parameter`](ml-guide.html#parameters)**: All `Transformer`s and `Estimator`s now share a common API for specifying parameters. +* **[`Parameter`](ml-pipeline.html#parameters)**: All `Transformer`s and `Estimator`s now share a common API for specifying parameters. ## DataFrame Machine learning can be applied to a wide variety of data types, such as vectors, text, images, and structured data. This API adopts the `DataFrame` from Spark SQL in order to support a variety of data types. -`DataFrame` supports many basic and structured types; see the [Spark SQL datatype reference](sql-programming-guide.html#spark-sql-datatype-reference) for a list of supported types. +`DataFrame` supports many basic and structured types; see the [Spark SQL datatype reference](sql-programming-guide.html#data-types) for a list of supported types. In addition to the types listed in the Spark SQL guide, `DataFrame` can use ML [`Vector`](mllib-data-types.html#local-vector) types. A `DataFrame` can be created either implicitly or explicitly from a regular `RDD`. See the code examples below and the [Spark SQL programming guide](sql-programming-guide.html) for examples. diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 816bdf1317000..3085539b40e61 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -139,7 +139,7 @@ and logistic regression. Linear SVMs supports only binary classification, while logistic regression supports both binary and multiclass classification problems. For both methods, `spark.mllib` supports L1 and L2 regularized variants. -The training data set is represented by an RDD of [LabeledPoint](mllib-data-types.html) in MLlib, +The training data set is represented by an RDD of [LabeledPoint](mllib-data-types.html#labeled-point) in MLlib, where labels are class indices starting from zero: $0, 1, 2, \ldots$. ### Linear Support Vector Machines (SVMs) @@ -491,5 +491,3 @@ Algorithms are all implemented in Scala: * [RidgeRegressionWithSGD](api/scala/index.html#org.apache.spark.mllib.regression.RidgeRegressionWithSGD) * [LassoWithSGD](api/scala/index.html#org.apache.spark.mllib.regression.LassoWithSGD) -Python calls the Scala implementation via -[PythonMLLibAPI](api/scala/index.html#org.apache.spark.mllib.api.python.PythonMLLibAPI). diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala index 333a8c364a884..eb117c40eea3a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala @@ -40,7 +40,7 @@ private[ml] trait LSHParams extends HasInputCol with HasOutputCol { * @group param */ final val outputDim: IntParam = new IntParam(this, "outputDim", "output dimension, where" + - "increasing dimensionality lowers the false negative rate, and decreasing dimensionality" + + " increasing dimensionality lowers the false negative rate, and decreasing dimensionality" + " improves the running performance", ParamValidators.gt(0)) /** @group getParam */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index 7bef899a633d9..ede0a060eef95 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -34,7 +34,7 @@ private[spark] object GradientBoostedTrees extends Logging { /** * Method to train a gradient boosting model - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @param input Training dataset: RDD of [[LabeledPoint]]. * @param seed Random seed. * @return tuple of ensemble models and weights: * (array of decision tree models, array of model weights) @@ -59,7 +59,7 @@ private[spark] object GradientBoostedTrees extends Logging { /** * Method to validate a gradient boosting model - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @param input Training dataset: RDD of [[LabeledPoint]]. * @param validationInput Validation dataset. * This dataset should be different from the training dataset, * but it should follow the same distribution. @@ -162,7 +162,7 @@ private[spark] object GradientBoostedTrees extends Logging { * Method to calculate error of the base learner for the gradient boosting calculation. * Note: This method is not used by the gradient boosting algorithm but is useful for debugging * purposes. - * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @param data Training dataset: RDD of [[LabeledPoint]]. * @param trees Boosted Decision Tree models * @param treeWeights Learning rates at each boosting iteration. * @param loss evaluation metric. @@ -184,7 +184,7 @@ private[spark] object GradientBoostedTrees extends Logging { /** * Method to compute error or loss for every iteration of gradient boosting. * - * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param data RDD of [[LabeledPoint]] * @param trees Boosted Decision Tree models * @param treeWeights Learning rates at each boosting iteration. * @param loss evaluation metric. diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index b504f411d256d..8ae5ca3c84b0e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -82,7 +82,7 @@ private[spark] object RandomForest extends Logging { /** * Train a random forest. * - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param input Training data: RDD of [[LabeledPoint]] * @return an unweighted set of trees */ def run( @@ -343,7 +343,7 @@ private[spark] object RandomForest extends Logging { /** * Given a group of nodes, this finds the best split for each node. * - * @param input Training data: RDD of [[org.apache.spark.ml.tree.impl.TreePoint]] + * @param input Training data: RDD of [[TreePoint]] * @param metadata Learning and dataset metadata * @param topNodesForGroup For each tree in group, tree index -> root node. * Used for matching instances with nodes. @@ -854,10 +854,10 @@ private[spark] object RandomForest extends Logging { * and for multiclass classification with a high-arity feature, * there is one bin per category. * - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param input Training data: RDD of [[LabeledPoint]] * @param metadata Learning and dataset metadata * @param seed random seed - * @return Splits, an Array of [[org.apache.spark.mllib.tree.model.Split]] + * @return Splits, an Array of [[Split]] * of size (numFeatures, numSplits) */ protected[tree] def findSplits( From b0aa1aa1af6c513a6a881eaea96abdd2b480ef98 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 17 Nov 2016 17:04:19 +0000 Subject: [PATCH 173/198] [SPARK-18490][SQL] duplication nodename extrainfo for ShuffleExchange ## What changes were proposed in this pull request? In ShuffleExchange, the nodename's extraInfo are the same when exchangeCoordinator.isEstimated is true or false. Merge the two situation in the PR. Author: root Closes #15920 from windpiger/DupNodeNameShuffleExchange. --- .../apache/spark/sql/execution/exchange/ShuffleExchange.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala index 7a4a251370706..125a4930c6528 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala @@ -45,9 +45,7 @@ case class ShuffleExchange( override def nodeName: String = { val extraInfo = coordinator match { - case Some(exchangeCoordinator) if exchangeCoordinator.isEstimated => - s"(coordinator id: ${System.identityHashCode(coordinator)})" - case Some(exchangeCoordinator) if !exchangeCoordinator.isEstimated => + case Some(exchangeCoordinator) => s"(coordinator id: ${System.identityHashCode(coordinator)})" case None => "" } From ce13c2672318242748f7520ed4ce6bcfad4fb428 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 17 Nov 2016 17:31:12 -0800 Subject: [PATCH 174/198] [SPARK-18360][SQL] default table path of tables in default database should depend on the location of default database ## What changes were proposed in this pull request? The current semantic of the warehouse config: 1. it's a static config, which means you can't change it once your spark application is launched. 2. Once a database is created, its location won't change even the warehouse path config is changed. 3. default database is a special case, although its location is fixed, but the locations of tables created in it are not. If a Spark app starts with warehouse path B(while the location of default database is A), then users create a table `tbl` in default database, its location will be `B/tbl` instead of `A/tbl`. If uses change the warehouse path config to C, and create another table `tbl2`, its location will still be `B/tbl2` instead of `C/tbl2`. rule 3 doesn't make sense and I think we made it by mistake, not intentionally. Data source tables don't follow rule 3 and treat default database like normal ones. This PR fixes hive serde tables to make it consistent with data source tables. ## How was this patch tested? HiveSparkSubmitSuite Author: Wenchen Fan Closes #15812 from cloud-fan/default-db. --- .../spark/sql/hive/HiveExternalCatalog.scala | 237 ++++++++++-------- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 76 +++++- 2 files changed, 190 insertions(+), 123 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 843305883abc8..cacffcf33c263 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -197,136 +197,151 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat if (tableDefinition.tableType == VIEW) { client.createTable(tableDefinition, ignoreIfExists) - } else if (tableDefinition.provider.get == DDLUtils.HIVE_PROVIDER) { - // Here we follow data source tables and put table metadata like provider, schema, etc. in - // table properties, so that we can work around the Hive metastore issue about not case - // preserving and make Hive serde table support mixed-case column names. - val tableWithDataSourceProps = tableDefinition.copy( - properties = tableDefinition.properties ++ tableMetaToTableProps(tableDefinition)) - client.createTable(tableWithDataSourceProps, ignoreIfExists) } else { - // To work around some hive metastore issues, e.g. not case-preserving, bad decimal type - // support, no column nullability, etc., we should do some extra works before saving table - // metadata into Hive metastore: - // 1. Put table metadata like provider, schema, etc. in table properties. - // 2. Check if this table is hive compatible. - // 2.1 If it's not hive compatible, set location URI, schema, partition columns and bucket - // spec to empty and save table metadata to Hive. - // 2.2 If it's hive compatible, set serde information in table metadata and try to save - // it to Hive. If it fails, treat it as not hive compatible and go back to 2.1 - val tableProperties = tableMetaToTableProps(tableDefinition) - // Ideally we should not create a managed table with location, but Hive serde table can // specify location for managed table. And in [[CreateDataSourceTableAsSelectCommand]] we have // to create the table directory and write out data before we create this table, to avoid // exposing a partial written table. val needDefaultTableLocation = tableDefinition.tableType == MANAGED && tableDefinition.storage.locationUri.isEmpty + val tableLocation = if (needDefaultTableLocation) { Some(defaultTablePath(tableDefinition.identifier)) } else { tableDefinition.storage.locationUri } - // Ideally we should also put `locationUri` in table properties like provider, schema, etc. - // However, in older version of Spark we already store table location in storage properties - // with key "path". Here we keep this behaviour for backward compatibility. - val storagePropsWithLocation = tableDefinition.storage.properties ++ - tableLocation.map("path" -> _) - - // converts the table metadata to Spark SQL specific format, i.e. set data schema, names and - // bucket specification to empty. Note that partition columns are retained, so that we can - // call partition-related Hive API later. - def newSparkSQLSpecificMetastoreTable(): CatalogTable = { - tableDefinition.copy( - // Hive only allows directory paths as location URIs while Spark SQL data source tables - // also allow file paths. For non-hive-compatible format, we should not set location URI - // to avoid hive metastore to throw exception. - storage = tableDefinition.storage.copy( - locationUri = None, - properties = storagePropsWithLocation), - schema = tableDefinition.partitionSchema, - bucketSpec = None, - properties = tableDefinition.properties ++ tableProperties) + + if (tableDefinition.provider.get == DDLUtils.HIVE_PROVIDER) { + val tableWithDataSourceProps = tableDefinition.copy( + // We can't leave `locationUri` empty and count on Hive metastore to set a default table + // location, because Hive metastore uses hive.metastore.warehouse.dir to generate default + // table location for tables in default database, while we expect to use the location of + // default database. + storage = tableDefinition.storage.copy(locationUri = tableLocation), + // Here we follow data source tables and put table metadata like provider, schema, etc. in + // table properties, so that we can work around the Hive metastore issue about not case + // preserving and make Hive serde table support mixed-case column names. + properties = tableDefinition.properties ++ tableMetaToTableProps(tableDefinition)) + client.createTable(tableWithDataSourceProps, ignoreIfExists) + } else { + createDataSourceTable( + tableDefinition.withNewStorage(locationUri = tableLocation), + ignoreIfExists) } + } + } - // converts the table metadata to Hive compatible format, i.e. set the serde information. - def newHiveCompatibleMetastoreTable(serde: HiveSerDe): CatalogTable = { - val location = if (tableDefinition.tableType == EXTERNAL) { - // When we hit this branch, we are saving an external data source table with hive - // compatible format, which means the data source is file-based and must have a `path`. - require(tableDefinition.storage.locationUri.isDefined, - "External file-based data source table must have a `path` entry in storage properties.") - Some(new Path(tableDefinition.location).toUri.toString) - } else { - None - } + private def createDataSourceTable(table: CatalogTable, ignoreIfExists: Boolean): Unit = { + // To work around some hive metastore issues, e.g. not case-preserving, bad decimal type + // support, no column nullability, etc., we should do some extra works before saving table + // metadata into Hive metastore: + // 1. Put table metadata like provider, schema, etc. in table properties. + // 2. Check if this table is hive compatible. + // 2.1 If it's not hive compatible, set location URI, schema, partition columns and bucket + // spec to empty and save table metadata to Hive. + // 2.2 If it's hive compatible, set serde information in table metadata and try to save + // it to Hive. If it fails, treat it as not hive compatible and go back to 2.1 + val tableProperties = tableMetaToTableProps(table) + + // Ideally we should also put `locationUri` in table properties like provider, schema, etc. + // However, in older version of Spark we already store table location in storage properties + // with key "path". Here we keep this behaviour for backward compatibility. + val storagePropsWithLocation = table.storage.properties ++ + table.storage.locationUri.map("path" -> _) + + // converts the table metadata to Spark SQL specific format, i.e. set data schema, names and + // bucket specification to empty. Note that partition columns are retained, so that we can + // call partition-related Hive API later. + def newSparkSQLSpecificMetastoreTable(): CatalogTable = { + table.copy( + // Hive only allows directory paths as location URIs while Spark SQL data source tables + // also allow file paths. For non-hive-compatible format, we should not set location URI + // to avoid hive metastore to throw exception. + storage = table.storage.copy( + locationUri = None, + properties = storagePropsWithLocation), + schema = table.partitionSchema, + bucketSpec = None, + properties = table.properties ++ tableProperties) + } - tableDefinition.copy( - storage = tableDefinition.storage.copy( - locationUri = location, - inputFormat = serde.inputFormat, - outputFormat = serde.outputFormat, - serde = serde.serde, - properties = storagePropsWithLocation - ), - properties = tableDefinition.properties ++ tableProperties) + // converts the table metadata to Hive compatible format, i.e. set the serde information. + def newHiveCompatibleMetastoreTable(serde: HiveSerDe): CatalogTable = { + val location = if (table.tableType == EXTERNAL) { + // When we hit this branch, we are saving an external data source table with hive + // compatible format, which means the data source is file-based and must have a `path`. + require(table.storage.locationUri.isDefined, + "External file-based data source table must have a `path` entry in storage properties.") + Some(new Path(table.location).toUri.toString) + } else { + None } - val qualifiedTableName = tableDefinition.identifier.quotedString - val maybeSerde = HiveSerDe.sourceToSerDe(tableDefinition.provider.get) - val skipHiveMetadata = tableDefinition.storage.properties - .getOrElse("skipHiveMetadata", "false").toBoolean - - val (hiveCompatibleTable, logMessage) = maybeSerde match { - case _ if skipHiveMetadata => - val message = - s"Persisting data source table $qualifiedTableName into Hive metastore in" + - "Spark SQL specific format, which is NOT compatible with Hive." - (None, message) - - // our bucketing is un-compatible with hive(different hash function) - case _ if tableDefinition.bucketSpec.nonEmpty => - val message = - s"Persisting bucketed data source table $qualifiedTableName into " + - "Hive metastore in Spark SQL specific format, which is NOT compatible with Hive. " - (None, message) - - case Some(serde) => - val message = - s"Persisting file based data source table $qualifiedTableName into " + - s"Hive metastore in Hive compatible format." - (Some(newHiveCompatibleMetastoreTable(serde)), message) - - case _ => - val provider = tableDefinition.provider.get - val message = - s"Couldn't find corresponding Hive SerDe for data source provider $provider. " + - s"Persisting data source table $qualifiedTableName into Hive metastore in " + - s"Spark SQL specific format, which is NOT compatible with Hive." - (None, message) - } + table.copy( + storage = table.storage.copy( + locationUri = location, + inputFormat = serde.inputFormat, + outputFormat = serde.outputFormat, + serde = serde.serde, + properties = storagePropsWithLocation + ), + properties = table.properties ++ tableProperties) + } - (hiveCompatibleTable, logMessage) match { - case (Some(table), message) => - // We first try to save the metadata of the table in a Hive compatible way. - // If Hive throws an error, we fall back to save its metadata in the Spark SQL - // specific way. - try { - logInfo(message) - saveTableIntoHive(table, ignoreIfExists) - } catch { - case NonFatal(e) => - val warningMessage = - s"Could not persist ${tableDefinition.identifier.quotedString} in a Hive " + - "compatible way. Persisting it into Hive metastore in Spark SQL specific format." - logWarning(warningMessage, e) - saveTableIntoHive(newSparkSQLSpecificMetastoreTable(), ignoreIfExists) - } + val qualifiedTableName = table.identifier.quotedString + val maybeSerde = HiveSerDe.sourceToSerDe(table.provider.get) + val skipHiveMetadata = table.storage.properties + .getOrElse("skipHiveMetadata", "false").toBoolean + + val (hiveCompatibleTable, logMessage) = maybeSerde match { + case _ if skipHiveMetadata => + val message = + s"Persisting data source table $qualifiedTableName into Hive metastore in" + + "Spark SQL specific format, which is NOT compatible with Hive." + (None, message) + + // our bucketing is un-compatible with hive(different hash function) + case _ if table.bucketSpec.nonEmpty => + val message = + s"Persisting bucketed data source table $qualifiedTableName into " + + "Hive metastore in Spark SQL specific format, which is NOT compatible with Hive. " + (None, message) + + case Some(serde) => + val message = + s"Persisting file based data source table $qualifiedTableName into " + + s"Hive metastore in Hive compatible format." + (Some(newHiveCompatibleMetastoreTable(serde)), message) + + case _ => + val provider = table.provider.get + val message = + s"Couldn't find corresponding Hive SerDe for data source provider $provider. " + + s"Persisting data source table $qualifiedTableName into Hive metastore in " + + s"Spark SQL specific format, which is NOT compatible with Hive." + (None, message) + } - case (None, message) => - logWarning(message) - saveTableIntoHive(newSparkSQLSpecificMetastoreTable(), ignoreIfExists) - } + (hiveCompatibleTable, logMessage) match { + case (Some(table), message) => + // We first try to save the metadata of the table in a Hive compatible way. + // If Hive throws an error, we fall back to save its metadata in the Spark SQL + // specific way. + try { + logInfo(message) + saveTableIntoHive(table, ignoreIfExists) + } catch { + case NonFatal(e) => + val warningMessage = + s"Could not persist ${table.identifier.quotedString} in a Hive " + + "compatible way. Persisting it into Hive metastore in Spark SQL specific format." + logWarning(warningMessage, e) + saveTableIntoHive(newSparkSQLSpecificMetastoreTable(), ignoreIfExists) + } + + case (None, message) => + logWarning(message) + saveTableIntoHive(newSparkSQLSpecificMetastoreTable(), ignoreIfExists) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index fbd705172cae6..a670560c5969d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -24,6 +24,7 @@ import java.util.Date import scala.collection.mutable.ArrayBuffer import scala.tools.nsc.Properties +import org.apache.hadoop.fs.Path import org.scalatest.{BeforeAndAfterEach, Matchers} import org.scalatest.concurrent.Timeouts import org.scalatest.exceptions.TestFailedDueToTimeoutException @@ -33,11 +34,12 @@ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.sql.{QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, FunctionResource, JarResource} +import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer -import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.types.{DecimalType, StructType} import org.apache.spark.util.{ResetSystemProperties, Utils} /** @@ -295,6 +297,20 @@ class HiveSparkSubmitSuite runSparkSubmit(args) } + test("SPARK-18360: default table path of tables in default database should depend on the " + + "location of default database") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", SPARK_18360.getClass.getName.stripSuffix("$"), + "--name", "SPARK-18360", + "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", + unusedJar.toString) + runSparkSubmit(args) + } + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. // This is copied from org.apache.spark.deploy.SparkSubmitSuite private def runSparkSubmit(args: Seq[String]): Unit = { @@ -397,11 +413,7 @@ object SetWarehouseLocationTest extends Logging { def main(args: Array[String]): Unit = { Utils.configTestLog4j("INFO") - val sparkConf = new SparkConf(loadDefaults = true) - val builder = SparkSession.builder() - .config(sparkConf) - .config("spark.ui.enabled", "false") - .enableHiveSupport() + val sparkConf = new SparkConf(loadDefaults = true).set("spark.ui.enabled", "false") val providedExpectedWarehouseLocation = sparkConf.getOption("spark.sql.test.expectedWarehouseDir") @@ -410,7 +422,7 @@ object SetWarehouseLocationTest extends Logging { // If spark.sql.test.expectedWarehouseDir is set, the warehouse dir is set // through spark-summit. So, neither spark.sql.warehouse.dir nor // hive.metastore.warehouse.dir is set at here. - (builder.getOrCreate(), warehouseDir) + (new TestHiveContext(new SparkContext(sparkConf)).sparkSession, warehouseDir) case None => val warehouseLocation = Utils.createTempDir() warehouseLocation.delete() @@ -420,10 +432,10 @@ object SetWarehouseLocationTest extends Logging { // spark.sql.warehouse.dir and hive.metastore.warehouse.dir. // We are expecting that the value of spark.sql.warehouse.dir will override the // value of hive.metastore.warehouse.dir. - val session = builder - .config("spark.sql.warehouse.dir", warehouseLocation.toString) - .config("hive.metastore.warehouse.dir", hiveWarehouseLocation.toString) - .getOrCreate() + val session = new TestHiveContext(new SparkContext(sparkConf + .set("spark.sql.warehouse.dir", warehouseLocation.toString) + .set("hive.metastore.warehouse.dir", hiveWarehouseLocation.toString))) + .sparkSession (session, warehouseLocation.toString) } @@ -801,3 +813,43 @@ object SPARK_14244 extends QueryTest { } } } + +object SPARK_18360 { + def main(args: Array[String]): Unit = { + val spark = SparkSession.builder() + .config("spark.ui.enabled", "false") + .enableHiveSupport().getOrCreate() + + val defaultDbLocation = spark.catalog.getDatabase("default").locationUri + assert(new Path(defaultDbLocation) == new Path(spark.sharedState.warehousePath)) + + val hiveClient = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + + try { + val tableMeta = CatalogTable( + identifier = TableIdentifier("test_tbl", Some("default")), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat.empty, + schema = new StructType().add("i", "int"), + provider = Some(DDLUtils.HIVE_PROVIDER)) + + val newWarehousePath = Utils.createTempDir().getAbsolutePath + hiveClient.runSqlHive(s"SET hive.metastore.warehouse.dir=$newWarehousePath") + hiveClient.createTable(tableMeta, ignoreIfExists = false) + val rawTable = hiveClient.getTable("default", "test_tbl") + // Hive will use the value of `hive.metastore.warehouse.dir` to generate default table + // location for tables in default database. + assert(rawTable.storage.locationUri.get.contains(newWarehousePath)) + hiveClient.dropTable("default", "test_tbl", ignoreIfNotExists = false, purge = false) + + spark.sharedState.externalCatalog.createTable(tableMeta, ignoreIfExists = false) + val readBack = spark.sharedState.externalCatalog.getTable("default", "test_tbl") + // Spark SQL will use the location of default database to generate default table + // location for tables in default database. + assert(readBack.storage.locationUri.get.contains(defaultDbLocation)) + } finally { + hiveClient.dropTable("default", "test_tbl", ignoreIfNotExists = true, purge = false) + hiveClient.runSqlHive(s"SET hive.metastore.warehouse.dir=$defaultDbLocation") + } + } +} From d9dd979d170f44383a9a87f892f2486ddb3cca7d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 17 Nov 2016 18:45:15 -0800 Subject: [PATCH 175/198] [SPARK-18462] Fix ClassCastException in SparkListenerDriverAccumUpdates event ## What changes were proposed in this pull request? This patch fixes a `ClassCastException: java.lang.Integer cannot be cast to java.lang.Long` error which could occur in the HistoryServer while trying to process a deserialized `SparkListenerDriverAccumUpdates` event. The problem stems from how `jackson-module-scala` handles primitive type parameters (see https://github.com/FasterXML/jackson-module-scala/wiki/FAQ#deserializing-optionint-and-other-primitive-challenges for more details). This was causing a problem where our code expected a field to be deserialized as a `(Long, Long)` tuple but we got an `(Int, Int)` tuple instead. This patch hacks around this issue by registering a custom `Converter` with Jackson in order to deserialize the tuples as `(Object, Object)` and perform the appropriate casting. ## How was this patch tested? New regression tests in `SQLListenerSuite`. Author: Josh Rosen Closes #15922 from JoshRosen/SPARK-18462. --- .../spark/sql/execution/ui/SQLListener.scala | 39 +++++++++++++++- .../sql/execution/ui/SQLListenerSuite.scala | 44 ++++++++++++++++++- 2 files changed, 80 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 60f13432d78d2..5daf21595d8a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -19,6 +19,11 @@ package org.apache.spark.sql.execution.ui import scala.collection.mutable +import com.fasterxml.jackson.databind.JavaType +import com.fasterxml.jackson.databind.`type`.TypeFactory +import com.fasterxml.jackson.databind.annotation.JsonDeserialize +import com.fasterxml.jackson.databind.util.Converter + import org.apache.spark.{JobExecutionStatus, SparkConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging @@ -43,9 +48,41 @@ case class SparkListenerSQLExecutionEnd(executionId: Long, time: Long) extends SparkListenerEvent @DeveloperApi -case class SparkListenerDriverAccumUpdates(executionId: Long, accumUpdates: Seq[(Long, Long)]) +case class SparkListenerDriverAccumUpdates( + executionId: Long, + @JsonDeserialize(contentConverter = classOf[LongLongTupleConverter]) + accumUpdates: Seq[(Long, Long)]) extends SparkListenerEvent +/** + * Jackson [[Converter]] for converting an (Int, Int) tuple into a (Long, Long) tuple. + * + * This is necessary due to limitations in how Jackson's scala module deserializes primitives; + * see the "Deserializing Option[Int] and other primitive challenges" section in + * https://github.com/FasterXML/jackson-module-scala/wiki/FAQ for a discussion of this issue and + * SPARK-18462 for the specific problem that motivated this conversion. + */ +private class LongLongTupleConverter extends Converter[(Object, Object), (Long, Long)] { + + override def convert(in: (Object, Object)): (Long, Long) = { + def toLong(a: Object): Long = a match { + case i: java.lang.Integer => i.intValue() + case l: java.lang.Long => l.longValue() + } + (toLong(in._1), toLong(in._2)) + } + + override def getInputType(typeFactory: TypeFactory): JavaType = { + val objectType = typeFactory.uncheckedSimpleType(classOf[Object]) + typeFactory.constructSimpleType(classOf[(_, _)], classOf[(_, _)], Array(objectType, objectType)) + } + + override def getOutputType(typeFactory: TypeFactory): JavaType = { + val longType = typeFactory.uncheckedSimpleType(classOf[Long]) + typeFactory.constructSimpleType(classOf[(_, _)], classOf[(_, _)], Array(longType, longType)) + } +} + class SQLHistoryListenerFactory extends SparkHistoryListenerFactory { override def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 948a155457b65..8aea112897fb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.ui import java.util.Properties +import org.json4s.jackson.JsonMethods._ import org.mockito.Mockito.mock import org.apache.spark._ @@ -35,10 +36,10 @@ import org.apache.spark.sql.execution.{LeafExecNode, QueryExecution, SparkPlanIn import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{AccumulatorMetadata, LongAccumulator} +import org.apache.spark.util.{AccumulatorMetadata, JsonProtocol, LongAccumulator} -class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { +class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTestUtils { import testImplicits._ import org.apache.spark.AccumulatorSuite.makeInfo @@ -416,6 +417,45 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { assert(driverUpdates(physicalPlan.longMetric("dummy").id) == expectedAccumValue) } + test("roundtripping SparkListenerDriverAccumUpdates through JsonProtocol (SPARK-18462)") { + val event = SparkListenerDriverAccumUpdates(1L, Seq((2L, 3L))) + val json = JsonProtocol.sparkEventToJson(event) + assertValidDataInJson(json, + parse(""" + |{ + | "Event": "org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates", + | "executionId": 1, + | "accumUpdates": [[2,3]] + |} + """.stripMargin)) + JsonProtocol.sparkEventFromJson(json) match { + case SparkListenerDriverAccumUpdates(executionId, accums) => + assert(executionId == 1L) + accums.foreach { case (a, b) => + assert(a == 2L) + assert(b == 3L) + } + } + + // Test a case where the numbers in the JSON can only fit in longs: + val longJson = parse( + """ + |{ + | "Event": "org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates", + | "executionId": 4294967294, + | "accumUpdates": [[4294967294,3]] + |} + """.stripMargin) + JsonProtocol.sparkEventFromJson(longJson) match { + case SparkListenerDriverAccumUpdates(executionId, accums) => + assert(executionId == 4294967294L) + accums.foreach { case (a, b) => + assert(a == 4294967294L) + assert(b == 3L) + } + } + } + } From 51baca2219fda8692b88fc8552548544aec73a1e Mon Sep 17 00:00:00 2001 From: Tyson Condie Date: Fri, 18 Nov 2016 11:11:24 -0800 Subject: [PATCH 176/198] [SPARK-18187][SQL] CompactibleFileStreamLog should not use "compactInterval" direcly with user setting. ## What changes were proposed in this pull request? CompactibleFileStreamLog relys on "compactInterval" to detect a compaction batch. If the "compactInterval" is reset by user, CompactibleFileStreamLog will return wrong answer, resulting data loss. This PR procides a way to check the validity of 'compactInterval', and calculate an appropriate value. ## How was this patch tested? When restart a stream, we change the 'spark.sql.streaming.fileSource.log.compactInterval' different with the former one. The primary solution to this issue was given by uncleGen Added extensions include an additional metadata field in OffsetSeq and CompactibleFileStreamLog APIs. zsxwing Author: Tyson Condie Author: genmao.ygm Closes #15852 from tcondie/spark-18187. --- .../streaming/CompactibleFileStreamLog.scala | 61 ++++++++++++++++++- .../streaming/FileStreamSinkLog.scala | 8 ++- .../streaming/FileStreamSourceLog.scala | 9 +-- .../execution/streaming/HDFSMetadataLog.scala | 2 +- .../sql/execution/streaming/OffsetSeq.scala | 12 +++- .../execution/streaming/OffsetSeqLog.scala | 31 +++++++--- .../CompactibleFileStreamLogSuite.scala | 33 ++++++++++ .../sql/streaming/FileStreamSourceSuite.scala | 41 ++++++++----- .../spark/sql/streaming/StreamTest.scala | 20 +++++- 9 files changed, 178 insertions(+), 39 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala index 8af3db1968882..8529ceac30f1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala @@ -63,7 +63,46 @@ abstract class CompactibleFileStreamLog[T <: AnyRef : ClassTag]( protected def isDeletingExpiredLog: Boolean - protected def compactInterval: Int + protected def defaultCompactInterval: Int + + protected final lazy val compactInterval: Int = { + // SPARK-18187: "compactInterval" can be set by user via defaultCompactInterval. + // If there are existing log entries, then we should ensure a compatible compactInterval + // is used, irrespective of the defaultCompactInterval. There are three cases: + // + // 1. If there is no '.compact' file, we can use the default setting directly. + // 2. If there are two or more '.compact' files, we use the interval of patch id suffix with + // '.compact' as compactInterval. This case could arise if isDeletingExpiredLog == false. + // 3. If there is only one '.compact' file, then we must find a compact interval + // that is compatible with (i.e., a divisor of) the previous compact file, and that + // faithfully tries to represent the revised default compact interval i.e., is at least + // is large if possible. + // e.g., if defaultCompactInterval is 5 (and previous compact interval could have + // been any 2,3,4,6,12), then a log could be: 11.compact, 12, 13, in which case + // will ensure that the new compactInterval = 6 > 5 and (11 + 1) % 6 == 0 + val compactibleBatchIds = fileManager.list(metadataPath, batchFilesFilter) + .filter(f => f.getPath.toString.endsWith(CompactibleFileStreamLog.COMPACT_FILE_SUFFIX)) + .map(f => pathToBatchId(f.getPath)) + .sorted + .reverse + + // Case 1 + var interval = defaultCompactInterval + if (compactibleBatchIds.length >= 2) { + // Case 2 + val latestCompactBatchId = compactibleBatchIds(0) + val previousCompactBatchId = compactibleBatchIds(1) + interval = (latestCompactBatchId - previousCompactBatchId).toInt + } else if (compactibleBatchIds.length == 1) { + // Case 3 + interval = CompactibleFileStreamLog.deriveCompactInterval( + defaultCompactInterval, compactibleBatchIds(0).toInt) + } + assert(interval > 0, s"intervalValue = $interval not positive value.") + logInfo(s"Set the compact interval to $interval " + + s"[defaultCompactInterval: $defaultCompactInterval]") + interval + } /** * Filter out the obsolete logs. @@ -245,4 +284,24 @@ object CompactibleFileStreamLog { def nextCompactionBatchId(batchId: Long, compactInterval: Long): Long = { (batchId + compactInterval + 1) / compactInterval * compactInterval - 1 } + + /** + * Derives a compact interval from the latest compact batch id and + * a default compact interval. + */ + def deriveCompactInterval(defaultInterval: Int, latestCompactBatchId: Int) : Int = { + if (latestCompactBatchId + 1 <= defaultInterval) { + latestCompactBatchId + 1 + } else if (defaultInterval < (latestCompactBatchId + 1) / 2) { + // Find the first divisor >= default compact interval + def properDivisors(min: Int, n: Int) = + (min to n/2).view.filter(i => n % i == 0) :+ n + + properDivisors(defaultInterval, latestCompactBatchId + 1).head + } else { + // default compact interval > than any divisor other than latest compact id + latestCompactBatchId + 1 + } + } } + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala index b4f14151f1ef2..eb6eed87eca7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala @@ -88,9 +88,11 @@ class FileStreamSinkLog( protected override val isDeletingExpiredLog = sparkSession.sessionState.conf.fileSinkLogDeletion - protected override val compactInterval = sparkSession.sessionState.conf.fileSinkLogCompactInterval - require(compactInterval > 0, - s"Please set ${SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key} (was $compactInterval) " + + protected override val defaultCompactInterval = + sparkSession.sessionState.conf.fileSinkLogCompactInterval + + require(defaultCompactInterval > 0, + s"Please set ${SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key} (was $defaultCompactInterval) " + "to a positive value.") override def compactLogs(logs: Seq[SinkFileStatus]): Seq[SinkFileStatus] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala index fe81b15607068..327b3ac267766 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala @@ -38,11 +38,12 @@ class FileStreamSourceLog( import CompactibleFileStreamLog._ // Configurations about metadata compaction - protected override val compactInterval = + protected override val defaultCompactInterval: Int = sparkSession.sessionState.conf.fileSourceLogCompactInterval - require(compactInterval > 0, - s"Please set ${SQLConf.FILE_SOURCE_LOG_COMPACT_INTERVAL.key} (was $compactInterval) to a " + - s"positive value.") + + require(defaultCompactInterval > 0, + s"Please set ${SQLConf.FILE_SOURCE_LOG_COMPACT_INTERVAL.key} " + + s"(was $defaultCompactInterval) to a positive value.") protected override val fileCleanupDelayMs = sparkSession.sessionState.conf.fileSourceLogCleanupDelay diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index db7057d7da70c..080729b2ca8d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -70,7 +70,7 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: /** * A `PathFilter` to filter only batch files */ - private val batchFilesFilter = new PathFilter { + protected val batchFilesFilter = new PathFilter { override def accept(path: Path): Boolean = isBatchFile(path) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index a4e1fe6797097..7469caeee3be5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -23,7 +23,7 @@ package org.apache.spark.sql.execution.streaming * [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance * vector clock that must progress linearly forward. */ -case class OffsetSeq(offsets: Seq[Option[Offset]]) { +case class OffsetSeq(offsets: Seq[Option[Offset]], metadata: Option[String] = None) { /** * Unpacks an offset into [[StreamProgress]] by associating each offset with the order list of @@ -47,7 +47,13 @@ object OffsetSeq { * Returns a [[OffsetSeq]] with a variable sequence of offsets. * `nulls` in the sequence are converted to `None`s. */ - def fill(offsets: Offset*): OffsetSeq = { - OffsetSeq(offsets.map(Option(_))) + def fill(offsets: Offset*): OffsetSeq = OffsetSeq.fill(None, offsets: _*) + + /** + * Returns a [[OffsetSeq]] with metadata and a variable sequence of offsets. + * `nulls` in the sequence are converted to `None`s. + */ + def fill(metadata: Option[String], offsets: Offset*): OffsetSeq = { + OffsetSeq(offsets.map(Option(_)), metadata) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala index d1c9d95be9fdb..cc25b4474ba2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala @@ -33,12 +33,13 @@ import org.apache.spark.sql.SparkSession * by a newline character. If a source offset is missing, then * that line will contain a string value defined in the * SERIALIZED_VOID_OFFSET variable in [[OffsetSeqLog]] companion object. - * For instance, when dealine wiht [[LongOffset]] types: - * v1 // version 1 - * {0} // LongOffset 0 - * {3} // LongOffset 3 - * - // No offset for this source i.e., an invalid JSON string - * {2} // LongOffset 2 + * For instance, when dealing with [[LongOffset]] types: + * v1 // version 1 + * metadata + * {0} // LongOffset 0 + * {3} // LongOffset 3 + * - // No offset for this source i.e., an invalid JSON string + * {2} // LongOffset 2 * ... */ class OffsetSeqLog(sparkSession: SparkSession, path: String) @@ -58,13 +59,25 @@ class OffsetSeqLog(sparkSession: SparkSession, path: String) if (version != OffsetSeqLog.VERSION) { throw new IllegalStateException(s"Unknown log version: ${version}") } - OffsetSeq.fill(lines.map(parseOffset).toArray: _*) + + // read metadata + val metadata = lines.next().trim match { + case "" => None + case md => Some(md) + } + OffsetSeq.fill(metadata, lines.map(parseOffset).toArray: _*) } - override protected def serialize(metadata: OffsetSeq, out: OutputStream): Unit = { + override protected def serialize(offsetSeq: OffsetSeq, out: OutputStream): Unit = { // called inside a try-finally where the underlying stream is closed in the caller out.write(OffsetSeqLog.VERSION.getBytes(UTF_8)) - metadata.offsets.map(_.map(_.json)).foreach { offset => + + // write metadata + out.write('\n') + out.write(offsetSeq.metadata.getOrElse("").getBytes(UTF_8)) + + // write offsets, one per line + offsetSeq.offsets.map(_.map(_.json)).foreach { offset => out.write('\n') offset match { case Some(json: String) => out.write(json.getBytes(UTF_8)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala new file mode 100644 index 0000000000000..2cd2157b293cb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.SparkFunSuite + +class CompactibleFileStreamLogSuite extends SparkFunSuite { + + import CompactibleFileStreamLog._ + + test("deriveCompactInterval") { + // latestCompactBatchId(4) + 1 <= default(5) + // then use latestestCompactBatchId + 1 === 5 + assert(5 === deriveCompactInterval(5, 4)) + // First divisor of 10 greater than 4 === 5 + assert(5 === deriveCompactInterval(4, 9)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index b365af76c3795..a099153d2e58e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.streaming import java.io.File +import scala.collection.mutable + import org.scalatest.PrivateMethodTester import org.scalatest.time.SpanSugar._ @@ -896,32 +898,38 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } } - test("compacat metadata log") { + test("compact interval metadata log") { val _sources = PrivateMethod[Seq[Source]]('sources) val _metadataLog = PrivateMethod[FileStreamSourceLog]('metadataLog) - def verify(execution: StreamExecution) - (batchId: Long, expectedBatches: Int): Boolean = { + def verify( + execution: StreamExecution, + batchId: Long, + expectedBatches: Int, + expectedCompactInterval: Int): Boolean = { import CompactibleFileStreamLog._ val fileSource = (execution invokePrivate _sources()).head.asInstanceOf[FileStreamSource] val metadataLog = fileSource invokePrivate _metadataLog() - if (isCompactionBatch(batchId, 2)) { + if (isCompactionBatch(batchId, expectedCompactInterval)) { val path = metadataLog.batchIdToPath(batchId) // Assert path name should be ended with compact suffix. - assert(path.getName.endsWith(COMPACT_FILE_SUFFIX)) + assert(path.getName.endsWith(COMPACT_FILE_SUFFIX), + "path does not end with compact file suffix") // Compacted batch should include all entries from start. val entries = metadataLog.get(batchId) - assert(entries.isDefined) - assert(entries.get.length === metadataLog.allFiles().length) - assert(metadataLog.get(None, Some(batchId)).flatMap(_._2).length === entries.get.length) + assert(entries.isDefined, "Entries not defined") + assert(entries.get.length === metadataLog.allFiles().length, "clean up check") + assert(metadataLog.get(None, Some(batchId)).flatMap(_._2).length === + entries.get.length, "Length check") } assert(metadataLog.allFiles().sortBy(_.batchId) === - metadataLog.get(None, Some(batchId)).flatMap(_._2).sortBy(_.batchId)) + metadataLog.get(None, Some(batchId)).flatMap(_._2).sortBy(_.batchId), + "Batch id mismatch") metadataLog.get(None, Some(batchId)).flatMap(_._2).length === expectedBatches } @@ -932,26 +940,27 @@ class FileStreamSourceSuite extends FileStreamSourceTest { ) { val fileStream = createFileStream("text", src.getCanonicalPath) val filtered = fileStream.filter($"value" contains "keep") + val updateConf = Map(SQLConf.FILE_SOURCE_LOG_COMPACT_INTERVAL.key -> "5") testStream(filtered)( AddTextFileData("drop1\nkeep2\nkeep3", src, tmp), CheckAnswer("keep2", "keep3"), - AssertOnQuery(verify(_)(0L, 1)), + AssertOnQuery(verify(_, 0L, 1, 2)), AddTextFileData("drop4\nkeep5\nkeep6", src, tmp), CheckAnswer("keep2", "keep3", "keep5", "keep6"), - AssertOnQuery(verify(_)(1L, 2)), + AssertOnQuery(verify(_, 1L, 2, 2)), AddTextFileData("drop7\nkeep8\nkeep9", src, tmp), CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9"), - AssertOnQuery(verify(_)(2L, 3)), + AssertOnQuery(verify(_, 2L, 3, 2)), StopStream, - StartStream(), - AssertOnQuery(verify(_)(2L, 3)), + StartStream(additionalConfs = updateConf), + AssertOnQuery(verify(_, 2L, 3, 2)), AddTextFileData("drop10\nkeep11", src, tmp), CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9", "keep11"), - AssertOnQuery(verify(_)(3L, 4)), + AssertOnQuery(verify(_, 3L, 4, 2)), AddTextFileData("drop12\nkeep13", src, tmp), CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9", "keep11", "keep13"), - AssertOnQuery(verify(_)(4L, 5)) + AssertOnQuery(verify(_, 4L, 5, 2)) ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 742833065144d..a6b2d4b9ab4c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -161,7 +161,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { /** Starts the stream, resuming if data has already been processed. It must not be running. */ case class StartStream( trigger: Trigger = ProcessingTime(0), - triggerClock: Clock = new SystemClock) + triggerClock: Clock = new SystemClock, + additionalConfs: Map[String, String] = Map.empty) extends StreamAction /** Advance the trigger clock's time manually. */ @@ -240,6 +241,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { var lastStream: StreamExecution = null val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for val sink = new MemorySink(stream.schema, outputMode) + val resetConfValues = mutable.Map[String, Option[String]]() @volatile var streamDeathCause: Throwable = null @@ -330,7 +332,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { startedTest.foreach { action => logInfo(s"Processing test stream action: $action") action match { - case StartStream(trigger, triggerClock) => + case StartStream(trigger, triggerClock, additionalConfs) => verify(currentStream == null, "stream already running") verify(triggerClock.isInstanceOf[SystemClock] || triggerClock.isInstanceOf[StreamManualClock], @@ -338,6 +340,14 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { if (triggerClock.isInstanceOf[StreamManualClock]) { manualClockExpectedTime = triggerClock.asInstanceOf[StreamManualClock].getTimeMillis() } + + additionalConfs.foreach(pair => { + val value = + if (spark.conf.contains(pair._1)) Some(spark.conf.get(pair._1)) else None + resetConfValues(pair._1) = value + spark.conf.set(pair._1, pair._2) + }) + lastStream = currentStream currentStream = spark @@ -519,6 +529,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { currentStream.stop() } spark.streams.removeListener(statusCollector) + + // Rollback prev configuration values + resetConfValues.foreach { + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) + } } } From 795e9fc9213cb9941ae131aadcafddb94bde5f74 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Fri, 18 Nov 2016 11:19:49 -0800 Subject: [PATCH 177/198] [SPARK-18457][SQL] ORC and other columnar formats using HiveShim read all columns when doing a simple count ## What changes were proposed in this pull request? When reading zero columns (e.g., count(*)) from ORC or any other format that uses HiveShim, actually set the read column list to empty for Hive to use. ## How was this patch tested? Query correctness is handled by existing unit tests. I'm happy to add more if anyone can point out some case that is not covered. Reduction in data read can be verified in the UI when built with a recent version of Hadoop say: ``` build/mvn -Pyarn -Phadoop-2.7 -Dhadoop.version=2.7.0 -Phive -DskipTests clean package ``` However the default Hadoop 2.2 that is used for unit tests does not report actual bytes read and instead just full file sizes (see FileScanRDD.scala line 80). Therefore I don't think there is a good way to add a unit test for this. I tested with the following setup using above build options ``` case class OrcData(intField: Long, stringField: String) spark.range(1,1000000).map(i => OrcData(i, s"part-$i")).toDF().write.format("orc").save("orc_test") sql( s"""CREATE EXTERNAL TABLE orc_test( | intField LONG, | stringField STRING |) |STORED AS ORC |LOCATION '${System.getProperty("user.dir") + "/orc_test"}' """.stripMargin) ``` ## Results query | Spark 2.0.2 | this PR ---|---|--- `sql("select count(*) from orc_test").collect`|4.4 MB|199.4 KB `sql("select intField from orc_test").collect`|743.4 KB|743.4 KB `sql("select * from orc_test").collect`|4.4 MB|4.4 MB Author: Andrew Ray Closes #15898 from aray/sql-orc-no-col. --- .../org/apache/spark/sql/hive/HiveShim.scala | 6 ++--- .../spark/sql/hive/orc/OrcQuerySuite.scala | 25 ++++++++++++++++++- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala index 0d2a765a388aa..9e9894803ce25 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -69,13 +69,13 @@ private[hive] object HiveShim { } /* - * Cannot use ColumnProjectionUtils.appendReadColumns directly, if ids is null or empty + * Cannot use ColumnProjectionUtils.appendReadColumns directly, if ids is null */ def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) { - if (ids != null && ids.nonEmpty) { + if (ids != null) { ColumnProjectionUtils.appendReadColumns(conf, ids.asJava) } - if (names != null && names.nonEmpty) { + if (names != null) { appendReadColumnNames(conf, names) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index ecb5972984523..a628977af2f4e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -20,11 +20,13 @@ package org.apache.spark.sql.hive.orc import java.nio.charset.StandardCharsets import java.sql.Timestamp +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.ql.io.orc.{OrcStruct, SparkOrcNewRecordReader} import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.datasources.{LogicalRelation, RecordReaderIterator} import org.apache.spark.sql.hive.{HiveUtils, MetastoreRelation} import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ @@ -577,4 +579,25 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { assert(spark.table(tableName).schema == schema.copy(fields = expectedFields)) } } + + test("Empty schema does not read data from ORC file") { + val data = Seq((1, 1), (2, 2)) + withOrcFile(data) { path => + val requestedSchema = StructType(Nil) + val conf = new Configuration() + val physicalSchema = OrcFileOperator.readSchema(Seq(path), Some(conf)).get + OrcRelation.setRequiredColumns(conf, physicalSchema, requestedSchema) + val maybeOrcReader = OrcFileOperator.getFileReader(path, Some(conf)) + assert(maybeOrcReader.isDefined) + val orcRecordReader = new SparkOrcNewRecordReader( + maybeOrcReader.get, conf, 0, maybeOrcReader.get.getContentLength) + + val recordsIterator = new RecordReaderIterator[OrcStruct](orcRecordReader) + try { + assert(recordsIterator.next().toString == "{null, null}") + } finally { + recordsIterator.close() + } + } + } } From 40d59ff5eaac6df237fe3d50186695c3806b268c Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 18 Nov 2016 21:45:18 +0000 Subject: [PATCH 178/198] [SPARK-18422][CORE] Fix wholeTextFiles test to pass on Windows in JavaAPISuite ## What changes were proposed in this pull request? This PR fixes the test `wholeTextFiles` in `JavaAPISuite.java`. This is failed due to the different path format on Windows. For example, the path in `container` was ``` C:\projects\spark\target\tmp\1478967560189-0/part-00000 ``` whereas `new URI(res._1()).getPath()` was as below: ``` /C:/projects/spark/target/tmp/1478967560189-0/part-00000 ``` ## How was this patch tested? Tests in `JavaAPISuite.java`. Tested via AppVeyor. **Before** Build: https://ci.appveyor.com/project/spark-test/spark/build/63-JavaAPISuite-1 Diff: https://github.com/apache/spark/compare/master...spark-test:JavaAPISuite-1 ``` [info] Test org.apache.spark.JavaAPISuite.wholeTextFiles started [error] Test org.apache.spark.JavaAPISuite.wholeTextFiles failed: java.lang.AssertionError: expected: but was:, took 0.578 sec [error] at org.apache.spark.JavaAPISuite.wholeTextFiles(JavaAPISuite.java:1089) ... ``` **After** Build started: [CORE] `org.apache.spark.JavaAPISuite` [![PR-15866](https://ci.appveyor.com/api/projects/status/github/spark-test/spark?branch=198DDA52-F201-4D2B-BE2F-244E0C1725B2&svg=true)](https://ci.appveyor.com/project/spark-test/spark/branch/198DDA52-F201-4D2B-BE2F-244E0C1725B2) Diff: https://github.com/apache/spark/compare/master...spark-test:198DDA52-F201-4D2B-BE2F-244E0C1725B2 ``` [info] Test org.apache.spark.JavaAPISuite.wholeTextFiles started ... ``` Author: hyukjinkwon Closes #15866 from HyukjinKwon/SPARK-18422. --- .../java/org/apache/spark/JavaAPISuite.java | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 533025ba83e72..7bebe0612f9a8 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -20,7 +20,6 @@ import java.io.*; import java.nio.channels.FileChannel; import java.nio.ByteBuffer; -import java.net.URI; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; @@ -46,6 +45,7 @@ import com.google.common.collect.Lists; import com.google.common.base.Throwables; import com.google.common.io.Files; +import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.io.compress.DefaultCodec; @@ -1075,18 +1075,23 @@ public void wholeTextFiles() throws Exception { byte[] content2 = "spark is also easy to use.\n".getBytes(StandardCharsets.UTF_8); String tempDirName = tempDir.getAbsolutePath(); - Files.write(content1, new File(tempDirName + "/part-00000")); - Files.write(content2, new File(tempDirName + "/part-00001")); + String path1 = new Path(tempDirName, "part-00000").toUri().getPath(); + String path2 = new Path(tempDirName, "part-00001").toUri().getPath(); + + Files.write(content1, new File(path1)); + Files.write(content2, new File(path2)); Map container = new HashMap<>(); - container.put(tempDirName+"/part-00000", new Text(content1).toString()); - container.put(tempDirName+"/part-00001", new Text(content2).toString()); + container.put(path1, new Text(content1).toString()); + container.put(path2, new Text(content2).toString()); JavaPairRDD readRDD = sc.wholeTextFiles(tempDirName, 3); List> result = readRDD.collect(); for (Tuple2 res : result) { - assertEquals(res._2(), container.get(new URI(res._1()).getPath())); + // Note that the paths from `wholeTextFiles` are in URI format on Windows, + // for example, file:/C:/a/b/c. + assertEquals(res._2(), container.get(new Path(res._1()).toUri().getPath())); } } From e5f5c29e021d504284fe5ad1a77dcd5a992ac10a Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 18 Nov 2016 16:13:02 -0800 Subject: [PATCH 179/198] [SPARK-18477][SS] Enable interrupts for HDFS in HDFSMetadataLog ## What changes were proposed in this pull request? HDFS `write` may just hang until timeout if some network error happens. It's better to enable interrupts to allow stopping the query fast on HDFS. This PR just changes the logic to only disable interrupts for local file system, as HADOOP-10622 only happens for local file system. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #15911 from zsxwing/interrupt-on-dfs. --- .../execution/streaming/HDFSMetadataLog.scala | 56 ++++++++++++++----- 1 file changed, 41 insertions(+), 15 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 080729b2ca8d6..d95ec7f67feb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -105,25 +105,34 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: /** * Store the metadata for the specified batchId and return `true` if successful. If the batchId's * metadata has already been stored, this method will return `false`. - * - * Note that this method must be called on a [[org.apache.spark.util.UninterruptibleThread]] - * so that interrupts can be disabled while writing the batch file. This is because there is a - * potential dead-lock in Hadoop "Shell.runCommand" before 2.5.0 (HADOOP-10622). If the thread - * running "Shell.runCommand" is interrupted, then the thread can get deadlocked. In our - * case, `writeBatch` creates a file using HDFS API and calls "Shell.runCommand" to set the - * file permissions, and can get deadlocked if the stream execution thread is stopped by - * interrupt. Hence, we make sure that this method is called on [[UninterruptibleThread]] which - * allows us to disable interrupts here. Also see SPARK-14131. */ override def add(batchId: Long, metadata: T): Boolean = { get(batchId).map(_ => false).getOrElse { // Only write metadata when the batch has not yet been written - Thread.currentThread match { - case ut: UninterruptibleThread => - ut.runUninterruptibly { writeBatch(batchId, metadata, serialize) } - case _ => - throw new IllegalStateException( - "HDFSMetadataLog.add() must be executed on a o.a.spark.util.UninterruptibleThread") + if (fileManager.isLocalFileSystem) { + Thread.currentThread match { + case ut: UninterruptibleThread => + // When using a local file system, "writeBatch" must be called on a + // [[org.apache.spark.util.UninterruptibleThread]] so that interrupts can be disabled + // while writing the batch file. This is because there is a potential dead-lock in + // Hadoop "Shell.runCommand" before 2.5.0 (HADOOP-10622). If the thread running + // "Shell.runCommand" is interrupted, then the thread can get deadlocked. In our case, + // `writeBatch` creates a file using HDFS API and will call "Shell.runCommand" to set + // the file permission if using the local file system, and can get deadlocked if the + // stream execution thread is stopped by interrupt. Hence, we make sure that + // "writeBatch" is called on [[UninterruptibleThread]] which allows us to disable + // interrupts here. Also see SPARK-14131. + ut.runUninterruptibly { writeBatch(batchId, metadata, serialize) } + case _ => + throw new IllegalStateException( + "HDFSMetadataLog.add() on a local file system must be executed on " + + "a o.a.spark.util.UninterruptibleThread") + } + } else { + // For a distributed file system, such as HDFS or S3, if the network is broken, write + // operations may just hang until timeout. We should enable interrupts to allow stopping + // the query fast. + writeBatch(batchId, metadata, serialize) } true } @@ -298,6 +307,9 @@ object HDFSMetadataLog { /** Recursively delete a path if it exists. Should not throw exception if file doesn't exist. */ def delete(path: Path): Unit + + /** Whether the file systme is a local FS. */ + def isLocalFileSystem: Boolean } /** @@ -342,6 +354,13 @@ object HDFSMetadataLog { // ignore if file has already been deleted } } + + override def isLocalFileSystem: Boolean = fc.getDefaultFileSystem match { + case _: local.LocalFs | _: local.RawLocalFs => + // LocalFs = RawLocalFs + ChecksumFs + true + case _ => false + } } /** @@ -398,5 +417,12 @@ object HDFSMetadataLog { // ignore if file has already been deleted } } + + override def isLocalFileSystem: Boolean = fs match { + case _: LocalFileSystem | _: RawLocalFileSystem => + // LocalFileSystem = RawLocalFileSystem + ChecksumFileSystem + true + case _ => false + } } } From 6f7ff75091154fed7649ea6d79e887aad9fbde6a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 18 Nov 2016 16:34:11 -0800 Subject: [PATCH 180/198] [SPARK-18505][SQL] Simplify AnalyzeColumnCommand ## What changes were proposed in this pull request? I'm spending more time at the design & code level for cost-based optimizer now, and have found a number of issues related to maintainability and compatibility that I will like to address. This is a small pull request to clean up AnalyzeColumnCommand: 1. Removed warning on duplicated columns. Warnings in log messages are useless since most users that run SQL don't see them. 2. Removed the nested updateStats function, by just inlining the function. 3. Renamed a few functions to better reflect what they do. 4. Removed the factory apply method for ColumnStatStruct. It is a bad pattern to use a apply method that returns an instantiation of a class that is not of the same type (ColumnStatStruct.apply used to return CreateNamedStruct). 5. Renamed ColumnStatStruct to just AnalyzeColumnCommand. 6. Added more documentation explaining some of the non-obvious return types and code blocks. In follow-up pull requests, I'd like to address the following: 1. Get rid of the Map[String, ColumnStat] map, since internally we should be using Attribute to reference columns, rather than strings. 2. Decouple the fields exposed by ColumnStat and internals of Spark SQL's execution path. Currently the two are coupled because ColumnStat takes in an InternalRow. 3. Correctness: Remove code path that stores statistics in the catalog using the base64 encoding of the UnsafeRow format, which is not stable across Spark versions. 4. Clearly document the data representation stored in the catalog for statistics. ## How was this patch tested? Affected test cases have been updated. Author: Reynold Xin Closes #15933 from rxin/SPARK-18505. --- .../command/AnalyzeColumnCommand.scala | 115 ++++++++++-------- .../spark/sql/StatisticsColumnSuite.scala | 2 +- .../org/apache/spark/sql/StatisticsTest.scala | 7 +- .../spark/sql/hive/HiveExternalCatalog.scala | 4 +- .../sql/hive/client/HiveClientImpl.scala | 2 +- 5 files changed, 74 insertions(+), 56 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 6141fab4aff0d..7fc57d09e9243 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.execution.command -import scala.collection.mutable - +import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases @@ -44,13 +43,16 @@ case class AnalyzeColumnCommand( val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentWithDB)) - relation match { + // Compute total size + val (catalogTable: CatalogTable, sizeInBytes: Long) = relation match { case catalogRel: CatalogRelation => - updateStats(catalogRel.catalogTable, + // This is a Hive serde format table + (catalogRel.catalogTable, AnalyzeTableCommand.calculateTotalSize(sessionState, catalogRel.catalogTable)) case logicalRel: LogicalRelation if logicalRel.catalogTable.isDefined => - updateStats(logicalRel.catalogTable.get, + // This is a data source format table + (logicalRel.catalogTable.get, AnalyzeTableCommand.calculateTotalSize(sessionState, logicalRel.catalogTable.get)) case otherRelation => @@ -58,45 +60,45 @@ case class AnalyzeColumnCommand( s"${otherRelation.nodeName}.") } - def updateStats(catalogTable: CatalogTable, newTotalSize: Long): Unit = { - val (rowCount, columnStats) = computeColStats(sparkSession, relation) - // We also update table-level stats in order to keep them consistent with column-level stats. - val statistics = Statistics( - sizeInBytes = newTotalSize, - rowCount = Some(rowCount), - // Newly computed column stats should override the existing ones. - colStats = catalogTable.stats.map(_.colStats).getOrElse(Map()) ++ columnStats) - sessionState.catalog.alterTable(catalogTable.copy(stats = Some(statistics))) - // Refresh the cached data source table in the catalog. - sessionState.catalog.refreshTable(tableIdentWithDB) - } + // Compute stats for each column + val (rowCount, newColStats) = + AnalyzeColumnCommand.computeColStats(sparkSession, relation, columnNames) + + // We also update table-level stats in order to keep them consistent with column-level stats. + val statistics = Statistics( + sizeInBytes = sizeInBytes, + rowCount = Some(rowCount), + // Newly computed column stats should override the existing ones. + colStats = catalogTable.stats.map(_.colStats).getOrElse(Map.empty) ++ newColStats) + + sessionState.catalog.alterTable(catalogTable.copy(stats = Some(statistics))) + + // Refresh the cached data source table in the catalog. + sessionState.catalog.refreshTable(tableIdentWithDB) Seq.empty[Row] } +} +object AnalyzeColumnCommand extends Logging { + + /** + * Compute stats for the given columns. + * @return (row count, map from column name to ColumnStats) + * + * This is visible for testing. + */ def computeColStats( sparkSession: SparkSession, - relation: LogicalPlan): (Long, Map[String, ColumnStat]) = { + relation: LogicalPlan, + columnNames: Seq[String]): (Long, Map[String, ColumnStat]) = { - // check correctness of column names - val attributesToAnalyze = mutable.MutableList[Attribute]() - val duplicatedColumns = mutable.MutableList[String]() + // Resolve the column names and dedup using AttributeSet val resolver = sparkSession.sessionState.conf.resolver - columnNames.foreach { col => + val attributesToAnalyze = AttributeSet(columnNames.map { col => val exprOption = relation.output.find(attr => resolver(attr.name, col)) - val expr = exprOption.getOrElse(throw new AnalysisException(s"Invalid column name: $col.")) - // do deduplication - if (!attributesToAnalyze.contains(expr)) { - attributesToAnalyze += expr - } else { - duplicatedColumns += col - } - } - if (duplicatedColumns.nonEmpty) { - logWarning("Duplicate column names were deduplicated in `ANALYZE TABLE` statement. " + - s"Input columns: ${columnNames.mkString("(", ", ", ")")}. " + - s"Duplicate columns: ${duplicatedColumns.mkString("(", ", ", ")")}.") - } + exprOption.getOrElse(throw new AnalysisException(s"Invalid column name: $col.")) + }).toSeq // Collect statistics per column. // The first element in the result will be the overall row count, the following elements @@ -104,22 +106,21 @@ case class AnalyzeColumnCommand( // The layout of each struct follows the layout of the ColumnStats. val ndvMaxErr = sparkSession.sessionState.conf.ndvMaxError val expressions = Count(Literal(1)).toAggregateExpression() +: - attributesToAnalyze.map(ColumnStatStruct(_, ndvMaxErr)) + attributesToAnalyze.map(AnalyzeColumnCommand.createColumnStatStruct(_, ndvMaxErr)) val namedExpressions = expressions.map(e => Alias(e, e.toString)()) val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)) .queryExecution.toRdd.collect().head // unwrap the result + // TODO: Get rid of numFields by using the public Dataset API. val rowCount = statsRow.getLong(0) val columnStats = attributesToAnalyze.zipWithIndex.map { case (expr, i) => - val numFields = ColumnStatStruct.numStatFields(expr.dataType) + val numFields = AnalyzeColumnCommand.numStatFields(expr.dataType) (expr.name, ColumnStat(statsRow.getStruct(i + 1, numFields))) }.toMap (rowCount, columnStats) } -} -object ColumnStatStruct { private val zero = Literal(0, LongType) private val one = Literal(1, LongType) @@ -137,7 +138,11 @@ object ColumnStatStruct { private def numTrues(e: Expression): Expression = Sum(If(e, one, zero)) private def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero)) - private def getStruct(exprs: Seq[Expression]): CreateNamedStruct = { + /** + * Creates a struct that groups the sequence of expressions together. This is used to create + * one top level struct per column. + */ + private def createStruct(exprs: Seq[Expression]): CreateNamedStruct = { CreateStruct(exprs.map { expr: Expression => expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() @@ -161,6 +166,7 @@ object ColumnStatStruct { Seq(numNulls(e), numTrues(e), numFalses(e)) } + // TODO(rxin): Get rid of this function. def numStatFields(dataType: DataType): Int = { dataType match { case BinaryType | BooleanType => 3 @@ -168,14 +174,25 @@ object ColumnStatStruct { } } - def apply(attr: Attribute, relativeSD: Double): CreateNamedStruct = attr.dataType match { - // Use aggregate functions to compute statistics we need. - case _: NumericType | TimestampType | DateType => getStruct(numericColumnStat(attr, relativeSD)) - case StringType => getStruct(stringColumnStat(attr, relativeSD)) - case BinaryType => getStruct(binaryColumnStat(attr)) - case BooleanType => getStruct(booleanColumnStat(attr)) - case otherType => - throw new AnalysisException("Analyzing columns is not supported for column " + - s"${attr.name} of data type: ${attr.dataType}.") + /** + * Creates a struct expression that contains the statistics to collect for a column. + * + * @param attr column to collect statistics + * @param relativeSD relative error for approximate number of distinct values. + */ + def createColumnStatStruct(attr: Attribute, relativeSD: Double): CreateNamedStruct = { + attr.dataType match { + case _: NumericType | TimestampType | DateType => + createStruct(numericColumnStat(attr, relativeSD)) + case StringType => + createStruct(stringColumnStat(attr, relativeSD)) + case BinaryType => + createStruct(binaryColumnStat(attr)) + case BooleanType => + createStruct(booleanColumnStat(attr)) + case otherType => + throw new AnalysisException("Analyzing columns is not supported for column " + + s"${attr.name} of data type: ${attr.dataType}.") + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala index f1a201abd8da6..e866ac2cb3b34 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala @@ -79,7 +79,7 @@ class StatisticsColumnSuite extends StatisticsTest { val tableIdent = TableIdentifier(table, Some("default")) val relation = spark.sessionState.catalog.lookupRelation(tableIdent) val (_, columnStats) = - AnalyzeColumnCommand(tableIdent, columnsToAnalyze).computeColStats(spark, relation) + AnalyzeColumnCommand.computeColStats(spark, relation, columnsToAnalyze) assert(columnStats.contains(colName1)) assert(columnStats.contains(colName2)) // check deduplication diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala index 5134ac0e7e5b3..915ee0d31bca2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala @@ -19,11 +19,12 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} -import org.apache.spark.sql.execution.command.{AnalyzeColumnCommand, ColumnStatStruct} +import org.apache.spark.sql.execution.command.AnalyzeColumnCommand import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ + trait StatisticsTest extends QueryTest with SharedSQLContext { def checkColStats( @@ -36,7 +37,7 @@ trait StatisticsTest extends QueryTest with SharedSQLContext { val tableIdent = TableIdentifier(table, Some("default")) val relation = spark.sessionState.catalog.lookupRelation(tableIdent) val (_, columnStats) = - AnalyzeColumnCommand(tableIdent, columns.map(_.name)).computeColStats(spark, relation) + AnalyzeColumnCommand.computeColStats(spark, relation, columns.map(_.name)) expectedColStatsSeq.foreach { case (field, expectedColStat) => assert(columnStats.contains(field.name)) val colStat = columnStats(field.name) @@ -48,7 +49,7 @@ trait StatisticsTest extends QueryTest with SharedSQLContext { // check if we get the same colStat after encoding and decoding val encodedCS = colStat.toString - val numFields = ColumnStatStruct.numStatFields(field.dataType) + val numFields = AnalyzeColumnCommand.numStatFields(field.dataType) val decodedCS = ColumnStat(numFields, encodedCS) StatisticsTest.checkColStat( dataType = field.dataType, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index cacffcf33c263..5dbb4024bbee0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.execution.command.{ColumnStatStruct, DDLUtils} +import org.apache.spark.sql.execution.command.{AnalyzeColumnCommand, DDLUtils} import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.HiveSerDe import org.apache.spark.sql.internal.StaticSQLConf._ @@ -634,7 +634,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat .map { case (k, v) => (k.drop(STATISTICS_COL_STATS_PREFIX.length), v) } val colStats: Map[String, ColumnStat] = tableWithSchema.schema.collect { case f if colStatsProps.contains(f.name) => - val numFields = ColumnStatStruct.numStatFields(f.dataType) + val numFields = AnalyzeColumnCommand.numStatFields(f.dataType) (f.name, ColumnStat(numFields, colStatsProps(f.name))) }.toMap tableWithSchema.copy( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 2bf9a26b0b7fc..daae8523c6366 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -97,7 +97,7 @@ private[hive] class HiveClientImpl( } // Create an internal session state for this HiveClientImpl. - val state = { + val state: SessionState = { val original = Thread.currentThread().getContextClassLoader // Switch to the initClassLoader. Thread.currentThread().setContextClassLoader(initClassLoader) From 2a40de408b5eb47edba92f9fe92a42ed1e78bf98 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 18 Nov 2016 16:34:38 -0800 Subject: [PATCH 181/198] [SPARK-18497][SS] Make ForeachSink support watermark ## What changes were proposed in this pull request? The issue in ForeachSink is the new created DataSet still uses the old QueryExecution. When `foreachPartition` is called, `QueryExecution.toString` will be called and then fail because it doesn't know how to plan EventTimeWatermark. This PR just replaces the QueryExecution with IncrementalExecution to fix the issue. ## How was this patch tested? `test("foreach with watermark")`. Author: Shixiong Zhu Closes #15934 from zsxwing/SPARK-18497. --- .../sql/execution/streaming/ForeachSink.scala | 16 ++++----- .../streaming/ForeachSinkSuite.scala | 35 +++++++++++++++++++ 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala index f5c550dd6ac3a..c93fcfb77cc93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala @@ -47,22 +47,22 @@ class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Seria // method supporting incremental planning. But in the long run, we should generally make newly // created Datasets use `IncrementalExecution` where necessary (which is SPARK-16264 tries to // resolve). - + val incrementalExecution = data.queryExecution.asInstanceOf[IncrementalExecution] val datasetWithIncrementalExecution = - new Dataset(data.sparkSession, data.logicalPlan, implicitly[Encoder[T]]) { + new Dataset(data.sparkSession, incrementalExecution, implicitly[Encoder[T]]) { override lazy val rdd: RDD[T] = { val objectType = exprEnc.deserializer.dataType val deserialized = CatalystSerde.deserialize[T](logicalPlan) // was originally: sparkSession.sessionState.executePlan(deserialized) ... - val incrementalExecution = new IncrementalExecution( + val newIncrementalExecution = new IncrementalExecution( this.sparkSession, deserialized, - data.queryExecution.asInstanceOf[IncrementalExecution].outputMode, - data.queryExecution.asInstanceOf[IncrementalExecution].checkpointLocation, - data.queryExecution.asInstanceOf[IncrementalExecution].currentBatchId, - data.queryExecution.asInstanceOf[IncrementalExecution].currentEventTimeWatermark) - incrementalExecution.toRdd.mapPartitions { rows => + incrementalExecution.outputMode, + incrementalExecution.checkpointLocation, + incrementalExecution.currentBatchId, + incrementalExecution.currentEventTimeWatermark) + newIncrementalExecution.toRdd.mapPartitions { rows => rows.map(_.get(0, objectType)) }.asInstanceOf[RDD[T]] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala index 9e059216110f2..ee6261036fdd0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkException import org.apache.spark.sql.ForeachWriter +import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamTest} import org.apache.spark.sql.test.SharedSQLContext @@ -169,6 +170,40 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf assert(errorEvent.error.get.getMessage === "error") } } + + test("foreach with watermark") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"count".as[Long]) + .map(_.toInt) + .repartition(1) + + val query = windowedAggregation + .writeStream + .outputMode(OutputMode.Complete) + .foreach(new TestForeachWriter()) + .start() + try { + inputData.addData(10, 11, 12) + query.processAllAvailable() + + val allEvents = ForeachSinkSuite.allEvents() + assert(allEvents.size === 1) + val expectedEvents = Seq( + ForeachSinkSuite.Open(partition = 0, version = 0), + ForeachSinkSuite.Process(value = 3), + ForeachSinkSuite.Close(None) + ) + assert(allEvents === Seq(expectedEvents)) + } finally { + query.stop() + } + } } /** A global object to collect events in the executor */ From db9fb9baacbf8640dd37a507b7450db727c7e6ea Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 19 Nov 2016 09:00:11 +0000 Subject: [PATCH 182/198] [SPARK-18448][CORE] SparkSession should implement java.lang.AutoCloseable like JavaSparkContext ## What changes were proposed in this pull request? Just adds `close()` + `Closeable` as a synonym for `stop()`. This makes it usable in Java in try-with-resources, as suggested by ash211 (`Closeable` extends `AutoCloseable` BTW) ## How was this patch tested? Existing tests Author: Sean Owen Closes #15932 from srowen/SPARK-18448. --- .../main/scala/org/apache/spark/sql/SparkSession.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 3045eb69f427f..58b2ab3957173 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.beans.Introspector +import java.io.Closeable import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConverters._ @@ -72,7 +73,7 @@ import org.apache.spark.util.Utils class SparkSession private( @transient val sparkContext: SparkContext, @transient private val existingSharedState: Option[SharedState]) - extends Serializable with Logging { self => + extends Serializable with Closeable with Logging { self => private[sql] def this(sc: SparkContext) { this(sc, None) @@ -647,6 +648,13 @@ class SparkSession private( sparkContext.stop() } + /** + * Synonym for `stop()`. + * + * @since 2.2.0 + */ + override def close(): Unit = stop() + /** * Parses the data type in our internal string representation. The data type string should * have the same format as the one generated by `toString` in scala. From d5b1d5fc80153571c308130833d0c0774de62c92 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 19 Nov 2016 11:24:15 +0000 Subject: [PATCH 183/198] [SPARK-18445][BUILD][DOCS] Fix the markdown for `Note:`/`NOTE:`/`Note that`/`'''Note:'''` across Scala/Java API documentation ## What changes were proposed in this pull request? It seems in Scala/Java, - `Note:` - `NOTE:` - `Note that` - `'''Note:'''` - `note` This PR proposes to fix those to `note` to be consistent. **Before** - Scala ![2016-11-17 6 16 39](https://cloud.githubusercontent.com/assets/6477701/20383180/1a7aed8c-acf2-11e6-9611-5eaf6d52c2e0.png) - Java ![2016-11-17 6 14 41](https://cloud.githubusercontent.com/assets/6477701/20383096/c8ffc680-acf1-11e6-914a-33460bf1401d.png) **After** - Scala ![2016-11-17 6 16 44](https://cloud.githubusercontent.com/assets/6477701/20383167/09940490-acf2-11e6-937a-0d5e1dc2cadf.png) - Java ![2016-11-17 6 13 39](https://cloud.githubusercontent.com/assets/6477701/20383132/e7c2a57e-acf1-11e6-9c47-b849674d4d88.png) ## How was this patch tested? The notes were found via ```bash grep -r "NOTE: " . | \ # Note:|NOTE:|Note that|'''Note:''' grep -v "// NOTE: " | \ # starting with // does not appear in API documentation. grep -E '.scala|.java' | \ # java/scala files grep -v Suite | \ # exclude tests grep -v Test | \ # exclude tests grep -e 'org.apache.spark.api.java' \ # packages appear in API documenation -e 'org.apache.spark.api.java.function' \ # note that this is a regular expression. So actual matches were mostly `org/apache/spark/api/java/functions ...` -e 'org.apache.spark.api.r' \ ... ``` ```bash grep -r "Note that " . | \ # Note:|NOTE:|Note that|'''Note:''' grep -v "// Note that " | \ # starting with // does not appear in API documentation. grep -E '.scala|.java' | \ # java/scala files grep -v Suite | \ # exclude tests grep -v Test | \ # exclude tests grep -e 'org.apache.spark.api.java' \ # packages appear in API documenation -e 'org.apache.spark.api.java.function' \ -e 'org.apache.spark.api.r' \ ... ``` ```bash grep -r "Note: " . | \ # Note:|NOTE:|Note that|'''Note:''' grep -v "// Note: " | \ # starting with // does not appear in API documentation. grep -E '.scala|.java' | \ # java/scala files grep -v Suite | \ # exclude tests grep -v Test | \ # exclude tests grep -e 'org.apache.spark.api.java' \ # packages appear in API documenation -e 'org.apache.spark.api.java.function' \ -e 'org.apache.spark.api.r' \ ... ``` ```bash grep -r "'''Note:'''" . | \ # Note:|NOTE:|Note that|'''Note:''' grep -v "// '''Note:''' " | \ # starting with // does not appear in API documentation. grep -E '.scala|.java' | \ # java/scala files grep -v Suite | \ # exclude tests grep -v Test | \ # exclude tests grep -e 'org.apache.spark.api.java' \ # packages appear in API documenation -e 'org.apache.spark.api.java.function' \ -e 'org.apache.spark.api.r' \ ... ``` And then fixed one by one comparing with API documentation/access modifiers. After that, manually tested via `jekyll build`. Author: hyukjinkwon Closes #15889 from HyukjinKwon/SPARK-18437. --- .../org/apache/spark/ContextCleaner.scala | 2 +- .../scala/org/apache/spark/Partitioner.scala | 2 +- .../scala/org/apache/spark/SparkConf.scala | 6 +- .../scala/org/apache/spark/SparkContext.scala | 47 ++++++++------- .../apache/spark/api/java/JavaDoubleRDD.scala | 4 +- .../apache/spark/api/java/JavaPairRDD.scala | 26 ++++---- .../org/apache/spark/api/java/JavaRDD.scala | 12 ++-- .../apache/spark/api/java/JavaRDDLike.scala | 3 +- .../spark/api/java/JavaSparkContext.scala | 21 +++---- .../api/java/JavaSparkStatusTracker.scala | 2 +- .../io/SparkHadoopMapReduceWriter.scala | 2 +- .../apache/spark/io/CompressionCodec.scala | 23 ++++--- .../apache/spark/partial/BoundedDouble.scala | 2 +- .../org/apache/spark/rdd/CoGroupedRDD.scala | 8 +-- .../apache/spark/rdd/DoubleRDDFunctions.scala | 2 +- .../org/apache/spark/rdd/HadoopRDD.scala | 6 +- .../org/apache/spark/rdd/NewHadoopRDD.scala | 6 +- .../apache/spark/rdd/PairRDDFunctions.scala | 23 +++---- .../spark/rdd/PartitionPruningRDD.scala | 2 +- .../spark/rdd/PartitionwiseSampledRDD.scala | 2 +- .../main/scala/org/apache/spark/rdd/RDD.scala | 46 +++++++------- .../apache/spark/rdd/RDDCheckpointData.scala | 2 +- .../spark/rdd/ReliableCheckpointRDD.scala | 2 +- .../spark/rdd/SequenceFileRDDFunctions.scala | 5 +- .../apache/spark/rdd/ZippedWithIndexRDD.scala | 2 +- .../spark/scheduler/AccumulableInfo.scala | 10 ++-- .../spark/serializer/JavaSerializer.scala | 2 +- .../spark/serializer/KryoSerializer.scala | 2 +- .../apache/spark/serializer/Serializer.scala | 2 +- .../apache/spark/storage/StorageUtils.scala | 19 +++--- .../org/apache/spark/util/AccumulatorV2.scala | 5 +- .../spark/scheduler/DAGSchedulerSuite.scala | 2 +- docs/mllib-isotonic-regression.md | 2 +- docs/streaming-programming-guide.md | 2 +- .../spark/sql/kafka010/KafkaSource.scala | 2 +- .../spark/streaming/kafka/KafkaUtils.scala | 8 +-- .../streaming/kinesis/KinesisUtils.scala | 60 +++++++++---------- .../kinesis/KinesisBackedBlockRDDSuite.scala | 2 +- .../apache/spark/graphx/impl/GraphImpl.scala | 2 +- .../apache/spark/graphx/lib/PageRank.scala | 2 +- .../org/apache/spark/ml/linalg/Vectors.scala | 2 +- .../scala/org/apache/spark/ml/Model.scala | 2 +- .../DecisionTreeClassifier.scala | 6 +- .../ml/classification/GBTClassifier.scala | 6 +- .../classification/LogisticRegression.scala | 36 +++++------ .../spark/ml/clustering/GaussianMixture.scala | 6 +- .../spark/ml/feature/MinMaxScaler.scala | 3 +- .../spark/ml/feature/OneHotEncoder.scala | 3 +- .../org/apache/spark/ml/feature/PCA.scala | 5 +- .../spark/ml/feature/StopWordsRemover.scala | 5 +- .../spark/ml/feature/StringIndexer.scala | 6 +- .../org/apache/spark/ml/param/params.scala | 2 +- .../ml/regression/DecisionTreeRegressor.scala | 6 +- .../GeneralizedLinearRegression.scala | 4 +- .../ml/regression/LinearRegression.scala | 28 +++++---- .../ml/source/libsvm/LibSVMDataSource.scala | 2 +- .../ml/tree/impl/GradientBoostedTrees.scala | 4 +- .../org/apache/spark/ml/util/ReadWrite.scala | 2 +- .../classification/LogisticRegression.scala | 28 +++++---- .../spark/mllib/classification/SVM.scala | 20 ++++--- .../mllib/clustering/GaussianMixture.scala | 8 +-- .../spark/mllib/clustering/KMeans.scala | 8 ++- .../apache/spark/mllib/clustering/LDA.scala | 4 +- .../spark/mllib/clustering/LDAModel.scala | 2 +- .../spark/mllib/clustering/LDAOptimizer.scala | 6 +- .../mllib/evaluation/AreaUnderCurve.scala | 2 +- .../apache/spark/mllib/linalg/Vectors.scala | 6 +- .../linalg/distributed/BlockMatrix.scala | 2 +- .../linalg/distributed/IndexedRowMatrix.scala | 5 +- .../mllib/linalg/distributed/RowMatrix.scala | 21 ++++--- .../spark/mllib/optimization/Gradient.scala | 3 +- .../apache/spark/mllib/rdd/RDDFunctions.scala | 2 +- .../MatrixFactorizationModel.scala | 6 +- .../apache/spark/mllib/stat/Statistics.scala | 34 +++++------ .../spark/mllib/tree/DecisionTree.scala | 32 +++++----- .../apache/spark/mllib/tree/loss/Loss.scala | 12 ++-- .../mllib/tree/model/treeEnsembleModels.scala | 4 +- pom.xml | 7 +++ project/SparkBuild.scala | 3 +- python/pyspark/mllib/stat/KernelDensity.py | 2 +- python/pyspark/mllib/util.py | 2 +- python/pyspark/rdd.py | 4 +- python/pyspark/streaming/kafka.py | 4 +- .../scala/org/apache/spark/sql/Encoders.scala | 8 +-- .../sql/types/CalendarIntervalType.scala | 4 +- .../scala/org/apache/spark/sql/Column.scala | 2 +- .../spark/sql/DataFrameStatFunctions.scala | 3 +- .../apache/spark/sql/DataFrameWriter.scala | 2 +- .../scala/org/apache/spark/sql/Dataset.scala | 56 ++++++++--------- .../org/apache/spark/sql/SQLContext.scala | 7 ++- .../org/apache/spark/sql/SparkSession.scala | 9 +-- .../apache/spark/sql/UDFRegistration.scala | 3 +- .../execution/streaming/state/package.scala | 4 +- .../sql/expressions/UserDefinedFunction.scala | 8 ++- .../org/apache/spark/sql/functions.scala | 22 +++---- .../apache/spark/sql/jdbc/JdbcDialects.scala | 2 +- .../apache/spark/sql/sources/interfaces.scala | 10 ++-- .../sql/util/QueryExecutionListener.scala | 8 ++- .../columnar/InMemoryColumnarQuerySuite.scala | 2 +- .../spark/streaming/StreamingContext.scala | 18 +++--- .../streaming/api/java/JavaPairDStream.scala | 2 +- .../api/java/JavaStreamingContext.scala | 40 +++++++------ .../spark/streaming/dstream/DStream.scala | 4 +- .../dstream/MapWithStateDStream.scala | 2 +- .../WriteAheadLogBackedBlockRDDSuite.scala | 2 +- 105 files changed, 517 insertions(+), 436 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 5678d790e9e76..af913454fce69 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -139,7 +139,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { periodicGCService.shutdown() } - /** Register a RDD for cleanup when it is garbage collected. */ + /** Register an RDD for cleanup when it is garbage collected. */ def registerRDDForCleanup(rdd: RDD[_]): Unit = { registerForCleanup(rdd, CleanRDD(rdd.id)) } diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 93dfbc0e6ed65..f83f5278e8b8f 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -101,7 +101,7 @@ class HashPartitioner(partitions: Int) extends Partitioner { * A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly * equal ranges. The ranges are determined by sampling the content of the RDD passed in. * - * Note that the actual number of partitions created by the RangePartitioner might not be the same + * @note The actual number of partitions created by the RangePartitioner might not be the same * as the `partitions` parameter, in the case where the number of sampled records is less than * the value of `partitions`. */ diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index c9c342df82c97..04d657c09afd0 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -42,10 +42,10 @@ import org.apache.spark.util.Utils * All setter methods in this class support chaining. For example, you can write * `new SparkConf().setMaster("local").setAppName("My app")`. * - * Note that once a SparkConf object is passed to Spark, it is cloned and can no longer be modified - * by the user. Spark does not support modifying the configuration at runtime. - * * @param loadDefaults whether to also load values from Java system properties + * + * @note Once a SparkConf object is passed to Spark, it is cloned and can no longer be modified + * by the user. Spark does not support modifying the configuration at runtime. */ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Serializable { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 25a3d609a6b09..1261e3e735761 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -281,7 +281,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. * - * '''Note:''' As it will be reused in all Hadoop RDDs, it's better not to modify it unless you + * @note As it will be reused in all Hadoop RDDs, it's better not to modify it unless you * plan to set some global configurations for all Hadoop RDDs. */ def hadoopConfiguration: Configuration = _hadoopConfiguration @@ -700,7 +700,7 @@ class SparkContext(config: SparkConf) extends Logging { * Execute a block of code in a scope such that all new RDDs created in this body will * be part of the same scope. For more detail, see {{org.apache.spark.rdd.RDDOperationScope}}. * - * Note: Return statements are NOT allowed in the given body. + * @note Return statements are NOT allowed in the given body. */ private[spark] def withScope[U](body: => U): U = RDDOperationScope.withScope[U](this)(body) @@ -927,7 +927,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * Load data from a flat binary file, assuming the length of each record is constant. * - * '''Note:''' We ensure that the byte array for each record in the resulting RDD + * @note We ensure that the byte array for each record in the resulting RDD * has the provided record length. * * @param path Directory to the input data files, the path can be comma separated paths as the @@ -970,7 +970,7 @@ class SparkContext(config: SparkConf) extends Logging { * @param valueClass Class of the values * @param minPartitions Minimum number of Hadoop Splits to generate. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -995,7 +995,7 @@ class SparkContext(config: SparkConf) extends Logging { /** Get an RDD for a Hadoop file with an arbitrary InputFormat * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1034,7 +1034,7 @@ class SparkContext(config: SparkConf) extends Logging { * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path, minPartitions) * }}} * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1058,7 +1058,7 @@ class SparkContext(config: SparkConf) extends Logging { * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path) * }}} * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1084,7 +1084,7 @@ class SparkContext(config: SparkConf) extends Logging { * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat * and extra configuration options to pass to the input format. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1124,7 +1124,7 @@ class SparkContext(config: SparkConf) extends Logging { * @param kClass Class of the keys * @param vClass Class of the values * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1150,7 +1150,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * Get an RDD for a Hadoop SequenceFile with given key and value types. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1169,7 +1169,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * Get an RDD for a Hadoop SequenceFile with given key and value types. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1199,7 +1199,7 @@ class SparkContext(config: SparkConf) extends Logging { * for the appropriate type. In addition, we pass the converter a ClassTag of its type to * allow it to figure out the Writable class to use in the subclass case. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1330,16 +1330,18 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Register the given accumulator. Note that accumulators must be registered before use, or it - * will throw exception. + * Register the given accumulator. + * + * @note Accumulators must be registered before use, or it will throw exception. */ def register(acc: AccumulatorV2[_, _]): Unit = { acc.register(this) } /** - * Register the given accumulator with given name. Note that accumulators must be registered - * before use, or it will throw exception. + * Register the given accumulator with given name. + * + * @note Accumulators must be registered before use, or it will throw exception. */ def register(acc: AccumulatorV2[_, _], name: String): Unit = { acc.register(this, name = Some(name)) @@ -1550,7 +1552,7 @@ class SparkContext(config: SparkConf) extends Logging { * :: DeveloperApi :: * Request that the cluster manager kill the specified executors. * - * Note: This is an indication to the cluster manager that the application wishes to adjust + * @note This is an indication to the cluster manager that the application wishes to adjust * its resource usage downwards. If the application wishes to replace the executors it kills * through this method with new ones, it should follow up explicitly with a call to * {{SparkContext#requestExecutors}}. @@ -1572,7 +1574,7 @@ class SparkContext(config: SparkConf) extends Logging { * :: DeveloperApi :: * Request that the cluster manager kill the specified executor. * - * Note: This is an indication to the cluster manager that the application wishes to adjust + * @note This is an indication to the cluster manager that the application wishes to adjust * its resource usage downwards. If the application wishes to replace the executor it kills * through this method with a new one, it should follow up explicitly with a call to * {{SparkContext#requestExecutors}}. @@ -1590,7 +1592,7 @@ class SparkContext(config: SparkConf) extends Logging { * this request. This assumes the cluster manager will automatically and eventually * fulfill all missing application resource requests. * - * Note: The replace is by no means guaranteed; another application on the same cluster + * @note The replace is by no means guaranteed; another application on the same cluster * can steal the window of opportunity and acquire this application's resources in the * mean time. * @@ -1639,7 +1641,8 @@ class SparkContext(config: SparkConf) extends Logging { /** * Returns an immutable map of RDDs that have marked themselves as persistent via cache() call. - * Note that this does not necessarily mean the caching or computation was successful. + * + * @note This does not necessarily mean the caching or computation was successful. */ def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap @@ -2298,7 +2301,7 @@ object SparkContext extends Logging { * singleton object. Because we can only have one active SparkContext per JVM, * this is useful when applications may wish to share a SparkContext. * - * Note: This function cannot be used to create multiple SparkContext instances + * @note This function cannot be used to create multiple SparkContext instances * even if multiple contexts are allowed. */ def getOrCreate(config: SparkConf): SparkContext = { @@ -2323,7 +2326,7 @@ object SparkContext extends Logging { * * This method allows not passing a SparkConf (useful if just retrieving). * - * Note: This function cannot be used to create multiple SparkContext instances + * @note This function cannot be used to create multiple SparkContext instances * even if multiple contexts are allowed. */ def getOrCreate(): SparkContext = { diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 0026fc9dad517..a32a4b28c1731 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -153,7 +153,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. */ def intersection(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.intersection(other.srdd)) @@ -256,7 +256,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) * e.g 1<=x<10 , 10<=x<20, 20<=x<50 * And on the input of 1 and 50 we would have a histogram of 1,0,0 * - * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched + * @note If your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched * from an O(log n) insertion to O(1) per element. (where n = # buckets) if you set evenBuckets * to true. * buckets must be sorted and not contain any duplicates. diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 1c95bc4bfcaaf..bff5a29bb60f1 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -206,7 +206,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. */ def intersection(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.intersection(other.rdd)) @@ -223,9 +223,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a - * "combined type" C. Note that V and C can be different -- for example, one might group an - * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three - * functions: + * "combined type" C. + * + * Users provide three functions: * * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) @@ -234,6 +234,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * In addition, users can control the partitioning of the output RDD, the serializer that is use * for the shuffle, and whether to perform map-side aggregation (if a mapper can produce multiple * items with the same key). + * + * @note V and C can be different -- for example, one might group an RDD of type (Int, Int) into + * an RDD of type (Int, List[Int]). */ def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], @@ -255,9 +258,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a - * "combined type" C. Note that V and C can be different -- for example, one might group an - * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three - * functions: + * "combined type" C. + * + * Users provide three functions: * * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) @@ -265,6 +268,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * * In addition, users can control the partitioning of the output RDD. This method automatically * uses map-side aggregation in shuffling the RDD. + * + * @note V and C can be different -- for example, one might group an RDD of type (Int, Int) into + * an RDD of type (Int, List[Int]). */ def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], @@ -398,7 +404,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Group the values for each key in the RDD into a single sequence. Allows controlling the * partitioning of the resulting key-value pair RDD by passing a Partitioner. * - * Note: If you are grouping in order to perform an aggregation (such as a sum or average) over + * @note If you are grouping in order to perform an aggregation (such as a sum or average) over * each key, using [[JavaPairRDD.reduceByKey]] or [[JavaPairRDD.combineByKey]] * will provide much better performance. */ @@ -409,7 +415,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Group the values for each key in the RDD into a single sequence. Hash-partitions the * resulting RDD with into `numPartitions` partitions. * - * Note: If you are grouping in order to perform an aggregation (such as a sum or average) over + * @note If you are grouping in order to perform an aggregation (such as a sum or average) over * each key, using [[JavaPairRDD.reduceByKey]] or [[JavaPairRDD.combineByKey]] * will provide much better performance. */ @@ -539,7 +545,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Group the values for each key in the RDD into a single sequence. Hash-partitions the * resulting RDD with the existing partitioner/parallelism level. * - * Note: If you are grouping in order to perform an aggregation (such as a sum or average) over + * @note If you are grouping in order to perform an aggregation (such as a sum or average) over * each key, using [[JavaPairRDD.reduceByKey]] or [[JavaPairRDD.combineByKey]] * will provide much better performance. */ diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index d67cff64e6e46..ccd94f876e0b8 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -99,27 +99,29 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) /** * Return a sampled subset of this RDD with a random seed. - * Note: this is NOT guaranteed to provide exactly the fraction of the count - * of the given [[RDD]]. * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] * with replacement: expected number of times each element is chosen; fraction must be >= 0 + * + * @note This is NOT guaranteed to provide exactly the fraction of the count + * of the given [[RDD]]. */ def sample(withReplacement: Boolean, fraction: Double): JavaRDD[T] = sample(withReplacement, fraction, Utils.random.nextLong) /** * Return a sampled subset of this RDD, with a user-supplied seed. - * Note: this is NOT guaranteed to provide exactly the fraction of the count - * of the given [[RDD]]. * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] * with replacement: expected number of times each element is chosen; fraction must be >= 0 * @param seed seed for the random number generator + * + * @note This is NOT guaranteed to provide exactly the fraction of the count + * of the given [[RDD]]. */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaRDD[T] = wrapRDD(rdd.sample(withReplacement, fraction, seed)) @@ -157,7 +159,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. */ def intersection(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.intersection(other.rdd)) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index a37c52cbaf210..eda16d957cc58 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -47,7 +47,8 @@ private[spark] abstract class AbstractJavaRDDLike[T, This <: JavaRDDLike[T, This /** * Defines operations common to several Java RDD implementations. - * Note that this trait is not intended to be implemented by user code. + * + * @note This trait is not intended to be implemented by user code. */ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def wrapRDD(rdd: RDD[T]): This diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 4e50c2686dd53..38d347aeab8c6 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -298,7 +298,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get an RDD for a Hadoop SequenceFile with given key and value types. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -316,7 +316,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get an RDD for a Hadoop SequenceFile. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -366,7 +366,7 @@ class JavaSparkContext(val sc: SparkContext) * @param valueClass Class of the values * @param minPartitions Minimum number of Hadoop Splits to generate. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -396,7 +396,7 @@ class JavaSparkContext(val sc: SparkContext) * @param keyClass Class of the keys * @param valueClass Class of the values * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -416,7 +416,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get an RDD for a Hadoop file with an arbitrary InputFormat. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -437,7 +437,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get an RDD for a Hadoop file with an arbitrary InputFormat * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -458,7 +458,7 @@ class JavaSparkContext(val sc: SparkContext) * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat * and extra configuration options to pass to the input format. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -487,7 +487,7 @@ class JavaSparkContext(val sc: SparkContext) * @param kClass Class of the keys * @param vClass Class of the values * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -694,7 +694,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse. * - * '''Note:''' As it will be reused in all Hadoop RDDs, it's better not to modify it unless you + * @note As it will be reused in all Hadoop RDDs, it's better not to modify it unless you * plan to set some global configurations for all Hadoop RDDs. */ def hadoopConfiguration(): Configuration = { @@ -811,7 +811,8 @@ class JavaSparkContext(val sc: SparkContext) /** * Returns a Java map of JavaRDDs that have marked themselves as persistent via cache() call. - * Note that this does not necessarily mean the caching or computation was successful. + * + * @note This does not necessarily mean the caching or computation was successful. */ def getPersistentRDDs: JMap[java.lang.Integer, JavaRDD[_]] = { sc.getPersistentRDDs.mapValues(s => JavaRDD.fromRDD(s)) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala index 99ca3c77cced0..6aa290ecd7bb5 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala @@ -31,7 +31,7 @@ import org.apache.spark.{SparkContext, SparkJobInfo, SparkStageInfo} * will provide information for the last `spark.ui.retainedStages` stages and * `spark.ui.retainedJobs` jobs. * - * NOTE: this class's constructor should be considered private and may be subject to change. + * @note This class's constructor should be considered private and may be subject to change. */ class JavaSparkStatusTracker private[spark] (sc: SparkContext) { diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala index 796439276a22e..aaeb3d003829a 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala @@ -119,7 +119,7 @@ object SparkHadoopMapReduceWriter extends Logging { } } - /** Write a RDD partition out in a single Spark task. */ + /** Write an RDD partition out in a single Spark task. */ private def executeTask[K, V: ClassTag]( context: TaskContext, jobTrackerId: String, diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index ae014becef755..6ba79e506a648 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -32,9 +32,8 @@ import org.apache.spark.util.Utils * CompressionCodec allows the customization of choosing different compression implementations * to be used in block storage. * - * Note: The wire protocol for a codec is not guaranteed compatible across versions of Spark. - * This is intended for use as an internal compression utility within a single - * Spark application. + * @note The wire protocol for a codec is not guaranteed compatible across versions of Spark. + * This is intended for use as an internal compression utility within a single Spark application. */ @DeveloperApi trait CompressionCodec { @@ -103,9 +102,9 @@ private[spark] object CompressionCodec { * LZ4 implementation of [[org.apache.spark.io.CompressionCodec]]. * Block size can be configured by `spark.io.compression.lz4.blockSize`. * - * Note: The wire protocol for this codec is not guaranteed to be compatible across versions - * of Spark. This is intended for use as an internal compression utility within a single Spark - * application. + * @note The wire protocol for this codec is not guaranteed to be compatible across versions + * of Spark. This is intended for use as an internal compression utility within a single Spark + * application. */ @DeveloperApi class LZ4CompressionCodec(conf: SparkConf) extends CompressionCodec { @@ -123,9 +122,9 @@ class LZ4CompressionCodec(conf: SparkConf) extends CompressionCodec { * :: DeveloperApi :: * LZF implementation of [[org.apache.spark.io.CompressionCodec]]. * - * Note: The wire protocol for this codec is not guaranteed to be compatible across versions - * of Spark. This is intended for use as an internal compression utility within a single Spark - * application. + * @note The wire protocol for this codec is not guaranteed to be compatible across versions + * of Spark. This is intended for use as an internal compression utility within a single Spark + * application. */ @DeveloperApi class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { @@ -143,9 +142,9 @@ class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { * Snappy implementation of [[org.apache.spark.io.CompressionCodec]]. * Block size can be configured by `spark.io.compression.snappy.blockSize`. * - * Note: The wire protocol for this codec is not guaranteed to be compatible across versions - * of Spark. This is intended for use as an internal compression utility within a single Spark - * application. + * @note The wire protocol for this codec is not guaranteed to be compatible across versions + * of Spark. This is intended for use as an internal compression utility within a single Spark + * application. */ @DeveloperApi class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { diff --git a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala index ab6aba6fc7d6a..8f579c5a3033c 100644 --- a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala +++ b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala @@ -28,7 +28,7 @@ class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, v this.mean.hashCode ^ this.confidence.hashCode ^ this.low.hashCode ^ this.high.hashCode /** - * Note that consistent with Double, any NaN value will make equality false + * @note Consistent with Double, any NaN value will make equality false */ override def equals(that: Any): Boolean = that match { diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 2381f54ee3f06..a091f06b4ed7c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -66,14 +66,14 @@ private[spark] class CoGroupPartition( /** * :: DeveloperApi :: - * A RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a + * An RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a * tuple with the list of values for that key. * - * Note: This is an internal API. We recommend users use RDD.cogroup(...) instead of - * instantiating this directly. - * * @param rdds parent RDDs. * @param part partitioner used to partition the shuffle output + * + * @note This is an internal API. We recommend users use RDD.cogroup(...) instead of + * instantiating this directly. */ @DeveloperApi class CoGroupedRDD[K: ClassTag]( diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index a05a770b40c57..f3ab324d59119 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -158,7 +158,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { * e.g 1<=x<10 , 10<=x<20, 20<=x<=50 * And on the input of 1 and 50 we would have a histogram of 1, 0, 1 * - * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched + * @note If your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched * from an O(log n) insertion to O(1) per element. (where n = # buckets) if you set evenBuckets * to true. * buckets must be sorted and not contain any duplicates. diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 36a2f5c87e372..86351b8c575e5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -84,9 +84,6 @@ private[spark] class HadoopPartition(rddId: Int, override val index: Int, s: Inp * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, * sources in HBase, or S3), using the older MapReduce API (`org.apache.hadoop.mapred`). * - * Note: Instantiating this class directly is not recommended, please use - * [[org.apache.spark.SparkContext.hadoopRDD()]] - * * @param sc The SparkContext to associate the RDD with. * @param broadcastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed * variable references an instance of JobConf, then that JobConf will be used for the Hadoop job. @@ -97,6 +94,9 @@ private[spark] class HadoopPartition(rddId: Int, override val index: Int, s: Inp * @param keyClass Class of the key associated with the inputFormatClass. * @param valueClass Class of the value associated with the inputFormatClass. * @param minPartitions Minimum number of HadoopRDD partitions (Hadoop Splits) to generate. + * + * @note Instantiating this class directly is not recommended, please use + * [[org.apache.spark.SparkContext.hadoopRDD()]] */ @DeveloperApi class HadoopRDD[K, V]( diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 488e777fea371..a5965f597038d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -57,13 +57,13 @@ private[spark] class NewHadoopPartition( * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, * sources in HBase, or S3), using the new MapReduce API (`org.apache.hadoop.mapreduce`). * - * Note: Instantiating this class directly is not recommended, please use - * [[org.apache.spark.SparkContext.newAPIHadoopRDD()]] - * * @param sc The SparkContext to associate the RDD with. * @param inputFormatClass Storage format of the data to be read. * @param keyClass Class of the key associated with the inputFormatClass. * @param valueClass Class of the value associated with the inputFormatClass. + * + * @note Instantiating this class directly is not recommended, please use + * [[org.apache.spark.SparkContext.newAPIHadoopRDD()]] */ @DeveloperApi class NewHadoopRDD[K, V]( diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index f9b9631d9e7ca..33e695ec5322b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -57,8 +57,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * :: Experimental :: * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined type" C - * Note that V and C can be different -- for example, one might group an RDD of type - * (Int, Int) into an RDD of type (Int, Seq[Int]). Users provide three functions: + * + * Users provide three functions: * * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) @@ -66,6 +66,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * * In addition, users can control the partitioning of the output RDD, and whether to perform * map-side aggregation (if a mapper can produce multiple items with the same key). + * + * @note V and C can be different -- for example, one might group an RDD of type + * (Int, Int) into an RDD of type (Int, Seq[Int]). */ @Experimental def combineByKeyWithClassTag[C]( @@ -361,7 +364,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) /** * Count the number of elements for each key, collecting the results to a local Map. * - * Note that this method should only be used if the resulting map is expected to be small, as + * @note This method should only be used if the resulting map is expected to be small, as * the whole thing is loaded into the driver's memory. * To handle very large results, consider using rdd.mapValues(_ => 1L).reduceByKey(_ + _), which * returns an RDD[T, Long] instead of a map. @@ -488,11 +491,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * The ordering of elements within each group is not guaranteed, and may even differ * each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an + * @note This operation may be very expensive. If you are grouping in order to perform an * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. * - * Note: As currently implemented, groupByKey must be able to hold all the key-value pairs for any + * @note As currently implemented, groupByKey must be able to hold all the key-value pairs for any * key in memory. If a key has too many values, it can result in an [[OutOfMemoryError]]. */ def groupByKey(partitioner: Partitioner): RDD[(K, Iterable[V])] = self.withScope { @@ -512,11 +515,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * resulting RDD with into `numPartitions` partitions. The ordering of elements within * each group is not guaranteed, and may even differ each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an + * @note This operation may be very expensive. If you are grouping in order to perform an * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. * - * Note: As currently implemented, groupByKey must be able to hold all the key-value pairs for any + * @note As currently implemented, groupByKey must be able to hold all the key-value pairs for any * key in memory. If a key has too many values, it can result in an [[OutOfMemoryError]]. */ def groupByKey(numPartitions: Int): RDD[(K, Iterable[V])] = self.withScope { @@ -633,7 +636,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * within each group is not guaranteed, and may even differ each time the resulting RDD is * evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an + * @note This operation may be very expensive. If you are grouping in order to perform an * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ @@ -1014,7 +1017,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class * supporting the key and value types K and V in this RDD. * - * Note that, we should make sure our tasks are idempotent when speculation is enabled, i.e. do + * @note We should make sure our tasks are idempotent when speculation is enabled, i.e. do * not use output committer that writes data directly. * There is an example in https://issues.apache.org/jira/browse/SPARK-10063 to show the bad * result of using direct output committer with speculation enabled. @@ -1068,7 +1071,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * output paths required (e.g. a table name to write to) in the same way as it would be * configured for a Hadoop MapReduce job. * - * Note that, we should make sure our tasks are idempotent when speculation is enabled, i.e. do + * @note We should make sure our tasks are idempotent when speculation is enabled, i.e. do * not use output committer that writes data directly. * There is an example in https://issues.apache.org/jira/browse/SPARK-10063 to show the bad * result of using direct output committer with speculation enabled. diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala index 0c6ddda52cee9..ce75a16031a3f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala @@ -48,7 +48,7 @@ private[spark] class PruneDependency[T](rdd: RDD[T], partitionFilterFunc: Int => /** * :: DeveloperApi :: - * A RDD used to prune RDD partitions/partitions so we can avoid launching tasks on + * An RDD used to prune RDD partitions/partitions so we can avoid launching tasks on * all partitions. An example use case: If we know the RDD is partitioned by range, * and the execution DAG has a filter on the key, we can avoid launching tasks * on partitions that don't have the range covering the key. diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala index 3b1acacf409b9..6a89ea8786464 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala @@ -32,7 +32,7 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long) } /** - * A RDD sampled from its parent RDD partition-wise. For each partition of the parent RDD, + * An RDD sampled from its parent RDD partition-wise. For each partition of the parent RDD, * a user-specified [[org.apache.spark.util.random.RandomSampler]] instance is used to obtain * a random sample of the records in the partition. The random seeds assigned to the samplers * are guaranteed to have different values. diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index cded899db1f5c..bff2b8f1d06c9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -428,7 +428,7 @@ abstract class RDD[T: ClassTag]( * current upstream partitions will be executed in parallel (per whatever * the current partitioning is). * - * Note: With shuffle = true, you can actually coalesce to a larger number + * @note With shuffle = true, you can actually coalesce to a larger number * of partitions. This is useful if you have a small number of partitions, * say 100, potentially with a few partitions being abnormally large. Calling * coalesce(1000, shuffle = true) will result in 1000 partitions with the @@ -466,14 +466,14 @@ abstract class RDD[T: ClassTag]( /** * Return a sampled subset of this RDD. * - * Note: this is NOT guaranteed to provide exactly the fraction of the count - * of the given [[RDD]]. - * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] * with replacement: expected number of times each element is chosen; fraction must be >= 0 * @param seed seed for the random number generator + * + * @note This is NOT guaranteed to provide exactly the fraction of the count + * of the given [[RDD]]. */ def sample( withReplacement: Boolean, @@ -537,13 +537,13 @@ abstract class RDD[T: ClassTag]( /** * Return a fixed-size sampled subset of this RDD in an array * - * @note this method should only be used if the resulting array is expected to be small, as - * all the data is loaded into the driver's memory. - * * @param withReplacement whether sampling is done with replacement * @param num size of the returned sample * @param seed seed for the random number generator * @return sample of specified size in an array + * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. */ def takeSample( withReplacement: Boolean, @@ -618,7 +618,7 @@ abstract class RDD[T: ClassTag]( * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. */ def intersection(other: RDD[T]): RDD[T] = withScope { this.map(v => (v, null)).cogroup(other.map(v => (v, null))) @@ -630,7 +630,7 @@ abstract class RDD[T: ClassTag]( * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. * * @param partitioner Partitioner to use for the resulting RDD */ @@ -646,7 +646,7 @@ abstract class RDD[T: ClassTag]( * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. Performs a hash partition across the cluster * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. * * @param numPartitions How many partitions to use in the resulting RDD */ @@ -674,7 +674,7 @@ abstract class RDD[T: ClassTag]( * mapping to that key. The ordering of elements within each group is not guaranteed, and * may even differ each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an + * @note This operation may be very expensive. If you are grouping in order to perform an * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ @@ -687,7 +687,7 @@ abstract class RDD[T: ClassTag]( * mapping to that key. The ordering of elements within each group is not guaranteed, and * may even differ each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an + * @note This operation may be very expensive. If you are grouping in order to perform an * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ @@ -702,7 +702,7 @@ abstract class RDD[T: ClassTag]( * mapping to that key. The ordering of elements within each group is not guaranteed, and * may even differ each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an + * @note This operation may be very expensive. If you are grouping in order to perform an * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ @@ -921,7 +921,7 @@ abstract class RDD[T: ClassTag]( /** * Return an array that contains all of the elements in this RDD. * - * @note this method should only be used if the resulting array is expected to be small, as + * @note This method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. */ def collect(): Array[T] = withScope { @@ -934,7 +934,7 @@ abstract class RDD[T: ClassTag]( * * The iterator will consume as much memory as the largest partition in this RDD. * - * Note: this results in multiple Spark jobs, and if the input RDD is the result + * @note This results in multiple Spark jobs, and if the input RDD is the result * of a wide transformation (e.g. join with different partitioners), to avoid * recomputing the input RDD should be cached first. */ @@ -1182,7 +1182,7 @@ abstract class RDD[T: ClassTag]( /** * Return the count of each unique value in this RDD as a local map of (value, count) pairs. * - * Note that this method should only be used if the resulting map is expected to be small, as + * @note This method should only be used if the resulting map is expected to be small, as * the whole thing is loaded into the driver's memory. * To handle very large results, consider using rdd.map(x => (x, 1L)).reduceByKey(_ + _), which * returns an RDD[T, Long] instead of a map. @@ -1272,7 +1272,7 @@ abstract class RDD[T: ClassTag]( * This is similar to Scala's zipWithIndex but it uses Long instead of Int as the index type. * This method needs to trigger a spark job when this RDD contains more than one partitions. * - * Note that some RDDs, such as those returned by groupBy(), do not guarantee order of + * @note Some RDDs, such as those returned by groupBy(), do not guarantee order of * elements in a partition. The index assigned to each element is therefore not guaranteed, * and may even change if the RDD is reevaluated. If a fixed ordering is required to guarantee * the same index assignments, you should sort the RDD with sortByKey() or save it to a file. @@ -1286,7 +1286,7 @@ abstract class RDD[T: ClassTag]( * 2*n+k, ..., where n is the number of partitions. So there may exist gaps, but this method * won't trigger a spark job, which is different from [[org.apache.spark.rdd.RDD#zipWithIndex]]. * - * Note that some RDDs, such as those returned by groupBy(), do not guarantee order of + * @note Some RDDs, such as those returned by groupBy(), do not guarantee order of * elements in a partition. The unique ID assigned to each element is therefore not guaranteed, * and may even change if the RDD is reevaluated. If a fixed ordering is required to guarantee * the same index assignments, you should sort the RDD with sortByKey() or save it to a file. @@ -1305,10 +1305,10 @@ abstract class RDD[T: ClassTag]( * results from that partition to estimate the number of additional partitions needed to satisfy * the limit. * - * @note this method should only be used if the resulting array is expected to be small, as + * @note This method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. * - * @note due to complications in the internal implementation, this method will raise + * @note Due to complications in the internal implementation, this method will raise * an exception if called on an RDD of `Nothing` or `Null`. */ def take(num: Int): Array[T] = withScope { @@ -1370,7 +1370,7 @@ abstract class RDD[T: ClassTag]( * // returns Array(6, 5) * }}} * - * @note this method should only be used if the resulting array is expected to be small, as + * @note This method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. * * @param num k, the number of top elements to return @@ -1393,7 +1393,7 @@ abstract class RDD[T: ClassTag]( * // returns Array(2, 3) * }}} * - * @note this method should only be used if the resulting array is expected to be small, as + * @note This method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. * * @param num k, the number of elements to return @@ -1438,7 +1438,7 @@ abstract class RDD[T: ClassTag]( } /** - * @note due to complications in the internal implementation, this method will raise an + * @note Due to complications in the internal implementation, this method will raise an * exception if called on an RDD of `Nothing` or `Null`. This may be come up in practice * because, for example, the type of `parallelize(Seq())` is `RDD[Nothing]`. * (`parallelize(Seq())` should be avoided anyway in favor of `parallelize(Seq[T]())`.) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index 429514b4f6bee..1070bb96b2524 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -32,7 +32,7 @@ private[spark] object CheckpointState extends Enumeration { /** * This class contains all the information related to RDD checkpointing. Each instance of this - * class is associated with a RDD. It manages process of checkpointing of the associated RDD, + * class is associated with an RDD. It manages process of checkpointing of the associated RDD, * as well as, manages the post-checkpoint state by providing the updated partitions, * iterator and preferred locations of the checkpointed RDD. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index 9f800e3a0953c..e0a29b48314fb 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -151,7 +151,7 @@ private[spark] object ReliableCheckpointRDD extends Logging { } /** - * Write a RDD partition's data to a checkpoint file. + * Write an RDD partition's data to a checkpoint file. */ def writePartitionToCheckpointFile[T: ClassTag]( path: String, diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala index 1311b481c7c71..86a332790fb00 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala @@ -27,9 +27,10 @@ import org.apache.spark.internal.Logging /** * Extra functions available on RDDs of (key, value) pairs to create a Hadoop SequenceFile, - * through an implicit conversion. Note that this can't be part of PairRDDFunctions because - * we need more implicit parameters to convert our keys and values to Writable. + * through an implicit conversion. * + * @note This can't be part of PairRDDFunctions because we need more implicit parameters to + * convert our keys and values to Writable. */ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag]( self: RDD[(K, V)], diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala index b0e5ba0865c63..8425b211d6ecf 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala @@ -29,7 +29,7 @@ class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long) } /** - * Represents a RDD zipped with its element indices. The ordering is first based on the partition + * Represents an RDD zipped with its element indices. The ordering is first based on the partition * index and then the ordering of items within each partition. So the first item in the first * partition gets index 0, and the last item in the last partition receives the largest index. * diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala index cedacad44afec..0a5fe5a1d3ee1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -24,11 +24,6 @@ import org.apache.spark.annotation.DeveloperApi * :: DeveloperApi :: * Information about an [[org.apache.spark.Accumulable]] modified during a task or stage. * - * Note: once this is JSON serialized the types of `update` and `value` will be lost and be - * cast to strings. This is because the user can define an accumulator of any type and it will - * be difficult to preserve the type in consumers of the event log. This does not apply to - * internal accumulators that represent task level metrics. - * * @param id accumulator ID * @param name accumulator name * @param update partial value from a task, may be None if used on driver to describe a stage @@ -36,6 +31,11 @@ import org.apache.spark.annotation.DeveloperApi * @param internal whether this accumulator was internal * @param countFailedValues whether to count this accumulator's partial value if the task failed * @param metadata internal metadata associated with this accumulator, if any + * + * @note Once this is JSON serialized the types of `update` and `value` will be lost and be + * cast to strings. This is because the user can define an accumulator of any type and it will + * be difficult to preserve the type in consumers of the event log. This does not apply to + * internal accumulators that represent task level metrics. */ @DeveloperApi case class AccumulableInfo private[spark] ( diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 8b72da2ee01b7..f60dcfddfdc20 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -131,7 +131,7 @@ private[spark] class JavaSerializerInstance( * :: DeveloperApi :: * A Spark serializer that uses Java's built-in serialization. * - * Note that this serializer is not guaranteed to be wire-compatible across different versions of + * @note This serializer is not guaranteed to be wire-compatible across different versions of * Spark. It is intended to be used to serialize/de-serialize data within a single * Spark application. */ diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 0d26281fe1076..19e020c968a9a 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -45,7 +45,7 @@ import org.apache.spark.util.collection.CompactBuffer /** * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. * - * Note that this serializer is not guaranteed to be wire-compatible across different versions of + * @note This serializer is not guaranteed to be wire-compatible across different versions of * Spark. It is intended to be used to serialize/de-serialize data within a single * Spark application. */ diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index cb95246d5b0ca..afe6cd86059f0 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -40,7 +40,7 @@ import org.apache.spark.util.NextIterator * * 2. Java serialization interface. * - * Note that serializers are not required to be wire-compatible across different versions of Spark. + * @note Serializers are not required to be wire-compatible across different versions of Spark. * They are intended to be used to serialize/de-serialize data within a single Spark application. */ @DeveloperApi diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index fb9941bbd9e0f..e12f2e6095d5a 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -71,7 +71,7 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return the blocks stored in this block manager. * - * Note that this is somewhat expensive, as it involves cloning the underlying maps and then + * @note This is somewhat expensive, as it involves cloning the underlying maps and then * concatenating them together. Much faster alternatives exist for common operations such as * contains, get, and size. */ @@ -80,7 +80,7 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return the RDD blocks stored in this block manager. * - * Note that this is somewhat expensive, as it involves cloning the underlying maps and then + * @note This is somewhat expensive, as it involves cloning the underlying maps and then * concatenating them together. Much faster alternatives exist for common operations such as * getting the memory, disk, and off-heap memory sizes occupied by this RDD. */ @@ -128,7 +128,8 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return whether the given block is stored in this block manager in O(1) time. - * Note that this is much faster than `this.blocks.contains`, which is O(blocks) time. + * + * @note This is much faster than `this.blocks.contains`, which is O(blocks) time. */ def containsBlock(blockId: BlockId): Boolean = { blockId match { @@ -141,7 +142,8 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return the given block stored in this block manager in O(1) time. - * Note that this is much faster than `this.blocks.get`, which is O(blocks) time. + * + * @note This is much faster than `this.blocks.get`, which is O(blocks) time. */ def getBlock(blockId: BlockId): Option[BlockStatus] = { blockId match { @@ -154,19 +156,22 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return the number of blocks stored in this block manager in O(RDDs) time. - * Note that this is much faster than `this.blocks.size`, which is O(blocks) time. + * + * @note This is much faster than `this.blocks.size`, which is O(blocks) time. */ def numBlocks: Int = _nonRddBlocks.size + numRddBlocks /** * Return the number of RDD blocks stored in this block manager in O(RDDs) time. - * Note that this is much faster than `this.rddBlocks.size`, which is O(RDD blocks) time. + * + * @note This is much faster than `this.rddBlocks.size`, which is O(RDD blocks) time. */ def numRddBlocks: Int = _rddBlocks.values.map(_.size).sum /** * Return the number of blocks that belong to the given RDD in O(1) time. - * Note that this is much faster than `this.rddBlocksById(rddId).size`, which is + * + * @note This is much faster than `this.rddBlocksById(rddId).size`, which is * O(blocks in this RDD) time. */ def numRddBlocksById(rddId: Int): Int = _rddBlocks.get(rddId).map(_.size).getOrElse(0) diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index d3ddd39131326..1326f0977c241 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -59,8 +59,9 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { } /** - * Returns true if this accumulator has been registered. Note that all accumulators must be - * registered before use, or it will throw exception. + * Returns true if this accumulator has been registered. + * + * @note All accumulators must be registered before use, or it will throw exception. */ final def isRegistered: Boolean = metadata != null && AccumulatorContext.get(metadata.id).isDefined diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index bec95d13d193a..5e8a854e46a0f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -2076,7 +2076,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou } /** - * Checks the DAGScheduler's internal logic for traversing a RDD DAG by making sure that + * Checks the DAGScheduler's internal logic for traversing an RDD DAG by making sure that * getShuffleDependencies correctly returns the direct shuffle dependencies of a particular * RDD. The test creates the following RDD graph (where n denotes a narrow dependency and s * denotes a shuffle dependency): diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md index d90905a86ade9..ca84551506b2b 100644 --- a/docs/mllib-isotonic-regression.md +++ b/docs/mllib-isotonic-regression.md @@ -27,7 +27,7 @@ best fitting the original data points. [pool adjacent violators algorithm](http://doi.org/10.1198/TECH.2010.10111) which uses an approach to [parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10). -The training input is a RDD of tuples of three double values that represent +The training input is an RDD of tuples of three double values that represent label, feature and weight in this order. Additionally IsotonicRegression algorithm has one optional parameter called $isotonic$ defaulting to true. This argument specifies if the isotonic regression is diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 0b0315b366501..18fc1cd934826 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -2191,7 +2191,7 @@ consistent batch processing times. Make sure you set the CMS GC on both the driv - When data is received from a stream source, receiver creates blocks of data. A new block of data is generated every blockInterval milliseconds. N blocks of data are created during the batchInterval where N = batchInterval/blockInterval. These blocks are distributed by the BlockManager of the current executor to the block managers of other executors. After that, the Network Input Tracker running on the driver is informed about the block locations for further processing. -- A RDD is created on the driver for the blocks created during the batchInterval. The blocks generated during the batchInterval are partitions of the RDD. Each partition is a task in spark. blockInterval== batchinterval would mean that a single partition is created and probably it is processed locally. +- An RDD is created on the driver for the blocks created during the batchInterval. The blocks generated during the batchInterval are partitions of the RDD. Each partition is a task in spark. blockInterval== batchinterval would mean that a single partition is created and probably it is processed locally. - The map tasks on the blocks are processed in the executors (one that received the block, and another where the block was replicated) that has the blocks irrespective of block interval, unless non-local scheduling kicks in. Having bigger blockinterval means bigger blocks. A high value of `spark.locality.wait` increases the chance of processing a block on the local node. A balance needs to be found out between these two parameters to ensure that the bigger blocks are processed locally. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 5bcc5124b0915..341081a338c0e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -279,7 +279,7 @@ private[kafka010] case class KafkaSource( } }.toArray - // Create a RDD that reads from Kafka and get the (key, value) pair as byte arrays. + // Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays. val rdd = new KafkaSourceRDD( sc, executorKafkaParams, offsetRanges, pollTimeoutMs).map { cr => Row(cr.key, cr.value, cr.topic, cr.partition, cr.offset, cr.timestamp, cr.timestampType.id) diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index b17e198077949..56f0cb0b166a2 100644 --- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -223,7 +223,7 @@ object KafkaUtils { } /** - * Create a RDD from Kafka using offset ranges for each topic and partition. + * Create an RDD from Kafka using offset ranges for each topic and partition. * * @param sc SparkContext object * @param kafkaParams Kafka @@ -255,7 +255,7 @@ object KafkaUtils { } /** - * Create a RDD from Kafka using offset ranges for each topic and partition. This allows you + * Create an RDD from Kafka using offset ranges for each topic and partition. This allows you * specify the Kafka leader to connect to (to optimize fetching) and access the message as well * as the metadata. * @@ -303,7 +303,7 @@ object KafkaUtils { } /** - * Create a RDD from Kafka using offset ranges for each topic and partition. + * Create an RDD from Kafka using offset ranges for each topic and partition. * * @param jsc JavaSparkContext object * @param kafkaParams Kafka @@ -340,7 +340,7 @@ object KafkaUtils { } /** - * Create a RDD from Kafka using offset ranges for each topic and partition. This allows you + * Create an RDD from Kafka using offset ranges for each topic and partition. This allows you * specify the Kafka leader to connect to (to optimize fetching) and access the message as well * as the metadata. * diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index a0007d33d6257..b2daffa34ccbf 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -33,10 +33,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain - * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain - * gets the AWS credentials. - * * @param ssc StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -57,6 +53,10 @@ object KinesisUtils { * StorageLevel.MEMORY_AND_DISK_2 is recommended. * @param messageHandler A custom message handler that can generate a generic output from a * Kinesis `Record`, which contains both message data, and metadata. + * + * @note The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. */ def createStream[T: ClassTag]( ssc: StreamingContext, @@ -81,10 +81,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: - * The given AWS credentials will get saved in DStream checkpoints if checkpointing - * is enabled. Make sure that your checkpoint directory is secure. - * * @param ssc StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -107,6 +103,9 @@ object KinesisUtils { * Kinesis `Record`, which contains both message data, and metadata. * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + * + * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. */ // scalastyle:off def createStream[T: ClassTag]( @@ -134,10 +133,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain - * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain - * gets the AWS credentials. - * * @param ssc StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -156,6 +151,10 @@ object KinesisUtils { * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects. * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * + * @note The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. */ def createStream( ssc: StreamingContext, @@ -178,10 +177,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: - * The given AWS credentials will get saved in DStream checkpoints if checkpointing - * is enabled. Make sure that your checkpoint directory is secure. - * * @param ssc StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -202,6 +197,9 @@ object KinesisUtils { * StorageLevel.MEMORY_AND_DISK_2 is recommended. * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + * + * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. */ def createStream( ssc: StreamingContext, @@ -225,10 +223,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain - * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain - * gets the AWS credentials. - * * @param jssc Java StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -250,6 +244,10 @@ object KinesisUtils { * @param messageHandler A custom message handler that can generate a generic output from a * Kinesis `Record`, which contains both message data, and metadata. * @param recordClass Class of the records in DStream + * + * @note The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. */ def createStream[T]( jssc: JavaStreamingContext, @@ -272,10 +270,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: - * The given AWS credentials will get saved in DStream checkpoints if checkpointing - * is enabled. Make sure that your checkpoint directory is secure. - * * @param jssc Java StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -299,6 +293,9 @@ object KinesisUtils { * @param recordClass Class of the records in DStream * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + * + * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. */ // scalastyle:off def createStream[T]( @@ -326,10 +323,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain - * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain - * gets the AWS credentials. - * * @param jssc Java StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -348,6 +341,10 @@ object KinesisUtils { * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects. * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * + * @note The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. */ def createStream( jssc: JavaStreamingContext, @@ -367,10 +364,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: - * The given AWS credentials will get saved in DStream checkpoints if checkpointing - * is enabled. Make sure that your checkpoint directory is secure. - * * @param jssc Java StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -391,6 +384,9 @@ object KinesisUtils { * StorageLevel.MEMORY_AND_DISK_2 is recommended. * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + * + * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. */ def createStream( jssc: JavaStreamingContext, diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala index 905c33834df16..a4d81a680979e 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -221,7 +221,7 @@ abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) assert(collectedData.toSet === testData.toSet) // Verify that the block fetching is skipped when isBlockValid is set to false. - // This is done by using a RDD whose data is only in memory but is set to skip block fetching + // This is done by using an RDD whose data is only in memory but is set to skip block fetching // Using that RDD will throw exception, as it skips block fetching even if the blocks are in // in BlockManager. if (testIsBlockValid) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index e18831382d4d5..3810110099993 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -42,7 +42,7 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( @transient override val edges: EdgeRDDImpl[ED, VD] = replicatedVertexView.edges - /** Return a RDD that brings edges together with their source and destination vertices. */ + /** Return an RDD that brings edges together with their source and destination vertices. */ @transient override lazy val triplets: RDD[EdgeTriplet[VD, ED]] = { replicatedVertexView.upgrade(vertices, true, true) replicatedVertexView.edges.partitionsRDD.mapPartitions(_.flatMap { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index c0c3c73463aab..f926984aa6335 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -58,7 +58,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors} * `alpha` is the random reset probability (typically 0.15), `inNbrs[i]` is the set of * neighbors which link to `i` and `outDeg[j]` is the out degree of vertex `j`. * - * Note that this is not the "normalized" PageRank and as a consequence pages that have no + * @note This is not the "normalized" PageRank and as a consequence pages that have no * inlinks will have a PageRank of alpha. */ object PageRank extends Logging { diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala index 2e4a58dc6291c..22e4ec693b1f7 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala @@ -30,7 +30,7 @@ import org.apache.spark.annotation.Since /** * Represents a numeric vector, whose index type is Int and value type is Double. * - * Note: Users should not implement this interface. + * @note Users should not implement this interface. */ @Since("2.0.0") sealed trait Vector extends Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala index 252acc156583f..c581fed177273 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala @@ -30,7 +30,7 @@ import org.apache.spark.ml.param.ParamMap abstract class Model[M <: Model[M]] extends Transformer { /** * The parent estimator that produced this model. - * Note: For ensembles' component Models, this value can be null. + * @note For ensembles' component Models, this value can be null. */ @transient var parent: Estimator[M] = _ diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index bb192ab5f25ab..7424031ed4608 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -207,9 +207,9 @@ class DecisionTreeClassificationModel private[ml] ( * where gain is scaled by the number of instances passing through node * - Normalize importances for tree to sum to 1. * - * Note: Feature importance for single decision trees can have high variance due to - * correlated predictor variables. Consider using a [[RandomForestClassifier]] - * to determine feature importance instead. + * @note Feature importance for single decision trees can have high variance due to + * correlated predictor variables. Consider using a [[RandomForestClassifier]] + * to determine feature importance instead. */ @Since("2.0.0") lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index f8f164e8c14bd..52f93f5a6b345 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -43,7 +43,6 @@ import org.apache.spark.sql.types.DoubleType * Gradient-Boosted Trees (GBTs) (http://en.wikipedia.org/wiki/Gradient_boosting) * learning algorithm for classification. * It supports binary labels, as well as both continuous and categorical features. - * Note: Multiclass labels are not currently supported. * * The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999. * @@ -54,6 +53,8 @@ import org.apache.spark.sql.types.DoubleType * based on the loss function, whereas the original gradient boosting method does not. * - We expect to implement TreeBoost in the future: * [https://issues.apache.org/jira/browse/SPARK-4240] + * + * @note Multiclass labels are not currently supported. */ @Since("1.4.0") class GBTClassifier @Since("1.4.0") ( @@ -169,10 +170,11 @@ object GBTClassifier extends DefaultParamsReadable[GBTClassifier] { * Gradient-Boosted Trees (GBTs) (http://en.wikipedia.org/wiki/Gradient_boosting) * model for classification. * It supports binary labels, as well as both continuous and categorical features. - * Note: Multiclass labels are not currently supported. * * @param _trees Decision trees in the ensemble. * @param _treeWeights Weights for the decision trees in the ensemble. + * + * @note Multiclass labels are not currently supported. */ @Since("1.6.0") class GBTClassificationModel private[ml]( diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 18b9b3043db8a..71a7fe53c15f8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1191,8 +1191,8 @@ class BinaryLogisticRegressionSummary private[classification] ( * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. * See http://en.wikipedia.org/wiki/Receiver_operating_characteristic * - * Note: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. + * This will change in later Spark versions. */ @Since("1.5.0") @transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR") @@ -1200,8 +1200,8 @@ class BinaryLogisticRegressionSummary private[classification] ( /** * Computes the area under the receiver operating characteristic (ROC) curve. * - * Note: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. + * This will change in later Spark versions. */ @Since("1.5.0") lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC() @@ -1210,8 +1210,8 @@ class BinaryLogisticRegressionSummary private[classification] ( * Returns the precision-recall curve, which is a Dataframe containing * two fields recall, precision with (0.0, 1.0) prepended to it. * - * Note: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. + * This will change in later Spark versions. */ @Since("1.5.0") @transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall", "precision") @@ -1219,8 +1219,8 @@ class BinaryLogisticRegressionSummary private[classification] ( /** * Returns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0. * - * Note: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. + * This will change in later Spark versions. */ @Since("1.5.0") @transient lazy val fMeasureByThreshold: DataFrame = { @@ -1232,8 +1232,8 @@ class BinaryLogisticRegressionSummary private[classification] ( * Every possible probability obtained in transforming the dataset are used * as thresholds used in calculating the precision. * - * Note: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. + * This will change in later Spark versions. */ @Since("1.5.0") @transient lazy val precisionByThreshold: DataFrame = { @@ -1245,8 +1245,8 @@ class BinaryLogisticRegressionSummary private[classification] ( * Every possible probability obtained in transforming the dataset are used * as thresholds used in calculating the recall. * - * Note: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. + * This will change in later Spark versions. */ @Since("1.5.0") @transient lazy val recallByThreshold: DataFrame = { @@ -1401,18 +1401,18 @@ class BinaryLogisticRegressionSummary private[classification] ( * $$ *

* - * @note In order to avoid unnecessary computation during calculation of the gradient updates - * we lay out the coefficients in column major order during training. This allows us to - * perform feature standardization once, while still retaining sequential memory access - * for speed. We convert back to row major order when we create the model, - * since this form is optimal for the matrix operations used for prediction. - * * @param bcCoefficients The broadcast coefficients corresponding to the features. * @param bcFeaturesStd The broadcast standard deviation values of the features. * @param numClasses the number of possible outcomes for k classes classification problem in * Multinomial Logistic Regression. * @param fitIntercept Whether to fit an intercept term. * @param multinomial Whether to use multinomial (softmax) or binary loss + * + * @note In order to avoid unnecessary computation during calculation of the gradient updates + * we lay out the coefficients in column major order during training. This allows us to + * perform feature standardization once, while still retaining sequential memory access + * for speed. We convert back to row major order when we create the model, + * since this form is optimal for the matrix operations used for prediction. */ private class LogisticAggregator( bcCoefficients: Broadcast[Vector], diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index a0bd66e731a1d..c6035cc4c9647 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -268,9 +268,9 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] { * While this process is generally guaranteed to converge, it is not guaranteed * to find a global optimum. * - * Note: For high-dimensional data (with many features), this algorithm may perform poorly. - * This is due to high-dimensional data (a) making it difficult to cluster at all (based - * on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions. + * @note For high-dimensional data (with many features), this algorithm may perform poorly. + * This is due to high-dimensional data (a) making it difficult to cluster at all (based + * on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions. */ @Since("2.0.0") @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index 28cbe1cb01e9a..ccfb0ce8f85ca 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -85,7 +85,8 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H *

* * For the case $E_{max} == E_{min}$, $Rescaled(e_i) = 0.5 * (max + min)$. - * Note that since zero values will probably be transformed to non-zero values, output of the + * + * @note Since zero values will probably be transformed to non-zero values, output of the * transformer will be DenseVector even for sparse input. */ @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index e8e28ba29c841..ea401216aec7b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -36,7 +36,8 @@ import org.apache.spark.sql.types.{DoubleType, NumericType, StructType} * The last category is not included by default (configurable via [[OneHotEncoder!.dropLast]] * because it makes the vector entries sum up to one, and hence linearly dependent. * So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`. - * Note that this is different from scikit-learn's OneHotEncoder, which keeps all categories. + * + * @note This is different from scikit-learn's OneHotEncoder, which keeps all categories. * The output vectors are sparse. * * @see [[StringIndexer]] for converting categorical values into category indices diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 1e49352b8517e..6e08bf059124c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -142,8 +142,9 @@ class PCAModel private[ml] ( /** * Transform a vector by computed Principal Components. - * NOTE: Vectors to be transformed must be the same length - * as the source vectors given to [[PCA.fit()]]. + * + * @note Vectors to be transformed must be the same length as the source vectors given + * to [[PCA.fit()]]. */ @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 666070037cdd8..0ced21365ff6f 100755 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -28,7 +28,10 @@ import org.apache.spark.sql.types.{ArrayType, StringType, StructType} /** * A feature transformer that filters out stop words from input. - * Note: null values from input array are preserved unless adding null to stopWords explicitly. + * + * @note null values from input array are preserved unless adding null to stopWords + * explicitly. + * * @see [[http://en.wikipedia.org/wiki/Stop_words]] */ @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 80fe46796f807..8b155f00017cf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -113,11 +113,11 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] { /** * Model fitted by [[StringIndexer]]. * - * NOTE: During transformation, if the input column does not exist, + * @param labels Ordered list of labels, corresponding to indices to be assigned. + * + * @note During transformation, if the input column does not exist, * [[StringIndexerModel.transform]] would return the input dataset unmodified. * This is a temporary fix for the case when target labels do not exist during prediction. - * - * @param labels Ordered list of labels, corresponding to indices to be assigned. */ @Since("1.4.0") class StringIndexerModel ( diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 9245931b27ca6..96206e0b7ad88 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -533,7 +533,7 @@ trait Params extends Identifiable with Serializable { * Returns all params sorted by their names. The default implementation uses Java reflection to * list all public methods that have no arguments and return [[Param]]. * - * Note: Developer should not use this method in constructor because we cannot guarantee that + * @note Developer should not use this method in constructor because we cannot guarantee that * this variable gets initialized before other params. */ lazy val params: Array[Param[_]] = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index ebc6c12ddcf92..1419da874709f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -207,9 +207,9 @@ class DecisionTreeRegressionModel private[ml] ( * where gain is scaled by the number of instances passing through node * - Normalize importances for tree to sum to 1. * - * Note: Feature importance for single decision trees can have high variance due to - * correlated predictor variables. Consider using a [[RandomForestRegressor]] - * to determine feature importance instead. + * @note Feature importance for single decision trees can have high variance due to + * correlated predictor variables. Consider using a [[RandomForestRegressor]] + * to determine feature importance instead. */ @Since("2.0.0") lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 1d2961e0277f5..736fd3b9e0f64 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -879,8 +879,8 @@ class GeneralizedLinearRegressionSummary private[regression] ( * Private copy of model to ensure Params are not modified outside this class. * Coefficients is not a deep copy, but that is acceptable. * - * NOTE: [[predictionCol]] must be set correctly before the value of [[model]] is set, - * and [[model]] must be set before [[predictions]] is set! + * @note [[predictionCol]] must be set correctly before the value of [[model]] is set, + * and [[model]] must be set before [[predictions]] is set! */ protected val model: GeneralizedLinearRegressionModel = origModel.copy(ParamMap.empty).setPredictionCol(predictionCol) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 71c542adf6f6f..da7ce6b46f2ab 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -103,11 +103,13 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String /** * Whether to standardize the training features before fitting the model. * The coefficients of models will be always returned on the original scale, - * so it will be transparent for users. Note that with/without standardization, - * the models should be always converged to the same solution when no regularization - * is applied. In R's GLMNET package, the default behavior is true as well. + * so it will be transparent for users. * Default is true. * + * @note With/without standardization, the models should be always converged + * to the same solution when no regularization is applied. In R's GLMNET package, + * the default behavior is true as well. + * * @group setParam */ @Since("1.5.0") @@ -624,8 +626,8 @@ class LinearRegressionSummary private[regression] ( * explainedVariance = 1 - variance(y - \hat{y}) / variance(y) * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]] * - * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val explainedVariance: Double = metrics.explainedVariance @@ -634,8 +636,8 @@ class LinearRegressionSummary private[regression] ( * Returns the mean absolute error, which is a risk function corresponding to the * expected value of the absolute error loss or l1-norm loss. * - * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val meanAbsoluteError: Double = metrics.meanAbsoluteError @@ -644,8 +646,8 @@ class LinearRegressionSummary private[regression] ( * Returns the mean squared error, which is a risk function corresponding to the * expected value of the squared error loss or quadratic loss. * - * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val meanSquaredError: Double = metrics.meanSquaredError @@ -654,8 +656,8 @@ class LinearRegressionSummary private[regression] ( * Returns the root mean squared error, which is defined as the square root of * the mean squared error. * - * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val rootMeanSquaredError: Double = metrics.rootMeanSquaredError @@ -664,8 +666,8 @@ class LinearRegressionSummary private[regression] ( * Returns R^2^, the coefficient of determination. * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] * - * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val r2: Double = metrics.r2 diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMDataSource.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMDataSource.scala index 73d813064decb..e1376927030e4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMDataSource.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMDataSource.scala @@ -48,7 +48,7 @@ import org.apache.spark.sql.{DataFrame, DataFrameReader} * inconsistent feature dimensions. * - "vectorType": feature vector type, "sparse" (default) or "dense". * - * Note that this class is public for documentation purpose. Please don't use this class directly. + * @note This class is public for documentation purpose. Please don't use this class directly. * Rather, use the data source API as illustrated above. * * @see [[https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ LIBSVM datasets]] diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index ede0a060eef95..0a0bc4c006389 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -98,7 +98,7 @@ private[spark] object GradientBoostedTrees extends Logging { * @param initTreeWeight: learning rate assigned to the first tree. * @param initTree: first DecisionTreeModel. * @param loss: evaluation metric. - * @return a RDD with each element being a zip of the prediction and error + * @return an RDD with each element being a zip of the prediction and error * corresponding to every sample. */ def computeInitialPredictionAndError( @@ -121,7 +121,7 @@ private[spark] object GradientBoostedTrees extends Logging { * @param treeWeight: Learning rate. * @param tree: Tree using which the prediction and error should be updated. * @param loss: evaluation metric. - * @return a RDD with each element being a zip of the prediction and error + * @return an RDD with each element being a zip of the prediction and error * corresponding to each sample. */ def updatePredictionError( diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index bc4f9e6716ee8..e5fa5d53e3fca 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -221,7 +221,7 @@ trait MLReadable[T] { /** * Reads an ML instance from the input path, a shortcut of `read.load(path)`. * - * Note: Implementing classes should override this to be Java-friendly. + * @note Implementing classes should override this to be Java-friendly. */ @Since("1.6.0") def load(path: String): T = read.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index d851b983349c9..4b650000736e2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -202,9 +202,11 @@ object LogisticRegressionModel extends Loader[LogisticRegressionModel] { * Train a classification model for Binary Logistic Regression * using Stochastic Gradient Descent. By default L2 regularization is used, * which can be changed via `LogisticRegressionWithSGD.optimizer`. - * NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1} - * for k classes multi-label classification problem. + * * Using [[LogisticRegressionWithLBFGS]] is recommended over this. + * + * @note Labels used in Logistic Regression should be {0, 1, ..., k - 1} + * for k classes multi-label classification problem. */ @Since("0.8.0") class LogisticRegressionWithSGD private[mllib] ( @@ -239,7 +241,8 @@ class LogisticRegressionWithSGD private[mllib] ( /** * Top-level methods for calling Logistic Regression using Stochastic Gradient Descent. - * NOTE: Labels used in Logistic Regression should be {0, 1} + * + * @note Labels used in Logistic Regression should be {0, 1} */ @Since("0.8.0") @deprecated("Use ml.classification.LogisticRegression or LogisticRegressionWithLBFGS", "2.0.0") @@ -252,7 +255,6 @@ object LogisticRegressionWithSGD { * number of iterations of gradient descent using the specified step size. Each iteration uses * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in * gradient descent are initialized using the initial weights provided. - * NOTE: Labels used in Logistic Regression should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. @@ -260,6 +262,8 @@ object LogisticRegressionWithSGD { * @param miniBatchFraction Fraction of data to be used per iteration. * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. + * + * @note Labels used in Logistic Regression should be {0, 1} */ @Since("1.0.0") def train( @@ -276,13 +280,13 @@ object LogisticRegressionWithSGD { * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed * number of iterations of gradient descent using the specified step size. Each iteration uses * `miniBatchFraction` fraction of the data to calculate the gradient. - * NOTE: Labels used in Logistic Regression should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. * @param stepSize Step size to be used for each iteration of gradient descent. - * @param miniBatchFraction Fraction of data to be used per iteration. + * + * @note Labels used in Logistic Regression should be {0, 1} */ @Since("1.0.0") def train( @@ -298,13 +302,13 @@ object LogisticRegressionWithSGD { * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed * number of iterations of gradient descent using the specified step size. We use the entire data * set to update the gradient in each iteration. - * NOTE: Labels used in Logistic Regression should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param stepSize Step size to be used for each iteration of Gradient Descent. - * @param numIterations Number of iterations of gradient descent to run. * @return a LogisticRegressionModel which has the weights and offset from training. + * + * @note Labels used in Logistic Regression should be {0, 1} */ @Since("1.0.0") def train( @@ -318,11 +322,12 @@ object LogisticRegressionWithSGD { * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed * number of iterations of gradient descent using a step size of 1.0. We use the entire data set * to update the gradient in each iteration. - * NOTE: Labels used in Logistic Regression should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. * @return a LogisticRegressionModel which has the weights and offset from training. + * + * @note Labels used in Logistic Regression should be {0, 1} */ @Since("1.0.0") def train( @@ -335,8 +340,6 @@ object LogisticRegressionWithSGD { /** * Train a classification model for Multinomial/Binary Logistic Regression using * Limited-memory BFGS. Standard feature scaling and L2 regularization are used by default. - * NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1} - * for k classes multi-label classification problem. * * Earlier implementations of LogisticRegressionWithLBFGS applies a regularization * penalty to all elements including the intercept. If this is called with one of @@ -344,6 +347,9 @@ object LogisticRegressionWithSGD { * into a call to ml.LogisticRegression, otherwise this will use the existing mllib * GeneralizedLinearAlgorithm trainer, resulting in a regularization penalty to the * intercept. + * + * @note Labels used in Logistic Regression should be {0, 1, ..., k - 1} + * for k classes multi-label classification problem. */ @Since("1.1.0") class LogisticRegressionWithLBFGS diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 7c3ccbb40b812..aec1526b55c49 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -125,7 +125,8 @@ object SVMModel extends Loader[SVMModel] { /** * Train a Support Vector Machine (SVM) using Stochastic Gradient Descent. By default L2 * regularization is used, which can be changed via [[SVMWithSGD.optimizer]]. - * NOTE: Labels used in SVM should be {0, 1}. + * + * @note Labels used in SVM should be {0, 1}. */ @Since("0.8.0") class SVMWithSGD private ( @@ -158,7 +159,9 @@ class SVMWithSGD private ( } /** - * Top-level methods for calling SVM. NOTE: Labels used in SVM should be {0, 1}. + * Top-level methods for calling SVM. + * + * @note Labels used in SVM should be {0, 1}. */ @Since("0.8.0") object SVMWithSGD { @@ -169,8 +172,6 @@ object SVMWithSGD { * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in * gradient descent are initialized using the initial weights provided. * - * NOTE: Labels used in SVM should be {0, 1}. - * * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. * @param stepSize Step size to be used for each iteration of gradient descent. @@ -178,6 +179,8 @@ object SVMWithSGD { * @param miniBatchFraction Fraction of data to be used per iteration. * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. + * + * @note Labels used in SVM should be {0, 1}. */ @Since("0.8.0") def train( @@ -195,7 +198,8 @@ object SVMWithSGD { * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number * of iterations of gradient descent using the specified step size. Each iteration uses * `miniBatchFraction` fraction of the data to calculate the gradient. - * NOTE: Labels used in SVM should be {0, 1} + * + * @note Labels used in SVM should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. @@ -217,13 +221,14 @@ object SVMWithSGD { * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number * of iterations of gradient descent using the specified step size. We use the entire data set to * update the gradient in each iteration. - * NOTE: Labels used in SVM should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param stepSize Step size to be used for each iteration of Gradient Descent. * @param regParam Regularization parameter. * @param numIterations Number of iterations of gradient descent to run. * @return a SVMModel which has the weights and offset from training. + * + * @note Labels used in SVM should be {0, 1} */ @Since("0.8.0") def train( @@ -238,11 +243,12 @@ object SVMWithSGD { * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number * of iterations of gradient descent using a step size of 1.0. We use the entire data set to * update the gradient in each iteration. - * NOTE: Labels used in SVM should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. * @return a SVMModel which has the weights and offset from training. + * + * @note Labels used in SVM should be {0, 1} */ @Since("0.8.0") def train(input: RDD[LabeledPoint], numIterations: Int): SVMModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index 43193adf3e184..56cdeea5f7a3f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -41,14 +41,14 @@ import org.apache.spark.util.Utils * While this process is generally guaranteed to converge, it is not guaranteed * to find a global optimum. * - * Note: For high-dimensional data (with many features), this algorithm may perform poorly. - * This is due to high-dimensional data (a) making it difficult to cluster at all (based - * on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions. - * * @param k Number of independent Gaussians in the mixture model. * @param convergenceTol Maximum change in log-likelihood at which convergence * is considered to have occurred. * @param maxIterations Maximum number of iterations allowed. + * + * @note For high-dimensional data (with many features), this algorithm may perform poorly. + * This is due to high-dimensional data (a) making it difficult to cluster at all (based + * on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions. */ @Since("1.3.0") class GaussianMixture private ( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index ed9c064879d01..fa72b72e2d921 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -56,14 +56,18 @@ class KMeans private ( def this() = this(2, 20, KMeans.K_MEANS_PARALLEL, 2, 1e-4, Utils.random.nextLong()) /** - * Number of clusters to create (k). Note that it is possible for fewer than k clusters to + * Number of clusters to create (k). + * + * @note It is possible for fewer than k clusters to * be returned, for example, if there are fewer than k distinct points to cluster. */ @Since("1.4.0") def getK: Int = k /** - * Set the number of clusters to create (k). Note that it is possible for fewer than k clusters to + * Set the number of clusters to create (k). + * + * @note It is possible for fewer than k clusters to * be returned, for example, if there are fewer than k distinct points to cluster. Default: 2. */ @Since("0.8.0") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index d999b9be8e8ac..7c52abdeaac22 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -175,7 +175,7 @@ class LDA private ( * * This is the parameter to a symmetric Dirichlet distribution. * - * Note: The topics' distributions over terms are called "beta" in the original LDA paper + * @note The topics' distributions over terms are called "beta" in the original LDA paper * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. */ @Since("1.3.0") @@ -187,7 +187,7 @@ class LDA private ( * * This is the parameter to a symmetric Dirichlet distribution. * - * Note: The topics' distributions over terms are called "beta" in the original LDA paper + * @note The topics' distributions over terms are called "beta" in the original LDA paper * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. * * If set to -1, then topicConcentration is set automatically. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 90d8a558f10d4..b5b0e64a2a6c6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -66,7 +66,7 @@ abstract class LDAModel private[clustering] extends Saveable { * * This is the parameter to a symmetric Dirichlet distribution. * - * Note: The topics' distributions over terms are called "beta" in the original LDA paper + * @note The topics' distributions over terms are called "beta" in the original LDA paper * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. */ @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index ae324f86fe6d1..7365ea1f200da 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -93,9 +93,11 @@ final class EMLDAOptimizer extends LDAOptimizer { /** * If using checkpointing, this indicates whether to keep the last checkpoint (vs clean up). * Deleting the checkpoint can cause failures if a data partition is lost, so set this bit with - * care. Note that checkpoints will be cleaned up via reference counting, regardless. + * care. * * Default: true + * + * @note Checkpoints will be cleaned up via reference counting, regardless. */ @Since("2.0.0") def setKeepLastCheckpoint(keepLastCheckpoint: Boolean): this.type = { @@ -348,7 +350,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { * Mini-batch fraction in (0, 1], which sets the fraction of document sampled and used in * each iteration. * - * Note that this should be adjusted in synch with [[LDA.setMaxIterations()]] + * @note This should be adjusted in synch with [[LDA.setMaxIterations()]] * so the entire corpus is used. Specifically, set both so that * maxIterations * miniBatchFraction >= 1. * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala index f0779491e6374..003d1411a9cf7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala @@ -39,7 +39,7 @@ private[evaluation] object AreaUnderCurve { /** * Returns the area under the given curve. * - * @param curve a RDD of ordered 2D points stored in pairs representing a curve + * @param curve an RDD of ordered 2D points stored in pairs representing a curve */ def of(curve: RDD[(Double, Double)]): Double = { curve.sliding(2).aggregate(0.0)( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index fbd217af74ecb..c94d7890cf557 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.types._ /** * Represents a numeric vector, whose index type is Int and value type is Double. * - * Note: Users should not implement this interface. + * @note Users should not implement this interface. */ @SQLUserDefinedType(udt = classOf[VectorUDT]) @Since("1.0.0") @@ -132,7 +132,9 @@ sealed trait Vector extends Serializable { /** * Number of active entries. An "active entry" is an element which is explicitly stored, - * regardless of its value. Note that inactive entries have value 0. + * regardless of its value. + * + * @note Inactive entries have value 0. */ @Since("1.4.0") def numActives: Int diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala index 377be6bfb9886..03866753b50ee 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala @@ -451,7 +451,7 @@ class BlockMatrix @Since("1.3.0") ( * [[BlockMatrix]] will only consist of blocks of [[DenseMatrix]]. This may cause * some performance issues until support for multiplying two sparse matrices is added. * - * Note: The behavior of multiply has changed in 1.6.0. `multiply` used to throw an error when + * @note The behavior of multiply has changed in 1.6.0. `multiply` used to throw an error when * there were blocks with duplicate indices. Now, the blocks with duplicate indices will be added * with each other. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index b03b3ecde94f4..809906a158337 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -188,8 +188,9 @@ class IndexedRowMatrix @Since("1.0.0") ( } /** - * Computes the Gramian matrix `A^T A`. Note that this cannot be - * computed on matrices with more than 65535 columns. + * Computes the Gramian matrix `A^T A`. + * + * @note This cannot be computed on matrices with more than 65535 columns. */ @Since("1.0.0") def computeGramianMatrix(): Matrix = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index ec32e37afb792..4b120332ab8d8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -106,8 +106,9 @@ class RowMatrix @Since("1.0.0") ( } /** - * Computes the Gramian matrix `A^T A`. Note that this cannot be computed on matrices with - * more than 65535 columns. + * Computes the Gramian matrix `A^T A`. + * + * @note This cannot be computed on matrices with more than 65535 columns. */ @Since("1.0.0") def computeGramianMatrix(): Matrix = { @@ -168,9 +169,6 @@ class RowMatrix @Since("1.0.0") ( * ARPACK is set to 300 or k * 3, whichever is larger. The numerical tolerance for ARPACK's * eigen-decomposition is set to 1e-10. * - * @note The conditions that decide which method to use internally and the default parameters are - * subject to change. - * * @param k number of leading singular values to keep (0 < k <= n). * It might return less than k if * there are numerically zero singular values or there are not enough Ritz values @@ -180,6 +178,9 @@ class RowMatrix @Since("1.0.0") ( * @param rCond the reciprocal condition number. All singular values smaller than rCond * sigma(0) * are treated as zero, where sigma(0) is the largest singular value. * @return SingularValueDecomposition(U, s, V). U = null if computeU = false. + * + * @note The conditions that decide which method to use internally and the default parameters are + * subject to change. */ @Since("1.0.0") def computeSVD( @@ -319,9 +320,11 @@ class RowMatrix @Since("1.0.0") ( } /** - * Computes the covariance matrix, treating each row as an observation. Note that this cannot - * be computed on matrices with more than 65535 columns. + * Computes the covariance matrix, treating each row as an observation. + * * @return a local dense matrix of size n x n + * + * @note This cannot be computed on matrices with more than 65535 columns. */ @Since("1.0.0") def computeCovariance(): Matrix = { @@ -369,12 +372,12 @@ class RowMatrix @Since("1.0.0") ( * The row data do not need to be "centered" first; it is not necessary for * the mean of each column to be 0. * - * Note that this cannot be computed on matrices with more than 65535 columns. - * * @param k number of top principal components. * @return a matrix of size n-by-k, whose columns are principal components, and * a vector of values which indicate how much variance each principal component * explains + * + * @note This cannot be computed on matrices with more than 65535 columns. */ @Since("1.6.0") def computePrincipalComponentsAndExplainedVariance(k: Int): (Matrix, Vector) = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala index 81e64de4e5b5d..c49e72646bf13 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -305,7 +305,8 @@ class LeastSquaresGradient extends Gradient { * :: DeveloperApi :: * Compute gradient and loss for a Hinge loss function, as used in SVM binary classification. * See also the documentation for the precise formulation. - * NOTE: This assumes that the labels are {0,1} + * + * @note This assumes that the labels are {0,1} */ @DeveloperApi class HingeGradient extends Gradient { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala index 0f7857b8d8627..005119616f063 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -31,7 +31,7 @@ import org.apache.spark.rdd.RDD class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable { /** - * Returns a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding + * Returns an RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding * window over them. The ordering is first based on the partition index and then the ordering of * items within each partition. This is similar to sliding in Scala collections, except that it * becomes an empty RDD if the window size is greater than the total number of items. It needs to diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index c642573ccba6d..24e4dcccc843f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -43,14 +43,14 @@ import org.apache.spark.storage.StorageLevel /** * Model representing the result of matrix factorization. * - * Note: If you create the model directly using constructor, please be aware that fast prediction - * requires cached user/product features and their associated partitioners. - * * @param rank Rank for the features in this model. * @param userFeatures RDD of tuples where each tuple represents the userId and * the features computed for this user. * @param productFeatures RDD of tuples where each tuple represents the productId * and the features computed for this product. + * + * @note If you create the model directly using constructor, please be aware that fast prediction + * requires cached user/product features and their associated partitioners. */ @Since("0.8.0") class MatrixFactorizationModel @Since("0.8.0") ( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index f3159f7e724cc..925fdf4d7e7bc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -60,15 +60,15 @@ object Statistics { * Compute the correlation matrix for the input RDD of Vectors using the specified method. * Methods currently supported: `pearson` (default), `spearman`. * - * Note that for Spearman, a rank correlation, we need to create an RDD[Double] for each column - * and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector], - * which is fairly costly. Cache the input RDD before calling corr with `method = "spearman"` to - * avoid recomputing the common lineage. - * * @param X an RDD[Vector] for which the correlation matrix is to be computed. * @param method String specifying the method to use for computing correlation. * Supported: `pearson` (default), `spearman` * @return Correlation matrix comparing columns in X. + * + * @note For Spearman, a rank correlation, we need to create an RDD[Double] for each column + * and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector], + * which is fairly costly. Cache the input RDD before calling corr with `method = "spearman"` to + * avoid recomputing the common lineage. */ @Since("1.1.0") def corr(X: RDD[Vector], method: String): Matrix = Correlations.corrMatrix(X, method) @@ -77,12 +77,12 @@ object Statistics { * Compute the Pearson correlation for the input RDDs. * Returns NaN if either vector has 0 variance. * - * Note: the two input RDDs need to have the same number of partitions and the same number of - * elements in each partition. - * * @param x RDD[Double] of the same cardinality as y. * @param y RDD[Double] of the same cardinality as x. * @return A Double containing the Pearson correlation between the two input RDD[Double]s + * + * @note The two input RDDs need to have the same number of partitions and the same number of + * elements in each partition. */ @Since("1.1.0") def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y) @@ -98,15 +98,15 @@ object Statistics { * Compute the correlation for the input RDDs using the specified method. * Methods currently supported: `pearson` (default), `spearman`. * - * Note: the two input RDDs need to have the same number of partitions and the same number of - * elements in each partition. - * * @param x RDD[Double] of the same cardinality as y. * @param y RDD[Double] of the same cardinality as x. * @param method String specifying the method to use for computing correlation. * Supported: `pearson` (default), `spearman` * @return A Double containing the correlation between the two input RDD[Double]s using the * specified method. + * + * @note The two input RDDs need to have the same number of partitions and the same number of + * elements in each partition. */ @Since("1.1.0") def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method) @@ -122,15 +122,15 @@ object Statistics { * Conduct Pearson's chi-squared goodness of fit test of the observed data against the * expected distribution. * - * Note: the two input Vectors need to have the same size. - * `observed` cannot contain negative values. - * `expected` cannot contain nonpositive values. - * * @param observed Vector containing the observed categorical counts/relative frequencies. * @param expected Vector containing the expected categorical counts/relative frequencies. * `expected` is rescaled if the `expected` sum differs from the `observed` sum. * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. + * + * @note The two input Vectors need to have the same size. + * `observed` cannot contain negative values. + * `expected` cannot contain nonpositive values. */ @Since("1.1.0") def chiSqTest(observed: Vector, expected: Vector): ChiSqTestResult = { @@ -141,11 +141,11 @@ object Statistics { * Conduct Pearson's chi-squared goodness of fit test of the observed data against the uniform * distribution, with each category having an expected frequency of `1 / observed.size`. * - * Note: `observed` cannot contain negative values. - * * @param observed Vector containing the observed categorical counts/relative frequencies. * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. + * + * @note `observed` cannot contain negative values. */ @Since("1.1.0") def chiSqTest(observed: Vector): ChiSqTestResult = ChiSqTest.chiSquared(observed) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 36feab7859b43..d846c43cf2913 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -75,10 +75,6 @@ object DecisionTree extends Serializable with Logging { * Method to train a decision tree model. * The method supports binary and multiclass classification and regression. * - * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] - * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] - * is recommended to clearly separate classification and regression. - * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. @@ -86,6 +82,10 @@ object DecisionTree extends Serializable with Logging { * of decision tree (classification or regression), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. * @return DecisionTreeModel that can be used for prediction. + * + * @note Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. */ @Since("1.0.0") def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { @@ -96,10 +96,6 @@ object DecisionTree extends Serializable with Logging { * Method to train a decision tree model. * The method supports binary and multiclass classification and regression. * - * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] - * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] - * is recommended to clearly separate classification and regression. - * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. @@ -108,6 +104,10 @@ object DecisionTree extends Serializable with Logging { * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means * 1 internal node + 2 leaf nodes). * @return DecisionTreeModel that can be used for prediction. + * + * @note Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. */ @Since("1.0.0") def train( @@ -123,10 +123,6 @@ object DecisionTree extends Serializable with Logging { * Method to train a decision tree model. * The method supports binary and multiclass classification and regression. * - * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] - * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] - * is recommended to clearly separate classification and regression. - * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. @@ -136,6 +132,10 @@ object DecisionTree extends Serializable with Logging { * 1 internal node + 2 leaf nodes). * @param numClasses Number of classes for classification. Default value of 2. * @return DecisionTreeModel that can be used for prediction. + * + * @note Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. */ @Since("1.2.0") def train( @@ -152,10 +152,6 @@ object DecisionTree extends Serializable with Logging { * Method to train a decision tree model. * The method supports binary and multiclass classification and regression. * - * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] - * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] - * is recommended to clearly separate classification and regression. - * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. @@ -170,6 +166,10 @@ object DecisionTree extends Serializable with Logging { * indicates that feature n is categorical with k categories * indexed from 0: {0, 1, ..., k-1}. * @return DecisionTreeModel that can be used for prediction. + * + * @note Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. */ @Since("1.0.0") def train( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala index de14ddf024d75..09274a2e1b2ac 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala @@ -42,11 +42,13 @@ trait Loss extends Serializable { /** * Method to calculate error of the base learner for the gradient boosting calculation. - * Note: This method is not used by the gradient boosting algorithm but is useful for debugging - * purposes. + * * @param model Model of the weak learner. * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @return Measure of model error on data + * + * @note This method is not used by the gradient boosting algorithm but is useful for debugging + * purposes. */ @Since("1.2.0") def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { @@ -55,11 +57,13 @@ trait Loss extends Serializable { /** * Method to calculate loss when the predictions are already known. - * Note: This method is used in the method evaluateEachIteration to avoid recomputing the - * predicted values from previously fit trees. + * * @param prediction Predicted label. * @param label True label. * @return Measure of model error on datapoint. + * + * @note This method is used in the method evaluateEachIteration to avoid recomputing the + * predicted values from previously fit trees. */ private[spark] def computeError(prediction: Double, label: Double): Double } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index 657ed0a8ecda8..299950785e420 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -187,7 +187,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { * @param initTreeWeight: learning rate assigned to the first tree. * @param initTree: first DecisionTreeModel. * @param loss: evaluation metric. - * @return a RDD with each element being a zip of the prediction and error + * @return an RDD with each element being a zip of the prediction and error * corresponding to every sample. */ @Since("1.4.0") @@ -213,7 +213,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { * @param treeWeight: Learning rate. * @param tree: Tree using which the prediction and error should be updated. * @param loss: evaluation metric. - * @return a RDD with each element being a zip of the prediction and error + * @return an RDD with each element being a zip of the prediction and error * corresponding to each sample. */ @Since("1.4.0") diff --git a/pom.xml b/pom.xml index 650b4cd965b66..024b2850d0a3d 100644 --- a/pom.xml +++ b/pom.xml @@ -2476,6 +2476,13 @@ maven-javadoc-plugin -Xdoclint:all -Xdoclint:-missing + + + note + a + Note: + + diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 2d3a95b163a76..92b45657210e1 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -741,7 +741,8 @@ object Unidoc { javacOptions in (JavaUnidoc, unidoc) := Seq( "-windowtitle", "Spark " + version.value.replaceAll("-SNAPSHOT", "") + " JavaDoc", "-public", - "-noqualifier", "java.lang" + "-noqualifier", "java.lang", + "-tag", """note:a:Note\:""" ), // Use GitHub repository for Scaladoc source links diff --git a/python/pyspark/mllib/stat/KernelDensity.py b/python/pyspark/mllib/stat/KernelDensity.py index 3b1c5519bd87e..7250eab6705a7 100644 --- a/python/pyspark/mllib/stat/KernelDensity.py +++ b/python/pyspark/mllib/stat/KernelDensity.py @@ -28,7 +28,7 @@ class KernelDensity(object): """ - Estimate probability density at required points given a RDD of samples + Estimate probability density at required points given an RDD of samples from the population. >>> kd = KernelDensity() diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index ed6fd4bca4c54..97755807ef262 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -499,7 +499,7 @@ def generateLinearInput(intercept, weights, xMean, xVariance, def generateLinearRDD(sc, nexamples, nfeatures, eps, nParts=2, intercept=0.0): """ - Generate a RDD of LabeledPoints. + Generate an RDD of LabeledPoints. """ return callMLlibFunc( "generateLinearRDDWrapper", sc, int(nexamples), int(nfeatures), diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index a163ceafe9d3b..641787ee20e0c 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1218,7 +1218,7 @@ def mergeMaps(m1, m2): def top(self, num, key=None): """ - Get the top N elements from a RDD. + Get the top N elements from an RDD. Note that this method should only be used if the resulting array is expected to be small, as all the data is loaded into the driver's memory. @@ -1242,7 +1242,7 @@ def merge(a, b): def takeOrdered(self, num, key=None): """ - Get the N elements from a RDD ordered in ascending order or as + Get the N elements from an RDD ordered in ascending order or as specified by the optional key function. Note that this method should only be used if the resulting array is expected diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index bf27d8047a753..134424add3b62 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -144,7 +144,7 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders=None, """ .. note:: Experimental - Create a RDD from Kafka using offset ranges for each topic and partition. + Create an RDD from Kafka using offset ranges for each topic and partition. :param sc: SparkContext object :param kafkaParams: Additional params for Kafka @@ -155,7 +155,7 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders=None, :param valueDecoder: A function used to decode value (default is utf8_decoder) :param messageHandler: A function used to convert KafkaMessageAndMetadata. You can assess meta using messageHandler (default is None). - :return: A RDD object + :return: An RDD object """ if leaders is None: leaders = dict() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index dc90659a676e0..0b95a8821b05a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -165,10 +165,10 @@ object Encoders { * (Scala-specific) Creates an encoder that serializes objects of type T using generic Java * serialization. This encoder maps T into a single byte array (binary) field. * - * Note that this is extremely inefficient and should only be used as the last resort. - * * T must be publicly accessible. * + * @note This is extremely inefficient and should only be used as the last resort. + * * @since 1.6.0 */ def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false) @@ -177,10 +177,10 @@ object Encoders { * Creates an encoder that serializes objects of type T using generic Java serialization. * This encoder maps T into a single byte array (binary) field. * - * Note that this is extremely inefficient and should only be used as the last resort. - * * T must be publicly accessible. * + * @note This is extremely inefficient and should only be used as the last resort. + * * @since 1.6.0 */ def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala index e121044288e5a..21f3497ba06fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala @@ -23,10 +23,10 @@ import org.apache.spark.annotation.InterfaceStability * The data type representing calendar time intervals. The calendar time interval is stored * internally in two components: number of months the number of microseconds. * - * Note that calendar intervals are not comparable. - * * Please use the singleton [[DataTypes.CalendarIntervalType]]. * + * @note Calendar intervals are not comparable. + * * @since 1.5.0 */ @InterfaceStability.Stable diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 7a131b30eafd7..fa3b2b9de5d5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -118,7 +118,7 @@ class TypedColumn[-T, U]( * $"a" === $"b" * }}} * - * Note that the internal Catalyst expression can be accessed via "expr", but this method is for + * @note The internal Catalyst expression can be accessed via "expr", but this method is for * debugging purposes only and can change in any future Spark releases. * * @groupname java_expr_ops Java-specific expression operators diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index b5bbcee37150f..6335fc4579a28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -51,7 +51,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * The algorithm was first present in [[http://dx.doi.org/10.1145/375663.375670 Space-efficient * Online Computation of Quantile Summaries]] by Greenwald and Khanna. * - * Note that NaN values will be removed from the numerical column before calculation * @param col the name of the numerical column * @param probabilities a list of quantile probabilities * Each number must belong to [0, 1]. @@ -61,6 +60,8 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * Note that values greater than 1 are accepted but give the same result as 1. * @return the approximate quantiles at the given probabilities * + * @note NaN values will be removed from the numerical column before calculation + * * @since 2.0.0 */ def approxQuantile( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index e0c89811ddbfa..15281f24fa628 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -218,7 +218,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * Inserts the content of the [[DataFrame]] to the specified table. It requires that * the schema of the [[DataFrame]] is the same as the schema of the table. * - * Note: Unlike `saveAsTable`, `insertInto` ignores the column names and just uses position-based + * @note Unlike `saveAsTable`, `insertInto` ignores the column names and just uses position-based * resolution. For example: * * {{{ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 3761773698df3..3c75a6a45ec86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -377,7 +377,7 @@ class Dataset[T] private[sql]( /** * Converts this strongly typed collection of data to generic `DataFrame` with columns renamed. - * This can be quite convenient in conversion from a RDD of tuples into a [[DataFrame]] with + * This can be quite convenient in conversion from an RDD of tuples into a [[DataFrame]] with * meaningful names. For example: * {{{ * val rdd: RDD[(Int, String)] = ... @@ -703,13 +703,13 @@ class Dataset[T] private[sql]( * df1.join(df2, "user_id") * }}} * - * Note that if you perform a self-join using this function without aliasing the input - * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since - * there is no way to disambiguate which side of the join you would like to reference. - * * @param right Right side of the join operation. * @param usingColumn Name of the column to join on. This column must exist on both sides. * + * @note If you perform a self-join using this function without aliasing the input + * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since + * there is no way to disambiguate which side of the join you would like to reference. + * * @group untypedrel * @since 2.0.0 */ @@ -728,13 +728,13 @@ class Dataset[T] private[sql]( * df1.join(df2, Seq("user_id", "user_name")) * }}} * - * Note that if you perform a self-join using this function without aliasing the input - * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since - * there is no way to disambiguate which side of the join you would like to reference. - * * @param right Right side of the join operation. * @param usingColumns Names of the columns to join on. This columns must exist on both sides. * + * @note If you perform a self-join using this function without aliasing the input + * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since + * there is no way to disambiguate which side of the join you would like to reference. + * * @group untypedrel * @since 2.0.0 */ @@ -748,14 +748,14 @@ class Dataset[T] private[sql]( * Different from other join functions, the join columns will only appear once in the output, * i.e. similar to SQL's `JOIN USING` syntax. * - * Note that if you perform a self-join using this function without aliasing the input - * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since - * there is no way to disambiguate which side of the join you would like to reference. - * * @param right Right side of the join operation. * @param usingColumns Names of the columns to join on. This columns must exist on both sides. * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. * + * @note If you perform a self-join using this function without aliasing the input + * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since + * there is no way to disambiguate which side of the join you would like to reference. + * * @group untypedrel * @since 2.0.0 */ @@ -856,10 +856,10 @@ class Dataset[T] private[sql]( /** * Explicit cartesian join with another [[DataFrame]]. * - * Note that cartesian joins are very expensive without an extra filter that can be pushed down. - * * @param right Right side of the join operation. * + * @note Cartesian joins are very expensive without an extra filter that can be pushed down. + * * @group untypedrel * @since 2.1.0 */ @@ -1044,7 +1044,8 @@ class Dataset[T] private[sql]( /** * Selects column based on the column name and return it as a [[Column]]. - * Note that the column name can also reference to a nested column like `a.b`. + * + * @note The column name can also reference to a nested column like `a.b`. * * @group untypedrel * @since 2.0.0 @@ -1053,7 +1054,8 @@ class Dataset[T] private[sql]( /** * Selects column based on the column name and return it as a [[Column]]. - * Note that the column name can also reference to a nested column like `a.b`. + * + * @note The column name can also reference to a nested column like `a.b`. * * @group untypedrel * @since 2.0.0 @@ -1621,7 +1623,7 @@ class Dataset[T] private[sql]( * Returns a new Dataset containing rows only in both this Dataset and another Dataset. * This is equivalent to `INTERSECT` in SQL. * - * Note that, equality checking is performed directly on the encoded representation of the data + * @note Equality checking is performed directly on the encoded representation of the data * and thus is not affected by a custom `equals` function defined on `T`. * * @group typedrel @@ -1635,7 +1637,7 @@ class Dataset[T] private[sql]( * Returns a new Dataset containing rows in this Dataset but not in another Dataset. * This is equivalent to `EXCEPT` in SQL. * - * Note that, equality checking is performed directly on the encoded representation of the data + * @note Equality checking is performed directly on the encoded representation of the data * and thus is not affected by a custom `equals` function defined on `T`. * * @group typedrel @@ -1648,13 +1650,13 @@ class Dataset[T] private[sql]( /** * Returns a new [[Dataset]] by sampling a fraction of rows, using a user-supplied seed. * - * Note: this is NOT guaranteed to provide exactly the fraction of the count - * of the given [[Dataset]]. - * * @param withReplacement Sample with replacement or not. * @param fraction Fraction of rows to generate. * @param seed Seed for sampling. * + * @note This is NOT guaranteed to provide exactly the fraction of the count + * of the given [[Dataset]]. + * * @group typedrel * @since 1.6.0 */ @@ -1670,12 +1672,12 @@ class Dataset[T] private[sql]( /** * Returns a new [[Dataset]] by sampling a fraction of rows, using a random seed. * - * Note: this is NOT guaranteed to provide exactly the fraction of the total count - * of the given [[Dataset]]. - * * @param withReplacement Sample with replacement or not. * @param fraction Fraction of rows to generate. * + * @note This is NOT guaranteed to provide exactly the fraction of the total count + * of the given [[Dataset]]. + * * @group typedrel * @since 1.6.0 */ @@ -2375,7 +2377,7 @@ class Dataset[T] private[sql]( * * The iterator will consume as much memory as the largest partition in this Dataset. * - * Note: this results in multiple Spark jobs, and if the input Dataset is the result + * @note this results in multiple Spark jobs, and if the input Dataset is the result * of a wide transformation (e.g. join with different partitioners), to avoid * recomputing the input Dataset should be cached first. * @@ -2453,7 +2455,7 @@ class Dataset[T] private[sql]( * Returns a new Dataset that contains only the unique rows from this Dataset. * This is an alias for `dropDuplicates`. * - * Note that, equality checking is performed directly on the encoded representation of the data + * @note Equality checking is performed directly on the encoded representation of the data * and thus is not affected by a custom `equals` function defined on `T`. * * @group typedrel diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 3c5cf037c578d..2fae93651b344 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -181,9 +181,6 @@ class SQLContext private[sql](val sparkSession: SparkSession) /** * A collection of methods for registering user-defined functions (UDF). - * Note that the user-defined functions must be deterministic. Due to optimization, - * duplicate invocations may be eliminated or the function may even be invoked more times than - * it is present in the query. * * The following example registers a Scala closure as UDF: * {{{ @@ -208,6 +205,10 @@ class SQLContext private[sql](val sparkSession: SparkSession) * DataTypes.StringType); * }}} * + * @note The user-defined functions must be deterministic. Due to optimization, + * duplicate invocations may be eliminated or the function may even be invoked more times than + * it is present in the query. + * * @group basic * @since 1.3.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 58b2ab3957173..e09e3caa3c981 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -155,9 +155,6 @@ class SparkSession private( /** * A collection of methods for registering user-defined functions (UDF). - * Note that the user-defined functions must be deterministic. Due to optimization, - * duplicate invocations may be eliminated or the function may even be invoked more times than - * it is present in the query. * * The following example registers a Scala closure as UDF: * {{{ @@ -182,6 +179,10 @@ class SparkSession private( * DataTypes.StringType); * }}} * + * @note The user-defined functions must be deterministic. Due to optimization, + * duplicate invocations may be eliminated or the function may even be invoked more times than + * it is present in the query. + * * @since 2.0.0 */ def udf: UDFRegistration = sessionState.udf @@ -201,7 +202,7 @@ class SparkSession private( * Start a new session with isolated SQL configurations, temporary tables, registered * functions are isolated, but sharing the underlying [[SparkContext]] and cached data. * - * Note: Other than the [[SparkContext]], all shared state is initialized lazily. + * @note Other than the [[SparkContext]], all shared state is initialized lazily. * This method will force the initialization of the shared state to ensure that parent * and child sessions are set up with the same shared state. If the underlying catalog * implementation is Hive, this will initialize the metastore, which may take some time. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 0444ad10d34fb..6043c5ee14b54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -39,7 +39,8 @@ import org.apache.spark.util.Utils /** * Functions for registering user-defined functions. Use [[SQLContext.udf]] to access this. - * Note that the user-defined functions must be deterministic. + * + * @note The user-defined functions must be deterministic. * * @since 1.3.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 4914a9d722a83..1b56c08f729c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -28,7 +28,7 @@ package object state { implicit class StateStoreOps[T: ClassTag](dataRDD: RDD[T]) { - /** Map each partition of a RDD along with data in a [[StateStore]]. */ + /** Map each partition of an RDD along with data in a [[StateStore]]. */ def mapPartitionsWithStateStore[U: ClassTag]( sqlContext: SQLContext, checkpointLocation: String, @@ -49,7 +49,7 @@ package object state { storeUpdateFunction) } - /** Map each partition of a RDD along with data in a [[StateStore]]. */ + /** Map each partition of an RDD along with data in a [[StateStore]]. */ private[streaming] def mapPartitionsWithStateStore[U: ClassTag]( checkpointLocation: String, operatorId: Long, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 28598af781653..36dd5f78ac137 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -25,9 +25,7 @@ import org.apache.spark.sql.types.DataType /** * A user-defined function. To create one, use the `udf` functions in [[functions]]. - * Note that the user-defined functions must be deterministic. Due to optimization, - * duplicate invocations may be eliminated or the function may even be invoked more times than - * it is present in the query. + * * As an example: * {{{ * // Defined a UDF that returns true or false based on some numeric score. @@ -37,6 +35,10 @@ import org.apache.spark.sql.types.DataType * df.select( predict(df("score")) ) * }}} * + * @note The user-defined functions must be deterministic. Due to optimization, + * duplicate invocations may be eliminated or the function may even be invoked more times than + * it is present in the query. + * * @since 1.3.0 */ @InterfaceStability.Stable diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e221c032b82f6..d5940c638acdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -476,7 +476,7 @@ object functions { * * (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn) * - * Note: the list of columns should match with grouping columns exactly, or empty (means all the + * @note The list of columns should match with grouping columns exactly, or empty (means all the * grouping columns). * * @group agg_funcs @@ -489,7 +489,7 @@ object functions { * * (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn) * - * Note: the list of columns should match with grouping columns exactly. + * @note The list of columns should match with grouping columns exactly. * * @group agg_funcs * @since 2.0.0 @@ -1120,7 +1120,7 @@ object functions { * Generate a random column with independent and identically distributed (i.i.d.) samples * from U[0.0, 1.0]. * - * Note that this is indeterministic when data partitions are not fixed. + * @note This is indeterministic when data partitions are not fixed. * * @group normal_funcs * @since 1.4.0 @@ -1140,7 +1140,7 @@ object functions { * Generate a column with independent and identically distributed (i.i.d.) samples from * the standard normal distribution. * - * Note that this is indeterministic when data partitions are not fixed. + * @note This is indeterministic when data partitions are not fixed. * * @group normal_funcs * @since 1.4.0 @@ -1159,7 +1159,7 @@ object functions { /** * Partition ID. * - * Note that this is indeterministic because it depends on data partitioning and task scheduling. + * @note This is indeterministic because it depends on data partitioning and task scheduling. * * @group normal_funcs * @since 1.6.0 @@ -2207,7 +2207,7 @@ object functions { * Locate the position of the first occurrence of substr column in the given string. * Returns null if either of the arguments are null. * - * NOTE: The position is not zero based, but 1 based index. Returns 0 if substr + * @note The position is not zero based, but 1 based index. Returns 0 if substr * could not be found in str. * * @group string_funcs @@ -2242,7 +2242,8 @@ object functions { /** * Locate the position of the first occurrence of substr. - * NOTE: The position is not zero based, but 1 based index. Returns 0 if substr + * + * @note The position is not zero based, but 1 based index. Returns 0 if substr * could not be found in str. * * @group string_funcs @@ -2255,7 +2256,7 @@ object functions { /** * Locate the position of the first occurrence of substr in a string column, after position pos. * - * NOTE: The position is not zero based, but 1 based index. returns 0 if substr + * @note The position is not zero based, but 1 based index. returns 0 if substr * could not be found in str. * * @group string_funcs @@ -2369,7 +2370,8 @@ object functions { /** * Splits str around pattern (pattern is a regular expression). - * NOTE: pattern is a string representation of the regular expression. + * + * @note Pattern is a string representation of the regular expression. * * @group string_funcs * @since 1.5.0 @@ -2468,7 +2470,7 @@ object functions { * A pattern could be for instance `dd.MM.yyyy` and could return a string like '18.03.1993'. All * pattern letters of [[java.text.SimpleDateFormat]] can be used. * - * NOTE: Use when ever possible specialized functions like [[year]]. These benefit from a + * @note Use when ever possible specialized functions like [[year]]. These benefit from a * specialized implementation. * * @group datetime_funcs diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index dec316be7aea1..7c64e28d24724 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -140,7 +140,7 @@ abstract class JdbcDialect extends Serializable { * tried in reverse order. A user-added dialect will thus be applied first, * overwriting the defaults. * - * Note that all new dialects are applied to new jdbc DataFrames only. Make + * @note All new dialects are applied to new jdbc DataFrames only. Make * sure to register your dialects first. */ @DeveloperApi diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 15a48072525b2..ff6dd8cb0cf92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -69,7 +69,8 @@ trait DataSourceRegister { trait RelationProvider { /** * Returns a new base relation with the given parameters. - * Note: the parameters' keywords are case insensitive and this insensitivity is enforced + * + * @note The parameters' keywords are case insensitive and this insensitivity is enforced * by the Map that is passed to the function. */ def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation @@ -99,7 +100,8 @@ trait RelationProvider { trait SchemaRelationProvider { /** * Returns a new base relation with the given parameters and user defined schema. - * Note: the parameters' keywords are case insensitive and this insensitivity is enforced + * + * @note The parameters' keywords are case insensitive and this insensitivity is enforced * by the Map that is passed to the function. */ def createRelation( @@ -205,7 +207,7 @@ abstract class BaseRelation { * large to broadcast. This method will be called multiple times during query planning * and thus should not perform expensive operations for each invocation. * - * Note that it is always better to overestimate size than underestimate, because underestimation + * @note It is always better to overestimate size than underestimate, because underestimation * could lead to execution plans that are suboptimal (i.e. broadcasting a very large table). * * @since 1.3.0 @@ -219,7 +221,7 @@ abstract class BaseRelation { * * If `needConversion` is `false`, buildScan() should return an [[RDD]] of [[InternalRow]] * - * Note: The internal representation is not stable across releases and thus data sources outside + * @note The internal representation is not stable across releases and thus data sources outside * of Spark SQL should leave this as true. * * @since 1.4.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala index 5e93fc469a41f..4504582187b97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.QueryExecution * :: Experimental :: * The interface of query execution listener that can be used to analyze execution metrics. * - * Note that implementations should guarantee thread-safety as they can be invoked by + * @note Implementations should guarantee thread-safety as they can be invoked by * multiple different threads. */ @Experimental @@ -39,24 +39,26 @@ trait QueryExecutionListener { /** * A callback function that will be called when a query executed successfully. - * Note that this can be invoked by multiple different threads. * * @param funcName name of the action that triggered this query. * @param qe the QueryExecution object that carries detail information like logical plan, * physical plan, etc. * @param durationNs the execution time for this query in nanoseconds. + * + * @note This can be invoked by multiple different threads. */ @DeveloperApi def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit /** * A callback function that will be called when a query execution failed. - * Note that this can be invoked by multiple different threads. * * @param funcName the name of the action that triggered this query. * @param qe the QueryExecution object that carries detail information like logical plan, * physical plan, etc. * @param exception the exception that failed this query. + * + * @note This can be invoked by multiple different threads. */ @DeveloperApi def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 0daa29b666f62..b272c8e7d79c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -157,7 +157,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val allColumns = fields.map(_.name).mkString(",") val schema = StructType(fields) - // Create a RDD for the schema + // Create an RDD for the schema val rdd = sparkContext.parallelize((1 to 10000), 10).map { i => Row( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 4808d0fcbc6cc..444261da8de6a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -421,11 +421,11 @@ class StreamingContext private[streaming] ( * by "moving" them from another location within the same file system. File names * starting with . are ignored. * - * '''Note:''' We ensure that the byte array for each record in the - * resulting RDDs of the DStream has the provided record length. - * * @param directory HDFS directory to monitor for new file * @param recordLength length of each record in bytes + * + * @note We ensure that the byte array for each record in the + * resulting RDDs of the DStream has the provided record length. */ def binaryRecordsStream( directory: String, @@ -447,12 +447,12 @@ class StreamingContext private[streaming] ( * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * - * NOTE: Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of - * those RDDs, so `queueStream` doesn't support checkpointing. - * * @param queue Queue of RDDs. Modifications to this data structure must be synchronized. * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @tparam T Type of objects in the RDD + * + * @note Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. */ def queueStream[T: ClassTag]( queue: Queue[RDD[T]], @@ -465,14 +465,14 @@ class StreamingContext private[streaming] ( * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * - * NOTE: Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of - * those RDDs, so `queueStream` doesn't support checkpointing. - * * @param queue Queue of RDDs. Modifications to this data structure must be synchronized. * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @param defaultRDD Default RDD is returned by the DStream when the queue is empty. * Set as null if no RDD should be returned when empty * @tparam T Type of objects in the RDD + * + * @note Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. */ def queueStream[T: ClassTag]( queue: Queue[RDD[T]], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index da9ff858853cf..aa4003c62e1e7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -74,7 +74,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( */ def repartition(numPartitions: Int): JavaPairDStream[K, V] = dstream.repartition(numPartitions) - /** Method that generates a RDD for the given Duration */ + /** Method that generates an RDD for the given Duration */ def compute(validTime: Time): JavaPairRDD[K, V] = { dstream.compute(validTime) match { case Some(rdd) => new JavaPairRDD(rdd) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 4c4376a089f59..b43b9405def97 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -218,11 +218,11 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * for new files and reads them as flat binary files with fixed record lengths, * yielding byte arrays * - * '''Note:''' We ensure that the byte array for each record in the - * resulting RDDs of the DStream has the provided record length. - * * @param directory HDFS directory to monitor for new files * @param recordLength The length at which to split the records + * + * @note We ensure that the byte array for each record in the + * resulting RDDs of the DStream has the provided record length. */ def binaryRecordsStream(directory: String, recordLength: Int): JavaDStream[Array[Byte]] = { ssc.binaryRecordsStream(directory, recordLength) @@ -352,13 +352,13 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * - * NOTE: + * @param queue Queue of RDDs + * @tparam T Type of objects in the RDD + * + * @note * 1. Changes to the queue after the stream is created will not be recognized. * 2. Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of * those RDDs, so `queueStream` doesn't support checkpointing. - * - * @param queue Queue of RDDs - * @tparam T Type of objects in the RDD */ def queueStream[T](queue: java.util.Queue[JavaRDD[T]]): JavaDStream[T] = { implicit val cm: ClassTag[T] = @@ -372,14 +372,14 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * - * NOTE: - * 1. Changes to the queue after the stream is created will not be recognized. - * 2. Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of - * those RDDs, so `queueStream` doesn't support checkpointing. - * * @param queue Queue of RDDs * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @tparam T Type of objects in the RDD + * + * @note + * 1. Changes to the queue after the stream is created will not be recognized. + * 2. Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. */ def queueStream[T]( queue: java.util.Queue[JavaRDD[T]], @@ -396,7 +396,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * - * NOTE: + * @note * 1. Changes to the queue after the stream is created will not be recognized. * 2. Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of * those RDDs, so `queueStream` doesn't support checkpointing. @@ -454,9 +454,10 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { /** * Create a new DStream in which each RDD is generated by applying a function on RDDs of * the DStreams. The order of the JavaRDDs in the transform function parameter will be the - * same as the order of corresponding DStreams in the list. Note that for adding a - * JavaPairDStream in the list of JavaDStreams, convert it to a JavaDStream using - * [[org.apache.spark.streaming.api.java.JavaPairDStream]].toJavaDStream(). + * same as the order of corresponding DStreams in the list. + * + * @note For adding a JavaPairDStream in the list of JavaDStreams, convert it to a + * JavaDStream using [[org.apache.spark.streaming.api.java.JavaPairDStream]].toJavaDStream(). * In the transform function, convert the JavaRDD corresponding to that JavaDStream to * a JavaPairRDD using org.apache.spark.api.java.JavaPairRDD.fromJavaRDD(). */ @@ -476,9 +477,10 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { /** * Create a new DStream in which each RDD is generated by applying a function on RDDs of * the DStreams. The order of the JavaRDDs in the transform function parameter will be the - * same as the order of corresponding DStreams in the list. Note that for adding a - * JavaPairDStream in the list of JavaDStreams, convert it to a JavaDStream using - * [[org.apache.spark.streaming.api.java.JavaPairDStream]].toJavaDStream(). + * same as the order of corresponding DStreams in the list. + * + * @note For adding a JavaPairDStream in the list of JavaDStreams, convert it to + * a JavaDStream using [[org.apache.spark.streaming.api.java.JavaPairDStream]].toJavaDStream(). * In the transform function, convert the JavaRDD corresponding to that JavaDStream to * a JavaPairRDD using org.apache.spark.api.java.JavaPairRDD.fromJavaRDD(). */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 7e0a2ca609c86..e23edfa506517 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -69,13 +69,13 @@ abstract class DStream[T: ClassTag] ( // Methods that should be implemented by subclasses of DStream // ======================================================================= - /** Time interval after which the DStream generates a RDD */ + /** Time interval after which the DStream generates an RDD */ def slideDuration: Duration /** List of parent DStreams on which this DStream depends on */ def dependencies: List[DStream[_]] - /** Method that generates a RDD for the given time */ + /** Method that generates an RDD for the given time */ def compute(validTime: Time): Option[RDD[T]] // ======================================================================= diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala index ed08191f41cc8..9512db7d7d757 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala @@ -128,7 +128,7 @@ class InternalMapWithStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: Clas super.initialize(time) } - /** Method that generates a RDD for the given time */ + /** Method that generates an RDD for the given time */ override def compute(validTime: Time): Option[RDD[MapWithStateRDDRecord[K, S, E]]] = { // Get the previous state or create a new empty state RDD val prevStateRDD = getOrCompute(validTime - slideDuration) match { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index ce5a6e00fb2fe..a37fac87300b7 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -186,7 +186,7 @@ class WriteAheadLogBackedBlockRDDSuite assert(rdd.collect() === data.flatten) // Verify that the block fetching is skipped when isBlockValid is set to false. - // This is done by using a RDD whose data is only in memory but is set to skip block fetching + // This is done by using an RDD whose data is only in memory but is set to skip block fetching // Using that RDD will throw exception, as it skips block fetching even if the blocks are in // in BlockManager. if (testIsBlockValid) { From 8b1e1088eb274fb15260cd5d6d9508d42837a4d6 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 19 Nov 2016 11:28:25 +0000 Subject: [PATCH 184/198] [SPARK-18353][CORE] spark.rpc.askTimeout defalut value is not 120s ## What changes were proposed in this pull request? Avoid hard-coding spark.rpc.askTimeout to non-default in Client; fix doc about spark.rpc.askTimeout default ## How was this patch tested? Existing tests Author: Sean Owen Closes #15833 from srowen/SPARK-18353. --- core/src/main/scala/org/apache/spark/deploy/Client.scala | 4 +++- docs/configuration.md | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index ee276e1b71138..a4de3d7eaf458 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -221,7 +221,9 @@ object Client { val conf = new SparkConf() val driverArgs = new ClientArguments(args) - conf.set("spark.rpc.askTimeout", "10") + if (!conf.contains("spark.rpc.askTimeout")) { + conf.set("spark.rpc.askTimeout", "10s") + } Logger.getRootLogger.setLevel(driverArgs.logLevel) val rpcEnv = diff --git a/docs/configuration.md b/docs/configuration.md index c021a377ba105..a3b4ff01e6d92 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1184,7 +1184,7 @@ Apart from these, the following properties are also available, and may be useful spark.rpc.askTimeout - 120s + spark.network.timeout Duration for an RPC ask operation to wait before timing out. @@ -1566,7 +1566,7 @@ Apart from these, the following properties are also available, and may be useful spark.core.connection.ack.wait.timeout - 60s + spark.network.timeout How long for the connection to wait for ack to occur before timing out and giving up. To avoid unwilling timeout caused by long pause like GC, From ded5fefb6f5c0a97bf3d7fa1c0494dc434b6ee40 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 19 Nov 2016 13:48:56 +0000 Subject: [PATCH 185/198] [SPARK-18448][CORE] Fix @since 2.1.0 on new SparkSession.close() method ## What changes were proposed in this pull request? Fix since 2.1.0 on new SparkSession.close() method. I goofed in https://github.com/apache/spark/pull/15932 because it was back-ported to 2.1 instead of just master as originally planned. Author: Sean Owen Closes #15938 from srowen/SPARK-18448.2. --- sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index e09e3caa3c981..71b1880dc0715 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -652,7 +652,7 @@ class SparkSession private( /** * Synonym for `stop()`. * - * @since 2.2.0 + * @since 2.1.0 */ override def close(): Unit = stop() From ea77c81ec0db27ea4709f71dc080d00167505a7d Mon Sep 17 00:00:00 2001 From: Stavros Kontopoulos Date: Sat, 19 Nov 2016 16:02:59 -0800 Subject: [PATCH 186/198] [SPARK-17062][MESOS] add conf option to mesos dispatcher Adds --conf option to set spark configuration properties in mesos dispacther. Properties provided with --conf take precedence over properties within the properties file. The reason for this PR is that for simple configuration or testing purposes we need to provide a property file (ideally a shared one for a cluster) even if we just provide a single property. Manually tested. Author: Stavros Kontopoulos Author: Stavros Kontopoulos Closes #14650 from skonto/dipatcher_conf. --- .../org/apache/spark/deploy/SparkSubmit.scala | 18 ++--- .../spark/deploy/SparkSubmitArguments.scala | 6 +- .../apache/spark/util/CommandLineUtils.scala | 56 +++++++++++++++ .../scala/org/apache/spark/util/Utils.scala | 14 ++++ .../spark/deploy/SparkSubmitSuite.scala | 43 +++++++----- .../deploy/mesos/MesosClusterDispatcher.scala | 9 ++- .../MesosClusterDispatcherArguments.scala | 70 +++++++++++++++---- ...MesosClusterDispatcherArgumentsSuite.scala | 63 +++++++++++++++++ .../mesos/MesosClusterDispatcherSuite.scala | 40 +++++++++++ 9 files changed, 266 insertions(+), 53 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala create mode 100644 mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala create mode 100644 mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherSuite.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index c70061bc5b5bc..85f80b6971e80 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -41,12 +41,11 @@ import org.apache.ivy.plugins.matcher.GlobPatternMatcher import org.apache.ivy.plugins.repository.file.FileRepository import org.apache.ivy.plugins.resolver.{ChainResolver, FileSystemResolver, IBiblioResolver} -import org.apache.spark.{SPARK_REVISION, SPARK_VERSION, SparkException, SparkUserAppException} -import org.apache.spark.{SPARK_BRANCH, SPARK_BUILD_DATE, SPARK_BUILD_USER, SPARK_REPO_URL} +import org.apache.spark._ import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.rest._ import org.apache.spark.launcher.SparkLauncher -import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} +import org.apache.spark.util._ /** * Whether to submit, kill, or request the status of an application. @@ -63,7 +62,7 @@ private[deploy] object SparkSubmitAction extends Enumeration { * This program handles setting up the classpath with relevant Spark dependencies and provides * a layer over the different cluster managers and deploy modes that Spark supports. */ -object SparkSubmit { +object SparkSubmit extends CommandLineUtils { // Cluster managers private val YARN = 1 @@ -87,15 +86,6 @@ object SparkSubmit { private val CLASS_NOT_FOUND_EXIT_STATUS = 101 // scalastyle:off println - // Exposed for testing - private[spark] var exitFn: Int => Unit = (exitCode: Int) => System.exit(exitCode) - private[spark] var printStream: PrintStream = System.err - private[spark] def printWarning(str: String): Unit = printStream.println("Warning: " + str) - private[spark] def printErrorAndExit(str: String): Unit = { - printStream.println("Error: " + str) - printStream.println("Run with --help for usage help or --verbose for debug output") - exitFn(1) - } private[spark] def printVersionAndExit(): Unit = { printStream.println("""Welcome to ____ __ @@ -115,7 +105,7 @@ object SparkSubmit { } // scalastyle:on println - def main(args: Array[String]): Unit = { + override def main(args: Array[String]): Unit = { val appArgs = new SparkSubmitArguments(args) if (appArgs.verbose) { // scalastyle:off println diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index f1761e7c1ec92..b1d36e1821cc7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -412,10 +412,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S repositories = value case CONF => - value.split("=", 2).toSeq match { - case Seq(k, v) => sparkProperties(k) = v - case _ => SparkSubmit.printErrorAndExit(s"Spark config without '=': $value") - } + val (confName, confValue) = SparkSubmit.parseSparkConfProperty(value) + sparkProperties(confName) = confValue case PROXY_USER => proxyUser = value diff --git a/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala b/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala new file mode 100644 index 0000000000000..d73901686b705 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.PrintStream + +import org.apache.spark.SparkException + +/** + * Contains basic command line parsing functionality and methods to parse some common Spark CLI + * options. + */ +private[spark] trait CommandLineUtils { + + // Exposed for testing + private[spark] var exitFn: Int => Unit = (exitCode: Int) => System.exit(exitCode) + + private[spark] var printStream: PrintStream = System.err + + // scalastyle:off println + + private[spark] def printWarning(str: String): Unit = printStream.println("Warning: " + str) + + private[spark] def printErrorAndExit(str: String): Unit = { + printStream.println("Error: " + str) + printStream.println("Run with --help for usage help or --verbose for debug output") + exitFn(1) + } + + // scalastyle:on println + + private[spark] def parseSparkConfProperty(pair: String): (String, String) = { + pair.split("=", 2).toSeq match { + case Seq(k, v) => (k, v) + case _ => printErrorAndExit(s"Spark config without '=': $pair") + throw new SparkException(s"Spark config without '=': $pair") + } + } + + def main(args: Array[String]): Unit +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 23b95b9f649fe..748d729554fca 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2056,6 +2056,20 @@ private[spark] object Utils extends Logging { path } + /** + * Updates Spark config with properties from a set of Properties. + * Provided properties have the highest priority. + */ + def updateSparkConfigFromProperties( + conf: SparkConf, + properties: Map[String, String]) : Unit = { + properties.filter { case (k, v) => + k.startsWith("spark.") + }.foreach { case (k, v) => + conf.set(k, v) + } + } + /** Load properties present in the given file. */ def getPropertiesFromFile(filename: String): Map[String, String] = { val file = new File(filename) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 7c649e305a37e..626888022903b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -34,21 +34,11 @@ import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate import org.apache.spark.internal.config._ import org.apache.spark.internal.Logging import org.apache.spark.TestUtils.JavaSourceFromString -import org.apache.spark.util.{ResetSystemProperties, Utils} +import org.apache.spark.util.{CommandLineUtils, ResetSystemProperties, Utils} -// Note: this suite mixes in ResetSystemProperties because SparkSubmit.main() sets a bunch -// of properties that needed to be cleared after tests. -class SparkSubmitSuite - extends SparkFunSuite - with Matchers - with BeforeAndAfterEach - with ResetSystemProperties - with Timeouts { - override def beforeEach() { - super.beforeEach() - System.setProperty("spark.testing", "true") - } +trait TestPrematureExit { + suite: SparkFunSuite => private val noOpOutputStream = new OutputStream { def write(b: Int) = {} @@ -65,16 +55,19 @@ class SparkSubmitSuite } /** Returns true if the script exits and the given search string is printed. */ - private def testPrematureExit(input: Array[String], searchString: String) = { + private[spark] def testPrematureExit( + input: Array[String], + searchString: String, + mainObject: CommandLineUtils = SparkSubmit) : Unit = { val printStream = new BufferPrintStream() - SparkSubmit.printStream = printStream + mainObject.printStream = printStream @volatile var exitedCleanly = false - SparkSubmit.exitFn = (_) => exitedCleanly = true + mainObject.exitFn = (_) => exitedCleanly = true val thread = new Thread { override def run() = try { - SparkSubmit.main(input) + mainObject.main(input) } catch { // If exceptions occur after the "exit" has happened, fine to ignore them. // These represent code paths not reachable during normal execution. @@ -88,6 +81,22 @@ class SparkSubmitSuite fail(s"Search string '$searchString' not found in $joined") } } +} + +// Note: this suite mixes in ResetSystemProperties because SparkSubmit.main() sets a bunch +// of properties that needed to be cleared after tests. +class SparkSubmitSuite + extends SparkFunSuite + with Matchers + with BeforeAndAfterEach + with ResetSystemProperties + with Timeouts + with TestPrematureExit { + + override def beforeEach() { + super.beforeEach() + System.setProperty("spark.testing", "true") + } // scalastyle:off println test("prints usage on empty input") { diff --git a/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala index 7d6693b4cdf5b..792ade8f0bdbd 100644 --- a/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala +++ b/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -25,7 +25,7 @@ import org.apache.spark.deploy.mesos.ui.MesosClusterUI import org.apache.spark.deploy.rest.mesos.MesosRestServer import org.apache.spark.internal.Logging import org.apache.spark.scheduler.cluster.mesos._ -import org.apache.spark.util.{ShutdownHookManager, Utils} +import org.apache.spark.util.{CommandLineUtils, ShutdownHookManager, Utils} /* * A dispatcher that is responsible for managing and launching drivers, and is intended to be @@ -92,8 +92,11 @@ private[mesos] class MesosClusterDispatcher( } } -private[mesos] object MesosClusterDispatcher extends Logging { - def main(args: Array[String]) { +private[mesos] object MesosClusterDispatcher + extends Logging + with CommandLineUtils { + + override def main(args: Array[String]) { Utils.initDaemon(log) val conf = new SparkConf val dispatcherArgs = new MesosClusterDispatcherArguments(args, conf) diff --git a/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala index 11e13441eeba6..ef08502ec8dd6 100644 --- a/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala +++ b/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala @@ -18,23 +18,43 @@ package org.apache.spark.deploy.mesos import scala.annotation.tailrec +import scala.collection.mutable -import org.apache.spark.SparkConf import org.apache.spark.util.{IntParam, Utils} - +import org.apache.spark.SparkConf private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: SparkConf) { - var host = Utils.localHostName() - var port = 7077 - var name = "Spark Cluster" - var webUiPort = 8081 + var host: String = Utils.localHostName() + var port: Int = 7077 + var name: String = "Spark Cluster" + var webUiPort: Int = 8081 + var verbose: Boolean = false var masterUrl: String = _ var zookeeperUrl: Option[String] = None var propertiesFile: String = _ + val confProperties: mutable.HashMap[String, String] = + new mutable.HashMap[String, String]() parse(args.toList) + // scalastyle:on println propertiesFile = Utils.loadDefaultSparkProperties(conf, propertiesFile) + Utils.updateSparkConfigFromProperties(conf, confProperties) + + // scalastyle:off println + if (verbose) { + MesosClusterDispatcher.printStream.println(s"Using host: $host") + MesosClusterDispatcher.printStream.println(s"Using port: $port") + MesosClusterDispatcher.printStream.println(s"Using webUiPort: $webUiPort") + MesosClusterDispatcher.printStream.println(s"Framework Name: $name") + + Option(propertiesFile).foreach { file => + MesosClusterDispatcher.printStream.println(s"Using properties file: $file") + } + + MesosClusterDispatcher.printStream.println(s"Spark Config properties set:") + conf.getAll.foreach(println) + } @tailrec private def parse(args: List[String]): Unit = args match { @@ -58,9 +78,10 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: case ("--master" | "-m") :: value :: tail => if (!value.startsWith("mesos://")) { // scalastyle:off println - System.err.println("Cluster dispatcher only supports mesos (uri begins with mesos://)") + MesosClusterDispatcher.printStream + .println("Cluster dispatcher only supports mesos (uri begins with mesos://)") // scalastyle:on println - System.exit(1) + MesosClusterDispatcher.exitFn(1) } masterUrl = value.stripPrefix("mesos://") parse(tail) @@ -73,28 +94,45 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: propertiesFile = value parse(tail) + case ("--conf") :: value :: tail => + val pair = MesosClusterDispatcher. + parseSparkConfProperty(value) + confProperties(pair._1) = pair._2 + parse(tail) + case ("--help") :: tail => - printUsageAndExit(0) + printUsageAndExit(0) + + case ("--verbose") :: tail => + verbose = true + parse(tail) case Nil => - if (masterUrl == null) { + if (Option(masterUrl).isEmpty) { // scalastyle:off println - System.err.println("--master is required") + MesosClusterDispatcher.printStream.println("--master is required") // scalastyle:on println printUsageAndExit(1) } - case _ => + case value => + // scalastyle:off println + MesosClusterDispatcher.printStream.println(s"Unrecognized option: '${value.head}'") + // scalastyle:on println printUsageAndExit(1) } private def printUsageAndExit(exitCode: Int): Unit = { + val outStream = MesosClusterDispatcher.printStream + // scalastyle:off println - System.err.println( + outStream.println( "Usage: MesosClusterDispatcher [options]\n" + "\n" + "Options:\n" + " -h HOST, --host HOST Hostname to listen on\n" + + " --help Show this help message and exit.\n" + + " --verbose, Print additional debug output.\n" + " -p PORT, --port PORT Port to listen on (default: 7077)\n" + " --webui-port WEBUI_PORT WebUI Port to listen on (default: 8081)\n" + " --name NAME Framework name to show in Mesos UI\n" + @@ -102,8 +140,10 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: " -z --zk ZOOKEEPER Comma delimited URLs for connecting to \n" + " Zookeeper for persistence\n" + " --properties-file FILE Path to a custom Spark properties file.\n" + - " Default is conf/spark-defaults.conf.") + " Default is conf/spark-defaults.conf \n" + + " --conf PROP=VALUE Arbitrary Spark configuration property.\n" + + " Takes precedence over defined properties in properties-file.") // scalastyle:on println - System.exit(exitCode) + MesosClusterDispatcher.exitFn(exitCode) } } diff --git a/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala b/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala new file mode 100644 index 0000000000000..b6c0b325361da --- /dev/null +++ b/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.TestPrematureExit + +class MesosClusterDispatcherArgumentsSuite extends SparkFunSuite + with TestPrematureExit { + + test("test if spark config args are passed sucessfully") { + val args = Array[String]("--master", "mesos://localhost:5050", "--conf", "key1=value1", + "--conf", "spark.mesos.key2=value2", "--verbose") + val conf = new SparkConf() + new MesosClusterDispatcherArguments(args, conf) + + assert(conf.getOption("key1").isEmpty) + assert(conf.get("spark.mesos.key2") == "value2") + } + + test("test non conf settings") { + val masterUrl = "mesos://localhost:5050" + val port = "1212" + val zookeeperUrl = "zk://localhost:2181" + val host = "localhost" + val webUiPort = "2323" + val name = "myFramework" + + val args1 = Array("--master", masterUrl, "--verbose", "--name", name) + val args2 = Array("-p", port, "-h", host, "-z", zookeeperUrl) + val args3 = Array("--webui-port", webUiPort) + + val args = args1 ++ args2 ++ args3 + val conf = new SparkConf() + val mesosDispClusterArgs = new MesosClusterDispatcherArguments(args, conf) + + assert(mesosDispClusterArgs.verbose) + assert(mesosDispClusterArgs.confProperties.isEmpty) + assert(mesosDispClusterArgs.host == host) + assert(Option(mesosDispClusterArgs.masterUrl).isDefined) + assert(mesosDispClusterArgs.masterUrl == masterUrl.stripPrefix("mesos://")) + assert(Option(mesosDispClusterArgs.zookeeperUrl).isDefined) + assert(mesosDispClusterArgs.zookeeperUrl contains zookeeperUrl) + assert(mesosDispClusterArgs.name == name) + assert(mesosDispClusterArgs.webUiPort == webUiPort.toInt) + assert(mesosDispClusterArgs.port == port.toInt) + } +} diff --git a/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherSuite.scala b/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherSuite.scala new file mode 100644 index 0000000000000..7484e3b83670d --- /dev/null +++ b/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos + +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.TestPrematureExit + +class MesosClusterDispatcherSuite extends SparkFunSuite + with TestPrematureExit{ + + test("prints usage on empty input") { + testPrematureExit(Array[String](), + "Usage: MesosClusterDispatcher", MesosClusterDispatcher) + } + + test("prints usage with only --help") { + testPrematureExit(Array("--help"), + "Usage: MesosClusterDispatcher", MesosClusterDispatcher) + } + + test("prints error with unrecognized options") { + testPrematureExit(Array("--blarg"), "Unrecognized option: '--blarg'", MesosClusterDispatcher) + testPrematureExit(Array("-bleg"), "Unrecognized option: '-bleg'", MesosClusterDispatcher) + } +} From 856e0042007c789dda4539fb19a5d4580999fbf4 Mon Sep 17 00:00:00 2001 From: sethah Date: Sun, 20 Nov 2016 01:42:37 +0000 Subject: [PATCH 187/198] [SPARK-18456][ML][FOLLOWUP] Use matrix abstraction for coefficients in LogisticRegression training ## What changes were proposed in this pull request? This is a follow up to some of the discussion [here](https://github.com/apache/spark/pull/15593). During LogisticRegression training, we store the coefficients combined with intercepts as a flat vector, but a more natural abstraction is a matrix. Here, we refactor the code to use matrix where possible, which makes the code more readable and greatly simplifies the indexing. Note: We do not use a Breeze matrix for the cost function as was mentioned in the linked PR. This is because LBFGS/OWLQN require an implicit `MutableInnerProductModule[DenseMatrix[Double], Double]` which is not natively defined in Breeze. We would need to extend Breeze in Spark to define it ourselves. Also, we do not modify the `regParamL1Fun` because OWLQN in Breeze requires a `MutableEnumeratedCoordinateField[(Int, Int), DenseVector[Double]]` (since we still use a dense vector for coefficients). Here again we would have to extend Breeze inside Spark. ## How was this patch tested? This is internal code refactoring - the current unit tests passing show us that the change did not break anything. No added functionality in this patch. Author: sethah Closes #15893 from sethah/logreg_refactor. --- .../classification/LogisticRegression.scala | 115 ++++++++---------- 1 file changed, 53 insertions(+), 62 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 71a7fe53c15f8..f58efd36a1c66 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -463,16 +463,11 @@ class LogisticRegression @Since("1.2.0") ( } /* - The coefficients are laid out in column major order during training. e.g. for - `numClasses = 3` and `numFeatures = 2` and `fitIntercept = true` the layout is: - - Array(beta_11, beta_21, beta_31, beta_12, beta_22, beta_32, intercept_1, intercept_2, - intercept_3) - - where beta_jk corresponds to the coefficient for class `j` and feature `k`. + The coefficients are laid out in column major order during training. Here we initialize + a column major matrix of initial coefficients. */ - val initialCoefficientsWithIntercept = - Vectors.zeros(numCoefficientSets * numFeaturesPlusIntercept) + val initialCoefWithInterceptMatrix = + Matrices.zeros(numCoefficientSets, numFeaturesPlusIntercept) val initialModelIsValid = optInitialModel match { case Some(_initialModel) => @@ -491,18 +486,15 @@ class LogisticRegression @Since("1.2.0") ( } if (initialModelIsValid) { - val initialCoefWithInterceptArray = initialCoefficientsWithIntercept.toArray val providedCoef = optInitialModel.get.coefficientMatrix - providedCoef.foreachActive { (row, col, value) => - // convert matrix to column major for training - val flatIndex = col * numCoefficientSets + row + providedCoef.foreachActive { (classIndex, featureIndex, value) => // We need to scale the coefficients since they will be trained in the scaled space - initialCoefWithInterceptArray(flatIndex) = value * featuresStd(col) + initialCoefWithInterceptMatrix.update(classIndex, featureIndex, + value * featuresStd(featureIndex)) } if ($(fitIntercept)) { - optInitialModel.get.interceptVector.foreachActive { (index, value) => - val coefIndex = numCoefficientSets * numFeatures + index - initialCoefWithInterceptArray(coefIndex) = value + optInitialModel.get.interceptVector.foreachActive { (classIndex, value) => + initialCoefWithInterceptMatrix.update(classIndex, numFeatures, value) } } } else if ($(fitIntercept) && isMultinomial) { @@ -532,8 +524,7 @@ class LogisticRegression @Since("1.2.0") ( val rawIntercepts = histogram.map(c => math.log(c + 1)) // add 1 for smoothing val rawMean = rawIntercepts.sum / rawIntercepts.length rawIntercepts.indices.foreach { i => - initialCoefficientsWithIntercept.toArray(numClasses * numFeatures + i) = - rawIntercepts(i) - rawMean + initialCoefWithInterceptMatrix.update(i, numFeatures, rawIntercepts(i) - rawMean) } } else if ($(fitIntercept)) { /* @@ -549,12 +540,12 @@ class LogisticRegression @Since("1.2.0") ( b = \log{P(1) / P(0)} = \log{count_1 / count_0} }}} */ - initialCoefficientsWithIntercept.toArray(numFeatures) = math.log( - histogram(1) / histogram(0)) + initialCoefWithInterceptMatrix.update(0, numFeatures, + math.log(histogram(1) / histogram(0))) } val states = optimizer.iterations(new CachedDiffFunction(costFun), - initialCoefficientsWithIntercept.asBreeze.toDenseVector) + new BDV[Double](initialCoefWithInterceptMatrix.toArray)) /* Note that in Logistic Regression, the objective history (loss + regularization) @@ -586,15 +577,24 @@ class LogisticRegression @Since("1.2.0") ( Note that the intercept in scaled space and original space is the same; as a result, no scaling is needed. */ - val rawCoefficients = state.x.toArray.clone() - val coefficientArray = Array.tabulate(numCoefficientSets * numFeatures) { i => - val colMajorIndex = (i % numFeatures) * numCoefficientSets + i / numFeatures - val featureIndex = i % numFeatures - if (featuresStd(featureIndex) != 0.0) { - rawCoefficients(colMajorIndex) / featuresStd(featureIndex) - } else { - 0.0 + val allCoefficients = state.x.toArray.clone() + val allCoefMatrix = new DenseMatrix(numCoefficientSets, numFeaturesPlusIntercept, + allCoefficients) + val denseCoefficientMatrix = new DenseMatrix(numCoefficientSets, numFeatures, + new Array[Double](numCoefficientSets * numFeatures), isTransposed = true) + val interceptVec = if ($(fitIntercept) || !isMultinomial) { + Vectors.zeros(numCoefficientSets) + } else { + Vectors.sparse(numCoefficientSets, Seq()) + } + // separate intercepts and coefficients from the combined matrix + allCoefMatrix.foreachActive { (classIndex, featureIndex, value) => + val isIntercept = $(fitIntercept) && (featureIndex == numFeatures) + if (!isIntercept && featuresStd(featureIndex) != 0.0) { + denseCoefficientMatrix.update(classIndex, featureIndex, + value / featuresStd(featureIndex)) } + if (isIntercept) interceptVec.toArray(classIndex) = value } if ($(regParam) == 0.0 && isMultinomial) { @@ -607,17 +607,16 @@ class LogisticRegression @Since("1.2.0") ( Friedman, et al. "Regularization Paths for Generalized Linear Models via Coordinate Descent," https://core.ac.uk/download/files/153/6287975.pdf */ - val coefficientMean = coefficientArray.sum / coefficientArray.length - coefficientArray.indices.foreach { i => coefficientArray(i) -= coefficientMean} + val denseValues = denseCoefficientMatrix.values + val coefficientMean = denseValues.sum / denseValues.length + denseCoefficientMatrix.update(_ - coefficientMean) } - val denseCoefficientMatrix = - new DenseMatrix(numCoefficientSets, numFeatures, coefficientArray, isTransposed = true) // TODO: use `denseCoefficientMatrix.compressed` after SPARK-17471 val compressedCoefficientMatrix = if (isMultinomial) { denseCoefficientMatrix } else { - val compressedVector = Vectors.dense(coefficientArray).compressed + val compressedVector = Vectors.dense(denseCoefficientMatrix.values).compressed compressedVector match { case dv: DenseVector => denseCoefficientMatrix case sv: SparseVector => @@ -626,25 +625,13 @@ class LogisticRegression @Since("1.2.0") ( } } - val interceptsArray: Array[Double] = if ($(fitIntercept)) { - Array.tabulate(numCoefficientSets) { i => - val coefIndex = numFeatures * numCoefficientSets + i - rawCoefficients(coefIndex) - } - } else { - Array.empty[Double] - } - val interceptVector = if (interceptsArray.nonEmpty && isMultinomial) { - // The intercepts are never regularized, so we always center the mean. - val interceptMean = interceptsArray.sum / numClasses - interceptsArray.indices.foreach { i => interceptsArray(i) -= interceptMean } - Vectors.dense(interceptsArray) - } else if (interceptsArray.length == 1) { - Vectors.dense(interceptsArray) - } else { - Vectors.sparse(numCoefficientSets, Seq()) + // center the intercepts when using multinomial algorithm + if ($(fitIntercept) && isMultinomial) { + val interceptArray = interceptVec.toArray + val interceptMean = interceptArray.sum / interceptArray.length + (0 until interceptVec.size).foreach { i => interceptArray(i) -= interceptMean } } - (compressedCoefficientMatrix, interceptVector.compressed, arrayBuilder.result()) + (compressedCoefficientMatrix, interceptVec.compressed, arrayBuilder.result()) } } @@ -1424,6 +1411,7 @@ private class LogisticAggregator( private val numFeatures = bcFeaturesStd.value.length private val numFeaturesPlusIntercept = if (fitIntercept) numFeatures + 1 else numFeatures private val coefficientSize = bcCoefficients.value.size + private val numCoefficientSets = if (multinomial) numClasses else 1 if (multinomial) { require(numClasses == coefficientSize / numFeaturesPlusIntercept, s"The number of " + s"coefficients should be ${numClasses * numFeaturesPlusIntercept} but was $coefficientSize") @@ -1633,12 +1621,12 @@ private class LogisticAggregator( lossSum / weightSum } - def gradient: Vector = { + def gradient: Matrix = { require(weightSum > 0.0, s"The effective number of instances should be " + s"greater than 0.0, but $weightSum.") val result = Vectors.dense(gradientSumArray.clone()) scal(1.0 / weightSum, result) - result + new DenseMatrix(numCoefficientSets, numFeaturesPlusIntercept, result.toArray) } } @@ -1664,6 +1652,7 @@ private class LogisticCostFun( val featuresStd = bcFeaturesStd.value val numFeatures = featuresStd.length val numCoefficientSets = if (multinomial) numClasses else 1 + val numFeaturesPlusIntercept = if (fitIntercept) numFeatures + 1 else numFeatures val logisticAggregator = { val seqOp = (c: LogisticAggregator, instance: Instance) => c.add(instance) @@ -1675,24 +1664,25 @@ private class LogisticCostFun( )(seqOp, combOp, aggregationDepth) } - val totalGradientArray = logisticAggregator.gradient.toArray + val totalGradientMatrix = logisticAggregator.gradient + val coefMatrix = new DenseMatrix(numCoefficientSets, numFeaturesPlusIntercept, coeffs.toArray) // regVal is the sum of coefficients squares excluding intercept for L2 regularization. val regVal = if (regParamL2 == 0.0) { 0.0 } else { var sum = 0.0 - coeffs.foreachActive { case (index, value) => + coefMatrix.foreachActive { case (classIndex, featureIndex, value) => // We do not apply regularization to the intercepts - val isIntercept = fitIntercept && index >= numCoefficientSets * numFeatures + val isIntercept = fitIntercept && (featureIndex == numFeatures) if (!isIntercept) { // The following code will compute the loss of the regularization; also // the gradient of the regularization, and add back to totalGradientArray. sum += { if (standardization) { - totalGradientArray(index) += regParamL2 * value + val gradValue = totalGradientMatrix(classIndex, featureIndex) + totalGradientMatrix.update(classIndex, featureIndex, gradValue + regParamL2 * value) value * value } else { - val featureIndex = index / numCoefficientSets if (featuresStd(featureIndex) != 0.0) { // If `standardization` is false, we still standardize the data // to improve the rate of convergence; as a result, we have to @@ -1700,7 +1690,8 @@ private class LogisticCostFun( // differently to get effectively the same objective function when // the training dataset is not standardized. val temp = value / (featuresStd(featureIndex) * featuresStd(featureIndex)) - totalGradientArray(index) += regParamL2 * temp + val gradValue = totalGradientMatrix(classIndex, featureIndex) + totalGradientMatrix.update(classIndex, featureIndex, gradValue + regParamL2 * temp) value * temp } else { 0.0 @@ -1713,6 +1704,6 @@ private class LogisticCostFun( } bcCoeffs.destroy(blocking = false) - (logisticAggregator.loss + regVal, new BDV(totalGradientArray)) + (logisticAggregator.loss + regVal, new BDV(totalGradientMatrix.toArray)) } } From d93b6552473468df297a08c0bef9ea0bf0f5c13a Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 19 Nov 2016 21:50:20 -0800 Subject: [PATCH 188/198] [SPARK-18458][CORE] Fix signed integer overflow problem at an expression in RadixSort.java ## What changes were proposed in this pull request? This PR avoids that a result of an expression is negative due to signed integer overflow (e.g. 0x10?????? * 8 < 0). This PR casts each operand to `long` before executing a calculation. Since the result is interpreted as long, the result of the expression is positive. ## How was this patch tested? Manually executed query82 of TPC-DS with 100TB Author: Kazuaki Ishizaki Closes #15907 from kiszk/SPARK-18458. --- .../collection/unsafe/sort/RadixSort.java | 48 ++++++++++--------- .../unsafe/sort/UnsafeInMemorySorter.java | 2 +- .../unsafe/sort/RadixSortSuite.scala | 28 +++++------ 3 files changed, 40 insertions(+), 38 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java index 404361734a55b..3dd318471008b 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java @@ -17,6 +17,8 @@ package org.apache.spark.util.collection.unsafe.sort; +import com.google.common.primitives.Ints; + import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; @@ -40,14 +42,14 @@ public class RadixSort { * of always copying the data back to position zero for efficiency. */ public static int sort( - LongArray array, int numRecords, int startByteIndex, int endByteIndex, + LongArray array, long numRecords, int startByteIndex, int endByteIndex, boolean desc, boolean signed) { assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0"; assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7"; assert endByteIndex > startByteIndex; assert numRecords * 2 <= array.size(); - int inIndex = 0; - int outIndex = numRecords; + long inIndex = 0; + long outIndex = numRecords; if (numRecords > 0) { long[][] counts = getCounts(array, numRecords, startByteIndex, endByteIndex); for (int i = startByteIndex; i <= endByteIndex; i++) { @@ -55,13 +57,13 @@ public static int sort( sortAtByte( array, numRecords, counts[i], i, inIndex, outIndex, desc, signed && i == endByteIndex); - int tmp = inIndex; + long tmp = inIndex; inIndex = outIndex; outIndex = tmp; } } } - return inIndex; + return Ints.checkedCast(inIndex); } /** @@ -78,14 +80,14 @@ public static int sort( * @param signed whether this is a signed (two's complement) sort (only applies to last byte). */ private static void sortAtByte( - LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex, + LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex, boolean desc, boolean signed) { assert counts.length == 256; long[] offsets = transformCountsToOffsets( - counts, numRecords, array.getBaseOffset() + outIndex * 8, 8, desc, signed); + counts, numRecords, array.getBaseOffset() + outIndex * 8L, 8, desc, signed); Object baseObject = array.getBaseObject(); - long baseOffset = array.getBaseOffset() + inIndex * 8; - long maxOffset = baseOffset + numRecords * 8; + long baseOffset = array.getBaseOffset() + inIndex * 8L; + long maxOffset = baseOffset + numRecords * 8L; for (long offset = baseOffset; offset < maxOffset; offset += 8) { long value = Platform.getLong(baseObject, offset); int bucket = (int)((value >>> (byteIdx * 8)) & 0xff); @@ -106,13 +108,13 @@ private static void sortAtByte( * significant byte. If the byte does not need sorting the array will be null. */ private static long[][] getCounts( - LongArray array, int numRecords, int startByteIndex, int endByteIndex) { + LongArray array, long numRecords, int startByteIndex, int endByteIndex) { long[][] counts = new long[8][]; // Optimization: do a fast pre-pass to determine which byte indices we can skip for sorting. // If all the byte values at a particular index are the same we don't need to count it. long bitwiseMax = 0; long bitwiseMin = -1L; - long maxOffset = array.getBaseOffset() + numRecords * 8; + long maxOffset = array.getBaseOffset() + numRecords * 8L; Object baseObject = array.getBaseObject(); for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) { long value = Platform.getLong(baseObject, offset); @@ -146,18 +148,18 @@ private static long[][] getCounts( * @return the input counts array. */ private static long[] transformCountsToOffsets( - long[] counts, int numRecords, long outputOffset, int bytesPerRecord, + long[] counts, long numRecords, long outputOffset, long bytesPerRecord, boolean desc, boolean signed) { assert counts.length == 256; int start = signed ? 128 : 0; // output the negative records first (values 129-255). if (desc) { - int pos = numRecords; + long pos = numRecords; for (int i = start; i < start + 256; i++) { pos -= counts[i & 0xff]; counts[i & 0xff] = outputOffset + pos * bytesPerRecord; } } else { - int pos = 0; + long pos = 0; for (int i = start; i < start + 256; i++) { long tmp = counts[i & 0xff]; counts[i & 0xff] = outputOffset + pos * bytesPerRecord; @@ -176,8 +178,8 @@ private static long[] transformCountsToOffsets( */ public static int sortKeyPrefixArray( LongArray array, - int startIndex, - int numRecords, + long startIndex, + long numRecords, int startByteIndex, int endByteIndex, boolean desc, @@ -186,8 +188,8 @@ public static int sortKeyPrefixArray( assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7"; assert endByteIndex > startByteIndex; assert numRecords * 4 <= array.size(); - int inIndex = startIndex; - int outIndex = startIndex + numRecords * 2; + long inIndex = startIndex; + long outIndex = startIndex + numRecords * 2L; if (numRecords > 0) { long[][] counts = getKeyPrefixArrayCounts( array, startIndex, numRecords, startByteIndex, endByteIndex); @@ -196,13 +198,13 @@ public static int sortKeyPrefixArray( sortKeyPrefixArrayAtByte( array, numRecords, counts[i], i, inIndex, outIndex, desc, signed && i == endByteIndex); - int tmp = inIndex; + long tmp = inIndex; inIndex = outIndex; outIndex = tmp; } } } - return inIndex; + return Ints.checkedCast(inIndex); } /** @@ -210,7 +212,7 @@ public static int sortKeyPrefixArray( * getCounts with some added parameters but that seems to hurt in benchmarks. */ private static long[][] getKeyPrefixArrayCounts( - LongArray array, int startIndex, int numRecords, int startByteIndex, int endByteIndex) { + LongArray array, long startIndex, long numRecords, int startByteIndex, int endByteIndex) { long[][] counts = new long[8][]; long bitwiseMax = 0; long bitwiseMin = -1L; @@ -238,11 +240,11 @@ private static long[][] getKeyPrefixArrayCounts( * Specialization of sortAtByte() for key-prefix arrays. */ private static void sortKeyPrefixArrayAtByte( - LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex, + LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex, boolean desc, boolean signed) { assert counts.length == 256; long[] offsets = transformCountsToOffsets( - counts, numRecords, array.getBaseOffset() + outIndex * 8, 16, desc, signed); + counts, numRecords, array.getBaseOffset() + outIndex * 8L, 16, desc, signed); Object baseObject = array.getBaseObject(); long baseOffset = array.getBaseOffset() + inIndex * 8L; long maxOffset = baseOffset + numRecords * 16L; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 2a71e68adafad..252a35ec6bdf5 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -322,7 +322,7 @@ public UnsafeSorterIterator getSortedIterator() { if (sortComparator != null) { if (this.radixSortSupport != null) { offset = RadixSort.sortKeyPrefixArray( - array, nullBoundaryPos, (pos - nullBoundaryPos) / 2, 0, 7, + array, nullBoundaryPos, (pos - nullBoundaryPos) / 2L, 0, 7, radixSortSupport.sortDescending(), radixSortSupport.sortSigned()); } else { MemoryBlock unused = new MemoryBlock( diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala index 366ffda7788d3..d5956ea32096a 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala @@ -22,6 +22,8 @@ import java.util.{Arrays, Comparator} import scala.util.Random +import com.google.common.primitives.Ints + import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging import org.apache.spark.unsafe.array.LongArray @@ -30,7 +32,7 @@ import org.apache.spark.util.collection.Sorter import org.apache.spark.util.random.XORShiftRandom class RadixSortSuite extends SparkFunSuite with Logging { - private val N = 10000 // scale this down for more readable results + private val N = 10000L // scale this down for more readable results /** * Describes a type of sort to test, e.g. two's complement descending. Each sort type has @@ -73,22 +75,22 @@ class RadixSortSuite extends SparkFunSuite with Logging { }, 2, 4, false, false, true)) - private def generateTestData(size: Int, rand: => Long): (Array[JLong], LongArray) = { - val ref = Array.tabulate[Long](size) { i => rand } - val extended = ref ++ Array.fill[Long](size)(0) + private def generateTestData(size: Long, rand: => Long): (Array[JLong], LongArray) = { + val ref = Array.tabulate[Long](Ints.checkedCast(size)) { i => rand } + val extended = ref ++ Array.fill[Long](Ints.checkedCast(size))(0) (ref.map(i => new JLong(i)), new LongArray(MemoryBlock.fromLongArray(extended))) } - private def generateKeyPrefixTestData(size: Int, rand: => Long): (LongArray, LongArray) = { - val ref = Array.tabulate[Long](size * 2) { i => rand } - val extended = ref ++ Array.fill[Long](size * 2)(0) + private def generateKeyPrefixTestData(size: Long, rand: => Long): (LongArray, LongArray) = { + val ref = Array.tabulate[Long](Ints.checkedCast(size * 2)) { i => rand } + val extended = ref ++ Array.fill[Long](Ints.checkedCast(size * 2))(0) (new LongArray(MemoryBlock.fromLongArray(ref)), new LongArray(MemoryBlock.fromLongArray(extended))) } - private def collectToArray(array: LongArray, offset: Int, length: Int): Array[Long] = { + private def collectToArray(array: LongArray, offset: Int, length: Long): Array[Long] = { var i = 0 - val out = new Array[Long](length) + val out = new Array[Long](Ints.checkedCast(length)) while (i < length) { out(i) = array.get(offset + i) i += 1 @@ -107,15 +109,13 @@ class RadixSortSuite extends SparkFunSuite with Logging { } } - private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) { + private def referenceKeyPrefixSort(buf: LongArray, lo: Long, hi: Long, refCmp: PrefixComparator) { val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt))) new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort( - buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] { + buf, Ints.checkedCast(lo), Ints.checkedCast(hi), new Comparator[RecordPointerAndKeyPrefix] { override def compare( r1: RecordPointerAndKeyPrefix, - r2: RecordPointerAndKeyPrefix): Int = { - refCmp.compare(r1.keyPrefix, r2.keyPrefix) - } + r2: RecordPointerAndKeyPrefix): Int = refCmp.compare(r1.keyPrefix, r2.keyPrefix) }) } From bce9a03677f931d52491e7768aba9e4a19a7e696 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 19 Nov 2016 21:57:09 -0800 Subject: [PATCH 189/198] [SPARK-18508][SQL] Fix documentation error for DateDiff ## What changes were proposed in this pull request? The previous documentation and example for DateDiff was wrong. ## How was this patch tested? Doc only change. Author: Reynold Xin Closes #15937 from rxin/datediff-doc. --- .../sql/catalyst/expressions/datetimeExpressions.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 9cec6be841de0..1db1d1995d942 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1101,11 +1101,14 @@ case class TruncDate(date: Expression, format: Expression) * Returns the number of days from startDate to endDate. */ @ExpressionDescription( - usage = "_FUNC_(date1, date2) - Returns the number of days between `date1` and `date2`.", + usage = "_FUNC_(endDate, startDate) - Returns the number of days from `startDate` to `endDate`.", extended = """ Examples: - > SELECT _FUNC_('2009-07-30', '2009-07-31'); + > SELECT _FUNC_('2009-07-31', '2009-07-30'); 1 + + > SELECT _FUNC_('2009-07-30', '2009-07-31'); + -1 """) case class DateDiff(endDate: Expression, startDate: Expression) extends BinaryExpression with ImplicitCastInputTypes { From a64f25d8b403b17ff68c9575f6f35b22e5b62427 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 19 Nov 2016 21:57:49 -0800 Subject: [PATCH 190/198] [SQL] Fix documentation for Concat and ConcatWs --- .../sql/catalyst/expressions/stringExpressions.scala | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index e74ef9a08750e..908aa44f81c97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -40,15 +40,13 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} * An expression that concatenates multiple input strings into a single string. * If any input is null, concat returns null. */ -// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(str1, str2, ..., strN) - Returns the concatenation of `str1`, `str2`, ..., `strN`.", + usage = "_FUNC_(str1, str2, ..., strN) - Returns the concatenation of str1, str2, ..., strN.", extended = """ Examples: - > SELECT _FUNC_('Spark','SQL'); + > SELECT _FUNC_('Spark', 'SQL'); SparkSQL """) -// scalastyle:on line.size.limit case class Concat(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) @@ -89,8 +87,8 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas usage = "_FUNC_(sep, [str | array(str)]+) - Returns the concatenation of the strings separated by `sep`.", extended = """ Examples: - > SELECT _FUNC_(' ', Spark', 'SQL'); - Spark SQL + > SELECT _FUNC_(' ', 'Spark', 'SQL'); + Spark SQL """) // scalastyle:on line.size.limit case class ConcatWs(children: Seq[Expression]) From 7ca7a635242377634c302b7816ce60bd9c908527 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sat, 19 Nov 2016 23:55:09 -0800 Subject: [PATCH 191/198] [SPARK-15214][SQL] Code-generation for Generate ## What changes were proposed in this pull request? This PR adds code generation to `Generate`. It supports two code paths: - General `TraversableOnce` based iteration. This used for regular `Generator` (code generation supporting) expressions. This code path expects the expression to return a `TraversableOnce[InternalRow]` and it will iterate over the returned collection. This PR adds code generation for the `stack` generator. - Specialized `ArrayData/MapData` based iteration. This is used for the `explode`, `posexplode` & `inline` functions and operates directly on the `ArrayData`/`MapData` result that the child of the generator returns. ### Benchmarks I have added some benchmarks and it seems we can create a nice speedup for explode: #### Environment ``` Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 Intel(R) Core(TM) i7-4980HQ CPU 2.80GHz ``` #### Explode Array ##### Before ``` generate explode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate explode array wholestage off 7377 / 7607 2.3 439.7 1.0X generate explode array wholestage on 6055 / 6086 2.8 360.9 1.2X ``` ##### After ``` generate explode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate explode array wholestage off 7432 / 7696 2.3 443.0 1.0X generate explode array wholestage on 631 / 646 26.6 37.6 11.8X ``` #### Explode Map ##### Before ``` generate explode map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate explode map wholestage off 12792 / 12848 1.3 762.5 1.0X generate explode map wholestage on 11181 / 11237 1.5 666.5 1.1X ``` ##### After ``` generate explode map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate explode map wholestage off 10949 / 10972 1.5 652.6 1.0X generate explode map wholestage on 870 / 913 19.3 51.9 12.6X ``` #### Posexplode ##### Before ``` generate posexplode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate posexplode array wholestage off 7547 / 7580 2.2 449.8 1.0X generate posexplode array wholestage on 5786 / 5838 2.9 344.9 1.3X ``` ##### After ``` generate posexplode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate posexplode array wholestage off 7535 / 7548 2.2 449.1 1.0X generate posexplode array wholestage on 620 / 624 27.1 37.0 12.1X ``` #### Inline ##### Before ``` generate inline array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate inline array wholestage off 6935 / 6978 2.4 413.3 1.0X generate inline array wholestage on 6360 / 6400 2.6 379.1 1.1X ``` ##### After ``` generate inline array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate inline array wholestage off 6940 / 6966 2.4 413.6 1.0X generate inline array wholestage on 1002 / 1012 16.7 59.7 6.9X ``` #### Stack ##### Before ``` generate stack: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate stack wholestage off 12980 / 13104 1.3 773.7 1.0X generate stack wholestage on 11566 / 11580 1.5 689.4 1.1X ``` ##### After ``` generate stack: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate stack wholestage off 12875 / 12949 1.3 767.4 1.0X generate stack wholestage on 840 / 845 20.0 50.0 15.3X ``` ## How was this patch tested? Existing tests. Author: Herman van Hovell Author: Herman van Hovell Closes #13065 from hvanhovell/SPARK-15214. --- .../sql/catalyst/expressions/generators.scala | 110 ++++++++-- .../SubexpressionEliminationSuite.scala | 16 +- .../spark/sql/execution/GenerateExec.scala | 202 +++++++++++++++++- .../spark/sql/GeneratorFunctionSuite.scala | 34 +++ .../org/apache/spark/sql/SQLQuerySuite.scala | 7 - .../execution/WholeStageCodegenSuite.scala | 32 ++- .../execution/benchmark/MiscBenchmark.scala | 99 ++++++++- 7 files changed, 463 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index d042bfb63d567..6c38f4998e914 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.mutable + import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ @@ -60,6 +62,26 @@ trait Generator extends Expression { * rows can be made here. */ def terminate(): TraversableOnce[InternalRow] = Nil + + /** + * Check if this generator supports code generation. + */ + def supportCodegen: Boolean = !isInstanceOf[CodegenFallback] +} + +/** + * A collection producing [[Generator]]. This trait provides a different path for code generation, + * by allowing code generation to return either an [[ArrayData]] or a [[MapData]] object. + */ +trait CollectionGenerator extends Generator { + /** The position of an element within the collection should also be returned. */ + def position: Boolean + + /** Rows will be inlined during generation. */ + def inline: Boolean + + /** The type of the returned collection object. */ + def collectionType: DataType = dataType } /** @@ -77,7 +99,9 @@ case class UserDefinedGenerator( private def initializeConverters(): Unit = { inputRow = new InterpretedProjection(children) convertToScala = { - val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) + val inputSchema = StructType(children.map { e => + StructField(e.simpleString, e.dataType, nullable = true) + }) CatalystTypeConverters.createToScalaConverter(inputSchema) }.asInstanceOf[InternalRow => Row] } @@ -109,8 +133,7 @@ case class UserDefinedGenerator( 1 2 3 NULL """) -case class Stack(children: Seq[Expression]) - extends Expression with Generator with CodegenFallback { +case class Stack(children: Seq[Expression]) extends Generator { private lazy val numRows = children.head.eval().asInstanceOf[Int] private lazy val numFields = Math.ceil((children.length - 1.0) / numRows).toInt @@ -149,21 +172,50 @@ case class Stack(children: Seq[Expression]) InternalRow(fields: _*) } } + + + /** + * Only support code generation when stack produces 50 rows or less. + */ + override def supportCodegen: Boolean = numRows <= 50 + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + // Rows - we write these into an array. + val rowData = ctx.freshName("rows") + ctx.addMutableState("InternalRow[]", rowData, s"this.$rowData = new InternalRow[$numRows];") + val values = children.tail + val dataTypes = values.take(numFields).map(_.dataType) + val code = ctx.splitExpressions(ctx.INPUT_ROW, Seq.tabulate(numRows) { row => + val fields = Seq.tabulate(numFields) { col => + val index = row * numFields + col + if (index < values.length) values(index) else Literal(null, dataTypes(col)) + } + val eval = CreateStruct(fields).genCode(ctx) + s"${eval.code}\nthis.$rowData[$row] = ${eval.value};" + }) + + // Create the collection. + val wrapperClass = classOf[mutable.WrappedArray[_]].getName + ctx.addMutableState( + s"$wrapperClass", + ev.value, + s"this.${ev.value} = $wrapperClass$$.MODULE$$.make(this.$rowData);") + ev.copy(code = code, isNull = "false") + } } /** - * A base class for Explode and PosExplode + * A base class for [[Explode]] and [[PosExplode]]. */ -abstract class ExplodeBase(child: Expression, position: Boolean) - extends UnaryExpression with Generator with CodegenFallback with Serializable { +abstract class ExplodeBase extends UnaryExpression with CollectionGenerator with Serializable { + override val inline: Boolean = false - override def checkInputDataTypes(): TypeCheckResult = { - if (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) { + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case _: ArrayType | _: MapType => TypeCheckResult.TypeCheckSuccess - } else { + case _ => TypeCheckResult.TypeCheckFailure( s"input to function explode should be array or map type, not ${child.dataType}") - } } // hive-compatible default alias for explode function ("col" for array, "key", "value" for map) @@ -171,7 +223,7 @@ abstract class ExplodeBase(child: Expression, position: Boolean) case ArrayType(et, containsNull) => if (position) { new StructType() - .add("pos", IntegerType, false) + .add("pos", IntegerType, nullable = false) .add("col", et, containsNull) } else { new StructType() @@ -180,12 +232,12 @@ abstract class ExplodeBase(child: Expression, position: Boolean) case MapType(kt, vt, valueContainsNull) => if (position) { new StructType() - .add("pos", IntegerType, false) - .add("key", kt, false) + .add("pos", IntegerType, nullable = false) + .add("key", kt, nullable = false) .add("value", vt, valueContainsNull) } else { new StructType() - .add("key", kt, false) + .add("key", kt, nullable = false) .add("value", vt, valueContainsNull) } } @@ -218,6 +270,12 @@ abstract class ExplodeBase(child: Expression, position: Boolean) } } } + + override def collectionType: DataType = child.dataType + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + child.genCode(ctx) + } } /** @@ -239,7 +297,9 @@ abstract class ExplodeBase(child: Expression, position: Boolean) 20 """) // scalastyle:on line.size.limit -case class Explode(child: Expression) extends ExplodeBase(child, position = false) +case class Explode(child: Expression) extends ExplodeBase { + override val position: Boolean = false +} /** * Given an input array produces a sequence of rows for each position and value in the array. @@ -260,7 +320,9 @@ case class Explode(child: Expression) extends ExplodeBase(child, position = fals 1 20 """) // scalastyle:on line.size.limit -case class PosExplode(child: Expression) extends ExplodeBase(child, position = true) +case class PosExplode(child: Expression) extends ExplodeBase { + override val position = true +} /** * Explodes an array of structs into a table. @@ -273,10 +335,12 @@ case class PosExplode(child: Expression) extends ExplodeBase(child, position = t 1 a 2 b """) -case class Inline(child: Expression) extends UnaryExpression with Generator with CodegenFallback { +case class Inline(child: Expression) extends UnaryExpression with CollectionGenerator { + override val inline: Boolean = true + override val position: Boolean = false override def checkInputDataTypes(): TypeCheckResult = child.dataType match { - case ArrayType(et, _) if et.isInstanceOf[StructType] => + case ArrayType(st: StructType, _) => TypeCheckResult.TypeCheckSuccess case _ => TypeCheckResult.TypeCheckFailure( @@ -284,9 +348,11 @@ case class Inline(child: Expression) extends UnaryExpression with Generator with } override def elementSchema: StructType = child.dataType match { - case ArrayType(et : StructType, _) => et + case ArrayType(st: StructType, _) => st } + override def collectionType: DataType = child.dataType + private lazy val numFields = elementSchema.fields.length override def eval(input: InternalRow): TraversableOnce[InternalRow] = { @@ -298,4 +364,8 @@ case class Inline(child: Expression) extends UnaryExpression with Generator with yield inputArray.getStruct(i, numFields) } } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + child.genCode(ctx) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index 1e39b24fe8770..2db2a043e546a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.types.{DataType, IntegerType} class SubexpressionEliminationSuite extends SparkFunSuite { test("Semantic equals and hash") { @@ -162,13 +163,18 @@ class SubexpressionEliminationSuite extends SparkFunSuite { test("Children of CodegenFallback") { val one = Literal(1) val two = Add(one, one) - val explode = Explode(two) - val add = Add(two, explode) + val fallback = CodegenFallbackExpression(two) + val add = Add(two, fallback) - var equivalence = new EquivalentExpressions + val equivalence = new EquivalentExpressions equivalence.addExprTree(add, true) - // the `two` inside `explode` should not be added + // the `two` inside `fallback` should not be added assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0) assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode } } + +case class CodegenFallbackExpression(child: Expression) + extends UnaryExpression with CodegenFallback { + override def dataType: DataType = child.dataType +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 19fbf0c162048..f80214af43fc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -20,8 +20,10 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} /** * For lazy computing, be sure the generator.terminate() called in the very last @@ -40,6 +42,10 @@ private[execution] sealed case class LazyIterator(func: () => TraversableOnce[In * output of each into a new stream of rows. This operation is similar to a `flatMap` in functional * programming with one important additional feature, which allows the input rows to be joined with * their output. + * + * This operator supports whole stage code generation for generators that do not implement + * terminate(). + * * @param generator the generator expression * @param join when true, each output row is implicitly joined with the input tuple that produced * it. @@ -54,7 +60,7 @@ case class GenerateExec( outer: Boolean, output: Seq[Attribute], child: SparkPlan) - extends UnaryExecNode { + extends UnaryExecNode with CodegenSupport { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -103,5 +109,197 @@ case class GenerateExec( } } } -} + override def supportCodegen: Boolean = generator.supportCodegen + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].inputRDDs() + } + + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + ctx.currentVars = input + ctx.copyResult = true + + // Add input rows to the values when we are joining + val values = if (join) { + input + } else { + Seq.empty + } + + boundGenerator match { + case e: CollectionGenerator => codeGenCollection(ctx, e, values, row) + case g => codeGenTraversableOnce(ctx, g, values, row) + } + } + + /** + * Generate code for [[CollectionGenerator]] expressions. + */ + private def codeGenCollection( + ctx: CodegenContext, + e: CollectionGenerator, + input: Seq[ExprCode], + row: ExprCode): String = { + + // Generate code for the generator. + val data = e.genCode(ctx) + + // Generate looping variables. + val index = ctx.freshName("index") + + // Add a check if the generate outer flag is true. + val checks = optionalCode(outer, data.isNull) + + // Add position + val position = if (e.position) { + Seq(ExprCode("", "false", index)) + } else { + Seq.empty + } + + // Generate code for either ArrayData or MapData + val (initMapData, updateRowData, values) = e.collectionType match { + case ArrayType(st: StructType, nullable) if e.inline => + val row = codeGenAccessor(ctx, data.value, "col", index, st, nullable, checks) + val fieldChecks = checks ++ optionalCode(nullable, row.isNull) + val columns = st.fields.toSeq.zipWithIndex.map { case (f, i) => + codeGenAccessor(ctx, row.value, f.name, i.toString, f.dataType, f.nullable, fieldChecks) + } + ("", row.code, columns) + + case ArrayType(dataType, nullable) => + ("", "", Seq(codeGenAccessor(ctx, data.value, "col", index, dataType, nullable, checks))) + + case MapType(keyType, valueType, valueContainsNull) => + // Materialize the key and the value arrays before we enter the loop. + val keyArray = ctx.freshName("keyArray") + val valueArray = ctx.freshName("valueArray") + val initArrayData = + s""" + |ArrayData $keyArray = ${data.isNull} ? null : ${data.value}.keyArray(); + |ArrayData $valueArray = ${data.isNull} ? null : ${data.value}.valueArray(); + """.stripMargin + val values = Seq( + codeGenAccessor(ctx, keyArray, "key", index, keyType, nullable = false, checks), + codeGenAccessor(ctx, valueArray, "value", index, valueType, valueContainsNull, checks)) + (initArrayData, "", values) + } + + // In case of outer=true we need to make sure the loop is executed at-least once when the + // array/map contains no input. We do this by setting the looping index to -1 if there is no + // input, evaluation of the array is prevented by a check in the accessor code. + val numElements = ctx.freshName("numElements") + val init = if (outer) { + s"$numElements == 0 ? -1 : 0" + } else { + "0" + } + val numOutput = metricTerm(ctx, "numOutputRows") + s""" + |${data.code} + |$initMapData + |int $numElements = ${data.isNull} ? 0 : ${data.value}.numElements(); + |for (int $index = $init; $index < $numElements; $index++) { + | $numOutput.add(1); + | $updateRowData + | ${consume(ctx, input ++ position ++ values)} + |} + """.stripMargin + } + + /** + * Generate code for a regular [[TraversableOnce]] returning [[Generator]]. + */ + private def codeGenTraversableOnce( + ctx: CodegenContext, + e: Expression, + input: Seq[ExprCode], + row: ExprCode): String = { + + // Generate the code for the generator + val data = e.genCode(ctx) + + // Generate looping variables. + val iterator = ctx.freshName("iterator") + val hasNext = ctx.freshName("hasNext") + val current = ctx.freshName("row") + + // Add a check if the generate outer flag is true. + val checks = optionalCode(outer, s"!$hasNext") + val values = e.dataType match { + case ArrayType(st: StructType, nullable) => + st.fields.toSeq.zipWithIndex.map { case (f, i) => + codeGenAccessor(ctx, current, f.name, s"$i", f.dataType, f.nullable, checks) + } + } + + // In case of outer=true we need to make sure the loop is executed at-least-once when the + // iterator contains no input. We do this by adding an 'outer' variable which guarantees + // execution of the first iteration even if there is no input. Evaluation of the iterator is + // prevented by checks in the next() and accessor code. + val numOutput = metricTerm(ctx, "numOutputRows") + if (outer) { + val outerVal = ctx.freshName("outer") + s""" + |${data.code} + |scala.collection.Iterator $iterator = ${data.value}.toIterator(); + |boolean $outerVal = true; + |while ($iterator.hasNext() || $outerVal) { + | $numOutput.add(1); + | boolean $hasNext = $iterator.hasNext(); + | InternalRow $current = (InternalRow)($hasNext? $iterator.next() : null); + | $outerVal = false; + | ${consume(ctx, input ++ values)} + |} + """.stripMargin + } else { + s""" + |${data.code} + |scala.collection.Iterator $iterator = ${data.value}.toIterator(); + |while ($iterator.hasNext()) { + | $numOutput.add(1); + | InternalRow $current = (InternalRow)($iterator.next()); + | ${consume(ctx, input ++ values)} + |} + """.stripMargin + } + } + + /** + * Generate accessor code for ArrayData and InternalRows. + */ + private def codeGenAccessor( + ctx: CodegenContext, + source: String, + name: String, + index: String, + dt: DataType, + nullable: Boolean, + initialChecks: Seq[String]): ExprCode = { + val value = ctx.freshName(name) + val javaType = ctx.javaType(dt) + val getter = ctx.getValue(source, dt, index) + val checks = initialChecks ++ optionalCode(nullable, s"$source.isNullAt($index)") + if (checks.nonEmpty) { + val isNull = ctx.freshName("isNull") + val code = + s""" + |boolean $isNull = ${checks.mkString(" || ")}; + |$javaType $value = $isNull ? ${ctx.defaultValue(dt)} : $getter; + """.stripMargin + ExprCode(code, isNull, value) + } else { + ExprCode(s"$javaType $value = $getter;", "false", value) + } + } + + private def optionalCode(condition: Boolean, code: => String): Seq[String] = { + if (condition) Seq(code) + else Seq.empty + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index aedc0a8d6f70b..f0995ea1d0025 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -17,8 +17,12 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, Generator} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StructType} class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -202,4 +206,34 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { df.selectExpr("array(struct(a), named_struct('a', b))").selectExpr("inline(*)"), Row(1) :: Row(2) :: Nil) } + + test("SPARK-14986: Outer lateral view with empty generate expression") { + checkAnswer( + sql("select nil from values 1 lateral view outer explode(array()) n as nil"), + Row(null) :: Nil + ) + } + + test("outer explode()") { + checkAnswer( + sql("select * from values 1, 2 lateral view outer explode(array()) a as b"), + Row(1, null) :: Row(2, null) :: Nil) + } + + test("outer generator()") { + spark.sessionState.functionRegistry.registerFunction("empty_gen", _ => EmptyGenerator()) + checkAnswer( + sql("select * from values 1, 2 lateral view outer empty_gen() a as b"), + Row(1, null) :: Row(2, null) :: Nil) + } +} + +case class EmptyGenerator() extends Generator { + override def children: Seq[Expression] = Nil + override def elementSchema: StructType = new StructType().add("id", IntegerType) + override def eval(input: InternalRow): TraversableOnce[InternalRow] = Seq.empty + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val iteratorClass = classOf[Iterator[_]].getName + ev.copy(code = s"$iteratorClass ${ev.value} = $iteratorClass$$.MODULE$$.empty();") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 6b517bc70f7d2..a715176d55d95 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2086,13 +2086,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } - test("SPARK-14986: Outer lateral view with empty generate expression") { - checkAnswer( - sql("select nil from (select 1 as x ) x lateral view outer explode(array()) n as nil"), - Row(null) :: Nil - ) - } - test("data source table created in InMemoryCatalog should be able to read/write") { withTable("tbl") { sql("CREATE TABLE tbl(i INT, j STRING) USING parquet") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index f26e5e7b6990d..e8ea7758cf598 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.Row +import org.apache.spark.sql.{Column, Dataset, Row} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{Add, Literal, Stack} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.expressions.scalalang.typed @@ -113,4 +115,32 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined) assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0))) } + + test("generate should be included in WholeStageCodegen") { + import org.apache.spark.sql.functions._ + val ds = spark.range(2).select( + col("id"), + explode(array(col("id") + 1, col("id") + 2)).as("value")) + val plan = ds.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[GenerateExec]).isDefined) + assert(ds.collect() === Array(Row(0, 1), Row(0, 2), Row(1, 2), Row(1, 3))) + } + + test("large stack generator should not use WholeStageCodegen") { + def createStackGenerator(rows: Int): SparkPlan = { + val id = UnresolvedAttribute("id") + val stack = Stack(Literal(rows) +: Seq.tabulate(rows)(i => Add(id, Literal(i)))) + spark.range(500).select(Column(stack)).queryExecution.executedPlan + } + val isCodeGenerated: SparkPlan => Boolean = { + case WholeStageCodegenExec(_: GenerateExec) => true + case _ => false + } + + // Only 'stack' generators that produce 50 rows or less are code generated. + assert(createStackGenerator(50).find(isCodeGenerated).isDefined) + assert(createStackGenerator(100).find(isCodeGenerated).isEmpty) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala index 470c78120b194..01773c238b0db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala @@ -102,7 +102,7 @@ class MiscBenchmark extends BenchmarkBase { } benchmark.run() - /** + /* Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz collect: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- @@ -124,7 +124,7 @@ class MiscBenchmark extends BenchmarkBase { } benchmark.run() - /** + /* model name : Westmere E56xx/L56xx/X56xx (Nehalem-C) collect limit: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- @@ -132,4 +132,99 @@ class MiscBenchmark extends BenchmarkBase { collect limit 2 millions 3348 / 4005 0.3 3193.3 0.2X */ } + + ignore("generate explode") { + val N = 1 << 24 + runBenchmark("generate explode array", N) { + val df = sparkSession.range(N).selectExpr( + "id as key", + "array(rand(), rand(), rand(), rand(), rand()) as values") + df.selectExpr("key", "explode(values) value").count() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 + Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz + + generate explode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + generate explode array wholestage off 6920 / 7129 2.4 412.5 1.0X + generate explode array wholestage on 623 / 646 26.9 37.1 11.1X + */ + + runBenchmark("generate explode map", N) { + val df = sparkSession.range(N).selectExpr( + "id as key", + "map('a', rand(), 'b', rand(), 'c', rand(), 'd', rand(), 'e', rand()) pairs") + df.selectExpr("key", "explode(pairs) as (k, v)").count() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 + Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz + + generate explode map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + generate explode map wholestage off 11978 / 11993 1.4 714.0 1.0X + generate explode map wholestage on 866 / 919 19.4 51.6 13.8X + */ + + runBenchmark("generate posexplode array", N) { + val df = sparkSession.range(N).selectExpr( + "id as key", + "array(rand(), rand(), rand(), rand(), rand()) as values") + df.selectExpr("key", "posexplode(values) as (idx, value)").count() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 + Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz + + generate posexplode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + generate posexplode array wholestage off 7502 / 7513 2.2 447.1 1.0X + generate posexplode array wholestage on 617 / 623 27.2 36.8 12.2X + */ + + runBenchmark("generate inline array", N) { + val df = sparkSession.range(N).selectExpr( + "id as key", + "array((rand(), rand()), (rand(), rand()), (rand(), 0.0d)) as values") + df.selectExpr("key", "inline(values) as (r1, r2)").count() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 + Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz + + generate inline array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + generate inline array wholestage off 6901 / 6928 2.4 411.3 1.0X + generate inline array wholestage on 1001 / 1010 16.8 59.7 6.9X + */ + } + + ignore("generate regular generator") { + val N = 1 << 24 + runBenchmark("generate stack", N) { + val df = sparkSession.range(N).selectExpr( + "id as key", + "id % 2 as t1", + "id % 3 as t2", + "id % 5 as t3", + "id % 7 as t4", + "id % 13 as t5") + df.selectExpr("key", "stack(4, t1, t2, t3, t4, t5)").count() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 + Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz + + generate stack: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + generate stack wholestage off 12953 / 13070 1.3 772.1 1.0X + generate stack wholestage on 836 / 847 20.1 49.8 15.5X + */ + } } From c528812ce770fd8a6626e7f9d2f8ca9d1e84642b Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 20 Nov 2016 09:52:03 +0000 Subject: [PATCH 192/198] [SPARK-3359][BUILD][DOCS] Print examples and disable group and tparam tags in javadoc ## What changes were proposed in this pull request? This PR proposes/fixes two things. - Remove many errors to generate javadoc with Java8 from unrecognisable tags, `tparam` and `group`. ``` [error] .../spark/mllib/target/java/org/apache/spark/ml/classification/Classifier.java:18: error: unknown tag: group [error] /** group setParam */ [error] ^ [error] .../spark/mllib/target/java/org/apache/spark/ml/classification/Classifier.java:8: error: unknown tag: tparam [error] * tparam FeaturesType Type of input features. E.g., Vector [error] ^ ... ``` It does not fully resolve the problem but remove many errors. It seems both `group` and `tparam` are unrecognisable in javadoc. It seems we can't print them pretty in javadoc in a way of `example` here because they appear differently (both examples can be found in http://spark.apache.org/docs/2.0.2/api/scala/index.html#org.apache.spark.ml.classification.Classifier). - Print `example` in javadoc. Currently, there are few `example` tag in several places. ``` ./graphx/src/main/scala/org/apache/spark/graphx/Graph.scala: * example This operation might be used to evaluate a graph ./graphx/src/main/scala/org/apache/spark/graphx/Graph.scala: * example We might use this operation to change the vertex values ./graphx/src/main/scala/org/apache/spark/graphx/Graph.scala: * example This function might be used to initialize edge ./graphx/src/main/scala/org/apache/spark/graphx/Graph.scala: * example This function might be used to initialize edge ./graphx/src/main/scala/org/apache/spark/graphx/Graph.scala: * example This function might be used to initialize edge ./graphx/src/main/scala/org/apache/spark/graphx/Graph.scala: * example We can use this function to compute the in-degree of each ./graphx/src/main/scala/org/apache/spark/graphx/Graph.scala: * example This function is used to update the vertices with new values based on external data. ./graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala: * example Loads a file in the following format: ./graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala: * example This function is used to update the vertices with new ./graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala: * example This function can be used to filter the graph based on some property, without ./graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala: * example We can use the Pregel abstraction to implement PageRank: ./graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala: * example Construct a `VertexRDD` from a plain RDD: ./repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala: * example new SparkCommandLine(Nil).settings ./repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala: * example addImports("org.apache.spark.SparkContext") ./sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala: * example {{{ ``` **Before** 2016-11-20 2 43 23 **After** 2016-11-20 1 27 17 ## How was this patch tested? Maunally tested by `jekyll build` with Java 7 and 8 ``` java version "1.7.0_80" Java(TM) SE Runtime Environment (build 1.7.0_80-b15) Java HotSpot(TM) 64-Bit Server VM (build 24.80-b11, mixed mode) ``` ``` java version "1.8.0_45" Java(TM) SE Runtime Environment (build 1.8.0_45-b14) Java HotSpot(TM) 64-Bit Server VM (build 25.45-b02, mixed mode) ``` Note: this does not make sbt unidoc suceed with Java 8 yet but it reduces the number of errors with Java 8. Author: hyukjinkwon Closes #15939 from HyukjinKwon/SPARK-3359-javadoc. --- pom.xml | 13 +++++++++++++ project/SparkBuild.scala | 5 ++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 024b2850d0a3d..7c0b0b59dc62b 100644 --- a/pom.xml +++ b/pom.xml @@ -2477,11 +2477,24 @@ -Xdoclint:all -Xdoclint:-missing + + example + a + Example: + note a Note: + + group + X + + + tparam + X + diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 92b45657210e1..429a163d22a6d 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -742,7 +742,10 @@ object Unidoc { "-windowtitle", "Spark " + version.value.replaceAll("-SNAPSHOT", "") + " JavaDoc", "-public", "-noqualifier", "java.lang", - "-tag", """note:a:Note\:""" + "-tag", """example:a:Example\:""", + "-tag", """note:a:Note\:""", + "-tag", "group:X", + "-tag", "tparam:X" ), // Use GitHub repository for Scaladoc source links From 6659ae555a464c7a16881b660265061481c0d25f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 20 Nov 2016 13:56:08 -0800 Subject: [PATCH 193/198] Fix Mesos build break for Scala 2.10. --- .../deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala b/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala index b6c0b325361da..33e7d69d53d38 100644 --- a/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala +++ b/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala @@ -55,7 +55,7 @@ class MesosClusterDispatcherArgumentsSuite extends SparkFunSuite assert(Option(mesosDispClusterArgs.masterUrl).isDefined) assert(mesosDispClusterArgs.masterUrl == masterUrl.stripPrefix("mesos://")) assert(Option(mesosDispClusterArgs.zookeeperUrl).isDefined) - assert(mesosDispClusterArgs.zookeeperUrl contains zookeeperUrl) + assert(mesosDispClusterArgs.zookeeperUrl == Some(zookeeperUrl)) assert(mesosDispClusterArgs.name == name) assert(mesosDispClusterArgs.webUiPort == webUiPort.toInt) assert(mesosDispClusterArgs.port == port.toInt) From b625a36ebc59cbacc223fc03005bc0f6d296b6e7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 20 Nov 2016 20:00:59 -0800 Subject: [PATCH 194/198] [HOTFIX][SQL] Fix DDLSuite failure. --- .../org/apache/spark/sql/execution/command/DDLSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index a01073987423e..02d9d15684904 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1426,8 +1426,8 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { sql("DESCRIBE FUNCTION 'concat'"), Row("Class: org.apache.spark.sql.catalyst.expressions.Concat") :: Row("Function: concat") :: - Row("Usage: concat(str1, str2, ..., strN) " + - "- Returns the concatenation of `str1`, `str2`, ..., `strN`.") :: Nil + Row("Usage: concat(str1, str2, ..., strN) - " + + "Returns the concatenation of str1, str2, ..., strN.") :: Nil ) // extended mode checkAnswer( From 658547974915ebcaae83e13e4c3bdf68d5426fda Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 21 Nov 2016 12:05:01 +0800 Subject: [PATCH 195/198] [SPARK-18467][SQL] Extracts method for preparing arguments from StaticInvoke, Invoke and NewInstance and modify to short circuit if arguments have null when `needNullCheck == true`. ## What changes were proposed in this pull request? This pr extracts method for preparing arguments from `StaticInvoke`, `Invoke` and `NewInstance` and modify to short circuit if arguments have `null` when `propageteNull == true`. The steps are as follows: 1. Introduce `InvokeLike` to extract common logic from `StaticInvoke`, `Invoke` and `NewInstance` to prepare arguments. `StaticInvoke` and `Invoke` had a risk to exceed 64kb JVM limit to prepare arguments but after this patch they can handle them because they share the preparing code of NewInstance, which handles the limit well. 2. Remove unneeded null checking and fix nullability of `NewInstance`. Avoid some of nullabilty checking which are not needed because the expression is not nullable. 3. Modify to short circuit if arguments have `null` when `needNullCheck == true`. If `needNullCheck == true`, preparing arguments can be skipped if we found one of them is `null`, so modified to short circuit in the case. ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #15901 from ueshin/issues/SPARK-18467. --- .../expressions/objects/objects.scala | 163 +++++++++++------- 1 file changed, 101 insertions(+), 62 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 0e3d99127ed56..0b36091ece1bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -32,6 +32,78 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types._ +/** + * Common base class for [[StaticInvoke]], [[Invoke]], and [[NewInstance]]. + */ +trait InvokeLike extends Expression with NonSQLExpression { + + def arguments: Seq[Expression] + + def propagateNull: Boolean + + protected lazy val needNullCheck: Boolean = propagateNull && arguments.exists(_.nullable) + + /** + * Prepares codes for arguments. + * + * - generate codes for argument. + * - use ctx.splitExpressions() to not exceed 64kb JVM limit while preparing arguments. + * - avoid some of nullabilty checking which are not needed because the expression is not + * nullable. + * - when needNullCheck == true, short circuit if we found one of arguments is null because + * preparing rest of arguments can be skipped in the case. + * + * @param ctx a [[CodegenContext]] + * @return (code to prepare arguments, argument string, result of argument null check) + */ + def prepareArguments(ctx: CodegenContext): (String, String, String) = { + + val resultIsNull = if (needNullCheck) { + val resultIsNull = ctx.freshName("resultIsNull") + ctx.addMutableState("boolean", resultIsNull, "") + resultIsNull + } else { + "false" + } + val argValues = arguments.map { e => + val argValue = ctx.freshName("argValue") + ctx.addMutableState(ctx.javaType(e.dataType), argValue, "") + argValue + } + + val argCodes = if (needNullCheck) { + val reset = s"$resultIsNull = false;" + val argCodes = arguments.zipWithIndex.map { case (e, i) => + val expr = e.genCode(ctx) + val updateResultIsNull = if (e.nullable) { + s"$resultIsNull = ${expr.isNull};" + } else { + "" + } + s""" + if (!$resultIsNull) { + ${expr.code} + $updateResultIsNull + ${argValues(i)} = ${expr.value}; + } + """ + } + reset +: argCodes + } else { + arguments.zipWithIndex.map { case (e, i) => + val expr = e.genCode(ctx) + s""" + ${expr.code} + ${argValues(i)} = ${expr.value}; + """ + } + } + val argCode = ctx.splitExpressions(ctx.INPUT_ROW, argCodes) + + (argCode, argValues.mkString(", "), resultIsNull) + } +} + /** * Invokes a static function, returning the result. By default, any of the arguments being null * will result in returning null instead of calling the function. @@ -50,7 +122,7 @@ case class StaticInvoke( dataType: DataType, functionName: String, arguments: Seq[Expression] = Nil, - propagateNull: Boolean = true) extends Expression with NonSQLExpression { + propagateNull: Boolean = true) extends InvokeLike { val objectName = staticObject.getName.stripSuffix("$") @@ -62,16 +134,10 @@ case class StaticInvoke( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) - val argGen = arguments.map(_.genCode(ctx)) - val argString = argGen.map(_.value).mkString(", ") - val callFunc = s"$objectName.$functionName($argString)" + val (argCode, argString, resultIsNull) = prepareArguments(ctx) - val setIsNull = if (propagateNull && arguments.nonEmpty) { - s"boolean ${ev.isNull} = ${argGen.map(_.isNull).mkString(" || ")};" - } else { - s"boolean ${ev.isNull} = false;" - } + val callFunc = s"$objectName.$functionName($argString)" // If the function can return null, we do an extra check to make sure our null bit is still set // correctly. @@ -82,9 +148,9 @@ case class StaticInvoke( } val code = s""" - ${argGen.map(_.code).mkString("\n")} - $setIsNull - final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : $callFunc; + $argCode + boolean ${ev.isNull} = $resultIsNull; + final $javaType ${ev.value} = $resultIsNull ? ${ctx.defaultValue(dataType)} : $callFunc; $postNullCheck """ ev.copy(code = code) @@ -103,13 +169,15 @@ case class StaticInvoke( * @param functionName The name of the method to call. * @param dataType The expected return type of the function. * @param arguments An optional list of expressions, whos evaluation will be passed to the function. + * @param propagateNull When true, and any of the arguments is null, null will be returned instead + * of calling the function. */ case class Invoke( targetObject: Expression, functionName: String, dataType: DataType, arguments: Seq[Expression] = Nil, - propagateNull: Boolean = true) extends Expression with NonSQLExpression { + propagateNull: Boolean = true) extends InvokeLike { override def nullable: Boolean = true override def children: Seq[Expression] = targetObject +: arguments @@ -131,8 +199,8 @@ case class Invoke( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) val obj = targetObject.genCode(ctx) - val argGen = arguments.map(_.genCode(ctx)) - val argString = argGen.map(_.value).mkString(", ") + + val (argCode, argString, resultIsNull) = prepareArguments(ctx) val returnPrimitive = method.isDefined && method.get.getReturnType.isPrimitive val needTryCatch = method.isDefined && method.get.getExceptionTypes.nonEmpty @@ -164,12 +232,6 @@ case class Invoke( """ } - val setIsNull = if (propagateNull && arguments.nonEmpty) { - s"boolean ${ev.isNull} = ${obj.isNull} || ${argGen.map(_.isNull).mkString(" || ")};" - } else { - s"boolean ${ev.isNull} = ${obj.isNull};" - } - // If the function can return null, we do an extra check to make sure our null bit is still set // correctly. val postNullCheck = if (ctx.defaultValue(dataType) == "null") { @@ -177,15 +239,19 @@ case class Invoke( } else { "" } + val code = s""" ${obj.code} - ${argGen.map(_.code).mkString("\n")} - $setIsNull + boolean ${ev.isNull} = true; $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - $evaluate + if (!${obj.isNull}) { + $argCode + ${ev.isNull} = $resultIsNull; + if (!${ev.isNull}) { + $evaluate + } + $postNullCheck } - $postNullCheck """ ev.copy(code = code) } @@ -223,10 +289,10 @@ case class NewInstance( arguments: Seq[Expression], propagateNull: Boolean, dataType: DataType, - outerPointer: Option[() => AnyRef]) extends Expression with NonSQLExpression { + outerPointer: Option[() => AnyRef]) extends InvokeLike { private val className = cls.getName - override def nullable: Boolean = propagateNull + override def nullable: Boolean = needNullCheck override def children: Seq[Expression] = arguments @@ -245,52 +311,25 @@ case class NewInstance( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) - val argIsNulls = ctx.freshName("argIsNulls") - ctx.addMutableState("boolean[]", argIsNulls, - s"$argIsNulls = new boolean[${arguments.size}];") - val argValues = arguments.zipWithIndex.map { case (e, i) => - val argValue = ctx.freshName("argValue") - ctx.addMutableState(ctx.javaType(e.dataType), argValue, "") - argValue - } - val argCodes = arguments.zipWithIndex.map { case (e, i) => - val expr = e.genCode(ctx) - expr.code + s""" - $argIsNulls[$i] = ${expr.isNull}; - ${argValues(i)} = ${expr.value}; - """ - } - val argCode = ctx.splitExpressions(ctx.INPUT_ROW, argCodes) + val (argCode, argString, resultIsNull) = prepareArguments(ctx) val outer = outerPointer.map(func => Literal.fromObject(func()).genCode(ctx)) - var isNull = ev.isNull - val setIsNull = if (propagateNull && arguments.nonEmpty) { - s""" - boolean $isNull = false; - for (int idx = 0; idx < ${arguments.length}; idx++) { - if ($argIsNulls[idx]) { $isNull = true; break; } - } - """ - } else { - isNull = "false" - "" - } + ev.isNull = resultIsNull val constructorCall = outer.map { gen => - s"""${gen.value}.new ${cls.getSimpleName}(${argValues.mkString(", ")})""" + s"${gen.value}.new ${cls.getSimpleName}($argString)" }.getOrElse { - s"new $className(${argValues.mkString(", ")})" + s"new $className($argString)" } val code = s""" $argCode ${outer.map(_.code).getOrElse("")} - $setIsNull - final $javaType ${ev.value} = $isNull ? ${ctx.defaultValue(javaType)} : $constructorCall; - """ - ev.copy(code = code, isNull = isNull) + final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $constructorCall; + """ + ev.copy(code = code) } override def toString: String = s"newInstance($cls)" From e811fbf9ed131bccbc46f3c5701c4ff317222fd9 Mon Sep 17 00:00:00 2001 From: sethah Date: Mon, 21 Nov 2016 05:36:49 -0800 Subject: [PATCH 196/198] [SPARK-18282][ML][PYSPARK] Add python clustering summaries for GMM and BKM ## What changes were proposed in this pull request? Add model summary APIs for `GaussianMixtureModel` and `BisectingKMeansModel` in pyspark. ## How was this patch tested? Unit tests. Author: sethah Closes #15777 from sethah/pyspark_cluster_summaries. --- .../classification/LogisticRegression.scala | 11 +- .../spark/ml/clustering/BisectingKMeans.scala | 9 +- .../spark/ml/clustering/GaussianMixture.scala | 9 +- .../apache/spark/ml/clustering/KMeans.scala | 9 +- .../GeneralizedLinearRegression.scala | 11 +- .../ml/regression/LinearRegression.scala | 14 +- .../LogisticRegressionSuite.scala | 2 + .../ml/clustering/BisectingKMeansSuite.scala | 3 + .../ml/clustering/GaussianMixtureSuite.scala | 3 + .../spark/ml/clustering/KMeansSuite.scala | 3 + .../GeneralizedLinearRegressionSuite.scala | 2 + .../ml/regression/LinearRegressionSuite.scala | 2 + python/pyspark/ml/classification.py | 15 +- python/pyspark/ml/clustering.py | 162 +++++++++++++++++- python/pyspark/ml/regression.py | 16 +- python/pyspark/ml/tests.py | 32 ++++ 16 files changed, 256 insertions(+), 47 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index f58efd36a1c66..d07b4adebb08f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -648,7 +648,7 @@ class LogisticRegression @Since("1.2.0") ( $(labelCol), $(featuresCol), objectiveHistory) - model.setSummary(logRegSummary) + model.setSummary(Some(logRegSummary)) } else { model } @@ -790,9 +790,9 @@ class LogisticRegressionModel private[spark] ( } } - private[classification] def setSummary( - summary: LogisticRegressionTrainingSummary): this.type = { - this.trainingSummary = Some(summary) + private[classification] + def setSummary(summary: Option[LogisticRegressionTrainingSummary]): this.type = { + this.trainingSummary = summary this } @@ -887,8 +887,7 @@ class LogisticRegressionModel private[spark] ( override def copy(extra: ParamMap): LogisticRegressionModel = { val newModel = copyValues(new LogisticRegressionModel(uid, coefficientMatrix, interceptVector, numClasses, isMultinomial), extra) - if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) - newModel.setParent(parent) + newModel.setSummary(trainingSummary).setParent(parent) } override protected def raw2prediction(rawPrediction: Vector): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index f8a606d60b2aa..e6ca3aedffd9d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -95,8 +95,7 @@ class BisectingKMeansModel private[ml] ( @Since("2.0.0") override def copy(extra: ParamMap): BisectingKMeansModel = { val copied = copyValues(new BisectingKMeansModel(uid, parentModel), extra) - if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get) - copied.setParent(this.parent) + copied.setSummary(trainingSummary).setParent(this.parent) } @Since("2.0.0") @@ -132,8 +131,8 @@ class BisectingKMeansModel private[ml] ( private var trainingSummary: Option[BisectingKMeansSummary] = None - private[clustering] def setSummary(summary: BisectingKMeansSummary): this.type = { - this.trainingSummary = Some(summary) + private[clustering] def setSummary(summary: Option[BisectingKMeansSummary]): this.type = { + this.trainingSummary = summary this } @@ -265,7 +264,7 @@ class BisectingKMeans @Since("2.0.0") ( val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this)) val summary = new BisectingKMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) - model.setSummary(summary) + model.setSummary(Some(summary)) instr.logSuccess(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index c6035cc4c9647..92d0b7d085f12 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -90,8 +90,7 @@ class GaussianMixtureModel private[ml] ( @Since("2.0.0") override def copy(extra: ParamMap): GaussianMixtureModel = { val copied = copyValues(new GaussianMixtureModel(uid, weights, gaussians), extra) - if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get) - copied.setParent(this.parent) + copied.setSummary(trainingSummary).setParent(this.parent) } @Since("2.0.0") @@ -150,8 +149,8 @@ class GaussianMixtureModel private[ml] ( private var trainingSummary: Option[GaussianMixtureSummary] = None - private[clustering] def setSummary(summary: GaussianMixtureSummary): this.type = { - this.trainingSummary = Some(summary) + private[clustering] def setSummary(summary: Option[GaussianMixtureSummary]): this.type = { + this.trainingSummary = summary this } @@ -340,7 +339,7 @@ class GaussianMixture @Since("2.0.0") ( .setParent(this) val summary = new GaussianMixtureSummary(model.transform(dataset), $(predictionCol), $(probabilityCol), $(featuresCol), $(k)) - model.setSummary(summary) + model.setSummary(Some(summary)) instr.logNumFeatures(model.gaussians.head.mean.size) instr.logSuccess(model) model diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 26505b4cc1501..152bd13b7a17a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -110,8 +110,7 @@ class KMeansModel private[ml] ( @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { val copied = copyValues(new KMeansModel(uid, parentModel), extra) - if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get) - copied.setParent(this.parent) + copied.setSummary(trainingSummary).setParent(this.parent) } /** @group setParam */ @@ -165,8 +164,8 @@ class KMeansModel private[ml] ( private var trainingSummary: Option[KMeansSummary] = None - private[clustering] def setSummary(summary: KMeansSummary): this.type = { - this.trainingSummary = Some(summary) + private[clustering] def setSummary(summary: Option[KMeansSummary]): this.type = { + this.trainingSummary = summary this } @@ -325,7 +324,7 @@ class KMeans @Since("1.5.0") ( val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) val summary = new KMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) - model.setSummary(summary) + model.setSummary(Some(summary)) instr.logSuccess(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 736fd3b9e0f64..3f9de1fe74c9c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -270,7 +270,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val .setParent(this)) val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model, wlsModel.diagInvAtWA.toArray, 1, getSolver) - return model.setSummary(trainingSummary) + return model.setSummary(Some(trainingSummary)) } // Fit Generalized Linear Model by iteratively reweighted least squares (IRLS). @@ -284,7 +284,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val .setParent(this)) val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model, irlsModel.diagInvAtWA.toArray, irlsModel.numIterations, getSolver) - model.setSummary(trainingSummary) + model.setSummary(Some(trainingSummary)) } @Since("2.0.0") @@ -761,8 +761,8 @@ class GeneralizedLinearRegressionModel private[ml] ( def hasSummary: Boolean = trainingSummary.nonEmpty private[regression] - def setSummary(summary: GeneralizedLinearRegressionTrainingSummary): this.type = { - this.trainingSummary = Some(summary) + def setSummary(summary: Option[GeneralizedLinearRegressionTrainingSummary]): this.type = { + this.trainingSummary = summary this } @@ -778,8 +778,7 @@ class GeneralizedLinearRegressionModel private[ml] ( override def copy(extra: ParamMap): GeneralizedLinearRegressionModel = { val copied = copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra) - if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get) - copied.setParent(parent) + copied.setSummary(trainingSummary).setParent(parent) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index da7ce6b46f2ab..8ea5e1e6c453a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -225,7 +225,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String model.diagInvAtWA.toArray, model.objectiveHistory) - return lrModel.setSummary(trainingSummary) + return lrModel.setSummary(Some(trainingSummary)) } val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE @@ -278,7 +278,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String model, Array(0D), Array(0D)) - return model.setSummary(trainingSummary) + return model.setSummary(Some(trainingSummary)) } else { require($(regParam) == 0.0, "The standard deviation of the label is zero. " + "Model cannot be regularized.") @@ -400,7 +400,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String model, Array(0D), objectiveHistory) - model.setSummary(trainingSummary) + model.setSummary(Some(trainingSummary)) } @Since("1.4.0") @@ -446,8 +446,9 @@ class LinearRegressionModel private[ml] ( throw new SparkException("No training summary available for this LinearRegressionModel") } - private[regression] def setSummary(summary: LinearRegressionTrainingSummary): this.type = { - this.trainingSummary = Some(summary) + private[regression] + def setSummary(summary: Option[LinearRegressionTrainingSummary]): this.type = { + this.trainingSummary = summary this } @@ -490,8 +491,7 @@ class LinearRegressionModel private[ml] ( @Since("1.4.0") override def copy(extra: ParamMap): LinearRegressionModel = { val newModel = copyValues(new LinearRegressionModel(uid, coefficients, intercept), extra) - if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) - newModel.setParent(parent) + newModel.setSummary(trainingSummary).setParent(parent) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 2877285eb4d59..e360542eae2ab 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -147,6 +147,8 @@ class LogisticRegressionSuite assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) + model.setSummary(None) + assert(!model.hasSummary) } test("empty probabilityCol") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 49797d938d751..fc491cd6161fd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -109,6 +109,9 @@ class BisectingKMeansSuite assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + + model.setSummary(None) + assert(!model.hasSummary) } test("read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 7165b63ed3b96..07299123f8a47 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -111,6 +111,9 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + + model.setSummary(None) + assert(!model.hasSummary) } test("read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 73972557d2631..c1b7242e11a8f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -123,6 +123,9 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + + model.setSummary(None) + assert(!model.hasSummary) } test("KMeansModel transform with non-default feature and prediction cols") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 6a4ac1735b2cb..9b0fa67630d2e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -197,6 +197,8 @@ class GeneralizedLinearRegressionSuite assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) + model.setSummary(None) + assert(!model.hasSummary) assert(model.getFeaturesCol === "features") assert(model.getPredictionCol === "prediction") diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index df97d0b2ae7ad..0be82742a33be 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -146,6 +146,8 @@ class LinearRegressionSuite assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) + model.setSummary(None) + assert(!model.hasSummary) model.transform(datasetWithDenseFeature) .select("label", "prediction") diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 56c8c62259e79..83e1e89347660 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -309,13 +309,16 @@ def interceptVector(self): @since("2.0.0") def summary(self): """ - Gets summary (e.g. residuals, mse, r-squared ) of model on - training set. An exception is thrown if - `trainingSummary is None`. + Gets summary (e.g. accuracy/precision/recall, objective history, total iterations) of model + trained on the training set. An exception is thrown if `trainingSummary is None`. """ - java_blrt_summary = self._call_java("summary") - # Note: Once multiclass is added, update this to return correct summary - return BinaryLogisticRegressionTrainingSummary(java_blrt_summary) + if self.hasSummary: + java_blrt_summary = self._call_java("summary") + # Note: Once multiclass is added, update this to return correct summary + return BinaryLogisticRegressionTrainingSummary(java_blrt_summary) + else: + raise RuntimeError("No training summary available for this %s" % + self.__class__.__name__) @property @since("2.0.0") diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 7632f05c3b68c..e58ec1e7ac296 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -17,16 +17,74 @@ from pyspark import since, keyword_only from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper from pyspark.ml.param.shared import * from pyspark.ml.common import inherit_doc -__all__ = ['BisectingKMeans', 'BisectingKMeansModel', +__all__ = ['BisectingKMeans', 'BisectingKMeansModel', 'BisectingKMeansSummary', 'KMeans', 'KMeansModel', - 'GaussianMixture', 'GaussianMixtureModel', + 'GaussianMixture', 'GaussianMixtureModel', 'GaussianMixtureSummary', 'LDA', 'LDAModel', 'LocalLDAModel', 'DistributedLDAModel'] +class ClusteringSummary(JavaWrapper): + """ + .. note:: Experimental + + Clustering results for a given model. + + .. versionadded:: 2.1.0 + """ + + @property + @since("2.1.0") + def predictionCol(self): + """ + Name for column of predicted clusters in `predictions`. + """ + return self._call_java("predictionCol") + + @property + @since("2.1.0") + def predictions(self): + """ + DataFrame produced by the model's `transform` method. + """ + return self._call_java("predictions") + + @property + @since("2.1.0") + def featuresCol(self): + """ + Name for column of features in `predictions`. + """ + return self._call_java("featuresCol") + + @property + @since("2.1.0") + def k(self): + """ + The number of clusters the model was trained with. + """ + return self._call_java("k") + + @property + @since("2.1.0") + def cluster(self): + """ + DataFrame of predicted cluster centers for each training data point. + """ + return self._call_java("cluster") + + @property + @since("2.1.0") + def clusterSizes(self): + """ + Size of (number of data points in) each cluster. + """ + return self._call_java("clusterSizes") + + class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable): """ .. note:: Experimental @@ -56,6 +114,28 @@ def gaussiansDF(self): """ return self._call_java("gaussiansDF") + @property + @since("2.1.0") + def hasSummary(self): + """ + Indicates whether a training summary exists for this model + instance. + """ + return self._call_java("hasSummary") + + @property + @since("2.1.0") + def summary(self): + """ + Gets summary (e.g. cluster assignments, cluster sizes) of the model trained on the + training set. An exception is thrown if no summary exists. + """ + if self.hasSummary: + return GaussianMixtureSummary(self._call_java("summary")) + else: + raise RuntimeError("No training summary available for this %s" % + self.__class__.__name__) + @inherit_doc class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed, @@ -92,6 +172,13 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte >>> gm = GaussianMixture(k=3, tol=0.0001, ... maxIter=10, seed=10) >>> model = gm.fit(df) + >>> model.hasSummary + True + >>> summary = model.summary + >>> summary.k + 3 + >>> summary.clusterSizes + [2, 2, 2] >>> weights = model.weights >>> len(weights) 3 @@ -118,6 +205,8 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte >>> model_path = temp_path + "/gmm_model" >>> model.save(model_path) >>> model2 = GaussianMixtureModel.load(model_path) + >>> model2.hasSummary + False >>> model2.weights == model.weights True >>> model2.gaussiansDF.show() @@ -181,6 +270,32 @@ def getK(self): return self.getOrDefault(self.k) +class GaussianMixtureSummary(ClusteringSummary): + """ + .. note:: Experimental + + Gaussian mixture clustering results for a given model. + + .. versionadded:: 2.1.0 + """ + + @property + @since("2.1.0") + def probabilityCol(self): + """ + Name for column of predicted probability of each cluster in `predictions`. + """ + return self._call_java("probabilityCol") + + @property + @since("2.1.0") + def probability(self): + """ + DataFrame of probabilities of each cluster for each training data point. + """ + return self._call_java("probability") + + class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by KMeans. @@ -346,6 +461,27 @@ def computeCost(self, dataset): """ return self._call_java("computeCost", dataset) + @property + @since("2.1.0") + def hasSummary(self): + """ + Indicates whether a training summary exists for this model instance. + """ + return self._call_java("hasSummary") + + @property + @since("2.1.0") + def summary(self): + """ + Gets summary (e.g. cluster assignments, cluster sizes) of the model trained on the + training set. An exception is thrown if no summary exists. + """ + if self.hasSummary: + return BisectingKMeansSummary(self._call_java("summary")) + else: + raise RuntimeError("No training summary available for this %s" % + self.__class__.__name__) + @inherit_doc class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasSeed, @@ -373,6 +509,13 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte 2 >>> model.computeCost(df) 2.000... + >>> model.hasSummary + True + >>> summary = model.summary + >>> summary.k + 2 + >>> summary.clusterSizes + [2, 2] >>> transformed = model.transform(df).select("features", "prediction") >>> rows = transformed.collect() >>> rows[0].prediction == rows[1].prediction @@ -387,6 +530,8 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte >>> model_path = temp_path + "/bkm_model" >>> model.save(model_path) >>> model2 = BisectingKMeansModel.load(model_path) + >>> model2.hasSummary + False >>> model.clusterCenters()[0] == model2.clusterCenters()[0] array([ True, True], dtype=bool) >>> model.clusterCenters()[1] == model2.clusterCenters()[1] @@ -460,6 +605,17 @@ def _create_model(self, java_model): return BisectingKMeansModel(java_model) +class BisectingKMeansSummary(ClusteringSummary): + """ + .. note:: Experimental + + Bisecting KMeans clustering results for a given model. + + .. versionadded:: 2.1.0 + """ + pass + + @inherit_doc class LDAModel(JavaModel): """ diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 0bc319ca4d601..385391ba53fd4 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -160,8 +160,12 @@ def summary(self): training set. An exception is thrown if `trainingSummary is None`. """ - java_lrt_summary = self._call_java("summary") - return LinearRegressionTrainingSummary(java_lrt_summary) + if self.hasSummary: + java_lrt_summary = self._call_java("summary") + return LinearRegressionTrainingSummary(java_lrt_summary) + else: + raise RuntimeError("No training summary available for this %s" % + self.__class__.__name__) @property @since("2.0.0") @@ -1459,8 +1463,12 @@ def summary(self): training set. An exception is thrown if `trainingSummary is None`. """ - java_glrt_summary = self._call_java("summary") - return GeneralizedLinearRegressionTrainingSummary(java_glrt_summary) + if self.hasSummary: + java_glrt_summary = self._call_java("summary") + return GeneralizedLinearRegressionTrainingSummary(java_glrt_summary) + else: + raise RuntimeError("No training summary available for this %s" % + self.__class__.__name__) @property @since("2.0.0") diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 9d46cc3b4ae64..c0f0d4073564e 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1097,6 +1097,38 @@ def test_logistic_regression_summary(self): sameSummary = model.evaluate(df) self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC) + def test_gaussian_mixture_summary(self): + data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),), + (Vectors.sparse(1, [], []),)] + df = self.spark.createDataFrame(data, ["features"]) + gmm = GaussianMixture(k=2) + model = gmm.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.probabilityCol, "probability") + self.assertTrue(isinstance(s.probability, DataFrame)) + self.assertEqual(s.featuresCol, "features") + self.assertEqual(s.predictionCol, "prediction") + self.assertTrue(isinstance(s.cluster, DataFrame)) + self.assertEqual(len(s.clusterSizes), 2) + self.assertEqual(s.k, 2) + + def test_bisecting_kmeans_summary(self): + data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),), + (Vectors.sparse(1, [], []),)] + df = self.spark.createDataFrame(data, ["features"]) + bkm = BisectingKMeans(k=2) + model = bkm.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.featuresCol, "features") + self.assertEqual(s.predictionCol, "prediction") + self.assertTrue(isinstance(s.cluster, DataFrame)) + self.assertEqual(len(s.clusterSizes), 2) + self.assertEqual(s.k, 2) + class OneVsRestTests(SparkSessionTestCase): From 9f262ae163b6dca6526665b3ad12b3b2ea8fb873 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 21 Nov 2016 05:50:35 -0800 Subject: [PATCH 197/198] [SPARK-18398][SQL] Fix nullabilities of MapObjects and ExternalMapToCatalyst. ## What changes were proposed in this pull request? The nullabilities of `MapObject` can be made more strict by relying on `inputObject.nullable` and `lambdaFunction.nullable`. Also `ExternalMapToCatalyst.dataType` can be made more strict by relying on `valueConverter.nullable`. ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #15840 from ueshin/issues/SPARK-18398. --- .../spark/sql/catalyst/expressions/objects/objects.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 0b36091ece1bf..5c27179ec3b46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -461,14 +461,15 @@ case class MapObjects private( lambdaFunction: Expression, inputData: Expression) extends Expression with NonSQLExpression { - override def nullable: Boolean = true + override def nullable: Boolean = inputData.nullable override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") - override def dataType: DataType = ArrayType(lambdaFunction.dataType) + override def dataType: DataType = + ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val elementJavaType = ctx.javaType(loopVarDataType) @@ -642,7 +643,8 @@ case class ExternalMapToCatalyst private( override def foldable: Boolean = false - override def dataType: MapType = MapType(keyConverter.dataType, valueConverter.dataType) + override def dataType: MapType = MapType( + keyConverter.dataType, valueConverter.dataType, valueContainsNull = valueConverter.nullable) override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") From 07beb5d21c6803e80733149f1560c71cd3cacc86 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 21 Nov 2016 13:57:36 +0000 Subject: [PATCH 198/198] [SPARK-18413][SQL] Add `maxConnections` JDBCOption ## What changes were proposed in this pull request? This PR adds a new JDBCOption `maxConnections` which means the maximum number of simultaneous JDBC connections allowed. This option applies only to writing with coalesce operation if needed. It defaults to the number of partitions of RDD. Previously, SQL users cannot cannot control this while Scala/Java/Python users can use `coalesce` (or `repartition`) API. **Reported Scenario** For the following cases, the number of connections becomes 200 and database cannot handle all of them. ```sql CREATE OR REPLACE TEMPORARY VIEW resultview USING org.apache.spark.sql.jdbc OPTIONS ( url "jdbc:oracle:thin:10.129.10.111:1521:BKDB", dbtable "result", user "HIVE", password "HIVE" ); -- set spark.sql.shuffle.partitions=200 INSERT OVERWRITE TABLE resultview SELECT g, count(1) AS COUNT FROM tnet.DT_LIVE_INFO GROUP BY g ``` ## How was this patch tested? Manual. Do the followings and see Spark UI. **Step 1 (MySQL)** ``` CREATE TABLE t1 (a INT); CREATE TABLE data (a INT); INSERT INTO data VALUES (1); INSERT INTO data VALUES (2); INSERT INTO data VALUES (3); ``` **Step 2 (Spark)** ```scala SPARK_HOME=$PWD bin/spark-shell --driver-memory 4G --driver-class-path mysql-connector-java-5.1.40-bin.jar scala> sql("SET spark.sql.shuffle.partitions=3") scala> sql("CREATE OR REPLACE TEMPORARY VIEW data USING org.apache.spark.sql.jdbc OPTIONS (url 'jdbc:mysql://localhost:3306/t', dbtable 'data', user 'root', password '')") scala> sql("CREATE OR REPLACE TEMPORARY VIEW t1 USING org.apache.spark.sql.jdbc OPTIONS (url 'jdbc:mysql://localhost:3306/t', dbtable 't1', user 'root', password '', maxConnections '1')") scala> sql("INSERT OVERWRITE TABLE t1 SELECT a FROM data GROUP BY a") scala> sql("CREATE OR REPLACE TEMPORARY VIEW t1 USING org.apache.spark.sql.jdbc OPTIONS (url 'jdbc:mysql://localhost:3306/t', dbtable 't1', user 'root', password '', maxConnections '2')") scala> sql("INSERT OVERWRITE TABLE t1 SELECT a FROM data GROUP BY a") scala> sql("CREATE OR REPLACE TEMPORARY VIEW t1 USING org.apache.spark.sql.jdbc OPTIONS (url 'jdbc:mysql://localhost:3306/t', dbtable 't1', user 'root', password '', maxConnections '3')") scala> sql("INSERT OVERWRITE TABLE t1 SELECT a FROM data GROUP BY a") scala> sql("CREATE OR REPLACE TEMPORARY VIEW t1 USING org.apache.spark.sql.jdbc OPTIONS (url 'jdbc:mysql://localhost:3306/t', dbtable 't1', user 'root', password '', maxConnections '4')") scala> sql("INSERT OVERWRITE TABLE t1 SELECT a FROM data GROUP BY a") ``` ![maxconnections](https://cloud.githubusercontent.com/assets/9700541/20287987/ed8409c2-aa84-11e6-8aab-ae28e63fe54d.png) Author: Dongjoon Hyun Closes #15868 from dongjoon-hyun/SPARK-18413. --- docs/sql-programming-guide.md | 7 +++++++ .../sql/execution/datasources/jdbc/JDBCOptions.scala | 6 ++++++ .../sql/execution/datasources/jdbc/JdbcUtils.scala | 9 ++++++++- .../org/apache/spark/sql/jdbc/JDBCWriteSuite.scala | 12 ++++++++++++ 4 files changed, 33 insertions(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index ba3e55fc061a7..656e7ecdab0bb 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1086,6 +1086,13 @@ the following case-sensitive options: + + maxConnections + + The maximum number of concurrent JDBC connections that can be used, if set. Only applies when writing. It works by limiting the operation's parallelism, which depends on the input's partition count. If its partition count exceeds this limit, the operation will coalesce the input to fewer partitions before writing. + + + isolationLevel diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 7f419b5788c4f..d416eec6ddaec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -122,6 +122,11 @@ class JDBCOptions( case "REPEATABLE_READ" => Connection.TRANSACTION_REPEATABLE_READ case "SERIALIZABLE" => Connection.TRANSACTION_SERIALIZABLE } + // the maximum number of connections + val maxConnections = parameters.get(JDBC_MAX_CONNECTIONS).map(_.toInt) + require(maxConnections.isEmpty || maxConnections.get > 0, + s"Invalid value `${maxConnections.get}` for parameter `$JDBC_MAX_CONNECTIONS`. " + + "The minimum value is 1.") } object JDBCOptions { @@ -144,4 +149,5 @@ object JDBCOptions { val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions") val JDBC_BATCH_INSERT_SIZE = newOption("batchsize") val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel") + val JDBC_MAX_CONNECTIONS = newOption("maxConnections") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 41edb6511c2ce..cdc3c99daa1ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -667,7 +667,14 @@ object JdbcUtils extends Logging { val getConnection: () => Connection = createConnectionFactory(options) val batchSize = options.batchSize val isolationLevel = options.isolationLevel - df.foreachPartition(iterator => savePartition( + val maxConnections = options.maxConnections + val repartitionedDF = + if (maxConnections.isDefined && maxConnections.get < df.rdd.getNumPartitions) { + df.coalesce(maxConnections.get) + } else { + df + } + repartitionedDF.foreachPartition(iterator => savePartition( getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect, isolationLevel) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index e3d3c6c3a887c..5795b4d860cb1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -312,4 +312,16 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { .options(properties.asScala) .save() } + + test("SPARK-18413: Add `maxConnections` JDBCOption") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val e = intercept[IllegalArgumentException] { + df.write.format("jdbc") + .option("dbtable", "TEST.SAVETEST") + .option("url", url1) + .option(s"${JDBCOptions.JDBC_MAX_CONNECTIONS}", "0") + .save() + }.getMessage + assert(e.contains("Invalid value `0` for parameter `maxConnections`. The minimum value is 1")) + } }