diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorDiskUtils.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorDiskUtils.java
index 13f6046dd856b..6549cac011feb 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorDiskUtils.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorDiskUtils.java
@@ -23,11 +23,19 @@
import com.google.common.annotations.VisibleForTesting;
+import org.apache.commons.lang3.SystemUtils;
import org.apache.spark.network.util.JavaUtils;
public class ExecutorDiskUtils {
- private static final Pattern MULTIPLE_SEPARATORS = Pattern.compile(File.separator + "{2,}");
+ private static final Pattern MULTIPLE_SEPARATORS;
+ static {
+ if (SystemUtils.IS_OS_WINDOWS) {
+ MULTIPLE_SEPARATORS = Pattern.compile("[/\\\\]+");
+ } else {
+ MULTIPLE_SEPARATORS = Pattern.compile("/{2,}");
+ }
+ }
/**
* Hashes a filename into the corresponding local directory, in a manner consistent with
@@ -50,14 +58,18 @@ public static File getFile(String[] localDirs, int subDirsPerLocalDir, String fi
* the internal code in java.io.File would normalize it later, creating a new "foo/bar"
* String copy. Unfortunately, we cannot just reuse the normalization code that java.io.File
* uses, since it is in the package-private class java.io.FileSystem.
+ *
+ * On Windows, separator "\" is used instead of "/".
+ *
+ * "\\" is a legal character in path name on Unix-like OS, but illegal on Windows.
*/
@VisibleForTesting
static String createNormalizedInternedPathname(String dir1, String dir2, String fname) {
String pathname = dir1 + File.separator + dir2 + File.separator + fname;
Matcher m = MULTIPLE_SEPARATORS.matcher(pathname);
- pathname = m.replaceAll("/");
+ pathname = m.replaceAll(Matcher.quoteReplacement(File.separator));
// A single trailing slash needs to be taken care of separately
- if (pathname.length() > 1 && pathname.endsWith("/")) {
+ if (pathname.length() > 1 && pathname.charAt(pathname.length() - 1) == File.separatorChar) {
pathname = pathname.substring(0, pathname.length() - 1);
}
return pathname.intern();
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java
index 09b31430b1eb9..6515b6ca035f7 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java
@@ -25,6 +25,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.io.CharStreams;
+import org.apache.commons.lang3.SystemUtils;
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
import org.apache.spark.network.util.MapConfigProvider;
import org.apache.spark.network.util.TransportConf;
@@ -146,12 +147,19 @@ public void jsonSerializationOfExecutorRegistration() throws IOException {
@Test
public void testNormalizeAndInternPathname() {
- assertPathsMatch("/foo", "bar", "baz", "/foo/bar/baz");
- assertPathsMatch("//foo/", "bar/", "//baz", "/foo/bar/baz");
- assertPathsMatch("foo", "bar", "baz///", "foo/bar/baz");
- assertPathsMatch("/foo/", "/bar//", "/baz", "/foo/bar/baz");
- assertPathsMatch("/", "", "", "/");
- assertPathsMatch("/", "/", "/", "/");
+ String sep = File.separator;
+ String expectedPathname = sep + "foo" + sep + "bar" + sep + "baz";
+ assertPathsMatch("/foo", "bar", "baz", expectedPathname);
+ assertPathsMatch("//foo/", "bar/", "//baz", expectedPathname);
+ assertPathsMatch("/foo/", "/bar//", "/baz", expectedPathname);
+ assertPathsMatch("foo", "bar", "baz///", "foo" + sep + "bar" + sep + "baz");
+ assertPathsMatch("/", "", "", sep);
+ assertPathsMatch("/", "/", "/", sep);
+ if (SystemUtils.IS_OS_WINDOWS) {
+ assertPathsMatch("/foo\\/", "bar", "baz", expectedPathname);
+ } else {
+ assertPathsMatch("/foo\\/", "bar", "baz", sep + "foo\\" + sep + "bar" + sep + "baz");
+ }
}
private void assertPathsMatch(String p1, String p2, String p3, String expectedPathname) {
@@ -160,6 +168,6 @@ private void assertPathsMatch(String p1, String p2, String p3, String expectedPa
assertEquals(expectedPathname, normPathname);
File file = new File(normPathname);
String returnedPath = file.getPath();
- assertTrue(normPathname == returnedPath);
+ assertEquals(normPathname, returnedPath);
}
}
diff --git a/core/src/main/resources/org/apache/spark/ui/static/utils.js b/core/src/main/resources/org/apache/spark/ui/static/utils.js
index 6fc34a9e1f7ea..2e46111bf1ba0 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/utils.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/utils.js
@@ -56,13 +56,17 @@ function formatTimeMillis(timeMillis) {
return "-";
} else {
var dt = new Date(timeMillis);
+ return formatDateString(dt);
+ }
+}
+
+function formatDateString(dt) {
return dt.getFullYear() + "-" +
padZeroes(dt.getMonth() + 1) + "-" +
padZeroes(dt.getDate()) + " " +
padZeroes(dt.getHours()) + ":" +
padZeroes(dt.getMinutes()) + ":" +
padZeroes(dt.getSeconds());
- }
}
function getTimeZone() {
@@ -161,7 +165,10 @@ function setDataTableDefaults() {
function formatDate(date) {
if (date <= 0) return "-";
- else return date.split(".")[0].replace("T", " ");
+ else {
+ var dt = new Date(date.replace("GMT", "Z"))
+ return formatDateString(dt);
+ }
}
function createRESTEndPointForExecutorsPage(appId) {
diff --git a/core/src/test/scala/org/apache/spark/metrics/sink/PrometheusServletSuite.scala b/core/src/test/scala/org/apache/spark/metrics/sink/PrometheusServletSuite.scala
new file mode 100644
index 0000000000000..080ca0e41f793
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/metrics/sink/PrometheusServletSuite.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.metrics.sink
+
+import java.util.Properties
+
+import scala.collection.JavaConverters._
+
+import com.codahale.metrics.{Counter, Gauge, MetricRegistry}
+import org.scalatest.PrivateMethodTester
+
+import org.apache.spark.SparkFunSuite
+
+class PrometheusServletSuite extends SparkFunSuite with PrivateMethodTester {
+
+ test("register metrics") {
+ val sink = createPrometheusServlet()
+
+ val gauge = new Gauge[Double] {
+ override def getValue: Double = 5.0
+ }
+
+ val counter = new Counter
+ counter.inc(10)
+
+ sink.registry.register("gauge1", gauge)
+ sink.registry.register("gauge2", gauge)
+ sink.registry.register("counter1", counter)
+
+ val metricGaugeKeys = sink.registry.getGauges.keySet.asScala
+ assert(metricGaugeKeys.equals(Set("gauge1", "gauge2")),
+ "Should contain 2 gauges metrics registered")
+
+ val metricCounterKeys = sink.registry.getCounters.keySet.asScala
+ assert(metricCounterKeys.equals(Set("counter1")),
+ "Should contain 1 counter metric registered")
+
+ val gaugeValues = sink.registry.getGauges.values.asScala
+ assert(gaugeValues.size == 2)
+ gaugeValues.foreach(gauge => assert(gauge.getValue == 5.0))
+
+ val counterValues = sink.registry.getCounters.values.asScala
+ assert(counterValues.size == 1)
+ counterValues.foreach(counter => assert(counter.getCount == 10))
+ }
+
+ test("normalize key") {
+ val key = "local-1592132938718.driver.LiveListenerBus." +
+ "listenerProcessingTime.org.apache.spark.HeartbeatReceiver"
+ val sink = createPrometheusServlet()
+ val suffix = sink invokePrivate PrivateMethod[String]('normalizeKey)(key)
+ assert(suffix == "metrics_local_1592132938718_driver_LiveListenerBus_" +
+ "listenerProcessingTime_org_apache_spark_HeartbeatReceiver_")
+ }
+
+ private def createPrometheusServlet(): PrometheusServlet =
+ new PrometheusServlet(new Properties, new MetricRegistry, securityMgr = null)
+}
diff --git a/docs/_data/menu-sql.yaml b/docs/_data/menu-sql.yaml
index 219e6809a96f0..eea657e684495 100644
--- a/docs/_data/menu-sql.yaml
+++ b/docs/_data/menu-sql.yaml
@@ -139,7 +139,7 @@
- text: REPAIR TABLE
url: sql-ref-syntax-ddl-repair-table.html
- text: USE DATABASE
- url: sql-ref-syntax-qry-select-usedb.html
+ url: sql-ref-syntax-ddl-usedb.html
- text: Data Manipulation Statements
url: sql-ref-syntax-dml.html
subitems:
@@ -207,7 +207,7 @@
- text: CLEAR CACHE
url: sql-ref-syntax-aux-cache-clear-cache.html
- text: REFRESH TABLE
- url: sql-ref-syntax-aux-refresh-table.html
+ url: sql-ref-syntax-aux-cache-refresh-table.html
- text: REFRESH
url: sql-ref-syntax-aux-cache-refresh.html
- text: DESCRIBE
diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md
index 1f70d46d587a8..f3c479ba26547 100644
--- a/docs/spark-standalone.md
+++ b/docs/spark-standalone.md
@@ -359,7 +359,7 @@ Spark Standalone has 2 parts, the first is configuring the resources for the Wor
The user must configure the Workers to have a set of resources available so that it can assign them out to Executors. The spark.worker.resource.{resourceName}.amount is used to control the amount of each resource the worker has allocated. The user must also specify either spark.worker.resourcesFile or spark.worker.resource.{resourceName}.discoveryScript to specify how the Worker discovers the resources its assigned. See the descriptions above for each of those to see which method works best for your setup.
-The second part is running an application on Spark Standalone. The only special case from the standard Spark resource configs is when you are running the Driver in client mode. For a Driver in client mode, the user can specify the resources it uses via spark.driver.resourcesfile or spark.driver.resource.{resourceName}.discoveryScript. If the Driver is running on the same host as other Drivers, please make sure the resources file or discovery script only returns resources that do not conflict with other Drivers running on the same node.
+The second part is running an application on Spark Standalone. The only special case from the standard Spark resource configs is when you are running the Driver in client mode. For a Driver in client mode, the user can specify the resources it uses via spark.driver.resourcesFile or spark.driver.resource.{resourceName}.discoveryScript. If the Driver is running on the same host as other Drivers, please make sure the resources file or discovery script only returns resources that do not conflict with other Drivers running on the same node.
Note, the user does not need to specify a discovery script when submitting an application as the Worker will start each Executor with the resources it allocates to it.
diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md
index d6550c30b9553..0c84db38afafc 100644
--- a/docs/sql-migration-guide.md
+++ b/docs/sql-migration-guide.md
@@ -31,7 +31,11 @@ license: |
- In Spark 3.1, `from_unixtime`, `unix_timestamp`,`to_unix_timestamp`, `to_timestamp` and `to_date` will fail if the specified datetime pattern is invalid. In Spark 3.0 or earlier, they result `NULL`.
- In Spark 3.1, casting numeric to timestamp will be forbidden by default. It's strongly recommended to use dedicated functions: TIMESTAMP_SECONDS, TIMESTAMP_MILLIS and TIMESTAMP_MICROS. Or you can set `spark.sql.legacy.allowCastNumericToTimestamp` to true to work around it. See more details in SPARK-31710.
-
+
+## Upgrading from Spark SQL 3.0 to 3.0.1
+
+- In Spark 3.0, JSON datasource and JSON function `schema_of_json` infer TimestampType from string values if they match to the pattern defined by the JSON option `timestampFormat`. Since version 3.0.1, the timestamp type inference is disabled by default. Set the JSON option `inferTimestamp` to `true` to enable such type inference.
+
## Upgrading from Spark SQL 2.4 to 3.0
### Dataset/DataFrame APIs
diff --git a/docs/sql-ref-syntax-aux-cache-cache-table.md b/docs/sql-ref-syntax-aux-cache-cache-table.md
index 193e209d792b3..fdef3d657dfa3 100644
--- a/docs/sql-ref-syntax-aux-cache-cache-table.md
+++ b/docs/sql-ref-syntax-aux-cache-cache-table.md
@@ -78,5 +78,5 @@ CACHE TABLE testCache OPTIONS ('storageLevel' 'DISK_ONLY') SELECT * FROM testDat
* [CLEAR CACHE](sql-ref-syntax-aux-cache-clear-cache.html)
* [UNCACHE TABLE](sql-ref-syntax-aux-cache-uncache-table.html)
-* [REFRESH TABLE](sql-ref-syntax-aux-refresh-table.html)
+* [REFRESH TABLE](sql-ref-syntax-aux-cache-refresh-table.html)
* [REFRESH](sql-ref-syntax-aux-cache-refresh.html)
diff --git a/docs/sql-ref-syntax-aux-cache-clear-cache.md b/docs/sql-ref-syntax-aux-cache-clear-cache.md
index ee33e6a98296d..a27cd83c146a3 100644
--- a/docs/sql-ref-syntax-aux-cache-clear-cache.md
+++ b/docs/sql-ref-syntax-aux-cache-clear-cache.md
@@ -39,5 +39,5 @@ CLEAR CACHE;
* [CACHE TABLE](sql-ref-syntax-aux-cache-cache-table.html)
* [UNCACHE TABLE](sql-ref-syntax-aux-cache-uncache-table.html)
-* [REFRESH TABLE](sql-ref-syntax-aux-refresh-table.html)
+* [REFRESH TABLE](sql-ref-syntax-aux-cache-refresh-table.html)
* [REFRESH](sql-ref-syntax-aux-cache-refresh.html)
diff --git a/docs/sql-ref-syntax-aux-refresh-table.md b/docs/sql-ref-syntax-aux-cache-refresh-table.md
similarity index 100%
rename from docs/sql-ref-syntax-aux-refresh-table.md
rename to docs/sql-ref-syntax-aux-cache-refresh-table.md
diff --git a/docs/sql-ref-syntax-aux-cache-refresh.md b/docs/sql-ref-syntax-aux-cache-refresh.md
index 82bc12da5d1ac..b10e6fb47aaf7 100644
--- a/docs/sql-ref-syntax-aux-cache-refresh.md
+++ b/docs/sql-ref-syntax-aux-cache-refresh.md
@@ -53,4 +53,4 @@ REFRESH "hdfs://path/to/table";
* [CACHE TABLE](sql-ref-syntax-aux-cache-cache-table.html)
* [CLEAR CACHE](sql-ref-syntax-aux-cache-clear-cache.html)
* [UNCACHE TABLE](sql-ref-syntax-aux-cache-uncache-table.html)
-* [REFRESH TABLE](sql-ref-syntax-aux-refresh-table.html)
+* [REFRESH TABLE](sql-ref-syntax-aux-cache-refresh-table.html)
diff --git a/docs/sql-ref-syntax-aux-cache-uncache-table.md b/docs/sql-ref-syntax-aux-cache-uncache-table.md
index c5a8fbbe08281..96a691e4c3931 100644
--- a/docs/sql-ref-syntax-aux-cache-uncache-table.md
+++ b/docs/sql-ref-syntax-aux-cache-uncache-table.md
@@ -48,5 +48,5 @@ UNCACHE TABLE t1;
* [CACHE TABLE](sql-ref-syntax-aux-cache-cache-table.html)
* [CLEAR CACHE](sql-ref-syntax-aux-cache-clear-cache.html)
-* [REFRESH TABLE](sql-ref-syntax-aux-refresh-table.html)
+* [REFRESH TABLE](sql-ref-syntax-aux-cache-refresh-table.html)
* [REFRESH](sql-ref-syntax-aux-cache-refresh.html)
diff --git a/docs/sql-ref-syntax-aux-cache.md b/docs/sql-ref-syntax-aux-cache.md
index 418b8cc3403b5..0ccb1c61a0da5 100644
--- a/docs/sql-ref-syntax-aux-cache.md
+++ b/docs/sql-ref-syntax-aux-cache.md
@@ -22,5 +22,5 @@ license: |
* [CACHE TABLE statement](sql-ref-syntax-aux-cache-cache-table.html)
* [UNCACHE TABLE statement](sql-ref-syntax-aux-cache-uncache-table.html)
* [CLEAR CACHE statement](sql-ref-syntax-aux-cache-clear-cache.html)
- * [REFRESH TABLE statement](sql-ref-syntax-aux-refresh-table.html)
+ * [REFRESH TABLE statement](sql-ref-syntax-aux-cache-refresh-table.html)
* [REFRESH statement](sql-ref-syntax-aux-cache-refresh.html)
\ No newline at end of file
diff --git a/docs/sql-ref-syntax-qry-select-usedb.md b/docs/sql-ref-syntax-ddl-usedb.md
similarity index 100%
rename from docs/sql-ref-syntax-qry-select-usedb.md
rename to docs/sql-ref-syntax-ddl-usedb.md
diff --git a/docs/sql-ref-syntax-ddl.md b/docs/sql-ref-syntax-ddl.md
index 82fbf0498a20f..cb3e04c0ec910 100644
--- a/docs/sql-ref-syntax-ddl.md
+++ b/docs/sql-ref-syntax-ddl.md
@@ -34,4 +34,4 @@ Data Definition Statements are used to create or modify the structure of databas
* [DROP VIEW](sql-ref-syntax-ddl-drop-view.html)
* [TRUNCATE TABLE](sql-ref-syntax-ddl-truncate-table.html)
* [REPAIR TABLE](sql-ref-syntax-ddl-repair-table.html)
- * [USE DATABASE](sql-ref-syntax-qry-select-usedb.html)
+ * [USE DATABASE](sql-ref-syntax-ddl-usedb.html)
diff --git a/docs/sql-ref-syntax.md b/docs/sql-ref-syntax.md
index d78a01fd655a2..4bf1858428d98 100644
--- a/docs/sql-ref-syntax.md
+++ b/docs/sql-ref-syntax.md
@@ -36,7 +36,7 @@ Spark SQL is Apache Spark's module for working with structured data. The SQL Syn
* [DROP VIEW](sql-ref-syntax-ddl-drop-view.html)
* [REPAIR TABLE](sql-ref-syntax-ddl-repair-table.html)
* [TRUNCATE TABLE](sql-ref-syntax-ddl-truncate-table.html)
- * [USE DATABASE](sql-ref-syntax-qry-select-usedb.html)
+ * [USE DATABASE](sql-ref-syntax-ddl-usedb.html)
### DML Statements
@@ -82,7 +82,7 @@ Spark SQL is Apache Spark's module for working with structured data. The SQL Syn
* [LIST FILE](sql-ref-syntax-aux-resource-mgmt-list-file.html)
* [LIST JAR](sql-ref-syntax-aux-resource-mgmt-list-jar.html)
* [REFRESH](sql-ref-syntax-aux-cache-refresh.html)
- * [REFRESH TABLE](sql-ref-syntax-aux-refresh-table.html)
+ * [REFRESH TABLE](sql-ref-syntax-aux-cache-refresh-table.html)
* [RESET](sql-ref-syntax-aux-conf-mgmt-reset.html)
* [SET](sql-ref-syntax-aux-conf-mgmt-set.html)
* [SHOW COLUMNS](sql-ref-syntax-aux-show-columns.html)
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala
index c6f52d676422c..969dee0a39696 100644
--- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala
+++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala
@@ -31,13 +31,13 @@ class AvroDataSourceV2 extends FileDataSourceV2 {
override def getTable(options: CaseInsensitiveStringMap): Table = {
val paths = getPaths(options)
- val tableName = getTableName(paths)
+ val tableName = getTableName(options, paths)
AvroTable(tableName, sparkSession, options, paths, None, fallbackFileFormat)
}
override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = {
val paths = getPaths(options)
- val tableName = getTableName(paths)
+ val tableName = getTableName(options, paths)
AvroTable(tableName, sparkSession, options, paths, Some(schema), fallbackFileFormat)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ClassificationSummary.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ClassificationSummary.scala
index e9ea38161d3c0..9f3428db484c2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/ClassificationSummary.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ClassificationSummary.scala
@@ -44,7 +44,7 @@ private[classification] trait ClassificationSummary extends Serializable {
@Since("3.1.0")
def labelCol: String
- /** Field in "predictions" which gives the weight of each instance as a vector. */
+ /** Field in "predictions" which gives the weight of each instance. */
@Since("3.1.0")
def weightCol: String
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 1f3f291644f93..233e8e5bcdc88 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -22,6 +22,7 @@ import org.apache.spark.annotation.Since
import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
+import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.HasRawPredictionCol
import org.apache.spark.ml.util.{MetadataUtils, SchemaUtils}
import org.apache.spark.rdd.RDD
@@ -269,4 +270,26 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
* @return predicted label
*/
protected def raw2prediction(rawPrediction: Vector): Double = rawPrediction.argmax
+
+ /**
+ * If the rawPrediction and prediction columns are set, this method returns the current model,
+ * otherwise it generates new columns for them and sets them as columns on a new copy of
+ * the current model
+ */
+ private[classification] def findSummaryModel():
+ (ClassificationModel[FeaturesType, M], String, String) = {
+ val model = if ($(rawPredictionCol).isEmpty && $(predictionCol).isEmpty) {
+ copy(ParamMap.empty)
+ .setRawPredictionCol("rawPrediction_" + java.util.UUID.randomUUID.toString)
+ .setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
+ } else if ($(rawPredictionCol).isEmpty) {
+ copy(ParamMap.empty).setRawPredictionCol("rawPrediction_" +
+ java.util.UUID.randomUUID.toString)
+ } else if ($(predictionCol).isEmpty) {
+ copy(ParamMap.empty).setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
+ } else {
+ this
+ }
+ (model, model.getRawPredictionCol, model.getPredictionCol)
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
index 1659bbb1d34b3..4adc527c89b36 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
@@ -394,27 +394,6 @@ class LinearSVCModel private[classification] (
@Since("3.1.0")
override def summary: LinearSVCTrainingSummary = super.summary
- /**
- * If the rawPrediction and prediction columns are set, this method returns the current model,
- * otherwise it generates new columns for them and sets them as columns on a new copy of
- * the current model
- */
- private[classification] def findSummaryModel(): (LinearSVCModel, String, String) = {
- val model = if ($(rawPredictionCol).isEmpty && $(predictionCol).isEmpty) {
- copy(ParamMap.empty)
- .setRawPredictionCol("rawPrediction_" + java.util.UUID.randomUUID.toString)
- .setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
- } else if ($(rawPredictionCol).isEmpty) {
- copy(ParamMap.empty).setRawPredictionCol("rawPrediction_" +
- java.util.UUID.randomUUID.toString)
- } else if ($(predictionCol).isEmpty) {
- copy(ParamMap.empty).setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
- } else {
- this
- }
- (model, model.getRawPredictionCol, model.getPredictionCol)
- }
-
/**
* Evaluates the model on a test dataset.
*
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 20d619334f7b9..47b3e2de7695c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -1158,27 +1158,6 @@ class LogisticRegressionModel private[spark] (
s"(numClasses=${numClasses}), use summary instead.")
}
- /**
- * If the probability and prediction columns are set, this method returns the current model,
- * otherwise it generates new columns for them and sets them as columns on a new copy of
- * the current model
- */
- private[classification] def findSummaryModel():
- (LogisticRegressionModel, String, String) = {
- val model = if ($(probabilityCol).isEmpty && $(predictionCol).isEmpty) {
- copy(ParamMap.empty)
- .setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString)
- .setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
- } else if ($(probabilityCol).isEmpty) {
- copy(ParamMap.empty).setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString)
- } else if ($(predictionCol).isEmpty) {
- copy(ParamMap.empty).setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
- } else {
- this
- }
- (model, model.getProbabilityCol, model.getPredictionCol)
- }
-
/**
* Evaluates the model on a test dataset.
*
@@ -1451,7 +1430,7 @@ sealed trait BinaryLogisticRegressionTrainingSummary extends BinaryLogisticRegre
* double.
* @param labelCol field in "predictions" which gives the true label of each instance.
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
- * @param weightCol field in "predictions" which gives the weight of each instance as a vector.
+ * @param weightCol field in "predictions" which gives the weight of each instance.
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
*/
private class LogisticRegressionTrainingSummaryImpl(
@@ -1476,7 +1455,7 @@ private class LogisticRegressionTrainingSummaryImpl(
* double.
* @param labelCol field in "predictions" which gives the true label of each instance.
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
- * @param weightCol field in "predictions" which gives the weight of each instance as a vector.
+ * @param weightCol field in "predictions" which gives the weight of each instance.
*/
private class LogisticRegressionSummaryImpl(
@transient override val predictions: DataFrame,
@@ -1497,7 +1476,7 @@ private class LogisticRegressionSummaryImpl(
* double.
* @param labelCol field in "predictions" which gives the true label of each instance.
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
- * @param weightCol field in "predictions" which gives the weight of each instance as a vector.
+ * @param weightCol field in "predictions" which gives the weight of each instance.
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
*/
private class BinaryLogisticRegressionTrainingSummaryImpl(
@@ -1522,7 +1501,7 @@ private class BinaryLogisticRegressionTrainingSummaryImpl(
* each class as a double.
* @param labelCol field in "predictions" which gives the true label of each instance.
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
- * @param weightCol field in "predictions" which gives the weight of each instance as a vector.
+ * @param weightCol field in "predictions" which gives the weight of each instance.
*/
private class BinaryLogisticRegressionSummaryImpl(
predictions: DataFrame,
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index 9758e3ca72c38..1caaeccd7b0d8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.classification
import org.apache.spark.annotation.Since
import org.apache.spark.ml.linalg.{DenseVector, Vector, VectorUDT}
+import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.sql.{DataFrame, Dataset}
@@ -229,6 +230,27 @@ abstract class ProbabilisticClassificationModel[
argMax
}
}
+
+ /**
+ *If the probability and prediction columns are set, this method returns the current model,
+ * otherwise it generates new columns for them and sets them as columns on a new copy of
+ * the current model
+ */
+ override private[classification] def findSummaryModel():
+ (ProbabilisticClassificationModel[FeaturesType, M], String, String) = {
+ val model = if ($(probabilityCol).isEmpty && $(predictionCol).isEmpty) {
+ copy(ParamMap.empty)
+ .setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString)
+ .setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
+ } else if ($(probabilityCol).isEmpty) {
+ copy(ParamMap.empty).setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString)
+ } else if ($(predictionCol).isEmpty) {
+ copy(ParamMap.empty).setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
+ } else {
+ this
+ }
+ (model, model.getProbabilityCol, model.getPredictionCol)
+ }
}
private[ml] object ProbabilisticClassificationModel {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index a316e472d9674..f9ce62b91924b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -166,7 +166,35 @@ class RandomForestClassifier @Since("1.4.0") (
val numFeatures = trees.head.numFeatures
instr.logNumClasses(numClasses)
instr.logNumFeatures(numFeatures)
- new RandomForestClassificationModel(uid, trees, numFeatures, numClasses)
+ createModel(dataset, trees, numFeatures, numClasses)
+ }
+
+ private def createModel(
+ dataset: Dataset[_],
+ trees: Array[DecisionTreeClassificationModel],
+ numFeatures: Int,
+ numClasses: Int): RandomForestClassificationModel = {
+ val model = copyValues(new RandomForestClassificationModel(uid, trees, numFeatures, numClasses))
+ val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
+
+ val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
+ val rfSummary = if (numClasses <= 2) {
+ new BinaryRandomForestClassificationTrainingSummaryImpl(
+ summaryModel.transform(dataset),
+ probabilityColName,
+ predictionColName,
+ $(labelCol),
+ weightColName,
+ Array(0.0))
+ } else {
+ new RandomForestClassificationTrainingSummaryImpl(
+ summaryModel.transform(dataset),
+ predictionColName,
+ $(labelCol),
+ weightColName,
+ Array(0.0))
+ }
+ model.setSummary(Some(rfSummary))
}
@Since("1.4.1")
@@ -204,7 +232,8 @@ class RandomForestClassificationModel private[ml] (
@Since("1.5.0") override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
with RandomForestClassifierParams with TreeEnsembleModel[DecisionTreeClassificationModel]
- with MLWritable with Serializable {
+ with MLWritable with Serializable
+ with HasTrainingSummary[RandomForestClassificationTrainingSummary] {
require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.")
@@ -228,6 +257,44 @@ class RandomForestClassificationModel private[ml] (
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
+ /**
+ * Gets summary of model on training set. An exception is thrown
+ * if `hasSummary` is false.
+ */
+ @Since("3.1.0")
+ override def summary: RandomForestClassificationTrainingSummary = super.summary
+
+ /**
+ * Gets summary of model on training set. An exception is thrown
+ * if `hasSummary` is false or it is a multiclass model.
+ */
+ @Since("3.1.0")
+ def binarySummary: BinaryRandomForestClassificationTrainingSummary = summary match {
+ case b: BinaryRandomForestClassificationTrainingSummary => b
+ case _ =>
+ throw new RuntimeException("Cannot create a binary summary for a non-binary model" +
+ s"(numClasses=${numClasses}), use summary instead.")
+ }
+
+ /**
+ * Evaluates the model on a test dataset.
+ *
+ * @param dataset Test dataset to evaluate model on.
+ */
+ @Since("3.1.0")
+ def evaluate(dataset: Dataset[_]): RandomForestClassificationSummary = {
+ val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
+ // Handle possible missing or invalid prediction columns
+ val (summaryModel, probabilityColName, predictionColName) = findSummaryModel()
+ if (numClasses > 2) {
+ new RandomForestClassificationSummaryImpl(summaryModel.transform(dataset),
+ predictionColName, $(labelCol), weightColName)
+ } else {
+ new BinaryRandomForestClassificationSummaryImpl(summaryModel.transform(dataset),
+ probabilityColName, predictionColName, $(labelCol), weightColName)
+ }
+ }
+
@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
var outputSchema = super.transformSchema(schema)
@@ -388,3 +455,113 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica
new RandomForestClassificationModel(uid, newTrees, numFeatures, numClasses)
}
}
+
+/**
+ * Abstraction for multiclass RandomForestClassification results for a given model.
+ */
+sealed trait RandomForestClassificationSummary extends ClassificationSummary {
+ /**
+ * Convenient method for casting to BinaryRandomForestClassificationSummary.
+ * This method will throw an Exception if the summary is not a binary summary.
+ */
+ @Since("3.1.0")
+ def asBinary: BinaryRandomForestClassificationSummary = this match {
+ case b: BinaryRandomForestClassificationSummary => b
+ case _ =>
+ throw new RuntimeException("Cannot cast to a binary summary.")
+ }
+}
+
+/**
+ * Abstraction for multiclass RandomForestClassification training results.
+ */
+sealed trait RandomForestClassificationTrainingSummary extends RandomForestClassificationSummary
+ with TrainingSummary
+
+/**
+ * Abstraction for BinaryRandomForestClassification results for a given model.
+ */
+sealed trait BinaryRandomForestClassificationSummary extends BinaryClassificationSummary
+
+/**
+ * Abstraction for BinaryRandomForestClassification training results.
+ */
+sealed trait BinaryRandomForestClassificationTrainingSummary extends
+ BinaryRandomForestClassificationSummary with RandomForestClassificationTrainingSummary
+
+/**
+ * Multiclass RandomForestClassification training results.
+ *
+ * @param predictions dataframe output by the model's `transform` method.
+ * @param predictionCol field in "predictions" which gives the prediction for a data instance as a
+ * double.
+ * @param labelCol field in "predictions" which gives the true label of each instance.
+ * @param weightCol field in "predictions" which gives the weight of each instance.
+ * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
+ */
+private class RandomForestClassificationTrainingSummaryImpl(
+ predictions: DataFrame,
+ predictionCol: String,
+ labelCol: String,
+ weightCol: String,
+ override val objectiveHistory: Array[Double])
+ extends RandomForestClassificationSummaryImpl(
+ predictions, predictionCol, labelCol, weightCol)
+ with RandomForestClassificationTrainingSummary
+
+/**
+ * Multiclass RandomForestClassification results for a given model.
+ *
+ * @param predictions dataframe output by the model's `transform` method.
+ * @param predictionCol field in "predictions" which gives the prediction for a data instance as a
+ * double.
+ * @param labelCol field in "predictions" which gives the true label of each instance.
+ * @param weightCol field in "predictions" which gives the weight of each instance.
+ */
+private class RandomForestClassificationSummaryImpl(
+ @transient override val predictions: DataFrame,
+ override val predictionCol: String,
+ override val labelCol: String,
+ override val weightCol: String)
+ extends RandomForestClassificationSummary
+
+/**
+ * Binary RandomForestClassification training results.
+ *
+ * @param predictions dataframe output by the model's `transform` method.
+ * @param scoreCol field in "predictions" which gives the probability of each class as a vector.
+ * @param predictionCol field in "predictions" which gives the prediction for a data instance as a
+ * double.
+ * @param labelCol field in "predictions" which gives the true label of each instance.
+ * @param weightCol field in "predictions" which gives the weight of each instance.
+ * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
+ */
+private class BinaryRandomForestClassificationTrainingSummaryImpl(
+ predictions: DataFrame,
+ scoreCol: String,
+ predictionCol: String,
+ labelCol: String,
+ weightCol: String,
+ override val objectiveHistory: Array[Double])
+ extends BinaryRandomForestClassificationSummaryImpl(
+ predictions, scoreCol, predictionCol, labelCol, weightCol)
+ with BinaryRandomForestClassificationTrainingSummary
+
+/**
+ * Binary RandomForestClassification for a given model.
+ *
+ * @param predictions dataframe output by the model's `transform` method.
+ * @param scoreCol field in "predictions" which gives the prediction of
+ * each class as a vector.
+ * @param labelCol field in "predictions" which gives the true label of each instance.
+ * @param weightCol field in "predictions" which gives the weight of each instance.
+ */
+private class BinaryRandomForestClassificationSummaryImpl(
+ predictions: DataFrame,
+ override val scoreCol: String,
+ predictionCol: String,
+ labelCol: String,
+ weightCol: String)
+ extends RandomForestClassificationSummaryImpl(
+ predictions, predictionCol, labelCol, weightCol)
+ with BinaryRandomForestClassificationSummary
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index ecee531c88a8f..56eadff6df078 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -342,7 +342,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
blorModel2.summary.asBinary.weightedPrecision relTol 1e-6)
assert(blorModel.summary.asBinary.weightedRecall ~==
blorModel2.summary.asBinary.weightedRecall relTol 1e-6)
- assert(blorModel.summary.asBinary.asBinary.areaUnderROC ~==
+ assert(blorModel.summary.asBinary.areaUnderROC ~==
blorModel2.summary.asBinary.areaUnderROC relTol 1e-6)
assert(mlorSummary.accuracy ~== mlorSummary2.accuracy relTol 1e-6)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index e30e93ad4628c..645a436fa0ad6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -31,6 +31,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions._
/**
* Test suite for [[RandomForestClassifier]].
@@ -296,6 +297,115 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest {
}
}
+ test("summary for binary and multiclass") {
+ val arr = new Array[LabeledPoint](300)
+ for (i <- 0 until 300) {
+ if (i < 100) {
+ arr(i) = new LabeledPoint(0.0, Vectors.dense(2.0, 2.0))
+ } else if (i < 200) {
+ arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0))
+ } else {
+ arr(i) = new LabeledPoint(2.0, Vectors.dense(0.0, 2.0))
+ }
+ }
+ val rdd = sc.parallelize(arr)
+ val multinomialDataset = spark.createDataFrame(rdd)
+
+ val rf = new RandomForestClassifier()
+
+ val brfModel = rf.fit(binaryDataset)
+ assert(brfModel.summary.isInstanceOf[BinaryRandomForestClassificationTrainingSummary])
+ assert(brfModel.summary.asBinary.isInstanceOf[BinaryRandomForestClassificationTrainingSummary])
+ assert(brfModel.binarySummary.isInstanceOf[RandomForestClassificationTrainingSummary])
+ assert(brfModel.summary.totalIterations === 0)
+ assert(brfModel.binarySummary.totalIterations === 0)
+
+ val mrfModel = rf.fit(multinomialDataset)
+ assert(mrfModel.summary.isInstanceOf[RandomForestClassificationTrainingSummary])
+ withClue("cannot get binary summary for multiclass model") {
+ intercept[RuntimeException] {
+ mrfModel.binarySummary
+ }
+ }
+ withClue("cannot cast summary to binary summary multiclass model") {
+ intercept[RuntimeException] {
+ mrfModel.summary.asBinary
+ }
+ }
+ assert(mrfModel.summary.totalIterations === 0)
+
+ val brfSummary = brfModel.evaluate(binaryDataset)
+ val mrfSummary = mrfModel.evaluate(multinomialDataset)
+ assert(brfSummary.isInstanceOf[BinaryRandomForestClassificationSummary])
+ assert(mrfSummary.isInstanceOf[RandomForestClassificationSummary])
+
+ assert(brfSummary.accuracy === brfModel.summary.accuracy)
+ assert(brfSummary.weightedPrecision === brfModel.summary.weightedPrecision)
+ assert(brfSummary.weightedRecall === brfModel.summary.weightedRecall)
+ assert(brfSummary.asBinary.areaUnderROC ~== brfModel.summary.asBinary.areaUnderROC relTol 1e-6)
+
+ // verify instance weight works
+ val rf2 = new RandomForestClassifier()
+ .setWeightCol("weight")
+
+ val binaryDatasetWithWeight =
+ binaryDataset.select(col("label"), col("features"), lit(2.5).as("weight"))
+
+ val multinomialDatasetWithWeight =
+ multinomialDataset.select(col("label"), col("features"), lit(10.0).as("weight"))
+
+ val brfModel2 = rf2.fit(binaryDatasetWithWeight)
+ assert(brfModel2.summary.isInstanceOf[BinaryRandomForestClassificationTrainingSummary])
+ assert(brfModel2.summary.asBinary.isInstanceOf[BinaryRandomForestClassificationTrainingSummary])
+ assert(brfModel2.binarySummary.isInstanceOf[BinaryRandomForestClassificationTrainingSummary])
+
+ val mrfModel2 = rf2.fit(multinomialDatasetWithWeight)
+ assert(mrfModel2.summary.isInstanceOf[RandomForestClassificationTrainingSummary])
+ withClue("cannot get binary summary for multiclass model") {
+ intercept[RuntimeException] {
+ mrfModel2.binarySummary
+ }
+ }
+ withClue("cannot cast summary to binary summary multiclass model") {
+ intercept[RuntimeException] {
+ mrfModel2.summary.asBinary
+ }
+ }
+
+ val brfSummary2 = brfModel2.evaluate(binaryDatasetWithWeight)
+ val mrfSummary2 = mrfModel2.evaluate(multinomialDatasetWithWeight)
+ assert(brfSummary2.isInstanceOf[BinaryRandomForestClassificationSummary])
+ assert(mrfSummary2.isInstanceOf[RandomForestClassificationSummary])
+
+ assert(brfSummary2.accuracy === brfModel2.summary.accuracy)
+ assert(brfSummary2.weightedPrecision === brfModel2.summary.weightedPrecision)
+ assert(brfSummary2.weightedRecall === brfModel2.summary.weightedRecall)
+ assert(brfSummary2.asBinary.areaUnderROC ~==
+ brfModel2.summary.asBinary.areaUnderROC relTol 1e-6)
+
+ assert(brfSummary.accuracy ~== brfSummary2.accuracy relTol 1e-6)
+ assert(brfSummary.weightedPrecision ~== brfSummary2.weightedPrecision relTol 1e-6)
+ assert(brfSummary.weightedRecall ~== brfSummary2.weightedRecall relTol 1e-6)
+ assert(brfSummary.asBinary.areaUnderROC ~== brfSummary2.asBinary.areaUnderROC relTol 1e-6)
+
+ assert(brfModel.summary.asBinary.accuracy ~==
+ brfModel2.summary.asBinary.accuracy relTol 1e-6)
+ assert(brfModel.summary.asBinary.weightedPrecision ~==
+ brfModel2.summary.asBinary.weightedPrecision relTol 1e-6)
+ assert(brfModel.summary.asBinary.weightedRecall ~==
+ brfModel2.summary.asBinary.weightedRecall relTol 1e-6)
+ assert(brfModel.summary.asBinary.areaUnderROC ~==
+ brfModel2.summary.asBinary.areaUnderROC relTol 1e-6)
+
+ assert(mrfSummary.accuracy ~== mrfSummary2.accuracy relTol 1e-6)
+ assert(mrfSummary.weightedPrecision ~== mrfSummary2.weightedPrecision relTol 1e-6)
+ assert(mrfSummary.weightedRecall ~== mrfSummary2.weightedRecall relTol 1e-6)
+
+ assert(mrfModel.summary.accuracy ~== mrfModel2.summary.accuracy relTol 1e-6)
+ assert(mrfModel.summary.weightedPrecision ~== mrfModel2.summary.weightedPrecision relTol 1e-6)
+ assert(mrfModel.summary.weightedRecall ~==mrfModel2.summary.weightedRecall relTol 1e-6)
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index bdd37c99df0a8..d70932a1bc6fc 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -46,6 +46,9 @@
'DecisionTreeClassifier', 'DecisionTreeClassificationModel',
'GBTClassifier', 'GBTClassificationModel',
'RandomForestClassifier', 'RandomForestClassificationModel',
+ 'RandomForestClassificationSummary', 'RandomForestClassificationTrainingSummary',
+ 'BinaryRandomForestClassificationSummary',
+ 'BinaryRandomForestClassificationTrainingSummary',
'NaiveBayes', 'NaiveBayesModel',
'MultilayerPerceptronClassifier', 'MultilayerPerceptronClassificationModel',
'OneVsRest', 'OneVsRestModel',
@@ -1762,7 +1765,7 @@ def setMinWeightFractionPerNode(self, value):
class RandomForestClassificationModel(_TreeEnsembleModel, _JavaProbabilisticClassificationModel,
_RandomForestClassifierParams, JavaMLWritable,
- JavaMLReadable):
+ JavaMLReadable, HasTrainingSummary):
"""
Model fitted by RandomForestClassifier.
@@ -1790,6 +1793,80 @@ def trees(self):
"""Trees in this ensemble. Warning: These have null parent Estimators."""
return [DecisionTreeClassificationModel(m) for m in list(self._call_java("trees"))]
+ @property
+ @since("3.1.0")
+ def summary(self):
+ """
+ Gets summary (e.g. accuracy/precision/recall, objective history, total iterations) of model
+ trained on the training set. An exception is thrown if `trainingSummary is None`.
+ """
+ if self.hasSummary:
+ if self.numClasses <= 2:
+ return BinaryRandomForestClassificationTrainingSummary(
+ super(RandomForestClassificationModel, self).summary)
+ else:
+ return RandomForestClassificationTrainingSummary(
+ super(RandomForestClassificationModel, self).summary)
+ else:
+ raise RuntimeError("No training summary available for this %s" %
+ self.__class__.__name__)
+
+ @since("3.1.0")
+ def evaluate(self, dataset):
+ """
+ Evaluates the model on a test dataset.
+
+ :param dataset:
+ Test dataset to evaluate model on, where dataset is an
+ instance of :py:class:`pyspark.sql.DataFrame`
+ """
+ if not isinstance(dataset, DataFrame):
+ raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
+ java_rf_summary = self._call_java("evaluate", dataset)
+ if self.numClasses <= 2:
+ return BinaryRandomForestClassificationSummary(java_rf_summary)
+ else:
+ return RandomForestClassificationSummary(java_rf_summary)
+
+
+class RandomForestClassificationSummary(_ClassificationSummary):
+ """
+ Abstraction for RandomForestClassification Results for a given model.
+ .. versionadded:: 3.1.0
+ """
+ pass
+
+
+@inherit_doc
+class RandomForestClassificationTrainingSummary(RandomForestClassificationSummary,
+ _TrainingSummary):
+ """
+ Abstraction for RandomForestClassificationTraining Training results.
+ .. versionadded:: 3.1.0
+ """
+ pass
+
+
+@inherit_doc
+class BinaryRandomForestClassificationSummary(_BinaryClassificationSummary):
+ """
+ BinaryRandomForestClassification results for a given model.
+
+ .. versionadded:: 3.1.0
+ """
+ pass
+
+
+@inherit_doc
+class BinaryRandomForestClassificationTrainingSummary(BinaryRandomForestClassificationSummary,
+ RandomForestClassificationTrainingSummary):
+ """
+ BinaryRandomForestClassification training results for a given model.
+
+ .. versionadded:: 3.1.0
+ """
+ pass
+
class _GBTClassifierParams(_GBTParams, _HasVarianceImpurity):
"""
diff --git a/python/pyspark/ml/tests/test_training_summary.py b/python/pyspark/ml/tests/test_training_summary.py
index 19acd194f4ddf..7d905793188bf 100644
--- a/python/pyspark/ml/tests/test_training_summary.py
+++ b/python/pyspark/ml/tests/test_training_summary.py
@@ -22,7 +22,9 @@
basestring = str
from pyspark.ml.classification import BinaryLogisticRegressionSummary, LinearSVC, \
- LinearSVCSummary, LogisticRegression, LogisticRegressionSummary
+ LinearSVCSummary, BinaryRandomForestClassificationSummary, LogisticRegression, \
+ LogisticRegressionSummary, RandomForestClassificationSummary, \
+ RandomForestClassifier
from pyspark.ml.clustering import BisectingKMeans, GaussianMixture, KMeans
from pyspark.ml.linalg import Vectors
from pyspark.ml.regression import GeneralizedLinearRegression, LinearRegression
@@ -235,6 +237,81 @@ def test_linear_svc_summary(self):
self.assertTrue(isinstance(sameSummary, LinearSVCSummary))
self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC)
+ def test_binary_randomforest_classification_summary(self):
+ df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
+ (0.0, 2.0, Vectors.sparse(1, [], []))],
+ ["label", "weight", "features"])
+ rf = RandomForestClassifier(weightCol="weight")
+ model = rf.fit(df)
+ self.assertTrue(model.hasSummary)
+ s = model.summary
+ # test that api is callable and returns expected types
+ self.assertTrue(isinstance(s.predictions, DataFrame))
+ self.assertEqual(s.labelCol, "label")
+ self.assertEqual(s.predictionCol, "prediction")
+ self.assertEqual(s.totalIterations, 0)
+ self.assertTrue(isinstance(s.labels, list))
+ self.assertTrue(isinstance(s.truePositiveRateByLabel, list))
+ self.assertTrue(isinstance(s.falsePositiveRateByLabel, list))
+ self.assertTrue(isinstance(s.precisionByLabel, list))
+ self.assertTrue(isinstance(s.recallByLabel, list))
+ self.assertTrue(isinstance(s.fMeasureByLabel(), list))
+ self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list))
+ self.assertTrue(isinstance(s.roc, DataFrame))
+ self.assertAlmostEqual(s.areaUnderROC, 1.0, 2)
+ self.assertTrue(isinstance(s.pr, DataFrame))
+ self.assertTrue(isinstance(s.fMeasureByThreshold, DataFrame))
+ self.assertTrue(isinstance(s.precisionByThreshold, DataFrame))
+ self.assertTrue(isinstance(s.recallByThreshold, DataFrame))
+ self.assertAlmostEqual(s.accuracy, 1.0, 2)
+ self.assertAlmostEqual(s.weightedTruePositiveRate, 1.0, 2)
+ self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.0, 2)
+ self.assertAlmostEqual(s.weightedRecall, 1.0, 2)
+ self.assertAlmostEqual(s.weightedPrecision, 1.0, 2)
+ self.assertAlmostEqual(s.weightedFMeasure(), 1.0, 2)
+ self.assertAlmostEqual(s.weightedFMeasure(1.0), 1.0, 2)
+ # test evaluation (with training dataset) produces a summary with same values
+ # one check is enough to verify a summary is returned, Scala version runs full test
+ sameSummary = model.evaluate(df)
+ self.assertTrue(isinstance(sameSummary, BinaryRandomForestClassificationSummary))
+ self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC)
+
+ def test_multiclass_randomforest_classification_summary(self):
+ df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
+ (0.0, 2.0, Vectors.sparse(1, [], [])),
+ (2.0, 2.0, Vectors.dense(2.0)),
+ (2.0, 2.0, Vectors.dense(1.9))],
+ ["label", "weight", "features"])
+ rf = RandomForestClassifier(weightCol="weight")
+ model = rf.fit(df)
+ self.assertTrue(model.hasSummary)
+ s = model.summary
+ # test that api is callable and returns expected types
+ self.assertTrue(isinstance(s.predictions, DataFrame))
+ self.assertEqual(s.labelCol, "label")
+ self.assertEqual(s.predictionCol, "prediction")
+ self.assertEqual(s.totalIterations, 0)
+ self.assertTrue(isinstance(s.labels, list))
+ self.assertTrue(isinstance(s.truePositiveRateByLabel, list))
+ self.assertTrue(isinstance(s.falsePositiveRateByLabel, list))
+ self.assertTrue(isinstance(s.precisionByLabel, list))
+ self.assertTrue(isinstance(s.recallByLabel, list))
+ self.assertTrue(isinstance(s.fMeasureByLabel(), list))
+ self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list))
+ self.assertAlmostEqual(s.accuracy, 1.0, 2)
+ self.assertAlmostEqual(s.weightedTruePositiveRate, 1.0, 2)
+ self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.0, 2)
+ self.assertAlmostEqual(s.weightedRecall, 1.0, 2)
+ self.assertAlmostEqual(s.weightedPrecision, 1.0, 2)
+ self.assertAlmostEqual(s.weightedFMeasure(), 1.0, 2)
+ self.assertAlmostEqual(s.weightedFMeasure(1.0), 1.0, 2)
+ # test evaluation (with training dataset) produces a summary with same values
+ # one check is enough to verify a summary is returned, Scala version runs full test
+ sameSummary = model.evaluate(df)
+ self.assertTrue(isinstance(sameSummary, RandomForestClassificationSummary))
+ self.assertFalse(isinstance(sameSummary, BinaryRandomForestClassificationSummary))
+ self.assertAlmostEqual(sameSummary.accuracy, s.accuracy)
+
def test_gaussian_mixture_summary(self):
data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),),
(Vectors.sparse(1, [], []),)]
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index b0498d0298785..b5a7c18904b14 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1433,8 +1433,12 @@ def timestamp_seconds(col):
>>> from pyspark.sql.functions import timestamp_seconds
>>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles")
>>> time_df = spark.createDataFrame([(1230219000,)], ['unix_time'])
- >>> time_df.select(timestamp_seconds(time_df.unix_time).alias('ts')).collect()
- [Row(ts=datetime.datetime(2008, 12, 25, 7, 30))]
+ >>> time_df.select(timestamp_seconds(time_df.unix_time).alias('ts')).show()
+ +-------------------+
+ | ts|
+ +-------------------+
+ |2008-12-25 07:30:00|
+ +-------------------+
>>> spark.conf.unset("spark.sql.session.timeZone")
"""
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index be4fa20a04327..61891c478dbe4 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -265,7 +265,10 @@ def newSession(self):
@since(3.0)
def getActiveSession(cls):
"""
- Returns the active SparkSession for the current thread, returned by the builder.
+ Returns the active SparkSession for the current thread, returned by the builder
+
+ :return: :class:`SparkSession` if an active session exists for the current thread
+
>>> s = SparkSession.getActiveSession()
>>> l = [('Alice', 1)]
>>> rdd = s.sparkContext.parallelize(l)
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 9a6a43914bca3..5ca624a8d66cb 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -211,9 +211,13 @@ private[spark] class ApplicationMaster(
final def run(): Int = {
try {
val attemptID = if (isClusterMode) {
- // Set the web ui port to be ephemeral for yarn so we don't conflict with
- // other spark processes running on the same box
- System.setProperty(UI_PORT.key, "0")
+ // Set the web ui port to be ephemeral for yarn if not set explicitly
+ // so we don't conflict with other spark processes running on the same box
+ // If set explicitly, Web UI will attempt to run on UI_PORT and try
+ // incrementally until UI_PORT + `spark.port.maxRetries`
+ if (System.getProperty(UI_PORT.key) == null) {
+ System.setProperty(UI_PORT.key, "0")
+ }
// Set the master and deploy mode property to match the requested mode.
System.setProperty("spark.master", "yarn")
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsReportStatistics.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsReportStatistics.java
index b839fd5a4a726..1e0c9ca7c7e4b 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsReportStatistics.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsReportStatistics.java
@@ -23,9 +23,9 @@
* A mix in interface for {@link Scan}. Data sources can implement this interface to
* report statistics to Spark.
*
- * As of Spark 2.4, statistics are reported to the optimizer before any operator is pushed to the
- * data source. Implementations that return more accurate statistics based on pushed operators will
- * not improve query performance until the planner can push operators before getting stats.
+ * As of Spark 3.0, statistics are reported to the optimizer after operators are pushed to the
+ * data source. Implementations may return more accurate statistics based on pushed operators
+ * which may improve query performance by providing better information to the optimizer.
*
* @since 3.0.0
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 9e325d0c2e4e1..9c99acaa994b8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -337,7 +337,8 @@ trait CheckAnalysis extends PredicateHelper {
def ordinalNumber(i: Int): String = i match {
case 0 => "first"
case 1 => "second"
- case i => s"${i}th"
+ case 2 => "third"
+ case i => s"${i + 1}th"
}
val ref = dataTypes(operator.children.head)
operator.children.tail.zipWithIndex.foreach { case (child, ti) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
index 81de086e78f91..4cbff62e16cc1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
@@ -183,7 +183,7 @@ object ResolveHints {
val hintName = hint.name.toUpperCase(Locale.ROOT)
def createRepartitionByExpression(
- numPartitions: Int, partitionExprs: Seq[Any]): RepartitionByExpression = {
+ numPartitions: Option[Int], partitionExprs: Seq[Any]): RepartitionByExpression = {
val sortOrders = partitionExprs.filter(_.isInstanceOf[SortOrder])
if (sortOrders.nonEmpty) throw new IllegalArgumentException(
s"""Invalid partitionExprs specified: $sortOrders
@@ -208,11 +208,11 @@ object ResolveHints {
throw new AnalysisException(s"$hintName Hint expects a partition number as a parameter")
case param @ Seq(IntegerLiteral(numPartitions), _*) if shuffle =>
- createRepartitionByExpression(numPartitions, param.tail)
+ createRepartitionByExpression(Some(numPartitions), param.tail)
case param @ Seq(numPartitions: Int, _*) if shuffle =>
- createRepartitionByExpression(numPartitions, param.tail)
+ createRepartitionByExpression(Some(numPartitions), param.tail)
case param @ Seq(_*) if shuffle =>
- createRepartitionByExpression(conf.numShufflePartitions, param)
+ createRepartitionByExpression(None, param)
}
}
@@ -224,7 +224,7 @@ object ResolveHints {
val hintName = hint.name.toUpperCase(Locale.ROOT)
def createRepartitionByExpression(
- numPartitions: Int, partitionExprs: Seq[Any]): RepartitionByExpression = {
+ numPartitions: Option[Int], partitionExprs: Seq[Any]): RepartitionByExpression = {
val invalidParams = partitionExprs.filter(!_.isInstanceOf[UnresolvedAttribute])
if (invalidParams.nonEmpty) {
throw new AnalysisException(s"$hintName Hint parameter should include columns, but " +
@@ -239,11 +239,11 @@ object ResolveHints {
hint.parameters match {
case param @ Seq(IntegerLiteral(numPartitions), _*) =>
- createRepartitionByExpression(numPartitions, param.tail)
+ createRepartitionByExpression(Some(numPartitions), param.tail)
case param @ Seq(numPartitions: Int, _*) =>
- createRepartitionByExpression(numPartitions, param.tail)
+ createRepartitionByExpression(Some(numPartitions), param.tail)
case param @ Seq(_*) =>
- createRepartitionByExpression(conf.numShufflePartitions, param)
+ createRepartitionByExpression(None, param)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index 0e4ff4f9f2cb4..a1277217b1b3a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions
import java.text.ParseException
import java.time.{DateTimeException, LocalDate, LocalDateTime, ZoneId}
import java.time.format.DateTimeParseException
-import java.time.temporal.IsoFields
import java.util.Locale
import org.apache.commons.text.StringEscapeUtils
@@ -386,7 +385,7 @@ case class DayOfYear(child: Expression) extends GetDateField {
override val funcName = "getDayInYear"
}
-abstract class NumberToTimestampBase extends UnaryExpression
+abstract class IntegralToTimestampBase extends UnaryExpression
with ExpectsInputTypes with NullIntolerant {
protected def upScaleFactor: Long
@@ -408,19 +407,66 @@ abstract class NumberToTimestampBase extends UnaryExpression
}
}
+// scalastyle:off line.size.limit
@ExpressionDescription(
- usage = "_FUNC_(seconds) - Creates timestamp from the number of seconds since UTC epoch.",
+ usage = "_FUNC_(seconds) - Creates timestamp from the number of seconds (can be fractional) since UTC epoch.",
examples = """
Examples:
> SELECT _FUNC_(1230219000);
2008-12-25 07:30:00
+ > SELECT _FUNC_(1230219000.123);
+ 2008-12-25 07:30:00.123
""",
group = "datetime_funcs",
since = "3.1.0")
-case class SecondsToTimestamp(child: Expression)
- extends NumberToTimestampBase {
+// scalastyle:on line.size.limit
+case class SecondsToTimestamp(child: Expression) extends UnaryExpression
+ with ExpectsInputTypes with NullIntolerant {
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
+
+ override def dataType: DataType = TimestampType
- override def upScaleFactor: Long = MICROS_PER_SECOND
+ override def nullable: Boolean = child.dataType match {
+ case _: FloatType | _: DoubleType => true
+ case _ => child.nullable
+ }
+
+ @transient
+ private lazy val evalFunc: Any => Any = child.dataType match {
+ case _: IntegralType => input =>
+ Math.multiplyExact(input.asInstanceOf[Number].longValue(), MICROS_PER_SECOND)
+ case _: DecimalType => input =>
+ val operand = new java.math.BigDecimal(MICROS_PER_SECOND)
+ input.asInstanceOf[Decimal].toJavaBigDecimal.multiply(operand).longValueExact()
+ case _: FloatType => input =>
+ val f = input.asInstanceOf[Float]
+ if (f.isNaN || f.isInfinite) null else (f * MICROS_PER_SECOND).toLong
+ case _: DoubleType => input =>
+ val d = input.asInstanceOf[Double]
+ if (d.isNaN || d.isInfinite) null else (d * MICROS_PER_SECOND).toLong
+ }
+
+ override def nullSafeEval(input: Any): Any = evalFunc(input)
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = child.dataType match {
+ case _: IntegralType =>
+ defineCodeGen(ctx, ev, c => s"java.lang.Math.multiplyExact($c, ${MICROS_PER_SECOND}L)")
+ case _: DecimalType =>
+ val operand = s"new java.math.BigDecimal($MICROS_PER_SECOND)"
+ defineCodeGen(ctx, ev, c => s"$c.toJavaBigDecimal().multiply($operand).longValueExact()")
+ case other =>
+ nullSafeCodeGen(ctx, ev, c => {
+ val typeStr = CodeGenerator.boxedType(other)
+ s"""
+ |if ($typeStr.isNaN($c) || $typeStr.isInfinite($c)) {
+ | ${ev.isNull} = true;
+ |} else {
+ | ${ev.value} = (long)($c * $MICROS_PER_SECOND);
+ |}
+ |""".stripMargin
+ })
+ }
override def prettyName: String = "timestamp_seconds"
}
@@ -437,7 +483,7 @@ case class SecondsToTimestamp(child: Expression)
since = "3.1.0")
// scalastyle:on line.size.limit
case class MillisToTimestamp(child: Expression)
- extends NumberToTimestampBase {
+ extends IntegralToTimestampBase {
override def upScaleFactor: Long = MICROS_PER_MILLIS
@@ -456,7 +502,7 @@ case class MillisToTimestamp(child: Expression)
since = "3.1.0")
// scalastyle:on line.size.limit
case class MicrosToTimestamp(child: Expression)
- extends NumberToTimestampBase {
+ extends IntegralToTimestampBase {
override def upScaleFactor: Long = 1L
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 05a5ff45b8fb1..527618b8e2c5a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -211,7 +211,9 @@ trait PredicateHelper extends Logging {
* @return the CNF result as sequence of disjunctive expressions. If the number of expressions
* exceeds threshold on converting `Or`, `Seq.empty` is returned.
*/
- protected def conjunctiveNormalForm(condition: Expression): Seq[Expression] = {
+ protected def conjunctiveNormalForm(
+ condition: Expression,
+ groupExpsFunc: Seq[Expression] => Seq[Expression]): Seq[Expression] = {
val postOrderNodes = postOrderTraversal(condition)
val resultStack = new mutable.Stack[Seq[Expression]]
val maxCnfNodeCount = SQLConf.get.maxCnfNodeCount
@@ -226,8 +228,8 @@ trait PredicateHelper extends Logging {
// For each side, there is no need to expand predicates of the same references.
// So here we can aggregate predicates of the same qualifier as one single predicate,
// for reducing the size of pushed down predicates and corresponding codegen.
- val right = groupExpressionsByQualifier(resultStack.pop())
- val left = groupExpressionsByQualifier(resultStack.pop())
+ val right = groupExpsFunc(resultStack.pop())
+ val left = groupExpsFunc(resultStack.pop())
// Stop the loop whenever the result exceeds the `maxCnfNodeCount`
if (left.size * right.size > maxCnfNodeCount) {
logInfo(s"As the result size exceeds the threshold $maxCnfNodeCount. " +
@@ -249,8 +251,36 @@ trait PredicateHelper extends Logging {
resultStack.top
}
- private def groupExpressionsByQualifier(expressions: Seq[Expression]): Seq[Expression] = {
- expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq
+ /**
+ * Convert an expression to conjunctive normal form when pushing predicates through Join,
+ * when expand predicates, we can group by the qualifier avoiding generate unnecessary
+ * expression to control the length of final result since there are multiple tables.
+ *
+ * @param condition condition need to be converted
+ * @return the CNF result as sequence of disjunctive expressions. If the number of expressions
+ * exceeds threshold on converting `Or`, `Seq.empty` is returned.
+ */
+ def CNFWithGroupExpressionsByQualifier(condition: Expression): Seq[Expression] = {
+ conjunctiveNormalForm(condition, (expressions: Seq[Expression]) =>
+ expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq)
+ }
+
+ /**
+ * Convert an expression to conjunctive normal form for predicate pushdown and partition pruning.
+ * When expanding predicates, this method groups expressions by their references for reducing
+ * the size of pushed down predicates and corresponding codegen. In partition pruning strategies,
+ * we split filters by [[splitConjunctivePredicates]] and partition filters by judging if it's
+ * references is subset of partCols, if we combine expressions group by reference when expand
+ * predicate of [[Or]], it won't impact final predicate pruning result since
+ * [[splitConjunctivePredicates]] won't split [[Or]] expression.
+ *
+ * @param condition condition need to be converted
+ * @return the CNF result as sequence of disjunctive expressions. If the number of expressions
+ * exceeds threshold on converting `Or`, `Seq.empty` is returned.
+ */
+ def CNFWithGroupExpressionsByReference(condition: Expression): Seq[Expression] = {
+ conjunctiveNormalForm(condition, (expressions: Seq[Expression]) =>
+ expressions.groupBy(e => AttributeSet(e.references)).map(_._2.reduceLeft(And)).toSeq)
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
index f9222f5af54da..70a673bb42457 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
@@ -133,7 +133,7 @@ private[sql] class JSONOptions(
* Enables inferring of TimestampType from strings matched to the timestamp pattern
* defined by the timestampFormat option.
*/
- val inferTimestamp: Boolean = parameters.get("inferTimestamp").map(_.toBoolean).getOrElse(true)
+ val inferTimestamp: Boolean = parameters.get("inferTimestamp").map(_.toBoolean).getOrElse(false)
/** Build a Jackson [[JsonFactory]] using JSON options. */
def buildJsonFactory(): JsonFactory = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
index 43738204c6704..8d5dbc7dc90eb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.optimizer
-import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, KnownFloatingPointNormalized, LambdaFunction, NamedLambdaVariable, UnaryExpression}
+import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, If, IsNull, KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery, Window}
@@ -123,7 +123,8 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
val fields = expr.dataType.asInstanceOf[StructType].fields.indices.map { i =>
normalize(GetStructField(expr, i))
}
- CreateStruct(fields)
+ val struct = CreateStruct(fields)
+ KnownFloatingPointNormalized(If(IsNull(expr), Literal(null, struct.dataType), struct))
case _ if expr.dataType.isInstanceOf[ArrayType] =>
val ArrayType(et, containsNull) = expr.dataType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala
index 109e5f993c02e..47e9527ead7c3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala
@@ -38,7 +38,7 @@ object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelpe
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case j @ Join(left, right, joinType, Some(joinCondition), hint)
if canPushThrough(joinType) =>
- val predicates = conjunctiveNormalForm(joinCondition)
+ val predicates = CNFWithGroupExpressionsByQualifier(joinCondition)
if (predicates.isEmpty) {
j
} else {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 03571a740df3e..d08bcb1420176 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -411,12 +411,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
val mergeCondition = expression(ctx.mergeCondition)
- val matchedClauses = ctx.matchedClause()
- if (matchedClauses.size() > 2) {
- throw new ParseException("There should be at most 2 'WHEN MATCHED' clauses.",
- matchedClauses.get(2))
- }
- val matchedActions = matchedClauses.asScala.map {
+ val matchedActions = ctx.matchedClause().asScala.map {
clause => {
if (clause.matchedAction().DELETE() != null) {
DeleteAction(Option(clause.matchedCond).map(expression))
@@ -435,12 +430,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
}
}
}
- val notMatchedClauses = ctx.notMatchedClause()
- if (notMatchedClauses.size() > 1) {
- throw new ParseException("There should be at most 1 'WHEN NOT MATCHED' clause.",
- notMatchedClauses.get(1))
- }
- val notMatchedActions = notMatchedClauses.asScala.map {
+ val notMatchedActions = ctx.notMatchedClause().asScala.map {
clause => {
if (clause.notMatchedAction().INSERT() != null) {
val condition = Option(clause.notMatchedCond).map(expression)
@@ -468,13 +458,15 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
throw new ParseException("There must be at least one WHEN clause in a MERGE statement", ctx)
}
// children being empty means that the condition is not set
- if (matchedActions.length == 2 && matchedActions.head.children.isEmpty) {
- throw new ParseException("When there are 2 MATCHED clauses in a MERGE statement, " +
- "the first MATCHED clause must have a condition", ctx)
- }
- if (matchedActions.groupBy(_.getClass).mapValues(_.size).exists(_._2 > 1)) {
- throw new ParseException(
- "UPDATE and DELETE can appear at most once in MATCHED clauses in a MERGE statement", ctx)
+ val matchedActionSize = matchedActions.length
+ if (matchedActionSize >= 2 && !matchedActions.init.forall(_.condition.nonEmpty)) {
+ throw new ParseException("When there are more than one MATCHED clauses in a MERGE " +
+ "statement, only the last MATCHED clause can omit the condition.", ctx)
+ }
+ val notMatchedActionSize = notMatchedActions.length
+ if (notMatchedActionSize >= 2 && !notMatchedActions.init.forall(_.condition.nonEmpty)) {
+ throw new ParseException("When there are more than one NOT MATCHED clauses in a MERGE " +
+ "statement, only the last NOT MATCHED clause can omit the condition.", ctx)
}
MergeIntoTable(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
index 579157a6f2f2e..b4120d9f64cc5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
@@ -346,25 +346,25 @@ case class MergeIntoTable(
override def children: Seq[LogicalPlan] = Seq(targetTable, sourceTable)
}
-sealed abstract class MergeAction(
- condition: Option[Expression]) extends Expression with Unevaluable {
+sealed abstract class MergeAction extends Expression with Unevaluable {
+ def condition: Option[Expression]
override def foldable: Boolean = false
override def nullable: Boolean = false
override def dataType: DataType = throw new UnresolvedException(this, "nullable")
override def children: Seq[Expression] = condition.toSeq
}
-case class DeleteAction(condition: Option[Expression]) extends MergeAction(condition)
+case class DeleteAction(condition: Option[Expression]) extends MergeAction
case class UpdateAction(
condition: Option[Expression],
- assignments: Seq[Assignment]) extends MergeAction(condition) {
+ assignments: Seq[Assignment]) extends MergeAction {
override def children: Seq[Expression] = condition.toSeq ++ assignments
}
case class InsertAction(
condition: Option[Expression],
- assignments: Seq[Assignment]) extends MergeAction(condition) {
+ assignments: Seq[Assignment]) extends MergeAction {
override def children: Seq[Expression] = condition.toSeq ++ assignments
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala
index a81d8f79d6fcc..c6d21540f27d5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala
@@ -42,6 +42,7 @@ class CatalogManager(
defaultSessionCatalog: CatalogPlugin,
val v1SessionCatalog: SessionCatalog) extends Logging {
import CatalogManager.SESSION_CATALOG_NAME
+ import CatalogV2Util._
private val catalogs = mutable.HashMap.empty[String, CatalogPlugin]
@@ -106,13 +107,15 @@ class CatalogManager(
}
def setCurrentNamespace(namespace: Array[String]): Unit = synchronized {
- if (currentCatalog.name() == SESSION_CATALOG_NAME) {
- if (namespace.length != 1) {
+ currentCatalog match {
+ case _ if isSessionCatalog(currentCatalog) && namespace.length == 1 =>
+ v1SessionCatalog.setCurrentDatabase(namespace.head)
+ case _ if isSessionCatalog(currentCatalog) =>
throw new NoSuchNamespaceException(namespace)
- }
- v1SessionCatalog.setCurrentDatabase(namespace.head)
- } else {
- _currentNamespace = Some(namespace)
+ case catalog: SupportsNamespaces if !catalog.namespaceExists(namespace) =>
+ throw new NoSuchNamespaceException(namespace)
+ case _ =>
+ _currentNamespace = Some(namespace)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 189152374b0d1..c15ec49e14282 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -831,4 +831,57 @@ class AnalysisSuite extends AnalysisTest with Matchers {
}
}
}
+
+ test("SPARK-32131: Fix wrong column index when we have more than two columns" +
+ " during union and set operations" ) {
+ val firstTable = LocalRelation(
+ AttributeReference("a", StringType)(),
+ AttributeReference("b", DoubleType)(),
+ AttributeReference("c", IntegerType)(),
+ AttributeReference("d", FloatType)())
+
+ val secondTable = LocalRelation(
+ AttributeReference("a", StringType)(),
+ AttributeReference("b", TimestampType)(),
+ AttributeReference("c", IntegerType)(),
+ AttributeReference("d", FloatType)())
+
+ val thirdTable = LocalRelation(
+ AttributeReference("a", StringType)(),
+ AttributeReference("b", DoubleType)(),
+ AttributeReference("c", TimestampType)(),
+ AttributeReference("d", FloatType)())
+
+ val fourthTable = LocalRelation(
+ AttributeReference("a", StringType)(),
+ AttributeReference("b", DoubleType)(),
+ AttributeReference("c", IntegerType)(),
+ AttributeReference("d", TimestampType)())
+
+ val r1 = Union(firstTable, secondTable)
+ val r2 = Union(firstTable, thirdTable)
+ val r3 = Union(firstTable, fourthTable)
+ val r4 = Except(firstTable, secondTable, isAll = false)
+ val r5 = Intersect(firstTable, secondTable, isAll = false)
+
+ assertAnalysisError(r1,
+ Seq("Union can only be performed on tables with the compatible column types. " +
+ "timestamp <> double at the second column of the second table"))
+
+ assertAnalysisError(r2,
+ Seq("Union can only be performed on tables with the compatible column types. " +
+ "timestamp <> int at the third column of the second table"))
+
+ assertAnalysisError(r3,
+ Seq("Union can only be performed on tables with the compatible column types. " +
+ "timestamp <> float at the 4th column of the second table"))
+
+ assertAnalysisError(r4,
+ Seq("Except can only be performed on tables with the compatible column types. " +
+ "timestamp <> double at the second column of the second table"))
+
+ assertAnalysisError(r5,
+ Seq("Intersect can only be performed on tables with the compatible column types. " +
+ "timestamp <> double at the second column of the second table"))
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
index d3bd5d07a0932..513f1d001f757 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
@@ -163,7 +163,7 @@ class ResolveHintsSuite extends AnalysisTest {
checkAnalysis(
UnresolvedHint("REPARTITION", Seq(UnresolvedAttribute("a")), table("TaBlE")),
RepartitionByExpression(
- Seq(AttributeReference("a", IntegerType)()), testRelation, conf.numShufflePartitions))
+ Seq(AttributeReference("a", IntegerType)()), testRelation, None))
val e = intercept[IllegalArgumentException] {
checkAnalysis(
@@ -187,7 +187,7 @@ class ResolveHintsSuite extends AnalysisTest {
"REPARTITION_BY_RANGE", Seq(UnresolvedAttribute("a")), table("TaBlE")),
RepartitionByExpression(
Seq(SortOrder(AttributeReference("a", IntegerType)(), Ascending)),
- testRelation, conf.numShufflePartitions))
+ testRelation, None))
val errMsg2 = "REPARTITION Hint parameter should include columns, but"
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala
index b449ed5cc0d07..793abccd79405 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala
@@ -43,7 +43,7 @@ class ConjunctiveNormalFormPredicateSuite extends SparkFunSuite with PredicateHe
// Check CNF conversion with expected expression, assuming the input has non-empty result.
private def checkCondition(input: Expression, expected: Expression): Unit = {
- val cnf = conjunctiveNormalForm(input)
+ val cnf = CNFWithGroupExpressionsByQualifier(input)
assert(cnf.nonEmpty)
val result = cnf.reduceLeft(And)
assert(result.semanticEquals(expected))
@@ -113,14 +113,14 @@ class ConjunctiveNormalFormPredicateSuite extends SparkFunSuite with PredicateHe
Seq(8, 9, 10, 35, 36, 37).foreach { maxCount =>
withSQLConf(SQLConf.MAX_CNF_NODE_COUNT.key -> maxCount.toString) {
if (maxCount < 36) {
- assert(conjunctiveNormalForm(input).isEmpty)
+ assert(CNFWithGroupExpressionsByQualifier(input).isEmpty)
} else {
- assert(conjunctiveNormalForm(input).nonEmpty)
+ assert(CNFWithGroupExpressionsByQualifier(input).nonEmpty)
}
if (maxCount < 9) {
- assert(conjunctiveNormalForm(input2).isEmpty)
+ assert(CNFWithGroupExpressionsByQualifier(input2).isEmpty)
} else {
- assert(conjunctiveNormalForm(input2).nonEmpty)
+ assert(CNFWithGroupExpressionsByQualifier(input2).nonEmpty)
}
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
index 4edf95d8f994b..85492084d51ac 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
@@ -1142,28 +1142,6 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
- test("SPARK-31710:Adds TIMESTAMP_SECONDS, " +
- "TIMESTAMP_MILLIS and TIMESTAMP_MICROS functions") {
- checkEvaluation(SecondsToTimestamp(Literal(1230219000)), 1230219000L * MICROS_PER_SECOND)
- checkEvaluation(SecondsToTimestamp(Literal(-1230219000)), -1230219000L * MICROS_PER_SECOND)
- checkEvaluation(SecondsToTimestamp(Literal(null, IntegerType)), null)
- checkEvaluation(MillisToTimestamp(Literal(1230219000123L)), 1230219000123L * MICROS_PER_MILLIS)
- checkEvaluation(MillisToTimestamp(
- Literal(-1230219000123L)), -1230219000123L * MICROS_PER_MILLIS)
- checkEvaluation(MillisToTimestamp(Literal(null, IntegerType)), null)
- checkEvaluation(MicrosToTimestamp(Literal(1230219000123123L)), 1230219000123123L)
- checkEvaluation(MicrosToTimestamp(Literal(-1230219000123123L)), -1230219000123123L)
- checkEvaluation(MicrosToTimestamp(Literal(null, IntegerType)), null)
- checkExceptionInExpression[ArithmeticException](
- SecondsToTimestamp(Literal(1230219000123123L)), "long overflow")
- checkExceptionInExpression[ArithmeticException](
- SecondsToTimestamp(Literal(-1230219000123123L)), "long overflow")
- checkExceptionInExpression[ArithmeticException](
- MillisToTimestamp(Literal(92233720368547758L)), "long overflow")
- checkExceptionInExpression[ArithmeticException](
- MillisToTimestamp(Literal(-92233720368547758L)), "long overflow")
- }
-
test("Consistent error handling for datetime formatting and parsing functions") {
def checkException[T <: Exception : ClassTag](c: String): Unit = {
@@ -1194,4 +1172,118 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
new ParseToTimestamp(Literal("11:11 PM"), Literal("mm:ss a")).child,
Timestamp.valueOf("1970-01-01 12:11:11.0"))
}
+
+ def testIntegralInput(testFunc: Number => Unit): Unit = {
+ def checkResult(input: Long): Unit = {
+ if (input.toByte == input) {
+ testFunc(input.toByte)
+ } else if (input.toShort == input) {
+ testFunc(input.toShort)
+ } else if (input.toInt == input) {
+ testFunc(input.toInt)
+ } else {
+ testFunc(input)
+ }
+ }
+ checkResult(0)
+ checkResult(Byte.MaxValue)
+ checkResult(Byte.MinValue)
+ checkResult(Short.MaxValue)
+ checkResult(Short.MinValue)
+ checkResult(Int.MaxValue)
+ checkResult(Int.MinValue)
+ checkResult(Int.MaxValue.toLong + 100)
+ checkResult(Int.MinValue.toLong - 100)
+ }
+
+ test("TIMESTAMP_SECONDS") {
+ def testIntegralFunc(value: Number): Unit = {
+ checkEvaluation(
+ SecondsToTimestamp(Literal(value)),
+ Instant.ofEpochSecond(value.longValue()))
+ }
+
+ // test null input
+ checkEvaluation(
+ SecondsToTimestamp(Literal(null, IntegerType)),
+ null)
+
+ // test integral input
+ testIntegralInput(testIntegralFunc)
+ // test overflow
+ checkExceptionInExpression[ArithmeticException](
+ SecondsToTimestamp(Literal(Long.MaxValue, LongType)), EmptyRow, "long overflow")
+
+ def testFractionalInput(input: String): Unit = {
+ Seq(input.toFloat, input.toDouble, Decimal(input)).foreach { value =>
+ checkEvaluation(
+ SecondsToTimestamp(Literal(value)),
+ (input.toDouble * MICROS_PER_SECOND).toLong)
+ }
+ }
+
+ testFractionalInput("1.0")
+ testFractionalInput("-1.0")
+ testFractionalInput("1.234567")
+ testFractionalInput("-1.234567")
+
+ // test overflow for decimal input
+ checkExceptionInExpression[ArithmeticException](
+ SecondsToTimestamp(Literal(Decimal("9" * 38))), "Overflow"
+ )
+ // test truncation error for decimal input
+ checkExceptionInExpression[ArithmeticException](
+ SecondsToTimestamp(Literal(Decimal("0.1234567"))), "Rounding necessary"
+ )
+
+ // test NaN
+ checkEvaluation(
+ SecondsToTimestamp(Literal(Double.NaN)),
+ null)
+ checkEvaluation(
+ SecondsToTimestamp(Literal(Float.NaN)),
+ null)
+ // double input can truncate
+ checkEvaluation(
+ SecondsToTimestamp(Literal(123.456789123)),
+ Instant.ofEpochSecond(123, 456789000))
+ }
+
+ test("TIMESTAMP_MILLIS") {
+ def testIntegralFunc(value: Number): Unit = {
+ checkEvaluation(
+ MillisToTimestamp(Literal(value)),
+ Instant.ofEpochMilli(value.longValue()))
+ }
+
+ // test null input
+ checkEvaluation(
+ MillisToTimestamp(Literal(null, IntegerType)),
+ null)
+
+ // test integral input
+ testIntegralInput(testIntegralFunc)
+ // test overflow
+ checkExceptionInExpression[ArithmeticException](
+ MillisToTimestamp(Literal(Long.MaxValue, LongType)), EmptyRow, "long overflow")
+ }
+
+ test("TIMESTAMP_MICROS") {
+ def testIntegralFunc(value: Number): Unit = {
+ checkEvaluation(
+ MicrosToTimestamp(Literal(value)),
+ value.longValue())
+ }
+
+ // test null input
+ checkEvaluation(
+ MicrosToTimestamp(Literal(null, IntegerType)),
+ null)
+
+ // test integral input
+ testIntegralInput(testIntegralFunc)
+ // test max/min input
+ testIntegralFunc(Long.MaxValue)
+ testIntegralFunc(Long.MinValue)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala
index bce917c80f93c..8290b38e33934 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala
@@ -35,22 +35,29 @@ class JsonInferSchemaSuite extends SparkFunSuite with SQLHelper {
assert(inferSchema.inferField(parser) === expectedType)
}
- def checkTimestampType(pattern: String, json: String): Unit = {
- checkType(Map("timestampFormat" -> pattern), json, TimestampType)
+ def checkTimestampType(pattern: String, json: String, inferTimestamp: Boolean): Unit = {
+ checkType(
+ Map("timestampFormat" -> pattern, "inferTimestamp" -> inferTimestamp.toString),
+ json,
+ if (inferTimestamp) TimestampType else StringType)
}
test("inferring timestamp type") {
- Seq("legacy", "corrected").foreach { legacyParserPolicy =>
- withSQLConf(SQLConf.LEGACY_TIME_PARSER_POLICY.key -> legacyParserPolicy) {
- checkTimestampType("yyyy", """{"a": "2018"}""")
- checkTimestampType("yyyy=MM", """{"a": "2018=12"}""")
- checkTimestampType("yyyy MM dd", """{"a": "2018 12 02"}""")
- checkTimestampType(
- "yyyy-MM-dd'T'HH:mm:ss.SSS",
- """{"a": "2018-12-02T21:04:00.123"}""")
- checkTimestampType(
- "yyyy-MM-dd'T'HH:mm:ss.SSSSSSXXX",
- """{"a": "2018-12-02T21:04:00.123567+01:00"}""")
+ Seq(true, false).foreach { inferTimestamp =>
+ Seq("legacy", "corrected").foreach { legacyParserPolicy =>
+ withSQLConf(SQLConf.LEGACY_TIME_PARSER_POLICY.key -> legacyParserPolicy) {
+ checkTimestampType("yyyy", """{"a": "2018"}""", inferTimestamp)
+ checkTimestampType("yyyy=MM", """{"a": "2018=12"}""", inferTimestamp)
+ checkTimestampType("yyyy MM dd", """{"a": "2018 12 02"}""", inferTimestamp)
+ checkTimestampType(
+ "yyyy-MM-dd'T'HH:mm:ss.SSS",
+ """{"a": "2018-12-02T21:04:00.123"}""",
+ inferTimestamp)
+ checkTimestampType(
+ "yyyy-MM-dd'T'HH:mm:ss.SSSSSSXXX",
+ """{"a": "2018-12-02T21:04:00.123567+01:00"}""",
+ inferTimestamp)
+ }
}
}
}
@@ -71,16 +78,19 @@ class JsonInferSchemaSuite extends SparkFunSuite with SQLHelper {
}
test("skip decimal type inferring") {
- Seq("legacy", "corrected").foreach { legacyParserPolicy =>
- withSQLConf(SQLConf.LEGACY_TIME_PARSER_POLICY.key -> legacyParserPolicy) {
- checkType(
- options = Map(
- "prefersDecimal" -> "false",
- "timestampFormat" -> "yyyyMMdd.HHmmssSSS"
- ),
- json = """{"a": "20181202.210400123"}""",
- dt = TimestampType
- )
+ Seq(true, false).foreach { inferTimestamp =>
+ Seq("legacy", "corrected").foreach { legacyParserPolicy =>
+ withSQLConf(SQLConf.LEGACY_TIME_PARSER_POLICY.key -> legacyParserPolicy) {
+ checkType(
+ options = Map(
+ "prefersDecimal" -> "false",
+ "timestampFormat" -> "yyyyMMdd.HHmmssSSS",
+ "inferTimestamp" -> inferTimestamp.toString
+ ),
+ json = """{"a": "20181202.210400123"}""",
+ dt = if (inferTimestamp) TimestampType else StringType
+ )
+ }
}
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
index 6499b5d8e7974..e802449a69743 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
@@ -1134,58 +1134,74 @@ class DDLParserSuite extends AnalysisTest {
}
}
- test("merge into table: at most two matched clauses") {
- val exc = intercept[ParseException] {
- parsePlan(
- """
- |MERGE INTO testcat1.ns1.ns2.tbl AS target
- |USING testcat2.ns1.ns2.tbl AS source
- |ON target.col1 = source.col1
- |WHEN MATCHED AND (target.col2='delete') THEN DELETE
- |WHEN MATCHED AND (target.col2='update1') THEN UPDATE SET target.col2 = source.col2
- |WHEN MATCHED AND (target.col2='update2') THEN UPDATE SET target.col2 = source.col2
- |WHEN NOT MATCHED AND (target.col2='insert')
- |THEN INSERT (target.col1, target.col2) values (source.col1, source.col2)
- """.stripMargin)
- }
-
- assert(exc.getMessage.contains("There should be at most 2 'WHEN MATCHED' clauses."))
+ test("merge into table: multi matched and not matched clauses") {
+ parseCompare(
+ """
+ |MERGE INTO testcat1.ns1.ns2.tbl AS target
+ |USING testcat2.ns1.ns2.tbl AS source
+ |ON target.col1 = source.col1
+ |WHEN MATCHED AND (target.col2='delete') THEN DELETE
+ |WHEN MATCHED AND (target.col2='update1') THEN UPDATE SET target.col2 = 1
+ |WHEN MATCHED AND (target.col2='update2') THEN UPDATE SET target.col2 = 2
+ |WHEN NOT MATCHED AND (target.col2='insert1')
+ |THEN INSERT (target.col1, target.col2) values (source.col1, 1)
+ |WHEN NOT MATCHED AND (target.col2='insert2')
+ |THEN INSERT (target.col1, target.col2) values (source.col1, 2)
+ """.stripMargin,
+ MergeIntoTable(
+ SubqueryAlias("target", UnresolvedRelation(Seq("testcat1", "ns1", "ns2", "tbl"))),
+ SubqueryAlias("source", UnresolvedRelation(Seq("testcat2", "ns1", "ns2", "tbl"))),
+ EqualTo(UnresolvedAttribute("target.col1"), UnresolvedAttribute("source.col1")),
+ Seq(DeleteAction(Some(EqualTo(UnresolvedAttribute("target.col2"), Literal("delete")))),
+ UpdateAction(Some(EqualTo(UnresolvedAttribute("target.col2"), Literal("update1"))),
+ Seq(Assignment(UnresolvedAttribute("target.col2"), Literal(1)))),
+ UpdateAction(Some(EqualTo(UnresolvedAttribute("target.col2"), Literal("update2"))),
+ Seq(Assignment(UnresolvedAttribute("target.col2"), Literal(2))))),
+ Seq(InsertAction(Some(EqualTo(UnresolvedAttribute("target.col2"), Literal("insert1"))),
+ Seq(Assignment(UnresolvedAttribute("target.col1"), UnresolvedAttribute("source.col1")),
+ Assignment(UnresolvedAttribute("target.col2"), Literal(1)))),
+ InsertAction(Some(EqualTo(UnresolvedAttribute("target.col2"), Literal("insert2"))),
+ Seq(Assignment(UnresolvedAttribute("target.col1"), UnresolvedAttribute("source.col1")),
+ Assignment(UnresolvedAttribute("target.col2"), Literal(2)))))))
}
- test("merge into table: at most one not matched clause") {
+ test("merge into table: only the last matched clause can omit the condition") {
val exc = intercept[ParseException] {
parsePlan(
"""
|MERGE INTO testcat1.ns1.ns2.tbl AS target
|USING testcat2.ns1.ns2.tbl AS source
|ON target.col1 = source.col1
- |WHEN MATCHED AND (target.col2='delete') THEN DELETE
- |WHEN MATCHED AND (target.col2='update1') THEN UPDATE SET target.col2 = source.col2
- |WHEN NOT MATCHED AND (target.col2='insert1')
- |THEN INSERT (target.col1, target.col2) values (source.col1, source.col2)
- |WHEN NOT MATCHED AND (target.col2='insert2')
+ |WHEN MATCHED AND (target.col2 == 'update1') THEN UPDATE SET target.col2 = 1
+ |WHEN MATCHED THEN UPDATE SET target.col2 = 2
+ |WHEN MATCHED THEN DELETE
+ |WHEN NOT MATCHED AND (target.col2='insert')
|THEN INSERT (target.col1, target.col2) values (source.col1, source.col2)
""".stripMargin)
}
- assert(exc.getMessage.contains("There should be at most 1 'WHEN NOT MATCHED' clause."))
+ assert(exc.getMessage.contains("only the last MATCHED clause can omit the condition"))
}
- test("merge into table: the first matched clause must have a condition if there's a second") {
+ test("merge into table: only the last not matched clause can omit the condition") {
val exc = intercept[ParseException] {
parsePlan(
"""
|MERGE INTO testcat1.ns1.ns2.tbl AS target
|USING testcat2.ns1.ns2.tbl AS source
|ON target.col1 = source.col1
+ |WHEN MATCHED AND (target.col2 == 'update') THEN UPDATE SET target.col2 = source.col2
|WHEN MATCHED THEN DELETE
- |WHEN MATCHED THEN UPDATE SET target.col2 = source.col2
- |WHEN NOT MATCHED AND (target.col2='insert')
+ |WHEN NOT MATCHED AND (target.col2='insert1')
+ |THEN INSERT (target.col1, target.col2) values (source.col1, 1)
+ |WHEN NOT MATCHED
+ |THEN INSERT (target.col1, target.col2) values (source.col1, 2)
+ |WHEN NOT MATCHED
|THEN INSERT (target.col1, target.col2) values (source.col1, source.col2)
""".stripMargin)
}
- assert(exc.getMessage.contains("the first MATCHED clause must have a condition"))
+ assert(exc.getMessage.contains("only the last NOT MATCHED clause can omit the condition"))
}
test("merge into table: there must be a when (not) matched condition") {
@@ -1201,26 +1217,6 @@ class DDLParserSuite extends AnalysisTest {
assert(exc.getMessage.contains("There must be at least one WHEN clause in a MERGE statement"))
}
- test("merge into table: there can be only a single use DELETE or UPDATE") {
- Seq("UPDATE SET *", "DELETE").foreach { op =>
- val exc = intercept[ParseException] {
- parsePlan(
- s"""
- |MERGE INTO testcat1.ns1.ns2.tbl AS target
- |USING testcat2.ns1.ns2.tbl AS source
- |ON target.col1 = source.col1
- |WHEN MATCHED AND (target.col2='delete') THEN $op
- |WHEN MATCHED THEN $op
- |WHEN NOT MATCHED AND (target.col2='insert')
- |THEN INSERT (target.col1, target.col2) values (source.col1, source.col2)
- """.stripMargin)
- }
-
- assert(exc.getMessage.contains(
- "UPDATE and DELETE can appear at most once in MATCHED clauses"))
- }
- }
-
test("show tables") {
comparePlans(
parsePlan("SHOW TABLES"),
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala
index 17d326019f86b..7dd0753fcf777 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala
@@ -19,9 +19,12 @@ package org.apache.spark.sql.connector.catalog
import java.net.URI
+import scala.collection.JavaConverters._
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.{EmptyFunctionRegistry, FakeV2SessionCatalog, NoSuchNamespaceException}
import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog}
+import org.apache.spark.sql.connector.InMemoryTableCatalog
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -108,6 +111,19 @@ class CatalogManagerSuite extends SparkFunSuite {
assert(v1SessionCatalog.getCurrentDatabase == "default")
catalogManager.setCurrentNamespace(Array("test2"))
assert(v1SessionCatalog.getCurrentDatabase == "default")
+
+ // Check namespace existence if currentCatalog implements SupportsNamespaces.
+ conf.setConfString("spark.sql.catalog.testCatalog", classOf[InMemoryTableCatalog].getName)
+ catalogManager.setCurrentCatalog("testCatalog")
+ catalogManager.currentCatalog.asInstanceOf[InMemoryTableCatalog]
+ .createNamespace(Array("test3"), Map.empty[String, String].asJava)
+ assert(v1SessionCatalog.getCurrentDatabase == "default")
+ catalogManager.setCurrentNamespace(Array("test3"))
+ assert(v1SessionCatalog.getCurrentDatabase == "default")
+
+ intercept[NoSuchNamespaceException] {
+ catalogManager.setCurrentNamespace(Array("ns1", "ns2"))
+ }
}
}
diff --git a/sql/core/benchmarks/JsonBenchmark-jdk11-results.txt b/sql/core/benchmarks/JsonBenchmark-jdk11-results.txt
index d0cd591da4c94..2d506f03d9f7e 100644
--- a/sql/core/benchmarks/JsonBenchmark-jdk11-results.txt
+++ b/sql/core/benchmarks/JsonBenchmark-jdk11-results.txt
@@ -7,106 +7,106 @@ OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-106
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
JSON schema inferring: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-No encoding 68879 68993 116 1.5 688.8 1.0X
-UTF-8 is set 115270 115602 455 0.9 1152.7 0.6X
+No encoding 73307 73400 141 1.4 733.1 1.0X
+UTF-8 is set 143834 143925 152 0.7 1438.3 0.5X
Preparing data for benchmarking ...
OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
count a short column: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-No encoding 47452 47538 113 2.1 474.5 1.0X
-UTF-8 is set 77330 77354 30 1.3 773.3 0.6X
+No encoding 50894 51065 292 2.0 508.9 1.0X
+UTF-8 is set 98462 99455 1173 1.0 984.6 0.5X
Preparing data for benchmarking ...
OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
count a wide column: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-No encoding 60470 60900 534 0.2 6047.0 1.0X
-UTF-8 is set 104733 104931 189 0.1 10473.3 0.6X
+No encoding 64011 64969 1001 0.2 6401.1 1.0X
+UTF-8 is set 102757 102984 311 0.1 10275.7 0.6X
Preparing data for benchmarking ...
OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
select wide row: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-No encoding 130302 131072 976 0.0 260604.6 1.0X
-UTF-8 is set 150860 151284 377 0.0 301720.1 0.9X
+No encoding 132559 133561 1010 0.0 265117.3 1.0X
+UTF-8 is set 151458 152129 611 0.0 302915.4 0.9X
Preparing data for benchmarking ...
OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
Select a subset of 10 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-Select 10 columns 18619 18684 99 0.5 1861.9 1.0X
-Select 1 column 24227 24270 38 0.4 2422.7 0.8X
+Select 10 columns 21148 21202 87 0.5 2114.8 1.0X
+Select 1 column 24701 24724 21 0.4 2470.1 0.9X
Preparing data for benchmarking ...
OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
creation of JSON parser per line: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-Short column without encoding 7947 7971 21 1.3 794.7 1.0X
-Short column with UTF-8 12700 12753 58 0.8 1270.0 0.6X
-Wide column without encoding 92632 92955 463 0.1 9263.2 0.1X
-Wide column with UTF-8 147013 147170 188 0.1 14701.3 0.1X
+Short column without encoding 6945 6998 59 1.4 694.5 1.0X
+Short column with UTF-8 11510 11569 51 0.9 1151.0 0.6X
+Wide column without encoding 95004 95795 790 0.1 9500.4 0.1X
+Wide column with UTF-8 149223 149409 276 0.1 14922.3 0.0X
Preparing data for benchmarking ...
OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
JSON functions: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-Text read 713 734 19 14.0 71.3 1.0X
-from_json 22019 22429 456 0.5 2201.9 0.0X
-json_tuple 27987 28047 74 0.4 2798.7 0.0X
-get_json_object 21468 21870 350 0.5 2146.8 0.0X
+Text read 649 652 3 15.4 64.9 1.0X
+from_json 22284 22393 99 0.4 2228.4 0.0X
+json_tuple 32310 32824 484 0.3 3231.0 0.0X
+get_json_object 22111 22751 568 0.5 2211.1 0.0X
Preparing data for benchmarking ...
OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
Dataset of json strings: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-Text read 2887 2910 24 17.3 57.7 1.0X
-schema inferring 31793 31843 43 1.6 635.9 0.1X
-parsing 36791 37104 294 1.4 735.8 0.1X
+Text read 2894 2903 8 17.3 57.9 1.0X
+schema inferring 26724 26785 62 1.9 534.5 0.1X
+parsing 37502 37632 131 1.3 750.0 0.1X
Preparing data for benchmarking ...
OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
Json files in the per-line mode: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-Text read 10570 10611 45 4.7 211.4 1.0X
-Schema inferring 48729 48763 41 1.0 974.6 0.2X
-Parsing without charset 35490 35648 141 1.4 709.8 0.3X
-Parsing with UTF-8 63853 63994 163 0.8 1277.1 0.2X
+Text read 10994 11010 16 4.5 219.9 1.0X
+Schema inferring 45654 45677 37 1.1 913.1 0.2X
+Parsing without charset 34476 34559 73 1.5 689.5 0.3X
+Parsing with UTF-8 56987 57002 13 0.9 1139.7 0.2X
OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
Write dates and timestamps: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-Create a dataset of timestamps 2187 2190 5 4.6 218.7 1.0X
-to_json(timestamp) 16262 16503 323 0.6 1626.2 0.1X
-write timestamps to files 11679 11692 12 0.9 1167.9 0.2X
-Create a dataset of dates 2297 2310 12 4.4 229.7 1.0X
-to_json(date) 10904 10956 46 0.9 1090.4 0.2X
-write dates to files 6610 6645 35 1.5 661.0 0.3X
+Create a dataset of timestamps 2150 2188 35 4.7 215.0 1.0X
+to_json(timestamp) 17874 18080 294 0.6 1787.4 0.1X
+write timestamps to files 12518 12538 34 0.8 1251.8 0.2X
+Create a dataset of dates 2298 2310 18 4.4 229.8 0.9X
+to_json(date) 11673 11703 27 0.9 1167.3 0.2X
+write dates to files 7121 7135 12 1.4 712.1 0.3X
OpenJDK 64-Bit Server VM 11.0.7+10-post-Ubuntu-2ubuntu218.04 on Linux 4.15.0-1063-aws
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
Read dates and timestamps: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-read timestamp text from files 2524 2530 9 4.0 252.4 1.0X
-read timestamps from files 41002 41052 59 0.2 4100.2 0.1X
-infer timestamps from files 84621 84939 526 0.1 8462.1 0.0X
-read date text from files 2292 2302 9 4.4 229.2 1.1X
-read date from files 16954 16976 21 0.6 1695.4 0.1X
-timestamp strings 3067 3077 13 3.3 306.7 0.8X
-parse timestamps from Dataset[String] 48690 48971 243 0.2 4869.0 0.1X
-infer timestamps from Dataset[String] 97463 97786 338 0.1 9746.3 0.0X
-date strings 3952 3956 3 2.5 395.2 0.6X
-parse dates from Dataset[String] 24210 24241 30 0.4 2421.0 0.1X
-from_json(timestamp) 71710 72242 629 0.1 7171.0 0.0X
-from_json(date) 42465 42481 13 0.2 4246.5 0.1X
+read timestamp text from files 2616 2641 34 3.8 261.6 1.0X
+read timestamps from files 37481 37517 58 0.3 3748.1 0.1X
+infer timestamps from files 84774 84964 201 0.1 8477.4 0.0X
+read date text from files 2362 2365 3 4.2 236.2 1.1X
+read date from files 16583 16612 29 0.6 1658.3 0.2X
+timestamp strings 3927 3963 40 2.5 392.7 0.7X
+parse timestamps from Dataset[String] 52827 53004 243 0.2 5282.7 0.0X
+infer timestamps from Dataset[String] 101108 101644 769 0.1 10110.8 0.0X
+date strings 4886 4906 26 2.0 488.6 0.5X
+parse dates from Dataset[String] 27623 27694 62 0.4 2762.3 0.1X
+from_json(timestamp) 71764 71887 124 0.1 7176.4 0.0X
+from_json(date) 46200 46314 99 0.2 4620.0 0.1X
diff --git a/sql/core/benchmarks/JsonBenchmark-results.txt b/sql/core/benchmarks/JsonBenchmark-results.txt
index 46d2410fb47c3..c22118f91b3fc 100644
--- a/sql/core/benchmarks/JsonBenchmark-results.txt
+++ b/sql/core/benchmarks/JsonBenchmark-results.txt
@@ -7,106 +7,106 @@ OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aw
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
JSON schema inferring: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-No encoding 63981 64044 56 1.6 639.8 1.0X
-UTF-8 is set 112672 113350 962 0.9 1126.7 0.6X
+No encoding 63839 64000 263 1.6 638.4 1.0X
+UTF-8 is set 124633 124945 429 0.8 1246.3 0.5X
Preparing data for benchmarking ...
OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
count a short column: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-No encoding 51256 51449 180 2.0 512.6 1.0X
-UTF-8 is set 83694 83859 148 1.2 836.9 0.6X
+No encoding 51720 51901 157 1.9 517.2 1.0X
+UTF-8 is set 91161 91190 25 1.1 911.6 0.6X
Preparing data for benchmarking ...
OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
count a wide column: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-No encoding 58440 59097 569 0.2 5844.0 1.0X
-UTF-8 is set 102746 102883 198 0.1 10274.6 0.6X
+No encoding 58486 59038 714 0.2 5848.6 1.0X
+UTF-8 is set 103045 103350 358 0.1 10304.5 0.6X
Preparing data for benchmarking ...
OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
select wide row: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-No encoding 128982 129304 356 0.0 257965.0 1.0X
-UTF-8 is set 147247 147415 231 0.0 294494.1 0.9X
+No encoding 134909 135024 105 0.0 269818.6 1.0X
+UTF-8 is set 154418 154593 155 0.0 308836.7 0.9X
Preparing data for benchmarking ...
OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
Select a subset of 10 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-Select 10 columns 18837 19048 331 0.5 1883.7 1.0X
-Select 1 column 24707 24723 14 0.4 2470.7 0.8X
+Select 10 columns 19538 19620 70 0.5 1953.8 1.0X
+Select 1 column 26142 26159 15 0.4 2614.2 0.7X
Preparing data for benchmarking ...
OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
creation of JSON parser per line: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-Short column without encoding 8218 8234 17 1.2 821.8 1.0X
-Short column with UTF-8 12374 12438 107 0.8 1237.4 0.7X
-Wide column without encoding 136918 137298 345 0.1 13691.8 0.1X
-Wide column with UTF-8 176961 177142 257 0.1 17696.1 0.0X
+Short column without encoding 8103 8162 53 1.2 810.3 1.0X
+Short column with UTF-8 13104 13150 58 0.8 1310.4 0.6X
+Wide column without encoding 135280 135593 375 0.1 13528.0 0.1X
+Wide column with UTF-8 175189 175483 278 0.1 17518.9 0.0X
Preparing data for benchmarking ...
OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
JSON functions: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-Text read 1268 1278 12 7.9 126.8 1.0X
-from_json 23348 23479 176 0.4 2334.8 0.1X
-json_tuple 29606 30221 1024 0.3 2960.6 0.0X
-get_json_object 21898 22148 226 0.5 2189.8 0.1X
+Text read 1225 1234 8 8.2 122.5 1.0X
+from_json 22482 22552 95 0.4 2248.2 0.1X
+json_tuple 30203 30338 146 0.3 3020.3 0.0X
+get_json_object 22219 22245 26 0.5 2221.9 0.1X
Preparing data for benchmarking ...
OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
Dataset of json strings: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-Text read 5887 5944 49 8.5 117.7 1.0X
-schema inferring 46696 47054 312 1.1 933.9 0.1X
-parsing 32336 32450 129 1.5 646.7 0.2X
+Text read 5897 5904 10 8.5 117.9 1.0X
+schema inferring 30282 30340 50 1.7 605.6 0.2X
+parsing 33304 33577 289 1.5 666.1 0.2X
Preparing data for benchmarking ...
OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
Json files in the per-line mode: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-Text read 9756 9769 11 5.1 195.1 1.0X
-Schema inferring 51318 51433 108 1.0 1026.4 0.2X
-Parsing without charset 43609 43743 118 1.1 872.2 0.2X
-Parsing with UTF-8 60775 60844 106 0.8 1215.5 0.2X
+Text read 9710 9757 80 5.1 194.2 1.0X
+Schema inferring 35929 35939 9 1.4 718.6 0.3X
+Parsing without charset 39175 39227 87 1.3 783.5 0.2X
+Parsing with UTF-8 59188 59294 109 0.8 1183.8 0.2X
OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
Write dates and timestamps: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-Create a dataset of timestamps 1998 2015 17 5.0 199.8 1.0X
-to_json(timestamp) 18156 18317 263 0.6 1815.6 0.1X
-write timestamps to files 12912 12917 5 0.8 1291.2 0.2X
-Create a dataset of dates 2209 2270 53 4.5 220.9 0.9X
-to_json(date) 9433 9489 90 1.1 943.3 0.2X
-write dates to files 6915 6923 8 1.4 691.5 0.3X
+Create a dataset of timestamps 1967 1977 9 5.1 196.7 1.0X
+to_json(timestamp) 17086 17304 371 0.6 1708.6 0.1X
+write timestamps to files 12691 12716 28 0.8 1269.1 0.2X
+Create a dataset of dates 2192 2217 39 4.6 219.2 0.9X
+to_json(date) 10541 10656 137 0.9 1054.1 0.2X
+write dates to files 7259 7311 46 1.4 725.9 0.3X
OpenJDK 64-Bit Server VM 1.8.0_252-8u252-b09-1~18.04-b09 on Linux 4.15.0-1063-aws
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
Read dates and timestamps: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-read timestamp text from files 2395 2412 17 4.2 239.5 1.0X
-read timestamps from files 47269 47334 89 0.2 4726.9 0.1X
-infer timestamps from files 91806 91851 67 0.1 9180.6 0.0X
-read date text from files 2118 2133 13 4.7 211.8 1.1X
-read date from files 17267 17340 115 0.6 1726.7 0.1X
-timestamp strings 3906 3935 26 2.6 390.6 0.6X
-parse timestamps from Dataset[String] 52244 52534 279 0.2 5224.4 0.0X
-infer timestamps from Dataset[String] 100488 100714 198 0.1 10048.8 0.0X
-date strings 4572 4584 12 2.2 457.2 0.5X
-parse dates from Dataset[String] 26749 26768 17 0.4 2674.9 0.1X
-from_json(timestamp) 71414 71867 556 0.1 7141.4 0.0X
-from_json(date) 45322 45549 250 0.2 4532.2 0.1X
+read timestamp text from files 2318 2326 13 4.3 231.8 1.0X
+read timestamps from files 43345 43627 258 0.2 4334.5 0.1X
+infer timestamps from files 89570 89621 59 0.1 8957.0 0.0X
+read date text from files 2099 2107 9 4.8 209.9 1.1X
+read date from files 18000 18065 98 0.6 1800.0 0.1X
+timestamp strings 3937 3956 32 2.5 393.7 0.6X
+parse timestamps from Dataset[String] 56001 56429 539 0.2 5600.1 0.0X
+infer timestamps from Dataset[String] 109410 109963 559 0.1 10941.0 0.0X
+date strings 4530 4540 9 2.2 453.0 0.5X
+parse dates from Dataset[String] 29723 29767 72 0.3 2972.3 0.1X
+from_json(timestamp) 74106 74619 728 0.1 7410.6 0.0X
+from_json(date) 46599 46632 32 0.2 4659.9 0.0X
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index 0855fa13fa79a..c2ed4c079d3cf 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -150,6 +150,11 @@
mssql-jdbc
test
+
+ com.oracle.database.jdbc
+ ojdbc8
+ test
+
org.apache.parquet
parquet-avro
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
index 52cec8b202885..7d86c48015406 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
@@ -248,12 +248,17 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
* `HadoopFsRelation` node(s) as part of its logical plan.
*/
def recacheByPath(spark: SparkSession, resourcePath: String): Unit = {
- val (fs, qualifiedPath) = {
- val path = new Path(resourcePath)
- val fs = path.getFileSystem(spark.sessionState.newHadoopConf())
- (fs, fs.makeQualified(path))
- }
+ val path = new Path(resourcePath)
+ val fs = path.getFileSystem(spark.sessionState.newHadoopConf())
+ recacheByPath(spark, path, fs)
+ }
+ /**
+ * Tries to re-cache all the cache entries that contain `resourcePath` in one or more
+ * `HadoopFsRelation` node(s) as part of its logical plan.
+ */
+ def recacheByPath(spark: SparkSession, resourcePath: Path, fs: FileSystem): Unit = {
+ val qualifiedPath = fs.makeQualified(resourcePath)
recacheByCondition(spark, _.plan.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index 458e11b97db6b..78808ff21394c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -207,7 +207,7 @@ case class FileSourceScanExec(
private def isDynamicPruningFilter(e: Expression): Boolean =
e.find(_.isInstanceOf[PlanExpression[_]]).isDefined
- @transient private lazy val selectedPartitions: Array[PartitionDirectory] = {
+ @transient lazy val selectedPartitions: Array[PartitionDirectory] = {
val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L)
val startTime = System.nanoTime()
val ret =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
index 078813b7d631d..3a2c673229c20 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
@@ -746,7 +746,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
ctx: QueryOrganizationContext,
expressions: Seq[Expression],
query: LogicalPlan): LogicalPlan = {
- RepartitionByExpression(expressions, query, conf.numShufflePartitions)
+ RepartitionByExpression(expressions, query, None)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
index f11972115e09f..fe733f4238e1a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
@@ -192,7 +192,7 @@ case class InsertIntoHadoopFsRelationCommand(
// refresh cached files in FileIndex
fileIndex.foreach(_.refresh())
// refresh data cache if table is cached
- sparkSession.catalog.refreshByPath(outputPath.toString)
+ sparkSession.sharedState.cacheManager.recacheByPath(sparkSession, outputPath, fs)
if (catalogTable.nonEmpty) {
CommandUtils.updateTableStats(sparkSession, catalogTable.get)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala
index a7129fb14d1a6..576a826faf894 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala
@@ -39,7 +39,8 @@ import org.apache.spark.sql.types.StructType
* its underlying [[FileScan]]. And the partition filters will be removed in the filters of
* returned logical plan.
*/
-private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {
+private[sql] object PruneFileSourcePartitions
+ extends Rule[LogicalPlan] with PredicateHelper {
private def getPartitionKeyFiltersAndDataFilters(
sparkSession: SparkSession,
@@ -87,8 +88,12 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {
_,
_))
if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined =>
+ val predicates = CNFWithGroupExpressionsByReference(filters.reduceLeft(And))
+ val finalPredicates = if (predicates.nonEmpty) predicates else filters
val (partitionKeyFilters, _) = getPartitionKeyFiltersAndDataFilters(
- fsRelation.sparkSession, logicalRelation, partitionSchema, filters, logicalRelation.output)
+ fsRelation.sparkSession, logicalRelation, partitionSchema, finalPredicates,
+ logicalRelation.output)
+
if (partitionKeyFilters.nonEmpty) {
val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq)
val prunedFsRelation =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaMergeUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaMergeUtils.scala
index 99882b0f7c7b0..28097c35401c9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaMergeUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaMergeUtils.scala
@@ -32,10 +32,12 @@ object SchemaMergeUtils extends Logging {
*/
def mergeSchemasInParallel(
sparkSession: SparkSession,
+ parameters: Map[String, String],
files: Seq[FileStatus],
schemaReader: (Seq[FileStatus], Configuration, Boolean) => Seq[StructType])
: Option[StructType] = {
- val serializedConf = new SerializableConfiguration(sparkSession.sessionState.newHadoopConf())
+ val serializedConf = new SerializableConfiguration(
+ sparkSession.sessionState.newHadoopConfWithOptions(parameters))
// !! HACK ALERT !!
// Here is a hack for Parquet, but it can be used by Orc as well.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/BasicConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/BasicConnectionProvider.scala
index c21e16bcf1280..16b244cc617ce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/BasicConnectionProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/BasicConnectionProvider.scala
@@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
private[jdbc] class BasicConnectionProvider(driver: Driver, options: JDBCOptions)
- extends ConnectionProvider {
+ extends ConnectionProvider {
def getConnection(): Connection = {
val properties = getAdditionalProperties()
options.asConnectionProperties.entrySet().asScala.foreach { e =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala
index 6c310ced37883..ce45be442ccc3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala
@@ -64,6 +64,10 @@ private[jdbc] object ConnectionProvider extends Logging {
logDebug("MS SQL connection provider found")
new MSSQLConnectionProvider(driver, options)
+ case OracleConnectionProvider.driverClass =>
+ logDebug("Oracle connection provider found")
+ new OracleConnectionProvider(driver, options)
+
case _ =>
throw new IllegalArgumentException(s"Driver ${options.driverClass} does not support " +
"Kerberos authentication")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/DB2ConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/DB2ConnectionProvider.scala
index cf9729639c03c..095821cf83890 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/DB2ConnectionProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/DB2ConnectionProvider.scala
@@ -26,7 +26,7 @@ import org.apache.hadoop.security.UserGroupInformation
import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
private[sql] class DB2ConnectionProvider(driver: Driver, options: JDBCOptions)
- extends SecureConnectionProvider(driver, options) {
+ extends SecureConnectionProvider(driver, options) {
override val appEntry: String = "JaasClient"
override def getConnection(): Connection = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MariaDBConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MariaDBConnectionProvider.scala
index 589f13cf6ad5f..3c0286654a8ec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MariaDBConnectionProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MariaDBConnectionProvider.scala
@@ -22,7 +22,7 @@ import java.sql.Driver
import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
private[jdbc] class MariaDBConnectionProvider(driver: Driver, options: JDBCOptions)
- extends SecureConnectionProvider(driver, options) {
+ extends SecureConnectionProvider(driver, options) {
override val appEntry: String = {
"Krb5ConnectorContext"
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/OracleConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/OracleConnectionProvider.scala
new file mode 100644
index 0000000000000..c2b71b35b8128
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/OracleConnectionProvider.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.jdbc.connection
+
+import java.security.PrivilegedExceptionAction
+import java.sql.{Connection, Driver}
+import java.util.Properties
+
+import org.apache.hadoop.security.UserGroupInformation
+
+import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
+
+private[sql] class OracleConnectionProvider(driver: Driver, options: JDBCOptions)
+ extends SecureConnectionProvider(driver, options) {
+ override val appEntry: String = "kprb5module"
+
+ override def getConnection(): Connection = {
+ setAuthenticationConfigIfNeeded()
+ UserGroupInformation.loginUserFromKeytabAndReturnUGI(options.principal, options.keytab).doAs(
+ new PrivilegedExceptionAction[Connection]() {
+ override def run(): Connection = {
+ OracleConnectionProvider.super.getConnection()
+ }
+ }
+ )
+ }
+
+ override def getAdditionalProperties(): Properties = {
+ val result = new Properties()
+ // This prop is needed to turn on kerberos authentication in the JDBC driver.
+ // The possible values can be found in AnoServices public interface
+ // The value is coming from AUTHENTICATION_KERBEROS5 final String in driver version 19.6.0.0
+ result.put("oracle.net.authentication_services", "(KERBEROS5)");
+ result
+ }
+
+ override def setAuthenticationConfigIfNeeded(): Unit = SecurityConfigurationLock.synchronized {
+ val (parent, configEntry) = getConfigWithAppEntry()
+ if (configEntry == null || configEntry.isEmpty) {
+ setAuthenticationConfig(parent)
+ }
+ }
+}
+
+private[sql] object OracleConnectionProvider {
+ val driverClass = "oracle.jdbc.OracleDriver"
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProvider.scala
index 73034dcb9c2e0..fa9232e00bd88 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProvider.scala
@@ -23,7 +23,7 @@ import java.util.Properties
import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
private[jdbc] class PostgresConnectionProvider(driver: Driver, options: JDBCOptions)
- extends SecureConnectionProvider(driver, options) {
+ extends SecureConnectionProvider(driver, options) {
override val appEntry: String = {
val parseURL = driver.getClass.getMethod("parseURL", classOf[String], classOf[Properties])
val properties = parseURL.invoke(driver, options.url, null).asInstanceOf[Properties]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/SecureConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/SecureConnectionProvider.scala
index fa75fc8c28fbf..24eec63a7244f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/SecureConnectionProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/SecureConnectionProvider.scala
@@ -33,7 +33,7 @@ import org.apache.spark.util.SecurityUtils
private[connection] object SecurityConfigurationLock
private[jdbc] abstract class SecureConnectionProvider(driver: Driver, options: JDBCOptions)
- extends BasicConnectionProvider(driver, options) with Logging {
+ extends BasicConnectionProvider(driver, options) with Logging {
override def getConnection(): Connection = {
setAuthenticationConfigIfNeeded()
super.getConnection()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
index eea9b2a8f9613..d274bcd0edd2c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
@@ -109,7 +109,7 @@ object OrcUtils extends Logging {
val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf)
if (orcOptions.mergeSchema) {
SchemaMergeUtils.mergeSchemasInParallel(
- sparkSession, files, OrcUtils.readOrcSchemasInParallel)
+ sparkSession, options, files, OrcUtils.readOrcSchemasInParallel)
} else {
OrcUtils.readSchema(sparkSession, files)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
index 71874104fcf4f..68f49f9442579 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
@@ -475,6 +475,7 @@ object ParquetFileFormat extends Logging {
* S3 nodes).
*/
def mergeSchemasInParallel(
+ parameters: Map[String, String],
filesToTouch: Seq[FileStatus],
sparkSession: SparkSession): Option[StructType] = {
val assumeBinaryIsString = sparkSession.sessionState.conf.isParquetBinaryAsString
@@ -490,7 +491,7 @@ object ParquetFileFormat extends Logging {
.map(ParquetFileFormat.readSchemaFromFooter(_, converter))
}
- SchemaMergeUtils.mergeSchemasInParallel(sparkSession, filesToTouch, reader)
+ SchemaMergeUtils.mergeSchemasInParallel(sparkSession, parameters, filesToTouch, reader)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
index 7e7dba92f37b5..b91d75c55c513 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
@@ -104,7 +104,7 @@ object ParquetUtils {
.orElse(filesByType.data.headOption)
.toSeq
}
- ParquetFileFormat.mergeSchemasInParallel(filesToTouch, sparkSession)
+ ParquetFileFormat.mergeSchemasInParallel(parameters, filesToTouch, sparkSession)
}
case class FileTypes(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala
index 30a964d7e643f..bbe8835049fa4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala
@@ -18,7 +18,10 @@ package org.apache.spark.sql.execution.datasources.v2
import java.util
+import scala.collection.JavaConverters._
+
import com.fasterxml.jackson.databind.ObjectMapper
+import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.spark.sql.SparkSession
@@ -53,14 +56,16 @@ trait FileDataSourceV2 extends TableProvider with DataSourceRegister {
paths ++ Option(map.get("path")).toSeq
}
- protected def getTableName(paths: Seq[String]): String = {
- val name = shortName() + " " + paths.map(qualifiedPathName).mkString(",")
+ protected def getTableName(map: CaseInsensitiveStringMap, paths: Seq[String]): String = {
+ val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(
+ map.asCaseSensitiveMap().asScala.toMap)
+ val name = shortName() + " " + paths.map(qualifiedPathName(_, hadoopConf)).mkString(",")
Utils.redact(sparkSession.sessionState.conf.stringRedactionPattern, name)
}
- private def qualifiedPathName(path: String): String = {
+ private def qualifiedPathName(path: String, hadoopConf: Configuration): String = {
val hdfsPath = new Path(path)
- val fs = hdfsPath.getFileSystem(sparkSession.sessionState.newHadoopConf())
+ val fs = hdfsPath.getFileSystem(hadoopConf)
hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory).toString
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala
index 1f99d4282f6da..69d001b4a615c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala
@@ -31,13 +31,13 @@ class CSVDataSourceV2 extends FileDataSourceV2 {
override def getTable(options: CaseInsensitiveStringMap): Table = {
val paths = getPaths(options)
- val tableName = getTableName(paths)
+ val tableName = getTableName(options, paths)
CSVTable(tableName, sparkSession, options, paths, None, fallbackFileFormat)
}
override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = {
val paths = getPaths(options)
- val tableName = getTableName(paths)
+ val tableName = getTableName(options, paths)
CSVTable(tableName, sparkSession, options, paths, Some(schema), fallbackFileFormat)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonDataSourceV2.scala
index 7a0949e586cd8..9c4e3b8c78026 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonDataSourceV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonDataSourceV2.scala
@@ -31,13 +31,13 @@ class JsonDataSourceV2 extends FileDataSourceV2 {
override def getTable(options: CaseInsensitiveStringMap): Table = {
val paths = getPaths(options)
- val tableName = getTableName(paths)
+ val tableName = getTableName(options, paths)
JsonTable(tableName, sparkSession, options, paths, None, fallbackFileFormat)
}
override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = {
val paths = getPaths(options)
- val tableName = getTableName(paths)
+ val tableName = getTableName(options, paths)
JsonTable(tableName, sparkSession, options, paths, Some(schema), fallbackFileFormat)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala
index 8665af33b976a..fa2febdc5a984 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala
@@ -31,13 +31,13 @@ class OrcDataSourceV2 extends FileDataSourceV2 {
override def getTable(options: CaseInsensitiveStringMap): Table = {
val paths = getPaths(options)
- val tableName = getTableName(paths)
+ val tableName = getTableName(options, paths)
OrcTable(tableName, sparkSession, options, paths, None, fallbackFileFormat)
}
override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = {
val paths = getPaths(options)
- val tableName = getTableName(paths)
+ val tableName = getTableName(options, paths)
OrcTable(tableName, sparkSession, options, paths, Some(schema), fallbackFileFormat)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetDataSourceV2.scala
index 8cb6186c12ff3..7e7ca964de28f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetDataSourceV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetDataSourceV2.scala
@@ -31,13 +31,13 @@ class ParquetDataSourceV2 extends FileDataSourceV2 {
override def getTable(options: CaseInsensitiveStringMap): Table = {
val paths = getPaths(options)
- val tableName = getTableName(paths)
+ val tableName = getTableName(options, paths)
ParquetTable(tableName, sparkSession, options, paths, None, fallbackFileFormat)
}
override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = {
val paths = getPaths(options)
- val tableName = getTableName(paths)
+ val tableName = getTableName(options, paths)
ParquetTable(tableName, sparkSession, options, paths, Some(schema), fallbackFileFormat)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextDataSourceV2.scala
index 049c717effa26..43bcb61f25962 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextDataSourceV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextDataSourceV2.scala
@@ -31,13 +31,13 @@ class TextDataSourceV2 extends FileDataSourceV2 {
override def getTable(options: CaseInsensitiveStringMap): Table = {
val paths = getPaths(options)
- val tableName = getTableName(paths)
+ val tableName = getTableName(options, paths)
TextTable(tableName, sparkSession, options, paths, None, fallbackFileFormat)
}
override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = {
val paths = getPaths(options)
- val tableName = getTableName(paths)
+ val tableName = getTableName(options, paths)
TextTable(tableName, sparkSession, options, paths, Some(schema), fallbackFileFormat)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
index 32245470d8f5d..ecaf4f8160a06 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
@@ -45,8 +45,7 @@ object FileStreamSink extends Logging {
val hdfsPath = new Path(singlePath)
val fs = hdfsPath.getFileSystem(hadoopConf)
if (fs.isDirectory(hdfsPath)) {
- val metadataPath = new Path(hdfsPath, metadataDir)
- checkEscapedMetadataPath(fs, metadataPath, sqlConf)
+ val metadataPath = getMetadataLogPath(fs, hdfsPath, sqlConf)
fs.exists(metadataPath)
} else {
false
@@ -55,6 +54,12 @@ object FileStreamSink extends Logging {
}
}
+ def getMetadataLogPath(fs: FileSystem, path: Path, sqlConf: SQLConf): Path = {
+ val metadataDir = new Path(path, FileStreamSink.metadataDir)
+ FileStreamSink.checkEscapedMetadataPath(fs, metadataDir, sqlConf)
+ metadataDir
+ }
+
def checkEscapedMetadataPath(fs: FileSystem, metadataPath: Path, sqlConf: SQLConf): Unit = {
if (sqlConf.getConf(SQLConf.STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED)
&& StreamExecution.containsSpecialCharsInPath(metadataPath)) {
@@ -125,14 +130,12 @@ class FileStreamSink(
partitionColumnNames: Seq[String],
options: Map[String, String]) extends Sink with Logging {
+ import FileStreamSink._
+
private val hadoopConf = sparkSession.sessionState.newHadoopConf()
private val basePath = new Path(path)
- private val logPath = {
- val metadataDir = new Path(basePath, FileStreamSink.metadataDir)
- val fs = metadataDir.getFileSystem(hadoopConf)
- FileStreamSink.checkEscapedMetadataPath(fs, metadataDir, sparkSession.sessionState.conf)
- metadataDir
- }
+ private val logPath = getMetadataLogPath(basePath.getFileSystem(hadoopConf), basePath,
+ sparkSession.sessionState.conf)
private val fileLog =
new FileStreamSinkLog(FileStreamSinkLog.VERSION, sparkSession, logPath.toString)
diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql
index 06765627f5545..bdf11f51db532 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql
@@ -2,13 +2,18 @@
-- [SPARK-31710] TIMESTAMP_SECONDS, TIMESTAMP_MILLISECONDS and TIMESTAMP_MICROSECONDS to timestamp transfer
select TIMESTAMP_SECONDS(1230219000),TIMESTAMP_SECONDS(-1230219000),TIMESTAMP_SECONDS(null);
+select TIMESTAMP_SECONDS(1.23), TIMESTAMP_SECONDS(1.23d), TIMESTAMP_SECONDS(FLOAT(1.23));
select TIMESTAMP_MILLIS(1230219000123),TIMESTAMP_MILLIS(-1230219000123),TIMESTAMP_MILLIS(null);
select TIMESTAMP_MICROS(1230219000123123),TIMESTAMP_MICROS(-1230219000123123),TIMESTAMP_MICROS(null);
--- overflow exception:
+-- overflow exception
select TIMESTAMP_SECONDS(1230219000123123);
select TIMESTAMP_SECONDS(-1230219000123123);
select TIMESTAMP_MILLIS(92233720368547758);
select TIMESTAMP_MILLIS(-92233720368547758);
+-- truncate exception
+select TIMESTAMP_SECONDS(0.1234567);
+-- truncation is OK for float/double
+select TIMESTAMP_SECONDS(0.1234567d), TIMESTAMP_SECONDS(FLOAT(0.1234567));
-- [SPARK-16836] current_date and current_timestamp literals
select current_date = current_date(), current_timestamp = current_timestamp();
diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql
index 481b5e8cc7700..0a16f118f0455 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql
@@ -72,7 +72,7 @@ SELECT Count(DISTINCT( t1a )),
FROM t1
WHERE t1d IN (SELECT t2d
FROM t2
- ORDER BY t2c
+ ORDER BY t2c, t2d
LIMIT 2)
GROUP BY t1b
ORDER BY t1b DESC NULLS FIRST
@@ -93,7 +93,7 @@ SELECT Count(DISTINCT( t1a )),
FROM t1
WHERE t1d NOT IN (SELECT t2d
FROM t2
- ORDER BY t2b DESC nulls first
+ ORDER BY t2b DESC nulls first, t2d
LIMIT 1)
GROUP BY t1b
ORDER BY t1b NULLS last
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/datetime.sql.out
index 26adb40ce1b14..484b67677a91b 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/datetime.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/datetime.sql.out
@@ -1,15 +1,23 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 103
+-- Number of queries: 106
-- !query
select TIMESTAMP_SECONDS(1230219000),TIMESTAMP_SECONDS(-1230219000),TIMESTAMP_SECONDS(null)
-- !query schema
-struct
+struct
-- !query output
2008-12-25 07:30:00 1931-01-07 00:30:00 NULL
+-- !query
+select TIMESTAMP_SECONDS(1.23), TIMESTAMP_SECONDS(1.23d), TIMESTAMP_SECONDS(FLOAT(1.23))
+-- !query schema
+struct
+-- !query output
+1969-12-31 16:00:01.23 1969-12-31 16:00:01.23 1969-12-31 16:00:01.23
+
+
-- !query
select TIMESTAMP_MILLIS(1230219000123),TIMESTAMP_MILLIS(-1230219000123),TIMESTAMP_MILLIS(null)
-- !query schema
@@ -62,6 +70,23 @@ java.lang.ArithmeticException
long overflow
+-- !query
+select TIMESTAMP_SECONDS(0.1234567)
+-- !query schema
+struct<>
+-- !query output
+java.lang.ArithmeticException
+Rounding necessary
+
+
+-- !query
+select TIMESTAMP_SECONDS(0.1234567d), TIMESTAMP_SECONDS(FLOAT(0.1234567))
+-- !query schema
+struct
+-- !query output
+1969-12-31 16:00:00.123456 1969-12-31 16:00:00.123456
+
+
-- !query
select current_date = current_date(), current_timestamp = current_timestamp()
-- !query schema
diff --git a/sql/core/src/test/resources/sql-tests/results/datetime-legacy.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime-legacy.sql.out
index 15092f0a27c1f..edb49e575f52e 100644
--- a/sql/core/src/test/resources/sql-tests/results/datetime-legacy.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/datetime-legacy.sql.out
@@ -1,15 +1,23 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 103
+-- Number of queries: 106
-- !query
select TIMESTAMP_SECONDS(1230219000),TIMESTAMP_SECONDS(-1230219000),TIMESTAMP_SECONDS(null)
-- !query schema
-struct
+struct
-- !query output
2008-12-25 07:30:00 1931-01-07 00:30:00 NULL
+-- !query
+select TIMESTAMP_SECONDS(1.23), TIMESTAMP_SECONDS(1.23d), TIMESTAMP_SECONDS(FLOAT(1.23))
+-- !query schema
+struct
+-- !query output
+1969-12-31 16:00:01.23 1969-12-31 16:00:01.23 1969-12-31 16:00:01.23
+
+
-- !query
select TIMESTAMP_MILLIS(1230219000123),TIMESTAMP_MILLIS(-1230219000123),TIMESTAMP_MILLIS(null)
-- !query schema
@@ -62,6 +70,23 @@ java.lang.ArithmeticException
long overflow
+-- !query
+select TIMESTAMP_SECONDS(0.1234567)
+-- !query schema
+struct<>
+-- !query output
+java.lang.ArithmeticException
+Rounding necessary
+
+
+-- !query
+select TIMESTAMP_SECONDS(0.1234567d), TIMESTAMP_SECONDS(FLOAT(0.1234567))
+-- !query schema
+struct
+-- !query output
+1969-12-31 16:00:00.123456 1969-12-31 16:00:00.123456
+
+
-- !query
select current_date = current_date(), current_timestamp = current_timestamp()
-- !query schema
diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out
index b80f36e9c2347..9f9351a4809af 100755
--- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out
@@ -1,15 +1,23 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 103
+-- Number of queries: 106
-- !query
select TIMESTAMP_SECONDS(1230219000),TIMESTAMP_SECONDS(-1230219000),TIMESTAMP_SECONDS(null)
-- !query schema
-struct
+struct
-- !query output
2008-12-25 07:30:00 1931-01-07 00:30:00 NULL
+-- !query
+select TIMESTAMP_SECONDS(1.23), TIMESTAMP_SECONDS(1.23d), TIMESTAMP_SECONDS(FLOAT(1.23))
+-- !query schema
+struct
+-- !query output
+1969-12-31 16:00:01.23 1969-12-31 16:00:01.23 1969-12-31 16:00:01.23
+
+
-- !query
select TIMESTAMP_MILLIS(1230219000123),TIMESTAMP_MILLIS(-1230219000123),TIMESTAMP_MILLIS(null)
-- !query schema
@@ -62,6 +70,23 @@ java.lang.ArithmeticException
long overflow
+-- !query
+select TIMESTAMP_SECONDS(0.1234567)
+-- !query schema
+struct<>
+-- !query output
+java.lang.ArithmeticException
+Rounding necessary
+
+
+-- !query
+select TIMESTAMP_SECONDS(0.1234567d), TIMESTAMP_SECONDS(FLOAT(0.1234567))
+-- !query schema
+struct
+-- !query output
+1969-12-31 16:00:00.123456 1969-12-31 16:00:00.123456
+
+
-- !query
select current_date = current_date(), current_timestamp = current_timestamp()
-- !query schema
diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out
index 1c335445114c7..e24538b9138ba 100644
--- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out
@@ -103,7 +103,7 @@ SELECT Count(DISTINCT( t1a )),
FROM t1
WHERE t1d IN (SELECT t2d
FROM t2
- ORDER BY t2c
+ ORDER BY t2c, t2d
LIMIT 2)
GROUP BY t1b
ORDER BY t1b DESC NULLS FIRST
@@ -136,7 +136,7 @@ SELECT Count(DISTINCT( t1a )),
FROM t1
WHERE t1d NOT IN (SELECT t2d
FROM t2
- ORDER BY t2b DESC nulls first
+ ORDER BY t2b DESC nulls first, t2d
LIMIT 1)
GROUP BY t1b
ORDER BY t1b NULLS last
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index f7438f3ffec04..09f30bb5e2c77 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -1028,4 +1028,16 @@ class DataFrameAggregateSuite extends QueryTest
checkAnswer(df, Row("abellina", 2) :: Row("mithunr", 1) :: Nil)
}
}
+
+ test("SPARK-32136: NormalizeFloatingNumbers should work on null struct") {
+ val df = Seq(
+ A(None),
+ A(Some(B(None))),
+ A(Some(B(Some(1.0))))).toDF
+ val groupBy = df.groupBy("b").agg(count("*"))
+ checkAnswer(groupBy, Row(null, 1) :: Row(Row(null), 1) :: Row(Row(1.0), 1) :: Nil)
+ }
}
+
+case class B(c: Option[Double])
+case class A(b: Option[B])
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
index d8157d3c779b9..231a8f2aa7ddd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
@@ -18,12 +18,14 @@
package org.apache.spark.sql
import java.io.{File, FileNotFoundException}
+import java.net.URI
import java.nio.file.{Files, StandardOpenOption}
import java.util.Locale
import scala.collection.mutable
-import org.apache.hadoop.fs.Path
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{LocalFileSystem, Path}
import org.apache.spark.SparkException
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
@@ -845,19 +847,15 @@ class FileBasedDataSourceSuite extends QueryTest
test("SPARK-31935: Hadoop file system config should be effective in data source options") {
Seq("parquet", "").foreach { format =>
- withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> format) {
+ withSQLConf(
+ SQLConf.USE_V1_SOURCE_LIST.key -> format,
+ "fs.file.impl" -> classOf[FakeFileSystemRequiringDSOption].getName,
+ "fs.file.impl.disable.cache" -> "true") {
withTempDir { dir =>
- val path = dir.getCanonicalPath
- val defaultFs = "nonexistFS://nonexistFS"
- val expectMessage = "No FileSystem for scheme nonexistFS"
- val message1 = intercept[java.io.IOException] {
- spark.range(10).write.option("fs.defaultFS", defaultFs).parquet(path)
- }.getMessage
- assert(message1.filterNot(Set(':', '"').contains) == expectMessage)
- val message2 = intercept[java.io.IOException] {
- spark.read.option("fs.defaultFS", defaultFs).parquet(path)
- }.getMessage
- assert(message2.filterNot(Set(':', '"').contains) == expectMessage)
+ val path = "file:" + dir.getCanonicalPath.stripPrefix("file:")
+ spark.range(10).write.option("ds_option", "value").mode("overwrite").parquet(path)
+ checkAnswer(
+ spark.read.option("ds_option", "value").parquet(path), spark.range(10).toDF())
}
}
}
@@ -932,3 +930,10 @@ object TestingUDT {
override def userClass: Class[NullData] = classOf[NullData]
}
}
+
+class FakeFileSystemRequiringDSOption extends LocalFileSystem {
+ override def initialize(name: URI, conf: Configuration): Unit = {
+ super.initialize(name, conf)
+ require(conf.get("ds_option", "") == "value")
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
index 8462ce5a6c44f..f7f4df8f2d2e9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
@@ -1479,11 +1479,26 @@ class DataSourceV2SQLSuite
assert(exception.getMessage.contains("Database 'ns1' not found"))
}
- test("Use: v2 catalog is used and namespace does not exist") {
- // Namespaces are not required to exist for v2 catalogs.
- sql("USE testcat.ns1.ns2")
- val catalogManager = spark.sessionState.catalogManager
- assert(catalogManager.currentNamespace === Array("ns1", "ns2"))
+ test("SPARK-31100: Use: v2 catalog that implements SupportsNamespaces is used " +
+ "and namespace not exists") {
+ // Namespaces are required to exist for v2 catalogs that implements SupportsNamespaces.
+ val exception = intercept[NoSuchNamespaceException] {
+ sql("USE testcat.ns1.ns2")
+ }
+ assert(exception.getMessage.contains("Namespace 'ns1.ns2' not found"))
+ }
+
+ test("SPARK-31100: Use: v2 catalog that does not implement SupportsNameSpaces is used " +
+ "and namespace does not exist") {
+ // Namespaces are not required to exist for v2 catalogs
+ // that does not implement SupportsNamespaces.
+ withSQLConf("spark.sql.catalog.dummy" -> classOf[BasicInMemoryTableCatalog].getName) {
+ val catalogManager = spark.sessionState.catalogManager
+
+ sql("USE dummy.ns1")
+ assert(catalogManager.currentCatalog.name() == "dummy")
+ assert(catalogManager.currentNamespace === Array("ns1"))
+ }
}
test("ShowCurrentNamespace: basic tests") {
@@ -1505,6 +1520,8 @@ class DataSourceV2SQLSuite
sql("USE testcat")
testShowCurrentNamespace("testcat", "")
+
+ sql("CREATE NAMESPACE testcat.ns1.ns2")
sql("USE testcat.ns1.ns2")
testShowCurrentNamespace("testcat", "ns1.ns2")
}
@@ -2342,6 +2359,7 @@ class DataSourceV2SQLSuite
spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key)
val sessionCatalogName = CatalogManager.SESSION_CATALOG_NAME
+ sql("CREATE NAMESPACE testcat.ns1.ns2")
sql("USE testcat.ns1.ns2")
sql("CREATE TABLE t USING foo AS SELECT 1 col")
checkAnswer(spark.table("t"), Row(1))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
index 06574a9f8fd2c..1991f139e48c5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
@@ -199,20 +199,20 @@ class SparkSqlParserSuite extends AnalysisTest {
assertEqual(s"$baseSql distribute by a, b",
RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil,
basePlan,
- numPartitions = newConf.numShufflePartitions))
+ None))
assertEqual(s"$baseSql distribute by a sort by b",
Sort(SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil,
global = false,
RepartitionByExpression(UnresolvedAttribute("a") :: Nil,
basePlan,
- numPartitions = newConf.numShufflePartitions)))
+ None)))
assertEqual(s"$baseSql cluster by a, b",
Sort(SortOrder(UnresolvedAttribute("a"), Ascending) ::
SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil,
global = false,
RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil,
basePlan,
- numPartitions = newConf.numShufflePartitions)))
+ None)))
}
test("pipeline concatenation") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
index 27d9748476c98..c696d3f648ed1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
@@ -23,7 +23,7 @@ import java.net.URI
import org.apache.log4j.Level
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart}
-import org.apache.spark.sql.{QueryTest, Row, SparkSession, Strategy}
+import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, ReusedSubqueryExec, ShuffledRowRDD, SparkPlan}
@@ -130,6 +130,17 @@ class AdaptiveQueryExecSuite
assert(numShuffles === (numLocalReaders.length + numShufflesWithoutLocalReader))
}
+ private def checkInitialPartitionNum(df: Dataset[_], numPartition: Int): Unit = {
+ // repartition obeys initialPartitionNum when adaptiveExecutionEnabled
+ val plan = df.queryExecution.executedPlan
+ assert(plan.isInstanceOf[AdaptiveSparkPlanExec])
+ val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect {
+ case s: ShuffleExchangeExec => s
+ }
+ assert(shuffle.size == 1)
+ assert(shuffle(0).outputPartitioning.numPartitions == numPartition)
+ }
+
test("Change merge join to broadcast join") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
@@ -1040,14 +1051,8 @@ class AdaptiveQueryExecSuite
assert(partitionsNum1 < 10)
assert(partitionsNum2 < 10)
- // repartition obeys initialPartitionNum when adaptiveExecutionEnabled
- val plan = df1.queryExecution.executedPlan
- assert(plan.isInstanceOf[AdaptiveSparkPlanExec])
- val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect {
- case s: ShuffleExchangeExec => s
- }
- assert(shuffle.size == 1)
- assert(shuffle(0).outputPartitioning.numPartitions == 10)
+ checkInitialPartitionNum(df1, 10)
+ checkInitialPartitionNum(df2, 10)
} else {
assert(partitionsNum1 === 10)
assert(partitionsNum2 === 10)
@@ -1081,14 +1086,8 @@ class AdaptiveQueryExecSuite
assert(partitionsNum1 < 10)
assert(partitionsNum2 < 10)
- // repartition obeys initialPartitionNum when adaptiveExecutionEnabled
- val plan = df1.queryExecution.executedPlan
- assert(plan.isInstanceOf[AdaptiveSparkPlanExec])
- val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect {
- case s: ShuffleExchangeExec => s
- }
- assert(shuffle.size == 1)
- assert(shuffle(0).outputPartitioning.numPartitions == 10)
+ checkInitialPartitionNum(df1, 10)
+ checkInitialPartitionNum(df2, 10)
} else {
assert(partitionsNum1 === 10)
assert(partitionsNum2 === 10)
@@ -1100,4 +1099,52 @@ class AdaptiveQueryExecSuite
}
}
}
+
+ test("SPARK-31220, SPARK-32056: repartition using sql and hint with AQE") {
+ Seq(true, false).foreach { enableAQE =>
+ withTempView("test") {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString,
+ SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
+ SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10",
+ SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+
+ spark.range(10).toDF.createTempView("test")
+
+ val df1 = spark.sql("SELECT /*+ REPARTITION(id) */ * from test")
+ val df2 = spark.sql("SELECT /*+ REPARTITION_BY_RANGE(id) */ * from test")
+ val df3 = spark.sql("SELECT * from test DISTRIBUTE BY id")
+ val df4 = spark.sql("SELECT * from test CLUSTER BY id")
+
+ val partitionsNum1 = df1.rdd.collectPartitions().length
+ val partitionsNum2 = df2.rdd.collectPartitions().length
+ val partitionsNum3 = df3.rdd.collectPartitions().length
+ val partitionsNum4 = df4.rdd.collectPartitions().length
+
+ if (enableAQE) {
+ assert(partitionsNum1 < 10)
+ assert(partitionsNum2 < 10)
+ assert(partitionsNum3 < 10)
+ assert(partitionsNum4 < 10)
+
+ checkInitialPartitionNum(df1, 10)
+ checkInitialPartitionNum(df2, 10)
+ checkInitialPartitionNum(df3, 10)
+ checkInitialPartitionNum(df4, 10)
+ } else {
+ assert(partitionsNum1 === 10)
+ assert(partitionsNum2 === 10)
+ assert(partitionsNum3 === 10)
+ assert(partitionsNum4 === 10)
+ }
+
+ // Don't coalesce partitions if the number of partitions is specified.
+ val df5 = spark.sql("SELECT /*+ REPARTITION(10, id) */ * from test")
+ val df6 = spark.sql("SELECT /*+ REPARTITION_BY_RANGE(10, id) */ * from test")
+ assert(df5.rdd.collectPartitions().length == 10)
+ assert(df6.rdd.collectPartitions().length == 10)
+ }
+ }
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/OracleConnectionProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/OracleConnectionProviderSuite.scala
new file mode 100644
index 0000000000000..13cde32ddbe4e
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/OracleConnectionProviderSuite.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.jdbc.connection
+
+class OracleConnectionProviderSuite extends ConnectionProviderSuiteBase {
+ test("setAuthenticationConfigIfNeeded must set authentication if not set") {
+ val driver = registerDriver(OracleConnectionProvider.driverClass)
+ val provider = new OracleConnectionProvider(driver,
+ options("jdbc:oracle:thin:@//localhost/xe"))
+
+ testSecureConnectionProvider(provider)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmark.scala
index 56930880ed5da..0dbd6b5754afb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmark.scala
@@ -430,7 +430,7 @@ object JsonBenchmark extends SqlBasedBenchmark {
}
readBench.addCase("infer timestamps from files", numIters) { _ =>
- spark.read.json(timestampDir).noop()
+ spark.read.option("inferTimestamp", true).json(timestampDir).noop()
}
val dateSchema = new StructType().add("date", DateType)
@@ -460,7 +460,7 @@ object JsonBenchmark extends SqlBasedBenchmark {
}
readBench.addCase("infer timestamps from Dataset[String]", numIters) { _ =>
- spark.read.json(timestampStr).noop()
+ spark.read.option("inferTimestamp", true).json(timestampStr).noop()
}
def dateStr: Dataset[String] = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 6344ec6be4878..c7448b12626be 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -2610,7 +2610,9 @@ abstract class JsonSuite extends QueryTest with SharedSparkSession with TestJson
}
test("inferring timestamp type") {
- def schemaOf(jsons: String*): StructType = spark.read.json(jsons.toDS).schema
+ def schemaOf(jsons: String*): StructType = {
+ spark.read.option("inferTimestamp", true).json(jsons.toDS).schema
+ }
assert(schemaOf(
"""{"a":"2018-12-17T10:11:12.123-01:00"}""",
@@ -2633,6 +2635,7 @@ abstract class JsonSuite extends QueryTest with SharedSparkSession with TestJson
val timestampsWithFormatPath = s"${dir.getCanonicalPath}/timestampsWithFormat.json"
val timestampsWithFormat = spark.read
.option("timestampFormat", "dd/MM/yyyy HH:mm")
+ .option("inferTimestamp", true)
.json(datesRecords)
assert(timestampsWithFormat.schema === customSchema)
@@ -2645,6 +2648,7 @@ abstract class JsonSuite extends QueryTest with SharedSparkSession with TestJson
val readBack = spark.read
.option("timestampFormat", "yyyy-MM-dd HH:mm:ss")
.option(DateTimeUtils.TIMEZONE_OPTION, "UTC")
+ .option("inferTimestamp", true)
.json(timestampsWithFormatPath)
assert(readBack.schema === customSchema)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala
index 73873684f6aaf..b70fd7476ed98 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala
@@ -213,9 +213,7 @@ abstract class OrcSuite extends OrcTest with BeforeAndAfterAll {
Seq(fs.listStatus(path1), fs.listStatus(path2), fs.listStatus(path3)).flatten
val schema = SchemaMergeUtils.mergeSchemasInParallel(
- spark,
- fileStatuses,
- schemaReader)
+ spark, Map.empty, fileStatuses, schemaReader)
assert(schema.isDefined)
assert(schema.get == StructType(Seq(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
index d20a07f420e87..8b922aaed4e68 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
@@ -22,6 +22,9 @@ import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
import java.time.{LocalDate, LocalDateTime, ZoneId}
+import scala.reflect.ClassTag
+import scala.reflect.runtime.universe.TypeTag
+
import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, Operators}
import org.apache.parquet.filter2.predicate.FilterApi._
import org.apache.parquet.filter2.predicate.Operators.{Column => _, _}
@@ -106,10 +109,18 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
}
/**
- * Takes single level `inputDF` dataframe to generate multi-level nested
- * dataframes as new test data.
+ * Takes a sequence of products `data` to generate multi-level nested
+ * dataframes as new test data. It tests both non-nested and nested dataframes
+ * which are written and read back with Parquet datasource.
+ *
+ * This is different from [[ParquetTest.withParquetDataFrame]] which does not
+ * test nested cases.
*/
- private def withNestedDataFrame(inputDF: DataFrame)
+ private def withNestedParquetDataFrame[T <: Product: ClassTag: TypeTag](data: Seq[T])
+ (runTest: (DataFrame, String, Any => Any) => Unit): Unit =
+ withNestedParquetDataFrame(spark.createDataFrame(data))(runTest)
+
+ private def withNestedParquetDataFrame(inputDF: DataFrame)
(runTest: (DataFrame, String, Any => Any) => Unit): Unit = {
assert(inputDF.schema.fields.length == 1)
assert(!inputDF.schema.fields.head.dataType.isInstanceOf[StructType])
@@ -138,8 +149,11 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
"`a.b`.`c.d`", // one level nesting with column names containing `dots`
(x: Any) => Row(x)
)
- ).foreach { case (df, colName, resultFun) =>
- runTest(df, colName, resultFun)
+ ).foreach { case (newDF, colName, resultFun) =>
+ withTempPath { file =>
+ newDF.write.format(dataSourceName).save(file.getCanonicalPath)
+ readParquetFile(file.getCanonicalPath) { df => runTest(df, colName, resultFun) }
+ }
}
}
@@ -155,7 +169,9 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
import testImplicits._
val df = data.map(i => Tuple1(Timestamp.valueOf(i))).toDF()
- withNestedDataFrame(df) { case (inputDF, colName, fun) =>
+ withNestedParquetDataFrame(df) { case (parquetDF, colName, fun) =>
+ implicit val df: DataFrame = parquetDF
+
def resultFun(tsStr: String): Any = {
val parsed = if (java8Api) {
LocalDateTime.parse(tsStr.replace(" ", "T"))
@@ -166,36 +182,35 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
}
fun(parsed)
}
- withParquetDataFrame(inputDF) { implicit df =>
- val tsAttr = df(colName).expr
- assert(df(colName).expr.dataType === TimestampType)
-
- checkFilterPredicate(tsAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
- checkFilterPredicate(tsAttr.isNotNull, classOf[NotEq[_]],
- data.map(i => Row.apply(resultFun(i))))
-
- checkFilterPredicate(tsAttr === ts1.ts, classOf[Eq[_]], resultFun(ts1))
- checkFilterPredicate(tsAttr <=> ts1.ts, classOf[Eq[_]], resultFun(ts1))
- checkFilterPredicate(tsAttr =!= ts1.ts, classOf[NotEq[_]],
- Seq(ts2, ts3, ts4).map(i => Row.apply(resultFun(i))))
-
- checkFilterPredicate(tsAttr < ts2.ts, classOf[Lt[_]], resultFun(ts1))
- checkFilterPredicate(tsAttr > ts1.ts, classOf[Gt[_]],
- Seq(ts2, ts3, ts4).map(i => Row.apply(resultFun(i))))
- checkFilterPredicate(tsAttr <= ts1.ts, classOf[LtEq[_]], resultFun(ts1))
- checkFilterPredicate(tsAttr >= ts4.ts, classOf[GtEq[_]], resultFun(ts4))
-
- checkFilterPredicate(Literal(ts1.ts) === tsAttr, classOf[Eq[_]], resultFun(ts1))
- checkFilterPredicate(Literal(ts1.ts) <=> tsAttr, classOf[Eq[_]], resultFun(ts1))
- checkFilterPredicate(Literal(ts2.ts) > tsAttr, classOf[Lt[_]], resultFun(ts1))
- checkFilterPredicate(Literal(ts3.ts) < tsAttr, classOf[Gt[_]], resultFun(ts4))
- checkFilterPredicate(Literal(ts1.ts) >= tsAttr, classOf[LtEq[_]], resultFun(ts1))
- checkFilterPredicate(Literal(ts4.ts) <= tsAttr, classOf[GtEq[_]], resultFun(ts4))
-
- checkFilterPredicate(!(tsAttr < ts4.ts), classOf[GtEq[_]], resultFun(ts4))
- checkFilterPredicate(tsAttr < ts2.ts || tsAttr > ts3.ts, classOf[Operators.Or],
- Seq(Row(resultFun(ts1)), Row(resultFun(ts4))))
- }
+
+ val tsAttr = df(colName).expr
+ assert(df(colName).expr.dataType === TimestampType)
+
+ checkFilterPredicate(tsAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkFilterPredicate(tsAttr.isNotNull, classOf[NotEq[_]],
+ data.map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(tsAttr === ts1.ts, classOf[Eq[_]], resultFun(ts1))
+ checkFilterPredicate(tsAttr <=> ts1.ts, classOf[Eq[_]], resultFun(ts1))
+ checkFilterPredicate(tsAttr =!= ts1.ts, classOf[NotEq[_]],
+ Seq(ts2, ts3, ts4).map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(tsAttr < ts2.ts, classOf[Lt[_]], resultFun(ts1))
+ checkFilterPredicate(tsAttr > ts1.ts, classOf[Gt[_]],
+ Seq(ts2, ts3, ts4).map(i => Row.apply(resultFun(i))))
+ checkFilterPredicate(tsAttr <= ts1.ts, classOf[LtEq[_]], resultFun(ts1))
+ checkFilterPredicate(tsAttr >= ts4.ts, classOf[GtEq[_]], resultFun(ts4))
+
+ checkFilterPredicate(Literal(ts1.ts) === tsAttr, classOf[Eq[_]], resultFun(ts1))
+ checkFilterPredicate(Literal(ts1.ts) <=> tsAttr, classOf[Eq[_]], resultFun(ts1))
+ checkFilterPredicate(Literal(ts2.ts) > tsAttr, classOf[Lt[_]], resultFun(ts1))
+ checkFilterPredicate(Literal(ts3.ts) < tsAttr, classOf[Gt[_]], resultFun(ts4))
+ checkFilterPredicate(Literal(ts1.ts) >= tsAttr, classOf[LtEq[_]], resultFun(ts1))
+ checkFilterPredicate(Literal(ts4.ts) <= tsAttr, classOf[GtEq[_]], resultFun(ts4))
+
+ checkFilterPredicate(!(tsAttr < ts4.ts), classOf[GtEq[_]], resultFun(ts4))
+ checkFilterPredicate(tsAttr < ts2.ts || tsAttr > ts3.ts, classOf[Operators.Or],
+ Seq(Row(resultFun(ts1)), Row(resultFun(ts4))))
}
}
@@ -226,272 +241,264 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
test("filter pushdown - boolean") {
val data = (true :: false :: Nil).map(b => Tuple1.apply(Option(b)))
- import testImplicits._
- withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) =>
- withParquetDataFrame(inputDF) { implicit df =>
- val booleanAttr = df(colName).expr
- assert(df(colName).expr.dataType === BooleanType)
-
- checkFilterPredicate(booleanAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
- checkFilterPredicate(booleanAttr.isNotNull, classOf[NotEq[_]],
- Seq(Row(resultFun(true)), Row(resultFun(false))))
-
- checkFilterPredicate(booleanAttr === true, classOf[Eq[_]], resultFun(true))
- checkFilterPredicate(booleanAttr <=> true, classOf[Eq[_]], resultFun(true))
- checkFilterPredicate(booleanAttr =!= true, classOf[NotEq[_]], resultFun(false))
- }
+ withNestedParquetDataFrame(data) { case (inputDF, colName, resultFun) =>
+ implicit val df: DataFrame = inputDF
+
+ val booleanAttr = df(colName).expr
+ assert(df(colName).expr.dataType === BooleanType)
+
+ checkFilterPredicate(booleanAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkFilterPredicate(booleanAttr.isNotNull, classOf[NotEq[_]],
+ Seq(Row(resultFun(true)), Row(resultFun(false))))
+
+ checkFilterPredicate(booleanAttr === true, classOf[Eq[_]], resultFun(true))
+ checkFilterPredicate(booleanAttr <=> true, classOf[Eq[_]], resultFun(true))
+ checkFilterPredicate(booleanAttr =!= true, classOf[NotEq[_]], resultFun(false))
}
}
test("filter pushdown - tinyint") {
val data = (1 to 4).map(i => Tuple1(Option(i.toByte)))
- import testImplicits._
- withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) =>
- withParquetDataFrame(inputDF) { implicit df =>
- val tinyIntAttr = df(colName).expr
- assert(df(colName).expr.dataType === ByteType)
-
- checkFilterPredicate(tinyIntAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
- checkFilterPredicate(tinyIntAttr.isNotNull, classOf[NotEq[_]],
- (1 to 4).map(i => Row.apply(resultFun(i))))
-
- checkFilterPredicate(tinyIntAttr === 1.toByte, classOf[Eq[_]], resultFun(1))
- checkFilterPredicate(tinyIntAttr <=> 1.toByte, classOf[Eq[_]], resultFun(1))
- checkFilterPredicate(tinyIntAttr =!= 1.toByte, classOf[NotEq[_]],
- (2 to 4).map(i => Row.apply(resultFun(i))))
-
- checkFilterPredicate(tinyIntAttr < 2.toByte, classOf[Lt[_]], resultFun(1))
- checkFilterPredicate(tinyIntAttr > 3.toByte, classOf[Gt[_]], resultFun(4))
- checkFilterPredicate(tinyIntAttr <= 1.toByte, classOf[LtEq[_]], resultFun(1))
- checkFilterPredicate(tinyIntAttr >= 4.toByte, classOf[GtEq[_]], resultFun(4))
-
- checkFilterPredicate(Literal(1.toByte) === tinyIntAttr, classOf[Eq[_]], resultFun(1))
- checkFilterPredicate(Literal(1.toByte) <=> tinyIntAttr, classOf[Eq[_]], resultFun(1))
- checkFilterPredicate(Literal(2.toByte) > tinyIntAttr, classOf[Lt[_]], resultFun(1))
- checkFilterPredicate(Literal(3.toByte) < tinyIntAttr, classOf[Gt[_]], resultFun(4))
- checkFilterPredicate(Literal(1.toByte) >= tinyIntAttr, classOf[LtEq[_]], resultFun(1))
- checkFilterPredicate(Literal(4.toByte) <= tinyIntAttr, classOf[GtEq[_]], resultFun(4))
-
- checkFilterPredicate(!(tinyIntAttr < 4.toByte), classOf[GtEq[_]], resultFun(4))
- checkFilterPredicate(tinyIntAttr < 2.toByte || tinyIntAttr > 3.toByte,
- classOf[Operators.Or], Seq(Row(resultFun(1)), Row(resultFun(4))))
- }
+ withNestedParquetDataFrame(data) { case (inputDF, colName, resultFun) =>
+ implicit val df: DataFrame = inputDF
+
+ val tinyIntAttr = df(colName).expr
+ assert(df(colName).expr.dataType === ByteType)
+
+ checkFilterPredicate(tinyIntAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkFilterPredicate(tinyIntAttr.isNotNull, classOf[NotEq[_]],
+ (1 to 4).map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(tinyIntAttr === 1.toByte, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(tinyIntAttr <=> 1.toByte, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(tinyIntAttr =!= 1.toByte, classOf[NotEq[_]],
+ (2 to 4).map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(tinyIntAttr < 2.toByte, classOf[Lt[_]], resultFun(1))
+ checkFilterPredicate(tinyIntAttr > 3.toByte, classOf[Gt[_]], resultFun(4))
+ checkFilterPredicate(tinyIntAttr <= 1.toByte, classOf[LtEq[_]], resultFun(1))
+ checkFilterPredicate(tinyIntAttr >= 4.toByte, classOf[GtEq[_]], resultFun(4))
+
+ checkFilterPredicate(Literal(1.toByte) === tinyIntAttr, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(Literal(1.toByte) <=> tinyIntAttr, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(Literal(2.toByte) > tinyIntAttr, classOf[Lt[_]], resultFun(1))
+ checkFilterPredicate(Literal(3.toByte) < tinyIntAttr, classOf[Gt[_]], resultFun(4))
+ checkFilterPredicate(Literal(1.toByte) >= tinyIntAttr, classOf[LtEq[_]], resultFun(1))
+ checkFilterPredicate(Literal(4.toByte) <= tinyIntAttr, classOf[GtEq[_]], resultFun(4))
+
+ checkFilterPredicate(!(tinyIntAttr < 4.toByte), classOf[GtEq[_]], resultFun(4))
+ checkFilterPredicate(tinyIntAttr < 2.toByte || tinyIntAttr > 3.toByte,
+ classOf[Operators.Or], Seq(Row(resultFun(1)), Row(resultFun(4))))
}
}
test("filter pushdown - smallint") {
val data = (1 to 4).map(i => Tuple1(Option(i.toShort)))
- import testImplicits._
- withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) =>
- withParquetDataFrame(inputDF) { implicit df =>
- val smallIntAttr = df(colName).expr
- assert(df(colName).expr.dataType === ShortType)
-
- checkFilterPredicate(smallIntAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
- checkFilterPredicate(smallIntAttr.isNotNull, classOf[NotEq[_]],
- (1 to 4).map(i => Row.apply(resultFun(i))))
-
- checkFilterPredicate(smallIntAttr === 1.toShort, classOf[Eq[_]], resultFun(1))
- checkFilterPredicate(smallIntAttr <=> 1.toShort, classOf[Eq[_]], resultFun(1))
- checkFilterPredicate(smallIntAttr =!= 1.toShort, classOf[NotEq[_]],
- (2 to 4).map(i => Row.apply(resultFun(i))))
-
- checkFilterPredicate(smallIntAttr < 2.toShort, classOf[Lt[_]], resultFun(1))
- checkFilterPredicate(smallIntAttr > 3.toShort, classOf[Gt[_]], resultFun(4))
- checkFilterPredicate(smallIntAttr <= 1.toShort, classOf[LtEq[_]], resultFun(1))
- checkFilterPredicate(smallIntAttr >= 4.toShort, classOf[GtEq[_]], resultFun(4))
-
- checkFilterPredicate(Literal(1.toShort) === smallIntAttr, classOf[Eq[_]], resultFun(1))
- checkFilterPredicate(Literal(1.toShort) <=> smallIntAttr, classOf[Eq[_]], resultFun(1))
- checkFilterPredicate(Literal(2.toShort) > smallIntAttr, classOf[Lt[_]], resultFun(1))
- checkFilterPredicate(Literal(3.toShort) < smallIntAttr, classOf[Gt[_]], resultFun(4))
- checkFilterPredicate(Literal(1.toShort) >= smallIntAttr, classOf[LtEq[_]], resultFun(1))
- checkFilterPredicate(Literal(4.toShort) <= smallIntAttr, classOf[GtEq[_]], resultFun(4))
-
- checkFilterPredicate(!(smallIntAttr < 4.toShort), classOf[GtEq[_]], resultFun(4))
- checkFilterPredicate(smallIntAttr < 2.toShort || smallIntAttr > 3.toShort,
- classOf[Operators.Or], Seq(Row(resultFun(1)), Row(resultFun(4))))
- }
+ withNestedParquetDataFrame(data) { case (inputDF, colName, resultFun) =>
+ implicit val df: DataFrame = inputDF
+
+ val smallIntAttr = df(colName).expr
+ assert(df(colName).expr.dataType === ShortType)
+
+ checkFilterPredicate(smallIntAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkFilterPredicate(smallIntAttr.isNotNull, classOf[NotEq[_]],
+ (1 to 4).map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(smallIntAttr === 1.toShort, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(smallIntAttr <=> 1.toShort, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(smallIntAttr =!= 1.toShort, classOf[NotEq[_]],
+ (2 to 4).map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(smallIntAttr < 2.toShort, classOf[Lt[_]], resultFun(1))
+ checkFilterPredicate(smallIntAttr > 3.toShort, classOf[Gt[_]], resultFun(4))
+ checkFilterPredicate(smallIntAttr <= 1.toShort, classOf[LtEq[_]], resultFun(1))
+ checkFilterPredicate(smallIntAttr >= 4.toShort, classOf[GtEq[_]], resultFun(4))
+
+ checkFilterPredicate(Literal(1.toShort) === smallIntAttr, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(Literal(1.toShort) <=> smallIntAttr, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(Literal(2.toShort) > smallIntAttr, classOf[Lt[_]], resultFun(1))
+ checkFilterPredicate(Literal(3.toShort) < smallIntAttr, classOf[Gt[_]], resultFun(4))
+ checkFilterPredicate(Literal(1.toShort) >= smallIntAttr, classOf[LtEq[_]], resultFun(1))
+ checkFilterPredicate(Literal(4.toShort) <= smallIntAttr, classOf[GtEq[_]], resultFun(4))
+
+ checkFilterPredicate(!(smallIntAttr < 4.toShort), classOf[GtEq[_]], resultFun(4))
+ checkFilterPredicate(smallIntAttr < 2.toShort || smallIntAttr > 3.toShort,
+ classOf[Operators.Or], Seq(Row(resultFun(1)), Row(resultFun(4))))
}
}
test("filter pushdown - integer") {
val data = (1 to 4).map(i => Tuple1(Option(i)))
- import testImplicits._
- withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) =>
- withParquetDataFrame(inputDF) { implicit df =>
- val intAttr = df(colName).expr
- assert(df(colName).expr.dataType === IntegerType)
-
- checkFilterPredicate(intAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
- checkFilterPredicate(intAttr.isNotNull, classOf[NotEq[_]],
- (1 to 4).map(i => Row.apply(resultFun(i))))
-
- checkFilterPredicate(intAttr === 1, classOf[Eq[_]], resultFun(1))
- checkFilterPredicate(intAttr <=> 1, classOf[Eq[_]], resultFun(1))
- checkFilterPredicate(intAttr =!= 1, classOf[NotEq[_]],
- (2 to 4).map(i => Row.apply(resultFun(i))))
-
- checkFilterPredicate(intAttr < 2, classOf[Lt[_]], resultFun(1))
- checkFilterPredicate(intAttr > 3, classOf[Gt[_]], resultFun(4))
- checkFilterPredicate(intAttr <= 1, classOf[LtEq[_]], resultFun(1))
- checkFilterPredicate(intAttr >= 4, classOf[GtEq[_]], resultFun(4))
-
- checkFilterPredicate(Literal(1) === intAttr, classOf[Eq[_]], resultFun(1))
- checkFilterPredicate(Literal(1) <=> intAttr, classOf[Eq[_]], resultFun(1))
- checkFilterPredicate(Literal(2) > intAttr, classOf[Lt[_]], resultFun(1))
- checkFilterPredicate(Literal(3) < intAttr, classOf[Gt[_]], resultFun(4))
- checkFilterPredicate(Literal(1) >= intAttr, classOf[LtEq[_]], resultFun(1))
- checkFilterPredicate(Literal(4) <= intAttr, classOf[GtEq[_]], resultFun(4))
-
- checkFilterPredicate(!(intAttr < 4), classOf[GtEq[_]], resultFun(4))
- checkFilterPredicate(intAttr < 2 || intAttr > 3, classOf[Operators.Or],
- Seq(Row(resultFun(1)), Row(resultFun(4))))
- }
+ withNestedParquetDataFrame(data) { case (inputDF, colName, resultFun) =>
+ implicit val df: DataFrame = inputDF
+
+ val intAttr = df(colName).expr
+ assert(df(colName).expr.dataType === IntegerType)
+
+ checkFilterPredicate(intAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkFilterPredicate(intAttr.isNotNull, classOf[NotEq[_]],
+ (1 to 4).map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(intAttr === 1, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(intAttr <=> 1, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(intAttr =!= 1, classOf[NotEq[_]],
+ (2 to 4).map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(intAttr < 2, classOf[Lt[_]], resultFun(1))
+ checkFilterPredicate(intAttr > 3, classOf[Gt[_]], resultFun(4))
+ checkFilterPredicate(intAttr <= 1, classOf[LtEq[_]], resultFun(1))
+ checkFilterPredicate(intAttr >= 4, classOf[GtEq[_]], resultFun(4))
+
+ checkFilterPredicate(Literal(1) === intAttr, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(Literal(1) <=> intAttr, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(Literal(2) > intAttr, classOf[Lt[_]], resultFun(1))
+ checkFilterPredicate(Literal(3) < intAttr, classOf[Gt[_]], resultFun(4))
+ checkFilterPredicate(Literal(1) >= intAttr, classOf[LtEq[_]], resultFun(1))
+ checkFilterPredicate(Literal(4) <= intAttr, classOf[GtEq[_]], resultFun(4))
+
+ checkFilterPredicate(!(intAttr < 4), classOf[GtEq[_]], resultFun(4))
+ checkFilterPredicate(intAttr < 2 || intAttr > 3, classOf[Operators.Or],
+ Seq(Row(resultFun(1)), Row(resultFun(4))))
}
}
test("filter pushdown - long") {
val data = (1 to 4).map(i => Tuple1(Option(i.toLong)))
- import testImplicits._
- withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) =>
- withParquetDataFrame(inputDF) { implicit df =>
- val longAttr = df(colName).expr
- assert(df(colName).expr.dataType === LongType)
-
- checkFilterPredicate(longAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
- checkFilterPredicate(longAttr.isNotNull, classOf[NotEq[_]],
- (1 to 4).map(i => Row.apply(resultFun(i))))
-
- checkFilterPredicate(longAttr === 1, classOf[Eq[_]], resultFun(1))
- checkFilterPredicate(longAttr <=> 1, classOf[Eq[_]], resultFun(1))
- checkFilterPredicate(longAttr =!= 1, classOf[NotEq[_]],
- (2 to 4).map(i => Row.apply(resultFun(i))))
-
- checkFilterPredicate(longAttr < 2, classOf[Lt[_]], resultFun(1))
- checkFilterPredicate(longAttr > 3, classOf[Gt[_]], resultFun(4))
- checkFilterPredicate(longAttr <= 1, classOf[LtEq[_]], resultFun(1))
- checkFilterPredicate(longAttr >= 4, classOf[GtEq[_]], resultFun(4))
-
- checkFilterPredicate(Literal(1) === longAttr, classOf[Eq[_]], resultFun(1))
- checkFilterPredicate(Literal(1) <=> longAttr, classOf[Eq[_]], resultFun(1))
- checkFilterPredicate(Literal(2) > longAttr, classOf[Lt[_]], resultFun(1))
- checkFilterPredicate(Literal(3) < longAttr, classOf[Gt[_]], resultFun(4))
- checkFilterPredicate(Literal(1) >= longAttr, classOf[LtEq[_]], resultFun(1))
- checkFilterPredicate(Literal(4) <= longAttr, classOf[GtEq[_]], resultFun(4))
-
- checkFilterPredicate(!(longAttr < 4), classOf[GtEq[_]], resultFun(4))
- checkFilterPredicate(longAttr < 2 || longAttr > 3, classOf[Operators.Or],
- Seq(Row(resultFun(1)), Row(resultFun(4))))
- }
+ withNestedParquetDataFrame(data) { case (inputDF, colName, resultFun) =>
+ implicit val df: DataFrame = inputDF
+
+ val longAttr = df(colName).expr
+ assert(df(colName).expr.dataType === LongType)
+
+ checkFilterPredicate(longAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkFilterPredicate(longAttr.isNotNull, classOf[NotEq[_]],
+ (1 to 4).map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(longAttr === 1, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(longAttr <=> 1, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(longAttr =!= 1, classOf[NotEq[_]],
+ (2 to 4).map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(longAttr < 2, classOf[Lt[_]], resultFun(1))
+ checkFilterPredicate(longAttr > 3, classOf[Gt[_]], resultFun(4))
+ checkFilterPredicate(longAttr <= 1, classOf[LtEq[_]], resultFun(1))
+ checkFilterPredicate(longAttr >= 4, classOf[GtEq[_]], resultFun(4))
+
+ checkFilterPredicate(Literal(1) === longAttr, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(Literal(1) <=> longAttr, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(Literal(2) > longAttr, classOf[Lt[_]], resultFun(1))
+ checkFilterPredicate(Literal(3) < longAttr, classOf[Gt[_]], resultFun(4))
+ checkFilterPredicate(Literal(1) >= longAttr, classOf[LtEq[_]], resultFun(1))
+ checkFilterPredicate(Literal(4) <= longAttr, classOf[GtEq[_]], resultFun(4))
+
+ checkFilterPredicate(!(longAttr < 4), classOf[GtEq[_]], resultFun(4))
+ checkFilterPredicate(longAttr < 2 || longAttr > 3, classOf[Operators.Or],
+ Seq(Row(resultFun(1)), Row(resultFun(4))))
}
}
test("filter pushdown - float") {
val data = (1 to 4).map(i => Tuple1(Option(i.toFloat)))
- import testImplicits._
- withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) =>
- withParquetDataFrame(inputDF) { implicit df =>
- val floatAttr = df(colName).expr
- assert(df(colName).expr.dataType === FloatType)
-
- checkFilterPredicate(floatAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
- checkFilterPredicate(floatAttr.isNotNull, classOf[NotEq[_]],
- (1 to 4).map(i => Row.apply(resultFun(i))))
-
- checkFilterPredicate(floatAttr === 1, classOf[Eq[_]], resultFun(1))
- checkFilterPredicate(floatAttr <=> 1, classOf[Eq[_]], resultFun(1))
- checkFilterPredicate(floatAttr =!= 1, classOf[NotEq[_]],
- (2 to 4).map(i => Row.apply(resultFun(i))))
-
- checkFilterPredicate(floatAttr < 2, classOf[Lt[_]], resultFun(1))
- checkFilterPredicate(floatAttr > 3, classOf[Gt[_]], resultFun(4))
- checkFilterPredicate(floatAttr <= 1, classOf[LtEq[_]], resultFun(1))
- checkFilterPredicate(floatAttr >= 4, classOf[GtEq[_]], resultFun(4))
-
- checkFilterPredicate(Literal(1) === floatAttr, classOf[Eq[_]], resultFun(1))
- checkFilterPredicate(Literal(1) <=> floatAttr, classOf[Eq[_]], resultFun(1))
- checkFilterPredicate(Literal(2) > floatAttr, classOf[Lt[_]], resultFun(1))
- checkFilterPredicate(Literal(3) < floatAttr, classOf[Gt[_]], resultFun(4))
- checkFilterPredicate(Literal(1) >= floatAttr, classOf[LtEq[_]], resultFun(1))
- checkFilterPredicate(Literal(4) <= floatAttr, classOf[GtEq[_]], resultFun(4))
-
- checkFilterPredicate(!(floatAttr < 4), classOf[GtEq[_]], resultFun(4))
- checkFilterPredicate(floatAttr < 2 || floatAttr > 3, classOf[Operators.Or],
- Seq(Row(resultFun(1)), Row(resultFun(4))))
- }
+ withNestedParquetDataFrame(data) { case (inputDF, colName, resultFun) =>
+ implicit val df: DataFrame = inputDF
+
+ val floatAttr = df(colName).expr
+ assert(df(colName).expr.dataType === FloatType)
+
+ checkFilterPredicate(floatAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkFilterPredicate(floatAttr.isNotNull, classOf[NotEq[_]],
+ (1 to 4).map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(floatAttr === 1, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(floatAttr <=> 1, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(floatAttr =!= 1, classOf[NotEq[_]],
+ (2 to 4).map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(floatAttr < 2, classOf[Lt[_]], resultFun(1))
+ checkFilterPredicate(floatAttr > 3, classOf[Gt[_]], resultFun(4))
+ checkFilterPredicate(floatAttr <= 1, classOf[LtEq[_]], resultFun(1))
+ checkFilterPredicate(floatAttr >= 4, classOf[GtEq[_]], resultFun(4))
+
+ checkFilterPredicate(Literal(1) === floatAttr, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(Literal(1) <=> floatAttr, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(Literal(2) > floatAttr, classOf[Lt[_]], resultFun(1))
+ checkFilterPredicate(Literal(3) < floatAttr, classOf[Gt[_]], resultFun(4))
+ checkFilterPredicate(Literal(1) >= floatAttr, classOf[LtEq[_]], resultFun(1))
+ checkFilterPredicate(Literal(4) <= floatAttr, classOf[GtEq[_]], resultFun(4))
+
+ checkFilterPredicate(!(floatAttr < 4), classOf[GtEq[_]], resultFun(4))
+ checkFilterPredicate(floatAttr < 2 || floatAttr > 3, classOf[Operators.Or],
+ Seq(Row(resultFun(1)), Row(resultFun(4))))
}
}
test("filter pushdown - double") {
val data = (1 to 4).map(i => Tuple1(Option(i.toDouble)))
- import testImplicits._
- withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) =>
- withParquetDataFrame(inputDF) { implicit df =>
- val doubleAttr = df(colName).expr
- assert(df(colName).expr.dataType === DoubleType)
-
- checkFilterPredicate(doubleAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
- checkFilterPredicate(doubleAttr.isNotNull, classOf[NotEq[_]],
- (1 to 4).map(i => Row.apply(resultFun(i))))
-
- checkFilterPredicate(doubleAttr === 1, classOf[Eq[_]], resultFun(1))
- checkFilterPredicate(doubleAttr <=> 1, classOf[Eq[_]], resultFun(1))
- checkFilterPredicate(doubleAttr =!= 1, classOf[NotEq[_]],
- (2 to 4).map(i => Row.apply(resultFun(i))))
-
- checkFilterPredicate(doubleAttr < 2, classOf[Lt[_]], resultFun(1))
- checkFilterPredicate(doubleAttr > 3, classOf[Gt[_]], resultFun(4))
- checkFilterPredicate(doubleAttr <= 1, classOf[LtEq[_]], resultFun(1))
- checkFilterPredicate(doubleAttr >= 4, classOf[GtEq[_]], resultFun(4))
-
- checkFilterPredicate(Literal(1) === doubleAttr, classOf[Eq[_]], resultFun(1))
- checkFilterPredicate(Literal(1) <=> doubleAttr, classOf[Eq[_]], resultFun(1))
- checkFilterPredicate(Literal(2) > doubleAttr, classOf[Lt[_]], resultFun(1))
- checkFilterPredicate(Literal(3) < doubleAttr, classOf[Gt[_]], resultFun(4))
- checkFilterPredicate(Literal(1) >= doubleAttr, classOf[LtEq[_]], resultFun(1))
- checkFilterPredicate(Literal(4) <= doubleAttr, classOf[GtEq[_]], resultFun(4))
-
- checkFilterPredicate(!(doubleAttr < 4), classOf[GtEq[_]], resultFun(4))
- checkFilterPredicate(doubleAttr < 2 || doubleAttr > 3, classOf[Operators.Or],
- Seq(Row(resultFun(1)), Row(resultFun(4))))
- }
+ withNestedParquetDataFrame(data) { case (inputDF, colName, resultFun) =>
+ implicit val df: DataFrame = inputDF
+
+ val doubleAttr = df(colName).expr
+ assert(df(colName).expr.dataType === DoubleType)
+
+ checkFilterPredicate(doubleAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkFilterPredicate(doubleAttr.isNotNull, classOf[NotEq[_]],
+ (1 to 4).map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(doubleAttr === 1, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(doubleAttr <=> 1, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(doubleAttr =!= 1, classOf[NotEq[_]],
+ (2 to 4).map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(doubleAttr < 2, classOf[Lt[_]], resultFun(1))
+ checkFilterPredicate(doubleAttr > 3, classOf[Gt[_]], resultFun(4))
+ checkFilterPredicate(doubleAttr <= 1, classOf[LtEq[_]], resultFun(1))
+ checkFilterPredicate(doubleAttr >= 4, classOf[GtEq[_]], resultFun(4))
+
+ checkFilterPredicate(Literal(1) === doubleAttr, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(Literal(1) <=> doubleAttr, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(Literal(2) > doubleAttr, classOf[Lt[_]], resultFun(1))
+ checkFilterPredicate(Literal(3) < doubleAttr, classOf[Gt[_]], resultFun(4))
+ checkFilterPredicate(Literal(1) >= doubleAttr, classOf[LtEq[_]], resultFun(1))
+ checkFilterPredicate(Literal(4) <= doubleAttr, classOf[GtEq[_]], resultFun(4))
+
+ checkFilterPredicate(!(doubleAttr < 4), classOf[GtEq[_]], resultFun(4))
+ checkFilterPredicate(doubleAttr < 2 || doubleAttr > 3, classOf[Operators.Or],
+ Seq(Row(resultFun(1)), Row(resultFun(4))))
}
}
test("filter pushdown - string") {
val data = (1 to 4).map(i => Tuple1(Option(i.toString)))
- import testImplicits._
- withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) =>
- withParquetDataFrame(inputDF) { implicit df =>
- val stringAttr = df(colName).expr
- assert(df(colName).expr.dataType === StringType)
-
- checkFilterPredicate(stringAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
- checkFilterPredicate(stringAttr.isNotNull, classOf[NotEq[_]],
- (1 to 4).map(i => Row.apply(resultFun(i.toString))))
-
- checkFilterPredicate(stringAttr === "1", classOf[Eq[_]], resultFun("1"))
- checkFilterPredicate(stringAttr <=> "1", classOf[Eq[_]], resultFun("1"))
- checkFilterPredicate(stringAttr =!= "1", classOf[NotEq[_]],
- (2 to 4).map(i => Row.apply(resultFun(i.toString))))
-
- checkFilterPredicate(stringAttr < "2", classOf[Lt[_]], resultFun("1"))
- checkFilterPredicate(stringAttr > "3", classOf[Gt[_]], resultFun("4"))
- checkFilterPredicate(stringAttr <= "1", classOf[LtEq[_]], resultFun("1"))
- checkFilterPredicate(stringAttr >= "4", classOf[GtEq[_]], resultFun("4"))
-
- checkFilterPredicate(Literal("1") === stringAttr, classOf[Eq[_]], resultFun("1"))
- checkFilterPredicate(Literal("1") <=> stringAttr, classOf[Eq[_]], resultFun("1"))
- checkFilterPredicate(Literal("2") > stringAttr, classOf[Lt[_]], resultFun("1"))
- checkFilterPredicate(Literal("3") < stringAttr, classOf[Gt[_]], resultFun("4"))
- checkFilterPredicate(Literal("1") >= stringAttr, classOf[LtEq[_]], resultFun("1"))
- checkFilterPredicate(Literal("4") <= stringAttr, classOf[GtEq[_]], resultFun("4"))
-
- checkFilterPredicate(!(stringAttr < "4"), classOf[GtEq[_]], resultFun("4"))
- checkFilterPredicate(stringAttr < "2" || stringAttr > "3", classOf[Operators.Or],
- Seq(Row(resultFun("1")), Row(resultFun("4"))))
- }
+ withNestedParquetDataFrame(data) { case (inputDF, colName, resultFun) =>
+ implicit val df: DataFrame = inputDF
+
+ val stringAttr = df(colName).expr
+ assert(df(colName).expr.dataType === StringType)
+
+ checkFilterPredicate(stringAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkFilterPredicate(stringAttr.isNotNull, classOf[NotEq[_]],
+ (1 to 4).map(i => Row.apply(resultFun(i.toString))))
+
+ checkFilterPredicate(stringAttr === "1", classOf[Eq[_]], resultFun("1"))
+ checkFilterPredicate(stringAttr <=> "1", classOf[Eq[_]], resultFun("1"))
+ checkFilterPredicate(stringAttr =!= "1", classOf[NotEq[_]],
+ (2 to 4).map(i => Row.apply(resultFun(i.toString))))
+
+ checkFilterPredicate(stringAttr < "2", classOf[Lt[_]], resultFun("1"))
+ checkFilterPredicate(stringAttr > "3", classOf[Gt[_]], resultFun("4"))
+ checkFilterPredicate(stringAttr <= "1", classOf[LtEq[_]], resultFun("1"))
+ checkFilterPredicate(stringAttr >= "4", classOf[GtEq[_]], resultFun("4"))
+
+ checkFilterPredicate(Literal("1") === stringAttr, classOf[Eq[_]], resultFun("1"))
+ checkFilterPredicate(Literal("1") <=> stringAttr, classOf[Eq[_]], resultFun("1"))
+ checkFilterPredicate(Literal("2") > stringAttr, classOf[Lt[_]], resultFun("1"))
+ checkFilterPredicate(Literal("3") < stringAttr, classOf[Gt[_]], resultFun("4"))
+ checkFilterPredicate(Literal("1") >= stringAttr, classOf[LtEq[_]], resultFun("1"))
+ checkFilterPredicate(Literal("4") <= stringAttr, classOf[GtEq[_]], resultFun("4"))
+
+ checkFilterPredicate(!(stringAttr < "4"), classOf[GtEq[_]], resultFun("4"))
+ checkFilterPredicate(stringAttr < "2" || stringAttr > "3", classOf[Operators.Or],
+ Seq(Row(resultFun("1")), Row(resultFun("4"))))
}
}
@@ -501,38 +508,37 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
}
val data = (1 to 4).map(i => Tuple1(Option(i.b)))
- import testImplicits._
- withNestedDataFrame(data.toDF()) { case (inputDF, colName, resultFun) =>
- withParquetDataFrame(inputDF) { implicit df =>
- val binaryAttr: Expression = df(colName).expr
- assert(df(colName).expr.dataType === BinaryType)
-
- checkFilterPredicate(binaryAttr === 1.b, classOf[Eq[_]], resultFun(1.b))
- checkFilterPredicate(binaryAttr <=> 1.b, classOf[Eq[_]], resultFun(1.b))
-
- checkFilterPredicate(binaryAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
- checkFilterPredicate(binaryAttr.isNotNull, classOf[NotEq[_]],
- (1 to 4).map(i => Row.apply(resultFun(i.b))))
-
- checkFilterPredicate(binaryAttr =!= 1.b, classOf[NotEq[_]],
- (2 to 4).map(i => Row.apply(resultFun(i.b))))
-
- checkFilterPredicate(binaryAttr < 2.b, classOf[Lt[_]], resultFun(1.b))
- checkFilterPredicate(binaryAttr > 3.b, classOf[Gt[_]], resultFun(4.b))
- checkFilterPredicate(binaryAttr <= 1.b, classOf[LtEq[_]], resultFun(1.b))
- checkFilterPredicate(binaryAttr >= 4.b, classOf[GtEq[_]], resultFun(4.b))
-
- checkFilterPredicate(Literal(1.b) === binaryAttr, classOf[Eq[_]], resultFun(1.b))
- checkFilterPredicate(Literal(1.b) <=> binaryAttr, classOf[Eq[_]], resultFun(1.b))
- checkFilterPredicate(Literal(2.b) > binaryAttr, classOf[Lt[_]], resultFun(1.b))
- checkFilterPredicate(Literal(3.b) < binaryAttr, classOf[Gt[_]], resultFun(4.b))
- checkFilterPredicate(Literal(1.b) >= binaryAttr, classOf[LtEq[_]], resultFun(1.b))
- checkFilterPredicate(Literal(4.b) <= binaryAttr, classOf[GtEq[_]], resultFun(4.b))
-
- checkFilterPredicate(!(binaryAttr < 4.b), classOf[GtEq[_]], resultFun(4.b))
- checkFilterPredicate(binaryAttr < 2.b || binaryAttr > 3.b, classOf[Operators.Or],
- Seq(Row(resultFun(1.b)), Row(resultFun(4.b))))
- }
+ withNestedParquetDataFrame(data) { case (inputDF, colName, resultFun) =>
+ implicit val df: DataFrame = inputDF
+
+ val binaryAttr: Expression = df(colName).expr
+ assert(df(colName).expr.dataType === BinaryType)
+
+ checkFilterPredicate(binaryAttr === 1.b, classOf[Eq[_]], resultFun(1.b))
+ checkFilterPredicate(binaryAttr <=> 1.b, classOf[Eq[_]], resultFun(1.b))
+
+ checkFilterPredicate(binaryAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkFilterPredicate(binaryAttr.isNotNull, classOf[NotEq[_]],
+ (1 to 4).map(i => Row.apply(resultFun(i.b))))
+
+ checkFilterPredicate(binaryAttr =!= 1.b, classOf[NotEq[_]],
+ (2 to 4).map(i => Row.apply(resultFun(i.b))))
+
+ checkFilterPredicate(binaryAttr < 2.b, classOf[Lt[_]], resultFun(1.b))
+ checkFilterPredicate(binaryAttr > 3.b, classOf[Gt[_]], resultFun(4.b))
+ checkFilterPredicate(binaryAttr <= 1.b, classOf[LtEq[_]], resultFun(1.b))
+ checkFilterPredicate(binaryAttr >= 4.b, classOf[GtEq[_]], resultFun(4.b))
+
+ checkFilterPredicate(Literal(1.b) === binaryAttr, classOf[Eq[_]], resultFun(1.b))
+ checkFilterPredicate(Literal(1.b) <=> binaryAttr, classOf[Eq[_]], resultFun(1.b))
+ checkFilterPredicate(Literal(2.b) > binaryAttr, classOf[Lt[_]], resultFun(1.b))
+ checkFilterPredicate(Literal(3.b) < binaryAttr, classOf[Gt[_]], resultFun(4.b))
+ checkFilterPredicate(Literal(1.b) >= binaryAttr, classOf[LtEq[_]], resultFun(1.b))
+ checkFilterPredicate(Literal(4.b) <= binaryAttr, classOf[GtEq[_]], resultFun(4.b))
+
+ checkFilterPredicate(!(binaryAttr < 4.b), classOf[GtEq[_]], resultFun(4.b))
+ checkFilterPredicate(binaryAttr < 2.b || binaryAttr > 3.b, classOf[Operators.Or],
+ Seq(Row(resultFun(1.b)), Row(resultFun(4.b))))
}
}
@@ -546,56 +552,57 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
Seq(false, true).foreach { java8Api =>
withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> java8Api.toString) {
- val df = data.map(i => Tuple1(Date.valueOf(i))).toDF()
- withNestedDataFrame(df) { case (inputDF, colName, fun) =>
+ val dates = data.map(i => Tuple1(Date.valueOf(i))).toDF()
+ withNestedParquetDataFrame(dates) { case (inputDF, colName, fun) =>
+ implicit val df: DataFrame = inputDF
+
def resultFun(dateStr: String): Any = {
val parsed = if (java8Api) LocalDate.parse(dateStr) else Date.valueOf(dateStr)
fun(parsed)
}
- withParquetDataFrame(inputDF) { implicit df =>
- val dateAttr: Expression = df(colName).expr
- assert(df(colName).expr.dataType === DateType)
-
- checkFilterPredicate(dateAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
- checkFilterPredicate(dateAttr.isNotNull, classOf[NotEq[_]],
- data.map(i => Row.apply(resultFun(i))))
-
- checkFilterPredicate(dateAttr === "2018-03-18".date, classOf[Eq[_]],
- resultFun("2018-03-18"))
- checkFilterPredicate(dateAttr <=> "2018-03-18".date, classOf[Eq[_]],
- resultFun("2018-03-18"))
- checkFilterPredicate(dateAttr =!= "2018-03-18".date, classOf[NotEq[_]],
- Seq("2018-03-19", "2018-03-20", "2018-03-21").map(i => Row.apply(resultFun(i))))
-
- checkFilterPredicate(dateAttr < "2018-03-19".date, classOf[Lt[_]],
- resultFun("2018-03-18"))
- checkFilterPredicate(dateAttr > "2018-03-20".date, classOf[Gt[_]],
- resultFun("2018-03-21"))
- checkFilterPredicate(dateAttr <= "2018-03-18".date, classOf[LtEq[_]],
- resultFun("2018-03-18"))
- checkFilterPredicate(dateAttr >= "2018-03-21".date, classOf[GtEq[_]],
- resultFun("2018-03-21"))
-
- checkFilterPredicate(Literal("2018-03-18".date) === dateAttr, classOf[Eq[_]],
- resultFun("2018-03-18"))
- checkFilterPredicate(Literal("2018-03-18".date) <=> dateAttr, classOf[Eq[_]],
- resultFun("2018-03-18"))
- checkFilterPredicate(Literal("2018-03-19".date) > dateAttr, classOf[Lt[_]],
- resultFun("2018-03-18"))
- checkFilterPredicate(Literal("2018-03-20".date) < dateAttr, classOf[Gt[_]],
- resultFun("2018-03-21"))
- checkFilterPredicate(Literal("2018-03-18".date) >= dateAttr, classOf[LtEq[_]],
- resultFun("2018-03-18"))
- checkFilterPredicate(Literal("2018-03-21".date) <= dateAttr, classOf[GtEq[_]],
- resultFun("2018-03-21"))
-
- checkFilterPredicate(!(dateAttr < "2018-03-21".date), classOf[GtEq[_]],
- resultFun("2018-03-21"))
- checkFilterPredicate(
- dateAttr < "2018-03-19".date || dateAttr > "2018-03-20".date,
- classOf[Operators.Or],
- Seq(Row(resultFun("2018-03-18")), Row(resultFun("2018-03-21"))))
- }
+
+ val dateAttr: Expression = df(colName).expr
+ assert(df(colName).expr.dataType === DateType)
+
+ checkFilterPredicate(dateAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkFilterPredicate(dateAttr.isNotNull, classOf[NotEq[_]],
+ data.map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(dateAttr === "2018-03-18".date, classOf[Eq[_]],
+ resultFun("2018-03-18"))
+ checkFilterPredicate(dateAttr <=> "2018-03-18".date, classOf[Eq[_]],
+ resultFun("2018-03-18"))
+ checkFilterPredicate(dateAttr =!= "2018-03-18".date, classOf[NotEq[_]],
+ Seq("2018-03-19", "2018-03-20", "2018-03-21").map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(dateAttr < "2018-03-19".date, classOf[Lt[_]],
+ resultFun("2018-03-18"))
+ checkFilterPredicate(dateAttr > "2018-03-20".date, classOf[Gt[_]],
+ resultFun("2018-03-21"))
+ checkFilterPredicate(dateAttr <= "2018-03-18".date, classOf[LtEq[_]],
+ resultFun("2018-03-18"))
+ checkFilterPredicate(dateAttr >= "2018-03-21".date, classOf[GtEq[_]],
+ resultFun("2018-03-21"))
+
+ checkFilterPredicate(Literal("2018-03-18".date) === dateAttr, classOf[Eq[_]],
+ resultFun("2018-03-18"))
+ checkFilterPredicate(Literal("2018-03-18".date) <=> dateAttr, classOf[Eq[_]],
+ resultFun("2018-03-18"))
+ checkFilterPredicate(Literal("2018-03-19".date) > dateAttr, classOf[Lt[_]],
+ resultFun("2018-03-18"))
+ checkFilterPredicate(Literal("2018-03-20".date) < dateAttr, classOf[Gt[_]],
+ resultFun("2018-03-21"))
+ checkFilterPredicate(Literal("2018-03-18".date) >= dateAttr, classOf[LtEq[_]],
+ resultFun("2018-03-18"))
+ checkFilterPredicate(Literal("2018-03-21".date) <= dateAttr, classOf[GtEq[_]],
+ resultFun("2018-03-21"))
+
+ checkFilterPredicate(!(dateAttr < "2018-03-21".date), classOf[GtEq[_]],
+ resultFun("2018-03-21"))
+ checkFilterPredicate(
+ dateAttr < "2018-03-19".date || dateAttr > "2018-03-20".date,
+ classOf[Operators.Or],
+ Seq(Row(resultFun("2018-03-18")), Row(resultFun("2018-03-21"))))
}
}
}
@@ -603,7 +610,9 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
test("filter pushdown - timestamp") {
Seq(true, false).foreach { java8Api =>
- withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> java8Api.toString) {
+ withSQLConf(
+ SQLConf.DATETIME_JAVA8API_ENABLED.key -> java8Api.toString,
+ SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_WRITE.key -> "CORRECTED") {
// spark.sql.parquet.outputTimestampType = TIMESTAMP_MILLIS
val millisData = Seq(
"1000-06-14 08:28:53.123",
@@ -630,11 +639,14 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key ->
ParquetOutputTimestampType.INT96.toString) {
import testImplicits._
- withParquetDataFrame(
- millisData.map(i => Tuple1(Timestamp.valueOf(i))).toDF()) { implicit df =>
- val schema = new SparkToParquetSchemaConverter(conf).convert(df.schema)
- assertResult(None) {
- createParquetFilters(schema).createFilter(sources.IsNull("_1"))
+ withTempPath { file =>
+ millisData.map(i => Tuple1(Timestamp.valueOf(i))).toDF
+ .write.format(dataSourceName).save(file.getCanonicalPath)
+ readParquetFile(file.getCanonicalPath) { df =>
+ val schema = new SparkToParquetSchemaConverter(conf).convert(df.schema)
+ assertResult(None) {
+ createParquetFilters(schema).createFilter(sources.IsNull("_1"))
+ }
}
}
}
@@ -653,36 +665,36 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
val rdd =
spark.sparkContext.parallelize((1 to 4).map(i => Row(new java.math.BigDecimal(i))))
val dataFrame = spark.createDataFrame(rdd, StructType.fromDDL(s"a decimal($precision, 2)"))
- withNestedDataFrame(dataFrame) { case (inputDF, colName, resultFun) =>
- withParquetDataFrame(inputDF) { implicit df =>
- val decimalAttr: Expression = df(colName).expr
- assert(df(colName).expr.dataType === DecimalType(precision, 2))
-
- checkFilterPredicate(decimalAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
- checkFilterPredicate(decimalAttr.isNotNull, classOf[NotEq[_]],
- (1 to 4).map(i => Row.apply(resultFun(i))))
-
- checkFilterPredicate(decimalAttr === 1, classOf[Eq[_]], resultFun(1))
- checkFilterPredicate(decimalAttr <=> 1, classOf[Eq[_]], resultFun(1))
- checkFilterPredicate(decimalAttr =!= 1, classOf[NotEq[_]],
- (2 to 4).map(i => Row.apply(resultFun(i))))
-
- checkFilterPredicate(decimalAttr < 2, classOf[Lt[_]], resultFun(1))
- checkFilterPredicate(decimalAttr > 3, classOf[Gt[_]], resultFun(4))
- checkFilterPredicate(decimalAttr <= 1, classOf[LtEq[_]], resultFun(1))
- checkFilterPredicate(decimalAttr >= 4, classOf[GtEq[_]], resultFun(4))
-
- checkFilterPredicate(Literal(1) === decimalAttr, classOf[Eq[_]], resultFun(1))
- checkFilterPredicate(Literal(1) <=> decimalAttr, classOf[Eq[_]], resultFun(1))
- checkFilterPredicate(Literal(2) > decimalAttr, classOf[Lt[_]], resultFun(1))
- checkFilterPredicate(Literal(3) < decimalAttr, classOf[Gt[_]], resultFun(4))
- checkFilterPredicate(Literal(1) >= decimalAttr, classOf[LtEq[_]], resultFun(1))
- checkFilterPredicate(Literal(4) <= decimalAttr, classOf[GtEq[_]], resultFun(4))
-
- checkFilterPredicate(!(decimalAttr < 4), classOf[GtEq[_]], resultFun(4))
- checkFilterPredicate(decimalAttr < 2 || decimalAttr > 3, classOf[Operators.Or],
- Seq(Row(resultFun(1)), Row(resultFun(4))))
- }
+ withNestedParquetDataFrame(dataFrame) { case (inputDF, colName, resultFun) =>
+ implicit val df: DataFrame = inputDF
+
+ val decimalAttr: Expression = df(colName).expr
+ assert(df(colName).expr.dataType === DecimalType(precision, 2))
+
+ checkFilterPredicate(decimalAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkFilterPredicate(decimalAttr.isNotNull, classOf[NotEq[_]],
+ (1 to 4).map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(decimalAttr === 1, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(decimalAttr <=> 1, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(decimalAttr =!= 1, classOf[NotEq[_]],
+ (2 to 4).map(i => Row.apply(resultFun(i))))
+
+ checkFilterPredicate(decimalAttr < 2, classOf[Lt[_]], resultFun(1))
+ checkFilterPredicate(decimalAttr > 3, classOf[Gt[_]], resultFun(4))
+ checkFilterPredicate(decimalAttr <= 1, classOf[LtEq[_]], resultFun(1))
+ checkFilterPredicate(decimalAttr >= 4, classOf[GtEq[_]], resultFun(4))
+
+ checkFilterPredicate(Literal(1) === decimalAttr, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(Literal(1) <=> decimalAttr, classOf[Eq[_]], resultFun(1))
+ checkFilterPredicate(Literal(2) > decimalAttr, classOf[Lt[_]], resultFun(1))
+ checkFilterPredicate(Literal(3) < decimalAttr, classOf[Gt[_]], resultFun(4))
+ checkFilterPredicate(Literal(1) >= decimalAttr, classOf[LtEq[_]], resultFun(1))
+ checkFilterPredicate(Literal(4) <= decimalAttr, classOf[GtEq[_]], resultFun(4))
+
+ checkFilterPredicate(!(decimalAttr < 4), classOf[GtEq[_]], resultFun(4))
+ checkFilterPredicate(decimalAttr < 2 || decimalAttr > 3, classOf[Operators.Or],
+ Seq(Row(resultFun(1)), Row(resultFun(4))))
}
}
}
@@ -1195,8 +1207,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
}
test("SPARK-16371 Do not push down filters when inner name and outer name are the same") {
- import testImplicits._
- withParquetDataFrame((1 to 4).map(i => Tuple1(Tuple1(i))).toDF()) { implicit df =>
+ withParquetDataFrame((1 to 4).map(i => Tuple1(Tuple1(i)))) { implicit df =>
// Here the schema becomes as below:
//
// root
@@ -1336,10 +1347,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
}
test("filter pushdown - StringStartsWith") {
- withParquetDataFrame {
- import testImplicits._
- (1 to 4).map(i => Tuple1(i + "str" + i)).toDF()
- } { implicit df =>
+ withParquetDataFrame((1 to 4).map(i => Tuple1(i + "str" + i))) { implicit df =>
checkFilterPredicate(
'_1.startsWith("").asInstanceOf[Predicate],
classOf[UserDefinedByInstance[_, _]],
@@ -1385,10 +1393,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
}
// SPARK-28371: make sure filter is null-safe.
- withParquetDataFrame {
- import testImplicits._
- Seq(Tuple1[String](null)).toDF()
- } { implicit df =>
+ withParquetDataFrame(Seq(Tuple1[String](null))) { implicit df =>
checkFilterPredicate(
'_1.startsWith("blah").asInstanceOf[Predicate],
classOf[UserDefinedByInstance[_, _]],
@@ -1607,7 +1612,7 @@ class ParquetV1FilterSuite extends ParquetFilterSuite {
expected: Seq[Row]): Unit = {
val output = predicate.collect { case a: Attribute => a }.distinct
- Seq(("parquet", true), ("", false)).map { case (pushdownDsList, nestedPredicatePushdown) =>
+ Seq(("parquet", true), ("", false)).foreach { case (pushdownDsList, nestedPredicatePushdown) =>
withSQLConf(
SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true",
SQLConf.PARQUET_FILTER_PUSHDOWN_DATE_ENABLED.key -> "true",
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
index 79c32976f02ec..2dc8a062bb73d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
@@ -85,7 +85,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession
* Writes `data` to a Parquet file, reads it back and check file contents.
*/
protected def checkParquetFile[T <: Product : ClassTag: TypeTag](data: Seq[T]): Unit = {
- withParquetDataFrame(data.toDF())(r => checkAnswer(r, data.map(Row.fromTuple)))
+ withParquetDataFrame(data)(r => checkAnswer(r, data.map(Row.fromTuple)))
}
test("basic data types (without binary)") {
@@ -97,7 +97,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession
test("raw binary") {
val data = (1 to 4).map(i => Tuple1(Array.fill(3)(i.toByte)))
- withParquetDataFrame(data.toDF()) { df =>
+ withParquetDataFrame(data) { df =>
assertResult(data.map(_._1.mkString(",")).sorted) {
df.collect().map(_.getAs[Array[Byte]](0).mkString(",")).sorted
}
@@ -200,7 +200,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession
testStandardAndLegacyModes("struct") {
val data = (1 to 4).map(i => Tuple1((i, s"val_$i")))
- withParquetDataFrame(data.toDF()) { df =>
+ withParquetDataFrame(data) { df =>
// Structs are converted to `Row`s
checkAnswer(df, data.map { case Tuple1(struct) =>
Row(Row(struct.productIterator.toSeq: _*))
@@ -217,7 +217,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession
)
)
}
- withParquetDataFrame(data.toDF()) { df =>
+ withParquetDataFrame(data) { df =>
// Structs are converted to `Row`s
checkAnswer(df, data.map { case Tuple1(array) =>
Row(array.map(struct => Row(struct.productIterator.toSeq: _*)))
@@ -236,7 +236,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession
)
)
}
- withParquetDataFrame(data.toDF()) { df =>
+ withParquetDataFrame(data) { df =>
// Structs are converted to `Row`s
checkAnswer(df, data.map { case Tuple1(array) =>
Row(array.map { case Tuple1(Tuple1(str)) => Row(Row(str))})
@@ -246,7 +246,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession
testStandardAndLegacyModes("nested struct with array of array as field") {
val data = (1 to 4).map(i => Tuple1((i, Seq(Seq(s"val_$i")))))
- withParquetDataFrame(data.toDF()) { df =>
+ withParquetDataFrame(data) { df =>
// Structs are converted to `Row`s
checkAnswer(df, data.map { case Tuple1(struct) =>
Row(Row(struct.productIterator.toSeq: _*))
@@ -263,7 +263,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession
)
)
}
- withParquetDataFrame(data.toDF()) { df =>
+ withParquetDataFrame(data) { df =>
// Structs are converted to `Row`s
checkAnswer(df, data.map { case Tuple1(m) =>
Row(m.map { case (k, v) => Row(k.productIterator.toSeq: _*) -> v })
@@ -280,7 +280,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession
)
)
}
- withParquetDataFrame(data.toDF()) { df =>
+ withParquetDataFrame(data) { df =>
// Structs are converted to `Row`s
checkAnswer(df, data.map { case Tuple1(m) =>
Row(m.mapValues(struct => Row(struct.productIterator.toSeq: _*)))
@@ -296,7 +296,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession
null.asInstanceOf[java.lang.Float],
null.asInstanceOf[java.lang.Double])
- withParquetDataFrame((allNulls :: Nil).toDF()) { df =>
+ withParquetDataFrame(allNulls :: Nil) { df =>
val rows = df.collect()
assert(rows.length === 1)
assert(rows.head === Row(Seq.fill(5)(null): _*))
@@ -309,7 +309,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession
None.asInstanceOf[Option[Long]],
None.asInstanceOf[Option[String]])
- withParquetDataFrame((allNones :: Nil).toDF()) { df =>
+ withParquetDataFrame(allNones :: Nil) { df =>
val rows = df.collect()
assert(rows.length === 1)
assert(rows.head === Row(Seq.fill(3)(null): _*))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
index 105f025adc0ad..db8ee724c01c1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
@@ -63,18 +63,12 @@ private[sql] trait ParquetTest extends FileBasedDataSourceTest {
(f: String => Unit): Unit = withDataSourceFile(data)(f)
/**
- * Writes `df` dataframe to a Parquet file and reads it back as a [[DataFrame]],
+ * Writes `data` to a Parquet file and reads it back as a [[DataFrame]],
* which is then passed to `f`. The Parquet file will be deleted after `f` returns.
*/
- protected def withParquetDataFrame(df: DataFrame, testVectorized: Boolean = true)
- (f: DataFrame => Unit): Unit = {
- withTempPath { file =>
- withSQLConf(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_WRITE.key -> "CORRECTED") {
- df.write.format(dataSourceName).save(file.getCanonicalPath)
- }
- readFile(file.getCanonicalPath, testVectorized)(f)
- }
- }
+ protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag]
+ (data: Seq[T], testVectorized: Boolean = true)
+ (f: DataFrame => Unit): Unit = withDataSourceDataFrame(data, testVectorized)(f)
/**
* Writes `data` to a Parquet file, reads it back as a [[DataFrame]] and registers it as a
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
index a25451bef62fd..4ccab58d24fed 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
@@ -555,10 +555,12 @@ abstract class FileStreamSinkSuite extends StreamTest {
}
}
- val fs = new Path(outputDir.getCanonicalPath).getFileSystem(
- spark.sessionState.newHadoopConf())
- val sinkLog = new FileStreamSinkLog(FileStreamSinkLog.VERSION, spark,
- outputDir.getCanonicalPath)
+ val outputDirPath = new Path(outputDir.getCanonicalPath)
+ val hadoopConf = spark.sessionState.newHadoopConf()
+ val fs = outputDirPath.getFileSystem(hadoopConf)
+ val logPath = FileStreamSink.getMetadataLogPath(fs, outputDirPath, conf)
+
+ val sinkLog = new FileStreamSinkLog(FileStreamSinkLog.VERSION, spark, logPath.toString)
val allFiles = sinkLog.allFiles()
// only files from non-empty partition should be logged
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala
index da6e4c52cf3a7..c4885f2842597 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala
@@ -21,8 +21,8 @@ import org.apache.hadoop.hive.common.StatsSetupConst
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.CastSupport
-import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTablePartition, ExternalCatalogUtils, HiveTableRelation}
-import org.apache.spark.sql.catalyst.expressions.{And, AttributeSet, Expression, ExpressionSet, SubqueryExpression}
+import org.apache.spark.sql.catalyst.catalog._
+import org.apache.spark.sql.catalyst.expressions.{And, AttributeSet, Expression, ExpressionSet, PredicateHelper, SubqueryExpression}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
@@ -41,7 +41,7 @@ import org.apache.spark.sql.internal.SQLConf
* TODO: merge this with PruneFileSourcePartitions after we completely make hive as a data source.
*/
private[sql] class PruneHiveTablePartitions(session: SparkSession)
- extends Rule[LogicalPlan] with CastSupport {
+ extends Rule[LogicalPlan] with CastSupport with PredicateHelper {
override val conf: SQLConf = session.sessionState.conf
@@ -103,7 +103,9 @@ private[sql] class PruneHiveTablePartitions(session: SparkSession)
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case op @ PhysicalOperation(projections, filters, relation: HiveTableRelation)
if filters.nonEmpty && relation.isPartitioned && relation.prunedPartitions.isEmpty =>
- val partitionKeyFilters = getPartitionKeyFilters(filters, relation)
+ val predicates = CNFWithGroupExpressionsByReference(filters.reduceLeft(And))
+ val finalPredicates = if (predicates.nonEmpty) predicates else filters
+ val partitionKeyFilters = getPartitionKeyFilters(finalPredicates, relation)
if (partitionKeyFilters.nonEmpty) {
val newPartitions = prunePartitions(relation, partitionKeyFilters)
val newTableMeta = updateTableMeta(relation.tableMeta, newPartitions)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala
index 7f2eb14956dc1..356b92b4652b3 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala
@@ -70,9 +70,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable
val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf)
if (orcOptions.mergeSchema) {
SchemaMergeUtils.mergeSchemasInParallel(
- sparkSession,
- files,
- OrcFileOperator.readOrcSchemasInParallel)
+ sparkSession, options, files, OrcFileOperator.readOrcSchemasInParallel)
} else {
val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles
OrcFileOperator.readSchema(
diff --git a/sql/hive/src/test/resources/golden/timestamp cast #3-0-9f6fcc8c1a29c793c2238bad91453e9f b/sql/hive/src/test/resources/golden/timestamp cast #3-0-9f6fcc8c1a29c793c2238bad91453e9f
new file mode 100644
index 0000000000000..f99e724db6af8
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/timestamp cast #3-0-9f6fcc8c1a29c793c2238bad91453e9f
@@ -0,0 +1,2 @@
+1.2
+
diff --git a/sql/hive/src/test/resources/golden/timestamp cast #4-0-e9286317470d42e9f8122bc98a2c1ce1 b/sql/hive/src/test/resources/golden/timestamp cast #4-0-e9286317470d42e9f8122bc98a2c1ce1
new file mode 100644
index 0000000000000..decdb1d30e6a6
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/timestamp cast #4-0-e9286317470d42e9f8122bc98a2c1ce1
@@ -0,0 +1,2 @@
+-1.2
+
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala
index 473a93bf129df..270595b0011e9 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala
@@ -181,41 +181,25 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite {
"INSERT overwrite directory \"fs://localhost/tmp\" select 1 as a"))
}
- test("SPARK-31061: alterTable should be able to change table provider") {
+ test("SPARK-31061: alterTable should be able to change table provider/hive") {
val catalog = newBasicCatalog()
- val parquetTable = CatalogTable(
- identifier = TableIdentifier("parq_tbl", Some("db1")),
- tableType = CatalogTableType.MANAGED,
- storage = storageFormat.copy(locationUri = Some(new URI("file:/some/path"))),
- schema = new StructType().add("col1", "int").add("col2", "string"),
- provider = Some("parquet"))
- catalog.createTable(parquetTable, ignoreIfExists = false)
-
- val rawTable = externalCatalog.getTable("db1", "parq_tbl")
- assert(rawTable.provider === Some("parquet"))
-
- val fooTable = parquetTable.copy(provider = Some("foo"))
- catalog.alterTable(fooTable)
- val alteredTable = externalCatalog.getTable("db1", "parq_tbl")
- assert(alteredTable.provider === Some("foo"))
- }
-
- test("SPARK-31061: alterTable should be able to change table provider from hive") {
- val catalog = newBasicCatalog()
- val hiveTable = CatalogTable(
- identifier = TableIdentifier("parq_tbl", Some("db1")),
- tableType = CatalogTableType.MANAGED,
- storage = storageFormat,
- schema = new StructType().add("col1", "int").add("col2", "string"),
- provider = Some("hive"))
- catalog.createTable(hiveTable, ignoreIfExists = false)
-
- val rawTable = externalCatalog.getTable("db1", "parq_tbl")
- assert(rawTable.provider === Some("hive"))
-
- val fooTable = rawTable.copy(provider = Some("foo"))
- catalog.alterTable(fooTable)
- val alteredTable = externalCatalog.getTable("db1", "parq_tbl")
- assert(alteredTable.provider === Some("foo"))
+ Seq("parquet", "hive").foreach( provider => {
+ val tableDDL = CatalogTable(
+ identifier = TableIdentifier("parq_tbl", Some("db1")),
+ tableType = CatalogTableType.MANAGED,
+ storage = storageFormat,
+ schema = new StructType().add("col1", "int"),
+ provider = Some(provider))
+ catalog.dropTable("db1", "parq_tbl", true, true)
+ catalog.createTable(tableDDL, ignoreIfExists = false)
+
+ val rawTable = externalCatalog.getTable("db1", "parq_tbl")
+ assert(rawTable.provider === Some(provider))
+
+ val fooTable = rawTable.copy(provider = Some("foo"))
+ catalog.alterTable(fooTable)
+ val alteredTable = externalCatalog.getTable("db1", "parq_tbl")
+ assert(alteredTable.provider === Some("foo"))
+ })
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 2b42444ceeaa1..e5628c33b5ec8 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -564,12 +564,18 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd
assert(-1 == res.get(0))
}
- test("timestamp cast #3") {
+ createQueryTest("timestamp cast #3",
+ "SELECT CAST(TIMESTAMP_SECONDS(1.2) AS DOUBLE) FROM src LIMIT 1")
+
+ createQueryTest("timestamp cast #4",
+ "SELECT CAST(TIMESTAMP_SECONDS(-1.2) AS DOUBLE) FROM src LIMIT 1")
+
+ test("timestamp cast #5") {
val res = sql("SELECT CAST(TIMESTAMP_SECONDS(1200) AS INT) FROM src LIMIT 1").collect().head
assert(1200 == res.getInt(0))
}
- test("timestamp cast #4") {
+ test("timestamp cast #6") {
val res = sql("SELECT CAST(TIMESTAMP_SECONDS(-1200) AS INT) FROM src LIMIT 1").collect().head
assert(-1200 == res.getInt(0))
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala
index c9c36992906a8..24aecb0274ece 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala
@@ -19,22 +19,22 @@ package org.apache.spark.sql.hive.execution
import org.scalatest.Matchers._
-import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.{CatalogFileIndex, HadoopFsRelation, LogicalRelation, PruneFileSourcePartitions}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.functions.broadcast
-import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.StructType
-class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
+class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase {
+
+ override def format: String = "parquet"
object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("PruneFileSourcePartitions", Once, PruneFileSourcePartitions) :: Nil
@@ -108,4 +108,10 @@ class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with Te
}
}
}
+
+ override def getScanExecPartitionSize(plan: SparkPlan): Long = {
+ plan.collectFirst {
+ case p: FileSourceScanExec => p
+ }.get.selectedPartitions.length
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala
index e41709841a736..c29e889c3a941 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala
@@ -17,14 +17,14 @@
package org.apache.spark.sql.hive.execution
-import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.RuleExecutor
-import org.apache.spark.sql.hive.test.TestHiveSingleton
-import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.execution.SparkPlan
-class PruneHiveTablePartitionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
+class PruneHiveTablePartitionsSuite extends PrunePartitionSuiteBase {
+
+ override def format(): String = "hive"
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
@@ -32,7 +32,7 @@ class PruneHiveTablePartitionsSuite extends QueryTest with SQLTestUtils with Tes
EliminateSubqueryAliases, new PruneHiveTablePartitions(spark)) :: Nil
}
- test("SPARK-15616 statistics pruned after going throuhg PruneHiveTablePartitions") {
+ test("SPARK-15616: statistics pruned after going through PruneHiveTablePartitions") {
withTable("test", "temp") {
sql(
s"""
@@ -54,4 +54,10 @@ class PruneHiveTablePartitionsSuite extends QueryTest with SQLTestUtils with Tes
Optimize.execute(analyzed2).stats.sizeInBytes)
}
}
+
+ override def getScanExecPartitionSize(plan: SparkPlan): Long = {
+ plan.collectFirst {
+ case p: HiveTableScanExec => p
+ }.get.prunedPartitions.size
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PrunePartitionSuiteBase.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PrunePartitionSuiteBase.scala
new file mode 100644
index 0000000000000..d088061cdc6e5
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PrunePartitionSuiteBase.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution
+
+import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.test.SQLTestUtils
+
+abstract class PrunePartitionSuiteBase extends QueryTest with SQLTestUtils with TestHiveSingleton {
+
+ protected def format: String
+
+ test("SPARK-28169: Convert scan predicate condition to CNF") {
+ withTempView("temp") {
+ withTable("t") {
+ sql(
+ s"""
+ |CREATE TABLE t(i INT, p STRING)
+ |USING $format
+ |PARTITIONED BY (p)""".stripMargin)
+
+ spark.range(0, 1000, 1).selectExpr("id as col")
+ .createOrReplaceTempView("temp")
+
+ for (part <- Seq(1, 2, 3, 4)) {
+ sql(
+ s"""
+ |INSERT OVERWRITE TABLE t PARTITION (p='$part')
+ |SELECT col FROM temp""".stripMargin)
+ }
+
+ assertPrunedPartitions(
+ "SELECT * FROM t WHERE p = '1' OR (p = '2' AND i = 1)", 2)
+ assertPrunedPartitions(
+ "SELECT * FROM t WHERE (p = '1' AND i = 2) OR (i = 1 OR p = '2')", 4)
+ assertPrunedPartitions(
+ "SELECT * FROM t WHERE (p = '1' AND i = 2) OR (p = '3' AND i = 3 )", 2)
+ assertPrunedPartitions(
+ "SELECT * FROM t WHERE (p = '1' AND i = 2) OR (p = '2' OR p = '3')", 3)
+ assertPrunedPartitions(
+ "SELECT * FROM t", 4)
+ assertPrunedPartitions(
+ "SELECT * FROM t WHERE p = '1' AND i = 2", 1)
+ assertPrunedPartitions(
+ """
+ |SELECT i, COUNT(1) FROM (
+ |SELECT * FROM t WHERE p = '1' OR (p = '2' AND i = 1)
+ |) tmp GROUP BY i
+ """.stripMargin, 2)
+ }
+ }
+ }
+
+ protected def assertPrunedPartitions(query: String, expected: Long): Unit = {
+ val plan = sql(query).queryExecution.sparkPlan
+ assert(getScanExecPartitionSize(plan) == expected)
+ }
+
+ protected def getScanExecPartitionSize(plan: SparkPlan): Long
+}