diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index 04faf7f87cf2..8ffccdf664b2 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -176,7 +176,7 @@ private[spark] class BarrierCoordinator( logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " + s"$taskId, current progress: ${requesters.size}/$numTasks.") if (requesters.size == numTasks) { - requesters.foreach(_.reply(messages)) + requesters.foreach(_.reply(messages.clone())) // Finished current barrier() call successfully, clean up ContextBarrierState and // increase the barrier epoch. logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received all updates from " + diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala index 4f97003e2ed5..26cd5374fa09 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -367,4 +367,27 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with // double check we kill task success assert(System.currentTimeMillis() - startTime < 5000) } + + test("SPARK-40932, messages of allGather should not been overridden " + + "by the following barrier APIs") { + + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local[2]")) + sc.setLogLevel("INFO") + val rdd = sc.makeRDD(1 to 10, 2) + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() + // Sleep for a random time before global sync. + Thread.sleep(Random.nextInt(1000)) + // Pass partitionId message in + val message: String = context.partitionId().toString + val messages: Array[String] = context.allGather(message) + context.barrier() + Iterator.single(messages.toList) + } + val messages = rdd2.collect() + // All the task partitionIds are shared across all tasks + assert(messages.length === 2) + assert(messages.forall(_ == List("0", "1"))) + } + } diff --git a/dev/requirements.txt b/dev/requirements.txt index fa4b6752f145..2f32066d6a88 100644 --- a/dev/requirements.txt +++ b/dev/requirements.txt @@ -13,6 +13,7 @@ matplotlib<3.3.0 # PySpark test dependencies unittest-xml-reporting +openpyxl # PySpark test dependencies (optional) coverage diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 459b05cc37aa..17d50f0f50e9 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -18,14 +18,18 @@ import unittest import shutil import tempfile +from pyspark.testing.sqlutils import have_pandas -import pandas +if have_pandas: + import pandas from pyspark.sql import SparkSession, Row from pyspark.sql.types import StructType, StructField, LongType, StringType -from pyspark.sql.connect.client import RemoteSparkSession -from pyspark.sql.connect.function_builder import udf -from pyspark.sql.connect.functions import lit + +if have_pandas: + from pyspark.sql.connect.client import RemoteSparkSession + from pyspark.sql.connect.function_builder import udf + from pyspark.sql.connect.functions import lit from pyspark.sql.dataframe import DataFrame from pyspark.testing.connectutils import should_test_connect, connect_requirement_message from pyspark.testing.utils import ReusedPySparkTestCase @@ -36,7 +40,8 @@ class SparkConnectSQLTestCase(ReusedPySparkTestCase): """Parent test fixture class for all Spark Connect related test cases.""" - connect: RemoteSparkSession + if have_pandas: + connect: RemoteSparkSession tbl_name: str df_text: "DataFrame" diff --git a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py index 6036b63d76f2..790a987e8809 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py +++ b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py @@ -15,14 +15,20 @@ # limitations under the License. # +from typing import cast +import unittest from pyspark.testing.connectutils import PlanOnlyTestFixture -from pyspark.sql.connect.proto import Expression as ProtoExpression -import pyspark.sql.connect as c -import pyspark.sql.connect.plan as p -import pyspark.sql.connect.column as col -import pyspark.sql.connect.functions as fun +from pyspark.testing.sqlutils import have_pandas, pandas_requirement_message +if have_pandas: + from pyspark.sql.connect.proto import Expression as ProtoExpression + import pyspark.sql.connect as c + import pyspark.sql.connect.plan as p + import pyspark.sql.connect.column as col + import pyspark.sql.connect.functions as fun + +@unittest.skipIf(not have_pandas, cast(str, pandas_requirement_message)) class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture): def test_simple_column_expressions(self): df = c.DataFrame.withPlan(p.Read("table")) diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py index 450f5c70faba..14b939e019ba 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py @@ -14,15 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from typing import cast import unittest from pyspark.testing.connectutils import PlanOnlyTestFixture -import pyspark.sql.connect.proto as proto -from pyspark.sql.connect.readwriter import DataFrameReader -from pyspark.sql.connect.function_builder import UserDefinedFunction, udf -from pyspark.sql.types import StringType +from pyspark.testing.sqlutils import have_pandas, pandas_requirement_message +if have_pandas: + import pyspark.sql.connect.proto as proto + from pyspark.sql.connect.readwriter import DataFrameReader + from pyspark.sql.connect.function_builder import UserDefinedFunction, udf + from pyspark.sql.types import StringType + +@unittest.skipIf(not have_pandas, cast(str, pandas_requirement_message)) class SparkConnectTestsPlanOnly(PlanOnlyTestFixture): """These test cases exercise the interface to the proto plan generation but do not call Spark.""" diff --git a/python/pyspark/sql/tests/connect/test_connect_select_ops.py b/python/pyspark/sql/tests/connect/test_connect_select_ops.py index e89b4b34ea01..a29c70541462 100644 --- a/python/pyspark/sql/tests/connect/test_connect_select_ops.py +++ b/python/pyspark/sql/tests/connect/test_connect_select_ops.py @@ -14,13 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from typing import cast +import unittest + from pyspark.testing.connectutils import PlanOnlyTestFixture -from pyspark.sql.connect import DataFrame -from pyspark.sql.connect.functions import col -from pyspark.sql.connect.plan import Read -import pyspark.sql.connect.proto as proto +from pyspark.testing.sqlutils import have_pandas, pandas_requirement_message + +if have_pandas: + from pyspark.sql.connect import DataFrame + from pyspark.sql.connect.functions import col + from pyspark.sql.connect.plan import Read + import pyspark.sql.connect.proto as proto +@unittest.skipIf(not have_pandas, cast(str, pandas_requirement_message)) class SparkConnectToProtoSuite(PlanOnlyTestFixture): def test_select_with_columns_and_strings(self): df = DataFrame.withPlan(Read("table")) diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index 700b7bb72e18..d9bced3af114 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -18,13 +18,18 @@ from typing import Any, Dict import functools import unittest +from pyspark.testing.sqlutils import have_pandas -from pyspark.sql.connect import DataFrame -from pyspark.sql.connect.plan import Read -from pyspark.testing.utils import search_jar +if have_pandas: + from pyspark.sql.connect import DataFrame + from pyspark.sql.connect.plan import Read + from pyspark.testing.utils import search_jar + + connect_jar = search_jar("connector/connect", "spark-connect-assembly-", "spark-connect") +else: + connect_jar = None -connect_jar = search_jar("connector/connect", "spark-connect-assembly-", "spark-connect") if connect_jar is None: connect_requirement_message = ( "Skipping all Spark Connect Python tests as the optional Spark Connect project was " @@ -38,7 +43,7 @@ os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, plugin_args, existing_args]) connect_requirement_message = None # type: ignore -should_test_connect = connect_requirement_message is None +should_test_connect = connect_requirement_message is None and have_pandas class MockRemoteSession: diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 04de2bdcb51c..b718f410be6b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -119,8 +119,7 @@ class AnalysisErrorSuite extends AnalysisTest { messageParameters: Map[String, String], caseSensitive: Boolean = true): Unit = { test(name) { - assertAnalysisErrorClass(plan, errorClass, messageParameters, - caseSensitive = true, line = -1, pos = -1) + assertAnalysisErrorClass(plan, errorClass, messageParameters, caseSensitive = caseSensitive) } } @@ -899,9 +898,8 @@ class AnalysisErrorSuite extends AnalysisTest { "inputSql" -> inputSql, "inputType" -> inputType, "requiredType" -> "(\"INT\" or \"BIGINT\")"), - caseSensitive = false, - line = -1, - pos = -1) + caseSensitive = false + ) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisExceptionPositionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisExceptionPositionSuite.scala index 7b720a7a0472..be256adbd892 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisExceptionPositionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisExceptionPositionSuite.scala @@ -52,8 +52,8 @@ class AnalysisExceptionPositionSuite extends AnalysisTest { parsePlan("SHOW COLUMNS FROM unknown IN db"), "TABLE_OR_VIEW_NOT_FOUND", Map("relationName" -> "`db`.`unknown`"), - line = 1, - pos = 18) + Array(ExpectedContext("unknown", 18, 24)) + ) verifyTableOrViewPosition("ALTER TABLE unknown RENAME TO t", "unknown") verifyTableOrViewPosition("ALTER VIEW unknown RENAME TO v", "unknown") } @@ -92,13 +92,13 @@ class AnalysisExceptionPositionSuite extends AnalysisTest { } private def verifyPosition(sql: String, table: String): Unit = { - val expectedPos = sql.indexOf(table) - assert(expectedPos != -1) + val startPos = sql.indexOf(table) + assert(startPos != -1) assertAnalysisErrorClass( parsePlan(sql), "TABLE_OR_VIEW_NOT_FOUND", Map("relationName" -> s"`$table`"), - line = 1, - pos = expectedPos) + Array(ExpectedContext(table, startPos, startPos + table.length - 1)) + ) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 6f0e6ef0c110..c1106a265451 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -104,10 +104,8 @@ class AnalysisSuite extends AnalysisTest with Matchers { Project(Seq(UnresolvedAttribute("tBl.a")), SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")))), "UNRESOLVED_COLUMN.WITH_SUGGESTION", - Map("objectName" -> "`tBl`.`a`", "proposal" -> "`TbL`.`a`"), - caseSensitive = true, - line = -1, - pos = -1) + Map("objectName" -> "`tBl`.`a`", "proposal" -> "`TbL`.`a`") + ) checkAnalysisWithoutViewWrapper( Project(Seq(UnresolvedAttribute("TbL.a")), @@ -716,9 +714,8 @@ class AnalysisSuite extends AnalysisTest with Matchers { assertAnalysisErrorClass(parsePlan("WITH t(x) AS (SELECT 1) SELECT * FROM t WHERE y = 1"), "UNRESOLVED_COLUMN.WITH_SUGGESTION", Map("objectName" -> "`y`", "proposal" -> "`t`.`x`"), - caseSensitive = true, - line = -1, - pos = -1) + Array(ExpectedContext("y", 46, 46)) + ) } test("CTE with non-matching column alias") { @@ -729,7 +726,8 @@ class AnalysisSuite extends AnalysisTest with Matchers { test("SPARK-28251: Insert into non-existing table error message is user friendly") { assertAnalysisErrorClass(parsePlan("INSERT INTO test VALUES (1)"), - "TABLE_OR_VIEW_NOT_FOUND", Map("relationName" -> "`test`")) + "TABLE_OR_VIEW_NOT_FOUND", Map("relationName" -> "`test`"), + Array(ExpectedContext("test", 12, 15))) } test("check CollectMetrics resolved") { @@ -1157,9 +1155,8 @@ class AnalysisSuite extends AnalysisTest with Matchers { |""".stripMargin), "UNRESOLVED_COLUMN.WITH_SUGGESTION", Map("objectName" -> "`c`.`y`", "proposal" -> "`x`"), - caseSensitive = true, - line = -1, - pos = -1) + Array(ExpectedContext("c.y", 123, 125)) + ) } test("SPARK-38118: Func(wrong_type) in the HAVING clause should throw data mismatch error") { @@ -1178,7 +1175,9 @@ class AnalysisSuite extends AnalysisTest with Matchers { "inputSql" -> "\"c\"", "inputType" -> "\"BOOLEAN\"", "requiredType" -> "\"NUMERIC\" or \"ANSI INTERVAL\""), - caseSensitive = false) + queryContext = Array(ExpectedContext("mean(t.c)", 65, 73)), + caseSensitive = false + ) assertAnalysisErrorClass( inputPlan = parsePlan( @@ -1195,6 +1194,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { "inputSql" -> "\"c\"", "inputType" -> "\"BOOLEAN\"", "requiredType" -> "\"NUMERIC\" or \"ANSI INTERVAL\""), + queryContext = Array(ExpectedContext("mean(c)", 91, 97)), caseSensitive = false) assertAnalysisErrorClass( @@ -1213,9 +1213,9 @@ class AnalysisSuite extends AnalysisTest with Matchers { "inputType" -> "\"BOOLEAN\"", "requiredType" -> "(\"NUMERIC\" or \"INTERVAL DAY TO SECOND\" or \"INTERVAL YEAR TO MONTH\")"), - caseSensitive = false, - line = -1, - pos = -1) + queryContext = Array(ExpectedContext("abs(t.c)", 65, 72)), + caseSensitive = false + ) assertAnalysisErrorClass( inputPlan = parsePlan( @@ -1233,9 +1233,9 @@ class AnalysisSuite extends AnalysisTest with Matchers { "inputType" -> "\"BOOLEAN\"", "requiredType" -> "(\"NUMERIC\" or \"INTERVAL DAY TO SECOND\" or \"INTERVAL YEAR TO MONTH\")"), - caseSensitive = false, - line = -1, - pos = -1) + queryContext = Array(ExpectedContext("abs(c)", 91, 96)), + caseSensitive = false + ) } test("SPARK-39354: should be [TABLE_OR_VIEW_NOT_FOUND]") { @@ -1246,7 +1246,8 @@ class AnalysisSuite extends AnalysisTest with Matchers { |FROM t1 |JOIN t2 ON t1.user_id = t2.user_id |WHERE t1.dt >= DATE_SUB('2020-12-27', 90)""".stripMargin), - "TABLE_OR_VIEW_NOT_FOUND", Map("relationName" -> "`t2`")) + "TABLE_OR_VIEW_NOT_FOUND", Map("relationName" -> "`t2`"), + Array(ExpectedContext("t2", 84, 85))) } test("SPARK-39144: nested subquery expressions deduplicate relations should be done bottom up") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index a195b76d7c43..5e7395d905d2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -34,6 +34,8 @@ import org.apache.spark.sql.types.StructType trait AnalysisTest extends PlanTest { + import org.apache.spark.QueryContext + protected def extendedAnalysisRules: Seq[Rule[LogicalPlan]] = Nil protected def createTempView( @@ -174,40 +176,19 @@ trait AnalysisTest extends PlanTest { inputPlan: LogicalPlan, expectedErrorClass: String, expectedMessageParameters: Map[String, String], - caseSensitive: Boolean = true, - line: Int = -1, - pos: Int = -1): Unit = { + queryContext: Array[QueryContext] = Array.empty, + caseSensitive: Boolean = true): Unit = { withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { val analyzer = getAnalyzer val e = intercept[AnalysisException] { analyzer.checkAnalysis(analyzer.execute(inputPlan)) } - - if (e.getErrorClass != expectedErrorClass || - e.messageParameters != expectedMessageParameters || - (line >= 0 && e.line.getOrElse(-1) != line) || - (pos >= 0) && e.startPosition.getOrElse(-1) != pos) { - var failMsg = "" - if (e.getErrorClass != expectedErrorClass) { - failMsg += - s"""Error class should be: ${expectedErrorClass} - |Actual error class: ${e.getErrorClass} - """.stripMargin - } - if (e.messageParameters != expectedMessageParameters) { - failMsg += - s"""Message parameters should be: ${expectedMessageParameters.mkString("\n ")} - |Actual message parameters: ${e.messageParameters.mkString("\n ")} - """.stripMargin - } - if (e.line.getOrElse(-1) != line || e.startPosition.getOrElse(-1) != pos) { - failMsg += - s"""Line/position should be: $line, $pos - |Actual line/position: ${e.line.getOrElse(-1)}, ${e.startPosition.getOrElse(-1)} - """.stripMargin - } - fail(failMsg) - } + checkError( + exception = e, + errorClass = expectedErrorClass, + parameters = expectedMessageParameters, + queryContext = queryContext + ) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala index 716d7aeb60fb..b1d569be5ba2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala @@ -134,10 +134,8 @@ class ResolveSubquerySuite extends AnalysisTest { assertAnalysisErrorClass( lateralJoin(t1, lateralJoin(t2, t0.select($"a", $"b", $"c"))), "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", - Map("objectName" -> "`a`"), - caseSensitive = true, - line = -1, - pos = -1) + Map("objectName" -> "`a`") + ) } test("lateral subquery with unresolvable attributes") { @@ -145,34 +143,26 @@ class ResolveSubquerySuite extends AnalysisTest { assertAnalysisErrorClass( lateralJoin(t1, t0.select($"a", $"c")), "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", - Map("objectName" -> "`c`"), - caseSensitive = true, - line = -1, - pos = -1) + Map("objectName" -> "`c`") + ) // SELECT * FROM t1, LATERAL (SELECT a, b, c, d FROM t2) assertAnalysisErrorClass( lateralJoin(t1, t2.select($"a", $"b", $"c", $"d")), "UNRESOLVED_COLUMN.WITH_SUGGESTION", - Map("objectName" -> "`d`", "proposal" -> "`b`, `c`"), - caseSensitive = true, - line = -1, - pos = -1) + Map("objectName" -> "`d`", "proposal" -> "`b`, `c`") + ) // SELECT * FROM t1, LATERAL (SELECT * FROM t2, LATERAL (SELECT t1.a)) assertAnalysisErrorClass( lateralJoin(t1, lateralJoin(t2, t0.select($"t1.a"))), "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", - Map("objectName" -> "`t1`.`a`"), - caseSensitive = true, - line = -1, - pos = -1) + Map("objectName" -> "`t1`.`a`") + ) // SELECT * FROM t1, LATERAL (SELECT * FROM t2, LATERAL (SELECT a, b)) assertAnalysisErrorClass( lateralJoin(t1, lateralJoin(t2, t0.select($"a", $"b"))), "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", - Map("objectName" -> "`a`"), - caseSensitive = true, - line = -1, - pos = -1) + Map("objectName" -> "`a`") + ) } test("lateral subquery with struct type") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala index b87531832970..9f9df10b398b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala @@ -691,10 +691,8 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { assertAnalysisErrorClass( parsedPlan, "UNRESOLVED_COLUMN.WITH_SUGGESTION", - Map("objectName" -> "`a`", "proposal" -> "`x`, `y`"), - caseSensitive = true, - line = -1, - pos = -1) + Map("objectName" -> "`a`", "proposal" -> "`x`, `y`") + ) val tableAcceptAnySchema = TestRelationAcceptAnySchema(StructType(Seq( StructField("x", DoubleType, nullable = false), @@ -706,10 +704,8 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { assertAnalysisErrorClass( parsedPlan2, "UNRESOLVED_COLUMN.WITH_SUGGESTION", - Map("objectName" -> "`a`", "proposal" -> "`x`, `y`"), - caseSensitive = true, - line = -1, - pos = -1) + Map("objectName" -> "`a`", "proposal" -> "`x`, `y`") + ) } test("SPARK-36498: reorder inner fields with byName mode") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index fdb49bd76746..d9e3db6903cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -523,7 +523,12 @@ case class FileSourceScanExec( // Note that some vals referring the file-based relation are lazy intentionally // so that this plan can be canonicalized on executor side too. See SPARK-23731. override lazy val supportsColumnar: Boolean = { - relation.fileFormat.supportBatch(relation.sparkSession, schema) + val conf = relation.sparkSession.sessionState.conf + // Only output columnar if there is WSCG to read it. + val requiredWholeStageCodegenSettings = + conf.wholeStageEnabled && !WholeStageCodegenExec.isTooManyFields(conf, schema) + requiredWholeStageCodegenSettings && + relation.fileFormat.supportBatch(relation.sparkSession, schema) } private lazy val needsUnsafeRowConversion: Boolean = { @@ -535,6 +540,8 @@ case class FileSourceScanExec( } lazy val inputRDD: RDD[InternalRow] = { + val options = relation.options + + (FileFormat.OPTION_RETURNING_BATCH -> supportsColumnar.toString) val readFile: (PartitionedFile) => Iterator[InternalRow] = relation.fileFormat.buildReaderWithPartitionValues( sparkSession = relation.sparkSession, @@ -542,7 +549,7 @@ case class FileSourceScanExec( partitionSchema = relation.partitionSchema, requiredSchema = requiredSchema, filters = pushedDownFilters, - options = relation.options, + options = options, hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)) val readRDD = if (bucketedScan) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index f7f917d89477..7e920773c048 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -61,6 +61,11 @@ trait FileFormat { /** * Returns whether this format supports returning columnar batch or not. + * If columnar batch output is requested, users shall supply + * FileFormat.OPTION_RETURNING_BATCH -> true + * in relation options when calling buildReaderWithPartitionValues. + * This should only be passed as true if it can actually be supported. + * For ParquetFileFormat and OrcFileFormat, passing this option is required. * * TODO: we should just have different traits for the different formats. */ @@ -191,6 +196,14 @@ object FileFormat { val METADATA_NAME = "_metadata" + /** + * Option to pass to buildReaderWithPartitionValues to return columnar batch output or not. + * For ParquetFileFormat and OrcFileFormat, passing this option is required. + * This should only be passed as true if it can actually be supported, which can be checked + * by calling supportBatch. + */ + val OPTION_RETURNING_BATCH = "returning_batch" + /** Schema of metadata struct that can be produced by every file format. */ val BASE_METADATA_STRUCT: StructType = new StructType() .add(StructField(FileFormat.FILE_PATH, StringType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 024c458feaff..6a58513c346d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -36,7 +36,6 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection -import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -102,8 +101,7 @@ class OrcFileFormat override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = { val conf = sparkSession.sessionState.conf - conf.orcVectorizedReaderEnabled && conf.wholeStageEnabled && - !WholeStageCodegenExec.isTooManyFields(conf, schema) && + conf.orcVectorizedReaderEnabled && schema.forall(s => OrcUtils.supportColumnarReads( s.dataType, sparkSession.sessionState.conf.orcVectorizedReaderNestedColumnEnabled)) } @@ -115,6 +113,18 @@ class OrcFileFormat true } + /** + * Build the reader. + * + * @note It is required to pass FileFormat.OPTION_RETURNING_BATCH in options, to indicate whether + * the reader should return row or columnar output. + * If the caller can handle both, pass + * FileFormat.OPTION_RETURNING_BATCH -> + * supportBatch(sparkSession, + * StructType(requiredSchema.fields ++ partitionSchema.fields)) + * as the option. + * It should be set to "true" only if this reader can support it. + */ override def buildReaderWithPartitionValues( sparkSession: SparkSession, dataSchema: StructType, @@ -126,9 +136,24 @@ class OrcFileFormat val resultSchema = StructType(requiredSchema.fields ++ partitionSchema.fields) val sqlConf = sparkSession.sessionState.conf - val enableVectorizedReader = supportBatch(sparkSession, resultSchema) val capacity = sqlConf.orcVectorizedReaderBatchSize + // Should always be set by FileSourceScanExec creating this. + // Check conf before checking option, to allow working around an issue by changing conf. + val enableVectorizedReader = sqlConf.orcVectorizedReaderEnabled && + options.get(FileFormat.OPTION_RETURNING_BATCH) + .getOrElse { + throw new IllegalArgumentException( + "OPTION_RETURNING_BATCH should always be set for OrcFileFormat. " + + "To workaround this issue, set spark.sql.orc.enableVectorizedReader=false.") + } + .equals("true") + if (enableVectorizedReader) { + // If the passed option said that we are to return batches, we need to also be able to + // do this based on config and resultSchema. + assert(supportBatch(sparkSession, resultSchema)) + } + OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(hadoopConf, sqlConf.caseSensitiveAnalysis) val broadcastedConf = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 5116a6bdb90c..80b6791d8fae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -42,7 +42,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjectio import org.apache.spark.sql.catalyst.parser.LegacyTypeStringParser import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.vectorized.{ConstantColumnVector, OffHeapColumnVector, OnHeapColumnVector} import org.apache.spark.sql.internal.SQLConf @@ -82,12 +81,11 @@ class ParquetFileFormat } /** - * Returns whether the reader will return the rows as batch or not. + * Returns whether the reader can return the rows as batch or not. */ override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = { val conf = sparkSession.sessionState.conf - ParquetUtils.isBatchReadSupportedForSchema(conf, schema) && conf.wholeStageEnabled && - !WholeStageCodegenExec.isTooManyFields(conf, schema) + ParquetUtils.isBatchReadSupportedForSchema(conf, schema) } override def vectorTypes( @@ -110,6 +108,18 @@ class ParquetFileFormat true } + /** + * Build the reader. + * + * @note It is required to pass FileFormat.OPTION_RETURNING_BATCH in options, to indicate whether + * the reader should return row or columnar output. + * If the caller can handle both, pass + * FileFormat.OPTION_RETURNING_BATCH -> + * supportBatch(sparkSession, + * StructType(requiredSchema.fields ++ partitionSchema.fields)) + * as the option. + * It should be set to "true" only if this reader can support it. + */ override def buildReaderWithPartitionValues( sparkSession: SparkSession, dataSchema: StructType, @@ -161,8 +171,6 @@ class ParquetFileFormat val timestampConversion: Boolean = sqlConf.isParquetINT96TimestampConversion val capacity = sqlConf.parquetVectorizedReaderBatchSize val enableParquetFilterPushDown: Boolean = sqlConf.parquetFilterPushDown - // Whole stage codegen (PhysicalRDD) is able to deal with batches directly - val returningBatch = supportBatch(sparkSession, resultSchema) val pushDownDate = sqlConf.parquetFilterPushDownDate val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal @@ -173,6 +181,22 @@ class ParquetFileFormat val datetimeRebaseModeInRead = parquetOptions.datetimeRebaseModeInRead val int96RebaseModeInRead = parquetOptions.int96RebaseModeInRead + // Should always be set by FileSourceScanExec creating this. + // Check conf before checking option, to allow working around an issue by changing conf. + val returningBatch = sparkSession.sessionState.conf.parquetVectorizedReaderEnabled && + options.get(FileFormat.OPTION_RETURNING_BATCH) + .getOrElse { + throw new IllegalArgumentException( + "OPTION_RETURNING_BATCH should always be set for ParquetFileFormat. " + + "To workaround this issue, set spark.sql.parquet.enableVectorizedReader=false.") + } + .equals("true") + if (returningBatch) { + // If the passed option said that we are to return batches, we need to also be able to + // do this based on config and resultSchema. + assert(supportBatch(sparkSession, resultSchema)) + } + (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 3b2271afc862..c50f1bec18cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.command import java.net.URI -import java.util.{Collections, Locale} +import java.util.Collections import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.{mock, when} @@ -39,7 +39,6 @@ import org.apache.spark.sql.connector.FakeV2Provider import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, Identifier, SupportsDelete, Table, TableCapability, TableCatalog, V1Table} import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.{CreateTable => CreateTableV1} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} @@ -251,14 +250,18 @@ class PlanResolutionSuite extends AnalysisTest { }.head } - private def assertUnsupported(sql: String, containsThesePhrases: Seq[String] = Seq()): Unit = { - val e = intercept[ParseException] { - parsePlan(sql) - } - assert(e.getMessage.toLowerCase(Locale.ROOT).contains("operation not allowed")) - containsThesePhrases.foreach { p => - assert(e.getMessage.toLowerCase(Locale.ROOT).contains(p.toLowerCase(Locale.ROOT))) - } + private def assertUnsupported( + sql: String, + parameters: Map[String, String], + context: ExpectedContext): Unit = { + checkError( + exception = intercept[ParseException] { + parsePlan(sql) + }, + errorClass = "_LEGACY_ERROR_TEMP_0035", + parameters = parameters, + context = context + ) } test("create table - with partitioned by") { @@ -400,16 +403,20 @@ class PlanResolutionSuite extends AnalysisTest { } val v2 = - """ - |CREATE TABLE my_tab(a INT, b STRING) + """CREATE TABLE my_tab(a INT, b STRING) |USING parquet |OPTIONS (path '/tmp/file') - |LOCATION '/tmp/file' - """.stripMargin - val e = intercept[AnalysisException] { - parseAndResolve(v2) - } - assert(e.message.contains("you can only specify one of them.")) + |LOCATION '/tmp/file'""".stripMargin + checkError( + exception = intercept[AnalysisException] { + parseAndResolve(v2) + }, + errorClass = "_LEGACY_ERROR_TEMP_0032", + parameters = Map("pathOne" -> "/tmp/file", "pathTwo" -> "/tmp/file"), + context = ExpectedContext( + fragment = v2, + start = 0, + stop = 97)) } test("create table - byte length literal table name") { @@ -1137,10 +1144,12 @@ class PlanResolutionSuite extends AnalysisTest { case _ => fail("Expect UpdateTable, but got:\n" + parsed7.treeString) } - assert(intercept[AnalysisException] { - parseAndResolve(sql8) - }.getMessage.contains( - QueryCompilationErrors.defaultReferencesNotAllowedInUpdateWhereClause().getMessage)) + checkError( + exception = intercept[AnalysisException] { + parseAndResolve(sql8) + }, + errorClass = "_LEGACY_ERROR_TEMP_1341", + parameters = Map.empty) parsed9 match { case UpdateTable( @@ -1250,10 +1259,23 @@ class PlanResolutionSuite extends AnalysisTest { comparePlans(parsed2, expected2) val sql3 = s"ALTER TABLE $tblName ALTER COLUMN j COMMENT 'new comment'" - val e1 = intercept[AnalysisException] { - parseAndResolve(sql3) - } - assert(e1.getMessage.contains("Missing field j in table spark_catalog.default.v1Table")) + checkError( + exception = intercept[AnalysisException] { + parseAndResolve(sql3) + }, + errorClass = "_LEGACY_ERROR_TEMP_1331", + parameters = Map( + "fieldName" -> "j", + "table" -> "spark_catalog.default.v1Table", + "schema" -> + """root + | |-- i: integer (nullable = true) + | |-- s: string (nullable = true) + | |-- point: struct (nullable = true) + | | |-- x: integer (nullable = true) + | | |-- y: integer (nullable = true) + |""".stripMargin), + context = ExpectedContext(fragment = sql3, start = 0, stop = 55)) val sql4 = s"ALTER TABLE $tblName ALTER COLUMN point.x TYPE bigint" val e2 = intercept[AnalysisException] { @@ -1304,11 +1326,15 @@ class PlanResolutionSuite extends AnalysisTest { } test("alter table: alter column action is not specified") { - val e = intercept[AnalysisException] { - parseAndResolve("ALTER TABLE v1Table ALTER COLUMN i") - } - assert(e.getMessage.contains( - "ALTER TABLE table ALTER COLUMN requires a TYPE, a SET/DROP, a COMMENT, or a FIRST/AFTER")) + val sql = "ALTER TABLE v1Table ALTER COLUMN i" + checkError( + exception = intercept[AnalysisException] { + parseAndResolve(sql) + }, + errorClass = "_LEGACY_ERROR_TEMP_0035", + parameters = Map("message" -> + "ALTER TABLE table ALTER COLUMN requires a TYPE, a SET/DROP, a COMMENT, or a FIRST/AFTER"), + context = ExpectedContext(fragment = sql, start = 0, stop = 33)) } test("alter table: alter column case sensitivity for v1 table") { @@ -1317,10 +1343,23 @@ class PlanResolutionSuite extends AnalysisTest { withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { val sql = s"ALTER TABLE $tblName ALTER COLUMN I COMMENT 'new comment'" if (caseSensitive) { - val e = intercept[AnalysisException] { - parseAndResolve(sql) - } - assert(e.getMessage.contains("Missing field I in table spark_catalog.default.v1Table")) + checkError( + exception = intercept[AnalysisException] { + parseAndResolve(sql) + }, + errorClass = "_LEGACY_ERROR_TEMP_1331", + parameters = Map( + "fieldName" -> "I", + "table" -> "spark_catalog.default.v1Table", + "schema" -> + """root + | |-- i: integer (nullable = true) + | |-- s: string (nullable = true) + | |-- point: struct (nullable = true) + | | |-- x: integer (nullable = true) + | | |-- y: integer (nullable = true) + |""".stripMargin), + context = ExpectedContext(fragment = sql, start = 0, stop = 55)) } else { val actual = parseAndResolve(sql) val expected = AlterTableChangeColumnCommand( @@ -1669,40 +1708,39 @@ class PlanResolutionSuite extends AnalysisTest { // This MERGE INTO command includes an ON clause with a DEFAULT column reference. This is // invalid and returns an error message. val mergeWithDefaultReferenceInMergeCondition = - s""" - |MERGE INTO testcat.tab AS target + s"""MERGE INTO testcat.tab AS target |USING testcat.tab1 AS source |ON target.i = DEFAULT |WHEN MATCHED AND (target.s = 31) THEN DELETE |WHEN MATCHED AND (target.s = 31) | THEN UPDATE SET target.s = DEFAULT |WHEN NOT MATCHED AND (source.s='insert') - | THEN INSERT (target.i, target.s) values (DEFAULT, DEFAULT) - """.stripMargin - assert(intercept[AnalysisException] { - parseAndResolve(mergeWithDefaultReferenceInMergeCondition) - }.getMessage.contains( - QueryCompilationErrors.defaultReferencesNotAllowedInMergeCondition().getMessage)) + | THEN INSERT (target.i, target.s) values (DEFAULT, DEFAULT)""".stripMargin + checkError( + exception = intercept[AnalysisException] { + parseAndResolve(mergeWithDefaultReferenceInMergeCondition) + }, + errorClass = "_LEGACY_ERROR_TEMP_1342", + parameters = Map.empty) // DEFAULT column reference within a complex expression: // This MERGE INTO command includes a WHEN MATCHED clause with a DEFAULT column reference as // of a complex expression (DEFAULT + 1). This is invalid and returns an error message. val mergeWithDefaultReferenceAsPartOfComplexExpression = - s""" - |MERGE INTO testcat.tab AS target + s"""MERGE INTO testcat.tab AS target |USING testcat.tab1 AS source |ON target.i = source.i |WHEN MATCHED AND (target.s = 31) THEN DELETE |WHEN MATCHED AND (target.s = 31) | THEN UPDATE SET target.s = DEFAULT + 1 |WHEN NOT MATCHED AND (source.s='insert') - | THEN INSERT (target.i, target.s) values (DEFAULT, DEFAULT) - """.stripMargin - assert(intercept[AnalysisException] { - parseAndResolve(mergeWithDefaultReferenceAsPartOfComplexExpression) - }.getMessage.contains( - QueryCompilationErrors - .defaultReferencesNotAllowedInComplexExpressionsInMergeInsertsOrUpdates().getMessage)) + | THEN INSERT (target.i, target.s) values (DEFAULT, DEFAULT)""".stripMargin + checkError( + exception = intercept[AnalysisException] { + parseAndResolve(mergeWithDefaultReferenceAsPartOfComplexExpression) + }, + errorClass = "_LEGACY_ERROR_TEMP_1343", + parameters = Map.empty) // Ambiguous DEFAULT column reference when the table itself contains a column named // "DEFAULT". @@ -1835,52 +1873,68 @@ class PlanResolutionSuite extends AnalysisTest { } val sql2 = - s""" - |MERGE INTO $target + s"""MERGE INTO $target |USING $source |ON i = 1 - |WHEN MATCHED THEN DELETE - """.stripMargin + |WHEN MATCHED THEN DELETE""".stripMargin // merge condition is resolved with both target and source tables, and we can't // resolve column `i` as it's ambiguous. - val e2 = intercept[AnalysisException](parseAndResolve(sql2)) - assert(e2.message.contains("Reference 'i' is ambiguous")) + checkError( + exception = intercept[AnalysisException](parseAndResolve(sql2)), + errorClass = null, + parameters = Map.empty, + context = ExpectedContext( + fragment = "i", + start = 22 + target.length + source.length, + stop = 22 + target.length + source.length)) val sql3 = - s""" - |MERGE INTO $target + s"""MERGE INTO $target |USING $source |ON 1 = 1 - |WHEN MATCHED AND (s='delete') THEN DELETE - """.stripMargin + |WHEN MATCHED AND (s='delete') THEN DELETE""".stripMargin // delete condition is resolved with both target and source tables, and we can't // resolve column `s` as it's ambiguous. - val e3 = intercept[AnalysisException](parseAndResolve(sql3)) - assert(e3.message.contains("Reference 's' is ambiguous")) + checkError( + exception = intercept[AnalysisException](parseAndResolve(sql3)), + errorClass = null, + parameters = Map.empty, + context = ExpectedContext( + fragment = "s", + start = 46 + target.length + source.length, + stop = 46 + target.length + source.length)) val sql4 = - s""" - |MERGE INTO $target + s"""MERGE INTO $target |USING $source |ON 1 = 1 - |WHEN MATCHED AND (s = 'a') THEN UPDATE SET i = 1 - """.stripMargin + |WHEN MATCHED AND (s = 'a') THEN UPDATE SET i = 1""".stripMargin // update condition is resolved with both target and source tables, and we can't // resolve column `s` as it's ambiguous. - val e4 = intercept[AnalysisException](parseAndResolve(sql4)) - assert(e4.message.contains("Reference 's' is ambiguous")) + checkError( + exception = intercept[AnalysisException](parseAndResolve(sql4)), + errorClass = null, + parameters = Map.empty, + context = ExpectedContext( + fragment = "s", + start = 46 + target.length + source.length, + stop = 46 + target.length + source.length)) val sql5 = - s""" - |MERGE INTO $target + s"""MERGE INTO $target |USING $source |ON 1 = 1 - |WHEN MATCHED THEN UPDATE SET s = s - """.stripMargin + |WHEN MATCHED THEN UPDATE SET s = s""".stripMargin // update value is resolved with both target and source tables, and we can't // resolve column `s` as it's ambiguous. - val e5 = intercept[AnalysisException](parseAndResolve(sql5)) - assert(e5.message.contains("Reference 's' is ambiguous")) + checkError( + exception = intercept[AnalysisException](parseAndResolve(sql5)), + errorClass = null, + parameters = Map.empty, + context = ExpectedContext( + fragment = "s", + start = 61 + target.length + source.length, + stop = 61 + target.length + source.length)) } val sql1 = @@ -1908,9 +1962,10 @@ class PlanResolutionSuite extends AnalysisTest { |ON 1 = 1 |WHEN MATCHED THEN UPDATE SET * |""".stripMargin - val e2 = intercept[AnalysisException](parseAndResolve(sql2)) - assert(e2.message.contains( - "cannot resolve s in MERGE command given columns [testcat.tab2.i, testcat.tab2.x]")) + checkError( + exception = intercept[AnalysisException](parseAndResolve(sql2)), + errorClass = null, + parameters = Map.empty) // INSERT * with incompatible schema between source and target tables. val sql3 = @@ -1920,9 +1975,10 @@ class PlanResolutionSuite extends AnalysisTest { |ON 1 = 1 |WHEN NOT MATCHED THEN INSERT * |""".stripMargin - val e3 = intercept[AnalysisException](parseAndResolve(sql3)) - assert(e3.message.contains( - "cannot resolve s in MERGE command given columns [testcat.tab2.i, testcat.tab2.x]")) + checkError( + exception = intercept[AnalysisException](parseAndResolve(sql3)), + errorClass = null, + parameters = Map.empty) val sql4 = """ @@ -2108,9 +2164,11 @@ class PlanResolutionSuite extends AnalysisTest { ) ) - interceptParseException(parsePlan)( - "CREATE TABLE my_tab(a: INT COMMENT 'test', b: STRING)", - "Syntax error at or near ':': extra input ':'")() + val sql = "CREATE TABLE my_tab(a: INT COMMENT 'test', b: STRING)" + checkError( + exception = parseException(parsePlan)(sql), + errorClass = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "':'", "hint" -> ": extra input ':'")) } test("create hive table - table file format") { @@ -2170,7 +2228,11 @@ class PlanResolutionSuite extends AnalysisTest { assert(ct.tableDesc.storage.outputFormat == hiveSerde.get.outputFormat) } } else { - assertUnsupported(query, Seq("row format serde", "incompatible", s)) + assertUnsupported( + query, + Map("message" -> (s"ROW FORMAT SERDE is incompatible with format '$s', " + + "which also specifies a serde")), + ExpectedContext(fragment = query, start = 0, stop = 57 + s.length)) } } } @@ -2192,7 +2254,11 @@ class PlanResolutionSuite extends AnalysisTest { assert(ct.tableDesc.storage.outputFormat == hiveSerde.get.outputFormat) } } else { - assertUnsupported(query, Seq("row format delimited", "only compatible with 'textfile'", s)) + assertUnsupported( + query, + Map("message" -> ("ROW FORMAT DELIMITED is only compatible with 'textfile', " + + s"not '$s'")), + ExpectedContext(fragment = query, start = 0, stop = 75 + s.length)) } } } @@ -2214,14 +2280,23 @@ class PlanResolutionSuite extends AnalysisTest { } test("create hive table - property values must be set") { + val sql1 = "CREATE TABLE my_tab STORED AS parquet " + + "TBLPROPERTIES('key_without_value', 'key_with_value'='x')" assertUnsupported( - sql = "CREATE TABLE my_tab STORED AS parquet " + - "TBLPROPERTIES('key_without_value', 'key_with_value'='x')", - containsThesePhrases = Seq("key_without_value")) + sql1, + Map("message" -> "Values must be specified for key(s): [key_without_value]"), + ExpectedContext(fragment = sql1, start = 0, stop = 93)) + + val sql2 = "CREATE TABLE my_tab ROW FORMAT SERDE 'serde' " + + "WITH SERDEPROPERTIES('key_without_value', 'key_with_value'='x')" assertUnsupported( - sql = "CREATE TABLE my_tab ROW FORMAT SERDE 'serde' " + - "WITH SERDEPROPERTIES('key_without_value', 'key_with_value'='x')", - containsThesePhrases = Seq("key_without_value")) + sql2, + Map("message" -> "Values must be specified for key(s): [key_without_value]"), + ExpectedContext( + fragment = "ROW FORMAT SERDE 'serde' WITH SERDEPROPERTIES('key_without_value', " + + "'key_with_value'='x')", + start = 20, + stop = 107)) } test("create hive table - location implies external") { @@ -2234,28 +2309,58 @@ class PlanResolutionSuite extends AnalysisTest { } test("Duplicate clauses - create hive table") { - def intercept(sqlCommand: String, messages: String*): Unit = - interceptParseException(parsePlan)(sqlCommand, messages: _*)() - def createTableHeader(duplicateClause: String): String = { s"CREATE TABLE my_tab(a INT, b STRING) STORED AS parquet $duplicateClause $duplicateClause" } - intercept(createTableHeader("TBLPROPERTIES('test' = 'test2')"), - "Found duplicate clauses: TBLPROPERTIES") - intercept(createTableHeader("LOCATION '/tmp/file'"), - "Found duplicate clauses: LOCATION") - intercept(createTableHeader("COMMENT 'a table'"), - "Found duplicate clauses: COMMENT") - intercept(createTableHeader("CLUSTERED BY(b) INTO 256 BUCKETS"), - "Found duplicate clauses: CLUSTERED BY") - intercept(createTableHeader("PARTITIONED BY (k int)"), - "Found duplicate clauses: PARTITIONED BY") - intercept(createTableHeader("STORED AS parquet"), - "Found duplicate clauses: STORED AS/BY") - intercept( - createTableHeader("ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe'"), - "Found duplicate clauses: ROW FORMAT") + val sql1 = createTableHeader("TBLPROPERTIES('test' = 'test2')") + checkError( + exception = parseException(parsePlan)(sql1), + errorClass = "_LEGACY_ERROR_TEMP_0041", + parameters = Map("clauseName" -> "TBLPROPERTIES"), + context = ExpectedContext(fragment = sql1, start = 0, stop = 117)) + + val sql2 = createTableHeader("LOCATION '/tmp/file'") + checkError( + exception = parseException(parsePlan)(sql2), + errorClass = "_LEGACY_ERROR_TEMP_0041", + parameters = Map("clauseName" -> "LOCATION"), + context = ExpectedContext(fragment = sql2, start = 0, stop = 95)) + + val sql3 = createTableHeader("COMMENT 'a table'") + checkError( + exception = parseException(parsePlan)(sql3), + errorClass = "_LEGACY_ERROR_TEMP_0041", + parameters = Map("clauseName" -> "COMMENT"), + context = ExpectedContext(fragment = sql3, start = 0, stop = 89)) + + val sql4 = createTableHeader("CLUSTERED BY(b) INTO 256 BUCKETS") + checkError( + exception = parseException(parsePlan)(sql4), + errorClass = "_LEGACY_ERROR_TEMP_0041", + parameters = Map("clauseName" -> "CLUSTERED BY"), + context = ExpectedContext(fragment = sql4, start = 0, stop = 119)) + + val sql5 = createTableHeader("PARTITIONED BY (k int)") + checkError( + exception = parseException(parsePlan)(sql5), + errorClass = "_LEGACY_ERROR_TEMP_0041", + parameters = Map("clauseName" -> "PARTITIONED BY"), + context = ExpectedContext(fragment = sql5, start = 0, stop = 99)) + + val sql6 = createTableHeader("STORED AS parquet") + checkError( + exception = parseException(parsePlan)(sql6), + errorClass = "_LEGACY_ERROR_TEMP_0041", + parameters = Map("clauseName" -> "STORED AS/BY"), + context = ExpectedContext(fragment = sql6, start = 0, stop = 89)) + + val sql7 = createTableHeader("ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe'") + checkError( + exception = parseException(parsePlan)(sql7), + errorClass = "_LEGACY_ERROR_TEMP_0041", + parameters = Map("clauseName" -> "ROW FORMAT"), + context = ExpectedContext(fragment = sql7, start = 0, stop = 163)) } test("Test CTAS #1") { @@ -2390,9 +2495,16 @@ class PlanResolutionSuite extends AnalysisTest { val s4 = """CREATE TABLE page_view |STORED BY 'storage.handler.class.name' AS SELECT * FROM src""".stripMargin - intercept[AnalysisException] { - extractTableDesc(s4) - } + checkError( + exception = intercept[AnalysisException] { + extractTableDesc(s4) + }, + errorClass = "_LEGACY_ERROR_TEMP_0035", + parameters = Map("message" -> "STORED BY"), + context = ExpectedContext( + fragment = "STORED BY 'storage.handler.class.name'", + start = 23, + stop = 60)) } test("Test CTAS #5") { @@ -2423,13 +2535,27 @@ class PlanResolutionSuite extends AnalysisTest { } test("CTAS statement with a PARTITIONED BY clause is not allowed") { - assertUnsupported(s"CREATE TABLE ctas1 PARTITIONED BY (k int)" + - " AS SELECT key, value FROM (SELECT 1 as key, 2 as value) tmp") + val sql = s"CREATE TABLE ctas1 PARTITIONED BY (k int)" + + " AS SELECT key, value FROM (SELECT 1 as key, 2 as value) tmp" + assertUnsupported( + sql, + Map("message" -> + "Partition column types may not be specified in Create Table As Select (CTAS)"), + ExpectedContext(fragment = sql, start = 0, stop = 100)) } test("CTAS statement with schema") { - assertUnsupported(s"CREATE TABLE ctas1 (age INT, name STRING) AS SELECT * FROM src") - assertUnsupported(s"CREATE TABLE ctas1 (age INT, name STRING) AS SELECT 1, 'hello'") + val sql1 = s"CREATE TABLE ctas1 (age INT, name STRING) AS SELECT * FROM src" + assertUnsupported( + sql1, + Map("message" -> "Schema may not be specified in a Create Table As Select (CTAS) statement"), + ExpectedContext(fragment = sql1, start = 0, stop = 61)) + + val sql2 = s"CREATE TABLE ctas1 (age INT, name STRING) AS SELECT 1, 'hello'" + assertUnsupported( + sql2, + Map("message" -> "Schema may not be specified in a Create Table As Select (CTAS) statement"), + ExpectedContext(fragment = sql2, start = 0, stop = 61)) } test("create table - basic") { @@ -2464,8 +2590,14 @@ class PlanResolutionSuite extends AnalysisTest { test("create table - temporary") { val query = "CREATE TEMPORARY TABLE tab1 (id int, name string)" - val e = intercept[ParseException] { parsePlan(query) } - assert(e.message.contains("Operation not allowed: CREATE TEMPORARY TABLE")) + checkError( + exception = intercept[ParseException] { + parsePlan(query) + }, + errorClass = "_LEGACY_ERROR_TEMP_0035", + parameters = Map( + "message" -> "CREATE TEMPORARY TABLE ..., use CREATE TEMPORARY VIEW instead"), + context = ExpectedContext(fragment = query, start = 0, stop = 48)) } test("create table - external") { @@ -2528,15 +2660,33 @@ class PlanResolutionSuite extends AnalysisTest { test("create table(hive) - skewed by") { val baseQuery = "CREATE TABLE my_table (id int, name string) SKEWED BY" + val query1 = s"$baseQuery(id) ON (1, 10, 100)" + checkError( + exception = intercept[ParseException] { + parsePlan(query1) + }, + errorClass = "_LEGACY_ERROR_TEMP_0035", + parameters = Map("message" -> "CREATE TABLE ... SKEWED BY"), + context = ExpectedContext(fragment = query1, start = 0, stop = 72)) + val query2 = s"$baseQuery(id, name) ON ((1, 'x'), (2, 'y'), (3, 'z'))" + checkError( + exception = intercept[ParseException] { + parsePlan(query2) + }, + errorClass = "_LEGACY_ERROR_TEMP_0035", + parameters = Map("message" -> "CREATE TABLE ... SKEWED BY"), + context = ExpectedContext(fragment = query2, start = 0, stop = 96)) + val query3 = s"$baseQuery(id, name) ON ((1, 'x'), (2, 'y'), (3, 'z')) STORED AS DIRECTORIES" - val e1 = intercept[ParseException] { parsePlan(query1) } - val e2 = intercept[ParseException] { parsePlan(query2) } - val e3 = intercept[ParseException] { parsePlan(query3) } - assert(e1.getMessage.contains("Operation not allowed")) - assert(e2.getMessage.contains("Operation not allowed")) - assert(e3.getMessage.contains("Operation not allowed")) + checkError( + exception = intercept[ParseException] { + parsePlan(query3) + }, + errorClass = "_LEGACY_ERROR_TEMP_0035", + parameters = Map("message" -> "CREATE TABLE ... SKEWED BY"), + context = ExpectedContext(fragment = query3, start = 0, stop = 118)) } test("create table(hive) - row format") { @@ -2583,12 +2733,30 @@ class PlanResolutionSuite extends AnalysisTest { test("create table(hive) - storage handler") { val baseQuery = "CREATE TABLE my_table (id int, name string) STORED BY" + val query1 = s"$baseQuery 'org.papachi.StorageHandler'" + checkError( + exception = intercept[ParseException] { + parsePlan(query1) + }, + errorClass = "_LEGACY_ERROR_TEMP_0035", + parameters = Map("message" -> "STORED BY"), + context = ExpectedContext( + fragment = "STORED BY 'org.papachi.StorageHandler'", + start = 44, + stop = 81)) + val query2 = s"$baseQuery 'org.mamachi.StorageHandler' WITH SERDEPROPERTIES ('k1'='v1')" - val e1 = intercept[ParseException] { parsePlan(query1) } - val e2 = intercept[ParseException] { parsePlan(query2) } - assert(e1.getMessage.contains("Operation not allowed")) - assert(e2.getMessage.contains("Operation not allowed")) + checkError( + exception = intercept[ParseException] { + parsePlan(query2) + }, + errorClass = "_LEGACY_ERROR_TEMP_0035", + parameters = Map("message" -> "STORED BY"), + context = ExpectedContext( + fragment = "STORED BY 'org.mamachi.StorageHandler' WITH SERDEPROPERTIES ('k1'='v1')", + start = 44, + stop = 114)) } test("create table(hive) - everything!") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala index 8909fe49aac1..2c8d72ec6093 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala @@ -622,4 +622,30 @@ class FileMetadataStructSuite extends QueryTest with SharedSparkSession { } } } + + Seq("parquet", "orc").foreach { format => + test(s"SPARK-40918: Output cols around WSCG.isTooManyFields limit in $format") { + // The issue was that ParquetFileFormat would not count the _metadata columns towards + // the WholeStageCodegenExec.isTooManyFields limit, while FileSourceScanExec would, + // resulting in Parquet reader returning columnar output, while scan expected row. + withTempPath { dir => + sql(s"SELECT ${(1 to 100).map(i => s"id+$i as c$i").mkString(", ")} FROM RANGE(100)") + .write.format(format).save(dir.getAbsolutePath) + (98 to 102).foreach { wscgCols => + withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> wscgCols.toString) { + // Would fail with + // java.lang.ClassCastException: org.apache.spark.sql.vectorized.ColumnarBatch + // cannot be cast to org.apache.spark.sql.catalyst.InternalRow + sql( + s""" + |SELECT + | ${(1 to 100).map(i => s"sum(c$i)").mkString(", ")}, + | max(_metadata.file_path) + |FROM $format.`$dir`""".stripMargin + ).collect() + } + } + } + } + } }