diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index fc54d89a1a4b3..27994ed76b2af 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -4098,14 +4098,13 @@ test_that("catalog APIs, listTables, getTable, listColumns, listFunctions, funct c("name", "description", "dataType", "nullable", "isPartition", "isBucket")) expect_equal(collect(c)[[1]][[1]], "speed") expect_error(listColumns("zxwtyswklpf", "default"), - paste("Error in listColumns : analysis error - Table", - "'zxwtyswklpf' does not exist in database 'default'")) + paste("Table or view not found: spark_catalog.default.zxwtyswklpf")) f <- listFunctions() expect_true(nrow(f) >= 200) # 250 expect_equal(colnames(f), c("name", "catalog", "namespace", "description", "className", "isTemporary")) - expect_equal(take(orderBy(f, "className"), 1)$className, + expect_equal(take(orderBy(filter(f, "className IS NOT NULL"), "className"), 1)$className, "org.apache.spark.sql.catalyst.expressions.Abs") expect_error(listFunctions("zxwtyswklpf_db"), paste("Error in listFunctions : no such database - Database", diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index c3b7db15d21d9..3d18d20518410 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -457,18 +457,20 @@ static ConcurrentMap reloadRegisteredExecutors(D throws IOException { ConcurrentMap registeredExecutors = Maps.newConcurrentMap(); if (db != null) { - DBIterator itr = db.iterator(); - itr.seek(APP_KEY_PREFIX.getBytes(StandardCharsets.UTF_8)); - while (itr.hasNext()) { - Map.Entry e = itr.next(); - String key = new String(e.getKey(), StandardCharsets.UTF_8); - if (!key.startsWith(APP_KEY_PREFIX)) { - break; + try (DBIterator itr = db.iterator()) { + itr.seek(APP_KEY_PREFIX.getBytes(StandardCharsets.UTF_8)); + while (itr.hasNext()) { + Map.Entry e = itr.next(); + String key = new String(e.getKey(), StandardCharsets.UTF_8); + if (!key.startsWith(APP_KEY_PREFIX)) { + break; + } + AppExecId id = parseDbAppExecKey(key); + logger.info("Reloading registered executors: " + id.toString()); + ExecutorShuffleInfo shuffleInfo = + mapper.readValue(e.getValue(), ExecutorShuffleInfo.class); + registeredExecutors.put(id, shuffleInfo); } - AppExecId id = parseDbAppExecKey(key); - logger.info("Reloading registered executors: " + id.toString()); - ExecutorShuffleInfo shuffleInfo = mapper.readValue(e.getValue(), ExecutorShuffleInfo.class); - registeredExecutors.put(id, shuffleInfo); } } return registeredExecutors; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java index 677adc76bffc5..9483e48ca446c 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java @@ -417,9 +417,7 @@ void removeAppAttemptPathInfoFromDB(String appId, int attemptId) { if (db != null) { try { byte[] key = getDbAppAttemptPathsKey(appAttemptId); - if (db.get(key) != null) { - db.delete(key); - } + db.delete(key); } catch (Exception e) { logger.error("Failed to remove the application attempt {} local path in DB", appAttemptId, e); @@ -909,39 +907,40 @@ void reloadAndCleanUpAppShuffleInfo(DB db) throws IOException { List reloadActiveAppAttemptsPathInfo(DB db) throws IOException { List dbKeysToBeRemoved = new ArrayList<>(); if (db != null) { - DBIterator itr = db.iterator(); - itr.seek(APP_ATTEMPT_PATH_KEY_PREFIX.getBytes(StandardCharsets.UTF_8)); - while (itr.hasNext()) { - Map.Entry entry = itr.next(); - String key = new String(entry.getKey(), StandardCharsets.UTF_8); - if (!key.startsWith(APP_ATTEMPT_PATH_KEY_PREFIX)) { - break; - } - AppAttemptId appAttemptId = parseDbAppAttemptPathsKey(key); - AppPathsInfo appPathsInfo = mapper.readValue(entry.getValue(), AppPathsInfo.class); - logger.debug("Reloading Application paths info for application {}", appAttemptId); - appsShuffleInfo.compute(appAttemptId.appId, - (appId, existingAppShuffleInfo) -> { - if (existingAppShuffleInfo == null || - existingAppShuffleInfo.attemptId < appAttemptId.attemptId) { - if (existingAppShuffleInfo != null) { - AppAttemptId existingAppAttemptId = new AppAttemptId( - existingAppShuffleInfo.appId, existingAppShuffleInfo.attemptId); - try { - // Add the former outdated DB key to deletion list - dbKeysToBeRemoved.add(getDbAppAttemptPathsKey(existingAppAttemptId)); - } catch (IOException e) { - logger.error("Failed to get the DB key for {}", existingAppAttemptId, e); + try (DBIterator itr = db.iterator()) { + itr.seek(APP_ATTEMPT_PATH_KEY_PREFIX.getBytes(StandardCharsets.UTF_8)); + while (itr.hasNext()) { + Map.Entry entry = itr.next(); + String key = new String(entry.getKey(), StandardCharsets.UTF_8); + if (!key.startsWith(APP_ATTEMPT_PATH_KEY_PREFIX)) { + break; + } + AppAttemptId appAttemptId = parseDbAppAttemptPathsKey(key); + AppPathsInfo appPathsInfo = mapper.readValue(entry.getValue(), AppPathsInfo.class); + logger.debug("Reloading Application paths info for application {}", appAttemptId); + appsShuffleInfo.compute(appAttemptId.appId, + (appId, existingAppShuffleInfo) -> { + if (existingAppShuffleInfo == null || + existingAppShuffleInfo.attemptId < appAttemptId.attemptId) { + if (existingAppShuffleInfo != null) { + AppAttemptId existingAppAttemptId = new AppAttemptId( + existingAppShuffleInfo.appId, existingAppShuffleInfo.attemptId); + try { + // Add the former outdated DB key to deletion list + dbKeysToBeRemoved.add(getDbAppAttemptPathsKey(existingAppAttemptId)); + } catch (IOException e) { + logger.error("Failed to get the DB key for {}", existingAppAttemptId, e); + } } + return new AppShuffleInfo( + appAttemptId.appId, appAttemptId.attemptId, appPathsInfo); + } else { + // Add the current DB key to deletion list as it is outdated + dbKeysToBeRemoved.add(entry.getKey()); + return existingAppShuffleInfo; } - return new AppShuffleInfo( - appAttemptId.appId, appAttemptId.attemptId, appPathsInfo); - } else { - // Add the current DB key to deletion list as it is outdated - dbKeysToBeRemoved.add(entry.getKey()); - return existingAppShuffleInfo; - } - }); + }); + } } } return dbKeysToBeRemoved; @@ -954,41 +953,44 @@ List reloadActiveAppAttemptsPathInfo(DB db) throws IOException { List reloadFinalizedAppAttemptsShuffleMergeInfo(DB db) throws IOException { List dbKeysToBeRemoved = new ArrayList<>(); if (db != null) { - DBIterator itr = db.iterator(); - itr.seek(APP_ATTEMPT_SHUFFLE_FINALIZE_STATUS_KEY_PREFIX.getBytes(StandardCharsets.UTF_8)); - while (itr.hasNext()) { - Map.Entry entry = itr.next(); - String key = new String(entry.getKey(), StandardCharsets.UTF_8); - if (!key.startsWith(APP_ATTEMPT_SHUFFLE_FINALIZE_STATUS_KEY_PREFIX)) { - break; - } - AppAttemptShuffleMergeId partitionId = parseDbAppAttemptShufflePartitionKey(key); - logger.debug("Reloading finalized shuffle info for partitionId {}", partitionId); - AppShuffleInfo appShuffleInfo = appsShuffleInfo.get(partitionId.appId); - if (appShuffleInfo != null && appShuffleInfo.attemptId == partitionId.attemptId) { - appShuffleInfo.shuffles.compute(partitionId.shuffleId, - (shuffleId, existingMergePartitionInfo) -> { - if (existingMergePartitionInfo == null || - existingMergePartitionInfo.shuffleMergeId < partitionId.shuffleMergeId) { - if (existingMergePartitionInfo != null) { - AppAttemptShuffleMergeId appAttemptShuffleMergeId = - new AppAttemptShuffleMergeId(appShuffleInfo.appId, appShuffleInfo.attemptId, - shuffleId, existingMergePartitionInfo.shuffleMergeId); - try{ - dbKeysToBeRemoved.add( - getDbAppAttemptShufflePartitionKey(appAttemptShuffleMergeId)); - } catch (Exception e) { - logger.error("Error getting the DB key for {}", appAttemptShuffleMergeId, e); + try (DBIterator itr = db.iterator()) { + itr.seek(APP_ATTEMPT_SHUFFLE_FINALIZE_STATUS_KEY_PREFIX.getBytes(StandardCharsets.UTF_8)); + while (itr.hasNext()) { + Map.Entry entry = itr.next(); + String key = new String(entry.getKey(), StandardCharsets.UTF_8); + if (!key.startsWith(APP_ATTEMPT_SHUFFLE_FINALIZE_STATUS_KEY_PREFIX)) { + break; + } + AppAttemptShuffleMergeId partitionId = parseDbAppAttemptShufflePartitionKey(key); + logger.debug("Reloading finalized shuffle info for partitionId {}", partitionId); + AppShuffleInfo appShuffleInfo = appsShuffleInfo.get(partitionId.appId); + if (appShuffleInfo != null && appShuffleInfo.attemptId == partitionId.attemptId) { + appShuffleInfo.shuffles.compute(partitionId.shuffleId, + (shuffleId, existingMergePartitionInfo) -> { + if (existingMergePartitionInfo == null || + existingMergePartitionInfo.shuffleMergeId < partitionId.shuffleMergeId) { + if (existingMergePartitionInfo != null) { + AppAttemptShuffleMergeId appAttemptShuffleMergeId = + new AppAttemptShuffleMergeId( + appShuffleInfo.appId, appShuffleInfo.attemptId, + shuffleId, existingMergePartitionInfo.shuffleMergeId); + try{ + dbKeysToBeRemoved.add( + getDbAppAttemptShufflePartitionKey(appAttemptShuffleMergeId)); + } catch (Exception e) { + logger.error("Error getting the DB key for {}", + appAttemptShuffleMergeId, e); + } } + return new AppShuffleMergePartitionsInfo(partitionId.shuffleMergeId, true); + } else { + dbKeysToBeRemoved.add(entry.getKey()); + return existingMergePartitionInfo; } - return new AppShuffleMergePartitionsInfo(partitionId.shuffleMergeId, true); - } else { - dbKeysToBeRemoved.add(entry.getKey()); - return existingMergePartitionInfo; - } - }); - } else { - dbKeysToBeRemoved.add(entry.getKey()); + }); + } else { + dbKeysToBeRemoved.add(entry.getKey()); + } } } } diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index 9295239f9964f..af3f9b112fb98 100644 --- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -341,18 +341,19 @@ private void loadSecretsFromDb() throws IOException { logger.info("Recovery location is: " + secretsFile.getPath()); if (db != null) { logger.info("Going to reload spark shuffle data"); - DBIterator itr = db.iterator(); - itr.seek(APP_CREDS_KEY_PREFIX.getBytes(StandardCharsets.UTF_8)); - while (itr.hasNext()) { - Map.Entry e = itr.next(); - String key = new String(e.getKey(), StandardCharsets.UTF_8); - if (!key.startsWith(APP_CREDS_KEY_PREFIX)) { - break; + try (DBIterator itr = db.iterator()) { + itr.seek(APP_CREDS_KEY_PREFIX.getBytes(StandardCharsets.UTF_8)); + while (itr.hasNext()) { + Map.Entry e = itr.next(); + String key = new String(e.getKey(), StandardCharsets.UTF_8); + if (!key.startsWith(APP_CREDS_KEY_PREFIX)) { + break; + } + String id = parseDbAppKey(key); + ByteBuffer secret = mapper.readValue(e.getValue(), ByteBuffer.class); + logger.info("Reloading tokens for app: " + id); + secretManager.registerApp(id, secret); } - String id = parseDbAppKey(key); - ByteBuffer secret = mapper.readValue(e.getValue(), ByteBuffer.class); - logger.info("Reloading tokens for app: " + id); - secretManager.registerApp(id, secret); } } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 72a03a4d1fbbc..589a1ffa713fd 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -931,8 +931,12 @@ package object config { private[spark] val LISTENER_BUS_METRICS_MAX_LISTENER_CLASSES_TIMED = ConfigBuilder("spark.scheduler.listenerbus.metrics.maxListenerClassesTimed") .internal() + .doc("The number of listeners that have timers to track the elapsed time of" + + "processing events. If 0 is set, disables this feature. If -1 is set," + + "it sets no limit to the number.") .version("2.3.0") .intConf + .checkValue(_ >= -1, "The number of listeners should be larger than -1.") .createWithDefault(128) private[spark] val LISTENER_BUS_LOG_SLOW_EVENT_ENABLED = diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index 4be4e7a88753b..104038fc209d3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -292,10 +292,14 @@ private[spark] class LiveListenerBusMetrics(conf: SparkConf) val maxTimed = conf.get(LISTENER_BUS_METRICS_MAX_LISTENER_CLASSES_TIMED) perListenerClassTimers.get(className).orElse { if (perListenerClassTimers.size == maxTimed) { - logError(s"Not measuring processing time for listener class $className because a " + - s"maximum of $maxTimed listener classes are already timed.") + if (maxTimed != 0) { + // Explicitly disabled. + logError(s"Not measuring processing time for listener class $className because a " + + s"maximum of $maxTimed listener classes are already timed.") + } None } else { + // maxTimed is either -1 (no limit), or an explicit number. perListenerClassTimers(className) = metricRegistry.timer(MetricRegistry.name("listenerProcessingTime", className)) perListenerClassTimers.get(className) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index d72744c5cc348..dca915e0a97ac 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -599,6 +599,22 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match assert(bus.getQueueCapacity(EVENT_LOG_QUEUE) == Some(2)) } + test("SPARK-39973: Suppress error logs when the number of timers is set to 0") { + sc = new SparkContext( + "local", + "SparkListenerSuite", + new SparkConf().set( + LISTENER_BUS_METRICS_MAX_LISTENER_CLASSES_TIMED.key, 0.toString)) + val testAppender = new LogAppender("Error logger for timers") + withLogAppender(testAppender) { + sc.addSparkListener(new SparkListener { }) + sc.addSparkListener(new SparkListener { }) + } + assert(!testAppender.loggingEvents + .exists(_.getMessage.getFormattedMessage.contains( + "Not measuring processing time for listener"))) + } + /** * Assert that the given list of numbers has an average that is greater than zero. */ diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 1b3027cdae662..854919a9af657 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -25,7 +25,7 @@ automaton/1.11-8//automaton-1.11-8.jar avro-ipc/1.11.0//avro-ipc-1.11.0.jar avro-mapred/1.11.0//avro-mapred-1.11.0.jar avro/1.11.0//avro-1.11.0.jar -aws-java-sdk-bundle/1.11.1026//aws-java-sdk-bundle-1.11.1026.jar +aws-java-sdk-bundle/1.12.262//aws-java-sdk-bundle-1.12.262.jar azure-data-lake-store-sdk/2.3.9//azure-data-lake-store-sdk-2.3.9.jar azure-keyvault-core/1.0.0//azure-keyvault-core-1.0.0.jar azure-storage/7.0.1//azure-storage-7.0.1.jar @@ -52,7 +52,6 @@ commons-math3/3.6.1//commons-math3-3.6.1.jar commons-pool/1.5.4//commons-pool-1.5.4.jar commons-text/1.9//commons-text-1.9.jar compress-lzf/1.1//compress-lzf-1.1.jar -cos_api-bundle/5.6.19//cos_api-bundle-5.6.19.jar curator-client/2.13.0//curator-client-2.13.0.jar curator-framework/2.13.0//curator-framework-2.13.0.jar curator-recipes/2.13.0//curator-recipes-2.13.0.jar @@ -66,18 +65,17 @@ generex/1.0.2//generex-1.0.2.jar gmetric4j/1.0.10//gmetric4j-1.0.10.jar gson/2.2.4//gson-2.2.4.jar guava/14.0.1//guava-14.0.1.jar -hadoop-aliyun/3.3.3//hadoop-aliyun-3.3.3.jar -hadoop-annotations/3.3.3//hadoop-annotations-3.3.3.jar -hadoop-aws/3.3.3//hadoop-aws-3.3.3.jar -hadoop-azure-datalake/3.3.3//hadoop-azure-datalake-3.3.3.jar -hadoop-azure/3.3.3//hadoop-azure-3.3.3.jar -hadoop-client-api/3.3.3//hadoop-client-api-3.3.3.jar -hadoop-client-runtime/3.3.3//hadoop-client-runtime-3.3.3.jar -hadoop-cloud-storage/3.3.3//hadoop-cloud-storage-3.3.3.jar -hadoop-cos/3.3.3//hadoop-cos-3.3.3.jar -hadoop-openstack/3.3.3//hadoop-openstack-3.3.3.jar +hadoop-aliyun/3.3.4//hadoop-aliyun-3.3.4.jar +hadoop-annotations/3.3.4//hadoop-annotations-3.3.4.jar +hadoop-aws/3.3.4//hadoop-aws-3.3.4.jar +hadoop-azure-datalake/3.3.4//hadoop-azure-datalake-3.3.4.jar +hadoop-azure/3.3.4//hadoop-azure-3.3.4.jar +hadoop-client-api/3.3.4//hadoop-client-api-3.3.4.jar +hadoop-client-runtime/3.3.4//hadoop-client-runtime-3.3.4.jar +hadoop-cloud-storage/3.3.4//hadoop-cloud-storage-3.3.4.jar +hadoop-openstack/3.3.4//hadoop-openstack-3.3.4.jar hadoop-shaded-guava/1.1.1//hadoop-shaded-guava-1.1.1.jar -hadoop-yarn-server-web-proxy/3.3.3//hadoop-yarn-server-web-proxy-3.3.3.jar +hadoop-yarn-server-web-proxy/3.3.4//hadoop-yarn-server-web-proxy-3.3.4.jar hive-beeline/2.3.9//hive-beeline-2.3.9.jar hive-cli/2.3.9//hive-cli-2.3.9.jar hive-common/2.3.9//hive-common-2.3.9.jar diff --git a/pom.xml b/pom.xml index da1e3a8a1e5eb..d6404e80c9892 100644 --- a/pom.xml +++ b/pom.xml @@ -115,7 +115,7 @@ 1.7.36 2.18.0 - 3.3.3 + 3.3.4 2.5.0 ${hadoop.version} 3.6.2 diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst index c66396bf4de41..ea495445426fb 100644 --- a/python/docs/source/reference/pyspark.sql/functions.rst +++ b/python/docs/source/reference/pyspark.sql/functions.rst @@ -235,8 +235,10 @@ Aggregate Functions max max_by mean + median min min_by + mode percentile_approx product skewness diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 548750d712025..10c9ab5f6d239 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -164,7 +164,7 @@ def getDatabase(self, dbName: str) -> Database: Examples -------- >>> spark.catalog.getDatabase("default") - Database(name='default', catalog=None, description='default database', ... + Database(name='default', catalog='spark_catalog', description='default database', ... >>> spark.catalog.getDatabase("spark_catalog.default") Database(name='default', catalog='spark_catalog', description='default database', ... """ @@ -376,9 +376,9 @@ def getFunction(self, functionName: str) -> Function: -------- >>> func = spark.sql("CREATE FUNCTION my_func1 AS 'test.org.apache.spark.sql.MyDoubleAvg'") >>> spark.catalog.getFunction("my_func1") - Function(name='my_func1', catalog=None, namespace=['default'], ... + Function(name='my_func1', catalog='spark_catalog', namespace=['default'], ... >>> spark.catalog.getFunction("default.my_func1") - Function(name='my_func1', catalog=None, namespace=['default'], ... + Function(name='my_func1', catalog='spark_catalog', namespace=['default'], ... >>> spark.catalog.getFunction("spark_catalog.default.my_func1") Function(name='my_func1', catalog='spark_catalog', namespace=['default'], ... >>> spark.catalog.getFunction("my_func2") diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 2997be08872c2..e73c70d8ca06f 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -187,6 +187,40 @@ def abs(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("abs", col) +def mode(col: "ColumnOrName") -> Column: + """ + Returns the most frequent value in a group. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column that the value will be returned + + Returns + ------- + :class:`~pyspark.sql.Column` + the most frequent value in a group. + + Examples + -------- + >>> df = spark.createDataFrame([ + ... ("Java", 2012, 20000), ("dotNET", 2012, 5000), + ... ("Java", 2012, 20000), ("dotNET", 2012, 5000), + ... ("dotNET", 2013, 48000), ("Java", 2013, 30000)], + ... schema=("course", "year", "earnings")) + >>> df.groupby("course").agg(mode("year")).show() + +------+----------+ + |course|mode(year)| + +------+----------+ + | Java| 2012| + |dotNET| 2012| + +------+----------+ + """ + return _invoke_function_over_columns("mode", col) + + @since(1.3) def max(col: "ColumnOrName") -> Column: """ @@ -305,6 +339,40 @@ def mean(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("mean", col) +def median(col: "ColumnOrName") -> Column: + """ + Returns the median of the values in a group. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column that the value will be returned + + Returns + ------- + :class:`~pyspark.sql.Column` + the median of the values in a group. + + Examples + -------- + >>> df = spark.createDataFrame([ + ... ("Java", 2012, 20000), ("dotNET", 2012, 5000), + ... ("Java", 2012, 22000), ("dotNET", 2012, 10000), + ... ("dotNET", 2013, 48000), ("Java", 2013, 30000)], + ... schema=("course", "year", "earnings")) + >>> df.groupby("course").agg(median("earnings")).show() + +------+----------------+ + |course|median(earnings)| + +------+----------------+ + | Java| 22000.0| + |dotNET| 10000.0| + +------+----------------+ + """ + return _invoke_function_over_columns("median", col) + + @since(1.3) def sumDistinct(col: "ColumnOrName") -> Column: """ diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index bece13684e087..2fbe76aa5ae92 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -25,7 +25,6 @@ from pyspark.sql.session import SparkSession from pyspark.sql.dataframe import DataFrame from pyspark.sql.pandas.group_ops import PandasGroupedOpsMixin -from pyspark.sql.types import StructType, StructField, IntegerType, StringType if TYPE_CHECKING: from pyspark.sql._typing import LiteralType @@ -112,20 +111,53 @@ def agg(self, *exprs: Union[Column, Dict[str, str]]) -> DataFrame: Examples -------- - >>> gdf = df.groupBy(df.name) - >>> sorted(gdf.agg({"*": "count"}).collect()) - [Row(name='Alice', count(1)=1), Row(name='Bob', count(1)=1)] - >>> from pyspark.sql import functions as F - >>> sorted(gdf.agg(F.min(df.age)).collect()) - [Row(name='Alice', min(age)=2), Row(name='Bob', min(age)=5)] - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> df = spark.createDataFrame( + ... [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")], ["age", "name"]) + >>> df.show() + +---+-----+ + |age| name| + +---+-----+ + | 2|Alice| + | 3|Alice| + | 5| Bob| + | 10| Bob| + +---+-----+ + + Group-by name, and count each group. + + >>> df.groupBy(df.name).agg({"*": "count"}).sort("name").show() + +-----+--------+ + | name|count(1)| + +-----+--------+ + |Alice| 2| + | Bob| 2| + +-----+--------+ + + Group-by name, and calculate the minimum age. + + >>> df.groupBy(df.name).agg(F.min(df.age)).sort("name").show() + +-----+--------+ + | name|min(age)| + +-----+--------+ + |Alice| 2| + | Bob| 5| + +-----+--------+ + + Same as above but uses pandas UDF. + >>> @pandas_udf('int', PandasUDFType.GROUPED_AGG) # doctest: +SKIP ... def min_udf(v): ... return v.min() - >>> sorted(gdf.agg(min_udf(df.age)).collect()) # doctest: +SKIP - [Row(name='Alice', min_udf(age)=2), Row(name='Bob', min_udf(age)=5)] + ... + >>> df.groupBy(df.name).agg(min_udf(df.age)).sort("name").show() # doctest: +SKIP + +-----+------------+ + | name|min_udf(age)| + +-----+------------+ + |Alice| 2| + | Bob| 5| + +-----+------------+ """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): @@ -145,8 +177,27 @@ def count(self) -> DataFrame: Examples -------- - >>> sorted(df.groupBy(df.age).count().collect()) - [Row(age=2, count=1), Row(age=5, count=1)] + >>> df = spark.createDataFrame( + ... [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")], ["age", "name"]) + >>> df.show() + +---+-----+ + |age| name| + +---+-----+ + | 2|Alice| + | 3|Alice| + | 5| Bob| + | 10| Bob| + +---+-----+ + + Group-by name, and count each group. + + >>> df.groupBy(df.name).count().sort("name").show() + +-----+-----+ + | name|count| + +-----+-----+ + |Alice| 2| + | Bob| 2| + +-----+-----+ """ @df_varargs_api @@ -161,13 +212,6 @@ def mean(self, *cols: str) -> DataFrame: ---------- cols : str column names. Non-numeric columns are ignored. - - Examples - -------- - >>> df.groupBy().mean('age').collect() - [Row(avg(age)=3.5)] - >>> df3.groupBy().mean('age', 'height').collect() - [Row(avg(age)=3.5, avg(height)=82.5)] """ @df_varargs_api @@ -185,10 +229,37 @@ def avg(self, *cols: str) -> DataFrame: Examples -------- - >>> df.groupBy().avg('age').collect() - [Row(avg(age)=3.5)] - >>> df3.groupBy().avg('age', 'height').collect() - [Row(avg(age)=3.5, avg(height)=82.5)] + >>> df = spark.createDataFrame([ + ... (2, "Alice", 80), (3, "Alice", 100), + ... (5, "Bob", 120), (10, "Bob", 140)], ["age", "name", "height"]) + >>> df.show() + +---+-----+------+ + |age| name|height| + +---+-----+------+ + | 2|Alice| 80| + | 3|Alice| 100| + | 5| Bob| 120| + | 10| Bob| 140| + +---+-----+------+ + + Group-by name, and calculate the mean of the age in each group. + + >>> df.groupBy("name").avg('age').sort("name").show() + +-----+--------+ + | name|avg(age)| + +-----+--------+ + |Alice| 2.5| + | Bob| 7.5| + +-----+--------+ + + Calculate the mean of the age and height in all data. + + >>> df.groupBy().avg('age', 'height').show() + +--------+-----------+ + |avg(age)|avg(height)| + +--------+-----------+ + | 5.0| 110.0| + +--------+-----------+ """ @df_varargs_api @@ -199,10 +270,37 @@ def max(self, *cols: str) -> DataFrame: Examples -------- - >>> df.groupBy().max('age').collect() - [Row(max(age)=5)] - >>> df3.groupBy().max('age', 'height').collect() - [Row(max(age)=5, max(height)=85)] + >>> df = spark.createDataFrame([ + ... (2, "Alice", 80), (3, "Alice", 100), + ... (5, "Bob", 120), (10, "Bob", 140)], ["age", "name", "height"]) + >>> df.show() + +---+-----+------+ + |age| name|height| + +---+-----+------+ + | 2|Alice| 80| + | 3|Alice| 100| + | 5| Bob| 120| + | 10| Bob| 140| + +---+-----+------+ + + Group-by name, and calculate the max of the age in each group. + + >>> df.groupBy("name").max("age").sort("name").show() + +-----+--------+ + | name|max(age)| + +-----+--------+ + |Alice| 3| + | Bob| 10| + +-----+--------+ + + Calculate the max of the age and height in all data. + + >>> df.groupBy().max("age", "height").show() + +--------+-----------+ + |max(age)|max(height)| + +--------+-----------+ + | 10| 140| + +--------+-----------+ """ @df_varargs_api @@ -218,10 +316,37 @@ def min(self, *cols: str) -> DataFrame: Examples -------- - >>> df.groupBy().min('age').collect() - [Row(min(age)=2)] - >>> df3.groupBy().min('age', 'height').collect() - [Row(min(age)=2, min(height)=80)] + >>> df = spark.createDataFrame([ + ... (2, "Alice", 80), (3, "Alice", 100), + ... (5, "Bob", 120), (10, "Bob", 140)], ["age", "name", "height"]) + >>> df.show() + +---+-----+------+ + |age| name|height| + +---+-----+------+ + | 2|Alice| 80| + | 3|Alice| 100| + | 5| Bob| 120| + | 10| Bob| 140| + +---+-----+------+ + + Group-by name, and calculate the min of the age in each group. + + >>> df.groupBy("name").min("age").sort("name").show() + +-----+--------+ + | name|min(age)| + +-----+--------+ + |Alice| 2| + | Bob| 5| + +-----+--------+ + + Calculate the min of the age and height in all data. + + >>> df.groupBy().min("age", "height").show() + +--------+-----------+ + |min(age)|min(height)| + +--------+-----------+ + | 2| 80| + +--------+-----------+ """ @df_varargs_api @@ -237,10 +362,37 @@ def sum(self, *cols: str) -> DataFrame: Examples -------- - >>> df.groupBy().sum('age').collect() - [Row(sum(age)=7)] - >>> df3.groupBy().sum('age', 'height').collect() - [Row(sum(age)=7, sum(height)=165)] + >>> df = spark.createDataFrame([ + ... (2, "Alice", 80), (3, "Alice", 100), + ... (5, "Bob", 120), (10, "Bob", 140)], ["age", "name", "height"]) + >>> df.show() + +---+-----+------+ + |age| name|height| + +---+-----+------+ + | 2|Alice| 80| + | 3|Alice| 100| + | 5| Bob| 120| + | 10| Bob| 140| + +---+-----+------+ + + Group-by name, and calculate the sum of the age in each group. + + >>> df.groupBy("name").sum("age").sort("name").show() + +-----+--------+ + | name|sum(age)| + +-----+--------+ + |Alice| 5| + | Bob| 15| + +-----+--------+ + + Calculate the sum of the age and height in all data. + + >>> df.groupBy().sum("age", "height").show() + +--------+-----------+ + |sum(age)|sum(height)| + +--------+-----------+ + | 20| 440| + +--------+-----------+ """ def pivot(self, pivot_col: str, values: Optional[List["LiteralType"]] = None) -> "GroupedData": @@ -261,17 +413,69 @@ def pivot(self, pivot_col: str, values: Optional[List["LiteralType"]] = None) -> Examples -------- - # Compute the sum of earnings for each year by course with each course as a separate column - - >>> df4.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").collect() - [Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, dotNET=48000, Java=30000)] - - # Or without specifying column values (less efficient) - - >>> df4.groupBy("year").pivot("course").sum("earnings").collect() - [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)] - >>> df5.groupBy("sales.year").pivot("sales.course").sum("sales.earnings").collect() - [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)] + >>> from pyspark.sql import Row + >>> spark = SparkSession.builder.master("local[4]").appName("sql.group tests").getOrCreate() + >>> df1 = spark.createDataFrame([ + ... Row(course="dotNET", year=2012, earnings=10000), + ... Row(course="Java", year=2012, earnings=20000), + ... Row(course="dotNET", year=2012, earnings=5000), + ... Row(course="dotNET", year=2013, earnings=48000), + ... Row(course="Java", year=2013, earnings=30000), + ... ]) + >>> df1.show() + +------+----+--------+ + |course|year|earnings| + +------+----+--------+ + |dotNET|2012| 10000| + | Java|2012| 20000| + |dotNET|2012| 5000| + |dotNET|2013| 48000| + | Java|2013| 30000| + +------+----+--------+ + >>> df2 = spark.createDataFrame([ + ... Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=10000)), + ... Row(training="junior", sales=Row(course="Java", year=2012, earnings=20000)), + ... Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=5000)), + ... Row(training="junior", sales=Row(course="dotNET", year=2013, earnings=48000)), + ... Row(training="expert", sales=Row(course="Java", year=2013, earnings=30000)), + ... ]) + >>> df2.show() + +--------+--------------------+ + |training| sales| + +--------+--------------------+ + | expert|{dotNET, 2012, 10...| + | junior| {Java, 2012, 20000}| + | expert|{dotNET, 2012, 5000}| + | junior|{dotNET, 2013, 48...| + | expert| {Java, 2013, 30000}| + +--------+--------------------+ + + Compute the sum of earnings for each year by course with each course as a separate column + + >>> df1.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").show() + +----+------+-----+ + |year|dotNET| Java| + +----+------+-----+ + |2012| 15000|20000| + |2013| 48000|30000| + +----+------+-----+ + + Or without specifying column values (less efficient) + + >>> df1.groupBy("year").pivot("course").sum("earnings").show() + +----+-----+------+ + |year| Java|dotNET| + +----+-----+------+ + |2012|20000| 15000| + |2013|30000| 48000| + +----+-----+------+ + >>> df2.groupBy("sales.year").pivot("sales.course").sum("sales.earnings").show() + +----+-----+------+ + |year| Java|dotNET| + +----+-----+------+ + |2012|20000| 15000| + |2013|30000| 48000| + +----+-----+------+ """ if values is None: jgd = self._jgd.pivot(pivot_col) @@ -282,7 +486,7 @@ def pivot(self, pivot_col: str, values: Optional[List["LiteralType"]] = None) -> def _test() -> None: import doctest - from pyspark.sql import Row, SparkSession + from pyspark.sql import SparkSession import pyspark.sql.group globs = pyspark.sql.group.__dict__.copy() @@ -290,30 +494,6 @@ def _test() -> None: sc = spark.sparkContext globs["sc"] = sc globs["spark"] = spark - globs["df"] = sc.parallelize([(2, "Alice"), (5, "Bob")]).toDF( - StructType([StructField("age", IntegerType()), StructField("name", StringType())]) - ) - globs["df3"] = sc.parallelize( - [Row(name="Alice", age=2, height=80), Row(name="Bob", age=5, height=85)] - ).toDF() - globs["df4"] = sc.parallelize( - [ - Row(course="dotNET", year=2012, earnings=10000), - Row(course="Java", year=2012, earnings=20000), - Row(course="dotNET", year=2012, earnings=5000), - Row(course="dotNET", year=2013, earnings=48000), - Row(course="Java", year=2013, earnings=30000), - ] - ).toDF() - globs["df5"] = sc.parallelize( - [ - Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=10000)), - Row(training="junior", sales=Row(course="Java", year=2012, earnings=20000)), - Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=5000)), - Row(training="junior", sales=Row(course="dotNET", year=2013, earnings=48000)), - Row(training="expert", sales=Row(course="Java", year=2013, earnings=30000)), - ] - ).toDF() (failure_count, test_count) = doctest.testmod( pyspark.sql.group, diff --git a/python/pyspark/sql/tests/test_catalog.py b/python/pyspark/sql/tests/test_catalog.py index 7d81234bce256..24cd67251a8f7 100644 --- a/python/pyspark/sql/tests/test_catalog.py +++ b/python/pyspark/sql/tests/test_catalog.py @@ -198,7 +198,7 @@ def test_list_functions(self): self.assertTrue("to_unix_timestamp" in functions) self.assertTrue("current_database" in functions) self.assertEqual(functions["+"].name, "+") - self.assertEqual(functions["+"].description, None) + self.assertEqual(functions["+"].description, "expr1 + expr2 - Returns `expr1`+`expr2`.") self.assertEqual( functions["+"].className, "org.apache.spark.sql.catalyst.expressions.Add" ) diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index b8bc90f458cdf..f895e2010ce49 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -128,11 +128,23 @@ def rowsBetween(start: int, end: int) -> "WindowSpec": -------- >>> from pyspark.sql import Window >>> from pyspark.sql import functions as func - >>> from pyspark.sql import SQLContext - >>> sc = SparkContext.getOrCreate() - >>> sqlContext = SQLContext(sc) - >>> tup = [(1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")] - >>> df = sqlContext.createDataFrame(tup, ["id", "category"]) + >>> df = spark.createDataFrame( + ... [(1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")], ["id", "category"]) + >>> df.show() + +---+--------+ + | id|category| + +---+--------+ + | 1| a| + | 1| a| + | 2| a| + | 1| b| + | 2| b| + | 3| b| + +---+--------+ + + Calculate sum of ``id`` in the range from currentRow to currentRow + 1 + in partition ``category`` + >>> window = Window.partitionBy("category").orderBy("id").rowsBetween(Window.currentRow, 1) >>> df.withColumn("sum", func.sum("id").over(window)).sort("id", "category", "sum").show() +---+--------+---+ @@ -196,11 +208,23 @@ def rangeBetween(start: int, end: int) -> "WindowSpec": -------- >>> from pyspark.sql import Window >>> from pyspark.sql import functions as func - >>> from pyspark.sql import SQLContext - >>> sc = SparkContext.getOrCreate() - >>> sqlContext = SQLContext(sc) - >>> tup = [(1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")] - >>> df = sqlContext.createDataFrame(tup, ["id", "category"]) + >>> df = spark.createDataFrame( + ... [(1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")], ["id", "category"]) + >>> df.show() + +---+--------+ + | id|category| + +---+--------+ + | 1| a| + | 1| a| + | 2| a| + | 1| b| + | 2| b| + | 3| b| + +---+--------+ + + Calculate sum of ``id`` in the range from ``id`` of currentRow to ``id`` of currentRow + 1 + in partition ``category`` + >>> window = Window.partitionBy("category").orderBy("id").rangeBetween(Window.currentRow, 1) >>> df.withColumn("sum", func.sum("id").over(window)).sort("id", "category").show() +---+--------+---+ @@ -329,13 +353,17 @@ def rangeBetween(self, start: int, end: int) -> "WindowSpec": def _test() -> None: import doctest + from pyspark.sql import SparkSession import pyspark.sql.window - SparkContext("local[4]", "PythonTest") globs = pyspark.sql.window.__dict__.copy() + spark = SparkSession.builder.master("local[4]").appName("sql.window tests").getOrCreate() + globs["spark"] = spark + (failure_count, test_count) = doctest.testmod( pyspark.sql.window, globs=globs, optionflags=doctest.NORMALIZE_WHITESPACE ) + spark.stop() if failure_count: sys.exit(-1) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala index 43c6597362e41..985b8b7bef051 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -141,11 +141,13 @@ private[spark] class KubernetesClusterSchedulerBackend( } } - Utils.tryLogNonFatalError { - kubernetesClient - .persistentVolumeClaims() - .withLabel(SPARK_APP_ID_LABEL, applicationId()) - .delete() + if (conf.get(KUBERNETES_DRIVER_OWN_PVC)) { + Utils.tryLogNonFatalError { + kubernetesClient + .persistentVolumeClaims() + .withLabel(SPARK_APP_ID_LABEL, applicationId()) + .delete() + } } if (shouldDeleteExecutors) { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 476201c9a8d8e..83ac80bf4ba51 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -46,10 +46,9 @@ /** * An Unsafe implementation of Row which is backed by raw memory instead of Java objects. * - * Each tuple has three parts: [null bit set] [values] [variable length portion] + * Each tuple has three parts: [null-tracking bit set] [values] [variable length portion] * - * The bit set is used for null tracking and is aligned to 8-byte word boundaries. It stores - * one bit per field. + * The null-tracking bit set is aligned to 8-byte word boundaries. It stores one bit per field. * * In the `values` region, we store one 8-byte word per field. For fields that hold fixed-length * primitive types, such as long, double, or int, we store the value directly in the word. For diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java index 26b97b46fe2ef..44111913f124b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java @@ -42,4 +42,9 @@ public Cast(Expression expression, DataType dataType) { @Override public Expression[] children() { return new Expression[]{ expression() }; } + + @Override + public String toString() { + return "CAST(" + expression.describe() + " AS " + dataType.typeName() + ")"; + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 16d89c9b2e48e..a0c98aac6c4ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -971,13 +971,17 @@ class SessionCatalog( } def lookupTempView(name: TableIdentifier): Option[View] = { - val tableName = formatTableName(name.table) - if (name.database.isEmpty) { - tempViews.get(tableName).map(getTempViewPlan) - } else if (formatDatabaseName(name.database.get) == globalTempViewManager.database) { - globalTempViewManager.get(tableName).map(getTempViewPlan) - } else { - None + lookupLocalOrGlobalRawTempView(name.database.toSeq :+ name.table).map(getTempViewPlan) + } + + /** + * Return the raw logical plan of a temporary local or global view for the given name. + */ + def lookupLocalOrGlobalRawTempView(name: Seq[String]): Option[TemporaryViewRelation] = { + name match { + case Seq(v) => getRawTempView(v) + case Seq(db, v) if isGlobalTempViewDB(db) => getRawGlobalTempView(v) + case _ => None } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 424b82533fca4..25e1889109e8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -112,6 +112,7 @@ object Cast { case (StringType, _: AnsiIntervalType) => true case (_: AnsiIntervalType, _: IntegralType | _: DecimalType) => true + case (_: IntegralType, _: AnsiIntervalType) => true case (_: DayTimeIntervalType, _: DayTimeIntervalType) => true case (_: YearMonthIntervalType, _: YearMonthIntervalType) => true @@ -196,6 +197,7 @@ object Cast { case (_: DayTimeIntervalType, _: DayTimeIntervalType) => true case (_: YearMonthIntervalType, _: YearMonthIntervalType) => true case (_: AnsiIntervalType, _: IntegralType | _: DecimalType) => true + case (_: IntegralType, _: AnsiIntervalType) => true case (StringType, _: NumericType) => true case (BooleanType, _: NumericType) => true @@ -786,7 +788,6 @@ case class Cast( case _: DayTimeIntervalType => buildCast[Long](_, s => IntervalUtils.durationToMicros(IntervalUtils.microsToDuration(s), it.endField)) case x: IntegralType => - assert(it.startField == it.endField) if (x == LongType) { b => IntervalUtils.longToDayTimeInterval( x.integral.asInstanceOf[Integral[Any]].toLong(b), it.endField) @@ -804,7 +805,6 @@ case class Cast( case _: YearMonthIntervalType => buildCast[Int](_, s => IntervalUtils.periodToMonths(IntervalUtils.monthsToPeriod(s), it.endField)) case x: IntegralType => - assert(it.startField == it.endField) if (x == LongType) { b => IntervalUtils.longToYearMonthInterval( x.integral.asInstanceOf[Integral[Any]].toLong(b), it.endField) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala index f672153436901..6d346c80ab7de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala @@ -49,7 +49,7 @@ case class Mode( // Returns null for empty inputs override def nullable: Boolean = true - override val dataType: DataType = child.dataType + override def dataType: DataType = child.dataType override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala index 236636ac7ea11..8c63012c6814d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate -import org.apache.spark.sql.catalyst.planning.{ExtractEquiJoinKeys, PhysicalOperation} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.{INVOKE, JSON_TO_STRUCT, LIKE_FAMLIY, PYTHON_UDF, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, SCALA_UDF} @@ -117,13 +117,38 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J * do not add a subquery that might have an expensive computation */ private def isSelectiveFilterOverScan(plan: LogicalPlan): Boolean = { - val ret = plan match { - case PhysicalOperation(_, filters, child) if child.isInstanceOf[LeafNode] => - filters.forall(isSimpleExpression) && - filters.exists(isLikelySelective) + def isSelective( + p: LogicalPlan, + predicateReference: AttributeSet, + hasHitFilter: Boolean, + hasHitSelectiveFilter: Boolean): Boolean = p match { + case Project(projectList, child) => + if (hasHitFilter) { + // We need to make sure all expressions referenced by filter predicates are simple + // expressions. + val referencedExprs = projectList.filter(predicateReference.contains) + referencedExprs.forall(isSimpleExpression) && + isSelective( + child, + referencedExprs.map(_.references).foldLeft(AttributeSet.empty)(_ ++ _), + hasHitFilter, + hasHitSelectiveFilter) + } else { + assert(predicateReference.isEmpty && !hasHitSelectiveFilter) + isSelective(child, predicateReference, hasHitFilter, hasHitSelectiveFilter) + } + case Filter(condition, child) => + isSimpleExpression(condition) && isSelective( + child, + predicateReference ++ condition.references, + hasHitFilter = true, + hasHitSelectiveFilter = hasHitSelectiveFilter || isLikelySelective(condition)) + case _: LeafNode => hasHitSelectiveFilter case _ => false } - !plan.isStreaming && ret + + !plan.isStreaming && + isSelective(plan, AttributeSet.empty, hasHitFilter = false, hasHitSelectiveFilter = false) } private def isSimpleExpression(e: Expression): Boolean = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownThroughWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownThroughWindow.scala index 635434741b944..88f92262dcc20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownThroughWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownThroughWindow.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentRow, DenseRank, IntegerLiteral, NamedExpression, NTile, Rank, RowFrame, RowNumber, SpecifiedWindowFrame, UnboundedPreceding, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentRow, DenseRank, IntegerLiteral, NamedExpression, Rank, RowFrame, RowNumber, SpecifiedWindowFrame, UnboundedPreceding, WindowExpression, WindowSpecDefinition} import org.apache.spark.sql.catalyst.plans.logical.{Limit, LocalLimit, LogicalPlan, Project, Sort, Window} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.{LIMIT, WINDOW} @@ -33,8 +33,7 @@ object LimitPushDownThroughWindow extends Rule[LogicalPlan] { // The window frame of RankLike and RowNumberLike can only be UNBOUNDED PRECEDING to CURRENT ROW. private def supportsPushdownThroughWindow( windowExpressions: Seq[NamedExpression]): Boolean = windowExpressions.forall { - case Alias(WindowExpression(_: Rank | _: DenseRank | _: NTile | _: RowNumber, - WindowSpecDefinition(Nil, _, + case Alias(WindowExpression(_: Rank | _: DenseRank | _: RowNumber, WindowSpecDefinition(Nil, _, SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow))), _) => true case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushdownPredicatesAndPruneColumnsForCTEDef.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushdownPredicatesAndPruneColumnsForCTEDef.scala index 2195eef2fc93b..fde22b249b690 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushdownPredicatesAndPruneColumnsForCTEDef.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushdownPredicatesAndPruneColumnsForCTEDef.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeSet, Expression, Literal, Or, SubqueryExpression} -import org.apache.spark.sql.catalyst.planning.ScanOperation +import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.CTE @@ -69,7 +69,7 @@ object PushdownPredicatesAndPruneColumnsForCTEDef extends Rule[LogicalPlan] { } gatherPredicatesAndAttributes(child, cteMap) - case ScanOperation(projects, predicates, ref: CTERelationRef) => + case PhysicalOperation(projects, predicates, ref: CTERelationRef) => val (cteDef, precedence, preds, attrs) = cteMap(ref.cteId) val attrMapping = ref.output.zip(cteDef.output).map{ case (r, d) => r -> d }.toMap val newPredicates = if (isTruePredicate(preds)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala index bf3fced0ae0fd..74085436870e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala @@ -21,7 +21,7 @@ import scala.annotation.tailrec import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.planning.{NodeWithOnlyDeterministicProjectAndFilter, PhysicalOperation} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -82,7 +82,8 @@ object StarSchemaDetection extends PredicateHelper with SQLConfHelper { // Find if the input plans are eligible for star join detection. // An eligible plan is a base table access with valid statistics. val foundEligibleJoin = input.forall { - case PhysicalOperation(_, _, t: LeafNode) if t.stats.rowCount.isDefined => true + case NodeWithOnlyDeterministicProjectAndFilter(t: LeafNode) + if t.stats.rowCount.isDefined => true case _ => false } @@ -177,7 +178,7 @@ object StarSchemaDetection extends PredicateHelper with SQLConfHelper { private def isUnique( column: Attribute, plan: LogicalPlan): Boolean = plan match { - case PhysicalOperation(_, _, t: LeafNode) => + case NodeWithOnlyDeterministicProjectAndFilter(t: LeafNode) => val leafCol = findLeafNodeCol(column, plan) leafCol match { case Some(col) if t.outputSet.contains(col) => @@ -212,7 +213,7 @@ object StarSchemaDetection extends PredicateHelper with SQLConfHelper { private def findLeafNodeCol( column: Attribute, plan: LogicalPlan): Option[Attribute] = plan match { - case pl @ PhysicalOperation(_, _, _: LeafNode) => + case pl @ NodeWithOnlyDeterministicProjectAndFilter(_: LeafNode) => pl match { case t: LeafNode if t.outputSet.contains(column) => Option(column) @@ -233,7 +234,7 @@ object StarSchemaDetection extends PredicateHelper with SQLConfHelper { private def hasStatistics( column: Attribute, plan: LogicalPlan): Boolean = plan match { - case PhysicalOperation(_, _, t: LeafNode) => + case NodeWithOnlyDeterministicProjectAndFilter(t: LeafNode) => val leafCol = findLeafNodeCol(column, plan) leafCol match { case Some(col) if t.outputSet.contains(col) => @@ -296,7 +297,7 @@ object StarSchemaDetection extends PredicateHelper with SQLConfHelper { */ private def getTableAccessCardinality( input: LogicalPlan): Option[BigInt] = input match { - case PhysicalOperation(_, cond, t: LeafNode) if t.stats.rowCount.isDefined => + case NodeWithOnlyDeterministicProjectAndFilter(t: LeafNode) if t.stats.rowCount.isDefined => if (conf.cboEnabled && input.stats.rowCount.isDefined) { Option(input.stats.rowCount.get) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 4e12b811acd1b..72546ea73dd9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -29,7 +29,14 @@ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation} import org.apache.spark.sql.internal.SQLConf -trait OperationHelper extends PredicateHelper { +/** + * A pattern that matches any number of project or filter operations even if they are + * non-deterministic, as long as they satisfy the requirement of CollapseProject and CombineFilters. + * All filter operators are collected and their conditions are broken up and returned + * together with the top project operator. [[Alias Aliases]] are in-lined/substituted if + * necessary. + */ +object PhysicalOperation extends AliasHelper with PredicateHelper { import org.apache.spark.sql.catalyst.optimizer.CollapseProject.canCollapseExpressions type ReturnType = @@ -43,16 +50,6 @@ trait OperationHelper extends PredicateHelper { Some((fields.getOrElse(child.output), filters, child)) } - /** - * This legacy mode is for PhysicalOperation which has been there for years and we want to be - * extremely safe to not change its behavior. There are two differences when legacy mode is off: - * 1. We postpone the deterministic check to the very end (calling `canCollapseExpressions`), - * so that it's more likely to collect more projects and filters. - * 2. We follow CollapseProject and only collect adjacent projects if they don't produce - * repeated expensive expressions. - */ - protected def legacyMode: Boolean - /** * Collects all adjacent projects and filters, in-lining/substituting aliases if necessary. * Here are two examples for alias in-lining/substitution. @@ -73,31 +70,27 @@ trait OperationHelper extends PredicateHelper { def empty: IntermediateType = (None, Nil, plan, AttributeMap.empty) plan match { - case Project(fields, child) if !legacyMode || fields.forall(_.deterministic) => + case Project(fields, child) => val (_, filters, other, aliases) = collectProjectsAndFilters(child, alwaysInline) - if (legacyMode || canCollapseExpressions(fields, aliases, alwaysInline)) { + if (canCollapseExpressions(fields, aliases, alwaysInline)) { val replaced = fields.map(replaceAliasButKeepName(_, aliases)) (Some(replaced), filters, other, getAliasMap(replaced)) } else { empty } - case Filter(condition, child) if !legacyMode || condition.deterministic => + case Filter(condition, child) => val (fields, filters, other, aliases) = collectProjectsAndFilters(child, alwaysInline) - val canIncludeThisFilter = if (legacyMode) { - true - } else { - // When collecting projects and filters, we effectively push down filters through - // projects. We need to meet the following conditions to do so: - // 1) no Project collected so far or the collected Projects are all deterministic - // 2) the collected filters and this filter are all deterministic, or this is the - // first collected filter. - // 3) this filter does not repeat any expensive expressions from the collected - // projects. - fields.forall(_.forall(_.deterministic)) && { - filters.isEmpty || (filters.forall(_.deterministic) && condition.deterministic) - } && canCollapseExpressions(Seq(condition), aliases, alwaysInline) - } + // When collecting projects and filters, we effectively push down filters through + // projects. We need to meet the following conditions to do so: + // 1) no Project collected so far or the collected Projects are all deterministic + // 2) the collected filters and this filter are all deterministic, or this is the + // first collected filter. + // 3) this filter does not repeat any expensive expressions from the collected + // projects. + val canIncludeThisFilter = fields.forall(_.forall(_.deterministic)) && { + filters.isEmpty || (filters.forall(_.deterministic) && condition.deterministic) + } && canCollapseExpressions(Seq(condition), aliases, alwaysInline) if (canIncludeThisFilter) { val replaced = replaceAlias(condition, aliases) (fields, filters ++ splitConjunctivePredicates(replaced), other, aliases) @@ -112,24 +105,12 @@ trait OperationHelper extends PredicateHelper { } } -/** - * A pattern that matches any number of project or filter operations on top of another relational - * operator. All filter operators are collected and their conditions are broken up and returned - * together with the top project operator. - * [[org.apache.spark.sql.catalyst.expressions.Alias Aliases]] are in-lined/substituted if - * necessary. - */ -object PhysicalOperation extends OperationHelper { - override protected def legacyMode: Boolean = true -} - -/** - * A variant of [[PhysicalOperation]]. It matches any number of project or filter - * operations even if they are non-deterministic, as long as they satisfy the - * requirement of CollapseProject and CombineFilters. - */ -object ScanOperation extends OperationHelper { - override protected def legacyMode: Boolean = false +object NodeWithOnlyDeterministicProjectAndFilter { + def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match { + case Project(projectList, child) if projectList.forall(_.deterministic) => unapply(child) + case Filter(cond, child) if cond.deterministic => unapply(child) + case _ => Some(plan) + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownThroughWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownThroughWindowSuite.scala index b09d10b260174..99812d20bf55f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownThroughWindowSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownThroughWindowSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{CurrentRow, PercentRank, Rank, RowFrame, RowNumber, SpecifiedWindowFrame, UnboundedPreceding} +import org.apache.spark.sql.catalyst.expressions.{CurrentRow, NTile, PercentRank, Rank, RowFrame, RowNumber, SpecifiedWindowFrame, UnboundedPreceding} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -198,4 +198,15 @@ class LimitPushdownThroughWindowSuite extends PlanTest { Optimize.execute(originalQuery.analyze), WithoutOptimize.execute(originalQuery.analyze)) } + + test("SPARK-40002: Should not push through ntile window function") { + val originalQuery = testRelation + .select(a, b, c, + windowExpr(new NTile(), windowSpec(Nil, c.desc :: Nil, windowFrame)).as("nt")) + .limit(2) + + comparePlans( + Optimize.execute(originalQuery.analyze), + WithoutOptimize.execute(originalQuery.analyze)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala index 86e1b625910c5..741b1bb8c082c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ class SimplifyCastsSuite extends PlanTest { @@ -95,10 +96,17 @@ class SimplifyCastsSuite extends PlanTest { Optimize.execute( input.select($"b".cast(DecimalType(10, 2)).cast(DecimalType(24, 2)).as("casted")).analyze), input.select($"b".cast(DecimalType(10, 2)).cast(DecimalType(24, 2)).as("casted")).analyze) - comparePlans( - Optimize.execute( - input.select($"c".cast(DecimalType(10, 2)).cast(DecimalType(24, 2)).as("casted")).analyze), - input.select($"c".cast(DecimalType(10, 2)).cast(DecimalType(24, 2)).as("casted")).analyze) + + withClue("SPARK-39963: cast date to decimal") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> false.toString) { + // ANSI mode does not allow to cast a date to a decimal. + comparePlans(Optimize.execute( + input.select( + $"c".cast(DecimalType(10, 2)).cast(DecimalType(24, 2)).as("casted")).analyze), + input.select( + $"c".cast(DecimalType(10, 2)).cast(DecimalType(24, 2)).as("casted")).analyze) + } + } comparePlans( Optimize.execute( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/ScanOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/PhysicalOperationSuite.scala similarity index 88% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/ScanOperationSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/PhysicalOperationSuite.scala index eb3899c9187db..3d3f4c4c448b4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/ScanOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/PhysicalOperationSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.DoubleType -class ScanOperationSuite extends SparkFunSuite { +class PhysicalOperationSuite extends SparkFunSuite { private val relation = TestRelations.testRelation2 private val colA = relation.output(0) private val colB = relation.output(1) @@ -34,7 +34,7 @@ class ScanOperationSuite extends SparkFunSuite { test("Project with a non-deterministic field and a deterministic child Filter") { val project1 = Project(Seq(colB, aliasR), Filter(EqualTo(colA, Literal(1)), relation)) project1 match { - case ScanOperation(projects, filters, _: LocalRelation) => + case PhysicalOperation(projects, filters, _: LocalRelation) => assert(projects.size === 2) assert(projects(0) === colB) assert(projects(1) === aliasR) @@ -46,7 +46,7 @@ class ScanOperationSuite extends SparkFunSuite { test("Project with all deterministic fields but a non-deterministic child Filter") { val project2 = Project(Seq(colA, colB), Filter(EqualTo(aliasR, Literal(1)), relation)) project2 match { - case ScanOperation(projects, filters, _: LocalRelation) => + case PhysicalOperation(projects, filters, _: LocalRelation) => assert(projects.size === 2) assert(projects(0) === colA) assert(projects(1) === colB) @@ -58,7 +58,7 @@ class ScanOperationSuite extends SparkFunSuite { test("Project which has the same non-deterministic expression with its child Project") { val project3 = Project(Seq(colA, colR), Project(Seq(colA, aliasR), relation)) project3 match { - case ScanOperation(projects, filters, _: Project) => + case PhysicalOperation(projects, filters, _: Project) => assert(projects.size === 2) assert(projects(0) === colA) assert(projects(1) === colR) @@ -70,7 +70,7 @@ class ScanOperationSuite extends SparkFunSuite { test("Project which has different non-deterministic expressions with its child Project") { val project4 = Project(Seq(colA, aliasId), Project(Seq(colA, aliasR), relation)) project4 match { - case ScanOperation(projects, _, _: LocalRelation) => + case PhysicalOperation(projects, _, _: LocalRelation) => assert(projects.size === 2) assert(projects(0) === colA) assert(projects(1) === aliasId) @@ -81,7 +81,7 @@ class ScanOperationSuite extends SparkFunSuite { test("Filter with non-deterministic Project") { val filter1 = Filter(EqualTo(colA, Literal(1)), Project(Seq(colA, aliasR), relation)) filter1 match { - case ScanOperation(projects, filters, _: Filter) => + case PhysicalOperation(projects, filters, _: Filter) => assert(projects.size === 2) assert(filters.isEmpty) case _ => assert(false) @@ -92,7 +92,7 @@ class ScanOperationSuite extends SparkFunSuite { val filter2 = Filter(EqualTo(MonotonicallyIncreasingID(), Literal(1)), Project(Seq(colA, colB), relation)) filter2 match { - case ScanOperation(projects, filters, _: LocalRelation) => + case PhysicalOperation(projects, filters, _: LocalRelation) => assert(projects.size === 2) assert(projects(0) === colA) assert(projects(1) === colB) @@ -105,7 +105,7 @@ class ScanOperationSuite extends SparkFunSuite { test("Deterministic filter which has a non-deterministic child Filter") { val filter3 = Filter(EqualTo(colA, Literal(1)), Filter(EqualTo(aliasR, Literal(1)), relation)) filter3 match { - case ScanOperation(projects, filters, _: Filter) => + case PhysicalOperation(projects, filters, _: Filter) => assert(filters.isEmpty) case _ => assert(false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index 29b35229e9753..82ac8fd604994 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -33,28 +33,28 @@ import org.apache.spark.storage.StorageLevel abstract class Catalog { /** - * Returns the current default database in this session. + * Returns the current database (namespace) in this session. * * @since 2.0.0 */ def currentDatabase: String /** - * Sets the current default database in this session. + * Sets the current database (namespace) in this session. * * @since 2.0.0 */ def setCurrentDatabase(dbName: String): Unit /** - * Returns a list of databases available across all sessions. + * Returns a list of databases (namespaces) available within the current catalog. * * @since 2.0.0 */ def listDatabases(): Dataset[Database] /** - * Returns a list of tables/views in the current database. + * Returns a list of tables/views in the current database (namespace). * This includes all temporary views. * * @since 2.0.0 @@ -62,7 +62,8 @@ abstract class Catalog { def listTables(): Dataset[Table] /** - * Returns a list of tables/views in the specified database. + * Returns a list of tables/views in the specified database (namespace) (the name can be qualified + * with catalog). * This includes all temporary views. * * @since 2.0.0 @@ -71,16 +72,17 @@ abstract class Catalog { def listTables(dbName: String): Dataset[Table] /** - * Returns a list of functions registered in the current database. - * This includes all temporary functions + * Returns a list of functions registered in the current database (namespace). + * This includes all temporary functions. * * @since 2.0.0 */ def listFunctions(): Dataset[Function] /** - * Returns a list of functions registered in the specified database. - * This includes all temporary functions + * Returns a list of functions registered in the specified database (namespace) (the name can be + * qualified with catalog). + * This includes all built-in and temporary functions. * * @since 2.0.0 */ @@ -90,21 +92,22 @@ abstract class Catalog { /** * Returns a list of columns for the given table/view or temporary view. * - * @param tableName is either a qualified or unqualified name that designates a table/view. - * If no database identifier is provided, it refers to a temporary view or - * a table/view in the current database. + * @param tableName is either a qualified or unqualified name that designates a table/view. It + * follows the same resolution rule with SQL: search for temp views first then + * table/views in the current database (namespace). * @since 2.0.0 */ @throws[AnalysisException]("table does not exist") def listColumns(tableName: String): Dataset[Column] /** - * Returns a list of columns for the given table/view in the specified database. + * Returns a list of columns for the given table/view in the specified database under the Hive + * Metastore. * - * This API does not support 3 layer namespace since 3.4.0. To use 3 layer namespace, - * use listColumns(tableName) instead. + * To list columns for table/view in other catalogs, please use `listColumns(tableName)` with + * qualified table/view name instead. * - * @param dbName is a name that designates a database. + * @param dbName is an unqualified name that designates a database. * @param tableName is an unqualified name that designates a table/view. * @since 2.0.0 */ @@ -112,8 +115,8 @@ abstract class Catalog { def listColumns(dbName: String, tableName: String): Dataset[Column] /** - * Get the database with the specified name. This throws an AnalysisException when the database - * cannot be found. + * Get the database (namespace) with the specified name (can be qualified with catalog). This + * throws an AnalysisException when the database (namespace) cannot be found. * * @since 2.1.0 */ @@ -124,20 +127,20 @@ abstract class Catalog { * Get the table or view with the specified name. This table can be a temporary view or a * table/view. This throws an AnalysisException when no Table can be found. * - * @param tableName is either a qualified or unqualified name that designates a table/view. - * If no database identifier is provided, it refers to a table/view in - * the current database. + * @param tableName is either a qualified or unqualified name that designates a table/view. It + * follows the same resolution rule with SQL: search for temp views first then + * table/views in the current database (namespace). * @since 2.1.0 */ @throws[AnalysisException]("table does not exist") def getTable(tableName: String): Table /** - * Get the table or view with the specified name in the specified database. This throws an - * AnalysisException when no Table can be found. + * Get the table or view with the specified name in the specified database under the Hive + * Metastore. This throws an AnalysisException when no Table can be found. * - * This API does not support 3 layer namespace since 3.4.0. To use 3 layer namespace, - * use getTable(tableName) instead. + * To get table/view in other catalogs, please use `getTable(tableName)` with qualified table/view + * name instead. * * @since 2.1.0 */ @@ -148,22 +151,22 @@ abstract class Catalog { * Get the function with the specified name. This function can be a temporary function or a * function. This throws an AnalysisException when the function cannot be found. * - * @param functionName is either a qualified or unqualified name that designates a function. - * If no database identifier is provided, it refers to a temporary function - * or a function in the current database. + * @param functionName is either a qualified or unqualified name that designates a function. It + * follows the same resolution rule with SQL: search for built-in/temp + * functions first then functions in the current database (namespace). * @since 2.1.0 */ @throws[AnalysisException]("function does not exist") def getFunction(functionName: String): Function /** - * Get the function with the specified name. This throws an AnalysisException when the function - * cannot be found. + * Get the function with the specified name in the specified database under the Hive Metastore. + * This throws an AnalysisException when the function cannot be found. * - * This API does not support 3 layer namespace since 3.4.0. To use 3 layer namespace, - * use getFunction(functionName) instead. + * To get functions in other catalogs, please use `getFunction(functionName)` with qualified + * function name instead. * - * @param dbName is a name that designates a database. + * @param dbName is an unqualified name that designates a database. * @param functionName is an unqualified name that designates a function in the specified database * @since 2.1.0 */ @@ -171,7 +174,8 @@ abstract class Catalog { def getFunction(dbName: String, functionName: String): Function /** - * Check if the database with the specified name exists. + * Check if the database (namespace) with the specified name exists (the name can be qualified + * with catalog). * * @since 2.1.0 */ @@ -181,20 +185,21 @@ abstract class Catalog { * Check if the table or view with the specified name exists. This can either be a temporary * view or a table/view. * - * @param tableName is either a qualified or unqualified name that designates a table/view. - * If no database identifier is provided, it refers to a table/view in - * the current database. + * @param tableName is either a qualified or unqualified name that designates a table/view. It + * follows the same resolution rule with SQL: search for temp views first then + * table/views in the current database (namespace). * @since 2.1.0 */ def tableExists(tableName: String): Boolean /** - * Check if the table or view with the specified name exists in the specified database. + * Check if the table or view with the specified name exists in the specified database under the + * Hive Metastore. * - * This API does not support 3 layer namespace since 3.4.0. To use 3 layer namespace, - * use tableExists(tableName) instead. + * To check existence of table/view in other catalogs, please use `tableExists(tableName)` with + * qualified table/view name instead. * - * @param dbName is a name that designates a database. + * @param dbName is an unqualified name that designates a database. * @param tableName is an unqualified name that designates a table. * @since 2.1.0 */ @@ -204,20 +209,21 @@ abstract class Catalog { * Check if the function with the specified name exists. This can either be a temporary function * or a function. * - * @param functionName is either a qualified or unqualified name that designates a function. - * If no database identifier is provided, it refers to a function in - * the current database. + * @param functionName is either a qualified or unqualified name that designates a function. It + * follows the same resolution rule with SQL: search for built-in/temp + * functions first then functions in the current database (namespace). * @since 2.1.0 */ def functionExists(functionName: String): Boolean /** - * Check if the function with the specified name exists in the specified database. + * Check if the function with the specified name exists in the specified database under the + * Hive Metastore. * - * This API does not support 3 layer namespace since 3.4.0. To use 3 layer namespace, - * use functionExists(functionName) instead. + * To check existence of functions in other catalogs, please use `functionExists(functionName)` + * with qualified function name instead. * - * @param dbName is a name that designates a database. + * @param dbName is an unqualified name that designates a database. * @param functionName is an unqualified name that designates a function. * @since 2.1.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index a9d5c6da3844c..0216503fba0f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.planning.ScanOperation +import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoStatement, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 @@ -318,7 +318,7 @@ object DataSourceStrategy extends Strategy with Logging with CastSupport with PredicateHelper with SQLConfHelper { def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match { - case ScanOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _, _, _)) => + case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _, _, _)) => pruneFilterProjectRaw( l, projects, @@ -326,7 +326,7 @@ object DataSourceStrategy (requestedColumns, allPredicates, _) => toCatalystRDD(l, requestedColumns, t.buildScan(requestedColumns, allPredicates))) :: Nil - case ScanOperation(projects, filters, + case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedFilteredScan, _, _, _)) => pruneFilterProject( l, @@ -334,7 +334,7 @@ object DataSourceStrategy filters, (a, f) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray, f))) :: Nil - case ScanOperation(projects, filters, l @ LogicalRelation(t: PrunedScan, _, _, _)) => + case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedScan, _, _, _)) => pruneFilterProject( l, projects, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 9356e46a69187..4995a0d6cd4f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.planning.ScanOperation +import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.datasources.FileFormat.METADATA_NAME @@ -144,7 +144,7 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging { } def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case ScanOperation(projects, filters, + case PhysicalOperation(projects, filters, l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _)) => // Filters on this relation fall into four categories based on where we can use them to avoid // reading unneeded data: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 01b0ae451b2a9..27daa899583e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And, Attribute, AttributeReference, AttributeSet, Cast, Expression, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And, Attribute, AttributeMap, AttributeReference, AttributeSet, Cast, Expression, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.CollapseProject -import org.apache.spark.sql.catalyst.planning.ScanOperation +import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LimitAndOffset, LocalLimit, LogicalPlan, Offset, OffsetAndLimit, Project, Sample, Sort} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.expressions.{SortOrder => V2SortOrder} @@ -97,7 +97,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } private def rewriteAggregate(agg: Aggregate): LogicalPlan = agg.child match { - case ScanOperation(project, Nil, holder @ ScanBuilderHolder(_, _, + case PhysicalOperation(project, Nil, holder @ ScanBuilderHolder(_, _, r: SupportsPushDownAggregates)) if CollapseProject.canCollapseExpressions( agg.aggregateExpressions, project, alwaysInline = true) => val aliasMap = getAliasMap(project) @@ -189,12 +189,14 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { // +- ScanBuilderHolder[group_col_0#10, agg_func_0#21, agg_func_1#22] // Later, we build the `Scan` instance and convert ScanBuilderHolder to DataSourceV2ScanRelation. // scalastyle:on - val groupOutput = normalizedGroupingExpr.zipWithIndex.map { case (e, i) => - AttributeReference(s"group_col_$i", e.dataType)() + val groupOutputMap = normalizedGroupingExpr.zipWithIndex.map { case (e, i) => + AttributeReference(s"group_col_$i", e.dataType)() -> e } - val aggOutput = finalAggExprs.zipWithIndex.map { case (e, i) => - AttributeReference(s"agg_func_$i", e.dataType)() + val groupOutput = groupOutputMap.unzip._1 + val aggOutputMap = finalAggExprs.zipWithIndex.map { case (e, i) => + AttributeReference(s"agg_func_$i", e.dataType)() -> e } + val aggOutput = aggOutputMap.unzip._1 val newOutput = groupOutput ++ aggOutput val groupByExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] normalizedGroupingExpr.zipWithIndex.foreach { case (expr, ordinal) => @@ -204,6 +206,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } holder.pushedAggregate = Some(translatedAgg) + holder.pushedAggOutputMap = AttributeMap(groupOutputMap ++ aggOutputMap) holder.output = newOutput logInfo( s""" @@ -342,7 +345,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } def pruneColumns(plan: LogicalPlan): LogicalPlan = plan.transform { - case ScanOperation(project, filters, sHolder: ScanBuilderHolder) => + case PhysicalOperation(project, filters, sHolder: ScanBuilderHolder) => // column pruning val normalizedProjects = DataSourceStrategy .normalizeExprs(project, sHolder.output) @@ -382,7 +385,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { def pushDownSample(plan: LogicalPlan): LogicalPlan = plan.transform { case sample: Sample => sample.child match { - case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty => + case PhysicalOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty => val tableSample = TableSampleInfo( sample.lowerBound, sample.upperBound, @@ -401,21 +404,27 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } private def pushDownLimit(plan: LogicalPlan, limit: Int): (LogicalPlan, Boolean) = plan match { - case operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty => + case operation @ PhysicalOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty => val (isPushed, isPartiallyPushed) = PushDownUtils.pushLimit(sHolder.builder, limit) if (isPushed) { sHolder.pushedLimit = Some(limit) } (operation, isPushed && !isPartiallyPushed) - case s @ Sort(order, _, operation @ ScanOperation(project, filter, sHolder: ScanBuilderHolder)) - // Without building the Scan, we do not know the resulting column names after aggregate - // push-down, and thus can't push down Top-N which needs to know the ordering column names. - // TODO: we can support simple cases like GROUP BY columns directly and ORDER BY the same - // columns, which we know the resulting column names: the original table columns. - if sHolder.pushedAggregate.isEmpty && filter.isEmpty && - CollapseProject.canCollapseExpressions(order, project, alwaysInline = true) => + case s @ Sort(order, _, operation @ PhysicalOperation(project, Nil, sHolder: ScanBuilderHolder)) + if CollapseProject.canCollapseExpressions(order, project, alwaysInline = true) => val aliasMap = getAliasMap(project) - val newOrder = order.map(replaceAlias(_, aliasMap)).asInstanceOf[Seq[SortOrder]] + val aliasReplacedOrder = order.map(replaceAlias(_, aliasMap)) + val newOrder = if (sHolder.pushedAggregate.isDefined) { + // `ScanBuilderHolder` has different output columns after aggregate push-down. Here we + // replace the attributes in ordering expressions with the original table output columns. + aliasReplacedOrder.map { + _.transform { + case a: Attribute => sHolder.pushedAggOutputMap.getOrElse(a, a) + }.asInstanceOf[SortOrder] + } + } else { + aliasReplacedOrder.asInstanceOf[Seq[SortOrder]] + } val normalizedOrders = DataSourceStrategy.normalizeExprs( newOrder, sHolder.relation.output).asInstanceOf[Seq[SortOrder]] val orders = DataSourceStrategy.translateSortOrders(normalizedOrders) @@ -545,6 +554,8 @@ case class ScanBuilderHolder( var pushedPredicates: Seq[Predicate] = Seq.empty[Predicate] var pushedAggregate: Option[Aggregation] = None + + var pushedAggOutputMap: AttributeMap[Expression] = AttributeMap.empty[Expression] } // A wrapper for v1 scan to carry the translated filters and the handled ones, along with diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/CleanupDynamicPruningFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/CleanupDynamicPruningFilters.scala index 9607ca5396449..2d0130985eacc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/CleanupDynamicPruningFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/CleanupDynamicPruningFilters.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.dynamicpruning import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.expressions.{DynamicPruning, DynamicPruningSubquery, EqualNullSafe, EqualTo, Expression, ExpressionSet, PredicateHelper} import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral -import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.planning.NodeWithOnlyDeterministicProjectAndFilter import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern._ @@ -72,12 +72,15 @@ object CleanupDynamicPruningFilters extends Rule[LogicalPlan] with PredicateHelp // No-op for trees that do not contain dynamic pruning. _.containsAnyPattern(DYNAMIC_PRUNING_EXPRESSION, DYNAMIC_PRUNING_SUBQUERY)) { // pass through anything that is pushed down into PhysicalOperation - case p @ PhysicalOperation(_, _, LogicalRelation(_: HadoopFsRelation, _, _, _)) => + case p @ NodeWithOnlyDeterministicProjectAndFilter( + LogicalRelation(_: HadoopFsRelation, _, _, _)) => removeUnnecessaryDynamicPruningSubquery(p) // pass through anything that is pushed down into PhysicalOperation - case p @ PhysicalOperation(_, _, HiveTableRelation(_, _, _, _, _)) => + case p @ NodeWithOnlyDeterministicProjectAndFilter( + HiveTableRelation(_, _, _, _, _)) => removeUnnecessaryDynamicPruningSubquery(p) - case p @ PhysicalOperation(_, _, _: DataSourceV2ScanRelation) => + case p @ NodeWithOnlyDeterministicProjectAndFilter( + _: DataSourceV2ScanRelation) => removeUnnecessaryDynamicPruningSubquery(p) // remove any Filters with DynamicPruning that didn't get pushed down to PhysicalOperation. case f @ Filter(condition, _) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c056baba8bacd..533c5614885c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -670,6 +670,14 @@ object functions { */ def last(columnName: String): Column = last(Column(columnName), ignoreNulls = false) + /** + * Aggregate function: returns the most frequent value in a group. + * + * @group agg_funcs + * @since 3.4.0 + */ + def mode(e: Column): Column = withAggregateFunction { Mode(e.expr) } + /** * Aggregate function: returns the maximum value of the expression in a group. * @@ -712,6 +720,14 @@ object functions { */ def mean(columnName: String): Column = avg(columnName) + /** + * Aggregate function: returns the median of the values in a group. + * + * @group agg_funcs + * @since 3.4.0 + */ + def median(e: Column): Column = withAggregateFunction { Median(e.expr) } + /** * Aggregate function: returns the minimum value of the expression in a group. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index e11b349777e8f..657ed87e609e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -23,14 +23,14 @@ import scala.util.control.NonFatal import org.apache.spark.sql._ import org.apache.spark.sql.catalog.{Catalog, CatalogMetadata, Column, Database, Function, Table} import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{ResolvedNamespace, ResolvedNonPersistentFunc, ResolvedPersistentFunc, ResolvedTable, ResolvedView, UnresolvedFunc, UnresolvedIdentifier, UnresolvedNamespace, UnresolvedTable, UnresolvedTableOrView} +import org.apache.spark.sql.catalyst.analysis.{ResolvedIdentifier, ResolvedNamespace, ResolvedNonPersistentFunc, ResolvedPersistentFunc, ResolvedTable, ResolvedView, UnresolvedFunc, UnresolvedIdentifier, UnresolvedNamespace, UnresolvedTable, UnresolvedTableOrView} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.plans.logical.{CreateTable, LocalRelation, RecoverPartitions, ShowFunctions, ShowNamespaces, ShowTables, SubqueryAlias, TableSpec, View} -import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, SupportsNamespaces, TableCatalog} +import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, CatalogV2Util, FunctionCatalog, Identifier, SupportsNamespaces, Table => V2Table, TableCatalog, V1Table} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{CatalogHelper, IdentifierHelper, MultipartIdentifierHelper, TransformHelper} import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.command.ShowTablesCommand import org.apache.spark.sql.execution.datasources.{DataSource, LogicalRelation} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.connector.V1Function @@ -45,15 +45,16 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { private def sessionCatalog: SessionCatalog = sparkSession.sessionState.catalog - private def requireDatabaseExists(dbName: String): Unit = { - if (!sessionCatalog.databaseExists(dbName)) { - throw QueryCompilationErrors.databaseDoesNotExistError(dbName) - } + private def parseIdent(name: String): Seq[String] = { + sparkSession.sessionState.sqlParser.parseMultipartIdentifier(name) } - private def requireTableExists(dbName: String, tableName: String): Unit = { - if (!sessionCatalog.tableExists(TableIdentifier(tableName, Some(dbName)))) { - throw QueryCompilationErrors.tableDoesNotExistInDatabaseError(tableName, dbName) + private def qualifyV1Ident(nameParts: Seq[String]): Seq[String] = { + assert(nameParts.length == 1 || nameParts.length == 2) + if (nameParts.length == 1) { + Seq(CatalogManager.SESSION_CATALOG_NAME, sessionCatalog.getCurrentDatabase) ++ nameParts + } else { + CatalogManager.SESSION_CATALOG_NAME +: nameParts } } @@ -68,32 +69,27 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ @throws[AnalysisException]("database does not exist") override def setCurrentDatabase(dbName: String): Unit = { - // we assume dbName will not include the catalog prefix. e.g. if you call - // setCurrentDatabase("catalog.db") it will search for a database catalog.db in the catalog. - val ident = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(dbName) - sparkSession.sessionState.catalogManager.setCurrentNamespace(ident.toArray) + // we assume `dbName` will not include the catalog name. e.g. if you call + // `setCurrentDatabase("catalog.db")`, it will search for a database 'catalog.db' in the current + // catalog. + sparkSession.sessionState.catalogManager.setCurrentNamespace(parseIdent(dbName).toArray) } /** * Returns a list of databases available across all sessions. */ override def listDatabases(): Dataset[Database] = { - val catalog = currentCatalog() - val plan = ShowNamespaces(UnresolvedNamespace(Seq(catalog)), None) - val databases = sparkSession.sessionState.executePlan(plan).toRdd.collect() - .map(row => catalog + "." + row.getString(0)) - .map(getDatabase) + val plan = ShowNamespaces(UnresolvedNamespace(Nil), None) + val qe = sparkSession.sessionState.executePlan(plan) + val catalog = qe.analyzed.collectFirst { + case ShowNamespaces(r: ResolvedNamespace, _, _) => r.catalog + }.get + val databases = qe.toRdd.collect().map { row => + getNamespace(catalog, parseIdent(row.getString(0))) + } CatalogImpl.makeDataset(databases, sparkSession) } - private def makeDatabase(dbName: String): Database = { - val metadata = sessionCatalog.getDatabaseMetadata(dbName) - new Database( - name = metadata.name, - description = metadata.description, - locationUri = CatalogUtils.URIToString(metadata.locationUri)) - } - /** * Returns a list of tables in the current database. * This includes all temporary tables. @@ -110,74 +106,93 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { override def listTables(dbName: String): Dataset[Table] = { // `dbName` could be either a single database name (behavior in Spark 3.3 and prior) or // a qualified namespace with catalog name. We assume it's a single database name - // and check if we can find the dbName in sessionCatalog. If so we listTables under - // that database. Otherwise we try 3-part name parsing and locate the database. - if (sessionCatalog.databaseExists(dbName) || sessionCatalog.isGlobalTempViewDB(dbName)) { - val tables = sessionCatalog.listTables(dbName).map(makeTable) - CatalogImpl.makeDataset(tables, sparkSession) + // and check if we can find it in the sessionCatalog. If so we list tables under + // that database. Otherwise we will resolve the catalog/namespace and list tables there. + val namespace = if (sessionCatalog.databaseExists(dbName)) { + Seq(CatalogManager.SESSION_CATALOG_NAME, dbName) } else { - val ident = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(dbName) - val plan = ShowTables(UnresolvedNamespace(ident), None) - val ret = sparkSession.sessionState.executePlan(plan).toRdd.collect() - val tables = ret - .map(row => ident ++ Seq(row.getString(1))) - .map(makeTable) - CatalogImpl.makeDataset(tables, sparkSession) + parseIdent(dbName) } - } + val plan = ShowTables(UnresolvedNamespace(namespace), None) + val qe = sparkSession.sessionState.executePlan(plan) + val catalog = qe.analyzed.collectFirst { + case ShowTables(r: ResolvedNamespace, _, _) => r.catalog + case _: ShowTablesCommand => + sparkSession.sessionState.catalogManager.v2SessionCatalog + }.get + val tables = qe.toRdd.collect().map { row => + val tableName = row.getString(1) + val namespaceName = row.getString(0) + val isTemp = row.getBoolean(2) + if (isTemp) { + // Temp views do not belong to any catalog. We shouldn't prepend the catalog name here. + val ns = if (namespaceName.isEmpty) Nil else Seq(namespaceName) + makeTable(ns :+ tableName) + } else { + val ns = parseIdent(namespaceName) + makeTable(catalog.name() +: ns :+ tableName) + } + } + CatalogImpl.makeDataset(tables, sparkSession) + } + + private def makeTable(nameParts: Seq[String]): Table = { + sessionCatalog.lookupLocalOrGlobalRawTempView(nameParts).map { tempView => + new Table( + name = tempView.tableMeta.identifier.table, + catalog = null, + namespace = tempView.tableMeta.identifier.database.toArray, + description = tempView.tableMeta.comment.orNull, + tableType = "TEMPORARY", + isTemporary = true) + }.getOrElse { + val plan = UnresolvedIdentifier(nameParts) + sparkSession.sessionState.executePlan(plan).analyzed match { + case ResolvedIdentifier(catalog: TableCatalog, ident) => + val tableOpt = try { + loadTable(catalog, ident) + } catch { + // Even if the table exits, error may still happen. For example, Spark can't read Hive's + // index table. We return a Table without description and tableType in this case. + case NonFatal(_) => + Some(new Table( + name = ident.name(), + catalog = catalog.name(), + namespace = ident.namespace(), + description = null, + tableType = null, + isTemporary = false)) + } + tableOpt.getOrElse(throw QueryCompilationErrors.tableOrViewNotFound(nameParts)) - /** - * Returns a Table for the given table/view or temporary view. - * - * Note that this function requires the table already exists in the Catalog. - * - * If the table metadata retrieval failed due to any reason (e.g., table serde class - * is not accessible or the table type is not accepted by Spark SQL), this function - * still returns the corresponding Table without the description and tableType) - */ - private def makeTable(tableIdent: TableIdentifier): Table = { - val metadata = try { - Some(sessionCatalog.getTempViewOrPermanentTableMetadata(tableIdent)) - } catch { - case NonFatal(_) => None + case _ => throw QueryCompilationErrors.tableOrViewNotFound(nameParts) + } } - val isTemp = sessionCatalog.isTempView(tableIdent) - val qualifier = - metadata.map(_.identifier.database).getOrElse(tableIdent.database).map(Array(_)).orNull - new Table( - name = tableIdent.table, - catalog = CatalogManager.SESSION_CATALOG_NAME, - namespace = qualifier, - description = metadata.map(_.comment.orNull).orNull, - tableType = if (isTemp) "TEMPORARY" else metadata.map(_.tableType.name).orNull, - isTemporary = isTemp) - } - - private def makeTable(ident: Seq[String]): Table = { - val plan = UnresolvedTableOrView(ident, "Catalog.listTables", true) - val node = sparkSession.sessionState.executePlan(plan).analyzed - node match { - case t: ResolvedTable => - val isExternal = t.table.properties().getOrDefault( + } + + private def loadTable(catalog: TableCatalog, ident: Identifier): Option[Table] = { + // TODO: support v2 view when it gets implemented. + CatalogV2Util.loadTable(catalog, ident).map { + case v1: V1Table if v1.v1Table.tableType == CatalogTableType.VIEW => + new Table( + name = v1.v1Table.identifier.table, + catalog = catalog.name(), + namespace = v1.v1Table.identifier.database.toArray, + description = v1.v1Table.comment.orNull, + tableType = "VIEW", + isTemporary = false) + case t: V2Table => + val isExternal = t.properties().getOrDefault( TableCatalog.PROP_EXTERNAL, "false").equals("true") new Table( - name = t.identifier.name(), - catalog = t.catalog.name(), - namespace = t.identifier.namespace(), - description = t.table.properties().get("comment"), + name = ident.name(), + catalog = catalog.name(), + namespace = ident.namespace(), + description = t.properties().get("comment"), tableType = if (isExternal) CatalogTableType.EXTERNAL.name else CatalogTableType.MANAGED.name, isTemporary = false) - case v: ResolvedView => - new Table( - name = v.identifier.name(), - catalog = null, - namespace = v.identifier.namespace(), - description = null, - tableType = if (v.isTemp) "TEMPORARY" else "VIEW", - isTemporary = v.isTemp) - case _ => throw QueryCompilationErrors.tableOrViewNotFound(ident) } } @@ -197,48 +212,37 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { override def listFunctions(dbName: String): Dataset[Function] = { // `dbName` could be either a single database name (behavior in Spark 3.3 and prior) or // a qualified namespace with catalog name. We assume it's a single database name - // and check if we can find the dbName in sessionCatalog. If so we listFunctions under - // that database. Otherwise we try 3-part name parsing and locate the database. - if (sessionCatalog.databaseExists(dbName)) { - val functions = sessionCatalog.listFunctions(dbName) - .map { case (functIdent, _) => makeFunction(functIdent) } - CatalogImpl.makeDataset(functions, sparkSession) + // and check if we can find it in the sessionCatalog. If so we list functions under + // that database. Otherwise we will resolve the catalog/namespace and list functions there. + val namespace = if (sessionCatalog.databaseExists(dbName)) { + Seq(CatalogManager.SESSION_CATALOG_NAME, dbName) } else { - val ident = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(dbName) - val functions = collection.mutable.ArrayBuilder.make[Function] - - // built-in functions - val plan0 = ShowFunctions(UnresolvedNamespace(ident), - userScope = false, systemScope = true, None) - sparkSession.sessionState.executePlan(plan0).toRdd.collect().foreach { row => - // `lookupBuiltinOrTempFunction` and `lookupBuiltinOrTempTableFunction` in Analyzer - // require the input identifier only contains the function name, otherwise, built-in - // functions will be skipped. - val name = row.getString(0) - functions += makeFunction(Seq(name)) - } - - // user functions - val plan1 = ShowFunctions(UnresolvedNamespace(ident), - userScope = true, systemScope = false, None) - sparkSession.sessionState.executePlan(plan1).toRdd.collect().foreach { row => - // `row.getString(0)` may contain dbName like `db.function`, so extract the function name. - val name = row.getString(0).split("\\.").last - functions += makeFunction(ident :+ name) - } + parseIdent(dbName) + } + val functions = collection.mutable.ArrayBuilder.make[Function] + + // TODO: The SHOW FUNCTIONS should tell us the function type (built-in, temp, persistent) and + // we can simply the code below quite a bit. For now we need to list built-in functions + // separately as several built-in function names are not parsable, such as `!=`. + + // List built-in functions. We don't need to specify the namespace here as SHOW FUNCTIONS with + // only system scope does not need to know the catalog and namespace. + val plan0 = ShowFunctions(UnresolvedNamespace(Nil), userScope = false, systemScope = true, None) + sparkSession.sessionState.executePlan(plan0).toRdd.collect().foreach { row => + // Built-in functions do not belong to any catalog or namespace. We can only look it up with + // a single part name. + val name = row.getString(0) + functions += makeFunction(Seq(name)) + } - CatalogImpl.makeDataset(functions.result(), sparkSession) + // List user functions. + val plan1 = ShowFunctions(UnresolvedNamespace(namespace), + userScope = true, systemScope = false, None) + sparkSession.sessionState.executePlan(plan1).toRdd.collect().foreach { row => + functions += makeFunction(parseIdent(row.getString(0))) } - } - private def makeFunction(funcIdent: FunctionIdentifier): Function = { - val metadata = sessionCatalog.lookupFunctionInfo(funcIdent) - new Function( - name = metadata.getName, - database = metadata.getDb, - description = null, // for now, this is always undefined - className = metadata.getClassName, - isTemporary = metadata.getDb == null) + CatalogImpl.makeDataset(functions.result(), sparkSession) } private def makeFunction(ident: Seq[String]): Function = { @@ -279,23 +283,16 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ @throws[AnalysisException]("table does not exist") override def listColumns(tableName: String): Dataset[Column] = { - // calling `sqlParser.parseTableIdentifier` to parse tableName. If it contains only table name - // and optionally contains a database name(thus a TableIdentifier), then we look up the table in - // sessionCatalog. Otherwise we try `sqlParser.parseMultipartIdentifier` to have a sequence of - // string as the qualified identifier and resolve the table through SQL analyzer. - try { - val ident = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName) - if (tableExists(ident.database.orNull, ident.table)) { - listColumns(ident) - } else { - val ident = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(tableName) - listColumns(ident) - } - } catch { - case e: org.apache.spark.sql.catalyst.parser.ParseException => - val ident = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(tableName) - listColumns(ident) + val parsed = parseIdent(tableName) + // For backward compatibility (Spark 3.3 and prior), we should check if the table exists in + // the Hive Metastore first. + val nameParts = if (parsed.length <= 2 && !sessionCatalog.isTempView(parsed) && + sessionCatalog.tableExists(parsed.asTableIdentifier)) { + qualifyV1Ident(parsed) + } else { + parsed } + listColumns(nameParts) } /** @@ -303,25 +300,9 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ @throws[AnalysisException]("database or table does not exist") override def listColumns(dbName: String, tableName: String): Dataset[Column] = { - requireTableExists(dbName, tableName) - listColumns(TableIdentifier(tableName, Some(dbName))) - } - - private def listColumns(tableIdentifier: TableIdentifier): Dataset[Column] = { - val tableMetadata = sessionCatalog.getTempViewOrPermanentTableMetadata(tableIdentifier) - - val partitionColumnNames = tableMetadata.partitionColumnNames.toSet - val bucketColumnNames = tableMetadata.bucketSpec.map(_.bucketColumnNames).getOrElse(Nil).toSet - val columns = tableMetadata.schema.map { c => - new Column( - name = c.name, - description = c.getComment().orNull, - dataType = CharVarcharUtils.getRawType(c.metadata).getOrElse(c.dataType).catalogString, - nullable = c.nullable, - isPartition = partitionColumnNames.contains(c.name), - isBucket = bucketColumnNames.contains(c.name)) - } - CatalogImpl.makeDataset(columns, sparkSession) + // For backward compatibility (Spark 3.3 and prior), here we always look up the table from the + // Hive Metastore. + listColumns(Seq(CatalogManager.SESSION_CATALOG_NAME, dbName, tableName)) } private def listColumns(ident: Seq[String]): Dataset[Column] = { @@ -361,41 +342,44 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { CatalogImpl.makeDataset(columns, sparkSession) } + private def getNamespace(catalog: CatalogPlugin, ns: Seq[String]): Database = catalog match { + case catalog: SupportsNamespaces => + val metadata = catalog.loadNamespaceMetadata(ns.toArray) + new Database( + name = ns.quoted, + catalog = catalog.name, + description = metadata.get(SupportsNamespaces.PROP_COMMENT), + locationUri = metadata.get(SupportsNamespaces.PROP_LOCATION)) + // If the catalog doesn't support namespaces, we assume it's an implicit namespace, which always + // exists but has no metadata. + case catalog: CatalogPlugin => + new Database( + name = ns.quoted, + catalog = catalog.name, + description = null, + locationUri = null) + case _ => new Database(name = ns.quoted, description = null, locationUri = null) + } /** * Gets the database with the specified name. This throws an `AnalysisException` when no * `Database` can be found. */ override def getDatabase(dbName: String): Database = { - // `dbName` could be either a single database name (behavior in Spark 3.3 and prior) or a - // qualified namespace with catalog name. To maintain backwards compatibility, we first assume - // it's a single database name and return the database from sessionCatalog if it exists. - // Otherwise we try 3-part name parsing and locate the database. If the parased identifier - // contains both catalog name and database name, we then search the database in the catalog. - if (sessionCatalog.databaseExists(dbName) || sessionCatalog.isGlobalTempViewDB(dbName)) { - makeDatabase(dbName) + // `dbName` could be either a single database name (behavior in Spark 3.3 and prior) or + // a qualified namespace with catalog name. We assume it's a single database name + // and check if we can find it in the sessionCatalog. Otherwise we will parse `dbName` and + // resolve catalog/namespace with it. + val namespace = if (sessionCatalog.databaseExists(dbName)) { + Seq(CatalogManager.SESSION_CATALOG_NAME, dbName) } else { - val ident = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(dbName) - val plan = UnresolvedNamespace(ident) - val resolved = sparkSession.sessionState.executePlan(plan).analyzed - resolved match { - case ResolvedNamespace(catalog: SupportsNamespaces, namespace) => - val metadata = catalog.loadNamespaceMetadata(namespace.toArray) - new Database( - name = namespace.quoted, - catalog = catalog.name, - description = metadata.get(SupportsNamespaces.PROP_COMMENT), - locationUri = metadata.get(SupportsNamespaces.PROP_LOCATION)) - // similar to databaseExists: if the catalog doesn't support namespaces, we assume it's an - // implicit namespace, which exists but has no metadata. - case ResolvedNamespace(catalog: CatalogPlugin, namespace) => - new Database( - name = namespace.quoted, - catalog = catalog.name, - description = null, - locationUri = null) - case _ => new Database(name = dbName, description = null, locationUri = null) - } + sparkSession.sessionState.sqlParser.parseMultipartIdentifier(dbName) + } + val plan = UnresolvedNamespace(namespace) + sparkSession.sessionState.executePlan(plan).analyzed match { + case ResolvedNamespace(catalog, namespace) => + getNamespace(catalog, namespace) + case _ => new Database(name = dbName, description = null, locationUri = null) } } @@ -404,26 +388,16 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * table/view. This throws an `AnalysisException` when no `Table` can be found. */ override def getTable(tableName: String): Table = { - // calling `sqlParser.parseTableIdentifier` to parse tableName. If it contains only table name - // and optionally contains a database name(thus a TableIdentifier), then we look up the table in - // sessionCatalog. Otherwise we try `sqlParser.parseMultipartIdentifier` to have a sequence of - // string as the qualified identifier and resolve the table through SQL analyzer. - try { - val ident = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName) - if (tableExists(ident.database.orNull, ident.table)) { - makeTable(ident) - } else { - getTable3LNamespace(tableName) - } - } catch { - case e: org.apache.spark.sql.catalyst.parser.ParseException => - getTable3LNamespace(tableName) + val parsed = parseIdent(tableName) + // For backward compatibility (Spark 3.3 and prior), we should check if the table exists in + // the Hive Metastore first. + val nameParts = if (parsed.length <= 2 && !sessionCatalog.isTempView(parsed) && + sessionCatalog.tableExists(parsed.asTableIdentifier)) { + qualifyV1Ident(parsed) + } else { + parsed } - } - - private def getTable3LNamespace(tableName: String): Table = { - val ident = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(tableName) - makeTable(ident) + makeTable(nameParts) } /** @@ -431,10 +405,12 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * `AnalysisException` when no `Table` can be found. */ override def getTable(dbName: String, tableName: String): Table = { - if (tableExists(dbName, tableName)) { - makeTable(TableIdentifier(tableName, Option(dbName))) + if (sessionCatalog.isGlobalTempViewDB(dbName)) { + makeTable(Seq(dbName, tableName)) } else { - throw QueryCompilationErrors.tableOrViewNotFoundInDatabaseError(tableName, dbName) + // For backward compatibility (Spark 3.3 and prior), here we always look up the table from the + // Hive Metastore. + makeTable(Seq(CatalogManager.SESSION_CATALOG_NAME, dbName, tableName)) } } @@ -443,19 +419,17 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * function. This throws an `AnalysisException` when no `Function` can be found. */ override def getFunction(functionName: String): Function = { - // calling `sqlParser.parseFunctionIdentifier` to parse functionName. If it contains only - // function name and optionally contains a database name(thus a FunctionIdentifier), then - // we look up the function in sessionCatalog. - // Otherwise we try `sqlParser.parseMultipartIdentifier` to have a sequence of string as - // the qualified identifier and resolve the function through SQL analyzer. - try { - val ident = sparkSession.sessionState.sqlParser.parseFunctionIdentifier(functionName) - getFunction(ident.database.orNull, ident.funcName) - } catch { - case e: org.apache.spark.sql.catalyst.parser.ParseException => - val ident = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(functionName) - makeFunction(ident) + val parsed = parseIdent(functionName) + // For backward compatibility (Spark 3.3 and prior), we should check if the function exists in + // the Hive Metastore first. + val nameParts = if (parsed.length <= 2 && + !sessionCatalog.isTemporaryFunction(parsed.asFunctionIdentifier) && + sessionCatalog.isPersistentFunction(parsed.asFunctionIdentifier)) { + qualifyV1Ident(parsed) + } else { + parsed } + makeFunction(nameParts) } /** @@ -463,7 +437,9 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * found. */ override def getFunction(dbName: String, functionName: String): Function = { - makeFunction(FunctionIdentifier(functionName, Option(dbName))) + // For backward compatibility (Spark 3.3 and prior), here we always look up the function from + // the Hive Metastore. + makeFunction(Seq(CatalogManager.SESSION_CATALOG_NAME, dbName, functionName)) } /** @@ -471,15 +447,13 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def databaseExists(dbName: String): Boolean = { // To maintain backwards compatibility, we first treat the input is a simple dbName and check - // if sessionCatalog contains it. If no, we try to parse it as 3 part name. If the parased - // identifier contains both catalog name and database name, we then search the database in the - // catalog. + // if sessionCatalog contains it. If no, we try to parse it, resolve catalog and namespace, + // and check if namespace exists in the catalog. if (!sessionCatalog.databaseExists(dbName)) { - val ident = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(dbName) - val plan = sparkSession.sessionState.executePlan(UnresolvedNamespace(ident)).analyzed - plan match { - case ResolvedNamespace(catalog: SupportsNamespaces, _) => - catalog.namespaceExists(ident.slice(1, ident.size).toArray) + val plan = UnresolvedNamespace(parseIdent(dbName)) + sparkSession.sessionState.executePlan(plan).analyzed match { + case ResolvedNamespace(catalog: SupportsNamespaces, ns) => + catalog.namespaceExists(ns.toArray) case _ => true } } else { @@ -492,11 +466,18 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * view or a table/view. */ override def tableExists(tableName: String): Boolean = { - try { - getTable(tableName) - true - } catch { - case e: AnalysisException => false + val parsed = parseIdent(tableName) + // For backward compatibility (Spark 3.3 and prior), we should check if the table exists in + // the Hive Metastore first. This also checks if it's a temp view. + (parsed.length <= 2 && { + val v1Ident = parsed.asTableIdentifier + sessionCatalog.isTempView(v1Ident) || sessionCatalog.tableExists(v1Ident) + }) || { + val plan = UnresolvedIdentifier(parsed) + sparkSession.sessionState.executePlan(plan).analyzed match { + case ResolvedIdentifier(catalog: TableCatalog, ident) => catalog.tableExists(ident) + case _ => false + } } } @@ -513,22 +494,15 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * or a function. */ override def functionExists(functionName: String): Boolean = { - try { - val ident = sparkSession.sessionState.sqlParser.parseFunctionIdentifier(functionName) - functionExists(ident.database.orNull, ident.funcName) - } catch { - case e: org.apache.spark.sql.catalyst.parser.ParseException => - try { - val ident = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(functionName) - val plan = UnresolvedFunc(ident, "Catalog.functionExists", false, None) - sparkSession.sessionState.executePlan(plan).analyzed match { - case _: ResolvedPersistentFunc => true - case _: ResolvedNonPersistentFunc => true - case _ => false - } - } catch { - case _: org.apache.spark.sql.AnalysisException => false - } + val parsed = parseIdent(functionName) + // For backward compatibility (Spark 3.3 and prior), we should check if the function exists in + // the Hive Metastore first. This also checks if it's a built-in/temp function. + (parsed.length <= 2 && sessionCatalog.functionExists(parsed.asFunctionIdentifier)) || { + val plan = UnresolvedIdentifier(parsed) + sparkSession.sessionState.executePlan(plan).analyzed match { + case ResolvedIdentifier(catalog: FunctionCatalog, ident) => catalog.functionExists(ident) + case _ => false + } } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/cast.sql b/sql/core/src/test/resources/sql-tests/inputs/cast.sql index 66a78ec9473ad..34102a1250780 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/cast.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/cast.sql @@ -105,7 +105,7 @@ select cast('a' as timestamp_ntz); select cast(cast('inf' as double) as timestamp); select cast(cast('inf' as float) as timestamp); --- cast ANSI intervals to numerics +-- cast ANSI intervals to integrals select cast(interval '1' year as tinyint); select cast(interval '-10-2' year to month as smallint); select cast(interval '1000' month as int); @@ -117,6 +117,18 @@ select cast(interval '10' day as bigint); select cast(interval '-1000' month as tinyint); select cast(interval '1000000' second as smallint); +-- cast integrals to ANSI intervals +select cast(1Y as interval year); +select cast(-122S as interval year to month); +select cast(1000 as interval month); +select cast(-10L as interval second); +select cast(100Y as interval hour to second); +select cast(-1000S as interval day to second); +select cast(10 as interval day); + +select cast(2147483647 as interval year); +select cast(-9223372036854775808L as interval day); + -- cast ANSI intervals to decimals select cast(interval '-1' year as decimal(10, 0)); select cast(interval '1.000001' second as decimal(10, 6)); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out index 470a6081c469d..c4b454b135c9e 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out @@ -840,6 +840,80 @@ org.apache.spark.SparkArithmeticException [CAST_OVERFLOW] The value INTERVAL '1000000' SECOND of the type "INTERVAL SECOND" cannot be cast to "SMALLINT" due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error. +-- !query +select cast(1Y as interval year) +-- !query schema +struct +-- !query output +1-0 + + +-- !query +select cast(-122S as interval year to month) +-- !query schema +struct +-- !query output +-10-2 + + +-- !query +select cast(1000 as interval month) +-- !query schema +struct +-- !query output +83-4 + + +-- !query +select cast(-10L as interval second) +-- !query schema +struct +-- !query output +-0 00:00:10.000000000 + + +-- !query +select cast(100Y as interval hour to second) +-- !query schema +struct +-- !query output +0 00:01:40.000000000 + + +-- !query +select cast(-1000S as interval day to second) +-- !query schema +struct +-- !query output +-0 00:16:40.000000000 + + +-- !query +select cast(10 as interval day) +-- !query schema +struct +-- !query output +10 00:00:00.000000000 + + +-- !query +select cast(2147483647 as interval year) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +[CAST_OVERFLOW] The value 2147483647 of the type "INT" cannot be cast to "INTERVAL YEAR" due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error. + + +-- !query +select cast(-9223372036854775808L as interval day) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +[CAST_OVERFLOW] The value -9223372036854775808L of the type "BIGINT" cannot be cast to "INTERVAL DAY" due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error. + + -- !query select cast(interval '-1' year as decimal(10, 0)) -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/cast.sql.out b/sql/core/src/test/resources/sql-tests/results/cast.sql.out index 911eaff30b938..2b976914bfe98 100644 --- a/sql/core/src/test/resources/sql-tests/results/cast.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/cast.sql.out @@ -668,6 +668,80 @@ org.apache.spark.SparkArithmeticException [CAST_OVERFLOW] The value INTERVAL '1000000' SECOND of the type "INTERVAL SECOND" cannot be cast to "SMALLINT" due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error. +-- !query +select cast(1Y as interval year) +-- !query schema +struct +-- !query output +1-0 + + +-- !query +select cast(-122S as interval year to month) +-- !query schema +struct +-- !query output +-10-2 + + +-- !query +select cast(1000 as interval month) +-- !query schema +struct +-- !query output +83-4 + + +-- !query +select cast(-10L as interval second) +-- !query schema +struct +-- !query output +-0 00:00:10.000000000 + + +-- !query +select cast(100Y as interval hour to second) +-- !query schema +struct +-- !query output +0 00:01:40.000000000 + + +-- !query +select cast(-1000S as interval day to second) +-- !query schema +struct +-- !query output +-0 00:16:40.000000000 + + +-- !query +select cast(10 as interval day) +-- !query schema +struct +-- !query output +10 00:00:00.000000000 + + +-- !query +select cast(2147483647 as interval year) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +[CAST_OVERFLOW] The value 2147483647 of the type "INT" cannot be cast to "INTERVAL YEAR" due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error. + + +-- !query +select cast(-9223372036854775808L as interval day) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +[CAST_OVERFLOW] The value -9223372036854775808L of the type "BIGINT" cannot be cast to "INTERVAL DAY" due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error. + + -- !query select cast(interval '-1' year as decimal(10, 0)) -- !query schema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index e872b6aaa640e..2581afd1df346 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -1203,4 +1203,17 @@ class DataFrameWindowFunctionsSuite extends QueryTest ) ) } + + test("SPARK-40002: ntile should apply before limit") { + val df = Seq.tabulate(101)(identity).toDF("id") + val w = Window.orderBy("id") + checkAnswer( + df.select($"id", ntile(10).over(w)).limit(3), + Seq( + Row(0, 1), + Row(1, 1), + Row(2, 1) + ) + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala index 8a635807abbbb..1a9baecfa747a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala @@ -21,7 +21,6 @@ import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalog.Table import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, Join, JoinHint} -import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType @@ -166,7 +165,7 @@ class GlobalTempViewSuite extends QueryTest with SharedSparkSession { assert(spark.catalog.tableExists(globalTempDB, "src")) assert(spark.catalog.getTable(globalTempDB, "src").toString == new Table( name = "src", - catalog = CatalogManager.SESSION_CATALOG_NAME, + catalog = null, namespace = Array(globalTempDB), description = null, tableType = "TEMPORARY", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 0de48325d981e..ab26a4fcc35f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -182,6 +182,31 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf assert(spark.catalog.listTables().collect().map(_.name).toSet == Set("my_table2")) } + test("SPARK-39828: Catalog.listTables() should respect currentCatalog") { + assert(spark.catalog.currentCatalog() == "spark_catalog") + assert(spark.catalog.listTables().collect().isEmpty) + createTable("my_table1") + assert(spark.catalog.listTables().collect().map(_.name).toSet == Set("my_table1")) + + val catalogName = "testcat" + val dbName = "my_db" + val tableName = "my_table2" + val tableSchema = new StructType().add("i", "int") + val description = "this is a test managed table" + sql(s"CREATE NAMESPACE $catalogName.$dbName") + + spark.catalog.setCurrentCatalog("testcat") + spark.catalog.setCurrentDatabase("my_db") + assert(spark.catalog.listTables().collect().isEmpty) + + createTable(tableName, dbName, catalogName, classOf[FakeV2Provider].getName, tableSchema, + Map.empty[String, String], description) + assert(spark.catalog.listTables() + .collect() + .map(t => Array(t.catalog, t.namespace.mkString("."), t.name).mkString(".")).toSet == + Set("testcat.my_db.my_table2")) + } + test("list tables with database") { assert(spark.catalog.listTables("default").collect().isEmpty) createDatabase("my_db1") @@ -229,6 +254,33 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf assert(!funcNames2.contains("my_temp_func")) } + test("SPARK-39828: Catalog.listFunctions() should respect currentCatalog") { + assert(spark.catalog.currentCatalog() == "spark_catalog") + assert(Set("+", "current_database", "window").subsetOf( + spark.catalog.listFunctions().collect().map(_.name).toSet)) + createFunction("my_func") + assert(spark.catalog.listFunctions().collect().map(_.name).contains("my_func")) + + sql(s"CREATE NAMESPACE testcat.ns") + spark.catalog.setCurrentCatalog("testcat") + spark.catalog.setCurrentDatabase("ns") + + val funcCatalog = spark.sessionState.catalogManager.catalog("testcat") + .asInstanceOf[InMemoryCatalog] + val function: UnboundFunction = new UnboundFunction { + override def bind(inputType: StructType): BoundFunction = new ScalarFunction[Int] { + override def inputTypes(): Array[DataType] = Array(IntegerType) + override def resultType(): DataType = IntegerType + override def name(): String = "my_bound_function" + } + override def description(): String = "my_function" + override def name(): String = "my_function" + } + assert(!spark.catalog.listFunctions().collect().map(_.name).contains("my_func")) + funcCatalog.createFunction(Identifier.of(Array("ns"), "my_func"), function) + assert(spark.catalog.listFunctions().collect().map(_.name).contains("my_func")) + } + test("list functions with database") { assert(Set("+", "current_database", "window").subsetOf( spark.catalog.listFunctions().collect().map(_.name).toSet)) @@ -283,7 +335,7 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf testListColumns("tab1", dbName = Some("db1")) } - test("SPARK-39615: three layer namespace compatibility - listColumns") { + test("SPARK-39615: qualified name with catalog - listColumns") { val answers = Map( "col1" -> ("int", true, false, true), "col2" -> ("string", true, false, false), @@ -637,7 +689,7 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf assert(errMsg.contains("my_temp_table is a temp view. 'recoverPartitions()' expects a table")) } - test("three layer namespace compatibility - create managed table") { + test("qualified name with catalog - create managed table") { val catalogName = "testcat" val dbName = "my_db" val tableName = "my_table" @@ -656,7 +708,7 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf assert(table.properties().get("comment").equals(description)) } - test("three layer namespace compatibility - create external table") { + test("qualified name with catalog - create external table") { withTempDir { dir => val catalogName = "testcat" val dbName = "my_db" @@ -680,7 +732,7 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf } } - test("three layer namespace compatibility - list tables") { + test("qualified name with catalog - list tables") { withTempDir { dir => val catalogName = "testcat" val dbName = "my_db" @@ -729,7 +781,7 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf Set("my_table1", "my_table2", "my_temp_table")) } - test("three layer namespace compatibility - get table") { + test("qualified name with catalog - get table") { val catalogName = "testcat" val dbName = "default" val tableName = "my_table" @@ -757,7 +809,7 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf assert(t2.catalog == CatalogManager.SESSION_CATALOG_NAME) } - test("three layer namespace compatibility - table exists") { + test("qualified name with catalog - table exists") { val catalogName = "testcat" val dbName = "my_db" val tableName = "my_table" @@ -781,7 +833,7 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf assert(spark.catalog.tableExists(Array(catalogName, dbName, tableName).mkString("."))) } - test("three layer namespace compatibility - database exists") { + test("qualified name with catalog - database exists") { val catalogName = "testcat" val dbName = "my_db" assert(!spark.catalog.databaseExists(Array(catalogName, dbName).mkString("."))) @@ -793,7 +845,7 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf assert(!spark.catalog.databaseExists(Array(catalogName2, dbName).mkString("."))) } - test("SPARK-39506: three layer namespace compatibility - cache table, isCached and" + + test("SPARK-39506: qualified name with catalog - cache table, isCached and" + "uncacheTable") { val tableSchema = new StructType().add("i", "int") createTable("my_table", "my_db", "testcat", classOf[FakeV2Provider].getName, @@ -840,7 +892,7 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf } } - test("three layer namespace compatibility - get database") { + test("qualified name with catalogy - get database") { val catalogsAndDatabases = Seq(("testcat", "somedb"), ("testcat", "ns.somedb"), ("spark_catalog", "somedb")) catalogsAndDatabases.foreach { case (catalog, dbName) => @@ -863,7 +915,7 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf intercept[AnalysisException](spark.catalog.getDatabase("randomcat.db10")) } - test("three layer namespace compatibility - get database, same in hive and testcat") { + test("qualified name with catalog - get database, same in hive and testcat") { // create 'testdb' in hive and testcat val dbName = "testdb" sql(s"CREATE NAMESPACE spark_catalog.$dbName COMMENT 'hive database'") @@ -883,7 +935,7 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf assert(spark.catalog.getDatabase(qualified).name === db) } - test("three layer namespace compatibility - set current database") { + test("qualified name with catalog - set current database") { spark.catalog.setCurrentCatalog("testcat") // namespace with the same name as catalog sql("CREATE NAMESPACE testcat.testcat.my_db") @@ -912,8 +964,7 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf assert(e3.contains("unknown_db")) } - test("SPARK-39579: Three layer namespace compatibility - " + - "listFunctions, getFunction, functionExists") { + test("SPARK-39579: qualified name with catalog - listFunctions, getFunction, functionExists") { createDatabase("my_db1") createFunction("my_func1", Some("my_db1")) @@ -931,8 +982,8 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf val func1b = spark.catalog.getFunction("spark_catalog.my_db1.my_func1") assert(func1a.name === func1b.name && func1a.namespace === func1b.namespace && func1a.className === func1b.className && func1a.isTemporary === func1b.isTemporary) - assert(func1a.catalog === null && func1b.catalog === "spark_catalog") - assert(func1a.description === null && func1b.description === "N/A.") + assert(func1a.catalog === "spark_catalog" && func1b.catalog === "spark_catalog") + assert(func1a.description === "N/A." && func1b.description === "N/A.") val function: UnboundFunction = new UnboundFunction { override def bind(inputType: StructType): BoundFunction = new ScalarFunction[Int] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 02dff0973fe12..a8c770f46cd67 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -775,59 +775,46 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df5, Seq(Row(1, "cathy", 9000.00, 1200.0, false), Row(1, "amy", 10000.00, 1000.0, true))) + val name = udf { (x: String) => x.matches("cat|dav|amy") } + val sub = udf { (x: String) => x.substring(0, 3) } val df6 = spark.read .table("h2.test.employee") - .groupBy("DEPT").sum("SALARY") - .orderBy("DEPT") + .select($"SALARY", $"BONUS", sub($"NAME").as("shortName")) + .filter(name($"shortName")) + .sort($"SALARY".desc) .limit(1) + // LIMIT is pushed down only if all the filters are pushed down checkSortRemoved(df6, false) checkLimitRemoved(df6, false) - checkPushedInfo(df6, - "PushedAggregates: [SUM(SALARY)]", - "PushedFilters: []", - "PushedGroupByExpressions: [DEPT]") - checkAnswer(df6, Seq(Row(1, 19000.00))) + checkPushedInfo(df6, "PushedFilters: []") + checkAnswer(df6, Seq(Row(10000.00, 1000.0, "amy"))) - val name = udf { (x: String) => x.matches("cat|dav|amy") } - val sub = udf { (x: String) => x.substring(0, 3) } val df7 = spark.read .table("h2.test.employee") - .select($"SALARY", $"BONUS", sub($"NAME").as("shortName")) - .filter(name($"shortName")) - .sort($"SALARY".desc) + .sort(sub($"NAME")) .limit(1) - // LIMIT is pushed down only if all the filters are pushed down checkSortRemoved(df7, false) checkLimitRemoved(df7, false) checkPushedInfo(df7, "PushedFilters: []") - checkAnswer(df7, Seq(Row(10000.00, 1000.0, "amy"))) + checkAnswer(df7, Seq(Row(2, "alex", 12000.00, 1200.0, false))) val df8 = spark.read - .table("h2.test.employee") - .sort(sub($"NAME")) - .limit(1) - checkSortRemoved(df8, false) - checkLimitRemoved(df8, false) - checkPushedInfo(df8, "PushedFilters: []") - checkAnswer(df8, Seq(Row(2, "alex", 12000.00, 1200.0, false))) - - val df9 = spark.read .table("h2.test.employee") .select($"DEPT", $"name", $"SALARY", when(($"SALARY" > 8000).and($"SALARY" < 10000), $"salary").otherwise(0).as("key")) .sort("key", "dept", "SALARY") .limit(3) - checkSortRemoved(df9) - checkLimitRemoved(df9) - checkPushedInfo(df9, + checkSortRemoved(df8) + checkLimitRemoved(df8) + checkPushedInfo(df8, "PushedFilters: []", - "PushedTopN: " + - "ORDER BY [CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END " + - "ASC NULLS FIRST, DEPT ASC NULLS FIRST, SALARY ASC NULLS FIRST] LIMIT 3,") - checkAnswer(df9, + "PushedTopN: ORDER BY " + + "[CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END" + + " ASC NULLS FIRST, DEPT ASC NULLS FIRST, SALARY ASC NULLS FIRST] LIMIT 3") + checkAnswer(df8, Seq(Row(1, "amy", 10000, 0), Row(2, "david", 10000, 0), Row(2, "alex", 12000, 0))) - val df10 = spark.read + val df9 = spark.read .option("partitionColumn", "dept") .option("lowerBound", "0") .option("upperBound", "2") @@ -837,14 +824,14 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel when(($"SALARY" > 8000).and($"SALARY" < 10000), $"salary").otherwise(0).as("key")) .orderBy($"key", $"dept", $"SALARY") .limit(3) - checkSortRemoved(df10, false) - checkLimitRemoved(df10, false) - checkPushedInfo(df10, + checkSortRemoved(df9, false) + checkLimitRemoved(df9, false) + checkPushedInfo(df9, "PushedFilters: []", - "PushedTopN: " + - "ORDER BY [CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END " + - "ASC NULLS FIRST, DEPT ASC NULLS FIRST, SALARY ASC NULLS FIRST] LIMIT 3,") - checkAnswer(df10, + "PushedTopN: ORDER BY " + + "[CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END " + + "ASC NULLS FIRST, DEPT ASC NULLS FIRST, SALARY ASC NULLS FIRST] LIMIT 3") + checkAnswer(df9, Seq(Row(1, "amy", 10000, 0), Row(2, "david", 10000, 0), Row(2, "alex", 12000, 0))) } @@ -873,6 +860,196 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df2, Seq(Row(2, "david", 10000.00))) } + test("scan with aggregate push-down, top N push-down and offset push-down") { + val df1 = spark.read + .table("h2.test.employee") + .groupBy("DEPT").sum("SALARY") + .orderBy("DEPT") + + val paging1 = df1.offset(1).limit(1) + checkSortRemoved(paging1) + checkLimitRemoved(paging1) + checkPushedInfo(paging1, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedOffset: OFFSET 1", + "PushedTopN: ORDER BY [DEPT ASC NULLS FIRST] LIMIT 2") + checkAnswer(paging1, Seq(Row(2, 22000.00))) + + val topN1 = df1.limit(1) + checkSortRemoved(topN1) + checkLimitRemoved(topN1) + checkPushedInfo(topN1, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedTopN: ORDER BY [DEPT ASC NULLS FIRST] LIMIT 1") + checkAnswer(topN1, Seq(Row(1, 19000.00))) + + val df2 = spark.read + .table("h2.test.employee") + .select($"DEPT".cast("string").as("my_dept"), $"SALARY") + .groupBy("my_dept").sum("SALARY") + .orderBy("my_dept") + + val paging2 = df2.offset(1).limit(1) + checkSortRemoved(paging2) + checkLimitRemoved(paging2) + checkPushedInfo(paging2, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [CAST(DEPT AS string)]", + "PushedFilters: []", + "PushedOffset: OFFSET 1", + "PushedTopN: ORDER BY [CAST(DEPT AS string) ASC NULLS FIRST] LIMIT 2") + checkAnswer(paging2, Seq(Row("2", 22000.00))) + + val topN2 = df2.limit(1) + checkSortRemoved(topN2) + checkLimitRemoved(topN2) + checkPushedInfo(topN2, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [CAST(DEPT AS string)]", + "PushedFilters: []", + "PushedTopN: ORDER BY [CAST(DEPT AS string) ASC NULLS FIRST] LIMIT 1") + checkAnswer(topN2, Seq(Row("1", 19000.00))) + + val df3 = spark.read + .table("h2.test.employee") + .groupBy("dept").sum("SALARY") + .orderBy($"dept".cast("string")) + + val paging3 = df3.offset(1).limit(1) + checkSortRemoved(paging3) + checkLimitRemoved(paging3) + checkPushedInfo(paging3, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedOffset: OFFSET 1", + "PushedTopN: ORDER BY [CAST(DEPT AS string) ASC NULLS FIRST] LIMIT 2") + checkAnswer(paging3, Seq(Row(2, 22000.00))) + + val topN3 = df3.limit(1) + checkSortRemoved(topN3) + checkLimitRemoved(topN3) + checkPushedInfo(topN3, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedTopN: ORDER BY [CAST(DEPT AS string) ASC NULLS FIRST] LIMIT 1") + checkAnswer(topN3, Seq(Row(1, 19000.00))) + + val df4 = spark.read + .table("h2.test.employee") + .groupBy("DEPT", "IS_MANAGER").sum("SALARY") + .orderBy("DEPT", "IS_MANAGER") + + val paging4 = df4.offset(1).limit(1) + checkSortRemoved(paging4) + checkLimitRemoved(paging4) + checkPushedInfo(paging4, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT, IS_MANAGER]", + "PushedFilters: []", + "PushedOffset: OFFSET 1", + "PushedTopN: ORDER BY [DEPT ASC NULLS FIRST, IS_MANAGER ASC NULLS FIRST] LIMIT 2") + checkAnswer(paging4, Seq(Row(1, true, 10000.00))) + + val topN4 = df4.limit(1) + checkSortRemoved(topN4) + checkLimitRemoved(topN4) + checkPushedInfo(topN4, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT, IS_MANAGER]", + "PushedFilters: []", + "PushedTopN: ORDER BY [DEPT ASC NULLS FIRST, IS_MANAGER ASC NULLS FIRST] LIMIT 1") + checkAnswer(topN4, Seq(Row(1, false, 9000.00))) + + val df5 = spark.read + .table("h2.test.employee") + .select($"SALARY", $"IS_MANAGER", $"DEPT".cast("string").as("my_dept")) + .groupBy("my_dept", "IS_MANAGER").sum("SALARY") + .orderBy("my_dept", "IS_MANAGER") + + val paging5 = df5.offset(1).limit(1) + checkSortRemoved(paging5) + checkLimitRemoved(paging5) + checkPushedInfo(paging5, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [CAST(DEPT AS string), IS_MANAGER]", + "PushedFilters: []", + "PushedOffset: OFFSET 1", + "PushedTopN: " + + "ORDER BY [CAST(DEPT AS string) ASC NULLS FIRST, IS_MANAGER ASC NULLS FIRST] LIMIT 2") + checkAnswer(paging5, Seq(Row("1", true, 10000.00))) + + val topN5 = df5.limit(1) + checkSortRemoved(topN5) + checkLimitRemoved(topN5) + checkPushedInfo(topN5, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [CAST(DEPT AS string), IS_MANAGER]", + "PushedFilters: []", + "PushedTopN: " + + "ORDER BY [CAST(DEPT AS string) ASC NULLS FIRST, IS_MANAGER ASC NULLS FIRST] LIMIT 1") + checkAnswer(topN5, Seq(Row("1", false, 9000.00))) + + val df6 = spark.read + .table("h2.test.employee") + .select($"DEPT", $"SALARY") + .groupBy("dept").agg(sum("SALARY")) + .orderBy(sum("SALARY")) + + val paging6 = df6.offset(1).limit(1) + checkSortRemoved(paging6) + checkLimitRemoved(paging6) + checkPushedInfo(paging6, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedOffset: OFFSET 1", + "PushedTopN: ORDER BY [SUM(SALARY) ASC NULLS FIRST] LIMIT 2") + checkAnswer(paging6, Seq(Row(1, 19000.00))) + + val topN6 = df6.limit(1) + checkSortRemoved(topN6) + checkLimitRemoved(topN6) + checkPushedInfo(topN6, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedTopN: ORDER BY [SUM(SALARY) ASC NULLS FIRST] LIMIT 1") + checkAnswer(topN6, Seq(Row(6, 12000.00))) + + val df7 = spark.read + .table("h2.test.employee") + .select($"DEPT", $"SALARY") + .groupBy("dept").agg(sum("SALARY").as("total")) + .orderBy("total") + + val paging7 = df7.offset(1).limit(1) + checkSortRemoved(paging7) + checkLimitRemoved(paging7) + checkPushedInfo(paging7, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedOffset: OFFSET 1", + "PushedTopN: ORDER BY [SUM(SALARY) ASC NULLS FIRST] LIMIT 2") + checkAnswer(paging7, Seq(Row(1, 19000.00))) + + val topN7 = df7.limit(1) + checkSortRemoved(topN7) + checkLimitRemoved(topN7) + checkPushedInfo(topN7, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedTopN: ORDER BY [SUM(SALARY) ASC NULLS FIRST] LIMIT 1") + checkAnswer(topN7, Seq(Row(6, 12000.00))) + } + test("scan with filter push-down") { val df = spark.table("h2.test.people").filter($"id" > 1) checkFiltersRemoved(df) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index d1e222794a526..42bf1e31bb04a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -276,7 +276,7 @@ private[hive] trait HiveStrategies { */ object HiveTableScans extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case ScanOperation(projectList, filters, relation: HiveTableRelation) => + case PhysicalOperation(projectList, filters, relation: HiveTableRelation) => // Filter out all predicates that only deal with partition keys, these are given to the // hive table scan operator to be used for partition pruning. val partitionKeyIds = AttributeSet(relation.partitionCols) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 12bb1b3631c9a..e65e6d42937c1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -69,7 +69,7 @@ private[hive] object IsolatedClientLoader extends Logging { // If the error message contains hadoop, it is probably because the hadoop // version cannot be resolved. val fallbackVersion = if (VersionUtils.isHadoop3) { - "3.3.3" + "3.3.4" } else { "2.7.4" }