diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 357ea2e6126b..0521270cee7c 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -232,7 +232,7 @@ jobs: distribution: temurin java-version: ${{ matrix.java }} - name: Install Python 3.8 - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 # We should install one Python that is higher then 3+ for SQL and Yarn because: # - SQL component also has Python related tests, for example, IntegratedUDFTestUtils. # - Yarn has a Python specific test too, for example, YarnClusterSuite. diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/SparkConnectPlugin.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/SparkConnectPlugin.scala index 4ecbfd123f0d..bb694a767989 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/SparkConnectPlugin.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/SparkConnectPlugin.scala @@ -22,7 +22,6 @@ import java.util import scala.collection.JavaConverters._ import org.apache.spark.SparkContext -import org.apache.spark.annotation.Unstable import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, PluginContext, SparkPlugin} import org.apache.spark.sql.connect.service.SparkConnectService @@ -33,7 +32,6 @@ import org.apache.spark.sql.connect.service.SparkConnectService * implement it as a Driver Plugin. To enable Spark Connect, simply make sure that the appropriate * JAR is available in the CLASSPATH and the driver plugin is configured to load this class. */ -@Unstable class SparkConnectPlugin extends SparkPlugin { /** diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala index 47d421a0359b..80c36a4773e6 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala @@ -21,7 +21,6 @@ import scala.collection.JavaConverters._ import com.google.common.collect.{Lists, Maps} -import org.apache.spark.annotation.{Since, Unstable} import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction} import org.apache.spark.connect.proto import org.apache.spark.connect.proto.WriteOperation @@ -35,8 +34,6 @@ final case class InvalidCommandInput( private val cause: Throwable = null) extends Exception(message, cause) -@Unstable -@Since("3.4.0") class SparkConnectCommandPlanner(session: SparkSession, command: proto.Command) { lazy val pythonExec = diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 6ae6dfa1577c..a9a97e740d84 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -26,6 +26,9 @@ import org.apache.spark.sql.connect.planner.DataTypeProtoConverter /** * A collection of implicit conversions that create a DSL for constructing connect protos. + * + * All classes in connect/dsl are considered an internal API to Spark Connect and are subject to + * change between minor releases. */ package object dsl { diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 9e3899f4a1a0..53abf2e77090 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.connect.planner import scala.collection.JavaConverters._ -import org.apache.spark.annotation.{Since, Unstable} import org.apache.spark.connect.proto import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} @@ -38,8 +37,6 @@ final case class InvalidPlanInput( private val cause: Throwable = None.orNull) extends Exception(message, cause) -@Unstable -@Since("3.4.0") class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { def transform(): LogicalPlan = { diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index 5841017e5bb7..a1e70975da55 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -27,7 +27,6 @@ import io.grpc.protobuf.services.ProtoReflectionService import io.grpc.stub.StreamObserver import org.apache.spark.SparkEnv -import org.apache.spark.annotation.{Since, Unstable} import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{AnalyzeResponse, Request, Response, SparkConnectServiceGrpc} import org.apache.spark.internal.Logging @@ -44,8 +43,6 @@ import org.apache.spark.sql.execution.ExtendedMode * @param debug * delegates debug behavior to the handlers. */ -@Unstable -@Since("3.4.0") class SparkConnectService(debug: Boolean) extends SparkConnectServiceGrpc.SparkConnectServiceImplBase with Logging { @@ -127,9 +124,7 @@ class SparkConnectService(debug: Boolean) * @param userId * @param session */ -@Unstable -@Since("3.4.0") -private[connect] case class SessionHolder(userId: String, session: SparkSession) +case class SessionHolder(userId: String, session: SparkSession) /** * Static instance of the SparkConnectService. @@ -137,8 +132,6 @@ private[connect] case class SessionHolder(userId: String, session: SparkSession) * Used to start the overall SparkConnect service and provides global state to manage the * different SparkSession from different users connecting to the cluster. */ -@Unstable -@Since("3.4.0") object SparkConnectService { private val CACHE_SIZE = 100 @@ -169,7 +162,7 @@ object SparkConnectService { /** * Based on the `key` find or create a new SparkSession. */ - private[connect] def getOrCreateIsolatedSession(key: SessionCacheKey): SessionHolder = { + def getOrCreateIsolatedSession(key: SessionCacheKey): SessionHolder = { userSessionMapping.get( key, () => { diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 997b0f6b6d82..a429823c02f8 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -23,7 +23,6 @@ import com.google.protobuf.ByteString import io.grpc.stub.StreamObserver import org.apache.spark.SparkException -import org.apache.spark.annotation.{Since, Unstable} import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{Request, Response} import org.apache.spark.internal.Logging @@ -34,8 +33,6 @@ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec} import org.apache.spark.sql.internal.SQLConf -@Unstable -@Since("3.4.0") class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) extends Logging { // The maximum batch size in bytes for a single batch of data to be returned via proto. @@ -60,7 +57,7 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte processRows(request.getClientId, rows) } - private[connect] def processRows(clientId: String, rows: DataFrame): Unit = { + def processRows(clientId: String, rows: DataFrame): Unit = { val timeZoneId = SQLConf.get.sessionLocalTimeZone // Only process up to 10MB of data. diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 9c494c043796..16347f89463e 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -200,6 +200,11 @@ "The accepts only arrays of pair structs, but is of ." ] }, + "MAP_ZIP_WITH_DIFF_TYPES" : { + "message" : [ + "Input to the should have been two maps with compatible key types, but it's [, ]." + ] + }, "NON_FOLDABLE_INPUT" : { "message" : [ "the input should be a foldable expression; however, got ." @@ -215,6 +220,11 @@ "Null typed values cannot be used as arguments of ." ] }, + "PARAMETER_CONSTRAINT_VIOLATION" : { + "message" : [ + "The () must be the ()" + ] + }, "RANGE_FRAME_INVALID_TYPE" : { "message" : [ "The data type used in the order specification does not match the data type which is used in the range frame." @@ -270,6 +280,11 @@ "The must not be null" ] }, + "UNEXPECTED_RETURN_TYPE" : { + "message" : [ + "The requires return type, but the actual is type." + ] + }, "UNEXPECTED_STATIC_METHOD" : { "message" : [ "cannot find a static method that matches the argument types in " @@ -949,6 +964,11 @@ "Literal for '' of ." ] }, + "MULTIPLE_BUCKET_TRANSFORMS" : { + "message" : [ + "Multiple bucket TRANSFORMs." + ] + }, "NATURAL_CROSS_JOIN" : { "message" : [ "NATURAL CROSS JOIN." diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index 46b62d879cf3..7a08de9c1814 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -370,6 +370,12 @@ abstract class SparkFunSuite checkError(exception, errorClass, sqlState, parameters, false, Array(context)) + protected def checkErrorMatchPVals( + exception: SparkThrowable, + errorClass: String, + parameters: Map[String, String]): Unit = + checkError(exception, errorClass, None, parameters, matchPVals = true) + protected def checkErrorMatchPVals( exception: SparkThrowable, errorClass: String, diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 33883a2efaa5..20c537e0e672 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -1177,6 +1177,7 @@ object Unidoc { .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/util/collection"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/util/kvstore"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/catalyst"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/connect"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/execution"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/internal"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/hive"))) diff --git a/python/pyspark/pandas/namespace.py b/python/pyspark/pandas/namespace.py index 21468855858e..333d39fc77cd 100644 --- a/python/pyspark/pandas/namespace.py +++ b/python/pyspark/pandas/namespace.py @@ -213,7 +213,7 @@ def range( def read_csv( - path: str, + path: Union[str, List[str]], sep: str = ",", header: Union[str, int, None] = "infer", names: Optional[Union[str, List[str]]] = None, @@ -234,8 +234,8 @@ def read_csv( Parameters ---------- - path : str - The path string storing the CSV file to be read. + path : str or list + Path(s) of the CSV file(s) to be read. sep : str, default ‘,’ Delimiter to use. Non empty string. header : int, default ‘infer’ @@ -296,6 +296,10 @@ def read_csv( Examples -------- >>> ps.read_csv('data.csv') # doctest: +SKIP + + Load multiple CSV files as a single DataFrame: + + >>> ps.read_csv(['data-01.csv', 'data-02.csv']) # doctest: +SKIP """ # For latin-1 encoding is same as iso-8859-1, that's why its mapped to iso-8859-1. encoding_mapping = {"latin-1": "iso-8859-1"} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 53c79d1fd54b..93c1074dfbed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -18,13 +18,13 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TreePattern.{TIME_WINDOW, TreePattern} import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_DAY import org.apache.spark.sql.catalyst.util.IntervalUtils -import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.types._ // scalastyle:off line.size.limit line.contains.tab @@ -71,7 +71,8 @@ case class TimeWindow( startTime: Long) extends UnaryExpression with ImplicitCastInputTypes with Unevaluable - with NonSQLExpression { + with NonSQLExpression + with QueryErrorsBase { ////////////////////////// // SQL Constructors @@ -114,18 +115,48 @@ case class TimeWindow( val dataTypeCheck = super.checkInputDataTypes() if (dataTypeCheck.isSuccess) { if (windowDuration <= 0) { - return TypeCheckFailure(s"The window duration ($windowDuration) must be greater than 0.") + return DataTypeMismatch( + errorSubClass = "VALUE_OUT_OF_RANGE", + messageParameters = Map( + "exprName" -> toSQLId("window_duration"), + "valueRange" -> s"(0, ${Long.MaxValue}]", + "currentValue" -> toSQLValue(windowDuration, LongType) + ) + ) } if (slideDuration <= 0) { - return TypeCheckFailure(s"The slide duration ($slideDuration) must be greater than 0.") + return DataTypeMismatch( + errorSubClass = "VALUE_OUT_OF_RANGE", + messageParameters = Map( + "exprName" -> toSQLId("slide_duration"), + "valueRange" -> s"(0, ${Long.MaxValue}]", + "currentValue" -> toSQLValue(slideDuration, LongType) + ) + ) } if (slideDuration > windowDuration) { - return TypeCheckFailure(s"The slide duration ($slideDuration) must be less than or equal" + - s" to the windowDuration ($windowDuration).") + return DataTypeMismatch( + errorSubClass = "PARAMETER_CONSTRAINT_VIOLATION", + messageParameters = Map( + "leftExprName" -> toSQLId("slide_duration"), + "leftExprValue" -> toSQLValue(slideDuration, LongType), + "constraint" -> "<=", + "rightExprName" -> toSQLId("window_duration"), + "rightExprValue" -> toSQLValue(windowDuration, LongType) + ) + ) } if (startTime.abs >= slideDuration) { - return TypeCheckFailure(s"The absolute value of start time ($startTime) must be less " + - s"than the slideDuration ($slideDuration).") + return DataTypeMismatch( + errorSubClass = "PARAMETER_CONSTRAINT_VIOLATION", + messageParameters = Map( + "leftExprName" -> toSQLId("abs(start_time)"), + "leftExprValue" -> toSQLValue(startTime.abs, LongType), + "constraint" -> "<", + "rightExprName" -> toSQLId("slide_duration"), + "rightExprValue" -> toSQLValue(slideDuration, LongType) + ) + ) } } dataTypeCheck diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 98513fb5dddf..b59860ff1812 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -24,6 +24,8 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedException} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.trees.{BinaryLike, QuaternaryLike, TernaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern._ @@ -400,11 +402,25 @@ case class ArraySort( if (function.dataType == IntegerType) { TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.TypeCheckFailure("Return type of the given function has to be " + - "IntegerType") + DataTypeMismatch( + errorSubClass = "UNEXPECTED_RETURN_TYPE", + messageParameters = Map( + "functionName" -> toSQLId(function.prettyName), + "expectedType" -> toSQLType(IntegerType), + "actualType" -> toSQLType(function.dataType) + ) + ) } case _ => - TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "1", + "requiredType" -> toSQLType(ArrayType), + "inputSql" -> toSQLExpr(argument), + "inputType" -> toSQLType(argument.dataType) + ) + ) } case failure => failure } @@ -804,9 +820,13 @@ case class ArrayAggregate( case TypeCheckResult.TypeCheckSuccess => if (!DataType.equalsStructurally( zero.dataType, merge.dataType, ignoreNullability = true)) { - TypeCheckResult.TypeCheckFailure( - s"argument 3 requires ${zero.dataType.simpleString} type, " + - s"however, '${merge.sql}' is of ${merge.dataType.catalogString} type.") + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "3", + "requiredType" -> toSQLType(zero.dataType), + "inputSql" -> toSQLExpr(merge), + "inputType" -> toSQLType(merge.dataType))) } else { TypeCheckResult.TypeCheckSuccess } @@ -1025,9 +1045,14 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) if (leftKeyType.sameType(rightKeyType)) { TypeUtils.checkForOrderingExpr(leftKeyType, prettyName) } else { - TypeCheckResult.TypeCheckFailure(s"The input to function $prettyName should have " + - s"been two ${MapType.simpleString}s with compatible key types, but the key types are " + - s"[${leftKeyType.catalogString}, ${rightKeyType.catalogString}].") + DataTypeMismatch( + errorSubClass = "MAP_ZIP_WITH_DIFF_TYPES", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "leftType" -> toSQLType(leftKeyType), + "rightType" -> toSQLType(rightKeyType) + ) + ) } case failure => failure } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index f69ece52d858..16f081a0cc2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1036,6 +1036,7 @@ object Hex { def unhex(bytes: Array[Byte]): Array[Byte] = { val out = new Array[Byte]((bytes.length + 1) >> 1) var i = 0 + var oddShift = 0 if ((bytes.length & 0x01) != 0) { // padding with '0' if (bytes(0) < 0) { @@ -1047,6 +1048,7 @@ object Hex { } out(0) = v i += 1 + oddShift = 1 } // two characters form the hex value. while (i < bytes.length) { @@ -1058,7 +1060,7 @@ object Hex { if (first == -1 || second == -1) { return null } - out(i / 2) = (((first << 4) | second) & 0xFF).toByte + out(i / 2 + oddShift) = (((first << 4) | second) & 0xFF).toByte i += 2 } out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala index 91809b6176c8..d9f15d84d893 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala @@ -60,7 +60,9 @@ private[sql] object CatalogV2Implicits { identityCols += col case BucketTransform(numBuckets, col, sortCol) => - if (bucketSpec.nonEmpty) throw QueryExecutionErrors.multipleBucketTransformsError + if (bucketSpec.nonEmpty) { + throw QueryExecutionErrors.unsupportedMultipleBucketTransformsError + } if (sortCol.isEmpty) { bucketSpec = Some(BucketSpec(numBuckets, col.map(_.fieldNames.mkString(".")), Nil)) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 7e870e23fba0..ba78858debc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -2623,9 +2623,9 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { "format" -> format)) } - def multipleBucketTransformsError(): SparkUnsupportedOperationException = { + def unsupportedMultipleBucketTransformsError(): SparkUnsupportedOperationException = { new SparkUnsupportedOperationException( - errorClass = "_LEGACY_ERROR_TEMP_2279", + errorClass = "UNSUPPORTED_FEATURE.MULTIPLE_BUCKET_TRANSFORMS", messageParameters = Map.empty) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index ecd5b9e22fb2..04de2bdcb51c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -433,67 +433,131 @@ class AnalysisErrorSuite extends AnalysisTest { "UNRESOLVED_COLUMN.WITH_SUGGESTION", Map("objectName" -> "`bad_column`", "proposal" -> "`a`, `b`, `c`, `d`, `e`")) - errorTest( + errorClassTest( "slide duration greater than window in time window", testRelation2.select( TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "2 second", "0 second").as("window")), - s"The slide duration " :: " must be less than or equal to the windowDuration " :: Nil + "DATATYPE_MISMATCH.PARAMETER_CONSTRAINT_VIOLATION", + Map( + "sqlExpr" -> "\"window(2016-01-01 01:01:01, 1000000, 2000000, 0)\"", + "leftExprName" -> "`slide_duration`", + "leftExprValue" -> "2000000L", + "constraint" -> "<=", + "rightExprName" -> "`window_duration`", + "rightExprValue" -> "1000000L" + ) ) - errorTest( + errorClassTest( "start time greater than slide duration in time window", testRelation.select( TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "1 minute").as("window")), - "The absolute value of start time " :: " must be less than the slideDuration " :: Nil + "DATATYPE_MISMATCH.PARAMETER_CONSTRAINT_VIOLATION", + Map( + "sqlExpr" -> "\"window(2016-01-01 01:01:01, 1000000, 1000000, 60000000)\"", + "leftExprName" -> "`abs(start_time)`", + "leftExprValue" -> "60000000L", + "constraint" -> "<", + "rightExprName" -> "`slide_duration`", + "rightExprValue" -> "1000000L" + ) ) - errorTest( + errorClassTest( "start time equal to slide duration in time window", testRelation.select( TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "1 second").as("window")), - "The absolute value of start time " :: " must be less than the slideDuration " :: Nil + "DATATYPE_MISMATCH.PARAMETER_CONSTRAINT_VIOLATION", + Map( + "sqlExpr" -> "\"window(2016-01-01 01:01:01, 1000000, 1000000, 1000000)\"", + "leftExprName" -> "`abs(start_time)`", + "leftExprValue" -> "1000000L", + "constraint" -> "<", + "rightExprName" -> "`slide_duration`", + "rightExprValue" -> "1000000L" + ) ) - errorTest( + errorClassTest( "SPARK-21590: absolute value of start time greater than slide duration in time window", testRelation.select( TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "-1 minute").as("window")), - "The absolute value of start time " :: " must be less than the slideDuration " :: Nil + "DATATYPE_MISMATCH.PARAMETER_CONSTRAINT_VIOLATION", + Map( + "sqlExpr" -> "\"window(2016-01-01 01:01:01, 1000000, 1000000, -60000000)\"", + "leftExprName" -> "`abs(start_time)`", + "leftExprValue" -> "60000000L", + "constraint" -> "<", + "rightExprName" -> "`slide_duration`", + "rightExprValue" -> "1000000L" + ) ) - errorTest( + errorClassTest( "SPARK-21590: absolute value of start time equal to slide duration in time window", testRelation.select( TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "-1 second").as("window")), - "The absolute value of start time " :: " must be less than the slideDuration " :: Nil + "DATATYPE_MISMATCH.PARAMETER_CONSTRAINT_VIOLATION", + Map( + "sqlExpr" -> "\"window(2016-01-01 01:01:01, 1000000, 1000000, -1000000)\"", + "leftExprName" -> "`abs(start_time)`", + "leftExprValue" -> "1000000L", + "constraint" -> "<", + "rightExprName" -> "`slide_duration`", + "rightExprValue" -> "1000000L" + ) ) - errorTest( + errorClassTest( "negative window duration in time window", testRelation.select( TimeWindow(Literal("2016-01-01 01:01:01"), "-1 second", "1 second", "0 second").as("window")), - "The window duration " :: " must be greater than 0." :: Nil + "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE", + Map( + "sqlExpr" -> "\"window(2016-01-01 01:01:01, -1000000, 1000000, 0)\"", + "exprName" -> "`window_duration`", + "valueRange" -> s"(0, 9223372036854775807]", + "currentValue" -> "-1000000L" + ) ) - errorTest( + errorClassTest( "zero window duration in time window", testRelation.select( TimeWindow(Literal("2016-01-01 01:01:01"), "0 second", "1 second", "0 second").as("window")), - "The window duration " :: " must be greater than 0." :: Nil + "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE", + Map( + "sqlExpr" -> "\"window(2016-01-01 01:01:01, 0, 1000000, 0)\"", + "exprName" -> "`window_duration`", + "valueRange" -> "(0, 9223372036854775807]", + "currentValue" -> "0L" + ) ) - errorTest( + errorClassTest( "negative slide duration in time window", testRelation.select( TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "-1 second", "0 second").as("window")), - "The slide duration " :: " must be greater than 0." :: Nil + "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE", + Map( + "sqlExpr" -> "\"window(2016-01-01 01:01:01, 1000000, -1000000, 0)\"", + "exprName" -> "`slide_duration`", + "valueRange" -> "(0, 9223372036854775807]", + "currentValue" -> "-1000000L" + ) ) - errorTest( + errorClassTest( "zero slide duration in time window", testRelation.select( TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "0 second", "0 second").as("window")), - "The slide duration" :: " must be greater than 0." :: Nil + "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE", + Map( + "sqlExpr" -> "\"window(2016-01-01 01:01:01, 1000000, 0, 0)\"", + "exprName" -> "`slide_duration`", + "valueRange" -> "(0, 9223372036854775807]", + "currentValue" -> "0L" + ) ) errorTest( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index ec2cd79dee18..a195b76d7c43 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -184,9 +184,9 @@ trait AnalysisTest extends PlanTest { } if (e.getErrorClass != expectedErrorClass || - !e.messageParameters.sameElements(expectedMessageParameters) || - (line >= 0 && e.line.getOrElse(-1) != line) || - (pos >= 0) && e.startPosition.getOrElse(-1) != pos) { + e.messageParameters != expectedMessageParameters || + (line >= 0 && e.line.getOrElse(-1) != line) || + (pos >= 0) && e.startPosition.getOrElse(-1) != pos) { var failMsg = "" if (e.getErrorClass != expectedErrorClass) { failMsg += @@ -194,7 +194,7 @@ trait AnalysisTest extends PlanTest { |Actual error class: ${e.getErrorClass} """.stripMargin } - if (!e.messageParameters.sameElements(expectedMessageParameters)) { + if (e.messageParameters != expectedMessageParameters) { failMsg += s"""Message parameters should be: ${expectedMessageParameters.mkString("\n ")} |Actual message parameters: ${e.messageParameters.mkString("\n ")} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index a6546d8a5dbb..5f62dc970864 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -859,4 +861,20 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper Seq(1, 1, 2, 3)) } } + + test("Return type of the given function has to be IntegerType") { + val comparator = { + val comp = ArraySort.comparator _ + (left: Expression, right: Expression) => Literal.create("hello", StringType) + } + + val result = arraySort(Literal.create(Seq(3, 1, 1, 2)), comparator).checkInputDataTypes() + assert(result == DataTypeMismatch( + errorSubClass = "UNEXPECTED_RETURN_TYPE", + messageParameters = Map( + "functionName" -> toSQLId("lambdafunction"), + "expectedType" -> toSQLType(IntegerType), + "actualType" -> toSQLType(StringType) + ))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index c741b685a34e..c78d72e7a98a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -590,6 +590,8 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Unhex(Literal("F")), Array[Byte](15)) checkEvaluation(Unhex(Literal("ff")), Array[Byte](-1)) checkEvaluation(Unhex(Literal("GG")), null) + checkEvaluation(Unhex(Literal("123")), Array[Byte](1, 35)) + checkEvaluation(Unhex(Literal("12345")), Array[Byte](1, 35, 69)) // scalastyle:off // Turn off scala style for non-ascii chars checkEvaluation(Unhex(Literal("E4B889E9878DE79A84")), "三重的".getBytes(StandardCharsets.UTF_8)) diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index 8af82efeab37..cb18c547b612 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -83,6 +83,8 @@ SELECT rpad('hi', 'invalid_length'); SELECT hex(lpad(unhex(''), 5)); SELECT hex(lpad(unhex('aabb'), 5)); SELECT hex(lpad(unhex('aabbcc'), 2)); +SELECT hex(lpad(unhex('123'), 2)); +SELECT hex(lpad(unhex('12345'), 2)); SELECT hex(lpad(unhex(''), 5, unhex('1f'))); SELECT hex(lpad(unhex('aa'), 5, unhex('1f'))); SELECT hex(lpad(unhex('aa'), 6, unhex('1f'))); @@ -97,6 +99,8 @@ SELECT hex(lpad(unhex('aabbcc'), 2, unhex('ff'))); SELECT hex(rpad(unhex(''), 5)); SELECT hex(rpad(unhex('aabb'), 5)); SELECT hex(rpad(unhex('aabbcc'), 2)); +SELECT hex(rpad(unhex('123'), 2)); +SELECT hex(rpad(unhex('12345'), 2)); SELECT hex(rpad(unhex(''), 5, unhex('1f'))); SELECT hex(rpad(unhex('aa'), 5, unhex('1f'))); SELECT hex(rpad(unhex('aa'), 6, unhex('1f'))); @@ -202,6 +206,8 @@ select to_binary('737472696E67', 'hex'); select to_binary(''); select to_binary('1', 'hex'); select to_binary('FF'); +select to_binary('123', 'hex'); +select to_binary('12345', 'hex'); -- hex invalid select to_binary('GG'); select to_binary('01 AF', 'hex'); diff --git a/sql/core/src/test/resources/sql-tests/inputs/try-string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/try-string-functions.sql index d21a80d482a0..4ff3e69da67a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/try-string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/try-string-functions.sql @@ -27,6 +27,8 @@ select try_to_binary('737472696E67', 'hex'); select try_to_binary(''); select try_to_binary('1', 'hex'); select try_to_binary('FF'); +select try_to_binary('123'); +select try_to_binary('12345'); -- hex invalid select try_to_binary('GG'); select try_to_binary('01 AF', 'hex'); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out index 1176042393b3..5b82cfa957d1 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out @@ -552,6 +552,22 @@ struct AABB +-- !query +SELECT hex(lpad(unhex('123'), 2)) +-- !query schema +struct +-- !query output +0123 + + +-- !query +SELECT hex(lpad(unhex('12345'), 2)) +-- !query schema +struct +-- !query output +0123 + + -- !query SELECT hex(lpad(unhex(''), 5, unhex('1f'))) -- !query schema @@ -648,6 +664,22 @@ struct AABB +-- !query +SELECT hex(rpad(unhex('123'), 2)) +-- !query schema +struct +-- !query output +0123 + + +-- !query +SELECT hex(rpad(unhex('12345'), 2)) +-- !query schema +struct +-- !query output +0123 + + -- !query SELECT hex(rpad(unhex(''), 5, unhex('1f'))) -- !query schema @@ -1408,6 +1440,22 @@ struct � +-- !query +select to_binary('123', 'hex') +-- !query schema +struct +-- !query output +# + + +-- !query +select to_binary('12345', 'hex') +-- !query schema +struct +-- !query output +#E + + -- !query select to_binary('GG') -- !query schema @@ -1489,7 +1537,8 @@ select to_binary('abc', 'Hex') -- !query schema struct -- !query output -� + +� -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 493fb3c34fcf..58a36b3299fb 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -484,6 +484,22 @@ struct AABB +-- !query +SELECT hex(lpad(unhex('123'), 2)) +-- !query schema +struct +-- !query output +0123 + + +-- !query +SELECT hex(lpad(unhex('12345'), 2)) +-- !query schema +struct +-- !query output +0123 + + -- !query SELECT hex(lpad(unhex(''), 5, unhex('1f'))) -- !query schema @@ -580,6 +596,22 @@ struct AABB +-- !query +SELECT hex(rpad(unhex('123'), 2)) +-- !query schema +struct +-- !query output +0123 + + +-- !query +SELECT hex(rpad(unhex('12345'), 2)) +-- !query schema +struct +-- !query output +0123 + + -- !query SELECT hex(rpad(unhex(''), 5, unhex('1f'))) -- !query schema @@ -1340,6 +1372,22 @@ struct � +-- !query +select to_binary('123', 'hex') +-- !query schema +struct +-- !query output +# + + +-- !query +select to_binary('12345', 'hex') +-- !query schema +struct +-- !query output +#E + + -- !query select to_binary('GG') -- !query schema @@ -1421,7 +1469,8 @@ select to_binary('abc', 'Hex') -- !query schema struct -- !query output -� + +� -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/try-string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/try-string-functions.sql.out index dacbc08a1038..4488bb649654 100644 Binary files a/sql/core/src/test/resources/sql-tests/results/try-string-functions.sql.out and b/sql/core/src/test/resources/sql-tests/results/try-string-functions.sql.out differ diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out index 2f176951df84..09c6e10f7620 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out @@ -82,8 +82,22 @@ FROM various_maps struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'map_zip_with(various_maps.decimal_map1, various_maps.decimal_map2, lambdafunction(struct(k, v1, v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,0), decimal(36,35)].; line 1 pos 7 - +{ + "errorClass" : "DATATYPE_MISMATCH.MAP_ZIP_WITH_DIFF_TYPES", + "messageParameters" : { + "functionName" : "`map_zip_with`", + "leftType" : "\"DECIMAL(36,0)\"", + "rightType" : "\"DECIMAL(36,35)\"", + "sqlExpr" : "\"map_zip_with(decimal_map1, decimal_map2, lambdafunction(struct(k, v1, v2), k, v1, v2))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 81, + "fragment" : "map_zip_with(decimal_map1, decimal_map2, (k, v1, v2) -> struct(k, v1, v2))" + } ] +} -- !query SELECT map_zip_with(decimal_map1, int_map, (k, v1, v2) -> struct(k, v1, v2)) m @@ -110,7 +124,22 @@ FROM various_maps struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'map_zip_with(various_maps.decimal_map2, various_maps.int_map, lambdafunction(struct(k, v1, v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,35), int].; line 1 pos 7 +{ + "errorClass" : "DATATYPE_MISMATCH.MAP_ZIP_WITH_DIFF_TYPES", + "messageParameters" : { + "functionName" : "`map_zip_with`", + "leftType" : "\"DECIMAL(36,35)\"", + "rightType" : "\"INT\"", + "sqlExpr" : "\"map_zip_with(decimal_map2, int_map, lambdafunction(struct(k, v1, v2), k, v1, v2))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 76, + "fragment" : "map_zip_with(decimal_map2, int_map, (k, v1, v2) -> struct(k, v1, v2))" + } ] +} -- !query diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 85877c97ed59..3f02429fe629 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -533,6 +533,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { ) } + test("The given function only supports array input") { + val df = Seq(1, 2, 3).toDF("a") + checkErrorMatchPVals( + exception = intercept[AnalysisException] { + df.select(array_sort(col("a"), (x, y) => x - y)) + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> """"array_sort\(a, lambdafunction\(\(x_\d+ - y_\d+\), x_\d+, y_\d+\)\)"""", + "paramIndex" -> "1", + "requiredType" -> "\"ARRAY\"", + "inputSql" -> "\"a\"", + "inputType" -> "\"INT\"" + )) + } + test("sort_array/array_sort functions") { val df = Seq( (Array[Int](2, 1, 3), Array("b", "c", "a")), @@ -3492,15 +3508,35 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "requiredType" -> "\"ARRAY\"")) // scalastyle:on line.size.limit - val ex4 = intercept[AnalysisException] { - df.selectExpr("aggregate(s, 0, (acc, x) -> x)") - } - assert(ex4.getMessage.contains("data type mismatch: argument 3 requires int type")) + // scalastyle:off line.size.limit + checkError( + exception = intercept[AnalysisException] { + df.selectExpr("aggregate(s, 0, (acc, x) -> x)") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> """"aggregate(s, 0, lambdafunction(namedlambdavariable(), namedlambdavariable(), namedlambdavariable()), lambdafunction(namedlambdavariable(), namedlambdavariable()))"""", + "paramIndex" -> "3", + "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable(), namedlambdavariable())\"", + "inputType" -> "\"STRING\"", + "requiredType" -> "\"INT\"" + )) + // scalastyle:on line.size.limit - val ex4a = intercept[AnalysisException] { - df.select(aggregate(col("s"), lit(0), (acc, x) => x)) - } - assert(ex4a.getMessage.contains("data type mismatch: argument 3 requires int type")) + // scalastyle:off line.size.limit + checkError( + exception = intercept[AnalysisException] { + df.select(aggregate(col("s"), lit(0), (acc, x) => x)) + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> """"aggregate(s, 0, lambdafunction(namedlambdavariable(), namedlambdavariable(), namedlambdavariable()), lambdafunction(namedlambdavariable(), namedlambdavariable()))"""", + "paramIndex" -> "3", + "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable(), namedlambdavariable())\"", + "inputType" -> "\"STRING\"", + "requiredType" -> "\"INT\"" + )) + // scalastyle:on line.size.limit checkError( exception = @@ -3570,17 +3606,34 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } assert(ex1.getMessage.contains("The number of lambda function arguments '2' does not match")) - val ex2 = intercept[AnalysisException] { - df.selectExpr("map_zip_with(mis, mmi, (x, y, z) -> concat(x, y, z))") - } - assert(ex2.getMessage.contains("The input to function map_zip_with should have " + - "been two maps with compatible key types")) + checkError( + exception = intercept[AnalysisException] { + df.selectExpr("map_zip_with(mis, mmi, (x, y, z) -> concat(x, y, z))") + }, + errorClass = "DATATYPE_MISMATCH.MAP_ZIP_WITH_DIFF_TYPES", + parameters = Map( + "sqlExpr" -> "\"map_zip_with(mis, mmi, lambdafunction(concat(x, y, z), x, y, z))\"", + "functionName" -> "`map_zip_with`", + "leftType" -> "\"INT\"", + "rightType" -> "\"MAP\""), + context = ExpectedContext( + fragment = "map_zip_with(mis, mmi, (x, y, z) -> concat(x, y, z))", + start = 0, + stop = 51)) - val ex2a = intercept[AnalysisException] { - df.select(map_zip_with(df("mis"), col("mmi"), (x, y, z) => concat(x, y, z))) - } - assert(ex2a.getMessage.contains("The input to function map_zip_with should have " + - "been two maps with compatible key types")) + // scalastyle:off line.size.limit + checkError( + exception = intercept[AnalysisException] { + df.select(map_zip_with(df("mis"), col("mmi"), (x, y, z) => concat(x, y, z))) + }, + errorClass = "DATATYPE_MISMATCH.MAP_ZIP_WITH_DIFF_TYPES", + matchPVals = true, + parameters = Map( + "sqlExpr" -> """"map_zip_with\(mis, mmi, lambdafunction\(concat\(x_\d+, y_\d+, z_\d+\), x_\d+, y_\d+, z_\d+\)\)"""", + "functionName" -> "`map_zip_with`", + "leftType" -> "\"INT\"", + "rightType" -> "\"MAP\"")) + // scalastyle:on line.size.limit checkError( exception = intercept[AnalysisException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 6276b1a3b60f..3b2271afc862 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -24,6 +24,7 @@ import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.{mock, when} import org.mockito.invocation.InvocationOnMock +import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AnalysisTest, Analyzer, EmptyFunctionRegistry, NoSuchTableException, ResolvedFieldName, ResolvedIdentifier, ResolvedTable, ResolveSessionCatalog, UnresolvedAttribute, UnresolvedInlineTable, UnresolvedRelation, UnresolvedSubqueryColumnAliases, UnresolvedTable} @@ -292,13 +293,12 @@ class PlanResolutionSuite extends AnalysisTest { |CREATE TABLE my_tab(a INT, b STRING) USING parquet |PARTITIONED BY ($transform) """.stripMargin - - val ae = intercept[UnsupportedOperationException] { - parseAndResolve(query) - } - - assert(ae.getMessage - .contains(s"Unsupported partition transform: $transform")) + checkError( + exception = intercept[SparkUnsupportedOperationException] { + parseAndResolve(query) + }, + errorClass = "_LEGACY_ERROR_TEMP_2067", + parameters = Map("transform" -> transform)) } } @@ -310,13 +310,12 @@ class PlanResolutionSuite extends AnalysisTest { |CREATE TABLE my_tab(a INT, b STRING, c String) USING parquet |PARTITIONED BY ($transform) """.stripMargin - - val ae = intercept[UnsupportedOperationException] { - parseAndResolve(query) - } - - assert(ae.getMessage - .contains("Multiple bucket transforms are not supported.")) + checkError( + exception = intercept[SparkUnsupportedOperationException] { + parseAndResolve(query) + }, + errorClass = "UNSUPPORTED_FEATURE.MULTIPLE_BUCKET_TRANSFORMS", + parameters = Map.empty) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala index 7850b2d79b04..32f50c1705dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala @@ -86,7 +86,7 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession { } test("SPARK-34265: Instrument Python UDF execution using SQL Metrics") { - + assume(shouldTestPythonUDFs) val pythonSQLMetrics = List( "data sent to Python workers", "data returned from Python workers",