From 5c9843db2b3ddec0b03374df03dcaa1847941c34 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 28 Oct 2022 19:05:38 +0800 Subject: [PATCH 1/6] [SPARK-40229][PS][TEST][FOLLOWUP] Add `openpyxl` to `requirements.txt` ### What changes were proposed in this pull request? This is a follow-up of https://github.com/apache/spark/pull/37671. ### Why are the changes needed? Since https://github.com/apache/spark/pull/37671 added `openpyxl` for PySpark test environments and re-enabled `test_to_excel` test, we need to add it to `requirements.txt` as PySpark test dependency explicitly. ### Does this PR introduce _any_ user-facing change? No. This is a test dependency. ### How was this patch tested? Manually. Closes #38425 from dongjoon-hyun/SPARK-40229. Authored-by: Dongjoon Hyun Signed-off-by: Yikun Jiang --- dev/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/dev/requirements.txt b/dev/requirements.txt index fa4b6752f145..2f32066d6a88 100644 --- a/dev/requirements.txt +++ b/dev/requirements.txt @@ -13,6 +13,7 @@ matplotlib<3.3.0 # PySpark test dependencies unittest-xml-reporting +openpyxl # PySpark test dependencies (optional) coverage From 145de7d56e0a4ead4d5d4715ba39e3108e659834 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Fri, 28 Oct 2022 15:05:25 +0300 Subject: [PATCH 2/6] [SPARK-40936][SQL][TESTS] Refactor `AnalysisTest#assertAnalysisErrorClass` by reusing the `SparkFunSuite#checkError` ### What changes were proposed in this pull request? This pr aims to refactor `AnalysisTest#assertAnalysisErrorClass` method by reusing the `checkError` method in `SparkFunSuite`. On the other hand, the signature of `AnalysisTest#assertAnalysisErrorClass` method is changed from ``` protected def assertAnalysisErrorClass( inputPlan: LogicalPlan, expectedErrorClass: String, expectedMessageParameters: Map[String, String], caseSensitive: Boolean = true, line: Int = -1, pos: Int = -1): Unit ``` to ``` protected def assertAnalysisErrorClass( inputPlan: LogicalPlan, expectedErrorClass: String, expectedMessageParameters: Map[String, String], queryContext: Array[QueryContext] = Array.empty, caseSensitive: Boolean = true): Unit ``` Then when we need to use `queryContext` instead of `line + pos` for assertion ### Why are the changes needed? `assertAnalysisErrorClass` and `checkError` does the same work. ### Does this PR introduce _any_ user-facing change? No, just for test ### How was this patch tested? - Pass GitHub Actions Closes #38413 from LuciferYang/simplify-assertAnalysisErrorClass. Authored-by: yangjie01 Signed-off-by: Max Gekk --- .../analysis/AnalysisErrorSuite.scala | 8 ++-- .../AnalysisExceptionPositionSuite.scala | 12 +++--- .../sql/catalyst/analysis/AnalysisSuite.scala | 39 ++++++++++--------- .../sql/catalyst/analysis/AnalysisTest.scala | 39 +++++-------------- .../analysis/ResolveSubquerySuite.scala | 30 +++++--------- .../analysis/V2WriteAnalysisSuite.scala | 12 ++---- 6 files changed, 53 insertions(+), 87 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 04de2bdcb51c..b718f410be6b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -119,8 +119,7 @@ class AnalysisErrorSuite extends AnalysisTest { messageParameters: Map[String, String], caseSensitive: Boolean = true): Unit = { test(name) { - assertAnalysisErrorClass(plan, errorClass, messageParameters, - caseSensitive = true, line = -1, pos = -1) + assertAnalysisErrorClass(plan, errorClass, messageParameters, caseSensitive = caseSensitive) } } @@ -899,9 +898,8 @@ class AnalysisErrorSuite extends AnalysisTest { "inputSql" -> inputSql, "inputType" -> inputType, "requiredType" -> "(\"INT\" or \"BIGINT\")"), - caseSensitive = false, - line = -1, - pos = -1) + caseSensitive = false + ) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisExceptionPositionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisExceptionPositionSuite.scala index 7b720a7a0472..be256adbd892 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisExceptionPositionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisExceptionPositionSuite.scala @@ -52,8 +52,8 @@ class AnalysisExceptionPositionSuite extends AnalysisTest { parsePlan("SHOW COLUMNS FROM unknown IN db"), "TABLE_OR_VIEW_NOT_FOUND", Map("relationName" -> "`db`.`unknown`"), - line = 1, - pos = 18) + Array(ExpectedContext("unknown", 18, 24)) + ) verifyTableOrViewPosition("ALTER TABLE unknown RENAME TO t", "unknown") verifyTableOrViewPosition("ALTER VIEW unknown RENAME TO v", "unknown") } @@ -92,13 +92,13 @@ class AnalysisExceptionPositionSuite extends AnalysisTest { } private def verifyPosition(sql: String, table: String): Unit = { - val expectedPos = sql.indexOf(table) - assert(expectedPos != -1) + val startPos = sql.indexOf(table) + assert(startPos != -1) assertAnalysisErrorClass( parsePlan(sql), "TABLE_OR_VIEW_NOT_FOUND", Map("relationName" -> s"`$table`"), - line = 1, - pos = expectedPos) + Array(ExpectedContext(table, startPos, startPos + table.length - 1)) + ) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 6f0e6ef0c110..c1106a265451 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -104,10 +104,8 @@ class AnalysisSuite extends AnalysisTest with Matchers { Project(Seq(UnresolvedAttribute("tBl.a")), SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")))), "UNRESOLVED_COLUMN.WITH_SUGGESTION", - Map("objectName" -> "`tBl`.`a`", "proposal" -> "`TbL`.`a`"), - caseSensitive = true, - line = -1, - pos = -1) + Map("objectName" -> "`tBl`.`a`", "proposal" -> "`TbL`.`a`") + ) checkAnalysisWithoutViewWrapper( Project(Seq(UnresolvedAttribute("TbL.a")), @@ -716,9 +714,8 @@ class AnalysisSuite extends AnalysisTest with Matchers { assertAnalysisErrorClass(parsePlan("WITH t(x) AS (SELECT 1) SELECT * FROM t WHERE y = 1"), "UNRESOLVED_COLUMN.WITH_SUGGESTION", Map("objectName" -> "`y`", "proposal" -> "`t`.`x`"), - caseSensitive = true, - line = -1, - pos = -1) + Array(ExpectedContext("y", 46, 46)) + ) } test("CTE with non-matching column alias") { @@ -729,7 +726,8 @@ class AnalysisSuite extends AnalysisTest with Matchers { test("SPARK-28251: Insert into non-existing table error message is user friendly") { assertAnalysisErrorClass(parsePlan("INSERT INTO test VALUES (1)"), - "TABLE_OR_VIEW_NOT_FOUND", Map("relationName" -> "`test`")) + "TABLE_OR_VIEW_NOT_FOUND", Map("relationName" -> "`test`"), + Array(ExpectedContext("test", 12, 15))) } test("check CollectMetrics resolved") { @@ -1157,9 +1155,8 @@ class AnalysisSuite extends AnalysisTest with Matchers { |""".stripMargin), "UNRESOLVED_COLUMN.WITH_SUGGESTION", Map("objectName" -> "`c`.`y`", "proposal" -> "`x`"), - caseSensitive = true, - line = -1, - pos = -1) + Array(ExpectedContext("c.y", 123, 125)) + ) } test("SPARK-38118: Func(wrong_type) in the HAVING clause should throw data mismatch error") { @@ -1178,7 +1175,9 @@ class AnalysisSuite extends AnalysisTest with Matchers { "inputSql" -> "\"c\"", "inputType" -> "\"BOOLEAN\"", "requiredType" -> "\"NUMERIC\" or \"ANSI INTERVAL\""), - caseSensitive = false) + queryContext = Array(ExpectedContext("mean(t.c)", 65, 73)), + caseSensitive = false + ) assertAnalysisErrorClass( inputPlan = parsePlan( @@ -1195,6 +1194,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { "inputSql" -> "\"c\"", "inputType" -> "\"BOOLEAN\"", "requiredType" -> "\"NUMERIC\" or \"ANSI INTERVAL\""), + queryContext = Array(ExpectedContext("mean(c)", 91, 97)), caseSensitive = false) assertAnalysisErrorClass( @@ -1213,9 +1213,9 @@ class AnalysisSuite extends AnalysisTest with Matchers { "inputType" -> "\"BOOLEAN\"", "requiredType" -> "(\"NUMERIC\" or \"INTERVAL DAY TO SECOND\" or \"INTERVAL YEAR TO MONTH\")"), - caseSensitive = false, - line = -1, - pos = -1) + queryContext = Array(ExpectedContext("abs(t.c)", 65, 72)), + caseSensitive = false + ) assertAnalysisErrorClass( inputPlan = parsePlan( @@ -1233,9 +1233,9 @@ class AnalysisSuite extends AnalysisTest with Matchers { "inputType" -> "\"BOOLEAN\"", "requiredType" -> "(\"NUMERIC\" or \"INTERVAL DAY TO SECOND\" or \"INTERVAL YEAR TO MONTH\")"), - caseSensitive = false, - line = -1, - pos = -1) + queryContext = Array(ExpectedContext("abs(c)", 91, 96)), + caseSensitive = false + ) } test("SPARK-39354: should be [TABLE_OR_VIEW_NOT_FOUND]") { @@ -1246,7 +1246,8 @@ class AnalysisSuite extends AnalysisTest with Matchers { |FROM t1 |JOIN t2 ON t1.user_id = t2.user_id |WHERE t1.dt >= DATE_SUB('2020-12-27', 90)""".stripMargin), - "TABLE_OR_VIEW_NOT_FOUND", Map("relationName" -> "`t2`")) + "TABLE_OR_VIEW_NOT_FOUND", Map("relationName" -> "`t2`"), + Array(ExpectedContext("t2", 84, 85))) } test("SPARK-39144: nested subquery expressions deduplicate relations should be done bottom up") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index a195b76d7c43..5e7395d905d2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -34,6 +34,8 @@ import org.apache.spark.sql.types.StructType trait AnalysisTest extends PlanTest { + import org.apache.spark.QueryContext + protected def extendedAnalysisRules: Seq[Rule[LogicalPlan]] = Nil protected def createTempView( @@ -174,40 +176,19 @@ trait AnalysisTest extends PlanTest { inputPlan: LogicalPlan, expectedErrorClass: String, expectedMessageParameters: Map[String, String], - caseSensitive: Boolean = true, - line: Int = -1, - pos: Int = -1): Unit = { + queryContext: Array[QueryContext] = Array.empty, + caseSensitive: Boolean = true): Unit = { withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { val analyzer = getAnalyzer val e = intercept[AnalysisException] { analyzer.checkAnalysis(analyzer.execute(inputPlan)) } - - if (e.getErrorClass != expectedErrorClass || - e.messageParameters != expectedMessageParameters || - (line >= 0 && e.line.getOrElse(-1) != line) || - (pos >= 0) && e.startPosition.getOrElse(-1) != pos) { - var failMsg = "" - if (e.getErrorClass != expectedErrorClass) { - failMsg += - s"""Error class should be: ${expectedErrorClass} - |Actual error class: ${e.getErrorClass} - """.stripMargin - } - if (e.messageParameters != expectedMessageParameters) { - failMsg += - s"""Message parameters should be: ${expectedMessageParameters.mkString("\n ")} - |Actual message parameters: ${e.messageParameters.mkString("\n ")} - """.stripMargin - } - if (e.line.getOrElse(-1) != line || e.startPosition.getOrElse(-1) != pos) { - failMsg += - s"""Line/position should be: $line, $pos - |Actual line/position: ${e.line.getOrElse(-1)}, ${e.startPosition.getOrElse(-1)} - """.stripMargin - } - fail(failMsg) - } + checkError( + exception = e, + errorClass = expectedErrorClass, + parameters = expectedMessageParameters, + queryContext = queryContext + ) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala index 716d7aeb60fb..b1d569be5ba2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala @@ -134,10 +134,8 @@ class ResolveSubquerySuite extends AnalysisTest { assertAnalysisErrorClass( lateralJoin(t1, lateralJoin(t2, t0.select($"a", $"b", $"c"))), "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", - Map("objectName" -> "`a`"), - caseSensitive = true, - line = -1, - pos = -1) + Map("objectName" -> "`a`") + ) } test("lateral subquery with unresolvable attributes") { @@ -145,34 +143,26 @@ class ResolveSubquerySuite extends AnalysisTest { assertAnalysisErrorClass( lateralJoin(t1, t0.select($"a", $"c")), "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", - Map("objectName" -> "`c`"), - caseSensitive = true, - line = -1, - pos = -1) + Map("objectName" -> "`c`") + ) // SELECT * FROM t1, LATERAL (SELECT a, b, c, d FROM t2) assertAnalysisErrorClass( lateralJoin(t1, t2.select($"a", $"b", $"c", $"d")), "UNRESOLVED_COLUMN.WITH_SUGGESTION", - Map("objectName" -> "`d`", "proposal" -> "`b`, `c`"), - caseSensitive = true, - line = -1, - pos = -1) + Map("objectName" -> "`d`", "proposal" -> "`b`, `c`") + ) // SELECT * FROM t1, LATERAL (SELECT * FROM t2, LATERAL (SELECT t1.a)) assertAnalysisErrorClass( lateralJoin(t1, lateralJoin(t2, t0.select($"t1.a"))), "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", - Map("objectName" -> "`t1`.`a`"), - caseSensitive = true, - line = -1, - pos = -1) + Map("objectName" -> "`t1`.`a`") + ) // SELECT * FROM t1, LATERAL (SELECT * FROM t2, LATERAL (SELECT a, b)) assertAnalysisErrorClass( lateralJoin(t1, lateralJoin(t2, t0.select($"a", $"b"))), "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", - Map("objectName" -> "`a`"), - caseSensitive = true, - line = -1, - pos = -1) + Map("objectName" -> "`a`") + ) } test("lateral subquery with struct type") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala index b87531832970..9f9df10b398b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala @@ -691,10 +691,8 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { assertAnalysisErrorClass( parsedPlan, "UNRESOLVED_COLUMN.WITH_SUGGESTION", - Map("objectName" -> "`a`", "proposal" -> "`x`, `y`"), - caseSensitive = true, - line = -1, - pos = -1) + Map("objectName" -> "`a`", "proposal" -> "`x`, `y`") + ) val tableAcceptAnySchema = TestRelationAcceptAnySchema(StructType(Seq( StructField("x", DoubleType, nullable = false), @@ -706,10 +704,8 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { assertAnalysisErrorClass( parsedPlan2, "UNRESOLVED_COLUMN.WITH_SUGGESTION", - Map("objectName" -> "`a`", "proposal" -> "`x`, `y`"), - caseSensitive = true, - line = -1, - pos = -1) + Map("objectName" -> "`a`", "proposal" -> "`x`, `y`") + ) } test("SPARK-36498: reorder inner fields with byName mode") { From 7b4dfa319b421f81e02f63ea867ce0977ac438d7 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Fri, 28 Oct 2022 15:07:27 +0300 Subject: [PATCH 3/6] [SPARK-40889][SQL][TESTS] Check error classes in PlanResolutionSuite ### What changes were proposed in this pull request? This PR aims to replace 'intercept' with 'Check error classes' in PlanResolutionSuite. ### Why are the changes needed? The changes improve the error framework. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? By running the modified test suite: ``` $ build/sbt "test:testOnly *PlanResolutionSuite" ``` Closes #38421 from panbingkun/SPARK-40889. Authored-by: panbingkun Signed-off-by: Max Gekk --- .../command/PlanResolutionSuite.scala | 428 ++++++++++++------ 1 file changed, 298 insertions(+), 130 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 3b2271afc862..c50f1bec18cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.command import java.net.URI -import java.util.{Collections, Locale} +import java.util.Collections import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.{mock, when} @@ -39,7 +39,6 @@ import org.apache.spark.sql.connector.FakeV2Provider import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, Identifier, SupportsDelete, Table, TableCapability, TableCatalog, V1Table} import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.{CreateTable => CreateTableV1} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} @@ -251,14 +250,18 @@ class PlanResolutionSuite extends AnalysisTest { }.head } - private def assertUnsupported(sql: String, containsThesePhrases: Seq[String] = Seq()): Unit = { - val e = intercept[ParseException] { - parsePlan(sql) - } - assert(e.getMessage.toLowerCase(Locale.ROOT).contains("operation not allowed")) - containsThesePhrases.foreach { p => - assert(e.getMessage.toLowerCase(Locale.ROOT).contains(p.toLowerCase(Locale.ROOT))) - } + private def assertUnsupported( + sql: String, + parameters: Map[String, String], + context: ExpectedContext): Unit = { + checkError( + exception = intercept[ParseException] { + parsePlan(sql) + }, + errorClass = "_LEGACY_ERROR_TEMP_0035", + parameters = parameters, + context = context + ) } test("create table - with partitioned by") { @@ -400,16 +403,20 @@ class PlanResolutionSuite extends AnalysisTest { } val v2 = - """ - |CREATE TABLE my_tab(a INT, b STRING) + """CREATE TABLE my_tab(a INT, b STRING) |USING parquet |OPTIONS (path '/tmp/file') - |LOCATION '/tmp/file' - """.stripMargin - val e = intercept[AnalysisException] { - parseAndResolve(v2) - } - assert(e.message.contains("you can only specify one of them.")) + |LOCATION '/tmp/file'""".stripMargin + checkError( + exception = intercept[AnalysisException] { + parseAndResolve(v2) + }, + errorClass = "_LEGACY_ERROR_TEMP_0032", + parameters = Map("pathOne" -> "/tmp/file", "pathTwo" -> "/tmp/file"), + context = ExpectedContext( + fragment = v2, + start = 0, + stop = 97)) } test("create table - byte length literal table name") { @@ -1137,10 +1144,12 @@ class PlanResolutionSuite extends AnalysisTest { case _ => fail("Expect UpdateTable, but got:\n" + parsed7.treeString) } - assert(intercept[AnalysisException] { - parseAndResolve(sql8) - }.getMessage.contains( - QueryCompilationErrors.defaultReferencesNotAllowedInUpdateWhereClause().getMessage)) + checkError( + exception = intercept[AnalysisException] { + parseAndResolve(sql8) + }, + errorClass = "_LEGACY_ERROR_TEMP_1341", + parameters = Map.empty) parsed9 match { case UpdateTable( @@ -1250,10 +1259,23 @@ class PlanResolutionSuite extends AnalysisTest { comparePlans(parsed2, expected2) val sql3 = s"ALTER TABLE $tblName ALTER COLUMN j COMMENT 'new comment'" - val e1 = intercept[AnalysisException] { - parseAndResolve(sql3) - } - assert(e1.getMessage.contains("Missing field j in table spark_catalog.default.v1Table")) + checkError( + exception = intercept[AnalysisException] { + parseAndResolve(sql3) + }, + errorClass = "_LEGACY_ERROR_TEMP_1331", + parameters = Map( + "fieldName" -> "j", + "table" -> "spark_catalog.default.v1Table", + "schema" -> + """root + | |-- i: integer (nullable = true) + | |-- s: string (nullable = true) + | |-- point: struct (nullable = true) + | | |-- x: integer (nullable = true) + | | |-- y: integer (nullable = true) + |""".stripMargin), + context = ExpectedContext(fragment = sql3, start = 0, stop = 55)) val sql4 = s"ALTER TABLE $tblName ALTER COLUMN point.x TYPE bigint" val e2 = intercept[AnalysisException] { @@ -1304,11 +1326,15 @@ class PlanResolutionSuite extends AnalysisTest { } test("alter table: alter column action is not specified") { - val e = intercept[AnalysisException] { - parseAndResolve("ALTER TABLE v1Table ALTER COLUMN i") - } - assert(e.getMessage.contains( - "ALTER TABLE table ALTER COLUMN requires a TYPE, a SET/DROP, a COMMENT, or a FIRST/AFTER")) + val sql = "ALTER TABLE v1Table ALTER COLUMN i" + checkError( + exception = intercept[AnalysisException] { + parseAndResolve(sql) + }, + errorClass = "_LEGACY_ERROR_TEMP_0035", + parameters = Map("message" -> + "ALTER TABLE table ALTER COLUMN requires a TYPE, a SET/DROP, a COMMENT, or a FIRST/AFTER"), + context = ExpectedContext(fragment = sql, start = 0, stop = 33)) } test("alter table: alter column case sensitivity for v1 table") { @@ -1317,10 +1343,23 @@ class PlanResolutionSuite extends AnalysisTest { withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { val sql = s"ALTER TABLE $tblName ALTER COLUMN I COMMENT 'new comment'" if (caseSensitive) { - val e = intercept[AnalysisException] { - parseAndResolve(sql) - } - assert(e.getMessage.contains("Missing field I in table spark_catalog.default.v1Table")) + checkError( + exception = intercept[AnalysisException] { + parseAndResolve(sql) + }, + errorClass = "_LEGACY_ERROR_TEMP_1331", + parameters = Map( + "fieldName" -> "I", + "table" -> "spark_catalog.default.v1Table", + "schema" -> + """root + | |-- i: integer (nullable = true) + | |-- s: string (nullable = true) + | |-- point: struct (nullable = true) + | | |-- x: integer (nullable = true) + | | |-- y: integer (nullable = true) + |""".stripMargin), + context = ExpectedContext(fragment = sql, start = 0, stop = 55)) } else { val actual = parseAndResolve(sql) val expected = AlterTableChangeColumnCommand( @@ -1669,40 +1708,39 @@ class PlanResolutionSuite extends AnalysisTest { // This MERGE INTO command includes an ON clause with a DEFAULT column reference. This is // invalid and returns an error message. val mergeWithDefaultReferenceInMergeCondition = - s""" - |MERGE INTO testcat.tab AS target + s"""MERGE INTO testcat.tab AS target |USING testcat.tab1 AS source |ON target.i = DEFAULT |WHEN MATCHED AND (target.s = 31) THEN DELETE |WHEN MATCHED AND (target.s = 31) | THEN UPDATE SET target.s = DEFAULT |WHEN NOT MATCHED AND (source.s='insert') - | THEN INSERT (target.i, target.s) values (DEFAULT, DEFAULT) - """.stripMargin - assert(intercept[AnalysisException] { - parseAndResolve(mergeWithDefaultReferenceInMergeCondition) - }.getMessage.contains( - QueryCompilationErrors.defaultReferencesNotAllowedInMergeCondition().getMessage)) + | THEN INSERT (target.i, target.s) values (DEFAULT, DEFAULT)""".stripMargin + checkError( + exception = intercept[AnalysisException] { + parseAndResolve(mergeWithDefaultReferenceInMergeCondition) + }, + errorClass = "_LEGACY_ERROR_TEMP_1342", + parameters = Map.empty) // DEFAULT column reference within a complex expression: // This MERGE INTO command includes a WHEN MATCHED clause with a DEFAULT column reference as // of a complex expression (DEFAULT + 1). This is invalid and returns an error message. val mergeWithDefaultReferenceAsPartOfComplexExpression = - s""" - |MERGE INTO testcat.tab AS target + s"""MERGE INTO testcat.tab AS target |USING testcat.tab1 AS source |ON target.i = source.i |WHEN MATCHED AND (target.s = 31) THEN DELETE |WHEN MATCHED AND (target.s = 31) | THEN UPDATE SET target.s = DEFAULT + 1 |WHEN NOT MATCHED AND (source.s='insert') - | THEN INSERT (target.i, target.s) values (DEFAULT, DEFAULT) - """.stripMargin - assert(intercept[AnalysisException] { - parseAndResolve(mergeWithDefaultReferenceAsPartOfComplexExpression) - }.getMessage.contains( - QueryCompilationErrors - .defaultReferencesNotAllowedInComplexExpressionsInMergeInsertsOrUpdates().getMessage)) + | THEN INSERT (target.i, target.s) values (DEFAULT, DEFAULT)""".stripMargin + checkError( + exception = intercept[AnalysisException] { + parseAndResolve(mergeWithDefaultReferenceAsPartOfComplexExpression) + }, + errorClass = "_LEGACY_ERROR_TEMP_1343", + parameters = Map.empty) // Ambiguous DEFAULT column reference when the table itself contains a column named // "DEFAULT". @@ -1835,52 +1873,68 @@ class PlanResolutionSuite extends AnalysisTest { } val sql2 = - s""" - |MERGE INTO $target + s"""MERGE INTO $target |USING $source |ON i = 1 - |WHEN MATCHED THEN DELETE - """.stripMargin + |WHEN MATCHED THEN DELETE""".stripMargin // merge condition is resolved with both target and source tables, and we can't // resolve column `i` as it's ambiguous. - val e2 = intercept[AnalysisException](parseAndResolve(sql2)) - assert(e2.message.contains("Reference 'i' is ambiguous")) + checkError( + exception = intercept[AnalysisException](parseAndResolve(sql2)), + errorClass = null, + parameters = Map.empty, + context = ExpectedContext( + fragment = "i", + start = 22 + target.length + source.length, + stop = 22 + target.length + source.length)) val sql3 = - s""" - |MERGE INTO $target + s"""MERGE INTO $target |USING $source |ON 1 = 1 - |WHEN MATCHED AND (s='delete') THEN DELETE - """.stripMargin + |WHEN MATCHED AND (s='delete') THEN DELETE""".stripMargin // delete condition is resolved with both target and source tables, and we can't // resolve column `s` as it's ambiguous. - val e3 = intercept[AnalysisException](parseAndResolve(sql3)) - assert(e3.message.contains("Reference 's' is ambiguous")) + checkError( + exception = intercept[AnalysisException](parseAndResolve(sql3)), + errorClass = null, + parameters = Map.empty, + context = ExpectedContext( + fragment = "s", + start = 46 + target.length + source.length, + stop = 46 + target.length + source.length)) val sql4 = - s""" - |MERGE INTO $target + s"""MERGE INTO $target |USING $source |ON 1 = 1 - |WHEN MATCHED AND (s = 'a') THEN UPDATE SET i = 1 - """.stripMargin + |WHEN MATCHED AND (s = 'a') THEN UPDATE SET i = 1""".stripMargin // update condition is resolved with both target and source tables, and we can't // resolve column `s` as it's ambiguous. - val e4 = intercept[AnalysisException](parseAndResolve(sql4)) - assert(e4.message.contains("Reference 's' is ambiguous")) + checkError( + exception = intercept[AnalysisException](parseAndResolve(sql4)), + errorClass = null, + parameters = Map.empty, + context = ExpectedContext( + fragment = "s", + start = 46 + target.length + source.length, + stop = 46 + target.length + source.length)) val sql5 = - s""" - |MERGE INTO $target + s"""MERGE INTO $target |USING $source |ON 1 = 1 - |WHEN MATCHED THEN UPDATE SET s = s - """.stripMargin + |WHEN MATCHED THEN UPDATE SET s = s""".stripMargin // update value is resolved with both target and source tables, and we can't // resolve column `s` as it's ambiguous. - val e5 = intercept[AnalysisException](parseAndResolve(sql5)) - assert(e5.message.contains("Reference 's' is ambiguous")) + checkError( + exception = intercept[AnalysisException](parseAndResolve(sql5)), + errorClass = null, + parameters = Map.empty, + context = ExpectedContext( + fragment = "s", + start = 61 + target.length + source.length, + stop = 61 + target.length + source.length)) } val sql1 = @@ -1908,9 +1962,10 @@ class PlanResolutionSuite extends AnalysisTest { |ON 1 = 1 |WHEN MATCHED THEN UPDATE SET * |""".stripMargin - val e2 = intercept[AnalysisException](parseAndResolve(sql2)) - assert(e2.message.contains( - "cannot resolve s in MERGE command given columns [testcat.tab2.i, testcat.tab2.x]")) + checkError( + exception = intercept[AnalysisException](parseAndResolve(sql2)), + errorClass = null, + parameters = Map.empty) // INSERT * with incompatible schema between source and target tables. val sql3 = @@ -1920,9 +1975,10 @@ class PlanResolutionSuite extends AnalysisTest { |ON 1 = 1 |WHEN NOT MATCHED THEN INSERT * |""".stripMargin - val e3 = intercept[AnalysisException](parseAndResolve(sql3)) - assert(e3.message.contains( - "cannot resolve s in MERGE command given columns [testcat.tab2.i, testcat.tab2.x]")) + checkError( + exception = intercept[AnalysisException](parseAndResolve(sql3)), + errorClass = null, + parameters = Map.empty) val sql4 = """ @@ -2108,9 +2164,11 @@ class PlanResolutionSuite extends AnalysisTest { ) ) - interceptParseException(parsePlan)( - "CREATE TABLE my_tab(a: INT COMMENT 'test', b: STRING)", - "Syntax error at or near ':': extra input ':'")() + val sql = "CREATE TABLE my_tab(a: INT COMMENT 'test', b: STRING)" + checkError( + exception = parseException(parsePlan)(sql), + errorClass = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "':'", "hint" -> ": extra input ':'")) } test("create hive table - table file format") { @@ -2170,7 +2228,11 @@ class PlanResolutionSuite extends AnalysisTest { assert(ct.tableDesc.storage.outputFormat == hiveSerde.get.outputFormat) } } else { - assertUnsupported(query, Seq("row format serde", "incompatible", s)) + assertUnsupported( + query, + Map("message" -> (s"ROW FORMAT SERDE is incompatible with format '$s', " + + "which also specifies a serde")), + ExpectedContext(fragment = query, start = 0, stop = 57 + s.length)) } } } @@ -2192,7 +2254,11 @@ class PlanResolutionSuite extends AnalysisTest { assert(ct.tableDesc.storage.outputFormat == hiveSerde.get.outputFormat) } } else { - assertUnsupported(query, Seq("row format delimited", "only compatible with 'textfile'", s)) + assertUnsupported( + query, + Map("message" -> ("ROW FORMAT DELIMITED is only compatible with 'textfile', " + + s"not '$s'")), + ExpectedContext(fragment = query, start = 0, stop = 75 + s.length)) } } } @@ -2214,14 +2280,23 @@ class PlanResolutionSuite extends AnalysisTest { } test("create hive table - property values must be set") { + val sql1 = "CREATE TABLE my_tab STORED AS parquet " + + "TBLPROPERTIES('key_without_value', 'key_with_value'='x')" assertUnsupported( - sql = "CREATE TABLE my_tab STORED AS parquet " + - "TBLPROPERTIES('key_without_value', 'key_with_value'='x')", - containsThesePhrases = Seq("key_without_value")) + sql1, + Map("message" -> "Values must be specified for key(s): [key_without_value]"), + ExpectedContext(fragment = sql1, start = 0, stop = 93)) + + val sql2 = "CREATE TABLE my_tab ROW FORMAT SERDE 'serde' " + + "WITH SERDEPROPERTIES('key_without_value', 'key_with_value'='x')" assertUnsupported( - sql = "CREATE TABLE my_tab ROW FORMAT SERDE 'serde' " + - "WITH SERDEPROPERTIES('key_without_value', 'key_with_value'='x')", - containsThesePhrases = Seq("key_without_value")) + sql2, + Map("message" -> "Values must be specified for key(s): [key_without_value]"), + ExpectedContext( + fragment = "ROW FORMAT SERDE 'serde' WITH SERDEPROPERTIES('key_without_value', " + + "'key_with_value'='x')", + start = 20, + stop = 107)) } test("create hive table - location implies external") { @@ -2234,28 +2309,58 @@ class PlanResolutionSuite extends AnalysisTest { } test("Duplicate clauses - create hive table") { - def intercept(sqlCommand: String, messages: String*): Unit = - interceptParseException(parsePlan)(sqlCommand, messages: _*)() - def createTableHeader(duplicateClause: String): String = { s"CREATE TABLE my_tab(a INT, b STRING) STORED AS parquet $duplicateClause $duplicateClause" } - intercept(createTableHeader("TBLPROPERTIES('test' = 'test2')"), - "Found duplicate clauses: TBLPROPERTIES") - intercept(createTableHeader("LOCATION '/tmp/file'"), - "Found duplicate clauses: LOCATION") - intercept(createTableHeader("COMMENT 'a table'"), - "Found duplicate clauses: COMMENT") - intercept(createTableHeader("CLUSTERED BY(b) INTO 256 BUCKETS"), - "Found duplicate clauses: CLUSTERED BY") - intercept(createTableHeader("PARTITIONED BY (k int)"), - "Found duplicate clauses: PARTITIONED BY") - intercept(createTableHeader("STORED AS parquet"), - "Found duplicate clauses: STORED AS/BY") - intercept( - createTableHeader("ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe'"), - "Found duplicate clauses: ROW FORMAT") + val sql1 = createTableHeader("TBLPROPERTIES('test' = 'test2')") + checkError( + exception = parseException(parsePlan)(sql1), + errorClass = "_LEGACY_ERROR_TEMP_0041", + parameters = Map("clauseName" -> "TBLPROPERTIES"), + context = ExpectedContext(fragment = sql1, start = 0, stop = 117)) + + val sql2 = createTableHeader("LOCATION '/tmp/file'") + checkError( + exception = parseException(parsePlan)(sql2), + errorClass = "_LEGACY_ERROR_TEMP_0041", + parameters = Map("clauseName" -> "LOCATION"), + context = ExpectedContext(fragment = sql2, start = 0, stop = 95)) + + val sql3 = createTableHeader("COMMENT 'a table'") + checkError( + exception = parseException(parsePlan)(sql3), + errorClass = "_LEGACY_ERROR_TEMP_0041", + parameters = Map("clauseName" -> "COMMENT"), + context = ExpectedContext(fragment = sql3, start = 0, stop = 89)) + + val sql4 = createTableHeader("CLUSTERED BY(b) INTO 256 BUCKETS") + checkError( + exception = parseException(parsePlan)(sql4), + errorClass = "_LEGACY_ERROR_TEMP_0041", + parameters = Map("clauseName" -> "CLUSTERED BY"), + context = ExpectedContext(fragment = sql4, start = 0, stop = 119)) + + val sql5 = createTableHeader("PARTITIONED BY (k int)") + checkError( + exception = parseException(parsePlan)(sql5), + errorClass = "_LEGACY_ERROR_TEMP_0041", + parameters = Map("clauseName" -> "PARTITIONED BY"), + context = ExpectedContext(fragment = sql5, start = 0, stop = 99)) + + val sql6 = createTableHeader("STORED AS parquet") + checkError( + exception = parseException(parsePlan)(sql6), + errorClass = "_LEGACY_ERROR_TEMP_0041", + parameters = Map("clauseName" -> "STORED AS/BY"), + context = ExpectedContext(fragment = sql6, start = 0, stop = 89)) + + val sql7 = createTableHeader("ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe'") + checkError( + exception = parseException(parsePlan)(sql7), + errorClass = "_LEGACY_ERROR_TEMP_0041", + parameters = Map("clauseName" -> "ROW FORMAT"), + context = ExpectedContext(fragment = sql7, start = 0, stop = 163)) } test("Test CTAS #1") { @@ -2390,9 +2495,16 @@ class PlanResolutionSuite extends AnalysisTest { val s4 = """CREATE TABLE page_view |STORED BY 'storage.handler.class.name' AS SELECT * FROM src""".stripMargin - intercept[AnalysisException] { - extractTableDesc(s4) - } + checkError( + exception = intercept[AnalysisException] { + extractTableDesc(s4) + }, + errorClass = "_LEGACY_ERROR_TEMP_0035", + parameters = Map("message" -> "STORED BY"), + context = ExpectedContext( + fragment = "STORED BY 'storage.handler.class.name'", + start = 23, + stop = 60)) } test("Test CTAS #5") { @@ -2423,13 +2535,27 @@ class PlanResolutionSuite extends AnalysisTest { } test("CTAS statement with a PARTITIONED BY clause is not allowed") { - assertUnsupported(s"CREATE TABLE ctas1 PARTITIONED BY (k int)" + - " AS SELECT key, value FROM (SELECT 1 as key, 2 as value) tmp") + val sql = s"CREATE TABLE ctas1 PARTITIONED BY (k int)" + + " AS SELECT key, value FROM (SELECT 1 as key, 2 as value) tmp" + assertUnsupported( + sql, + Map("message" -> + "Partition column types may not be specified in Create Table As Select (CTAS)"), + ExpectedContext(fragment = sql, start = 0, stop = 100)) } test("CTAS statement with schema") { - assertUnsupported(s"CREATE TABLE ctas1 (age INT, name STRING) AS SELECT * FROM src") - assertUnsupported(s"CREATE TABLE ctas1 (age INT, name STRING) AS SELECT 1, 'hello'") + val sql1 = s"CREATE TABLE ctas1 (age INT, name STRING) AS SELECT * FROM src" + assertUnsupported( + sql1, + Map("message" -> "Schema may not be specified in a Create Table As Select (CTAS) statement"), + ExpectedContext(fragment = sql1, start = 0, stop = 61)) + + val sql2 = s"CREATE TABLE ctas1 (age INT, name STRING) AS SELECT 1, 'hello'" + assertUnsupported( + sql2, + Map("message" -> "Schema may not be specified in a Create Table As Select (CTAS) statement"), + ExpectedContext(fragment = sql2, start = 0, stop = 61)) } test("create table - basic") { @@ -2464,8 +2590,14 @@ class PlanResolutionSuite extends AnalysisTest { test("create table - temporary") { val query = "CREATE TEMPORARY TABLE tab1 (id int, name string)" - val e = intercept[ParseException] { parsePlan(query) } - assert(e.message.contains("Operation not allowed: CREATE TEMPORARY TABLE")) + checkError( + exception = intercept[ParseException] { + parsePlan(query) + }, + errorClass = "_LEGACY_ERROR_TEMP_0035", + parameters = Map( + "message" -> "CREATE TEMPORARY TABLE ..., use CREATE TEMPORARY VIEW instead"), + context = ExpectedContext(fragment = query, start = 0, stop = 48)) } test("create table - external") { @@ -2528,15 +2660,33 @@ class PlanResolutionSuite extends AnalysisTest { test("create table(hive) - skewed by") { val baseQuery = "CREATE TABLE my_table (id int, name string) SKEWED BY" + val query1 = s"$baseQuery(id) ON (1, 10, 100)" + checkError( + exception = intercept[ParseException] { + parsePlan(query1) + }, + errorClass = "_LEGACY_ERROR_TEMP_0035", + parameters = Map("message" -> "CREATE TABLE ... SKEWED BY"), + context = ExpectedContext(fragment = query1, start = 0, stop = 72)) + val query2 = s"$baseQuery(id, name) ON ((1, 'x'), (2, 'y'), (3, 'z'))" + checkError( + exception = intercept[ParseException] { + parsePlan(query2) + }, + errorClass = "_LEGACY_ERROR_TEMP_0035", + parameters = Map("message" -> "CREATE TABLE ... SKEWED BY"), + context = ExpectedContext(fragment = query2, start = 0, stop = 96)) + val query3 = s"$baseQuery(id, name) ON ((1, 'x'), (2, 'y'), (3, 'z')) STORED AS DIRECTORIES" - val e1 = intercept[ParseException] { parsePlan(query1) } - val e2 = intercept[ParseException] { parsePlan(query2) } - val e3 = intercept[ParseException] { parsePlan(query3) } - assert(e1.getMessage.contains("Operation not allowed")) - assert(e2.getMessage.contains("Operation not allowed")) - assert(e3.getMessage.contains("Operation not allowed")) + checkError( + exception = intercept[ParseException] { + parsePlan(query3) + }, + errorClass = "_LEGACY_ERROR_TEMP_0035", + parameters = Map("message" -> "CREATE TABLE ... SKEWED BY"), + context = ExpectedContext(fragment = query3, start = 0, stop = 118)) } test("create table(hive) - row format") { @@ -2583,12 +2733,30 @@ class PlanResolutionSuite extends AnalysisTest { test("create table(hive) - storage handler") { val baseQuery = "CREATE TABLE my_table (id int, name string) STORED BY" + val query1 = s"$baseQuery 'org.papachi.StorageHandler'" + checkError( + exception = intercept[ParseException] { + parsePlan(query1) + }, + errorClass = "_LEGACY_ERROR_TEMP_0035", + parameters = Map("message" -> "STORED BY"), + context = ExpectedContext( + fragment = "STORED BY 'org.papachi.StorageHandler'", + start = 44, + stop = 81)) + val query2 = s"$baseQuery 'org.mamachi.StorageHandler' WITH SERDEPROPERTIES ('k1'='v1')" - val e1 = intercept[ParseException] { parsePlan(query1) } - val e2 = intercept[ParseException] { parsePlan(query2) } - assert(e1.getMessage.contains("Operation not allowed")) - assert(e2.getMessage.contains("Operation not allowed")) + checkError( + exception = intercept[ParseException] { + parsePlan(query2) + }, + errorClass = "_LEGACY_ERROR_TEMP_0035", + parameters = Map("message" -> "STORED BY"), + context = ExpectedContext( + fragment = "STORED BY 'org.mamachi.StorageHandler' WITH SERDEPROPERTIES ('k1'='v1')", + start = 44, + stop = 114)) } test("create table(hive) - everything!") { From 77694b4673dd2efb5b79d596fbd647af3db5f8a0 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Fri, 28 Oct 2022 20:59:13 +0800 Subject: [PATCH 4/6] [SPARK-40918][SQL] Mismatch between FileSourceScanExec and Orc and ParquetFileFormat on producing columnar output ### What changes were proposed in this pull request? We move the decision about supporting columnar output based on WSCG one level from ParquetFileFormat / OrcFileFormat up to FileSourceScanExec, and pass it as a new required option for ParquetFileFormat / OrcFileFormat. Now the semantics is as follows: * `ParquetFileFormat.supportsBatch` and `OrcFileFormat.supportsBatch` returns whether it **can**, not necessarily **will** return columnar output. * To return columnar output, an option `FileFormat.OPTION_RETURNING_BATCH` needs to be passed to `buildReaderWithPartitionValues` in these two file formats. It should only be set to `true` if `supportsBatch` is also `true`, but it can be set to `false` if we don't want columnar output nevertheless - this way, `FileSourceScanExec` can set it to false when there are more than 100 columsn for WSCG, and `ParquetFileFormat` / `OrcFileFormat` doesn't have to concern itself about WSCG limits. * To avoid not passing it by accident, this option is made required. Making it required requires updating a few places that use it, but an error resulting from this is very obscure. It's better to fail early and explicitly here. ### Why are the changes needed? This explains it for `ParquetFileFormat`. `OrcFileFormat` had exactly the same issue. `java.lang.ClassCastException: org.apache.spark.sql.vectorized.ColumnarBatch cannot be cast to org.apache.spark.sql.catalyst.InternalRow` was being thrown because ParquetReader was outputting columnar batches, while FileSourceScanExec expected row output. The mismatch comes from the fact that `ParquetFileFormat.supportBatch` depends on `WholeStageCodegenExec.isTooManyFields(conf, schema)`, where the threshold is 100 fields. When this is used in `FileSourceScanExec`: ``` override lazy val supportsColumnar: Boolean = { relation.fileFormat.supportBatch(relation.sparkSession, schema) } ``` the `schema` comes from output attributes, which includes extra metadata attributes. However, inside `ParquetFileFormat.buildReaderWithPartitionValues` it was calculated again as ``` relation.fileFormat.buildReaderWithPartitionValues( sparkSession = relation.sparkSession, dataSchema = relation.dataSchema, partitionSchema = relation.partitionSchema, requiredSchema = requiredSchema, filters = pushedDownFilters, options = options, hadoopConf = hadoopConf ... val resultSchema = StructType(requiredSchema.fields ++ partitionSchema.fields) ... val returningBatch = supportBatch(sparkSession, resultSchema) ``` Where `requiredSchema` and `partitionSchema` wouldn't include the metadata columns: ``` FileSourceScanExec: output: List(c1#4608L, c2#4609L, ..., c100#4707L, file_path#6388) FileSourceScanExec: dataSchema: StructType(StructField(c1,LongType,true),StructField(c2,LongType,true),...,StructField(c100,LongType,true)) FileSourceScanExec: partitionSchema: StructType() FileSourceScanExec: requiredSchema: StructType(StructField(c1,LongType,true),StructField(c2,LongType,true),...,StructField(c100,LongType,true)) ``` Column like `file_path#6388` are added by the scan, and contain metadata added by the scan, not by the file reader which concerns itself with what is within the file. ### Does this PR introduce _any_ user-facing change? Not a public API change, but it is now required to pass `FileFormat.OPTION_RETURNING_BATCH` in `options` to `ParquetFileFormat.buildReaderWithPartitionValues`. The only user of this API in Apache Spark is `FileSourceScanExec`. ### How was this patch tested? Tests added Closes #38397 from juliuszsompolski/SPARK-40918. Authored-by: Juliusz Sompolski Signed-off-by: Wenchen Fan --- .../sql/execution/DataSourceScanExec.scala | 11 ++++-- .../execution/datasources/FileFormat.scala | 13 +++++++ .../datasources/orc/OrcFileFormat.scala | 33 ++++++++++++++--- .../parquet/ParquetFileFormat.scala | 36 +++++++++++++++---- .../datasources/FileMetadataStructSuite.scala | 26 ++++++++++++++ 5 files changed, 107 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index fdb49bd76746..d9e3db6903cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -523,7 +523,12 @@ case class FileSourceScanExec( // Note that some vals referring the file-based relation are lazy intentionally // so that this plan can be canonicalized on executor side too. See SPARK-23731. override lazy val supportsColumnar: Boolean = { - relation.fileFormat.supportBatch(relation.sparkSession, schema) + val conf = relation.sparkSession.sessionState.conf + // Only output columnar if there is WSCG to read it. + val requiredWholeStageCodegenSettings = + conf.wholeStageEnabled && !WholeStageCodegenExec.isTooManyFields(conf, schema) + requiredWholeStageCodegenSettings && + relation.fileFormat.supportBatch(relation.sparkSession, schema) } private lazy val needsUnsafeRowConversion: Boolean = { @@ -535,6 +540,8 @@ case class FileSourceScanExec( } lazy val inputRDD: RDD[InternalRow] = { + val options = relation.options + + (FileFormat.OPTION_RETURNING_BATCH -> supportsColumnar.toString) val readFile: (PartitionedFile) => Iterator[InternalRow] = relation.fileFormat.buildReaderWithPartitionValues( sparkSession = relation.sparkSession, @@ -542,7 +549,7 @@ case class FileSourceScanExec( partitionSchema = relation.partitionSchema, requiredSchema = requiredSchema, filters = pushedDownFilters, - options = relation.options, + options = options, hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)) val readRDD = if (bucketedScan) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index f7f917d89477..7e920773c048 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -61,6 +61,11 @@ trait FileFormat { /** * Returns whether this format supports returning columnar batch or not. + * If columnar batch output is requested, users shall supply + * FileFormat.OPTION_RETURNING_BATCH -> true + * in relation options when calling buildReaderWithPartitionValues. + * This should only be passed as true if it can actually be supported. + * For ParquetFileFormat and OrcFileFormat, passing this option is required. * * TODO: we should just have different traits for the different formats. */ @@ -191,6 +196,14 @@ object FileFormat { val METADATA_NAME = "_metadata" + /** + * Option to pass to buildReaderWithPartitionValues to return columnar batch output or not. + * For ParquetFileFormat and OrcFileFormat, passing this option is required. + * This should only be passed as true if it can actually be supported, which can be checked + * by calling supportBatch. + */ + val OPTION_RETURNING_BATCH = "returning_batch" + /** Schema of metadata struct that can be produced by every file format. */ val BASE_METADATA_STRUCT: StructType = new StructType() .add(StructField(FileFormat.FILE_PATH, StringType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 024c458feaff..6a58513c346d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -36,7 +36,6 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection -import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -102,8 +101,7 @@ class OrcFileFormat override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = { val conf = sparkSession.sessionState.conf - conf.orcVectorizedReaderEnabled && conf.wholeStageEnabled && - !WholeStageCodegenExec.isTooManyFields(conf, schema) && + conf.orcVectorizedReaderEnabled && schema.forall(s => OrcUtils.supportColumnarReads( s.dataType, sparkSession.sessionState.conf.orcVectorizedReaderNestedColumnEnabled)) } @@ -115,6 +113,18 @@ class OrcFileFormat true } + /** + * Build the reader. + * + * @note It is required to pass FileFormat.OPTION_RETURNING_BATCH in options, to indicate whether + * the reader should return row or columnar output. + * If the caller can handle both, pass + * FileFormat.OPTION_RETURNING_BATCH -> + * supportBatch(sparkSession, + * StructType(requiredSchema.fields ++ partitionSchema.fields)) + * as the option. + * It should be set to "true" only if this reader can support it. + */ override def buildReaderWithPartitionValues( sparkSession: SparkSession, dataSchema: StructType, @@ -126,9 +136,24 @@ class OrcFileFormat val resultSchema = StructType(requiredSchema.fields ++ partitionSchema.fields) val sqlConf = sparkSession.sessionState.conf - val enableVectorizedReader = supportBatch(sparkSession, resultSchema) val capacity = sqlConf.orcVectorizedReaderBatchSize + // Should always be set by FileSourceScanExec creating this. + // Check conf before checking option, to allow working around an issue by changing conf. + val enableVectorizedReader = sqlConf.orcVectorizedReaderEnabled && + options.get(FileFormat.OPTION_RETURNING_BATCH) + .getOrElse { + throw new IllegalArgumentException( + "OPTION_RETURNING_BATCH should always be set for OrcFileFormat. " + + "To workaround this issue, set spark.sql.orc.enableVectorizedReader=false.") + } + .equals("true") + if (enableVectorizedReader) { + // If the passed option said that we are to return batches, we need to also be able to + // do this based on config and resultSchema. + assert(supportBatch(sparkSession, resultSchema)) + } + OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(hadoopConf, sqlConf.caseSensitiveAnalysis) val broadcastedConf = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 5116a6bdb90c..80b6791d8fae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -42,7 +42,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjectio import org.apache.spark.sql.catalyst.parser.LegacyTypeStringParser import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.vectorized.{ConstantColumnVector, OffHeapColumnVector, OnHeapColumnVector} import org.apache.spark.sql.internal.SQLConf @@ -82,12 +81,11 @@ class ParquetFileFormat } /** - * Returns whether the reader will return the rows as batch or not. + * Returns whether the reader can return the rows as batch or not. */ override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = { val conf = sparkSession.sessionState.conf - ParquetUtils.isBatchReadSupportedForSchema(conf, schema) && conf.wholeStageEnabled && - !WholeStageCodegenExec.isTooManyFields(conf, schema) + ParquetUtils.isBatchReadSupportedForSchema(conf, schema) } override def vectorTypes( @@ -110,6 +108,18 @@ class ParquetFileFormat true } + /** + * Build the reader. + * + * @note It is required to pass FileFormat.OPTION_RETURNING_BATCH in options, to indicate whether + * the reader should return row or columnar output. + * If the caller can handle both, pass + * FileFormat.OPTION_RETURNING_BATCH -> + * supportBatch(sparkSession, + * StructType(requiredSchema.fields ++ partitionSchema.fields)) + * as the option. + * It should be set to "true" only if this reader can support it. + */ override def buildReaderWithPartitionValues( sparkSession: SparkSession, dataSchema: StructType, @@ -161,8 +171,6 @@ class ParquetFileFormat val timestampConversion: Boolean = sqlConf.isParquetINT96TimestampConversion val capacity = sqlConf.parquetVectorizedReaderBatchSize val enableParquetFilterPushDown: Boolean = sqlConf.parquetFilterPushDown - // Whole stage codegen (PhysicalRDD) is able to deal with batches directly - val returningBatch = supportBatch(sparkSession, resultSchema) val pushDownDate = sqlConf.parquetFilterPushDownDate val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal @@ -173,6 +181,22 @@ class ParquetFileFormat val datetimeRebaseModeInRead = parquetOptions.datetimeRebaseModeInRead val int96RebaseModeInRead = parquetOptions.int96RebaseModeInRead + // Should always be set by FileSourceScanExec creating this. + // Check conf before checking option, to allow working around an issue by changing conf. + val returningBatch = sparkSession.sessionState.conf.parquetVectorizedReaderEnabled && + options.get(FileFormat.OPTION_RETURNING_BATCH) + .getOrElse { + throw new IllegalArgumentException( + "OPTION_RETURNING_BATCH should always be set for ParquetFileFormat. " + + "To workaround this issue, set spark.sql.parquet.enableVectorizedReader=false.") + } + .equals("true") + if (returningBatch) { + // If the passed option said that we are to return batches, we need to also be able to + // do this based on config and resultSchema. + assert(supportBatch(sparkSession, resultSchema)) + } + (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala index 8909fe49aac1..2c8d72ec6093 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala @@ -622,4 +622,30 @@ class FileMetadataStructSuite extends QueryTest with SharedSparkSession { } } } + + Seq("parquet", "orc").foreach { format => + test(s"SPARK-40918: Output cols around WSCG.isTooManyFields limit in $format") { + // The issue was that ParquetFileFormat would not count the _metadata columns towards + // the WholeStageCodegenExec.isTooManyFields limit, while FileSourceScanExec would, + // resulting in Parquet reader returning columnar output, while scan expected row. + withTempPath { dir => + sql(s"SELECT ${(1 to 100).map(i => s"id+$i as c$i").mkString(", ")} FROM RANGE(100)") + .write.format(format).save(dir.getAbsolutePath) + (98 to 102).foreach { wscgCols => + withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> wscgCols.toString) { + // Would fail with + // java.lang.ClassCastException: org.apache.spark.sql.vectorized.ColumnarBatch + // cannot be cast to org.apache.spark.sql.catalyst.InternalRow + sql( + s""" + |SELECT + | ${(1 to 100).map(i => s"sum(c$i)").mkString(", ")}, + | max(_metadata.file_path) + |FROM $format.`$dir`""".stripMargin + ).collect() + } + } + } + } + } } From 0b892a543f9ea913f961eea95a4e45f1231b9a57 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Fri, 28 Oct 2022 21:06:49 +0800 Subject: [PATCH 5/6] [SPARK-40932][CORE] Fix issue messages for allGather are overridden ### What changes were proposed in this pull request? The messages returned by allGather may be overridden by the following barrier APIs, eg, ``` scala val messages: Array[String] = context.allGather("ABC") context.barrier() ``` the `messages` may be like Array("", ""), but we're expecting Array("ABC", "ABC") The root cause of this issue is the [messages got by allGather](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala#L102) pointing to the [original message](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala#L107) in the local mode. So when the following barrier APIs changed the messages, then the allGather message will be changed accordingly. Finally, users can't get the correct result. This PR fixed this issue by sending back the cloned messages. ### Why are the changes needed? The bug mentioned in this description may block some external SPARK ML libraries which heavily depend on the spark barrier API to do some synchronization. If the barrier mechanism can't guarantee the correctness of the barrier APIs, it will be a disaster for external SPARK ML libraries. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? I added a unit test, with this PR, the unit test can pass Closes #38410 from wbo4958/allgather-issue. Authored-by: Bobby Wang Signed-off-by: Wenchen Fan --- .../org/apache/spark/BarrierCoordinator.scala | 2 +- .../scheduler/BarrierTaskContextSuite.scala | 23 +++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index 04faf7f87cf2..8ffccdf664b2 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -176,7 +176,7 @@ private[spark] class BarrierCoordinator( logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " + s"$taskId, current progress: ${requesters.size}/$numTasks.") if (requesters.size == numTasks) { - requesters.foreach(_.reply(messages)) + requesters.foreach(_.reply(messages.clone())) // Finished current barrier() call successfully, clean up ContextBarrierState and // increase the barrier epoch. logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received all updates from " + diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala index 4f97003e2ed5..26cd5374fa09 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -367,4 +367,27 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with // double check we kill task success assert(System.currentTimeMillis() - startTime < 5000) } + + test("SPARK-40932, messages of allGather should not been overridden " + + "by the following barrier APIs") { + + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local[2]")) + sc.setLogLevel("INFO") + val rdd = sc.makeRDD(1 to 10, 2) + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() + // Sleep for a random time before global sync. + Thread.sleep(Random.nextInt(1000)) + // Pass partitionId message in + val message: String = context.partitionId().toString + val messages: Array[String] = context.allGather(message) + context.barrier() + Iterator.single(messages.toList) + } + val messages = rdd2.collect() + // All the task partitionIds are shared across all tasks + assert(messages.length === 2) + assert(messages.forall(_ == List("0", "1"))) + } + } From e3b720fc7de9082169ae479128c8facabe6923a6 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 28 Oct 2022 10:03:08 -0700 Subject: [PATCH 6/6] [SPARK-40951][PYSPARK][TESTS] `pyspark-connect` tests should be skipped if `pandas` doesn't exist ### What changes were proposed in this pull request? This PR aims to skip `pyspark-connect` unit tests when `pandas` is unavailable. ### Why are the changes needed? **BEFORE** ``` % python/run-tests --modules pyspark-connect Running PySpark tests. Output is in /Users/dongjoon/APACHE/spark-merge/python/unit-tests.log Will test against the following Python executables: ['python3.9'] Will test the following Python modules: ['pyspark-connect'] python3.9 python_implementation is CPython python3.9 version is: Python 3.9.15 Starting test(python3.9): pyspark.sql.tests.connect.test_connect_plan_only (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/f14573f1-131f-494a-a015-8b4762219fb5/python3.9__pyspark.sql.tests.connect.test_connect_plan_only__86sd4pxg.log) Starting test(python3.9): pyspark.sql.tests.connect.test_connect_column_expressions (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/51391499-d21a-4c1d-8b79-6ac52859a4c9/python3.9__pyspark.sql.tests.connect.test_connect_column_expressions__kn__9aur.log) Starting test(python3.9): pyspark.sql.tests.connect.test_connect_basic (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/7854cbef-e40d-4090-a37d-5a5314eb245f/python3.9__pyspark.sql.tests.connect.test_connect_basic__i1rutevd.log) Starting test(python3.9): pyspark.sql.tests.connect.test_connect_select_ops (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/6f947453-7481-4891-81b0-169aaac8c6ee/python3.9__pyspark.sql.tests.connect.test_connect_select_ops__5sxao0ji.log) Traceback (most recent call last): File "/opt/homebrew/Cellar/python3.9/3.9.15/Frameworks/Python.framework/Versions/3.9/lib/python3.9/runpy.py", line 197, in _run_module_as_main return _run_code(code, main_globals, None, File "/opt/homebrew/Cellar/python3.9/3.9.15/Frameworks/Python.framework/Versions/3.9/lib/python3.9/runpy.py", line 87, in _run_code exec(code, run_globals) File "/Users/dongjoon/APACHE/spark-merge/python/pyspark/sql/tests/connect/test_connect_basic.py", line 22, in import pandas ModuleNotFoundError: No module named 'pandas' ``` **AFTER** ``` % python/run-tests --modules pyspark-connect Running PySpark tests. Output is in /Users/dongjoon/APACHE/spark-merge/python/unit-tests.log Will test against the following Python executables: ['python3.9'] Will test the following Python modules: ['pyspark-connect'] python3.9 python_implementation is CPython python3.9 version is: Python 3.9.15 Starting test(python3.9): pyspark.sql.tests.connect.test_connect_basic (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/571609c0-3070-476c-afbe-56e215eb5647/python3.9__pyspark.sql.tests.connect.test_connect_basic__4e9k__5x.log) Starting test(python3.9): pyspark.sql.tests.connect.test_connect_column_expressions (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/4a30d035-e392-4ad2-ac10-5d8bc5421321/python3.9__pyspark.sql.tests.connect.test_connect_column_expressions__c9x39tvp.log) Starting test(python3.9): pyspark.sql.tests.connect.test_connect_plan_only (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/eea0b5db-9a92-4fbb-912d-a59daaf73f8e/python3.9__pyspark.sql.tests.connect.test_connect_plan_only__0p9ivnod.log) Starting test(python3.9): pyspark.sql.tests.connect.test_connect_select_ops (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/6069c664-afd9-4a3c-a0cc-f707577e039e/python3.9__pyspark.sql.tests.connect.test_connect_select_ops__sxzrtiqa.log) Finished test(python3.9): pyspark.sql.tests.connect.test_connect_column_expressions (1s) ... 2 tests were skipped Finished test(python3.9): pyspark.sql.tests.connect.test_connect_select_ops (1s) ... 2 tests were skipped Finished test(python3.9): pyspark.sql.tests.connect.test_connect_plan_only (1s) ... 10 tests were skipped Finished test(python3.9): pyspark.sql.tests.connect.test_connect_basic (1s) ... 6 tests were skipped Tests passed in 1 seconds Skipped tests in pyspark.sql.tests.connect.test_connect_basic with python3.9: test_limit_offset (pyspark.sql.tests.connect.test_connect_basic.SparkConnectTests) ... skip (0.002s) test_schema (pyspark.sql.tests.connect.test_connect_basic.SparkConnectTests) ... skip (0.000s) test_simple_datasource_read (pyspark.sql.tests.connect.test_connect_basic.SparkConnectTests) ... skip (0.000s) test_simple_explain_string (pyspark.sql.tests.connect.test_connect_basic.SparkConnectTests) ... skip (0.000s) test_simple_read (pyspark.sql.tests.connect.test_connect_basic.SparkConnectTests) ... skip (0.000s) test_simple_udf (pyspark.sql.tests.connect.test_connect_basic.SparkConnectTests) ... skip (0.000s) Skipped tests in pyspark.sql.tests.connect.test_connect_column_expressions with python3.9: test_column_literals (pyspark.sql.tests.connect.test_connect_column_expressions.SparkConnectColumnExpressionSuite) ... skip (0.000s) test_simple_column_expressions (pyspark.sql.tests.connect.test_connect_column_expressions.SparkConnectColumnExpressionSuite) ... skip (0.000s) Skipped tests in pyspark.sql.tests.connect.test_connect_plan_only with python3.9: test_all_the_plans (pyspark.sql.tests.connect.test_connect_plan_only.SparkConnectTestsPlanOnly) ... skip (0.002s) test_datasource_read (pyspark.sql.tests.connect.test_connect_plan_only.SparkConnectTestsPlanOnly) ... skip (0.000s) test_deduplicate (pyspark.sql.tests.connect.test_connect_plan_only.SparkConnectTestsPlanOnly) ... skip (0.001s) test_filter (pyspark.sql.tests.connect.test_connect_plan_only.SparkConnectTestsPlanOnly) ... skip (0.000s) test_limit (pyspark.sql.tests.connect.test_connect_plan_only.SparkConnectTestsPlanOnly) ... skip (0.000s) test_offset (pyspark.sql.tests.connect.test_connect_plan_only.SparkConnectTestsPlanOnly) ... skip (0.000s) test_relation_alias (pyspark.sql.tests.connect.test_connect_plan_only.SparkConnectTestsPlanOnly) ... skip (0.000s) test_sample (pyspark.sql.tests.connect.test_connect_plan_only.SparkConnectTestsPlanOnly) ... skip (0.001s) test_simple_project (pyspark.sql.tests.connect.test_connect_plan_only.SparkConnectTestsPlanOnly) ... skip (0.000s) test_simple_udf (pyspark.sql.tests.connect.test_connect_plan_only.SparkConnectTestsPlanOnly) ... skip (0.000s) Skipped tests in pyspark.sql.tests.connect.test_connect_select_ops with python3.9: test_join_with_join_type (pyspark.sql.tests.connect.test_connect_select_ops.SparkConnectToProtoSuite) ... skip (0.002s) test_select_with_columns_and_strings (pyspark.sql.tests.connect.test_connect_select_ops.SparkConnectToProtoSuite) ... skip (0.000s) ``` ### Does this PR introduce _any_ user-facing change? No. This is a test-only PR. ### How was this patch tested? Manually run the following. ``` $ pip3 uninstall pandas $ python/run-tests --modules pyspark-connect ``` Closes #38426 from dongjoon-hyun/SPARK-40951. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../sql/tests/connect/test_connect_basic.py | 15 ++++++++++----- .../connect/test_connect_column_expressions.py | 16 +++++++++++----- .../sql/tests/connect/test_connect_plan_only.py | 13 +++++++++---- .../sql/tests/connect/test_connect_select_ops.py | 15 +++++++++++---- python/pyspark/testing/connectutils.py | 15 ++++++++++----- 5 files changed, 51 insertions(+), 23 deletions(-) diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 459b05cc37aa..17d50f0f50e9 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -18,14 +18,18 @@ import unittest import shutil import tempfile +from pyspark.testing.sqlutils import have_pandas -import pandas +if have_pandas: + import pandas from pyspark.sql import SparkSession, Row from pyspark.sql.types import StructType, StructField, LongType, StringType -from pyspark.sql.connect.client import RemoteSparkSession -from pyspark.sql.connect.function_builder import udf -from pyspark.sql.connect.functions import lit + +if have_pandas: + from pyspark.sql.connect.client import RemoteSparkSession + from pyspark.sql.connect.function_builder import udf + from pyspark.sql.connect.functions import lit from pyspark.sql.dataframe import DataFrame from pyspark.testing.connectutils import should_test_connect, connect_requirement_message from pyspark.testing.utils import ReusedPySparkTestCase @@ -36,7 +40,8 @@ class SparkConnectSQLTestCase(ReusedPySparkTestCase): """Parent test fixture class for all Spark Connect related test cases.""" - connect: RemoteSparkSession + if have_pandas: + connect: RemoteSparkSession tbl_name: str df_text: "DataFrame" diff --git a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py index 6036b63d76f2..790a987e8809 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py +++ b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py @@ -15,14 +15,20 @@ # limitations under the License. # +from typing import cast +import unittest from pyspark.testing.connectutils import PlanOnlyTestFixture -from pyspark.sql.connect.proto import Expression as ProtoExpression -import pyspark.sql.connect as c -import pyspark.sql.connect.plan as p -import pyspark.sql.connect.column as col -import pyspark.sql.connect.functions as fun +from pyspark.testing.sqlutils import have_pandas, pandas_requirement_message +if have_pandas: + from pyspark.sql.connect.proto import Expression as ProtoExpression + import pyspark.sql.connect as c + import pyspark.sql.connect.plan as p + import pyspark.sql.connect.column as col + import pyspark.sql.connect.functions as fun + +@unittest.skipIf(not have_pandas, cast(str, pandas_requirement_message)) class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture): def test_simple_column_expressions(self): df = c.DataFrame.withPlan(p.Read("table")) diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py index 450f5c70faba..14b939e019ba 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py @@ -14,15 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from typing import cast import unittest from pyspark.testing.connectutils import PlanOnlyTestFixture -import pyspark.sql.connect.proto as proto -from pyspark.sql.connect.readwriter import DataFrameReader -from pyspark.sql.connect.function_builder import UserDefinedFunction, udf -from pyspark.sql.types import StringType +from pyspark.testing.sqlutils import have_pandas, pandas_requirement_message +if have_pandas: + import pyspark.sql.connect.proto as proto + from pyspark.sql.connect.readwriter import DataFrameReader + from pyspark.sql.connect.function_builder import UserDefinedFunction, udf + from pyspark.sql.types import StringType + +@unittest.skipIf(not have_pandas, cast(str, pandas_requirement_message)) class SparkConnectTestsPlanOnly(PlanOnlyTestFixture): """These test cases exercise the interface to the proto plan generation but do not call Spark.""" diff --git a/python/pyspark/sql/tests/connect/test_connect_select_ops.py b/python/pyspark/sql/tests/connect/test_connect_select_ops.py index e89b4b34ea01..a29c70541462 100644 --- a/python/pyspark/sql/tests/connect/test_connect_select_ops.py +++ b/python/pyspark/sql/tests/connect/test_connect_select_ops.py @@ -14,13 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from typing import cast +import unittest + from pyspark.testing.connectutils import PlanOnlyTestFixture -from pyspark.sql.connect import DataFrame -from pyspark.sql.connect.functions import col -from pyspark.sql.connect.plan import Read -import pyspark.sql.connect.proto as proto +from pyspark.testing.sqlutils import have_pandas, pandas_requirement_message + +if have_pandas: + from pyspark.sql.connect import DataFrame + from pyspark.sql.connect.functions import col + from pyspark.sql.connect.plan import Read + import pyspark.sql.connect.proto as proto +@unittest.skipIf(not have_pandas, cast(str, pandas_requirement_message)) class SparkConnectToProtoSuite(PlanOnlyTestFixture): def test_select_with_columns_and_strings(self): df = DataFrame.withPlan(Read("table")) diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index 700b7bb72e18..d9bced3af114 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -18,13 +18,18 @@ from typing import Any, Dict import functools import unittest +from pyspark.testing.sqlutils import have_pandas -from pyspark.sql.connect import DataFrame -from pyspark.sql.connect.plan import Read -from pyspark.testing.utils import search_jar +if have_pandas: + from pyspark.sql.connect import DataFrame + from pyspark.sql.connect.plan import Read + from pyspark.testing.utils import search_jar + + connect_jar = search_jar("connector/connect", "spark-connect-assembly-", "spark-connect") +else: + connect_jar = None -connect_jar = search_jar("connector/connect", "spark-connect-assembly-", "spark-connect") if connect_jar is None: connect_requirement_message = ( "Skipping all Spark Connect Python tests as the optional Spark Connect project was " @@ -38,7 +43,7 @@ os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, plugin_args, existing_args]) connect_requirement_message = None # type: ignore -should_test_connect = connect_requirement_message is None +should_test_connect = connect_requirement_message is None and have_pandas class MockRemoteSession: