diff --git a/appveyor.yml b/appveyor.yml index a4da5f9040ded..1fd91daae9015 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -42,8 +42,8 @@ install: # Install maven and dependencies - ps: .\dev\appveyor-install-dependencies.ps1 # Required package for R unit tests - - cmd: R -e "install.packages(c('knitr', 'rmarkdown', 'testthat', 'e1071', 'survival', 'arrow'), repos='https://cloud.r-project.org/')" - - cmd: R -e "packageVersion('knitr'); packageVersion('rmarkdown'); packageVersion('testthat'); packageVersion('e1071'); packageVersion('survival'); packageVersion('arrow')" + - cmd: Rscript -e "install.packages(c('knitr', 'rmarkdown', 'testthat', 'e1071', 'survival', 'arrow'), repos='https://cloud.r-project.org/')" + - cmd: Rscript -e "pkg_list <- as.data.frame(installed.packages()[,c(1, 3:4)]); pkg_list[is.na(pkg_list$Priority), 1:2, drop = FALSE]" build_script: # '-Djna.nosys=true' is required to avoid kernel32.dll load failure. diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index f7f8a0e0e9061..d4394ebcfd258 100755 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -80,6 +80,10 @@ a:not([href]):hover { padding: 0; } +.navbar-brand a:hover { + text-decoration: none; +} + .navbar .navbar-nav .nav-link { height: 50px; padding: 10px 15px 10px; diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index b13028f868072..6606d317e7b86 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -90,7 +90,8 @@ case class FetchFailed( extends TaskFailedReason { override def toErrorString: String = { val bmAddressString = if (bmAddress == null) "null" else bmAddress.toString - s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapIndex=$mapIndex, " + + val mapIndexString = if (mapIndex == Int.MinValue) "Unknown" else mapIndex.toString + s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapIndex=$mapIndexString, " + s"mapId=$mapId, reduceId=$reduceId, message=\n$message\n)" } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala index b1adc3c112ed3..a542d2b8cb27c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala @@ -75,14 +75,29 @@ private class HistoryServerDiskManager( // Go through the recorded store directories and remove any that may have been removed by // external code. - val orphans = listing.view(classOf[ApplicationStoreInfo]).asScala.filter { info => - !new File(info.path).exists() - }.toSeq + val (existences, orphans) = listing + .view(classOf[ApplicationStoreInfo]) + .asScala + .toSeq + .partition { info => + new File(info.path).exists() + } orphans.foreach { info => listing.delete(info.getClass(), info.path) } + // Reading level db would trigger table file compaction, then it may cause size of level db + // directory changed. When service restarts, "currentUsage" is calculated from real directory + // size. Update "ApplicationStoreInfo.size" to ensure "currentUsage" equals + // sum of "ApplicationStoreInfo.size". + existences.foreach { info => + val fileSize = sizeOf(new File(info.path)) + if (fileSize != info.size) { + listing.write(info.copy(size = fileSize)) + } + } + logInfo("Initialized disk manager: " + s"current usage = ${Utils.bytesToString(currentUsage.get())}, " + s"max usage = ${Utils.bytesToString(maxUsage)}") diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index f2113947f6bf5..bf76eef443e81 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -18,6 +18,7 @@ package org.apache.spark.storage import java.io.{File, IOException} +import java.nio.file.Files import java.util.UUID import org.apache.spark.SparkConf @@ -69,8 +70,8 @@ private[spark] class DiskBlockManager(conf: SparkConf, deleteFilesOnStop: Boolea old } else { val newDir = new File(localDirs(dirId), "%02x".format(subDirId)) - if (!newDir.exists() && !newDir.mkdir()) { - throw new IOException(s"Failed to create local dir in $newDir.") + if (!newDir.exists()) { + Files.createDirectory(newDir.toPath) } subDirs(dirId)(subDirId) = newDir newDir diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index ced3f9d15720d..f3372501f471b 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -1078,8 +1078,12 @@ private[spark] object JsonProtocol { val blockManagerAddress = blockManagerIdFromJson(json \ "Block Manager Address") val shuffleId = (json \ "Shuffle ID").extract[Int] val mapId = (json \ "Map ID").extract[Long] - val mapIndex = (json \ "Map Index") match { - case JNothing => 0 + val mapIndex = json \ "Map Index" match { + case JNothing => + // Note, we use the invalid value Int.MinValue here to fill the map index for backward + // compatibility. Otherwise, the fetch failed event will be dropped when the history + // server loads the event log written by the Spark version before 3.0. + Int.MinValue case x => x.extract[Int] } val reduceId = (json \ "Reduce ID").extract[Int] diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerDiskManagerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerDiskManagerSuite.scala index f78469e132490..b17880a733615 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerDiskManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerDiskManagerSuite.scala @@ -158,4 +158,50 @@ class HistoryServerDiskManagerSuite extends SparkFunSuite with BeforeAndAfter { assert(manager.approximateSize(50L, true) > 50L) } + test("SPARK-32024: update ApplicationStoreInfo.size during initializing") { + val manager = mockManager() + val leaseA = manager.lease(2) + doReturn(3L).when(manager).sizeOf(meq(leaseA.tmpPath)) + val dstA = leaseA.commit("app1", None) + assert(manager.free() === 0) + assert(manager.committed() === 3) + // Listing store tracks dstA now. + assert(store.read(classOf[ApplicationStoreInfo], dstA.getAbsolutePath).size === 3) + + // Simulate: service restarts, new disk manager (manager1) is initialized. + val manager1 = mockManager() + // Simulate: event KVstore compaction before restart, directory size reduces. + doReturn(2L).when(manager1).sizeOf(meq(dstA)) + doReturn(2L).when(manager1).sizeOf(meq(new File(testDir, "apps"))) + manager1.initialize() + // "ApplicationStoreInfo.size" is updated for dstA. + assert(store.read(classOf[ApplicationStoreInfo], dstA.getAbsolutePath).size === 2) + assert(manager1.free() === 1) + // If "ApplicationStoreInfo.size" is not correctly updated, "IllegalStateException" + // would be thrown. + val leaseB = manager1.lease(2) + assert(manager1.free() === 1) + doReturn(2L).when(manager1).sizeOf(meq(leaseB.tmpPath)) + val dstB = leaseB.commit("app2", None) + assert(manager1.committed() === 2) + // Listing store tracks dstB only, dstA is evicted by "makeRoom()". + assert(store.read(classOf[ApplicationStoreInfo], dstB.getAbsolutePath).size === 2) + + val manager2 = mockManager() + // Simulate: cache entities are written after replaying, directory size increases. + doReturn(3L).when(manager2).sizeOf(meq(dstB)) + doReturn(3L).when(manager2).sizeOf(meq(new File(testDir, "apps"))) + manager2.initialize() + // "ApplicationStoreInfo.size" is updated for dstB. + assert(store.read(classOf[ApplicationStoreInfo], dstB.getAbsolutePath).size === 3) + assert(manager2.free() === 0) + val leaseC = manager2.lease(2) + doReturn(2L).when(manager2).sizeOf(meq(leaseC.tmpPath)) + val dstC = leaseC.commit("app3", None) + assert(manager2.free() === 1) + assert(manager2.committed() === 2) + // Listing store tracks dstC only, dstB is evicted by "makeRoom()". + assert(store.read(classOf[ApplicationStoreInfo], dstC.getAbsolutePath).size === 2) + } + } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 8737cd5bb3241..6ede98d55f094 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -313,8 +313,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers all (directSiteRelativeLinks) should not startWith (knoxBaseUrl) } - // TODO (SPARK-31723): re-enable it - ignore("static relative links are prefixed with uiRoot (spark.ui.proxyBase)") { + test("static relative links are prefixed with uiRoot (spark.ui.proxyBase)") { val uiRoot = Option(System.getenv("APPLICATION_WEB_PROXY_BASE")).getOrElse("/testwebproxybase") val page = new HistoryPage(server) val request = mock[HttpServletRequest] diff --git a/core/src/test/scala/org/apache/spark/scheduler/WorkerDecommissionExtendedSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/WorkerDecommissionExtendedSuite.scala index 02c72fa349a79..4de5aaeab5c51 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/WorkerDecommissionExtendedSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/WorkerDecommissionExtendedSuite.scala @@ -32,17 +32,17 @@ import org.apache.spark.scheduler.cluster.StandaloneSchedulerBackend class WorkerDecommissionExtendedSuite extends SparkFunSuite with LocalSparkContext { private val conf = new org.apache.spark.SparkConf() .setAppName(getClass.getName) - .set(SPARK_MASTER, "local-cluster[20,1,512]") + .set(SPARK_MASTER, "local-cluster[5,1,512]") .set(EXECUTOR_MEMORY, "512m") .set(DYN_ALLOCATION_ENABLED, true) .set(DYN_ALLOCATION_SHUFFLE_TRACKING_ENABLED, true) - .set(DYN_ALLOCATION_INITIAL_EXECUTORS, 20) + .set(DYN_ALLOCATION_INITIAL_EXECUTORS, 5) .set(WORKER_DECOMMISSION_ENABLED, true) test("Worker decommission and executor idle timeout") { sc = new SparkContext(conf.set(DYN_ALLOCATION_EXECUTOR_IDLE_TIMEOUT.key, "10s")) withSpark(sc) { sc => - TestUtils.waitUntilExecutorsUp(sc, 20, 60000) + TestUtils.waitUntilExecutorsUp(sc, 5, 60000) val rdd1 = sc.parallelize(1 to 10, 2) val rdd2 = rdd1.map(x => (1, x)) val rdd3 = rdd2.reduceByKey(_ + _) @@ -54,10 +54,10 @@ class WorkerDecommissionExtendedSuite extends SparkFunSuite with LocalSparkConte } } - test("Decommission 19 executors from 20 executors in total") { + test("Decommission 4 executors from 5 executors in total") { sc = new SparkContext(conf) withSpark(sc) { sc => - TestUtils.waitUntilExecutorsUp(sc, 20, 60000) + TestUtils.waitUntilExecutorsUp(sc, 5, 60000) val rdd1 = sc.parallelize(1 to 100000, 200) val rdd2 = rdd1.map(x => (x % 100, x)) val rdd3 = rdd2.reduceByKey(_ + _) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 955589fc5b47b..c75e98f39758d 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -334,7 +334,7 @@ class JsonProtocolSuite extends SparkFunSuite { val oldEvent = JsonProtocol.taskEndReasonToJson(fetchFailed) .removeField({ _._1 == "Map Index" }) val expectedFetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 16L, - 0, 19, "ignored") + Int.MinValue, 19, "ignored") assert(expectedFetchFailed === JsonProtocol.taskEndReasonFromJson(oldEvent)) } diff --git a/dev/run-tests.py b/dev/run-tests.py index ec04c37857d96..223072cbe7bfb 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -363,7 +363,8 @@ def build_spark_assembly_sbt(extra_profiles, checkstyle=False): if checkstyle: run_java_style_checks(build_profiles) - build_spark_unidoc_sbt(extra_profiles) + if not os.environ.get("AMPLAB_JENKINS"): + build_spark_unidoc_sbt(extra_profiles) def build_apache_spark(build_tool, extra_profiles): @@ -648,7 +649,7 @@ def main(): # if "DOCS" in changed_modules and test_env == "amplab_jenkins": # build_spark_documentation() - if any(m.should_run_build_tests for m in test_modules): + if any(m.should_run_build_tests for m in test_modules) and test_env != "amplab_jenkins": run_build_tests() # spark build diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index 0c84db38afafc..d3138ae319160 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -30,8 +30,6 @@ license: | - In Spark 3.1, `from_unixtime`, `unix_timestamp`,`to_unix_timestamp`, `to_timestamp` and `to_date` will fail if the specified datetime pattern is invalid. In Spark 3.0 or earlier, they result `NULL`. - - In Spark 3.1, casting numeric to timestamp will be forbidden by default. It's strongly recommended to use dedicated functions: TIMESTAMP_SECONDS, TIMESTAMP_MILLIS and TIMESTAMP_MICROS. Or you can set `spark.sql.legacy.allowCastNumericToTimestamp` to true to work around it. See more details in SPARK-31710. - ## Upgrading from Spark SQL 3.0 to 3.0.1 - In Spark 3.0, JSON datasource and JSON function `schema_of_json` infer TimestampType from string values if they match to the pattern defined by the JSON option `timestampFormat`. Since version 3.0.1, the timestamp type inference is disabled by default. Set the JSON option `inferTimestamp` to `true` to enable such type inference. diff --git a/external/docker-integration-tests/src/test/resources/mariadb_docker_entrypoint.sh b/external/docker-integration-tests/src/test/resources/mariadb_docker_entrypoint.sh index 00885a3b62327..343bc01651318 100755 --- a/external/docker-integration-tests/src/test/resources/mariadb_docker_entrypoint.sh +++ b/external/docker-integration-tests/src/test/resources/mariadb_docker_entrypoint.sh @@ -18,7 +18,7 @@ dpkg-divert --add /bin/systemctl && ln -sT /bin/true /bin/systemctl apt update -apt install -y mariadb-plugin-gssapi-server +apt install -y mariadb-plugin-gssapi-server=1:10.4.12+maria~bionic echo "gssapi_keytab_path=/docker-entrypoint-initdb.d/mariadb.keytab" >> /etc/mysql/mariadb.conf.d/auth_gssapi.cnf echo "gssapi_principal_name=mariadb/__IP_ADDRESS_REPLACE_ME__@EXAMPLE.COM" >> /etc/mysql/mariadb.conf.d/auth_gssapi.cnf docker-entrypoint.sh mysqld diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 0ee895a95a288..8336df8e34ae0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -1220,10 +1220,41 @@ class GeneralizedLinearRegressionSummary private[regression] ( private[regression] lazy val link: Link = familyLink.link + /** + * summary row containing: + * numInstances, weightSum, deviance, rss, weighted average of label - offset. + */ + private lazy val glrSummary = { + val devUDF = udf { (label: Double, pred: Double, weight: Double) => + family.deviance(label, pred, weight) + } + val devCol = sum(devUDF(label, prediction, weight)) + + val rssCol = if (model.getFamily.toLowerCase(Locale.ROOT) != Binomial.name && + model.getFamily.toLowerCase(Locale.ROOT) != Poisson.name) { + val rssUDF = udf { (label: Double, pred: Double, weight: Double) => + (label - pred) * (label - pred) * weight / family.variance(pred) + } + sum(rssUDF(label, prediction, weight)) + } else { + lit(Double.NaN) + } + + val avgCol = if (model.getFitIntercept && + (!model.hasOffsetCol || (model.hasOffsetCol && family == Gaussian && link == Identity))) { + sum((label - offset) * weight) / sum(weight) + } else { + lit(Double.NaN) + } + + predictions + .select(count(label), sum(weight), devCol, rssCol, avgCol) + .head() + } + /** Number of instances in DataFrame predictions. */ @Since("2.2.0") - lazy val numInstances: Long = predictions.count() - + lazy val numInstances: Long = glrSummary.getLong(0) /** * Name of features. If the name cannot be retrieved from attributes, @@ -1335,9 +1366,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( */ if (!model.hasOffsetCol || (model.hasOffsetCol && family == Gaussian && link == Identity)) { - val agg = predictions.agg(sum(weight.multiply( - label.minus(offset))), sum(weight)).first() - link.link(agg.getDouble(0) / agg.getDouble(1)) + link.link(glrSummary.getDouble(4)) } else { // Create empty feature column and fit intercept only model using param setting from model val featureNull = "feature_" + java.util.UUID.randomUUID.toString @@ -1362,12 +1391,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( * The deviance for the fitted model. */ @Since("2.0.0") - lazy val deviance: Double = { - predictions.select(label, prediction, weight).rdd.map { - case Row(label: Double, pred: Double, weight: Double) => - family.deviance(label, pred, weight) - }.sum() - } + lazy val deviance: Double = glrSummary.getDouble(2) /** * The dispersion of the fitted model. @@ -1381,14 +1405,14 @@ class GeneralizedLinearRegressionSummary private[regression] ( model.getFamily.toLowerCase(Locale.ROOT) == Poisson.name) { 1.0 } else { - val rss = pearsonResiduals.agg(sum(pow(col("pearsonResiduals"), 2.0))).first().getDouble(0) + val rss = glrSummary.getDouble(3) rss / degreesOfFreedom } /** Akaike Information Criterion (AIC) for the fitted model. */ @Since("2.0.0") lazy val aic: Double = { - val weightSum = predictions.select(weight).agg(sum(weight)).first().getDouble(0) + val weightSum = glrSummary.getDouble(1) val t = predictions.select( label, prediction, weight).rdd.map { case Row(label: Double, pred: Double, weight: Double) => diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index d9f09c097292a..de559142a9261 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -1037,7 +1037,7 @@ class LinearRegressionSummary private[regression] ( } /** Number of instances in DataFrame predictions */ - lazy val numInstances: Long = predictions.count() + lazy val numInstances: Long = metrics.count /** Degrees of freedom */ @Since("2.2.0") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index b697d2746ce7b..7938427544bd9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -131,4 +131,6 @@ class RegressionMetrics @Since("2.0.0") ( 1 - SSerr / SStot } } + + private[spark] def count: Long = summary.count } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 04a3fc4b63050..60c54dfc98a58 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -476,6 +476,7 @@ object SparkParallelTestGrouping { "org.apache.spark.ml.classification.LogisticRegressionSuite", "org.apache.spark.ml.classification.LinearSVCSuite", "org.apache.spark.sql.SQLQueryTestSuite", + "org.apache.spark.sql.hive.thriftserver.SparkExecuteStatementOperationSuite", "org.apache.spark.sql.hive.thriftserver.ThriftServerQueryTestSuite", "org.apache.spark.sql.hive.thriftserver.SparkSQLEnvSuite", "org.apache.spark.sql.hive.thriftserver.ui.ThriftServerPageSuite", diff --git a/python/pyspark/sql/tests/test_pandas_grouped_map.py b/python/pyspark/sql/tests/test_pandas_grouped_map.py index 76119432662ba..cc6167e619285 100644 --- a/python/pyspark/sql/tests/test_pandas_grouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_grouped_map.py @@ -545,13 +545,13 @@ def f(pdf): def test_grouped_over_window_with_key(self): - data = [(0, 1, "2018-03-10T00:00:00+00:00", False), - (1, 2, "2018-03-11T00:00:00+00:00", False), - (2, 2, "2018-03-12T00:00:00+00:00", False), - (3, 3, "2018-03-15T00:00:00+00:00", False), - (4, 3, "2018-03-16T00:00:00+00:00", False), - (5, 3, "2018-03-17T00:00:00+00:00", False), - (6, 3, "2018-03-21T00:00:00+00:00", False)] + data = [(0, 1, "2018-03-10T00:00:00+00:00", [0]), + (1, 2, "2018-03-11T00:00:00+00:00", [0]), + (2, 2, "2018-03-12T00:00:00+00:00", [0]), + (3, 3, "2018-03-15T00:00:00+00:00", [0]), + (4, 3, "2018-03-16T00:00:00+00:00", [0]), + (5, 3, "2018-03-17T00:00:00+00:00", [0]), + (6, 3, "2018-03-21T00:00:00+00:00", [0])] expected_window = [ {'start': datetime.datetime(2018, 3, 10, 0, 0), @@ -562,30 +562,43 @@ def test_grouped_over_window_with_key(self): 'end': datetime.datetime(2018, 3, 25, 0, 0)}, ] - expected = {0: (1, expected_window[0]), - 1: (2, expected_window[0]), - 2: (2, expected_window[0]), - 3: (3, expected_window[1]), - 4: (3, expected_window[1]), - 5: (3, expected_window[1]), - 6: (3, expected_window[2])} + expected_key = {0: (1, expected_window[0]), + 1: (2, expected_window[0]), + 2: (2, expected_window[0]), + 3: (3, expected_window[1]), + 4: (3, expected_window[1]), + 5: (3, expected_window[1]), + 6: (3, expected_window[2])} + + # id -> array of group with len of num records in window + expected = {0: [1], + 1: [2, 2], + 2: [2, 2], + 3: [3, 3, 3], + 4: [3, 3, 3], + 5: [3, 3, 3], + 6: [3]} df = self.spark.createDataFrame(data, ['id', 'group', 'ts', 'result']) df = df.select(col('id'), col('group'), col('ts').cast('timestamp'), col('result')) - @pandas_udf(df.schema, PandasUDFType.GROUPED_MAP) def f(key, pdf): group = key[0] window_range = key[1] - # Result will be True if group and window range equal to expected - is_expected = pdf.id.apply(lambda id: (expected[id][0] == group and - expected[id][1] == window_range)) - return pdf.assign(result=is_expected) - result = df.groupby('group', window('ts', '5 days')).apply(f).select('result').collect() + # Make sure the key with group and window values are correct + for _, i in pdf.id.iteritems(): + assert expected_key[i][0] == group, "{} != {}".format(expected_key[i][0], group) + assert expected_key[i][1] == window_range, \ + "{} != {}".format(expected_key[i][1], window_range) - # Check that all group and window_range values from udf matched expected - self.assertTrue(all([r[0] for r in result])) + return pdf.assign(result=[[group] * len(pdf)] * len(pdf)) + + result = df.groupby('group', window('ts', '5 days')).applyInPandas(f, df.schema)\ + .select('id', 'result').collect() + + for r in result: + self.assertListEqual(expected[r[0]], r[1]) def test_case_insensitive_grouping_column(self): # SPARK-31915: case-insensitive grouping column should work. diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 320a68dffe7a3..ddd13ca3a01be 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -116,6 +116,9 @@ class NullType(DataType): __metaclass__ = DataTypeSingleton + def simpleString(self): + return 'unknown' + class AtomicType(DataType): """An internal type used to represent everything that is not diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 691fde8d48f94..b383e037e1ed8 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -1461,8 +1461,7 @@ nonReserved ; // NOTE: If you add a new token in the list below, you should update the list of keywords -// in `docs/sql-keywords.md`. If the token is a non-reserved keyword, -// please update `ansiNonReserved` and `nonReserved` as well. +// and reserved tag in `docs/sql-ref-ansi-compliance.md#sql-keywords`. //============================ // Start of the keywords list diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d08a6382f738b..f92cf377bff12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1050,12 +1050,10 @@ class Analyzer( val staticPartitions = i.partitionSpec.filter(_._2.isDefined).mapValues(_.get) val query = addStaticPartitionColumns(r, i.query, staticPartitions) - val dynamicPartitionOverwrite = partCols.size > staticPartitions.size && - conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC if (!i.overwrite) { AppendData.byPosition(r, query) - } else if (dynamicPartitionOverwrite) { + } else if (conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC) { OverwritePartitionsDynamic.byPosition(r, query) } else { OverwriteByExpression.byPosition(r, query, staticDeleteExpression(r, staticPartitions)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 9c99acaa994b8..43dd0979eed7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -158,6 +158,11 @@ trait CheckAnalysis extends PredicateHelper { case g: GroupingID => failAnalysis("grouping_id() can only be used with GroupingSets/Cube/Rollup") + case e: Expression if e.children.exists(_.isInstanceOf[WindowFunction]) && + !e.isInstanceOf[WindowExpression] => + val w = e.children.find(_.isInstanceOf[WindowFunction]).get + failAnalysis(s"Window function $w requires an OVER clause.") + case w @ WindowExpression(AggregateExpression(_, _, true, _, _), _) => failAnalysis(s"Distinct window functions are not supported: $w") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala index 2a0a944e4849c..a40604045978c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala @@ -34,6 +34,7 @@ class ResolveCatalogs(val catalogManager: CatalogManager) override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case AlterTableAddColumnsStatement( nameParts @ NonSessionCatalogAndTable(catalog, tbl), cols) => + cols.foreach(c => failNullType(c.dataType)) cols.foreach(c => failCharType(c.dataType)) val changes = cols.map { col => TableChange.addColumn( @@ -47,6 +48,7 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case AlterTableReplaceColumnsStatement( nameParts @ NonSessionCatalogAndTable(catalog, tbl), cols) => + cols.foreach(c => failNullType(c.dataType)) cols.foreach(c => failCharType(c.dataType)) val changes: Seq[TableChange] = loadTable(catalog, tbl.asIdentifier) match { case Some(table) => @@ -69,6 +71,7 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case a @ AlterTableAlterColumnStatement( nameParts @ NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _) => + a.dataType.foreach(failNullType) a.dataType.foreach(failCharType) val colName = a.column.toArray val typeChange = a.dataType.map { newDataType => @@ -145,6 +148,7 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case c @ CreateTableStatement( NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _) => + assertNoNullTypeInSchema(c.tableSchema) assertNoCharTypeInSchema(c.tableSchema) CreateV2Table( catalog.asTableCatalog, @@ -157,6 +161,9 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case c @ CreateTableAsSelectStatement( NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _) => + if (c.asSelect.resolved) { + assertNoNullTypeInSchema(c.asSelect.schema) + } CreateTableAsSelect( catalog.asTableCatalog, tbl.asIdentifier, @@ -172,6 +179,7 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case c @ ReplaceTableStatement( NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _) => + assertNoNullTypeInSchema(c.tableSchema) assertNoCharTypeInSchema(c.tableSchema) ReplaceTable( catalog.asTableCatalog, @@ -184,6 +192,9 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case c @ ReplaceTableAsSelectStatement( NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _) => + if (c.asSelect.resolved) { + assertNoNullTypeInSchema(c.asSelect.schema) + } ReplaceTableAsSelect( catalog.asTableCatalog, tbl.asIdentifier, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 1b4a705e804f1..cf7cc3a5e16ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -539,3 +539,61 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E override def prettyName: String = "str_to_map" } + +/** + * Adds/replaces field in struct by name. + */ +case class WithFields( + structExpr: Expression, + names: Seq[String], + valExprs: Seq[Expression]) extends Unevaluable { + + assert(names.length == valExprs.length) + + override def checkInputDataTypes(): TypeCheckResult = { + if (!structExpr.dataType.isInstanceOf[StructType]) { + TypeCheckResult.TypeCheckFailure( + "struct argument should be struct type, got: " + structExpr.dataType.catalogString) + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def children: Seq[Expression] = structExpr +: valExprs + + override def dataType: StructType = evalExpr.dataType.asInstanceOf[StructType] + + override def foldable: Boolean = structExpr.foldable && valExprs.forall(_.foldable) + + override def nullable: Boolean = structExpr.nullable + + override def prettyName: String = "with_fields" + + lazy val evalExpr: Expression = { + val existingExprs = structExpr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map { + case (name, i) => (name, GetStructField(KnownNotNull(structExpr), i).asInstanceOf[Expression]) + } + + val addOrReplaceExprs = names.zip(valExprs) + + val resolver = SQLConf.get.resolver + val newExprs = addOrReplaceExprs.foldLeft(existingExprs) { + case (resultExprs, newExpr @ (newExprName, _)) => + if (resultExprs.exists(x => resolver(x._1, newExprName))) { + resultExprs.map { + case (name, _) if resolver(name, newExprName) => newExpr + case x => x + } + } else { + resultExprs :+ newExpr + } + }.flatMap { case (name, expr) => Seq(Literal(name), expr) } + + val expr = CreateNamedStruct(newExprs) + if (structExpr.nullable) { + If(IsNull(structExpr), Literal(null, expr.dataType), expr) + } else { + expr + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 9c600c9d39cf7..89ff4facd25a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -57,7 +57,7 @@ object ExtractValue { val fieldName = v.toString val ordinal = findField(fields, fieldName, resolver) GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), - ordinal, fields.length, containsNull) + ordinal, fields.length, containsNull || fields(ordinal).nullable) case (_: ArrayType, _) => GetArrayItem(child, extraction) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index a1277217b1b3a..551cbc3161cc1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1450,8 +1450,7 @@ case class ParseToDate(left: Expression, format: Option[Expression], child: Expr extends RuntimeReplaceable { def this(left: Expression, format: Expression) { - this(left, Option(format), - Cast(SecondsToTimestamp(UnixTimestamp(left, format)), DateType)) + this(left, Option(format), Cast(GetTimestamp(left, format), DateType)) } def this(left: Expression) = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala index f79dabf758c14..1c33a2c7c3136 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala @@ -39,7 +39,18 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] { // Remove redundant field extraction. case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) => createNamedStruct.valExprs(ordinal) - + case GetStructField(w @ WithFields(struct, names, valExprs), ordinal, maybeName) => + val name = w.dataType(ordinal).name + val matches = names.zip(valExprs).filter(_._1 == name) + if (matches.nonEmpty) { + // return last matching element as that is the final value for the field being extracted. + // For example, if a user submits a query like this: + // `$"struct_col".withField("b", lit(1)).withField("b", lit(2)).getField("b")` + // we want to return `lit(2)` (and not `lit(1)`). + matches.last._2 + } else { + GetStructField(struct, ordinal, maybeName) + } // Remove redundant array indexing. case GetArrayStructFields(CreateArray(elems, useStringTypeWhenEmpty), field, ordinal, _, _) => // Instead of selecting the field on the entire array, select it from each member diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala index 118f41f9cd232..0c8666b72cace 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala @@ -149,10 +149,12 @@ object NestedColumnAliasing { case _ => false } + // Note that when we group by extractors with their references, we should remove + // cosmetic variations. val exclusiveAttrSet = AttributeSet(exclusiveAttrs ++ otherRootReferences) val aliasSub = nestedFieldReferences.asInstanceOf[Seq[ExtractValue]] .filter(!_.references.subsetOf(exclusiveAttrSet)) - .groupBy(_.references.head) + .groupBy(_.references.head.canonicalized.asInstanceOf[Attribute]) .flatMap { case (attr, nestedFields: Seq[ExtractValue]) => // Remove redundant `ExtractValue`s if they share the same parent nest field. // For example, when `a.b` and `a.b.c` are in project list, we only need to alias `a.b`. @@ -174,9 +176,12 @@ object NestedColumnAliasing { // If all nested fields of `attr` are used, we don't need to introduce new aliases. // By default, ColumnPruning rule uses `attr` already. + // Note that we need to remove cosmetic variations first, so we only count a + // nested field once. if (nestedFieldToAlias.nonEmpty && - nestedFieldToAlias - .map { case (nestedField, _) => totalFieldNum(nestedField.dataType) } + dedupNestedFields.map(_.canonicalized) + .distinct + .map { nestedField => totalFieldNum(nestedField.dataType) } .sum < totalFieldNum(attr.dataType)) { Some(attr.exprId -> nestedFieldToAlias) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e800ee3b93f51..1b141572cc7f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -107,6 +107,7 @@ abstract class Optimizer(catalogManager: CatalogManager) EliminateSerialization, RemoveRedundantAliases, RemoveNoopOperators, + CombineWithFields, SimplifyExtractValueOps, CombineConcats) ++ extendedOperatorOptimizationRules @@ -207,7 +208,8 @@ abstract class Optimizer(catalogManager: CatalogManager) CollapseProject, RemoveNoopOperators) :+ // This batch must be executed after the `RewriteSubquery` batch, which creates joins. - Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) + Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+ + Batch("ReplaceWithFieldsExpression", Once, ReplaceWithFieldsExpression) // remove any batches with no rules. this may happen when subclasses do not add optional rules. batches.filter(_.rules.nonEmpty) @@ -240,7 +242,8 @@ abstract class Optimizer(catalogManager: CatalogManager) PullupCorrelatedPredicates.ruleName :: RewriteCorrelatedScalarSubquery.ruleName :: RewritePredicateSubquery.ruleName :: - NormalizeFloatingNumbers.ruleName :: Nil + NormalizeFloatingNumbers.ruleName :: + ReplaceWithFieldsExpression.ruleName :: Nil /** * Optimize all the subqueries inside expression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala new file mode 100644 index 0000000000000..05c90864e4bb0 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.WithFields +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule + + +/** + * Combines all adjacent [[WithFields]] expression into a single [[WithFields]] expression. + */ +object CombineWithFields extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case WithFields(WithFields(struct, names1, valExprs1), names2, valExprs2) => + WithFields(struct, names1 ++ names2, valExprs1 ++ valExprs2) + } +} + +/** + * Replaces [[WithFields]] expression with an evaluable expression. + */ +object ReplaceWithFieldsExpression extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case w: WithFields => w.evalExpr + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index d08bcb1420176..6b41a8b22fbee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2203,6 +2203,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging DecimalType(precision.getText.toInt, 0) case ("decimal" | "dec" | "numeric", precision :: scale :: Nil) => DecimalType(precision.getText.toInt, scale.getText.toInt) + case ("void", Nil) => NullType case ("interval", Nil) => CalendarIntervalType case (dt, params) => val dtStr = if (params.nonEmpty) s"$dt(${params.mkString(",")})" else dt diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index e1f329352592f..d130a13282cc8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.logical.AlterTable import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.types.{ArrayType, DataType, HIVE_TYPE_STRING, HiveStringType, MapType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, DataType, HIVE_TYPE_STRING, HiveStringType, MapType, NullType, StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.Utils @@ -346,4 +346,23 @@ private[sql] object CatalogV2Util { } } } + + def failNullType(dt: DataType): Unit = { + def containsNullType(dt: DataType): Boolean = dt match { + case ArrayType(et, _) => containsNullType(et) + case MapType(kt, vt, _) => containsNullType(kt) || containsNullType(vt) + case StructType(fields) => fields.exists(f => containsNullType(f.dataType)) + case _ => dt.isInstanceOf[NullType] + } + if (containsNullType(dt)) { + throw new AnalysisException( + "Cannot create tables with unknown type.") + } + } + + def assertNoNullTypeInSchema(schema: StructType): Unit = { + schema.foreach { f => + failNullType(f.dataType) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3149d14c1ddcc..31dd943eeba2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2634,7 +2634,7 @@ object SQLConf { "when false, forbid the cast, more details in SPARK-31710") .version("3.1.0") .booleanConf - .createWithDefault(false) + .createWithDefault(true) val COALESCE_BUCKETS_IN_SORT_MERGE_JOIN_ENABLED = buildConf("spark.sql.bucketing.coalesceBucketsInSortMergeJoin.enabled") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala index 14097a5280d50..6c9a1d69ca681 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala @@ -32,6 +32,10 @@ class NullType private() extends DataType { override def defaultSize: Int = 1 private[spark] override def asNullable: NullType = this + + // "null" is mainly used to represent a literal in Spark, + // it's better to avoid using it for data types. + override def simpleString: String = "unknown" } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index c15ec49e14282..c0be49af2107d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -884,4 +884,15 @@ class AnalysisSuite extends AnalysisTest with Matchers { Seq("Intersect can only be performed on tables with the compatible column types. " + "timestamp <> double at the second column of the second table")) } + + test("SPARK-31975: Throw user facing error when use WindowFunction directly") { + assertAnalysisError(testRelation2.select(RowNumber()), + Seq("Window function row_number() requires an OVER clause.")) + + assertAnalysisError(testRelation2.select(Sum(RowNumber())), + Seq("Window function row_number() requires an OVER clause.")) + + assertAnalysisError(testRelation2.select(RowNumber() + 1), + Seq("Window function row_number() requires an OVER clause.")) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 76ec450a4d7c6..4ab288a34cb08 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -49,9 +49,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { } protected def checkNullCast(from: DataType, to: DataType): Unit = { - withSQLConf(SQLConf.LEGACY_ALLOW_CAST_NUMERIC_TO_TIMESTAMP.key -> "true") { - checkEvaluation(cast(Literal.create(null, from), to, UTC_OPT), null) - } + checkEvaluation(cast(Literal.create(null, from), to, UTC_OPT), null) } test("null cast") { @@ -240,9 +238,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkCast(1.5, 1.5f) checkCast(1.5, "1.5") - withSQLConf(SQLConf.LEGACY_ALLOW_CAST_NUMERIC_TO_TIMESTAMP.key -> "true") { - checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble) - } + checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble) } test("cast from string") { @@ -309,19 +305,17 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { cast(cast("5", ByteType), ShortType), IntegerType), FloatType), DoubleType), LongType), 5.toLong) - withSQLConf(SQLConf.LEGACY_ALLOW_CAST_NUMERIC_TO_TIMESTAMP.key -> "true") { - checkEvaluation( - cast(cast(cast(cast(cast(cast("5", ByteType), TimestampType), - DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), - 5.toShort) - checkEvaluation( - cast(cast(cast(cast(cast(cast("5", TimestampType, UTC_OPT), ByteType), - DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), - null) - checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.SYSTEM_DEFAULT), - ByteType), TimestampType), LongType), StringType), ShortType), - 5.toShort) - } + checkEvaluation( + cast(cast(cast(cast(cast(cast("5", ByteType), TimestampType), + DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), + 5.toShort) + checkEvaluation( + cast(cast(cast(cast(cast(cast("5", TimestampType, UTC_OPT), ByteType), + DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), + null) + checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.SYSTEM_DEFAULT), + ByteType), TimestampType), LongType), StringType), ShortType), + 5.toShort) checkEvaluation(cast("23", DoubleType), 23d) checkEvaluation(cast("23", IntegerType), 23) @@ -383,31 +377,29 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(ts, FloatType), 15.003f) checkEvaluation(cast(ts, DoubleType), 15.003) - withSQLConf(SQLConf.LEGACY_ALLOW_CAST_NUMERIC_TO_TIMESTAMP.key -> "true") { - checkEvaluation(cast(cast(tss, ShortType), TimestampType), - fromJavaTimestamp(ts) * MILLIS_PER_SECOND) - checkEvaluation(cast(cast(tss, IntegerType), TimestampType), - fromJavaTimestamp(ts) * MILLIS_PER_SECOND) - checkEvaluation(cast(cast(tss, LongType), TimestampType), - fromJavaTimestamp(ts) * MILLIS_PER_SECOND) - checkEvaluation( - cast(cast(millis.toFloat / MILLIS_PER_SECOND, TimestampType), FloatType), - millis.toFloat / MILLIS_PER_SECOND) - checkEvaluation( - cast(cast(millis.toDouble / MILLIS_PER_SECOND, TimestampType), DoubleType), - millis.toDouble / MILLIS_PER_SECOND) - checkEvaluation( - cast(cast(Decimal(1), TimestampType), DecimalType.SYSTEM_DEFAULT), - Decimal(1)) + checkEvaluation(cast(cast(tss, ShortType), TimestampType), + fromJavaTimestamp(ts) * MILLIS_PER_SECOND) + checkEvaluation(cast(cast(tss, IntegerType), TimestampType), + fromJavaTimestamp(ts) * MILLIS_PER_SECOND) + checkEvaluation(cast(cast(tss, LongType), TimestampType), + fromJavaTimestamp(ts) * MILLIS_PER_SECOND) + checkEvaluation( + cast(cast(millis.toFloat / MILLIS_PER_SECOND, TimestampType), FloatType), + millis.toFloat / MILLIS_PER_SECOND) + checkEvaluation( + cast(cast(millis.toDouble / MILLIS_PER_SECOND, TimestampType), DoubleType), + millis.toDouble / MILLIS_PER_SECOND) + checkEvaluation( + cast(cast(Decimal(1), TimestampType), DecimalType.SYSTEM_DEFAULT), + Decimal(1)) - // A test for higher precision than millis - checkEvaluation(cast(cast(0.000001, TimestampType), DoubleType), 0.000001) + // A test for higher precision than millis + checkEvaluation(cast(cast(0.000001, TimestampType), DoubleType), 0.000001) - checkEvaluation(cast(Double.NaN, TimestampType), null) - checkEvaluation(cast(1.0 / 0.0, TimestampType), null) - checkEvaluation(cast(Float.NaN, TimestampType), null) - checkEvaluation(cast(1.0f / 0.0f, TimestampType), null) - } + checkEvaluation(cast(Double.NaN, TimestampType), null) + checkEvaluation(cast(1.0 / 0.0, TimestampType), null) + checkEvaluation(cast(Float.NaN, TimestampType), null) + checkEvaluation(cast(1.0f / 0.0f, TimestampType), null) } test("cast from array") { @@ -1036,10 +1028,8 @@ class CastSuite extends CastSuiteBase { test("cast from int 2") { checkEvaluation(cast(1, LongType), 1.toLong) - withSQLConf(SQLConf.LEGACY_ALLOW_CAST_NUMERIC_TO_TIMESTAMP.key -> "true") { - checkEvaluation(cast(cast(1000, TimestampType), LongType), 1000.toLong) - checkEvaluation(cast(cast(-1200, TimestampType), LongType), -1200.toLong) - } + checkEvaluation(cast(cast(1000, TimestampType), LongType), 1000.toLong) + checkEvaluation(cast(cast(-1200, TimestampType), LongType), -1200.toLong) checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123)) checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123)) @@ -1323,7 +1313,7 @@ class CastSuite extends CastSuiteBase { } } - test("SPARK-31710:fail casting from numeric to timestamp by default") { + test("SPARK-31710: fail casting from numeric to timestamp if it is forbidden") { Seq(true, false).foreach { enable => withSQLConf(SQLConf.LEGACY_ALLOW_CAST_NUMERIC_TO_TIMESTAMP.key -> enable.toString) { assert(cast(2.toByte, TimestampType).resolved == enable) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 3df7d02fb6604..dbe43709d1d35 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext @@ -159,6 +160,31 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(getArrayStructFields(nullArrayStruct, "a"), null) } + test("SPARK-32167: nullability of GetArrayStructFields") { + val resolver = SQLConf.get.resolver + + val array1 = ArrayType( + new StructType().add("a", "int", nullable = true), + containsNull = false) + val data1 = Literal.create(Seq(Row(null)), array1) + val get1 = ExtractValue(data1, Literal("a"), resolver).asInstanceOf[GetArrayStructFields] + assert(get1.containsNull) + + val array2 = ArrayType( + new StructType().add("a", "int", nullable = false), + containsNull = true) + val data2 = Literal.create(Seq(null), array2) + val get2 = ExtractValue(data2, Literal("a"), resolver).asInstanceOf[GetArrayStructFields] + assert(get2.containsNull) + + val array3 = ArrayType( + new StructType().add("a", "int", nullable = false), + containsNull = false) + val data3 = Literal.create(Seq(Row(1)), array3) + val get3 = ExtractValue(data3, Literal("a"), resolver).asInstanceOf[GetArrayStructFields] + assert(!get3.containsNull) + } + test("CreateArray") { val intSeq = Seq(5, 10, 15, 20, 25) val longSeq = intSeq.map(_.toLong) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala index 3c826e812b5cc..76d6890cc8f6f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala @@ -254,13 +254,13 @@ class SelectedFieldSuite extends AnalysisTest { StructField("col3", ArrayType(StructType( StructField("field1", StructType( StructField("subfield1", IntegerType, nullable = false) :: Nil)) - :: Nil), containsNull = false), nullable = false) + :: Nil), containsNull = true), nullable = false) } testSelect(arrayWithStructAndMap, "col3.field2['foo'] as foo") { StructField("col3", ArrayType(StructType( StructField("field2", MapType(StringType, IntegerType, valueContainsNull = false)) - :: Nil), containsNull = false), nullable = false) + :: Nil), containsNull = true), nullable = false) } // |-- col1: string (nullable = false) @@ -471,7 +471,7 @@ class SelectedFieldSuite extends AnalysisTest { testSelect(mapWithArrayOfStructKey, "map_keys(col2)[0].field1 as foo") { StructField("col2", MapType( ArrayType(StructType( - StructField("field1", StringType) :: Nil), containsNull = false), + StructField("field1", StringType) :: Nil), containsNull = true), ArrayType(StructType( StructField("field3", StructType( StructField("subfield3", IntegerType) :: @@ -482,7 +482,7 @@ class SelectedFieldSuite extends AnalysisTest { StructField("col2", MapType( ArrayType(StructType( StructField("field2", StructType( - StructField("subfield1", IntegerType) :: Nil)) :: Nil), containsNull = false), + StructField("subfield1", IntegerType) :: Nil)) :: Nil), containsNull = true), ArrayType(StructType( StructField("field3", StructType( StructField("subfield3", IntegerType) :: diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineWithFieldsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineWithFieldsSuite.scala new file mode 100644 index 0000000000000..a3e0bbc57e639 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineWithFieldsSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, WithFields} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + + +class CombineWithFieldsSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("CombineWithFields", FixedPoint(10), CombineWithFields) :: Nil + } + + private val testRelation = LocalRelation('a.struct('a1.int)) + + test("combines two WithFields") { + val originalQuery = testRelation + .select(Alias( + WithFields( + WithFields( + 'a, + Seq("b1"), + Seq(Literal(4))), + Seq("c1"), + Seq(Literal(5))), "out")()) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .select(Alias(WithFields('a, Seq("b1", "c1"), Seq(Literal(4), Literal(5))), "out")()) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("combines three WithFields") { + val originalQuery = testRelation + .select(Alias( + WithFields( + WithFields( + WithFields( + 'a, + Seq("b1"), + Seq(Literal(4))), + Seq("c1"), + Seq(Literal(5))), + Seq("d1"), + Seq(Literal(6))), "out")()) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .select(Alias(WithFields('a, Seq("b1", "c1", "d1"), Seq(4, 5, 6).map(Literal(_))), "out")()) + .analyze + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index d55746002783a..c71e7dbe7d6f9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -452,4 +452,61 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2") checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](3, 4))), null) } + + private val structAttr = 'struct1.struct('a.int) + private val testStructRelation = LocalRelation(structAttr) + + test("simplify GetStructField on WithFields that is not changing the attribute being extracted") { + val query = testStructRelation.select( + GetStructField(WithFields('struct1, Seq("b"), Seq(Literal(1))), 0, Some("a")) as "outerAtt") + val expected = testStructRelation.select(GetStructField('struct1, 0, Some("a")) as "outerAtt") + checkRule(query, expected) + } + + test("simplify GetStructField on WithFields that is changing the attribute being extracted") { + val query = testStructRelation.select( + GetStructField(WithFields('struct1, Seq("b"), Seq(Literal(1))), 1, Some("b")) as "outerAtt") + val expected = testStructRelation.select(Literal(1) as "outerAtt") + checkRule(query, expected) + } + + test( + "simplify GetStructField on WithFields that is changing the attribute being extracted twice") { + val query = testStructRelation + .select(GetStructField(WithFields('struct1, Seq("b", "b"), Seq(Literal(1), Literal(2))), 1, + Some("b")) as "outerAtt") + val expected = testStructRelation.select(Literal(2) as "outerAtt") + checkRule(query, expected) + } + + test("collapse multiple GetStructField on the same WithFields") { + val query = testStructRelation + .select(WithFields('struct1, Seq("b"), Seq(Literal(2))) as "struct2") + .select( + GetStructField('struct2, 0, Some("a")) as "struct1A", + GetStructField('struct2, 1, Some("b")) as "struct1B") + val expected = testStructRelation.select( + GetStructField('struct1, 0, Some("a")) as "struct1A", + Literal(2) as "struct1B") + checkRule(query, expected) + } + + test("collapse multiple GetStructField on different WithFields") { + val query = testStructRelation + .select( + WithFields('struct1, Seq("b"), Seq(Literal(2))) as "struct2", + WithFields('struct1, Seq("b"), Seq(Literal(3))) as "struct3") + .select( + GetStructField('struct2, 0, Some("a")) as "struct2A", + GetStructField('struct2, 1, Some("b")) as "struct2B", + GetStructField('struct3, 0, Some("a")) as "struct3A", + GetStructField('struct3, 1, Some("b")) as "struct3B") + val expected = testStructRelation + .select( + GetStructField('struct1, 0, Some("a")) as "struct2A", + Literal(2) as "struct2B", + GetStructField('struct1, 0, Some("a")) as "struct3A", + Literal(3) as "struct3B") + checkRule(query, expected) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala index d519fdf378786..655b1d26d6c90 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala @@ -61,6 +61,7 @@ class DataTypeParserSuite extends SparkFunSuite { checkDataType("varchAr(20)", StringType) checkDataType("cHaR(27)", StringType) checkDataType("BINARY", BinaryType) + checkDataType("void", NullType) checkDataType("interval", CalendarIntervalType) checkDataType("array", ArrayType(DoubleType, true)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index 3d7026e180cd1..616fc72320caf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.connector +import java.time.{Instant, ZoneId} +import java.time.temporal.ChronoUnit import java.util import scala.collection.JavaConverters._ @@ -25,12 +27,13 @@ import scala.collection.mutable import org.scalatest.Assertions._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.connector.catalog._ -import org.apache.spark.sql.connector.expressions.{IdentityTransform, NamedReference, Transform} +import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, HoursTransform, IdentityTransform, MonthsTransform, Transform, YearsTransform} import org.apache.spark.sql.connector.read._ import org.apache.spark.sql.connector.write._ import org.apache.spark.sql.sources.{And, EqualTo, Filter, IsNotNull} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, DateType, StructType, TimestampType} import org.apache.spark.sql.util.CaseInsensitiveStringMap /** @@ -46,10 +49,15 @@ class InMemoryTable( private val allowUnsupportedTransforms = properties.getOrDefault("allow-unsupported-transforms", "false").toBoolean - partitioning.foreach { t => - if (!t.isInstanceOf[IdentityTransform] && !allowUnsupportedTransforms) { - throw new IllegalArgumentException(s"Transform $t must be IdentityTransform") - } + partitioning.foreach { + case _: IdentityTransform => + case _: YearsTransform => + case _: MonthsTransform => + case _: DaysTransform => + case _: HoursTransform => + case _: BucketTransform => + case t if !allowUnsupportedTransforms => + throw new IllegalArgumentException(s"Transform $t is not a supported transform") } // The key `Seq[Any]` is the partition values. @@ -66,8 +74,14 @@ class InMemoryTable( } } + private val UTC = ZoneId.of("UTC") + private val EPOCH_LOCAL_DATE = Instant.EPOCH.atZone(UTC).toLocalDate + private def getKey(row: InternalRow): Seq[Any] = { - def extractor(fieldNames: Array[String], schema: StructType, row: InternalRow): Any = { + def extractor( + fieldNames: Array[String], + schema: StructType, + row: InternalRow): (Any, DataType) = { val index = schema.fieldIndex(fieldNames(0)) val value = row.toSeq(schema).apply(index) if (fieldNames.length > 1) { @@ -78,10 +92,44 @@ class InMemoryTable( throw new IllegalArgumentException(s"Unsupported type, ${dataType.simpleString}") } } else { - value + (value, schema(index).dataType) } } - partCols.map(fieldNames => extractor(fieldNames, schema, row)) + + partitioning.map { + case IdentityTransform(ref) => + extractor(ref.fieldNames, schema, row)._1 + case YearsTransform(ref) => + extractor(ref.fieldNames, schema, row) match { + case (days: Int, DateType) => + ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days)) + case (micros: Long, TimestampType) => + val localDate = DateTimeUtils.microsToInstant(micros).atZone(UTC).toLocalDate + ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate) + } + case MonthsTransform(ref) => + extractor(ref.fieldNames, schema, row) match { + case (days: Int, DateType) => + ChronoUnit.MONTHS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days)) + case (micros: Long, TimestampType) => + val localDate = DateTimeUtils.microsToInstant(micros).atZone(UTC).toLocalDate + ChronoUnit.MONTHS.between(EPOCH_LOCAL_DATE, localDate) + } + case DaysTransform(ref) => + extractor(ref.fieldNames, schema, row) match { + case (days, DateType) => + days + case (micros: Long, TimestampType) => + ChronoUnit.DAYS.between(Instant.EPOCH, DateTimeUtils.microsToInstant(micros)) + } + case HoursTransform(ref) => + extractor(ref.fieldNames, schema, row) match { + case (micros: Long, TimestampType) => + ChronoUnit.HOURS.between(Instant.EPOCH, DateTimeUtils.microsToInstant(micros)) + } + case BucketTransform(numBuckets, ref) => + (extractor(ref.fieldNames, schema, row).hashCode() & Integer.MAX_VALUE) % numBuckets + } } def withData(data: Array[BufferedRows]): InMemoryTable = dataMap.synchronized { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index e6f7b1d723af6..da542c67d9c51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -871,6 +871,72 @@ class Column(val expr: Expression) extends Logging { */ def getItem(key: Any): Column = withExpr { UnresolvedExtractValue(expr, Literal(key)) } + // scalastyle:off line.size.limit + /** + * An expression that adds/replaces field in `StructType` by name. + * + * {{{ + * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col") + * df.select($"struct_col".withField("c", lit(3))) + * // result: {"a":1,"b":2,"c":3} + * + * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col") + * df.select($"struct_col".withField("b", lit(3))) + * // result: {"a":1,"b":3} + * + * val df = sql("SELECT CAST(NULL AS struct) struct_col") + * df.select($"struct_col".withField("c", lit(3))) + * // result: null of type struct + * + * val df = sql("SELECT named_struct('a', 1, 'b', 2, 'b', 3) struct_col") + * df.select($"struct_col".withField("b", lit(100))) + * // result: {"a":1,"b":100,"b":100} + * + * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col") + * df.select($"struct_col".withField("a.c", lit(3))) + * // result: {"a":{"a":1,"b":2,"c":3}} + * + * val df = sql("SELECT named_struct('a', named_struct('b', 1), 'a', named_struct('c', 2)) struct_col") + * df.select($"struct_col".withField("a.c", lit(3))) + * // result: org.apache.spark.sql.AnalysisException: Ambiguous reference to fields + * }}} + * + * @group expr_ops + * @since 3.1.0 + */ + // scalastyle:on line.size.limit + def withField(fieldName: String, col: Column): Column = withExpr { + require(fieldName != null, "fieldName cannot be null") + require(col != null, "col cannot be null") + + val nameParts = if (fieldName.isEmpty) { + fieldName :: Nil + } else { + CatalystSqlParser.parseMultipartIdentifier(fieldName) + } + withFieldHelper(expr, nameParts, Nil, col.expr) + } + + private def withFieldHelper( + struct: Expression, + namePartsRemaining: Seq[String], + namePartsDone: Seq[String], + value: Expression) : WithFields = { + val name = namePartsRemaining.head + if (namePartsRemaining.length == 1) { + WithFields(struct, name :: Nil, value :: Nil) + } else { + val newNamesRemaining = namePartsRemaining.tail + val newNamesDone = namePartsDone :+ name + val newValue = withFieldHelper( + struct = UnresolvedExtractValue(struct, Literal(name)), + namePartsRemaining = newNamesRemaining, + namePartsDone = newNamesDone, + value = value) + WithFields(struct, name :: Nil, newValue :: Nil) + } + } + /** * An expression that gets a field by name in a `StructType`. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index bf90875e511f8..bc3f38a35834d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -48,6 +48,7 @@ class ResolveSessionCatalog( override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case AlterTableAddColumnsStatement( nameParts @ SessionCatalogAndTable(catalog, tbl), cols) => + cols.foreach(c => failNullType(c.dataType)) loadTable(catalog, tbl.asIdentifier).collect { case v1Table: V1Table => if (!DDLUtils.isHiveTable(v1Table.v1Table)) { @@ -76,6 +77,7 @@ class ResolveSessionCatalog( case AlterTableReplaceColumnsStatement( nameParts @ SessionCatalogAndTable(catalog, tbl), cols) => + cols.foreach(c => failNullType(c.dataType)) val changes: Seq[TableChange] = loadTable(catalog, tbl.asIdentifier) match { case Some(_: V1Table) => throw new AnalysisException("REPLACE COLUMNS is only supported with v2 tables.") @@ -100,6 +102,7 @@ class ResolveSessionCatalog( case a @ AlterTableAlterColumnStatement( nameParts @ SessionCatalogAndTable(catalog, tbl), _, _, _, _, _) => + a.dataType.foreach(failNullType) loadTable(catalog, tbl.asIdentifier).collect { case v1Table: V1Table => if (!DDLUtils.isHiveTable(v1Table.v1Table)) { @@ -268,6 +271,7 @@ class ResolveSessionCatalog( // session catalog and the table provider is not v2. case c @ CreateTableStatement( SessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _) => + assertNoNullTypeInSchema(c.tableSchema) val provider = c.provider.getOrElse(conf.defaultDataSourceName) if (!isV2Provider(provider)) { if (!DDLUtils.isHiveTable(Some(provider))) { @@ -292,6 +296,9 @@ class ResolveSessionCatalog( case c @ CreateTableAsSelectStatement( SessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _) => + if (c.asSelect.resolved) { + assertNoNullTypeInSchema(c.asSelect.schema) + } val provider = c.provider.getOrElse(conf.defaultDataSourceName) if (!isV2Provider(provider)) { val tableDesc = buildCatalogTable(tbl.asTableIdentifier, new StructType, @@ -319,6 +326,7 @@ class ResolveSessionCatalog( // session catalog and the table provider is not v2. case c @ ReplaceTableStatement( SessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _) => + assertNoNullTypeInSchema(c.tableSchema) val provider = c.provider.getOrElse(conf.defaultDataSourceName) if (!isV2Provider(provider)) { throw new AnalysisException("REPLACE TABLE is only supported with v2 tables.") @@ -336,6 +344,9 @@ class ResolveSessionCatalog( case c @ ReplaceTableAsSelectStatement( SessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _) => + if (c.asSelect.resolved) { + assertNoNullTypeInSchema(c.asSelect.schema) + } val provider = c.provider.getOrElse(conf.defaultDataSourceName) if (!isV2Provider(provider)) { throw new AnalysisException("REPLACE TABLE AS SELECT is only supported with v2 tables.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 95343e2872def..60cacda9f5f1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.{Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, RowOrdering} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.CatalogV2Util.assertNoNullTypeInSchema import org.apache.spark.sql.connector.expressions.{FieldReference, RewritableTransform} import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 @@ -292,6 +293,8 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi "in the table definition of " + table.identifier, sparkSession.sessionState.conf.caseSensitiveAnalysis) + assertNoNullTypeInSchema(schema) + val normalizedPartCols = normalizePartitionColumns(schema, table) val normalizedBucketSpec = normalizeBucketSpec(schema, table) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index e4e7887017a1d..c199df676ced3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -40,7 +40,7 @@ case class BatchScanExec( override def hashCode(): Int = batch.hashCode() - override lazy val partitions: Seq[InputPartition] = batch.planInputPartitions() + @transient override lazy val partitions: Seq[InputPartition] = batch.planInputPartitions() override lazy val readerFactory: PartitionReaderFactory = batch.createReaderFactory() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 520afad287648..7fe3263630820 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -198,7 +198,7 @@ object EvaluatePython { case udt: UserDefinedType[_] => makeFromJava(udt.sqlType) - case other => (obj: Any) => nullSafeConvert(other)(PartialFunction.empty) + case other => (obj: Any) => nullSafeConvert(obj)(PartialFunction.empty) } private def nullSafeConvert(input: Any)(f: PartialFunction[Any, Any]): Any = { diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 8898a11ec08fb..c39adac4ac680 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -34,7 +34,7 @@ | org.apache.spark.sql.catalyst.expressions.Ascii | ascii | SELECT ascii('222') | struct | | org.apache.spark.sql.catalyst.expressions.Asin | asin | SELECT asin(0) | struct | | org.apache.spark.sql.catalyst.expressions.Asinh | asinh | SELECT asinh(0) | struct | -| org.apache.spark.sql.catalyst.expressions.AssertTrue | assert_true | SELECT assert_true(0 < 1) | struct | +| org.apache.spark.sql.catalyst.expressions.AssertTrue | assert_true | SELECT assert_true(0 < 1) | struct | | org.apache.spark.sql.catalyst.expressions.Atan | atan | SELECT atan(0) | struct | | org.apache.spark.sql.catalyst.expressions.Atan2 | atan2 | SELECT atan2(0, 0) | struct | | org.apache.spark.sql.catalyst.expressions.Atanh | atanh | SELECT atanh(0) | struct | diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/literals.sql.out index f6720f6c5faa4..02747718c91df 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/literals.sql.out @@ -5,7 +5,7 @@ -- !query select null, Null, nUll -- !query schema -struct +struct -- !query output NULL NULL NULL diff --git a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out index 9943b93c431df..2dd6960682740 100644 --- a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out @@ -49,7 +49,7 @@ two 2 -- !query select * from values ("one", null), ("two", null) as data(a, b) -- !query schema -struct +struct -- !query output one NULL two NULL diff --git a/sql/core/src/test/resources/sql-tests/results/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/literals.sql.out index f6720f6c5faa4..02747718c91df 100644 --- a/sql/core/src/test/resources/sql-tests/results/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/literals.sql.out @@ -5,7 +5,7 @@ -- !query select null, Null, nUll -- !query schema -struct +struct -- !query output NULL NULL NULL diff --git a/sql/core/src/test/resources/sql-tests/results/misc-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/misc-functions.sql.out index bd8ffb82ee129..8d34bf293ef2b 100644 --- a/sql/core/src/test/resources/sql-tests/results/misc-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/misc-functions.sql.out @@ -7,7 +7,7 @@ select typeof(null) -- !query schema struct -- !query output -null +unknown -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/select.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/select.sql.out index 1e59036b979b4..8b32bd6ce1995 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/select.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/select.sql.out @@ -308,7 +308,7 @@ struct<1:int> -- !query select foo.* from (select null) as foo -- !query schema -struct +struct -- !query output NULL @@ -316,7 +316,7 @@ NULL -- !query select foo.* from (select 'xyzzy',1,null) as foo -- !query schema -struct +struct -- !query output xyzzy 1 NULL diff --git a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out index 26a44a85841e0..b905f9e038619 100644 --- a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out @@ -5,7 +5,7 @@ -- !query SELECT ifnull(null, 'x'), ifnull('y', 'x'), ifnull(null, null) -- !query schema -struct +struct -- !query output x y NULL @@ -21,7 +21,7 @@ NULL x -- !query SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null) -- !query schema -struct +struct -- !query output x y NULL @@ -29,7 +29,7 @@ x y NULL -- !query SELECT nvl2(null, 'x', 'y'), nvl2('n', 'x', 'y'), nvl2(null, null, null) -- !query schema -struct +struct -- !query output y x NULL diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-inline-table.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-inline-table.sql.out index d78d347bc9802..0680a873fbf8f 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/udf-inline-table.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-inline-table.sql.out @@ -49,7 +49,7 @@ two 2 -- !query select udf(a), b from values ("one", null), ("two", null) as data(a, b) -- !query schema -struct +struct -- !query output one NULL two NULL diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index fa06484a73d95..131ab1b94f59e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -923,4 +923,503 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { val inSet = InSet(Literal("a"), Set("a", "b").map(UTF8String.fromString)) assert(inSet.sql === "('a' IN ('a', 'b'))") } + + def checkAnswerAndSchema( + df: => DataFrame, + expectedAnswer: Seq[Row], + expectedSchema: StructType): Unit = { + + checkAnswer(df, expectedAnswer) + assert(df.schema == expectedSchema) + } + + private lazy val structType = StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false))) + + private lazy val structLevel1: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(1, null, 3)) :: Nil), + StructType(Seq(StructField("a", structType, nullable = false)))) + + private lazy val nullStructLevel1: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(null) :: Nil), + StructType(Seq(StructField("a", structType, nullable = true)))) + + private lazy val structLevel2: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(Row(1, null, 3))) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", structType, nullable = false))), + nullable = false)))) + + private lazy val nullStructLevel2: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(null)) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", structType, nullable = true))), + nullable = false)))) + + private lazy val structLevel3: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(Row(Row(1, null, 3)))) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", structType, nullable = false))), + nullable = false))), + nullable = false)))) + + test("withField should throw an exception if called on a non-StructType column") { + intercept[AnalysisException] { + testData.withColumn("key", $"key".withField("a", lit(2))) + }.getMessage should include("struct argument should be struct type, got: int") + } + + test("withField should throw an exception if either fieldName or col argument are null") { + intercept[IllegalArgumentException] { + structLevel1.withColumn("a", $"a".withField(null, lit(2))) + }.getMessage should include("fieldName cannot be null") + + intercept[IllegalArgumentException] { + structLevel1.withColumn("a", $"a".withField("b", null)) + }.getMessage should include("col cannot be null") + + intercept[IllegalArgumentException] { + structLevel1.withColumn("a", $"a".withField(null, null)) + }.getMessage should include("fieldName cannot be null") + } + + test("withField should throw an exception if any intermediate structs don't exist") { + intercept[AnalysisException] { + structLevel2.withColumn("a", 'a.withField("x.b", lit(2))) + }.getMessage should include("No such struct field x in a") + + intercept[AnalysisException] { + structLevel3.withColumn("a", 'a.withField("a.x.b", lit(2))) + }.getMessage should include("No such struct field x in a") + } + + test("withField should throw an exception if intermediate field is not a struct") { + intercept[AnalysisException] { + structLevel1.withColumn("a", 'a.withField("b.a", lit(2))) + }.getMessage should include("struct argument should be struct type, got: int") + } + + test("withField should throw an exception if intermediate field reference is ambiguous") { + intercept[AnalysisException] { + val structLevel2: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(Row(1, null, 3), 4)) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", structType, nullable = false), + StructField("a", structType, nullable = false))), + nullable = false)))) + + structLevel2.withColumn("a", 'a.withField("a.b", lit(2))) + }.getMessage should include("Ambiguous reference to fields") + } + + test("withField should add field with no name") { + checkAnswerAndSchema( + structLevel1.withColumn("a", $"a".withField("", lit(4))), + Row(Row(1, null, 3, 4)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("", IntegerType, nullable = false))), + nullable = false)))) + } + + test("withField should add field to struct") { + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.withField("d", lit(4))), + Row(Row(1, null, 3, 4)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false))), + nullable = false)))) + } + + test("withField should add field to null struct") { + checkAnswerAndSchema( + nullStructLevel1.withColumn("a", $"a".withField("d", lit(4))), + Row(null) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false))), + nullable = true)))) + } + + test("withField should add field to nested null struct") { + checkAnswerAndSchema( + nullStructLevel2.withColumn("a", $"a".withField("a.d", lit(4))), + Row(Row(null)) :: Nil, + StructType( + Seq(StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false))), + nullable = true))), + nullable = false)))) + } + + test("withField should add null field to struct") { + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.withField("d", lit(null).cast(IntegerType))), + Row(Row(1, null, 3, null)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = true))), + nullable = false)))) + } + + test("withField should add multiple fields to struct") { + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("e", lit(5))), + Row(Row(1, null, 3, 4, 5)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false), + StructField("e", IntegerType, nullable = false))), + nullable = false)))) + } + + test("withField should add field to nested struct") { + Seq( + structLevel2.withColumn("a", 'a.withField("a.d", lit(4))), + structLevel2.withColumn("a", 'a.withField("a", $"a.a".withField("d", lit(4)))) + ).foreach { df => + checkAnswerAndSchema( + df, + Row(Row(Row(1, null, 3, 4))) :: Nil, + StructType( + Seq(StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + } + } + + test("withField should add field to deeply nested struct") { + checkAnswerAndSchema( + structLevel3.withColumn("a", 'a.withField("a.a.d", lit(4))), + Row(Row(Row(Row(1, null, 3, 4)))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false))), + nullable = false))), + nullable = false))), + nullable = false)))) + } + + test("withField should replace field in struct") { + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.withField("b", lit(2))), + Row(Row(1, 2, 3)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false)))) + } + + test("withField should replace field in null struct") { + checkAnswerAndSchema( + nullStructLevel1.withColumn("a", 'a.withField("b", lit("foo"))), + Row(null) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", StringType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = true)))) + } + + test("withField should replace field in nested null struct") { + checkAnswerAndSchema( + nullStructLevel2.withColumn("a", $"a".withField("a.b", lit("foo"))), + Row(Row(null)) :: Nil, + StructType( + Seq(StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", StringType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = true))), + nullable = false)))) + } + + test("withField should replace field with null value in struct") { + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.withField("c", lit(null).cast(IntegerType))), + Row(Row(1, null, null)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = true))), + nullable = false)))) + } + + test("withField should replace multiple fields in struct") { + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.withField("a", lit(10)).withField("b", lit(20))), + Row(Row(10, 20, 3)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false)))) + } + + test("withField should replace field in nested struct") { + Seq( + structLevel2.withColumn("a", $"a".withField("a.b", lit(2))), + structLevel2.withColumn("a", 'a.withField("a", $"a.a".withField("b", lit(2)))) + ).foreach { df => + checkAnswerAndSchema( + df, + Row(Row(Row(1, 2, 3))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + } + } + + test("withField should replace field in deeply nested struct") { + checkAnswerAndSchema( + structLevel3.withColumn("a", $"a".withField("a.a.b", lit(2))), + Row(Row(Row(Row(1, 2, 3)))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false))), + nullable = false))), + nullable = false)))) + } + + test("withField should replace all fields with given name in struct") { + val structLevel1 = spark.createDataFrame( + sparkContext.parallelize(Row(Row(1, 2, 3)) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.withField("b", lit(100))), + Row(Row(1, 100, 100)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false)))) + } + + test("withField should replace fields in struct in given order") { + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.withField("b", lit(2)).withField("b", lit(20))), + Row(Row(1, 20, 3)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false)))) + } + + test("withField should add field and then replace same field in struct") { + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("d", lit(5))), + Row(Row(1, null, 3, 5)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false))), + nullable = false)))) + } + + test("withField should handle fields with dots in their name if correctly quoted") { + val df: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(Row(1, null, 3))) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a.b", StructType(Seq( + StructField("c.d", IntegerType, nullable = false), + StructField("e.f", IntegerType, nullable = true), + StructField("g.h", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + + checkAnswerAndSchema( + df.withColumn("a", 'a.withField("`a.b`.`e.f`", lit(2))), + Row(Row(Row(1, 2, 3))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a.b", StructType(Seq( + StructField("c.d", IntegerType, nullable = false), + StructField("e.f", IntegerType, nullable = false), + StructField("g.h", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + + intercept[AnalysisException] { + df.withColumn("a", 'a.withField("a.b.e.f", lit(2))) + }.getMessage should include("No such struct field a in a.b") + } + + private lazy val mixedCaseStructLevel1: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(1, 1)) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("B", IntegerType, nullable = false))), + nullable = false)))) + + test("withField should replace field in struct even if casing is different") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + checkAnswerAndSchema( + mixedCaseStructLevel1.withColumn("a", 'a.withField("A", lit(2))), + Row(Row(2, 1)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("A", IntegerType, nullable = false), + StructField("B", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswerAndSchema( + mixedCaseStructLevel1.withColumn("a", 'a.withField("b", lit(2))), + Row(Row(1, 2)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false)))) + } + } + + test("withField should add field to struct because casing is different") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + checkAnswerAndSchema( + mixedCaseStructLevel1.withColumn("a", 'a.withField("A", lit(2))), + Row(Row(1, 1, 2)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("B", IntegerType, nullable = false), + StructField("A", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswerAndSchema( + mixedCaseStructLevel1.withColumn("a", 'a.withField("b", lit(2))), + Row(Row(1, 1, 2)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("B", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false)))) + } + } + + private lazy val mixedCaseStructLevel2: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(Row(1, 1), Row(1, 1))) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false), + StructField("B", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + + test("withField should replace nested field in struct even if casing is different") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + checkAnswerAndSchema( + mixedCaseStructLevel2.withColumn("a", 'a.withField("A.a", lit(2))), + Row(Row(Row(2, 1), Row(1, 1))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("A", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false), + StructField("B", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + + checkAnswerAndSchema( + mixedCaseStructLevel2.withColumn("a", 'a.withField("b.a", lit(2))), + Row(Row(Row(1, 1), Row(2, 1))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false), + StructField("b", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + } + } + + test("withField should throw an exception because casing is different") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + intercept[AnalysisException] { + mixedCaseStructLevel2.withColumn("a", 'a.withField("A.a", lit(2))) + }.getMessage should include("No such struct field A in a, B") + + intercept[AnalysisException] { + mixedCaseStructLevel2.withColumn("a", 'a.withField("b.a", lit(2))) + }.getMessage should include("No such struct field b in a, B") + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala index 6b503334f9f23..bdcf7230e3211 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala @@ -17,11 +17,15 @@ package org.apache.spark.sql +import scala.collection.JavaConverters._ + import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{ArrayType, StructType} class ComplexTypesSuite extends QueryTest with SharedSparkSession { + import testImplicits._ override def beforeAll(): Unit = { super.beforeAll() @@ -106,4 +110,11 @@ class ComplexTypesSuite extends QueryTest with SharedSparkSession { checkAnswer(df1, Row(10, 12) :: Row(11, 13) :: Nil) checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0) } + + test("SPARK-32167: get field from an array of struct") { + val innerStruct = new StructType().add("i", "int", nullable = true) + val schema = new StructType().add("arr", ArrayType(innerStruct, containsNull = false)) + val df = spark.createDataFrame(List(Row(Seq(Row(1), Row(null)))).asJava, schema) + checkAnswer(df.select($"arr".getField("i")), Row(Seq(1, null))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala index ac2ebd8bd748b..508eefafd0754 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala @@ -336,7 +336,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo spark.table("source") .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp")) .writeTo("testcat.table_name") - .tableProperty("allow-unsupported-transforms", "true") .partitionedBy(years($"ts")) .create() @@ -350,7 +349,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo spark.table("source") .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp")) .writeTo("testcat.table_name") - .tableProperty("allow-unsupported-transforms", "true") .partitionedBy(months($"ts")) .create() @@ -364,7 +362,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo spark.table("source") .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp")) .writeTo("testcat.table_name") - .tableProperty("allow-unsupported-transforms", "true") .partitionedBy(days($"ts")) .create() @@ -378,7 +375,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo spark.table("source") .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp")) .writeTo("testcat.table_name") - .tableProperty("allow-unsupported-transforms", "true") .partitionedBy(hours($"ts")) .create() @@ -391,7 +387,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo test("Create: partitioned by bucket(4, id)") { spark.table("source") .writeTo("testcat.table_name") - .tableProperty("allow-unsupported-transforms", "true") .partitionedBy(bucket(4, $"id")) .create() @@ -596,7 +591,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo lit("2019-09-02 07:00:00.000000").cast("timestamp") as "modified", lit("America/Los_Angeles") as "timezone")) .writeTo("testcat.table_name") - .tableProperty("allow-unsupported-transforms", "true") .partitionedBy( years($"ts.created"), months($"ts.created"), days($"ts.created"), hours($"ts.created"), years($"ts.modified"), months($"ts.modified"), days($"ts.modified"), hours($"ts.modified") @@ -624,7 +618,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo lit("2019-09-02 07:00:00.000000").cast("timestamp") as "modified", lit("America/Los_Angeles") as "timezone")) .writeTo("testcat.table_name") - .tableProperty("allow-unsupported-transforms", "true") .partitionedBy(bucket(4, $"ts.timezone")) .create() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 231a8f2aa7ddd..daa262d581cb0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -405,7 +405,7 @@ class FileBasedDataSourceSuite extends QueryTest "" } def errorMessage(format: String): String = { - s"$format data source does not support null data type." + s"$format data source does not support unknown data type." } withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> useV1List) { withTempDir { dir => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index f7f4df8f2d2e9..85aea3ce41ecc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.connector +import java.sql.Timestamp +import java.time.LocalDate + import scala.collection.JavaConverters._ import org.apache.spark.SparkException @@ -27,7 +30,7 @@ import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} -import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION +import org.apache.spark.sql.internal.SQLConf.{PARTITION_OVERWRITE_MODE, PartitionOverwriteMode, V2_SESSION_CATALOG_IMPLEMENTATION} import org.apache.spark.sql.internal.connector.SimpleTableProvider import org.apache.spark.sql.sources.SimpleScanSource import org.apache.spark.sql.types.{BooleanType, LongType, StringType, StructField, StructType} @@ -1647,7 +1650,6 @@ class DataSourceV2SQLSuite """ |CREATE TABLE testcat.t (id int, `a.b` string) USING foo |CLUSTERED BY (`a.b`) INTO 4 BUCKETS - |OPTIONS ('allow-unsupported-transforms'=true) """.stripMargin) val testCatalog = catalog("testcat").asTableCatalog.asInstanceOf[InMemoryTableCatalog] @@ -2494,6 +2496,38 @@ class DataSourceV2SQLSuite } } + test("SPARK-32168: INSERT OVERWRITE - hidden days partition - dynamic mode") { + def testTimestamp(daysOffset: Int): Timestamp = { + Timestamp.valueOf(LocalDate.of(2020, 1, 1 + daysOffset).atStartOfDay()) + } + + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + val t1 = s"${catalogAndNamespace}tbl" + withTable(t1) { + val df = spark.createDataFrame(Seq( + (testTimestamp(1), "a"), + (testTimestamp(2), "b"), + (testTimestamp(3), "c"))).toDF("ts", "data") + df.createOrReplaceTempView("source_view") + + sql(s"CREATE TABLE $t1 (ts timestamp, data string) " + + s"USING $v2Format PARTITIONED BY (days(ts))") + sql(s"INSERT INTO $t1 VALUES " + + s"(CAST(date_add('2020-01-01', 2) AS timestamp), 'dummy'), " + + s"(CAST(date_add('2020-01-01', 4) AS timestamp), 'keep')") + sql(s"INSERT OVERWRITE TABLE $t1 SELECT ts, data FROM source_view") + + val expected = spark.createDataFrame(Seq( + (testTimestamp(1), "a"), + (testTimestamp(2), "b"), + (testTimestamp(3), "c"), + (testTimestamp(4), "keep"))).toDF("ts", "data") + + verifyTable(t1, expected) + } + } + } + private def testV1Command(sqlCommand: String, sqlParams: String): Unit = { val e = intercept[AnalysisException] { sql(s"$sqlCommand $sqlParams") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala index b88ad5218fcd2..2cc7a1f994645 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala @@ -446,21 +446,18 @@ trait InsertIntoSQLOnlyTests } } - test("InsertInto: overwrite - multiple static partitions - dynamic mode") { - // Since all partitions are provided statically, this should be supported by everyone - withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { - val t1 = s"${catalogAndNamespace}tbl" - withTableAndData(t1) { view => - sql(s"CREATE TABLE $t1 (id bigint, data string, p int) " + - s"USING $v2Format PARTITIONED BY (id, p)") - sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)") - sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id = 2, p = 2) SELECT data FROM $view") - verifyTable(t1, Seq( - (2, "a", 2), - (2, "b", 2), - (2, "c", 2), - (4, "keep", 2)).toDF("id", "data", "p")) - } + dynamicOverwriteTest("InsertInto: overwrite - multiple static partitions - dynamic mode") { + val t1 = s"${catalogAndNamespace}tbl" + withTableAndData(t1) { view => + sql(s"CREATE TABLE $t1 (id bigint, data string, p int) " + + s"USING $v2Format PARTITIONED BY (id, p)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id = 2, p = 2) SELECT data FROM $view") + verifyTable(t1, Seq( + (2, "a", 2), + (2, "b", 2), + (2, "c", 2), + (4, "keep", 2)).toDF("id", "data", "p")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala index 8b859e951b9b9..d51eafa5a8aed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala @@ -497,6 +497,26 @@ abstract class SchemaPruningSuite Row(Row("Janet", null, "Jones"), "Jones") ::Nil) } + testSchemaPruning("SPARK-32163: nested pruning should work even with cosmetic variations") { + withTempView("contact_alias") { + sql("select * from contacts") + .repartition(100, col("name.first"), col("name.last")) + .selectExpr("name").createOrReplaceTempView("contact_alias") + + val query1 = sql("select name.first from contact_alias") + checkScan(query1, "struct>") + checkAnswer(query1, Row("Jane") :: Row("John") :: Row("Jim") :: Row("Janet") ::Nil) + + sql("select * from contacts") + .select(explode(col("friends.first")), col("friends")) + .createOrReplaceTempView("contact_alias") + + val query2 = sql("select friends.middle, col from contact_alias") + checkScan(query2, "struct>>") + checkAnswer(query2, Row(Array("Z."), "Susan") :: Nil) + } + } + protected def testSchemaPruning(testName: String)(testThunk: => Unit): Unit = { test(s"Spark vectorized reader - without partition data column - $testName") { withSQLConf(vectorizedReaderEnabledKey -> "true") { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 1404ece76449e..57ed15a76a893 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -18,26 +18,22 @@ package org.apache.spark.sql.hive.thriftserver import java.security.PrivilegedExceptionAction -import java.sql.{Date, Timestamp} -import java.util.{Arrays, Map => JMap, UUID} +import java.util.{Arrays, Map => JMap} import java.util.concurrent.RejectedExecutionException import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal -import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.shims.Utils import org.apache.hive.service.cli._ import org.apache.hive.service.cli.operation.ExecuteStatementOperation import org.apache.hive.service.cli.session.HiveSession -import org.apache.spark.SparkContext import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLContext} import org.apache.spark.sql.execution.HiveResult.{getTimeFormatters, toHiveString, TimeFormatters} -import org.apache.spark.sql.execution.command.SetCommand import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -315,16 +311,11 @@ private[hive] class SparkExecuteStatementOperation( } else { logError(s"Error executing query with $statementId, currentState $currentState, ", e) setState(OperationState.ERROR) + HiveThriftServer2.eventManager.onStatementError( + statementId, e.getMessage, SparkUtils.exceptionString(e)) e match { - case hiveException: HiveSQLException => - HiveThriftServer2.eventManager.onStatementError( - statementId, hiveException.getMessage, SparkUtils.exceptionString(hiveException)) - throw hiveException - case _ => - val root = ExceptionUtils.getRootCause(e) - HiveThriftServer2.eventManager.onStatementError( - statementId, root.getMessage, SparkUtils.exceptionString(root)) - throw new HiveSQLException("Error running query: " + root.toString, root) + case _: HiveSQLException => throw e + case _ => throw new HiveSQLException("Error running query: " + e.toString, e) } } } finally { @@ -342,8 +333,8 @@ private[hive] class SparkExecuteStatementOperation( synchronized { if (!getStatus.getState.isTerminal) { logInfo(s"Cancel query with $statementId") - cleanup() setState(OperationState.CANCELED) + cleanup() HiveThriftServer2.eventManager.onStatementCanceled(statementId) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetCatalogsOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetCatalogsOperation.scala index 55070e035b944..01ef78cde8956 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetCatalogsOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetCatalogsOperation.scala @@ -17,17 +17,13 @@ package org.apache.spark.sql.hive.thriftserver -import java.util.UUID - -import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.hadoop.hive.ql.security.authorization.plugin.HiveOperationType -import org.apache.hive.service.cli.{HiveSQLException, OperationState} +import org.apache.hive.service.cli.OperationState import org.apache.hive.service.cli.operation.GetCatalogsOperation import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.internal.Logging import org.apache.spark.sql.SQLContext -import org.apache.spark.util.{Utils => SparkUtils} /** * Spark's own GetCatalogsOperation @@ -62,22 +58,8 @@ private[hive] class SparkGetCatalogsOperation( authorizeMetaGets(HiveOperationType.GET_CATALOGS, null) } setState(OperationState.FINISHED) - } catch { - case e: Throwable => - logError(s"Error executing get catalogs operation with $statementId", e) - setState(OperationState.ERROR) - e match { - case hiveException: HiveSQLException => - HiveThriftServer2.eventManager.onStatementError( - statementId, hiveException.getMessage, SparkUtils.exceptionString(hiveException)) - throw hiveException - case _ => - val root = ExceptionUtils.getRootCause(e) - HiveThriftServer2.eventManager.onStatementError( - statementId, root.getMessage, SparkUtils.exceptionString(root)) - throw new HiveSQLException("Error getting catalogs: " + root.toString, root) - } - } + } catch onError() + HiveThriftServer2.eventManager.onStatementFinish(statementId) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetColumnsOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetColumnsOperation.scala index ca8ad5e6ad134..d42732f426681 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetColumnsOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetColumnsOperation.scala @@ -35,7 +35,6 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.hive.thriftserver.ThriftserverShimUtils.toJavaSQLType import org.apache.spark.sql.types.StructType -import org.apache.spark.util.{Utils => SparkUtils} /** * Spark's own SparkGetColumnsOperation @@ -122,22 +121,8 @@ private[hive] class SparkGetColumnsOperation( } } setState(OperationState.FINISHED) - } catch { - case e: Throwable => - logError(s"Error executing get columns operation with $statementId", e) - setState(OperationState.ERROR) - e match { - case hiveException: HiveSQLException => - HiveThriftServer2.eventManager.onStatementError( - statementId, hiveException.getMessage, SparkUtils.exceptionString(hiveException)) - throw hiveException - case _ => - val root = ExceptionUtils.getRootCause(e) - HiveThriftServer2.eventManager.onStatementError( - statementId, root.getMessage, SparkUtils.exceptionString(root)) - throw new HiveSQLException("Error getting columns: " + root.toString, root) - } - } + } catch onError() + HiveThriftServer2.eventManager.onStatementFinish(statementId) } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetFunctionsOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetFunctionsOperation.scala index f5e647bfd4f38..cf5dbae93a365 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetFunctionsOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetFunctionsOperation.scala @@ -98,22 +98,8 @@ private[hive] class SparkGetFunctionsOperation( } } setState(OperationState.FINISHED) - } catch { - case e: Throwable => - logError(s"Error executing get functions operation with $statementId", e) - setState(OperationState.ERROR) - e match { - case hiveException: HiveSQLException => - HiveThriftServer2.eventManager.onStatementError( - statementId, hiveException.getMessage, SparkUtils.exceptionString(hiveException)) - throw hiveException - case _ => - val root = ExceptionUtils.getRootCause(e) - HiveThriftServer2.eventManager.onStatementError( - statementId, root.getMessage, SparkUtils.exceptionString(root)) - throw new HiveSQLException("Error getting functions: " + root.toString, root) - } - } + } catch onError() + HiveThriftServer2.eventManager.onStatementFinish(statementId) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala index 74220986fcd34..16fd502048e80 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala @@ -81,22 +81,8 @@ private[hive] class SparkGetSchemasOperation( rowSet.addRow(Array[AnyRef](globalTempViewDb, DEFAULT_HIVE_CATALOG)) } setState(OperationState.FINISHED) - } catch { - case e: Throwable => - logError(s"Error executing get schemas operation with $statementId", e) - setState(OperationState.ERROR) - e match { - case hiveException: HiveSQLException => - HiveThriftServer2.eventManager.onStatementError( - statementId, hiveException.getMessage, SparkUtils.exceptionString(hiveException)) - throw hiveException - case _ => - val root = ExceptionUtils.getRootCause(e) - HiveThriftServer2.eventManager.onStatementError( - statementId, root.getMessage, SparkUtils.exceptionString(root)) - throw new HiveSQLException("Error getting schemas: " + root.toString, root) - } - } + } catch onError() + HiveThriftServer2.eventManager.onStatementFinish(statementId) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTableTypesOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTableTypesOperation.scala index 1cf9c3a731af5..9e31b8baad78e 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTableTypesOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTableTypesOperation.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.hive.thriftserver import java.util.UUID -import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.hadoop.hive.ql.security.authorization.plugin.HiveOperationType import org.apache.hive.service.cli._ import org.apache.hive.service.cli.operation.GetTableTypesOperation @@ -28,7 +27,6 @@ import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.internal.Logging import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.catalog.CatalogTableType -import org.apache.spark.util.{Utils => SparkUtils} /** * Spark's own GetTableTypesOperation @@ -69,22 +67,8 @@ private[hive] class SparkGetTableTypesOperation( rowSet.addRow(Array[AnyRef](tableType)) } setState(OperationState.FINISHED) - } catch { - case e: Throwable => - logError(s"Error executing get table types operation with $statementId", e) - setState(OperationState.ERROR) - e match { - case hiveException: HiveSQLException => - HiveThriftServer2.eventManager.onStatementError( - statementId, hiveException.getMessage, SparkUtils.exceptionString(hiveException)) - throw hiveException - case _ => - val root = ExceptionUtils.getRootCause(e) - HiveThriftServer2.eventManager.onStatementError( - statementId, root.getMessage, SparkUtils.exceptionString(root)) - throw new HiveSQLException("Error getting table types: " + root.toString, root) - } - } + } catch onError() + HiveThriftServer2.eventManager.onStatementFinish(statementId) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala index a1d21e2d60c63..0d4b9b392f074 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala @@ -17,14 +17,12 @@ package org.apache.spark.sql.hive.thriftserver -import java.util.{List => JList, UUID} +import java.util.{List => JList} import java.util.regex.Pattern import scala.collection.JavaConverters._ -import org.apache.commons.lang3.exception.ExceptionUtils -import org.apache.hadoop.hive.ql.security.authorization.plugin.HiveOperationType -import org.apache.hadoop.hive.ql.security.authorization.plugin.HivePrivilegeObjectUtils +import org.apache.hadoop.hive.ql.security.authorization.plugin.{HiveOperationType, HivePrivilegeObjectUtils} import org.apache.hive.service.cli._ import org.apache.hive.service.cli.operation.GetTablesOperation import org.apache.hive.service.cli.session.HiveSession @@ -33,7 +31,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.catalog.CatalogTableType._ import org.apache.spark.sql.hive.HiveUtils -import org.apache.spark.util.{Utils => SparkUtils} /** * Spark's own GetTablesOperation @@ -111,22 +108,8 @@ private[hive] class SparkGetTablesOperation( } } setState(OperationState.FINISHED) - } catch { - case e: Throwable => - logError(s"Error executing get tables operation with $statementId", e) - setState(OperationState.ERROR) - e match { - case hiveException: HiveSQLException => - HiveThriftServer2.eventManager.onStatementError( - statementId, hiveException.getMessage, SparkUtils.exceptionString(hiveException)) - throw hiveException - case _ => - val root = ExceptionUtils.getRootCause(e) - HiveThriftServer2.eventManager.onStatementError( - statementId, root.getMessage, SparkUtils.exceptionString(root)) - throw new HiveSQLException("Error getting tables: " + root.toString, root) - } - } + } catch onError() + HiveThriftServer2.eventManager.onStatementFinish(statementId) } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTypeInfoOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTypeInfoOperation.scala index e38139d60df60..c2568ad4ada0a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTypeInfoOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTypeInfoOperation.scala @@ -19,15 +19,13 @@ package org.apache.spark.sql.hive.thriftserver import java.util.UUID -import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.hadoop.hive.ql.security.authorization.plugin.HiveOperationType -import org.apache.hive.service.cli.{HiveSQLException, OperationState} +import org.apache.hive.service.cli.OperationState import org.apache.hive.service.cli.operation.GetTypeInfoOperation import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.internal.Logging import org.apache.spark.sql.SQLContext -import org.apache.spark.util.{Utils => SparkUtils} /** * Spark's own GetTypeInfoOperation @@ -87,22 +85,8 @@ private[hive] class SparkGetTypeInfoOperation( rowSet.addRow(rowData) }) setState(OperationState.FINISHED) - } catch { - case e: Throwable => - logError(s"Error executing get type info with $statementId", e) - setState(OperationState.ERROR) - e match { - case hiveException: HiveSQLException => - HiveThriftServer2.eventManager.onStatementError( - statementId, hiveException.getMessage, SparkUtils.exceptionString(hiveException)) - throw hiveException - case _ => - val root = ExceptionUtils.getRootCause(e) - HiveThriftServer2.eventManager.onStatementError( - statementId, root.getMessage, SparkUtils.exceptionString(root)) - throw new HiveSQLException("Error getting type info: " + root.toString, root) - } - } + } catch onError() + HiveThriftServer2.eventManager.onStatementFinish(statementId) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkOperation.scala index 3da568cfa256e..446669d08e76b 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkOperation.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.thriftserver -import org.apache.hive.service.cli.OperationState +import org.apache.hive.service.cli.{HiveSQLException, OperationState} import org.apache.hive.service.cli.operation.Operation import org.apache.spark.SparkContext @@ -46,8 +46,8 @@ private[hive] trait SparkOperation extends Operation with Logging { } abstract override def close(): Unit = { - cleanup() super.close() + cleanup() logInfo(s"Close statement with $statementId") HiveThriftServer2.eventManager.onOperationClosed(statementId) } @@ -93,4 +93,16 @@ private[hive] trait SparkOperation extends Operation with Logging { case t => throw new IllegalArgumentException(s"Unknown table type is found: $t") } + + protected def onError(): PartialFunction[Throwable, Unit] = { + case e: Throwable => + logError(s"Error executing get catalogs operation with $statementId", e) + super.setState(OperationState.ERROR) + HiveThriftServer2.eventManager.onStatementError( + statementId, e.getMessage, Utils.exceptionString(e)) + e match { + case _: HiveSQLException => throw e + case _ => throw new HiveSQLException(s"Error operating $getType ${e.getMessage}", e) + } + } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SharedThriftServer.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SharedThriftServer.scala index 3d7933fba17d8..5f17607585521 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SharedThriftServer.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SharedThriftServer.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.thriftserver import java.io.File import java.sql.{DriverManager, Statement} +import java.util import scala.collection.JavaConverters._ import scala.concurrent.duration._ @@ -27,7 +28,12 @@ import scala.util.Try import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.metadata.Hive import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.hive.service.cli.thrift.ThriftCLIService +import org.apache.hive.jdbc.HttpBasicAuthInterceptor +import org.apache.hive.service.auth.PlainSaslHelper +import org.apache.hive.service.cli.thrift.{ThriftCLIService, ThriftCLIServiceClient} +import org.apache.http.impl.client.HttpClientBuilder +import org.apache.thrift.protocol.TBinaryProtocol +import org.apache.thrift.transport.{THttpClient, TSocket} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.Utils @@ -76,8 +82,9 @@ trait SharedThriftServer extends SharedSparkSession { s"jdbc:hive2://localhost:$serverPort/" } + protected def user: String = System.getProperty("user.name") + protected def withJdbcStatement(fs: (Statement => Unit)*): Unit = { - val user = System.getProperty("user.name") require(serverPort != 0, "Failed to bind an actual port for HiveThriftServer2") val connections = fs.map { _ => DriverManager.getConnection(jdbcUri, user, "") } @@ -91,6 +98,29 @@ trait SharedThriftServer extends SharedSparkSession { } } + protected def withCLIServiceClient(f: ThriftCLIServiceClient => Unit): Unit = { + require(serverPort != 0, "Failed to bind an actual port for HiveThriftServer2") + val transport = mode match { + case ServerMode.binary => + val rawTransport = new TSocket("localhost", serverPort) + PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport) + case ServerMode.http => + val interceptor = new HttpBasicAuthInterceptor( + user, + "anonymous", + null, null, true, new util.HashMap[String, String]()) + new THttpClient( + s"http://localhost:$serverPort/cliservice", + HttpClientBuilder.create.addInterceptorFirst(interceptor).build()) + } + + val protocol = new TBinaryProtocol(transport) + val client = new ThriftCLIServiceClient(new ThriftserverShimUtils.Client(protocol)) + + transport.open() + try f(client) finally transport.close() + } + private def startThriftServer(attempt: Int): Unit = { logInfo(s"Trying to start HiveThriftServer2: mode=$mode, attempt=$attempt") val sqlContext = spark.newSession().sqlContext diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala index 13df3fabc4919..4c2f29e0bf394 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala @@ -17,10 +17,25 @@ package org.apache.spark.sql.hive.thriftserver +import java.util +import java.util.concurrent.Semaphore + +import scala.concurrent.duration._ + +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hive.service.cli.OperationState +import org.apache.hive.service.cli.session.{HiveSession, HiveSessionImpl} +import org.mockito.Mockito.{doReturn, mock, spy, when, RETURNS_DEEP_STUBS} +import org.mockito.invocation.InvocationOnMock + import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.hive.thriftserver.ui.HiveThriftServer2EventManager +import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, NullType, StringType, StructField, StructType} -class SparkExecuteStatementOperationSuite extends SparkFunSuite { +class SparkExecuteStatementOperationSuite extends SparkFunSuite with SharedSparkSession { + test("SPARK-17112 `select null` via JDBC triggers IllegalArgumentException in ThriftServer") { val field1 = StructField("NULL", NullType) val field2 = StructField("(IF(true, NULL, NULL))", NullType) @@ -42,4 +57,68 @@ class SparkExecuteStatementOperationSuite extends SparkFunSuite { assert(columns.get(1).getType().getName == "INT") assert(columns.get(1).getComment() == "") } + + Seq( + (OperationState.CANCELED, (_: SparkExecuteStatementOperation).cancel()), + (OperationState.CLOSED, (_: SparkExecuteStatementOperation).close()) + ).foreach { case (finalState, transition) => + test("SPARK-32057 SparkExecuteStatementOperation should not transiently become ERROR " + + s"before being set to $finalState") { + val hiveSession = new HiveSessionImpl(ThriftserverShimUtils.testedProtocolVersions.head, + "username", "password", new HiveConf, "ip address") + hiveSession.open(new util.HashMap) + + HiveThriftServer2.eventManager = mock(classOf[HiveThriftServer2EventManager]) + + val spySqlContext = spy(sqlContext) + + // When cancel() is called on the operation, cleanup causes an exception to be thrown inside + // of execute(). This should not cause the state to become ERROR. The exception here will be + // triggered in our custom cleanup(). + val signal = new Semaphore(0) + val dataFrame = mock(classOf[DataFrame], RETURNS_DEEP_STUBS) + when(dataFrame.collect()).thenAnswer((_: InvocationOnMock) => { + signal.acquire() + throw new RuntimeException("Operation was cancelled by test cleanup.") + }) + val statement = "stmt" + doReturn(dataFrame, Nil: _*).when(spySqlContext).sql(statement) + + val executeStatementOperation = new MySparkExecuteStatementOperation(spySqlContext, + hiveSession, statement, signal, finalState) + + val run = new Thread() { + override def run(): Unit = executeStatementOperation.runInternal() + } + assert(executeStatementOperation.getStatus.getState === OperationState.INITIALIZED) + run.start() + eventually(timeout(5.seconds)) { + assert(executeStatementOperation.getStatus.getState === OperationState.RUNNING) + } + transition(executeStatementOperation) + run.join() + assert(executeStatementOperation.getStatus.getState === finalState) + } + } + + private class MySparkExecuteStatementOperation( + sqlContext: SQLContext, + hiveSession: HiveSession, + statement: String, + signal: Semaphore, + finalState: OperationState) + extends SparkExecuteStatementOperation(sqlContext, hiveSession, statement, + new util.HashMap, false) { + + override def cleanup(): Unit = { + super.cleanup() + signal.release() + // At this point, operation should already be in finalState (set by either close() or + // cancel()). We want to check if it stays in finalState after the exception thrown by + // releasing the semaphore propagates. We hence need to sleep for a short while. + Thread.sleep(1000) + // State should not be ERROR + assert(getStatus.getState === finalState) + } + } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala index 1382eb2d79f38..fd3a638c4fa44 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.hive.thriftserver +import java.sql.SQLException + +import org.apache.hive.service.cli.HiveSQLException + trait ThriftServerWithSparkContextSuite extends SharedThriftServer { test("the scratch dir will be deleted during server start but recreated with new operation") { @@ -45,6 +49,36 @@ trait ThriftServerWithSparkContextSuite extends SharedThriftServer { assert(cacheManager.isEmpty) } } + + test("Full stack traces as error message for jdbc or thrift client") { + val sql = "select date_sub(date'2011-11-11', '1.2')" + withCLIServiceClient { client => + val sessionHandle = client.openSession(user, "") + + val confOverlay = new java.util.HashMap[java.lang.String, java.lang.String] + val e = intercept[HiveSQLException] { + client.executeStatement( + sessionHandle, + sql, + confOverlay) + } + + assert(e.getMessage + .contains("The second argument of 'date_sub' function needs to be an integer.")) + assert(!e.getMessage.contains("" + + "java.lang.NumberFormatException: invalid input syntax for type numeric: 1.2")) + } + + withJdbcStatement { statement => + val e = intercept[SQLException] { + statement.executeQuery(sql) + } + assert(e.getMessage + .contains("The second argument of 'date_sub' function needs to be an integer.")) + assert(e.getMessage.contains("" + + "java.lang.NumberFormatException: invalid input syntax for type numeric: 1.2")) + } + } } diff --git a/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java b/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java index 783e5795aca76..ff533769b5b84 100644 --- a/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java +++ b/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java @@ -564,7 +564,8 @@ public TGetOperationStatusResp GetOperationStatus(TGetOperationStatusReq req) th if (opException != null) { resp.setSqlState(opException.getSQLState()); resp.setErrorCode(opException.getErrorCode()); - resp.setErrorMessage(opException.getMessage()); + resp.setErrorMessage(org.apache.hadoop.util.StringUtils + .stringifyException(opException)); } resp.setStatus(OK_STATUS); } catch (Exception e) { diff --git a/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java b/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java index e46799a1c427d..914d6d3612596 100644 --- a/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java +++ b/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java @@ -566,7 +566,8 @@ public TGetOperationStatusResp GetOperationStatus(TGetOperationStatusReq req) th if (opException != null) { resp.setSqlState(opException.getSQLState()); resp.setErrorCode(opException.getErrorCode()); - resp.setErrorMessage(opException.getMessage()); + resp.setErrorMessage(org.apache.hadoop.util.StringUtils + .stringifyException(opException)); } resp.setStatus(OK_STATUS); } catch (Exception e) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index b9c98f4ea15e9..2b1eb05e22cc7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoStatement, LogicalPlan, ScriptTransformation, Statistics} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.CatalogV2Util.assertNoNullTypeInSchema import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils} import org.apache.spark.sql.execution.datasources.CreateTable @@ -225,6 +226,8 @@ case class RelationConversions( isConvertible(tableDesc) && SQLConf.get.getConf(HiveUtils.CONVERT_METASTORE_CTAS) => // validation is required to be done here before relation conversion. DDLUtils.checkDataColNames(tableDesc.copy(schema = query.schema)) + // This is for CREATE TABLE .. STORED AS PARQUET/ORC AS SELECT null + assertNoNullTypeInSchema(query.schema) OptimizedCreateHiveTableAsSelectCommand( tableDesc, query, query.output.map(_.name), mode) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index e8cf4ad5d9f28..774fb5b4b9ad5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.connector.FakeV2Provider import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.connector.catalog.SupportsNamespaces.PROP_OWNER import org.apache.spark.sql.execution.command.{DDLSuite, DDLUtils} @@ -2309,6 +2310,126 @@ class HiveDDLSuite } } + test("SPARK-20680: Spark-sql do not support for unknown column datatype") { + withTable("t") { + withView("tabUnknownType") { + hiveClient.runSqlHive("CREATE TABLE t (t1 int)") + hiveClient.runSqlHive("INSERT INTO t VALUES (3)") + hiveClient.runSqlHive("CREATE VIEW tabUnknownType AS SELECT NULL AS col FROM t") + checkAnswer(spark.table("tabUnknownType"), Row(null)) + // No exception shows + val desc = spark.sql("DESC tabUnknownType").collect().toSeq + assert(desc.contains(Row("col", NullType.simpleString, null))) + } + } + + // Forbid CTAS with unknown type + withTable("t1", "t2", "t3") { + val e1 = intercept[AnalysisException] { + spark.sql("CREATE TABLE t1 USING PARQUET AS SELECT null as null_col") + }.getMessage + assert(e1.contains("Cannot create tables with unknown type")) + + val e2 = intercept[AnalysisException] { + spark.sql("CREATE TABLE t2 AS SELECT null as null_col") + }.getMessage + assert(e2.contains("Cannot create tables with unknown type")) + + val e3 = intercept[AnalysisException] { + spark.sql("CREATE TABLE t3 STORED AS PARQUET AS SELECT null as null_col") + }.getMessage + assert(e3.contains("Cannot create tables with unknown type")) + } + + // Forbid Replace table AS SELECT with unknown type + withTable("t") { + val v2Source = classOf[FakeV2Provider].getName + val e = intercept[AnalysisException] { + spark.sql(s"CREATE OR REPLACE TABLE t USING $v2Source AS SELECT null as null_col") + }.getMessage + assert(e.contains("Cannot create tables with unknown type")) + } + + // Forbid creating table with VOID type in Spark + withTable("t1", "t2", "t3", "t4") { + val e1 = intercept[AnalysisException] { + spark.sql(s"CREATE TABLE t1 (v VOID) USING PARQUET") + }.getMessage + assert(e1.contains("Cannot create tables with unknown type")) + val e2 = intercept[AnalysisException] { + spark.sql(s"CREATE TABLE t2 (v VOID) USING hive") + }.getMessage + assert(e2.contains("Cannot create tables with unknown type")) + val e3 = intercept[AnalysisException] { + spark.sql(s"CREATE TABLE t3 (v VOID)") + }.getMessage + assert(e3.contains("Cannot create tables with unknown type")) + val e4 = intercept[AnalysisException] { + spark.sql(s"CREATE TABLE t4 (v VOID) STORED AS PARQUET") + }.getMessage + assert(e4.contains("Cannot create tables with unknown type")) + } + + // Forbid Replace table with VOID type + withTable("t") { + val v2Source = classOf[FakeV2Provider].getName + val e = intercept[AnalysisException] { + spark.sql(s"CREATE OR REPLACE TABLE t (v VOID) USING $v2Source") + }.getMessage + assert(e.contains("Cannot create tables with unknown type")) + } + + // Make sure spark.catalog.createTable with null type will fail + val schema1 = new StructType().add("c", NullType) + assertHiveTableNullType(schema1) + assertDSTableNullType(schema1) + + val schema2 = new StructType() + .add("c", StructType(Seq(StructField.apply("c1", NullType)))) + assertHiveTableNullType(schema2) + assertDSTableNullType(schema2) + + val schema3 = new StructType().add("c", ArrayType(NullType)) + assertHiveTableNullType(schema3) + assertDSTableNullType(schema3) + + val schema4 = new StructType() + .add("c", MapType(StringType, NullType)) + assertHiveTableNullType(schema4) + assertDSTableNullType(schema4) + + val schema5 = new StructType() + .add("c", MapType(NullType, StringType)) + assertHiveTableNullType(schema5) + assertDSTableNullType(schema5) + } + + private def assertHiveTableNullType(schema: StructType): Unit = { + withTable("t") { + val e = intercept[AnalysisException] { + spark.catalog.createTable( + tableName = "t", + source = "hive", + schema = schema, + options = Map("fileFormat" -> "parquet")) + }.getMessage + assert(e.contains("Cannot create tables with unknown type")) + } + } + + private def assertDSTableNullType(schema: StructType): Unit = { + withTable("t") { + val e = intercept[AnalysisException] { + spark.catalog.createTable( + tableName = "t", + source = "json", + schema = schema, + options = Map.empty[String, String]) + }.getMessage + assert(e.contains("Cannot create tables with unknown type")) + } + } + test("SPARK-21216: join with a streaming DataFrame") { import org.apache.spark.sql.execution.streaming.MemoryStream import testImplicits._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala index 91fd8a47339fc..61c48c6f9c115 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala @@ -121,7 +121,7 @@ class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { msg = intercept[AnalysisException] { sql("select null").write.mode("overwrite").orc(orcDir) }.getMessage - assert(msg.contains("ORC data source does not support null data type.")) + assert(msg.contains("ORC data source does not support unknown data type.")) msg = intercept[AnalysisException] { spark.udf.register("testType", () => new IntervalData())