diff --git a/README.md b/README.md
index d0eca1ddea283..1e521a7e7b178 100644
--- a/README.md
+++ b/README.md
@@ -97,7 +97,7 @@ building for particular Hive and Hive Thriftserver distributions.
Please refer to the [Configuration Guide](http://spark.apache.org/docs/latest/configuration.html)
in the online documentation for an overview on how to configure Spark.
-## Contributing
+## Contributing
Please review the [Contribution to Spark guide](http://spark.apache.org/contributing.html)
for information on how to get started contributing to the project.
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 03815631a604c..6fc66e2374bd9 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -384,9 +384,16 @@ private[serializer] object KryoSerializer {
classOf[HighlyCompressedMapStatus],
classOf[CompactBuffer[_]],
classOf[BlockManagerId],
+ classOf[Array[Boolean]],
classOf[Array[Byte]],
classOf[Array[Short]],
+ classOf[Array[Int]],
classOf[Array[Long]],
+ classOf[Array[Float]],
+ classOf[Array[Double]],
+ classOf[Array[Char]],
+ classOf[Array[String]],
+ classOf[Array[Array[String]]],
classOf[BoundedPriorityQueue[_]],
classOf[SparkConf]
)
diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
index a30653bb36fa1..7c3922e47fbb9 100644
--- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
@@ -76,6 +76,9 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext {
}
test("basic types") {
+ val conf = new SparkConf(false)
+ conf.set("spark.kryo.registrationRequired", "true")
+
val ser = new KryoSerializer(conf).newInstance()
def check[T: ClassTag](t: T) {
assert(ser.deserialize[T](ser.serialize(t)) === t)
@@ -106,6 +109,9 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext {
}
test("pairs") {
+ val conf = new SparkConf(false)
+ conf.set("spark.kryo.registrationRequired", "true")
+
val ser = new KryoSerializer(conf).newInstance()
def check[T: ClassTag](t: T) {
assert(ser.deserialize[T](ser.serialize(t)) === t)
@@ -130,12 +136,16 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext {
}
test("Scala data structures") {
+ val conf = new SparkConf(false)
+ conf.set("spark.kryo.registrationRequired", "true")
+
val ser = new KryoSerializer(conf).newInstance()
def check[T: ClassTag](t: T) {
assert(ser.deserialize[T](ser.serialize(t)) === t)
}
check(List[Int]())
check(List[Int](1, 2, 3))
+ check(Seq[Int](1, 2, 3))
check(List[String]())
check(List[String]("x", "y", "z"))
check(None)
diff --git a/docs/building-spark.md b/docs/building-spark.md
index 8353b7a520b8e..e99b70f7a8b47 100644
--- a/docs/building-spark.md
+++ b/docs/building-spark.md
@@ -154,7 +154,7 @@ Developers who compile Spark frequently may want to speed up compilation; e.g.,
developers who build with SBT). For more information about how to do this, refer to the
[Useful Developer Tools page](http://spark.apache.org/developer-tools.html#reducing-build-times).
-## Encrypted Filesystems
+## Encrypted Filesystems
When building on an encrypted filesystem (if your home directory is encrypted, for example), then the Spark build might fail with a "Filename too long" error. As a workaround, add the following in the configuration args of the `scala-maven-plugin` in the project `pom.xml`:
diff --git a/docs/ml-features.md b/docs/ml-features.md
index dad1c6db18f8b..e19fba249fb2d 100644
--- a/docs/ml-features.md
+++ b/docs/ml-features.md
@@ -1284,6 +1284,72 @@ for more details on the API.
+
+## Imputer
+
+The `Imputer` transformer completes missing values in a dataset, either using the mean or the
+median of the columns in which the missing values are located. The input columns should be of
+`DoubleType` or `FloatType`. Currently `Imputer` does not support categorical features and possibly
+creates incorrect values for columns containing categorical features.
+
+**Note** all `null` values in the input columns are treated as missing, and so are also imputed.
+
+**Examples**
+
+Suppose that we have a DataFrame with the columns `a` and `b`:
+
+~~~
+ a | b
+------------|-----------
+ 1.0 | Double.NaN
+ 2.0 | Double.NaN
+ Double.NaN | 3.0
+ 4.0 | 4.0
+ 5.0 | 5.0
+~~~
+
+In this example, Imputer will replace all occurrences of `Double.NaN` (the default for the missing value)
+with the mean (the default imputation strategy) computed from the other values in the corresponding columns.
+In this example, the surrogate values for columns `a` and `b` are 3.0 and 4.0 respectively. After
+transformation, the missing values in the output columns will be replaced by the surrogate value for
+the relevant column.
+
+~~~
+ a | b | out_a | out_b
+------------|------------|-------|-------
+ 1.0 | Double.NaN | 1.0 | 4.0
+ 2.0 | Double.NaN | 2.0 | 4.0
+ Double.NaN | 3.0 | 3.0 | 3.0
+ 4.0 | 4.0 | 4.0 | 4.0
+ 5.0 | 5.0 | 5.0 | 5.0
+~~~
+
+
+
+
+Refer to the [Imputer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Imputer)
+for more details on the API.
+
+{% include_example scala/org/apache/spark/examples/ml/ImputerExample.scala %}
+
+
+
+
+Refer to the [Imputer Java docs](api/java/org/apache/spark/ml/feature/Imputer.html)
+for more details on the API.
+
+{% include_example java/org/apache/spark/examples/ml/JavaImputerExample.java %}
+
+
+
+
+Refer to the [Imputer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Imputer)
+for more details on the API.
+
+{% include_example python/ml/imputer_example.py %}
+
+
+
# Feature Selectors
## VectorSlicer
diff --git a/docs/monitoring.md b/docs/monitoring.md
index 80519525af0c3..6cbc6660e816c 100644
--- a/docs/monitoring.md
+++ b/docs/monitoring.md
@@ -257,7 +257,7 @@ In the API, an application is referenced by its application ID, `[app-id]`.
When running on YARN, each application may have multiple attempts, but there are attempt IDs
only for applications in cluster mode, not applications in client mode. Applications in YARN cluster mode
can be identified by their `[attempt-id]`. In the API listed below, when running in YARN cluster mode,
-`[app-id]` will actually be `[base-app-id]/[attempt-id]`, where `[base-app-id]` is the YARN application ID.
+`[app-id]` will actually be `[base-app-id]/[attempt-id]`, where `[base-app-id]` is the YARN application ID.
| Endpoint | Meaning |
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaImputerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaImputerExample.java
new file mode 100644
index 0000000000000..ac40ccd9dbd75
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaImputerExample.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml;
+
+// $example on$
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.spark.ml.feature.Imputer;
+import org.apache.spark.ml.feature.ImputerModel;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.types.*;
+// $example off$
+
+import static org.apache.spark.sql.types.DataTypes.*;
+
+/**
+ * An example demonstrating Imputer.
+ * Run with:
+ * bin/run-example ml.JavaImputerExample
+ */
+public class JavaImputerExample {
+ public static void main(String[] args) {
+ SparkSession spark = SparkSession
+ .builder()
+ .appName("JavaImputerExample")
+ .getOrCreate();
+
+ // $example on$
+ List data = Arrays.asList(
+ RowFactory.create(1.0, Double.NaN),
+ RowFactory.create(2.0, Double.NaN),
+ RowFactory.create(Double.NaN, 3.0),
+ RowFactory.create(4.0, 4.0),
+ RowFactory.create(5.0, 5.0)
+ );
+ StructType schema = new StructType(new StructField[]{
+ createStructField("a", DoubleType, false),
+ createStructField("b", DoubleType, false)
+ });
+ Dataset df = spark.createDataFrame(data, schema);
+
+ Imputer imputer = new Imputer()
+ .setInputCols(new String[]{"a", "b"})
+ .setOutputCols(new String[]{"out_a", "out_b"});
+
+ ImputerModel model = imputer.fit(df);
+ model.transform(df).show();
+ // $example off$
+
+ spark.stop();
+ }
+}
diff --git a/examples/src/main/python/ml/imputer_example.py b/examples/src/main/python/ml/imputer_example.py
new file mode 100644
index 0000000000000..b8437f827e56d
--- /dev/null
+++ b/examples/src/main/python/ml/imputer_example.py
@@ -0,0 +1,50 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# $example on$
+from pyspark.ml.feature import Imputer
+# $example off$
+from pyspark.sql import SparkSession
+
+"""
+An example demonstrating Imputer.
+Run with:
+ bin/spark-submit examples/src/main/python/ml/imputer_example.py
+"""
+
+if __name__ == "__main__":
+ spark = SparkSession\
+ .builder\
+ .appName("ImputerExample")\
+ .getOrCreate()
+
+ # $example on$
+ df = spark.createDataFrame([
+ (1.0, float("nan")),
+ (2.0, float("nan")),
+ (float("nan"), 3.0),
+ (4.0, 4.0),
+ (5.0, 5.0)
+ ], ["a", "b"])
+
+ imputer = Imputer(inputCols=["a", "b"], outputCols=["out_a", "out_b"])
+ model = imputer.fit(df)
+
+ model.transform(df).show()
+ # $example off$
+
+ spark.stop()
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ImputerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ImputerExample.scala
new file mode 100644
index 0000000000000..49e98d0c622ca
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/ImputerExample.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+// $example on$
+import org.apache.spark.ml.feature.Imputer
+// $example off$
+import org.apache.spark.sql.SparkSession
+
+/**
+ * An example demonstrating Imputer.
+ * Run with:
+ * bin/run-example ml.ImputerExample
+ */
+object ImputerExample {
+
+ def main(args: Array[String]): Unit = {
+ val spark = SparkSession.builder
+ .appName("ImputerExample")
+ .getOrCreate()
+
+ // $example on$
+ val df = spark.createDataFrame(Seq(
+ (1.0, Double.NaN),
+ (2.0, Double.NaN),
+ (Double.NaN, 3.0),
+ (4.0, 4.0),
+ (5.0, 5.0)
+ )).toDF("a", "b")
+
+ val imputer = new Imputer()
+ .setInputCols(Array("a", "b"))
+ .setOutputCols(Array("out_a", "out_b"))
+
+ val model = imputer.fit(df)
+ model.transform(df).show()
+ // $example off$
+
+ spark.stop()
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
index 95c1337ed5608..ec39f964e213a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
@@ -329,7 +329,8 @@ class MultilayerPerceptronClassificationModel private[ml] (
@Since("1.5.0")
override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = {
- copyValues(new MultilayerPerceptronClassificationModel(uid, layers, weights), extra)
+ val copied = new MultilayerPerceptronClassificationModel(uid, layers, weights).setParent(parent)
+ copyValues(copied, extra)
}
@Since("2.0.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
index cbac16345a292..36a46ca6ff4b7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
@@ -96,7 +96,10 @@ class BucketedRandomProjectionLSHModel private[ml](
}
@Since("2.1.0")
- override def copy(extra: ParamMap): this.type = defaultCopy(extra)
+ override def copy(extra: ParamMap): BucketedRandomProjectionLSHModel = {
+ val copied = new BucketedRandomProjectionLSHModel(uid, randUnitVectors).setParent(parent)
+ copyValues(copied, extra)
+ }
@Since("2.1.0")
override def write: MLWriter = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
index ec4c6ad75ee23..a41bd8e689d56 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.types._
private[feature] trait ImputerParams extends Params with HasInputCols {
/**
- * The imputation strategy.
+ * The imputation strategy. Currently only "mean" and "median" are supported.
* If "mean", then replace missing values using the mean value of the feature.
* If "median", then replace missing values using the approximate median value of the feature.
* Default: mean
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
index 620e1fbb09ff7..145422a059196 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
@@ -86,7 +86,10 @@ class MinHashLSHModel private[ml](
}
@Since("2.1.0")
- override def copy(extra: ParamMap): this.type = defaultCopy(extra)
+ override def copy(extra: ParamMap): MinHashLSHModel = {
+ val copied = new MinHashLSHModel(uid, randCoefficients).setParent(parent)
+ copyValues(copied, extra)
+ }
@Since("2.1.0")
override def write: MLWriter = new MinHashLSHModel.MinHashLSHModelWriter(this)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 389898666eb8e..5a3e2929f5f52 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -268,8 +268,10 @@ class RFormulaModel private[feature](
}
@Since("1.5.0")
- override def copy(extra: ParamMap): RFormulaModel = copyValues(
- new RFormulaModel(uid, resolvedFormula, pipelineModel))
+ override def copy(extra: ParamMap): RFormulaModel = {
+ val copied = new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(parent)
+ copyValues(copied, extra)
+ }
@Since("2.0.0")
override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)"
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
index 41684d92be33a..7700099caac37 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
@@ -74,6 +74,7 @@ class MultilayerPerceptronClassifierSuite
.setMaxIter(100)
.setSolver("l-bfgs")
val model = trainer.fit(dataset)
+ MLTestingUtils.checkCopy(model)
val result = model.transform(dataset)
val predictionAndLabels = result.select("prediction", "label").collect()
predictionAndLabels.foreach { case Row(p: Double, l: Double) =>
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
index 91eac9e733312..cc81da5c66e6d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
@@ -23,7 +23,7 @@ import breeze.numerics.constants.Pi
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Dataset
@@ -89,10 +89,12 @@ class BucketedRandomProjectionLSHSuite
.setOutputCol("values")
.setBucketLength(1.0)
.setSeed(12345)
- val unitVectors = brp.fit(dataset).randUnitVectors
+ val brpModel = brp.fit(dataset)
+ val unitVectors = brpModel.randUnitVectors
unitVectors.foreach { v: Vector =>
assert(Vectors.norm(v, 2.0) ~== 1.0 absTol 1e-14)
}
+ MLTestingUtils.checkCopy(brpModel)
}
test("BucketedRandomProjectionLSH: test of LSH property") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
index a2f009310fd7a..0ddf097a6eb22 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Dataset
@@ -57,6 +57,15 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
testEstimatorAndModelReadWrite(mh, dataset, settings, settings, checkModelData)
}
+ test("Model copy and uid checks") {
+ val mh = new MinHashLSH()
+ .setInputCol("keys")
+ .setOutputCol("values")
+ val model = mh.fit(dataset)
+ assert(mh.uid === model.uid)
+ MLTestingUtils.checkCopy(model)
+ }
+
test("hashFunction") {
val model = new MinHashLSHModel("mh", randCoefficients = Array((0, 1), (1, 2), (3, 0)))
val res = model.hashFunction(Vectors.sparse(10, Seq((2, 1.0), (3, 1.0), (5, 1.0), (7, 1.0))))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index c664460d7d8bb..5cfd59e6b88a2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -37,6 +37,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
val formula = new RFormula().setFormula("id ~ v1 + v2")
val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2")
val model = formula.fit(original)
+ MLTestingUtils.checkCopy(model)
val result = model.transform(original)
val resultSchema = model.transformSchema(original.schema)
val expected = Seq(
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 5e732b4bec8fd..d912f395dafce 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -223,7 +223,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
:param timestampFormat: sets the string that indicates a timestamp format. Custom date
formats follow the formats at ``java.text.SimpleDateFormat``.
This applies to timestamp type. If None is set, it uses the
- default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``.
+ default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``.
:param wholeFile: parse one record, which may span multiple lines, per file. If None is
set, it uses the default value, ``false``.
@@ -363,7 +363,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
:param timestampFormat: sets the string that indicates a timestamp format. Custom date
formats follow the formats at ``java.text.SimpleDateFormat``.
This applies to timestamp type. If None is set, it uses the
- default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``.
+ default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``.
:param maxColumns: defines a hard limit of how many columns a record can have. If None is
set, it uses the default value, ``20480``.
:param maxCharsPerColumn: defines the maximum number of characters allowed for any given
@@ -653,7 +653,7 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm
:param timestampFormat: sets the string that indicates a timestamp format. Custom date
formats follow the formats at ``java.text.SimpleDateFormat``.
This applies to timestamp type. If None is set, it uses the
- default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``.
+ default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``.
>>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data'))
"""
@@ -745,7 +745,7 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No
:param timestampFormat: sets the string that indicates a timestamp format. Custom date
formats follow the formats at ``java.text.SimpleDateFormat``.
This applies to timestamp type. If None is set, it uses the
- default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``.
+ default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``.
:param ignoreLeadingWhiteSpace: a flag indicating whether or not leading whitespaces from
values being written should be skipped. If None is set, it
uses the default value, ``true``.
diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py
index 27d6725615a4c..3b604963415f9 100644
--- a/python/pyspark/sql/streaming.py
+++ b/python/pyspark/sql/streaming.py
@@ -457,7 +457,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
:param timestampFormat: sets the string that indicates a timestamp format. Custom date
formats follow the formats at ``java.text.SimpleDateFormat``.
This applies to timestamp type. If None is set, it uses the
- default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``.
+ default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``.
:param wholeFile: parse one record, which may span multiple lines, per file. If None is
set, it uses the default value, ``false``.
@@ -581,7 +581,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
:param timestampFormat: sets the string that indicates a timestamp format. Custom date
formats follow the formats at ``java.text.SimpleDateFormat``.
This applies to timestamp type. If None is set, it uses the
- default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``.
+ default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``.
:param maxColumns: defines a hard limit of how many columns a record can have. If None is
set, it uses the default value, ``20480``.
:param maxCharsPerColumn: defines the maximum number of characters allowed for any given
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala
index a8693dcca539d..254eedfe77517 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala
@@ -25,6 +25,7 @@ import org.apache.hadoop.util.Shell
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
+import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, BoundReference, Expression, InterpretedPredicate}
object ExternalCatalogUtils {
// This duplicates default value of Hive `ConfVars.DEFAULTPARTITIONNAME`, since catalyst doesn't
@@ -125,6 +126,38 @@ object ExternalCatalogUtils {
}
escapePathName(col) + "=" + partitionString
}
+
+ def prunePartitionsByFilter(
+ catalogTable: CatalogTable,
+ inputPartitions: Seq[CatalogTablePartition],
+ predicates: Seq[Expression],
+ defaultTimeZoneId: String): Seq[CatalogTablePartition] = {
+ if (predicates.isEmpty) {
+ inputPartitions
+ } else {
+ val partitionSchema = catalogTable.partitionSchema
+ val partitionColumnNames = catalogTable.partitionColumnNames.toSet
+
+ val nonPartitionPruningPredicates = predicates.filterNot {
+ _.references.map(_.name).toSet.subsetOf(partitionColumnNames)
+ }
+ if (nonPartitionPruningPredicates.nonEmpty) {
+ throw new AnalysisException("Expected only partition pruning predicates: " +
+ nonPartitionPruningPredicates)
+ }
+
+ val boundPredicate =
+ InterpretedPredicate.create(predicates.reduce(And).transform {
+ case att: AttributeReference =>
+ val index = partitionSchema.indexWhere(_.name == att.name)
+ BoundReference(index, partitionSchema(index).dataType, nullable = true)
+ })
+
+ inputPartitions.filter { p =>
+ boundPredicate(p.toRow(partitionSchema, defaultTimeZoneId))
+ }
+ }
+ }
}
object CatalogUtils {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
index cdf618aef97c3..9ca1c71d1dcb1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
@@ -28,7 +28,7 @@ import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName
+import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.util.StringUtils
import org.apache.spark.sql.types.StructType
@@ -556,9 +556,9 @@ class InMemoryCatalog(
table: String,
predicates: Seq[Expression],
defaultTimeZoneId: String): Seq[CatalogTablePartition] = {
- // TODO: Provide an implementation
- throw new UnsupportedOperationException(
- "listPartitionsByFilter is not implemented for InMemoryCatalog")
+ val catalogTable = getTable(db, table)
+ val allPartitions = listPartitions(db, table)
+ prunePartitionsByFilter(catalogTable, allPartitions, predicates, defaultTimeZoneId)
}
// --------------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
index c22b1ade4e64b..23ba5ed4d50dc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
@@ -79,7 +79,7 @@ private[sql] class JSONOptions(
val timestampFormat: FastDateFormat =
FastDateFormat.getInstance(
- parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), timeZone, Locale.US)
+ parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US)
val wholeFile = parameters.get("wholeFile").map(_.toBoolean).getOrElse(false)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala
index 7820f39d96426..42db4398e5072 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.catalog
import java.net.URI
+import java.util.TimeZone
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
@@ -28,6 +29,8 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException}
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -436,6 +439,44 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac
assert(catalog.listPartitions("db2", "tbl2", Some(Map("a" -> "unknown"))).isEmpty)
}
+ test("list partitions by filter") {
+ val tz = TimeZone.getDefault.getID
+ val catalog = newBasicCatalog()
+
+ def checkAnswer(
+ table: CatalogTable, filters: Seq[Expression], expected: Set[CatalogTablePartition])
+ : Unit = {
+
+ assertResult(expected.map(_.spec)) {
+ catalog.listPartitionsByFilter(table.database, table.identifier.identifier, filters, tz)
+ .map(_.spec).toSet
+ }
+ }
+
+ val tbl2 = catalog.getTable("db2", "tbl2")
+
+ checkAnswer(tbl2, Seq.empty, Set(part1, part2))
+ checkAnswer(tbl2, Seq('a.int <= 1), Set(part1))
+ checkAnswer(tbl2, Seq('a.int === 2), Set.empty)
+ checkAnswer(tbl2, Seq(In('a.int * 10, Seq(30))), Set(part2))
+ checkAnswer(tbl2, Seq(Not(In('a.int, Seq(4)))), Set(part1, part2))
+ checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "2"), Set(part1))
+ checkAnswer(tbl2, Seq('a.int === 1 && 'b.string === "2"), Set(part1))
+ checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "x"), Set.empty)
+ checkAnswer(tbl2, Seq('a.int === 1 || 'b.string === "x"), Set(part1))
+
+ intercept[AnalysisException] {
+ try {
+ checkAnswer(tbl2, Seq('a.int > 0 && 'col1.int > 0), Set.empty)
+ } catch {
+ // HiveExternalCatalog may be the first one to notice and throw an exception, which will
+ // then be caught and converted to a RuntimeException with a descriptive message.
+ case ex: RuntimeException if ex.getMessage.contains("MetaException") =>
+ throw new AnalysisException(ex.getMessage)
+ }
+ }
+ }
+
test("drop partitions") {
val catalog = newBasicCatalog()
assert(catalogPartitionsEqual(catalog, "db2", "tbl2", Seq(part1, part2)))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 6c238618f2af7..2b8537c3d4a63 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -320,7 +320,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format.
* Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to
* date type.
- * `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that
+ * `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that
* indicates a timestamp format. Custom date formats follow the formats at
* `java.text.SimpleDateFormat`. This applies to timestamp type.
* `wholeFile` (default `false`): parse one record, which may span multiple lines,
@@ -502,7 +502,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format.
* Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to
* date type.
- * `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that
+ * `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that
* indicates a timestamp format. Custom date formats follow the formats at
* `java.text.SimpleDateFormat`. This applies to timestamp type.
* `maxColumns` (default `20480`): defines a hard limit of how many columns
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index e973d0bc6d09b..338a6e1314d90 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -477,7 +477,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
* `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format.
* Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to
* date type.
- * `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that
+ * `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that
* indicates a timestamp format. Custom date formats follow the formats at
* `java.text.SimpleDateFormat`. This applies to timestamp type.
*
@@ -583,7 +583,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
* `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format.
* Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to
* date type.
- * `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that
+ * `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that
* indicates a timestamp format. Custom date formats follow the formats at
* `java.text.SimpleDateFormat`. This applies to timestamp type.
* `ignoreLeadingWhiteSpace` (default `true`): a flag indicating whether or not leading
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index bddf5af23e060..c350d8bcbae97 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -217,8 +217,6 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan]
val table = r.tableMeta
val qualifiedTableName = QualifiedTableName(table.database, table.identifier.table)
val cache = sparkSession.sessionState.catalog.tableRelationCache
- val withHiveSupport =
- sparkSession.sparkContext.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive"
val plan = cache.get(qualifiedTableName, new Callable[LogicalPlan]() {
override def call(): LogicalPlan = {
@@ -233,8 +231,7 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan]
bucketSpec = table.bucketSpec,
className = table.provider.get,
options = table.storage.properties ++ pathOption,
- // TODO: improve `InMemoryCatalog` and remove this limitation.
- catalogTable = if (withHiveSupport) Some(table) else None)
+ catalogTable = Some(table))
LogicalRelation(
dataSource.resolveRelation(checkFilesExist = false),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
index e7b79e0cbfd17..4994b8dc80527 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
@@ -126,7 +126,7 @@ class CSVOptions(
val timestampFormat: FastDateFormat =
FastDateFormat.getInstance(
- parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), timeZone, Locale.US)
+ parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US)
val wholeFile = parameters.get("wholeFile").map(_.toBoolean).getOrElse(false)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
index e15c30b4374bb..fb632cf2bb70e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
@@ -25,7 +25,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.TypeCoercion
import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil
import org.apache.spark.sql.catalyst.json.JSONOptions
-import org.apache.spark.sql.catalyst.util.PermissiveMode
+import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -41,7 +41,7 @@ private[sql] object JsonInferSchema {
json: RDD[T],
configOptions: JSONOptions,
createParser: (JsonFactory, T) => JsonParser): StructType = {
- val shouldHandleCorruptRecord = configOptions.parseMode == PermissiveMode
+ val parseMode = configOptions.parseMode
val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord
// perform schema inference on each row and merge afterwards
@@ -55,20 +55,24 @@ private[sql] object JsonInferSchema {
Some(inferField(parser, configOptions))
}
} catch {
- case _: JsonParseException if shouldHandleCorruptRecord =>
- Some(StructType(Seq(StructField(columnNameOfCorruptRecord, StringType))))
- case _: JsonParseException =>
- None
+ case e @ (_: RuntimeException | _: JsonProcessingException) => parseMode match {
+ case PermissiveMode =>
+ Some(StructType(Seq(StructField(columnNameOfCorruptRecord, StringType))))
+ case DropMalformedMode =>
+ None
+ case FailFastMode =>
+ throw e
+ }
}
}
- }.fold(StructType(Seq()))(
- compatibleRootType(columnNameOfCorruptRecord, shouldHandleCorruptRecord))
+ }.fold(StructType(Nil))(
+ compatibleRootType(columnNameOfCorruptRecord, parseMode))
canonicalizeType(rootType) match {
case Some(st: StructType) => st
case _ =>
// canonicalizeType erases all empty structs, including the only one we want to keep
- StructType(Seq())
+ StructType(Nil)
}
}
@@ -202,19 +206,33 @@ private[sql] object JsonInferSchema {
private def withCorruptField(
struct: StructType,
- columnNameOfCorruptRecords: String): StructType = {
- if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) {
- // If this given struct does not have a column used for corrupt records,
- // add this field.
- val newFields: Array[StructField] =
- StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields
- // Note: other code relies on this sorting for correctness, so don't remove it!
- java.util.Arrays.sort(newFields, structFieldComparator)
- StructType(newFields)
- } else {
- // Otherwise, just return this struct.
+ other: DataType,
+ columnNameOfCorruptRecords: String,
+ parseMode: ParseMode) = parseMode match {
+ case PermissiveMode =>
+ // If we see any other data type at the root level, we get records that cannot be
+ // parsed. So, we use the struct as the data type and add the corrupt field to the schema.
+ if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) {
+ // If this given struct does not have a column used for corrupt records,
+ // add this field.
+ val newFields: Array[StructField] =
+ StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields
+ // Note: other code relies on this sorting for correctness, so don't remove it!
+ java.util.Arrays.sort(newFields, structFieldComparator)
+ StructType(newFields)
+ } else {
+ // Otherwise, just return this struct.
+ struct
+ }
+
+ case DropMalformedMode =>
+ // If corrupt record handling is disabled we retain the valid schema and discard the other.
struct
- }
+
+ case FailFastMode =>
+ // If `other` is not struct type, consider it as malformed one and throws an exception.
+ throw new RuntimeException("Failed to infer a common schema. Struct types are expected" +
+ s" but ${other.catalogString} was found.")
}
/**
@@ -222,21 +240,20 @@ private[sql] object JsonInferSchema {
*/
private def compatibleRootType(
columnNameOfCorruptRecords: String,
- shouldHandleCorruptRecord: Boolean): (DataType, DataType) => DataType = {
+ parseMode: ParseMode): (DataType, DataType) => DataType = {
// Since we support array of json objects at the top level,
// we need to check the element type and find the root level data type.
case (ArrayType(ty1, _), ty2) =>
- compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)(ty1, ty2)
+ compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2)
case (ty1, ArrayType(ty2, _)) =>
- compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)(ty1, ty2)
- // If we see any other data type at the root level, we get records that cannot be
- // parsed. So, we use the struct as the data type and add the corrupt field to the schema.
+ compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2)
+ // Discard null/empty documents
case (struct: StructType, NullType) => struct
case (NullType, struct: StructType) => struct
- case (struct: StructType, o) if !o.isInstanceOf[StructType] && shouldHandleCorruptRecord =>
- withCorruptField(struct, columnNameOfCorruptRecords)
- case (o, struct: StructType) if !o.isInstanceOf[StructType] && shouldHandleCorruptRecord =>
- withCorruptField(struct, columnNameOfCorruptRecords)
+ case (struct: StructType, o) if !o.isInstanceOf[StructType] =>
+ withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode)
+ case (o, struct: StructType) if !o.isInstanceOf[StructType] =>
+ withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode)
// If we get anything else, we call compatibleType.
// Usually, when we reach here, ty1 and ty2 are two StructTypes.
case (ty1, ty2) => compatibleType(ty1, ty2)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index 997ca286597da..c3a9cfc08517a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -201,7 +201,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format.
* Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to
* date type.
- * `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that
+ * `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that
* indicates a timestamp format. Custom date formats follow the formats at
* `java.text.SimpleDateFormat`. This applies to timestamp type.
* `wholeFile` (default `false`): parse one record, which may span multiple lines,
@@ -252,7 +252,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format.
* Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to
* date type.
- * `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that
+ * `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that
* indicates a timestamp format. Custom date formats follow the formats at
* `java.text.SimpleDateFormat`. This applies to timestamp type.
* `maxColumns` (default `20480`): defines a hard limit of how many columns
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index d70c47f4e2379..352dba79a4c08 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -766,7 +766,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
.option("header", "true")
.load(iso8601timestampsPath)
- val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSZZ", Locale.US)
+ val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSXXX", Locale.US)
val expectedTimestamps = timestamps.collect().map { r =>
// This should be ISO8601 formatted string.
Row(iso8501.format(r.toSeq.head))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index b09cef76d2be7..2ab03819964be 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -1041,7 +1041,6 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
spark.read
.option("mode", "FAILFAST")
.json(corruptRecords)
- .collect()
}
assert(exceptionOne.getMessage.contains("JsonParseException"))
@@ -1082,6 +1081,18 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
assert(jsonDFTwo.schema === schemaTwo)
}
+ test("SPARK-19641: Additional corrupt records: DROPMALFORMED mode") {
+ val schema = new StructType().add("dummy", StringType)
+ // `DROPMALFORMED` mode should skip corrupt records
+ val jsonDF = spark.read
+ .option("mode", "DROPMALFORMED")
+ .json(additionalCorruptRecords)
+ checkAnswer(
+ jsonDF,
+ Row("test"))
+ assert(jsonDF.schema === schema)
+ }
+
test("Corrupt records: PERMISSIVE mode, without designated column for malformed records") {
val schema = StructType(
StructField("a", StringType, true) ::
@@ -1882,6 +1893,24 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
}
+ test("SPARK-19641: Handle multi-line corrupt documents (DROPMALFORMED)") {
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ val corruptRecordCount = additionalCorruptRecords.count().toInt
+ assert(corruptRecordCount === 5)
+
+ additionalCorruptRecords
+ .toDF("value")
+ // this is the minimum partition count that avoids hash collisions
+ .repartition(corruptRecordCount * 4, F.hash($"value"))
+ .write
+ .text(path)
+
+ val jsonDF = spark.read.option("wholeFile", true).option("mode", "DROPMALFORMED").json(path)
+ checkAnswer(jsonDF, Seq(Row("test")))
+ }
+ }
+
test("SPARK-18352: Handle multi-line corrupt documents (FAILFAST)") {
withTempPath { dir =>
val path = dir.getCanonicalPath
@@ -1903,9 +1932,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
.option("wholeFile", true)
.option("mode", "FAILFAST")
.json(path)
- .collect()
}
- assert(exceptionOne.getMessage.contains("Failed to parse a value"))
+ assert(exceptionOne.getMessage.contains("Failed to infer a common schema"))
val exceptionTwo = intercept[SparkException] {
spark.read
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
index 33b21be37203b..f0e35dff57f7b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.catalyst.catalog._
-import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName
+import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.ColumnStat
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
@@ -1039,37 +1039,14 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
defaultTimeZoneId: String): Seq[CatalogTablePartition] = withClient {
val rawTable = getRawTable(db, table)
val catalogTable = restoreTableMetadata(rawTable)
- val partitionColumnNames = catalogTable.partitionColumnNames.toSet
- val nonPartitionPruningPredicates = predicates.filterNot {
- _.references.map(_.name).toSet.subsetOf(partitionColumnNames)
- }
- if (nonPartitionPruningPredicates.nonEmpty) {
- sys.error("Expected only partition pruning predicates: " +
- predicates.reduceLeft(And))
- }
+ val partColNameMap = buildLowerCasePartColNameMap(catalogTable)
- val partitionSchema = catalogTable.partitionSchema
- val partColNameMap = buildLowerCasePartColNameMap(getTable(db, table))
-
- if (predicates.nonEmpty) {
- val clientPrunedPartitions = client.getPartitionsByFilter(rawTable, predicates).map { part =>
+ val clientPrunedPartitions =
+ client.getPartitionsByFilter(rawTable, predicates).map { part =>
part.copy(spec = restorePartitionSpec(part.spec, partColNameMap))
}
- val boundPredicate =
- InterpretedPredicate.create(predicates.reduce(And).transform {
- case att: AttributeReference =>
- val index = partitionSchema.indexWhere(_.name == att.name)
- BoundReference(index, partitionSchema(index).dataType, nullable = true)
- })
- clientPrunedPartitions.filter { p =>
- boundPredicate(p.toRow(partitionSchema, defaultTimeZoneId))
- }
- } else {
- client.getPartitions(catalogTable).map { part =>
- part.copy(spec = restorePartitionSpec(part.spec, partColNameMap))
- }
- }
+ prunePartitionsByFilter(catalogTable, clientPrunedPartitions, predicates, defaultTimeZoneId)
}
// --------------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
index d55c41e5c9f29..2e35f39839488 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
@@ -584,7 +584,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
*/
def convertFilters(table: Table, filters: Seq[Expression]): String = {
// hive varchar is treated as catalyst string, but hive varchar can't be pushed down.
- val varcharKeys = table.getPartitionKeys.asScala
+ lazy val varcharKeys = table.getPartitionKeys.asScala
.filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME) ||
col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME))
.map(col => col.getName).toSet
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala
index 4349f1aa23be0..bd54c043c6ec4 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala
@@ -22,7 +22,6 @@ import org.apache.hadoop.conf.Configuration
import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog._
-import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.types.StructType
@@ -50,13 +49,6 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite {
import utils._
- test("list partitions by filter") {
- val catalog = newBasicCatalog()
- val selectedPartitions = catalog.listPartitionsByFilter("db2", "tbl2", Seq('a.int === 1), "GMT")
- assert(selectedPartitions.length == 1)
- assert(selectedPartitions.head.spec == part1.spec)
- }
-
test("SPARK-18647: do not put provider in table properties for Hive serde table") {
val catalog = newBasicCatalog()
val hiveTable = CatalogTable(