diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index 8a01b80c4164b..6d74f8328aea2 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -19,8 +19,6 @@ # This script builds and pushes docker images when run from a release of Spark # with Kubernetes support. -set -x - function error { echo "$@" 1>&2 exit 1 diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index a577194a48006..726cff6703dcb 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -74,13 +74,25 @@ private[spark] class PythonRDD( * runner. */ private[spark] case class PythonFunction( - command: Array[Byte], + command: Seq[Byte], envVars: JMap[String, String], pythonIncludes: JList[String], pythonExec: String, pythonVer: String, broadcastVars: JList[Broadcast[PythonBroadcast]], - accumulator: PythonAccumulatorV2) + accumulator: PythonAccumulatorV2) { + + def this( + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + pythonExec: String, + pythonVer: String, + broadcastVars: JList[Broadcast[PythonBroadcast]], + accumulator: PythonAccumulatorV2) = { + this(command.toSeq, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, accumulator) + } +} /** * A wrapper for chained Python functions (from bottom to top). diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index f34316424c4ca..d7a09b599794e 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -613,7 +613,7 @@ private[spark] class PythonRunner(funcs: Seq[ChainedPythonFunctions]) protected override def writeCommand(dataOut: DataOutputStream): Unit = { val command = funcs.head.funcs.head.command dataOut.writeInt(command.length) - dataOut.write(command) + dataOut.write(command.toArray) } protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 7022b986ea025..7c5ab43a9e1b3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -17,6 +17,8 @@ package org.apache.spark.deploy +import java.util.concurrent.TimeUnit + import scala.collection.mutable.HashSet import scala.concurrent.ExecutionContext import scala.reflect.ClassTag @@ -27,6 +29,7 @@ import org.apache.log4j.Logger import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} +import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.internal.{config, Logging} import org.apache.spark.internal.config.Network.RPC_ASK_TIMEOUT import org.apache.spark.resource.ResourceUtils @@ -61,6 +64,11 @@ private class ClientEndpoint( private val lostMasters = new HashSet[RpcAddress] private var activeMasterEndpoint: RpcEndpointRef = null + private val waitAppCompletion = conf.get(config.STANDALONE_SUBMIT_WAIT_APP_COMPLETION) + private val REPORT_DRIVER_STATUS_INTERVAL = 10000 + private var submittedDriverID = "" + private var driverStatusReported = false + private def getProperty(key: String, conf: SparkConf): Option[String] = { sys.props.get(key).orElse(conf.getOption(key)) @@ -107,8 +115,13 @@ private class ClientEndpoint( case "kill" => val driverId = driverArgs.driverId + submittedDriverID = driverId asyncSendToMasterAndForwardReply[KillDriverResponse](RequestKillDriver(driverId)) } + logInfo("... waiting before polling master for driver state") + forwardMessageThread.scheduleAtFixedRate(() => Utils.tryLogNonFatalError { + monitorDriverStatus() + }, 5000, REPORT_DRIVER_STATUS_INTERVAL, TimeUnit.MILLISECONDS) } /** @@ -124,58 +137,87 @@ private class ClientEndpoint( } } - /* Find out driver status then exit the JVM */ - def pollAndReportStatus(driverId: String): Unit = { - // Since ClientEndpoint is the only RpcEndpoint in the process, blocking the event loop thread - // is fine. - logInfo("... waiting before polling master for driver state") - Thread.sleep(5000) - logInfo("... polling master for driver state") - val statusResponse = - activeMasterEndpoint.askSync[DriverStatusResponse](RequestDriverStatus(driverId)) - if (statusResponse.found) { - logInfo(s"State of $driverId is ${statusResponse.state.get}") - // Worker node, if present - (statusResponse.workerId, statusResponse.workerHostPort, statusResponse.state) match { - case (Some(id), Some(hostPort), Some(DriverState.RUNNING)) => - logInfo(s"Driver running on $hostPort ($id)") - case _ => + private def monitorDriverStatus(): Unit = { + if (submittedDriverID != "") { + asyncSendToMasterAndForwardReply[DriverStatusResponse](RequestDriverStatus(submittedDriverID)) + } + } + + /** + * Processes and reports the driver status then exit the JVM if the + * waitAppCompletion is set to false, else reports the driver status + * if debug logs are enabled. + */ + + def reportDriverStatus( + found: Boolean, + state: Option[DriverState], + workerId: Option[String], + workerHostPort: Option[String], + exception: Option[Exception]): Unit = { + if (found) { + // Using driverStatusReported to avoid writing following + // logs again when waitAppCompletion is set to true + if (!driverStatusReported) { + driverStatusReported = true + logInfo(s"State of $submittedDriverID is ${state.get}") + // Worker node, if present + (workerId, workerHostPort, state) match { + case (Some(id), Some(hostPort), Some(DriverState.RUNNING)) => + logInfo(s"Driver running on $hostPort ($id)") + case _ => + } } // Exception, if present - statusResponse.exception match { + exception match { case Some(e) => logError(s"Exception from cluster was: $e") e.printStackTrace() System.exit(-1) case _ => - System.exit(0) + state.get match { + case DriverState.FINISHED | DriverState.FAILED | + DriverState.ERROR | DriverState.KILLED => + logInfo(s"State of driver $submittedDriverID is ${state.get}, " + + s"exiting spark-submit JVM.") + System.exit(0) + case _ => + if (!waitAppCompletion) { + logInfo(s"spark-submit not configured to wait for completion, " + + s"exiting spark-submit JVM.") + System.exit(0) + } else { + logDebug(s"State of driver $submittedDriverID is ${state.get}, " + + s"continue monitoring driver status.") + } + } + } + } else { + logError(s"ERROR: Cluster master did not recognize $submittedDriverID") + System.exit(-1) } - } else { - logError(s"ERROR: Cluster master did not recognize $driverId") - System.exit(-1) } - } - override def receive: PartialFunction[Any, Unit] = { case SubmitDriverResponse(master, success, driverId, message) => logInfo(message) if (success) { activeMasterEndpoint = master - pollAndReportStatus(driverId.get) + submittedDriverID = driverId.get } else if (!Utils.responseFromBackup(message)) { System.exit(-1) } - case KillDriverResponse(master, driverId, success, message) => logInfo(message) if (success) { activeMasterEndpoint = master - pollAndReportStatus(driverId) } else if (!Utils.responseFromBackup(message)) { System.exit(-1) } + + case DriverStatusResponse(found, state, workerId, workerHostPort, exception) => + reportDriverStatus(found, state, workerId, workerHostPort, exception) } override def onDisconnected(remoteAddress: RpcAddress): Unit = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 71df5dfa423a9..d2e65db970380 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -715,7 +715,9 @@ private[deploy] class Master( val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) .filter(canLaunchExecutor(_, app.desc)) .sortBy(_.coresFree).reverse - if (waitingApps.length == 1 && usableWorkers.isEmpty) { + val appMayHang = waitingApps.length == 1 && + waitingApps.head.executors.isEmpty && usableWorkers.isEmpty + if (appMayHang) { logWarning(s"App ${app.id} requires more resource than any of Workers could have.") } val assignedCores = scheduleExecutorsOnWorkers(app, usableWorkers, spreadOutApps) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 8ef0c37198568..ee437c696b47e 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1864,4 +1864,13 @@ package object config { .version("3.1.0") .booleanConf .createWithDefault(false) + + private[spark] val STANDALONE_SUBMIT_WAIT_APP_COMPLETION = + ConfigBuilder("spark.standalone.submit.waitAppCompletion") + .doc("In standalone cluster mode, controls whether the client waits to exit until the " + + "application completes. If set to true, the client process will stay alive polling " + + "the driver's status. Otherwise, the client process will exit after submission.") + .version("3.1.0") + .booleanConf + .createWithDefault(false) } diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index ea033d0c890ac..bd19c9522f3df 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -42,7 +42,7 @@ private[spark] class AppStatusStore( store.view(classOf[ApplicationInfoWrapper]).max(1).iterator().next().info } catch { case _: NoSuchElementException => - throw new SparkException("Failed to get the application information. " + + throw new NoSuchElementException("Failed to get the application information. " + "If you are starting up Spark, please wait a while until it's ready.") } } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 844d9b7cf2c27..1c788a30022d0 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -363,12 +363,22 @@ private[spark] object JsonProtocol { case v: Long => JInt(v) // We only have 3 kind of internal accumulator types, so if it's not int or long, it must be // the blocks accumulator, whose type is `java.util.List[(BlockId, BlockStatus)]` - case v => - JArray(v.asInstanceOf[java.util.List[(BlockId, BlockStatus)]].asScala.toList.map { - case (id, status) => - ("Block ID" -> id.toString) ~ - ("Status" -> blockStatusToJson(status)) + case v: java.util.List[_] => + JArray(v.asScala.toList.flatMap { + case (id: BlockId, status: BlockStatus) => + Some( + ("Block ID" -> id.toString) ~ + ("Status" -> blockStatusToJson(status)) + ) + case _ => + // Ignore unsupported types. A user may put `METRICS_PREFIX` in the name. We should + // not crash. + None }) + case _ => + // Ignore unsupported types. A user may put `METRICS_PREFIX` in the name. We should not + // crash. + JNothing } } else { // For all external accumulators, just use strings diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 248142a5ad633..5a4073baa19d4 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -507,6 +507,54 @@ class JsonProtocolSuite extends SparkFunSuite { testAccumValue(Some("anything"), 123, JString("123")) } + /** Create an AccumulableInfo and verify we can serialize and deserialize it. */ + private def testAccumulableInfo( + name: String, + value: Option[Any], + expectedValue: Option[Any]): Unit = { + val isInternal = name.startsWith(InternalAccumulator.METRICS_PREFIX) + val accum = AccumulableInfo( + 123L, + Some(name), + update = value, + value = value, + internal = isInternal, + countFailedValues = false) + val json = JsonProtocol.accumulableInfoToJson(accum) + val newAccum = JsonProtocol.accumulableInfoFromJson(json) + assert(newAccum == accum.copy(update = expectedValue, value = expectedValue)) + } + + test("SPARK-31923: unexpected value type of internal accumulator") { + // Because a user may use `METRICS_PREFIX` in an accumulator name, we should test unexpected + // types to make sure we don't crash. + import InternalAccumulator.METRICS_PREFIX + testAccumulableInfo( + METRICS_PREFIX + "fooString", + value = Some("foo"), + expectedValue = None) + testAccumulableInfo( + METRICS_PREFIX + "fooList", + value = Some(java.util.Arrays.asList("string")), + expectedValue = Some(java.util.Collections.emptyList()) + ) + val blocks = Seq( + (TestBlockId("block1"), BlockStatus(StorageLevel.MEMORY_ONLY, 1L, 2L)), + (TestBlockId("block2"), BlockStatus(StorageLevel.DISK_ONLY, 3L, 4L))) + testAccumulableInfo( + METRICS_PREFIX + "fooList", + value = Some(java.util.Arrays.asList( + "string", + blocks(0), + blocks(1))), + expectedValue = Some(blocks.asJava) + ) + testAccumulableInfo( + METRICS_PREFIX + "fooSet", + value = Some(Set("foo")), + expectedValue = None) + } + test("SPARK-30936: forwards compatibility - ignore unknown fields") { val expected = TestListenerEvent("foo", 123) val unknownFieldsJson = diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 1e6f8c586d546..1f70d46d587a8 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -374,6 +374,25 @@ To run an interactive Spark shell against the cluster, run the following command You can also pass an option `--total-executor-cores ` to control the number of cores that spark-shell uses on the cluster. +# Client Properties + +Spark applications supports the following configuration properties specific to standalone mode: + + + + + + + + + +
Property NameDefault ValueMeaningSince Version
spark.standalone.submit.waitAppCompletionfalse + In standalone cluster mode, controls whether the client waits to exit until the application completes. + If set to true, the client process will stay alive polling the driver's status. + Otherwise, the client process will exit after submission. + 3.1.0
+ + # Launching Spark Applications The [`spark-submit` script](submitting-applications.html) provides the most straightforward way to diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index 2272c90384847..0130923e694b1 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -27,6 +27,8 @@ license: | - In Spark 3.1, grouping_id() returns long values. In Spark version 3.0 and earlier, this function returns int values. To restore the behavior before Spark 3.0, you can set `spark.sql.legacy.integerGroupingId` to `true`. - In Spark 3.1, SQL UI data adopts the `formatted` mode for the query plan explain results. To restore the behavior before Spark 3.0, you can set `spark.sql.ui.explainMode` to `extended`. + + - In Spark 3.1, `from_unixtime`, `unix_timestamp`,`to_unix_timestamp`, `to_timestamp` and `to_date` will fail if the specified datetime pattern is invalid. In Spark 3.0 or earlier, they result `NULL`. ## Upgrading from Spark SQL 2.4 to 3.0 diff --git a/docs/ss-migration-guide.md b/docs/ss-migration-guide.md index 963ef07af7ace..002058b69bf30 100644 --- a/docs/ss-migration-guide.md +++ b/docs/ss-migration-guide.md @@ -31,3 +31,5 @@ Please refer [Migration Guide: SQL, Datasets and DataFrame](sql-migration-guide. - In Spark 3.0, Structured Streaming forces the source schema into nullable when file-based datasources such as text, json, csv, parquet and orc are used via `spark.readStream(...)`. Previously, it respected the nullability in source schema; however, it caused issues tricky to debug with NPE. To restore the previous behavior, set `spark.sql.streaming.fileSource.schema.forceNullable` to `false`. - Spark 3.0 fixes the correctness issue on Stream-stream outer join, which changes the schema of state. (See [SPARK-26154](https://issues.apache.org/jira/browse/SPARK-26154) for more details). If you start your query from checkpoint constructed from Spark 2.x which uses stream-stream outer join, Spark 3.0 fails the query. To recalculate outputs, discard the checkpoint and replay previous inputs. + +- In Spark 3.0, the deprecated class `org.apache.spark.sql.streaming.ProcessingTime` has been removed. Use `org.apache.spark.sql.streaming.Trigger.ProcessingTime` instead. Likewise, `org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger` has been removed in favor of `Trigger.Continuous`, and `org.apache.spark.sql.execution.streaming.OneTimeTrigger` has been hidden in favor of `Trigger.Once`. \ No newline at end of file diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index eb12f2f1f6ab7..1360d30fdd575 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -480,7 +480,8 @@ object SparkParallelTestGrouping { "org.apache.spark.sql.hive.thriftserver.SparkSQLEnvSuite", "org.apache.spark.sql.hive.thriftserver.ui.ThriftServerPageSuite", "org.apache.spark.sql.hive.thriftserver.ui.HiveThriftServer2ListenerSuite", - "org.apache.spark.sql.hive.thriftserver.ThriftServerWithSparkContextSuite", + "org.apache.spark.sql.hive.thriftserver.ThriftServerWithSparkContextInHttpSuite", + "org.apache.spark.sql.hive.thriftserver.ThriftServerWithSparkContextInBinarySuite", "org.apache.spark.sql.kafka010.KafkaDelegationTokenSuite" ) diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index 061d3f5e1f7ac..2689b9c33d576 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -642,6 +642,15 @@ def f(*a): r = df.select(fUdf(*df.columns)) self.assertEqual(r.first()[0], "success") + def test_udf_cache(self): + func = lambda x: x + + df = self.spark.range(1) + df.select(udf(func)("id")).cache() + + self.assertEqual(df.select(udf(func)("id"))._jdf.queryExecution() + .withCachedData().getClass().getSimpleName(), 'InMemoryRelation') + class UDFInitializationTests(unittest.TestCase): def tearDown(self): diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 1dbea12ab53ef..1d5bc49d252e2 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -116,8 +116,8 @@ def convert_exception(e): # To make sure this only catches Python UDFs. and any(map(lambda v: "org.apache.spark.sql.execution.python" in v.toString(), c.getStackTrace()))): - msg = ("\n An exception was thrown from Python worker in the executor. " - "The below is the Python worker stacktrace.\n%s" % c.getMessage()) + msg = ("\n An exception was thrown from the Python worker. " + "Please see the stack trace below.\n%s" % c.getMessage()) return PythonException(msg, stacktrace) return UnknownException(s, stacktrace, c) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index fc429d6fb1972..7b121194d1b31 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -635,7 +635,12 @@ private[spark] class Client( distribute(args.primaryPyFile, appMasterOnly = true) } - pySparkArchives.foreach { f => distribute(f) } + pySparkArchives.foreach { f => + val uri = Utils.resolveURI(f) + if (uri.getScheme != Utils.LOCAL_SCHEME) { + distribute(f) + } + } // The python files list needs to be treated especially. All files that are not an // archive need to be placed in a subdirectory that will be added to PYTHONPATH. diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index b03e6372a8eae..14b5fa3dbfda6 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -1173,6 +1173,7 @@ ansiNonReserved | TRIM | TRUE | TRUNCATE + | TYPE | UNARCHIVE | UNBOUNDED | UNCACHE diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index c5ead9412a438..c5cf447c103b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql.catalyst.expressions +import java.text.ParseException import java.time.{DateTimeException, LocalDate, LocalDateTime, ZoneId} +import java.time.format.DateTimeParseException import java.time.temporal.IsoFields import java.util.Locale -import scala.util.control.NonFatal - import org.apache.commons.text.StringEscapeUtils -import org.apache.spark.SparkUpgradeException import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -34,7 +33,6 @@ import org.apache.spark.sql.catalyst.util.{DateTimeUtils, LegacyDateFormats, Tim import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.catalyst.util.LegacyDateFormats.SIMPLE_DATE_FORMAT -import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -56,6 +54,26 @@ trait TimeZoneAwareExpression extends Expression { @transient lazy val zoneId: ZoneId = DateTimeUtils.getZoneId(timeZoneId.get) } +trait TimestampFormatterHelper extends TimeZoneAwareExpression { + + protected def formatString: Expression + + protected def isParsing: Boolean + + @transient final protected lazy val formatterOption: Option[TimestampFormatter] = + if (formatString.foldable) { + Option(formatString.eval()).map(fmt => getFormatter(fmt.toString)) + } else None + + final protected def getFormatter(fmt: String): TimestampFormatter = { + TimestampFormatter( + format = fmt, + zoneId = zoneId, + legacyFormat = SIMPLE_DATE_FORMAT, + isParsing = isParsing) + } +} + /** * Returns the current date at the start of query evaluation. * All calls of current_date within the same query return the same value. @@ -715,7 +733,7 @@ case class WeekOfYear(child: Expression) since = "1.5.0") // scalastyle:on line.size.limit case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Option[String] = None) - extends BinaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes + extends BinaryExpression with TimestampFormatterHelper with ImplicitCastInputTypes with NullIntolerant { def this(left: Expression, right: Expression) = this(left, right, None) @@ -727,33 +745,13 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) - @transient private lazy val formatter: Option[TimestampFormatter] = { - if (right.foldable) { - Option(right.eval()).map { format => - TimestampFormatter( - format.toString, - zoneId, - legacyFormat = SIMPLE_DATE_FORMAT, - isParsing = false) - } - } else None - } - override protected def nullSafeEval(timestamp: Any, format: Any): Any = { - val tf = if (formatter.isEmpty) { - TimestampFormatter( - format.toString, - zoneId, - legacyFormat = SIMPLE_DATE_FORMAT, - isParsing = false) - } else { - formatter.get - } - UTF8String.fromString(tf.format(timestamp.asInstanceOf[Long])) + val formatter = formatterOption.getOrElse(getFormatter(format.toString)) + UTF8String.fromString(formatter.format(timestamp.asInstanceOf[Long])) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - formatter.map { tf => + formatterOption.map { tf => val timestampFormatter = ctx.addReferenceObj("timestampFormatter", tf) defineCodeGen(ctx, ev, (timestamp, _) => { s"""UTF8String.fromString($timestampFormatter.format($timestamp))""" @@ -774,6 +772,10 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti } override def prettyName: String = "date_format" + + override protected def formatString: Expression = right + + override protected def isParsing: Boolean = false } /** @@ -871,31 +873,21 @@ case class UnixTimestamp(timeExp: Expression, format: Expression, timeZoneId: Op } abstract class ToTimestamp - extends BinaryExpression with TimeZoneAwareExpression with ExpectsInputTypes { + extends BinaryExpression with TimestampFormatterHelper with ExpectsInputTypes { // The result of the conversion to timestamp is microseconds divided by this factor. // For example if the factor is 1000000, the result of the expression is in seconds. protected def downScaleFactor: Long + override protected def formatString: Expression = right + override protected def isParsing = true + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, DateType, TimestampType), StringType) override def dataType: DataType = LongType override def nullable: Boolean = true - private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] - private lazy val formatter: TimestampFormatter = - try { - TimestampFormatter( - constFormat.toString, - zoneId, - legacyFormat = SIMPLE_DATE_FORMAT, - isParsing = true) - } catch { - case e: SparkUpgradeException => throw e - case NonFatal(_) => null - } - override def eval(input: InternalRow): Any = { val t = left.eval(input) if (t == null) { @@ -906,34 +898,18 @@ abstract class ToTimestamp epochDaysToMicros(t.asInstanceOf[Int], zoneId) / downScaleFactor case TimestampType => t.asInstanceOf[Long] / downScaleFactor - case StringType if right.foldable => - if (constFormat == null || formatter == null) { - null - } else { - try { - formatter.parse( - t.asInstanceOf[UTF8String].toString) / downScaleFactor - } catch { - case e: SparkUpgradeException => throw e - case NonFatal(_) => null - } - } case StringType => - val f = right.eval(input) - if (f == null) { + val fmt = right.eval(input) + if (fmt == null) { null } else { - val formatString = f.asInstanceOf[UTF8String].toString + val formatter = formatterOption.getOrElse(getFormatter(fmt.toString)) try { - TimestampFormatter( - formatString, - zoneId, - legacyFormat = SIMPLE_DATE_FORMAT, - isParsing = true) - .parse(t.asInstanceOf[UTF8String].toString) / downScaleFactor + formatter.parse(t.asInstanceOf[UTF8String].toString) / downScaleFactor } catch { - case e: SparkUpgradeException => throw e - case NonFatal(_) => null + case _: DateTimeParseException | + _: DateTimeException | + _: ParseException => null } } } @@ -943,55 +919,44 @@ abstract class ToTimestamp override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = CodeGenerator.javaType(dataType) left.dataType match { - case StringType if right.foldable => + case StringType => formatterOption.map { fmt => val df = classOf[TimestampFormatter].getName - if (formatter == null) { - ExprCode.forNullValue(dataType) - } else { - val formatterName = ctx.addReferenceObj("formatter", formatter, df) - val eval1 = left.genCode(ctx) - ev.copy(code = code""" - ${eval1.code} - boolean ${ev.isNull} = ${eval1.isNull}; - $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - if (!${ev.isNull}) { - try { - ${ev.value} = $formatterName.parse(${eval1.value}.toString()) / $downScaleFactor; - } catch (java.lang.IllegalArgumentException e) { - ${ev.isNull} = true; - } catch (java.text.ParseException e) { - ${ev.isNull} = true; - } catch (java.time.format.DateTimeParseException e) { - ${ev.isNull} = true; - } catch (java.time.DateTimeException e) { - ${ev.isNull} = true; - } - }""") - } - case StringType => + val formatterName = ctx.addReferenceObj("formatter", fmt, df) + nullSafeCodeGen(ctx, ev, (datetimeStr, _) => + s""" + |try { + | ${ev.value} = $formatterName.parse($datetimeStr.toString()) / $downScaleFactor; + |} catch (java.time.DateTimeException e) { + | ${ev.isNull} = true; + |} catch (java.time.format.DateTimeParseException e) { + | ${ev.isNull} = true; + |} catch (java.text.ParseException e) { + | ${ev.isNull} = true; + |} + |""".stripMargin) + }.getOrElse { val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName) val tf = TimestampFormatter.getClass.getName.stripSuffix("$") val ldf = LegacyDateFormats.getClass.getName.stripSuffix("$") - nullSafeCodeGen(ctx, ev, (string, format) => { + val timestampFormatter = ctx.freshName("timestampFormatter") + nullSafeCodeGen(ctx, ev, (string, format) => s""" - try { - ${ev.value} = $tf$$.MODULE$$.apply( - $format.toString(), - $zid, - $ldf$$.MODULE$$.SIMPLE_DATE_FORMAT(), - true) - .parse($string.toString()) / $downScaleFactor; - } catch (java.lang.IllegalArgumentException e) { - ${ev.isNull} = true; - } catch (java.text.ParseException e) { - ${ev.isNull} = true; - } catch (java.time.format.DateTimeParseException e) { - ${ev.isNull} = true; - } catch (java.time.DateTimeException e) { - ${ev.isNull} = true; - } - """ - }) + |$tf $timestampFormatter = $tf$$.MODULE$$.apply( + | $format.toString(), + | $zid, + | $ldf$$.MODULE$$.SIMPLE_DATE_FORMAT(), + | true); + |try { + | ${ev.value} = $timestampFormatter.parse($string.toString()) / $downScaleFactor; + |} catch (java.time.format.DateTimeParseException e) { + | ${ev.isNull} = true; + |} catch (java.time.DateTimeException e) { + | ${ev.isNull} = true; + |} catch (java.text.ParseException e) { + | ${ev.isNull} = true; + |} + |""".stripMargin) + } case TimestampType => val eval1 = left.genCode(ctx) ev.copy(code = code""" @@ -1044,7 +1009,8 @@ abstract class UnixTime extends ToTimestamp { since = "1.5.0") // scalastyle:on line.size.limit case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[String] = None) - extends BinaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + extends BinaryExpression with TimestampFormatterHelper with ImplicitCastInputTypes + with NullIntolerant { def this(sec: Expression, format: Expression) = this(sec, format, None) @@ -1065,93 +1031,34 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) - private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] - private lazy val formatter: TimestampFormatter = - try { - TimestampFormatter( - constFormat.toString, - zoneId, - legacyFormat = SIMPLE_DATE_FORMAT, - isParsing = false) - } catch { - case e: SparkUpgradeException => throw e - case NonFatal(_) => null - } - - override def eval(input: InternalRow): Any = { - val time = left.eval(input) - if (time == null) { - null - } else { - if (format.foldable) { - if (constFormat == null || formatter == null) { - null - } else { - try { - UTF8String.fromString(formatter.format(time.asInstanceOf[Long] * MICROS_PER_SECOND)) - } catch { - case e: SparkUpgradeException => throw e - case NonFatal(_) => null - } - } - } else { - val f = format.eval(input) - if (f == null) { - null - } else { - try { - UTF8String.fromString( - TimestampFormatter( - f.toString, - zoneId, - legacyFormat = SIMPLE_DATE_FORMAT, - isParsing = false) - .format(time.asInstanceOf[Long] * MICROS_PER_SECOND)) - } catch { - case e: SparkUpgradeException => throw e - case NonFatal(_) => null - } - } - } - } + override def nullSafeEval(seconds: Any, format: Any): Any = { + val fmt = formatterOption.getOrElse(getFormatter(format.toString)) + UTF8String.fromString(fmt.format(seconds.asInstanceOf[Long] * MICROS_PER_SECOND)) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val df = classOf[TimestampFormatter].getName - if (format.foldable) { - if (formatter == null) { - ExprCode.forNullValue(StringType) - } else { - val formatterName = ctx.addReferenceObj("formatter", formatter, df) - val t = left.genCode(ctx) - ev.copy(code = code""" - ${t.code} - boolean ${ev.isNull} = ${t.isNull}; - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - if (!${ev.isNull}) { - try { - ${ev.value} = UTF8String.fromString($formatterName.format(${t.value} * 1000000L)); - } catch (java.lang.IllegalArgumentException e) { - ${ev.isNull} = true; - } - }""") - } - } else { - val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName) + formatterOption.map { f => + val formatterName = ctx.addReferenceObj("formatter", f) + defineCodeGen(ctx, ev, (seconds, _) => + s"UTF8String.fromString($formatterName.format($seconds * 1000000L))") + }.getOrElse { val tf = TimestampFormatter.getClass.getName.stripSuffix("$") val ldf = LegacyDateFormats.getClass.getName.stripSuffix("$") - nullSafeCodeGen(ctx, ev, (seconds, f) => { + val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName) + defineCodeGen(ctx, ev, (seconds, format) => s""" - try { - ${ev.value} = UTF8String.fromString( - $tf$$.MODULE$$.apply($f.toString(), $zid, $ldf$$.MODULE$$.SIMPLE_DATE_FORMAT(), false) - .format($seconds * 1000000L)); - } catch (java.lang.IllegalArgumentException e) { - ${ev.isNull} = true; - }""" - }) + |UTF8String.fromString( + | $tf$$.MODULE$$.apply($format.toString(), + | $zid, + | $ldf$$.MODULE$$.SIMPLE_DATE_FORMAT(), + | false).format($seconds * 1000000L)) + |""".stripMargin) } } + + override protected def formatString: Expression = format + + override protected def isParsing: Boolean = false } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 77b4cecc263c7..2abd9d7bb4423 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -85,6 +85,7 @@ trait NamedExpression extends Expression { * e.g. top level attributes aliased in the SELECT clause, or column from a LocalRelation. * 2. Seq with a Single element: either the table name or the alias name of the table. * 3. Seq with 2 elements: database name and table name + * 4. Seq with 3 elements: catalog name, database name and table name */ def qualifier: Seq[String] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala index 30da902a33cf2..c38a1189387d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala @@ -35,6 +35,11 @@ object NestedColumnAliasing { case Project(projectList, child) if SQLConf.get.nestedSchemaPruningEnabled && canProjectPushThrough(child) => getAliasSubMap(projectList) + + case plan if SQLConf.get.nestedSchemaPruningEnabled && canPruneOn(plan) => + val exprCandidatesToPrune = plan.expressions + getAliasSubMap(exprCandidatesToPrune, plan.producedAttributes.toSeq) + case _ => None } @@ -48,7 +53,11 @@ object NestedColumnAliasing { case Project(projectList, child) => Project( getNewProjectList(projectList, nestedFieldToAlias), - replaceChildrenWithAliases(child, attrToAliases)) + replaceChildrenWithAliases(child, nestedFieldToAlias, attrToAliases)) + + // The operators reaching here was already guarded by `canPruneOn`. + case other => + replaceChildrenWithAliases(other, nestedFieldToAlias, attrToAliases) } /** @@ -68,10 +77,23 @@ object NestedColumnAliasing { */ def replaceChildrenWithAliases( plan: LogicalPlan, + nestedFieldToAlias: Map[ExtractValue, Alias], attrToAliases: Map[ExprId, Seq[Alias]]): LogicalPlan = { plan.withNewChildren(plan.children.map { plan => Project(plan.output.flatMap(a => attrToAliases.getOrElse(a.exprId, Seq(a))), plan) - }) + }).transformExpressions { + case f: ExtractValue if nestedFieldToAlias.contains(f) => + nestedFieldToAlias(f).toAttribute + } + } + + /** + * Returns true for those operators that we can prune nested column on it. + */ + private def canPruneOn(plan: LogicalPlan) = plan match { + case _: Aggregate => true + case _: Expand => true + case _ => false } /** @@ -204,15 +226,8 @@ object GeneratorNestedColumnAliasing { g: Generate, nestedFieldToAlias: Map[ExtractValue, Alias], attrToAliases: Map[ExprId, Seq[Alias]]): LogicalPlan = { - val newGenerator = g.generator.transform { - case f: ExtractValue if nestedFieldToAlias.contains(f) => - nestedFieldToAlias(f).toAttribute - }.asInstanceOf[Generator] - // Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`. - val newGenerate = g.copy(generator = newGenerator) - - NestedColumnAliasing.replaceChildrenWithAliases(newGenerate, attrToAliases) + NestedColumnAliasing.replaceChildrenWithAliases(g, nestedFieldToAlias, attrToAliases) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala index 76ae3e5e8469a..da80e629ee31d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala @@ -100,6 +100,17 @@ trait LegacyDateFormatter extends DateFormatter { } } +/** + * The legacy formatter is based on Apache Commons FastDateFormat. The formatter uses the default + * JVM time zone intentionally for compatibility with Spark 2.4 and earlier versions. + * + * Note: Using of the default JVM time zone makes the formatter compatible with the legacy + * `DateTimeUtils` methods `toJavaDate` and `fromJavaDate` that are based on the default + * JVM time zone too. + * + * @param pattern `java.text.SimpleDateFormat` compatible pattern. + * @param locale The locale overrides the system locale and is used in parsing/formatting. + */ class LegacyFastDateFormatter(pattern: String, locale: Locale) extends LegacyDateFormatter { @transient private lazy val fdf = FastDateFormat.getInstance(pattern, locale) @@ -108,6 +119,22 @@ class LegacyFastDateFormatter(pattern: String, locale: Locale) extends LegacyDat override def validatePatternString(): Unit = fdf } +// scalastyle:off line.size.limit +/** + * The legacy formatter is based on `java.text.SimpleDateFormat`. The formatter uses the default + * JVM time zone intentionally for compatibility with Spark 2.4 and earlier versions. + * + * Note: Using of the default JVM time zone makes the formatter compatible with the legacy + * `DateTimeUtils` methods `toJavaDate` and `fromJavaDate` that are based on the default + * JVM time zone too. + * + * @param pattern The pattern describing the date and time format. + * See + * Date and Time Patterns + * @param locale The locale whose date format symbols should be used. It overrides the system + * locale in parsing/formatting. + */ +// scalastyle:on line.size.limit class LegacySimpleDateFormatter(pattern: String, locale: Locale) extends LegacyDateFormatter { @transient private lazy val sdf = new SimpleDateFormat(pattern, locale) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 21a478aaf06a6..41a271b95e83c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -88,19 +88,20 @@ object DateTimeUtils { } /** - * Converts an instance of `java.sql.Date` to a number of days since the epoch - * 1970-01-01 via extracting date fields `year`, `month`, `days` from the input, - * creating a local date in Proleptic Gregorian calendar from the fields, and - * getting the number of days from the resulted local date. + * Converts a local date at the default JVM time zone to the number of days since 1970-01-01 + * in the hybrid calendar (Julian + Gregorian) by discarding the time part. The resulted days are + * rebased from the hybrid to Proleptic Gregorian calendar. The days rebasing is performed via + * UTC time zone for simplicity because the difference between two calendars is the same in + * any given time zone and UTC time zone. * - * This approach was taken to have the same local date as the triple of `year`, - * `month`, `day` in the original hybrid calendar used by `java.sql.Date` and - * Proleptic Gregorian calendar used by Spark since version 3.0.0, see SPARK-26651. + * Note: The date is shifted by the offset of the default JVM time zone for backward compatibility + * with Spark 2.4 and earlier versions. The goal of the shift is to get a local date derived + * from the number of days that has the same date fields (year, month, day) as the original + * `date` at the default JVM time zone. * - * @param date It represents a specific instant in time based on - * the hybrid calendar which combines Julian and - * Gregorian calendars. - * @return The number of days since epoch from java.sql.Date. + * @param date It represents a specific instant in time based on the hybrid calendar which + * combines Julian and Gregorian calendars. + * @return The number of days since the epoch in Proleptic Gregorian calendar. */ def fromJavaDate(date: Date): SQLDate = { val millisUtc = date.getTime @@ -110,17 +111,18 @@ object DateTimeUtils { } /** - * The opposite to `fromJavaDate` method which converts a number of days to an - * instance of `java.sql.Date`. It builds a local date in Proleptic Gregorian - * calendar, extracts date fields `year`, `month`, `day`, and creates a local - * date in the hybrid calendar (Julian + Gregorian calendars) from the fields. + * Converts days since the epoch 1970-01-01 in Proleptic Gregorian calendar to a local date + * at the default JVM time zone in the hybrid calendar (Julian + Gregorian). It rebases the given + * days from Proleptic Gregorian to the hybrid calendar at UTC time zone for simplicity because + * the difference between two calendars doesn't depend on any time zone. The result is shifted + * by the time zone offset in wall clock to have the same date fields (year, month, day) + * at the default JVM time zone as the input `daysSinceEpoch` in Proleptic Gregorian calendar. * - * The purpose of the conversion is to have the same local date as the triple - * of `year`, `month`, `day` in the original Proleptic Gregorian calendar and - * in the target calender. + * Note: The date is shifted by the offset of the default JVM time zone for backward compatibility + * with Spark 2.4 and earlier versions. * - * @param daysSinceEpoch The number of days since 1970-01-01. - * @return A `java.sql.Date` from number of days since epoch. + * @param daysSinceEpoch The number of days since 1970-01-01 in Proleptic Gregorian calendar. + * @return A local date in the hybrid calendar as `java.sql.Date` from number of days since epoch. */ def toJavaDate(daysSinceEpoch: SQLDate): Date = { val rebasedDays = rebaseGregorianToJulianDays(daysSinceEpoch) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala index f3b589657b254..f460404800264 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala @@ -292,14 +292,14 @@ object TimestampFormatter { legacyFormat: LegacyDateFormat = LENIENT_SIMPLE_DATE_FORMAT, isParsing: Boolean): TimestampFormatter = { val pattern = format.getOrElse(defaultPattern) - if (SQLConf.get.legacyTimeParserPolicy == LEGACY) { + val formatter = if (SQLConf.get.legacyTimeParserPolicy == LEGACY) { getLegacyFormatter(pattern, zoneId, locale, legacyFormat) } else { - val tf = new Iso8601TimestampFormatter( + new Iso8601TimestampFormatter( pattern, zoneId, locale, legacyFormat, isParsing) - tf.validatePatternString() - tf } + formatter.validatePatternString() + formatter } def getLegacyFormatter( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3a41b0553db54..189740e313207 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2784,7 +2784,15 @@ class SQLConf extends Serializable with Logging { def cacheVectorizedReaderEnabled: Boolean = getConf(CACHE_VECTORIZED_READER_ENABLED) - def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS) + def defaultNumShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS) + + def numShufflePartitions: Int = { + if (adaptiveExecutionEnabled && coalesceShufflePartitionsEnabled) { + getConf(COALESCE_PARTITIONS_INITIAL_PARTITION_NUM).getOrElse(defaultNumShufflePartitions) + } else { + defaultNumShufflePartitions + } + } def adaptiveExecutionEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_ENABLED) @@ -2797,9 +2805,6 @@ class SQLConf extends Serializable with Logging { def coalesceShufflePartitionsEnabled: Boolean = getConf(COALESCE_PARTITIONS_ENABLED) - def initialShufflePartitionNum: Int = - getConf(COALESCE_PARTITIONS_INITIAL_PARTITION_NUM).getOrElse(numShufflePartitions) - def minBatchesToRetain: Int = getConf(MIN_BATCHES_TO_RETAIN) def maxBatchesToRetainInMemory: Int = getConf(MAX_BATCHES_TO_RETAIN_IN_MEMORY) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 2dc5990eb6103..f248a3454f39a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -23,6 +23,8 @@ import java.time.{Instant, LocalDate, ZoneId} import java.util.{Calendar, Locale, TimeZone} import java.util.concurrent.TimeUnit._ +import scala.reflect.ClassTag + import org.apache.spark.{SparkFunSuite, SparkUpgradeException} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection @@ -777,8 +779,6 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( FromUnixTime(Literal(1000L), Literal.create(null, StringType), timeZoneId), null) - checkEvaluation( - FromUnixTime(Literal(0L), Literal("not a valid format"), timeZoneId), null) // SPARK-28072 The codegen path for non-literal input should also work checkEvaluation( @@ -792,7 +792,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } // Test escaping of format - GenerateUnsafeProjection.generate(FromUnixTime(Literal(0L), Literal("\"quote"), UTC_OPT) :: Nil) + GenerateUnsafeProjection.generate(FromUnixTime(Literal(0L), Literal("\""), UTC_OPT) :: Nil) } test("unix_timestamp") { @@ -854,15 +854,13 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { UnixTimestamp(Literal(date1), Literal.create(null, StringType), timeZoneId), MICROSECONDS.toSeconds( DateTimeUtils.daysToMicros(DateTimeUtils.fromJavaDate(date1), tz.toZoneId))) - checkEvaluation( - UnixTimestamp(Literal("2015-07-24"), Literal("not a valid format"), timeZoneId), null) } } } } // Test escaping of format GenerateUnsafeProjection.generate( - UnixTimestamp(Literal("2015-07-24"), Literal("\"quote"), UTC_OPT) :: Nil) + UnixTimestamp(Literal("2015-07-24"), Literal("\""), UTC_OPT) :: Nil) } test("to_unix_timestamp") { @@ -920,10 +918,6 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Literal(date1), Literal.create(null, StringType), timeZoneId), MICROSECONDS.toSeconds( DateTimeUtils.daysToMicros(DateTimeUtils.fromJavaDate(date1), zid))) - checkEvaluation( - ToUnixTimestamp( - Literal("2015-07-24"), - Literal("not a valid format"), timeZoneId), null) // SPARK-28072 The codegen path for non-literal input should also work checkEvaluation( @@ -940,7 +934,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } // Test escaping of format GenerateUnsafeProjection.generate( - ToUnixTimestamp(Literal("2015-07-24"), Literal("\"quote"), UTC_OPT) :: Nil) + ToUnixTimestamp(Literal("2015-07-24"), Literal("\""), UTC_OPT) :: Nil) } test("datediff") { @@ -1169,36 +1163,28 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { MillisToTimestamp(Literal(-92233720368547758L)), "long overflow") } - test("Disable week-based date fields and quarter fields for parsing") { + test("Consistent error handling for datetime formatting and parsing functions") { - def checkSparkUpgrade(c: Char): Unit = { - checkExceptionInExpression[SparkUpgradeException]( - new ParseToTimestamp(Literal("1"), Literal(c.toString)).child, "3.0") - checkExceptionInExpression[SparkUpgradeException]( - new ParseToDate(Literal("1"), Literal(c.toString)).child, "3.0") - checkExceptionInExpression[SparkUpgradeException]( - ToUnixTimestamp(Literal("1"), Literal(c.toString)), "3.0") - checkExceptionInExpression[SparkUpgradeException]( - UnixTimestamp(Literal("1"), Literal(c.toString)), "3.0") - } - - def checkNullify(c: Char): Unit = { - checkEvaluation(new ParseToTimestamp(Literal("1"), Literal(c.toString)).child, null) - checkEvaluation(new ParseToDate(Literal("1"), Literal(c.toString)).child, null) - checkEvaluation(ToUnixTimestamp(Literal("1"), Literal(c.toString)), null) - checkEvaluation(UnixTimestamp(Literal("1"), Literal(c.toString)), null) + def checkException[T <: Exception : ClassTag](c: String): Unit = { + checkExceptionInExpression[T](new ParseToTimestamp(Literal("1"), Literal(c)).child, c) + checkExceptionInExpression[T](new ParseToDate(Literal("1"), Literal(c)).child, c) + checkExceptionInExpression[T](ToUnixTimestamp(Literal("1"), Literal(c)), c) + checkExceptionInExpression[T](UnixTimestamp(Literal("1"), Literal(c)), c) + if (!Set("E", "F", "q", "Q").contains(c)) { + checkExceptionInExpression[T](DateFormatClass(CurrentTimestamp(), Literal(c)), c) + checkExceptionInExpression[T](FromUnixTime(Literal(0L), Literal(c)), c) + } } Seq('Y', 'W', 'w', 'E', 'u', 'F').foreach { l => - checkSparkUpgrade(l) + checkException[SparkUpgradeException](l.toString) } - Seq('q', 'Q').foreach { l => - checkNullify(l) + Seq('q', 'Q', 'e', 'c', 'A', 'n', 'N', 'p').foreach { l => + checkException[IllegalArgumentException](l.toString) } } - test("SPARK-31896: Handle am-pm timestamp parsing when hour is missing") { checkEvaluation( new ParseToTimestamp(Literal("PM"), Literal("a")).child, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala index d4d6f79d7895e..30fdcf17f8d60 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala @@ -341,6 +341,100 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { .analyze comparePlans(optimized, expected) } + + test("Nested field pruning for Aggregate") { + def runTest(basePlan: LogicalPlan => LogicalPlan): Unit = { + val query1 = basePlan(contact).groupBy($"id")(first($"name.first").as("first")).analyze + val optimized1 = Optimize.execute(query1) + val aliases1 = collectGeneratedAliases(optimized1) + + val expected1 = basePlan( + contact + .select($"id", 'name.getField("first").as(aliases1(0))) + ).groupBy($"id")(first($"${aliases1(0)}").as("first")).analyze + comparePlans(optimized1, expected1) + + val query2 = basePlan(contact).groupBy($"name.last")(first($"name.first").as("first")).analyze + val optimized2 = Optimize.execute(query2) + val aliases2 = collectGeneratedAliases(optimized2) + + val expected2 = basePlan( + contact + .select('name.getField("last").as(aliases2(0)), 'name.getField("first").as(aliases2(1))) + ).groupBy($"${aliases2(0)}")(first($"${aliases2(1)}").as("first")).analyze + comparePlans(optimized2, expected2) + } + + Seq( + (plan: LogicalPlan) => plan, + (plan: LogicalPlan) => plan.limit(100), + (plan: LogicalPlan) => plan.repartition(100), + (plan: LogicalPlan) => Sample(0.0, 0.6, false, 11L, plan)).foreach { base => + runTest(base) + } + + val query3 = contact.groupBy($"id")(first($"name"), first($"name.first").as("first")).analyze + val optimized3 = Optimize.execute(query3) + val expected3 = contact.select($"id", $"name") + .groupBy($"id")(first($"name"), first($"name.first").as("first")).analyze + comparePlans(optimized3, expected3) + } + + test("Nested field pruning for Expand") { + def runTest(basePlan: LogicalPlan => LogicalPlan): Unit = { + val query1 = Expand( + Seq( + Seq($"name.first", $"name.middle"), + Seq(ConcatWs(Seq($"name.first", $"name.middle")), + ConcatWs(Seq($"name.middle", $"name.first"))) + ), + Seq('a.string, 'b.string), + basePlan(contact) + ).analyze + val optimized1 = Optimize.execute(query1) + val aliases1 = collectGeneratedAliases(optimized1) + + val expected1 = Expand( + Seq( + Seq($"${aliases1(0)}", $"${aliases1(1)}"), + Seq(ConcatWs(Seq($"${aliases1(0)}", $"${aliases1(1)}")), + ConcatWs(Seq($"${aliases1(1)}", $"${aliases1(0)}"))) + ), + Seq('a.string, 'b.string), + basePlan(contact.select( + 'name.getField("first").as(aliases1(0)), + 'name.getField("middle").as(aliases1(1)))) + ).analyze + comparePlans(optimized1, expected1) + } + + Seq( + (plan: LogicalPlan) => plan, + (plan: LogicalPlan) => plan.limit(100), + (plan: LogicalPlan) => plan.repartition(100), + (plan: LogicalPlan) => Sample(0.0, 0.6, false, 11L, plan)).foreach { base => + runTest(base) + } + + val query2 = Expand( + Seq( + Seq($"name", $"name.middle"), + Seq($"name", ConcatWs(Seq($"name.middle", $"name.first"))) + ), + Seq('a.string, 'b.string), + contact + ).analyze + val optimized2 = Optimize.execute(query2) + val expected2 = Expand( + Seq( + Seq($"name", $"name.middle"), + Seq($"name", ConcatWs(Seq($"name.middle", $"name.first"))) + ), + Seq('a.string, 'b.string), + contact.select($"name") + ).analyze + comparePlans(optimized2, expected2) + } } object NestedColumnAliasingSuite { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index d5b0885555462..bd617bf7e3df6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -513,6 +513,7 @@ class TableIdentifierParserSuite extends SparkFunSuite with SQLHelper { "transform", "true", "truncate", + "type", "unarchive", "unbounded", "uncache", diff --git a/sql/core/benchmarks/DateTimeBenchmark-jdk11-results.txt b/sql/core/benchmarks/DateTimeBenchmark-jdk11-results.txt index f4ed8ce4afaea..70d888227141d 100644 --- a/sql/core/benchmarks/DateTimeBenchmark-jdk11-results.txt +++ b/sql/core/benchmarks/DateTimeBenchmark-jdk11-results.txt @@ -453,5 +453,9 @@ From java.time.Instant 325 328 Collect longs 1300 1321 25 3.8 260.0 0.3X Collect java.sql.Timestamp 1450 1557 102 3.4 290.0 0.3X Collect java.time.Instant 1499 1599 87 3.3 299.9 0.3X +java.sql.Date to Hive string 17536 18367 1059 0.3 3507.2 0.0X +java.time.LocalDate to Hive string 12089 12897 725 0.4 2417.8 0.0X +java.sql.Timestamp to Hive string 48014 48625 752 0.1 9602.9 0.0X +java.time.Instant to Hive string 37346 37445 93 0.1 7469.1 0.0X diff --git a/sql/core/benchmarks/DateTimeBenchmark-results.txt b/sql/core/benchmarks/DateTimeBenchmark-results.txt index 7a9aa4badfeb7..0795f11a57f28 100644 --- a/sql/core/benchmarks/DateTimeBenchmark-results.txt +++ b/sql/core/benchmarks/DateTimeBenchmark-results.txt @@ -453,5 +453,9 @@ From java.time.Instant 236 243 Collect longs 1280 1337 79 3.9 256.1 0.3X Collect java.sql.Timestamp 1485 1501 15 3.4 297.0 0.3X Collect java.time.Instant 1441 1465 37 3.5 288.1 0.3X +java.sql.Date to Hive string 18745 20895 1364 0.3 3749.0 0.0X +java.time.LocalDate to Hive string 15296 15450 143 0.3 3059.2 0.0X +java.sql.Timestamp to Hive string 46421 47210 946 0.1 9284.2 0.0X +java.time.Instant to Hive string 34747 35187 382 0.1 6949.4 0.0X diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 222fea6528261..07d7c4e97a095 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -110,6 +110,9 @@ case class DataSource( private def providingInstance() = providingClass.getConstructor().newInstance() + private def newHadoopConfiguration(): Configuration = + sparkSession.sessionState.newHadoopConfWithOptions(options) + lazy val sourceInfo: SourceInfo = sourceSchema() private val caseInsensitiveOptions = CaseInsensitiveMap(options) private val equality = sparkSession.sessionState.conf.resolver @@ -231,7 +234,7 @@ case class DataSource( // once the streaming job starts and some upstream source starts dropping data. val hdfsPath = new Path(path) if (!SparkHadoopUtil.get.isGlobPath(hdfsPath)) { - val fs = hdfsPath.getFileSystem(sparkSession.sessionState.newHadoopConf()) + val fs = hdfsPath.getFileSystem(newHadoopConfiguration()) if (!fs.exists(hdfsPath)) { throw new AnalysisException(s"Path does not exist: $path") } @@ -358,7 +361,7 @@ case class DataSource( case (format: FileFormat, _) if FileStreamSink.hasMetadata( caseInsensitiveOptions.get("path").toSeq ++ paths, - sparkSession.sessionState.newHadoopConf(), + newHadoopConfiguration(), sparkSession.sessionState.conf) => val basePath = new Path((caseInsensitiveOptions.get("path").toSeq ++ paths).head) val fileCatalog = new MetadataLogFileIndex(sparkSession, basePath, @@ -450,7 +453,7 @@ case class DataSource( val allPaths = paths ++ caseInsensitiveOptions.get("path") val outputPath = if (allPaths.length == 1) { val path = new Path(allPaths.head) - val fs = path.getFileSystem(sparkSession.sessionState.newHadoopConf()) + val fs = path.getFileSystem(newHadoopConfiguration()) path.makeQualified(fs.getUri, fs.getWorkingDirectory) } else { throw new IllegalArgumentException("Expected exactly one path to be specified, but " + @@ -570,9 +573,7 @@ case class DataSource( checkEmptyGlobPath: Boolean, checkFilesExist: Boolean): Seq[Path] = { val allPaths = caseInsensitiveOptions.get("path") ++ paths - val hadoopConf = sparkSession.sessionState.newHadoopConf() - - DataSource.checkAndGlobPathIfNecessary(allPaths.toSeq, hadoopConf, + DataSource.checkAndGlobPathIfNecessary(allPaths.toSeq, newHadoopConfiguration(), checkEmptyGlobPath, checkFilesExist) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 28ef793ed62db..3242ac21ab324 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -35,12 +35,6 @@ import org.apache.spark.sql.internal.SQLConf * the input partition ordering requirements are met. */ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { - private def defaultNumPreShufflePartitions: Int = - if (conf.adaptiveExecutionEnabled && conf.coalesceShufflePartitionsEnabled) { - conf.initialShufflePartitionNum - } else { - conf.numShufflePartitions - } private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = { val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution @@ -57,7 +51,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { BroadcastExchangeExec(mode, child) case (child, distribution) => val numPartitions = distribution.requiredNumPartitions - .getOrElse(defaultNumPreShufflePartitions) + .getOrElse(conf.numShufflePartitions) ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child) } @@ -95,7 +89,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { // expected number of shuffle partitions. However, if it's smaller than // `conf.numShufflePartitions`, we pick `conf.numShufflePartitions` as the // expected number of shuffle partitions. - math.max(nonShuffleChildrenNumPartitions.max, conf.numShufflePartitions) + math.max(nonShuffleChildrenNumPartitions.max, conf.defaultNumShufflePartitions) } else { childrenNumPartitions.max } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index 0a250b27ccb94..d341d7019f0ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -104,7 +104,7 @@ object PythonUDFRunner { dataOut.writeInt(chained.funcs.length) chained.funcs.foreach { f => dataOut.writeInt(f.command.length) - dataOut.write(f.command) + dataOut.write(f.command.toArray) } } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql index a63bb8526da44..06765627f5545 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql @@ -138,25 +138,11 @@ select to_timestamp("2019 40", "yyyy mm"); select to_timestamp("2019 10:10:10", "yyyy hh:mm:ss"); -- Unsupported narrow text style -select date_format(date '2020-05-23', 'GGGGG'); -select date_format(date '2020-05-23', 'MMMMM'); -select date_format(date '2020-05-23', 'LLLLL'); -select date_format(timestamp '2020-05-23', 'EEEEE'); -select date_format(timestamp '2020-05-23', 'uuuuu'); -select date_format('2020-05-23', 'QQQQQ'); -select date_format('2020-05-23', 'qqqqq'); select to_timestamp('2019-10-06 A', 'yyyy-MM-dd GGGGG'); select to_timestamp('22 05 2020 Friday', 'dd MM yyyy EEEEEE'); select to_timestamp('22 05 2020 Friday', 'dd MM yyyy EEEEE'); select unix_timestamp('22 05 2020 Friday', 'dd MM yyyy EEEEE'); -select from_unixtime(12345, 'MMMMM'); -select from_unixtime(54321, 'QQQQQ'); -select from_unixtime(23456, 'aaaaa'); select from_json('{"time":"26/October/2015"}', 'time Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')); select from_json('{"date":"26/October/2015"}', 'date Date', map('dateFormat', 'dd/MMMMM/yyyy')); select from_csv('26/October/2015', 'time Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')); select from_csv('26/October/2015', 'date Date', map('dateFormat', 'dd/MMMMM/yyyy')); - -select from_unixtime(1, 'yyyyyyyyyyy-MM-dd'); -select date_format(timestamp '2018-11-17 13:33:33', 'yyyyyyyyyy-MM-dd HH:mm:ss'); -select date_format(date '2018-11-17', 'yyyyyyyyyyy-MM-dd'); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/datetime.sql.out index a4e6e79b4573e..26adb40ce1b14 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/datetime.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 116 +-- Number of queries: 103 -- !query @@ -814,69 +814,6 @@ struct 2019-01-01 10:10:10 --- !query -select date_format(date '2020-05-23', 'GGGGG') --- !query schema -struct<> --- !query output -org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'GGGGG' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html - - --- !query -select date_format(date '2020-05-23', 'MMMMM') --- !query schema -struct<> --- !query output -org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'MMMMM' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html - - --- !query -select date_format(date '2020-05-23', 'LLLLL') --- !query schema -struct<> --- !query output -org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'LLLLL' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html - - --- !query -select date_format(timestamp '2020-05-23', 'EEEEE') --- !query schema -struct<> --- !query output -org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'EEEEE' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html - - --- !query -select date_format(timestamp '2020-05-23', 'uuuuu') --- !query schema -struct<> --- !query output -org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'uuuuu' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html - - --- !query -select date_format('2020-05-23', 'QQQQQ') --- !query schema -struct<> --- !query output -java.lang.IllegalArgumentException -Too many pattern letters: Q - - --- !query -select date_format('2020-05-23', 'qqqqq') --- !query schema -struct<> --- !query output -java.lang.IllegalArgumentException -Too many pattern letters: q - - -- !query select to_timestamp('2019-10-06 A', 'yyyy-MM-dd GGGGG') -- !query schema @@ -913,32 +850,6 @@ org.apache.spark.SparkUpgradeException You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd MM yyyy EEEEE' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html --- !query -select from_unixtime(12345, 'MMMMM') --- !query schema -struct<> --- !query output -org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'MMMMM' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html - - --- !query -select from_unixtime(54321, 'QQQQQ') --- !query schema -struct --- !query output -NULL - - --- !query -select from_unixtime(23456, 'aaaaa') --- !query schema -struct<> --- !query output -org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'aaaaa' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html - - -- !query select from_json('{"time":"26/October/2015"}', 'time Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')) -- !query schema @@ -973,29 +884,3 @@ struct<> -- !query output org.apache.spark.SparkUpgradeException You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd/MMMMM/yyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html - - --- !query -select from_unixtime(1, 'yyyyyyyyyyy-MM-dd') --- !query schema -struct<> --- !query output -org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'yyyyyyyyyyy-MM-dd' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html - - --- !query -select date_format(timestamp '2018-11-17 13:33:33', 'yyyyyyyyyy-MM-dd HH:mm:ss') --- !query schema -struct --- !query output -0000002018-11-17 13:33:33 - - --- !query -select date_format(date '2018-11-17', 'yyyyyyyyyyy-MM-dd') --- !query schema -struct<> --- !query output -org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'yyyyyyyyyyy-MM-dd' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html diff --git a/sql/core/src/test/resources/sql-tests/results/datetime-legacy.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime-legacy.sql.out index 38d078838ebee..15092f0a27c1f 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime-legacy.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime-legacy.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 116 +-- Number of queries: 103 -- !query @@ -786,64 +786,6 @@ struct 2019-01-01 10:10:10 --- !query -select date_format(date '2020-05-23', 'GGGGG') --- !query schema -struct --- !query output -AD - - --- !query -select date_format(date '2020-05-23', 'MMMMM') --- !query schema -struct --- !query output -May - - --- !query -select date_format(date '2020-05-23', 'LLLLL') --- !query schema -struct --- !query output -May - - --- !query -select date_format(timestamp '2020-05-23', 'EEEEE') --- !query schema -struct --- !query output -Saturday - - --- !query -select date_format(timestamp '2020-05-23', 'uuuuu') --- !query schema -struct --- !query output -00006 - - --- !query -select date_format('2020-05-23', 'QQQQQ') --- !query schema -struct<> --- !query output -java.lang.IllegalArgumentException -Illegal pattern character 'Q' - - --- !query -select date_format('2020-05-23', 'qqqqq') --- !query schema -struct<> --- !query output -java.lang.IllegalArgumentException -Illegal pattern character 'q' - - -- !query select to_timestamp('2019-10-06 A', 'yyyy-MM-dd GGGGG') -- !query schema @@ -876,30 +818,6 @@ struct 1590130800 --- !query -select from_unixtime(12345, 'MMMMM') --- !query schema -struct --- !query output -December - - --- !query -select from_unixtime(54321, 'QQQQQ') --- !query schema -struct --- !query output -NULL - - --- !query -select from_unixtime(23456, 'aaaaa') --- !query schema -struct --- !query output -PM - - -- !query select from_json('{"time":"26/October/2015"}', 'time Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')) -- !query schema @@ -930,27 +848,3 @@ select from_csv('26/October/2015', 'date Date', map('dateFormat', 'dd/MMMMM/yyyy struct> -- !query output {"date":2015-10-26} - - --- !query -select from_unixtime(1, 'yyyyyyyyyyy-MM-dd') --- !query schema -struct --- !query output -00000001969-12-31 - - --- !query -select date_format(timestamp '2018-11-17 13:33:33', 'yyyyyyyyyy-MM-dd HH:mm:ss') --- !query schema -struct --- !query output -0000002018-11-17 13:33:33 - - --- !query -select date_format(date '2018-11-17', 'yyyyyyyyyyy-MM-dd') --- !query schema -struct --- !query output -00000002018-11-17 diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index dc4220ff62261..b80f36e9c2347 100755 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 116 +-- Number of queries: 103 -- !query @@ -786,69 +786,6 @@ struct 2019-01-01 10:10:10 --- !query -select date_format(date '2020-05-23', 'GGGGG') --- !query schema -struct<> --- !query output -org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'GGGGG' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html - - --- !query -select date_format(date '2020-05-23', 'MMMMM') --- !query schema -struct<> --- !query output -org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'MMMMM' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html - - --- !query -select date_format(date '2020-05-23', 'LLLLL') --- !query schema -struct<> --- !query output -org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'LLLLL' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html - - --- !query -select date_format(timestamp '2020-05-23', 'EEEEE') --- !query schema -struct<> --- !query output -org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'EEEEE' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html - - --- !query -select date_format(timestamp '2020-05-23', 'uuuuu') --- !query schema -struct<> --- !query output -org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'uuuuu' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html - - --- !query -select date_format('2020-05-23', 'QQQQQ') --- !query schema -struct<> --- !query output -java.lang.IllegalArgumentException -Too many pattern letters: Q - - --- !query -select date_format('2020-05-23', 'qqqqq') --- !query schema -struct<> --- !query output -java.lang.IllegalArgumentException -Too many pattern letters: q - - -- !query select to_timestamp('2019-10-06 A', 'yyyy-MM-dd GGGGG') -- !query schema @@ -885,32 +822,6 @@ org.apache.spark.SparkUpgradeException You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd MM yyyy EEEEE' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html --- !query -select from_unixtime(12345, 'MMMMM') --- !query schema -struct<> --- !query output -org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'MMMMM' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html - - --- !query -select from_unixtime(54321, 'QQQQQ') --- !query schema -struct --- !query output -NULL - - --- !query -select from_unixtime(23456, 'aaaaa') --- !query schema -struct<> --- !query output -org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'aaaaa' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html - - -- !query select from_json('{"time":"26/October/2015"}', 'time Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')) -- !query schema @@ -945,29 +856,3 @@ struct<> -- !query output org.apache.spark.SparkUpgradeException You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'dd/MMMMM/yyyy' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html - - --- !query -select from_unixtime(1, 'yyyyyyyyyyy-MM-dd') --- !query schema -struct<> --- !query output -org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'yyyyyyyyyyy-MM-dd' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html - - --- !query -select date_format(timestamp '2018-11-17 13:33:33', 'yyyyyyyyyy-MM-dd HH:mm:ss') --- !query schema -struct --- !query output -0000002018-11-17 13:33:33 - - --- !query -select date_format(date '2018-11-17', 'yyyyyyyyyyy-MM-dd') --- !query schema -struct<> --- !query output -org.apache.spark.SparkUpgradeException -You may get a different result due to the upgrading of Spark 3.0: Fail to recognize 'yyyyyyyyyyy-MM-dd' pattern in the DateTimeFormatter. 1) You can set spark.sql.legacy.timeParserPolicy to LEGACY to restore the behavior before Spark 3.0. 2) You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index c12468a4e70f8..5cc9e156db1b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -689,8 +689,9 @@ class DateFunctionsSuite extends QueryTest with SharedSparkSession { Row(secs(ts5.getTime)), Row(null))) // invalid format - checkAnswer(df1.selectExpr(s"to_unix_timestamp(x, 'yyyy-MM-dd bb:HH:ss')"), Seq( - Row(null), Row(null), Row(null), Row(null))) + val invalid = df1.selectExpr(s"to_unix_timestamp(x, 'yyyy-MM-dd bb:HH:ss')") + val e = intercept[IllegalArgumentException](invalid.collect()) + assert(e.getMessage.contains('b')) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index cb410b4f0d7dc..efc7cac6a5f21 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -843,6 +843,26 @@ class FileBasedDataSourceSuite extends QueryTest } } + test("SPARK-31935: Hadoop file system config should be effective in data source options") { + Seq("parquet", "").foreach { format => + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> format) { + withTempDir { dir => + val path = dir.getCanonicalPath + val defaultFs = "nonexistFS://nonexistFS" + val expectMessage = "No FileSystem for scheme: nonexistFS" + val message1 = intercept[java.io.IOException] { + spark.range(10).write.option("fs.defaultFS", defaultFs).parquet(path) + }.getMessage + assert(message1 == expectMessage) + val message2 = intercept[java.io.IOException] { + spark.read.option("fs.defaultFS", defaultFs).parquet(path) + }.getMessage + assert(message2 == expectMessage) + } + } + } + } + test("SPARK-31116: Select nested schema with case insensitive mode") { // This test case failed at only Parquet. ORC is added for test coverage parity. Seq("orc", "parquet").foreach { format => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 3d0ba05f76b71..9fa97bffa8910 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -1021,4 +1021,20 @@ class AdaptiveQueryExecSuite } } } + + test("SPARK-31220 repartition obeys initialPartitionNum when adaptiveExecutionEnabled") { + Seq(true, false).foreach { enableAQE => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString, + SQLConf.SHUFFLE_PARTITIONS.key -> "6", + SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "7") { + val partitionsNum = spark.range(10).repartition($"id").rdd.collectPartitions().length + if (enableAQE) { + assert(partitionsNum === 7) + } else { + assert(partitionsNum === 6) + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DateTimeBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DateTimeBenchmark.scala index f56efa3bba600..c7b8737b7a753 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DateTimeBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DateTimeBenchmark.scala @@ -21,8 +21,10 @@ import java.sql.{Date, Timestamp} import java.time.{Instant, LocalDate} import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.Dataset import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_DAY import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, LA} +import org.apache.spark.sql.execution.HiveResult import org.apache.spark.sql.internal.SQLConf /** @@ -182,14 +184,19 @@ object DateTimeBenchmark extends SqlBasedBenchmark { benchmark.addCase("From java.time.LocalDate", numIters) { _ => spark.range(rowsNum).map(millis => LocalDate.ofEpochDay(millis / MILLIS_PER_DAY)).noop() } + def dates = { + spark.range(0, rowsNum, 1, 1).map(millis => new Date(millis)) + } benchmark.addCase("Collect java.sql.Date", numIters) { _ => - spark.range(0, rowsNum, 1, 1).map(millis => new Date(millis)).collect() + dates.collect() + } + def localDates = { + spark.range(0, rowsNum, 1, 1) + .map(millis => LocalDate.ofEpochDay(millis / MILLIS_PER_DAY)) } benchmark.addCase("Collect java.time.LocalDate", numIters) { _ => withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") { - spark.range(0, rowsNum, 1, 1) - .map(millis => LocalDate.ofEpochDay(millis / MILLIS_PER_DAY)) - .collect() + localDates.collect() } } benchmark.addCase("From java.sql.Timestamp", numIters) { _ => @@ -202,14 +209,37 @@ object DateTimeBenchmark extends SqlBasedBenchmark { spark.range(0, rowsNum, 1, 1) .collect() } + def timestamps = { + spark.range(0, rowsNum, 1, 1).map(millis => new Timestamp(millis)) + } benchmark.addCase("Collect java.sql.Timestamp", numIters) { _ => - spark.range(0, rowsNum, 1, 1).map(millis => new Timestamp(millis)).collect() + timestamps.collect() + } + def instants = { + spark.range(0, rowsNum, 1, 1).map(millis => Instant.ofEpochMilli(millis)) } benchmark.addCase("Collect java.time.Instant", numIters) { _ => withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") { - spark.range(0, rowsNum, 1, 1) - .map(millis => Instant.ofEpochMilli(millis)) - .collect() + instants.collect() + } + } + def toHiveString(df: Dataset[_]): Unit = { + HiveResult.hiveResultString(df.queryExecution.executedPlan) + } + benchmark.addCase("java.sql.Date to Hive string", numIters) { _ => + toHiveString(dates) + } + benchmark.addCase("java.time.LocalDate to Hive string", numIters) { _ => + withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") { + toHiveString(localDates) + } + } + benchmark.addCase("java.sql.Timestamp to Hive string", numIters) { _ => + toHiveString(timestamps) + } + benchmark.addCase("java.time.Instant to Hive string", numIters) { _ => + withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") { + toHiveString(instants) } } benchmark.run() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceSuite.scala index 1e3c660e09454..9345158fd07ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceSuite.scala @@ -19,11 +19,12 @@ package org.apache.spark.sql.execution.datasources import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} +import org.scalatest.PrivateMethodTester import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.test.SharedSparkSession -class DataSourceSuite extends SharedSparkSession { +class DataSourceSuite extends SharedSparkSession with PrivateMethodTester { import TestPaths._ test("test glob and non glob paths") { @@ -132,6 +133,17 @@ class DataSourceSuite extends SharedSparkSession { ) ) } + + test("Data source options should be propagated in method checkAndGlobPathIfNecessary") { + val dataSourceOptions = Map("fs.defaultFS" -> "nonexistsFs://nonexistsFs") + val dataSource = DataSource(spark, "parquet", Seq("/path3"), options = dataSourceOptions) + val checkAndGlobPathIfNecessary = PrivateMethod[Seq[Path]]('checkAndGlobPathIfNecessary) + + val message = intercept[java.io.IOException] { + dataSource invokePrivate checkAndGlobPathIfNecessary(false, false) + }.getMessage + assert(message.equals("No FileSystem for scheme: nonexistsFs")) + } } object TestPaths { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala index 80061dc84efbc..2f9e510752b02 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala @@ -23,7 +23,9 @@ import org.scalactic.Equality import org.apache.spark.sql.{DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.SchemaPruningTest +import org.apache.spark.sql.catalyst.expressions.Concat import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.plans.logical.Expand import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.functions._ @@ -338,6 +340,75 @@ abstract class SchemaPruningSuite } } + testSchemaPruning("select one deep nested complex field after repartition") { + val query = sql("select * from contacts") + .repartition(100) + .where("employer.company.address is not null") + .selectExpr("employer.id as employer_id") + checkScan(query, + "struct>>") + checkAnswer(query, Row(0) :: Nil) + } + + testSchemaPruning("select nested field in aggregation function of Aggregate") { + val query1 = sql("select count(name.first) from contacts group by name.last") + checkScan(query1, "struct>") + checkAnswer(query1, Row(2) :: Row(2) :: Nil) + + val query2 = sql("select count(name.first), sum(pets) from contacts group by id") + checkScan(query2, "struct,pets:int>") + checkAnswer(query2, Row(1, 1) :: Row(1, null) :: Row(1, 3) :: Row(1, null) :: Nil) + + val query3 = sql("select count(name.first), first(name) from contacts group by id") + checkScan(query3, "struct>") + checkAnswer(query3, + Row(1, Row("Jane", "X.", "Doe")) :: + Row(1, Row("Jim", null, "Jones")) :: + Row(1, Row("John", "Y.", "Doe")) :: + Row(1, Row("Janet", null, "Jones")) :: Nil) + + val query4 = sql("select count(name.first), sum(pets) from contacts group by name.last") + checkScan(query4, "struct,pets:int>") + checkAnswer(query4, Row(2, null) :: Row(2, 4) :: Nil) + } + + testSchemaPruning("select nested field in Expand") { + import org.apache.spark.sql.catalyst.dsl.expressions._ + + val query1 = Expand( + Seq( + Seq($"name.first", $"name.last"), + Seq(Concat(Seq($"name.first", $"name.last")), + Concat(Seq($"name.last", $"name.first"))) + ), + Seq('a.string, 'b.string), + sql("select * from contacts").logicalPlan + ).toDF() + checkScan(query1, "struct>") + checkAnswer(query1, + Row("Jane", "Doe") :: + Row("JaneDoe", "DoeJane") :: + Row("John", "Doe") :: + Row("JohnDoe", "DoeJohn") :: + Row("Jim", "Jones") :: + Row("JimJones", "JonesJim") :: + Row("Janet", "Jones") :: + Row("JanetJones", "JonesJanet") :: Nil) + + val name = StructType.fromDDL("first string, middle string, last string") + val query2 = Expand( + Seq(Seq($"name", $"name.last")), + Seq('a.struct(name), 'b.string), + sql("select * from contacts").logicalPlan + ).toDF() + checkScan(query2, "struct>") + checkAnswer(query2, + Row(Row("Jane", "X.", "Doe"), "Doe") :: + Row(Row("John", "Y.", "Doe"), "Doe") :: + Row(Row("Jim", null, "Jones"), "Jones") :: + Row(Row("Janet", null, "Jones"), "Jones") ::Nil) + } + protected def testSchemaPruning(testName: String)(testThunk: => Unit): Unit = { test(s"Spark vectorized reader - without partition data column - $testName") { withSQLConf(vectorizedReaderEnabledKey -> "true") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index fa320333143ec..32dceaac7059c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -532,6 +532,18 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } } + test("SPARK-31935: Hadoop file system config should be effective in data source options") { + withTempDir { dir => + val path = dir.getCanonicalPath + val defaultFs = "nonexistFS://nonexistFS" + val expectMessage = "No FileSystem for scheme: nonexistFS" + val message = intercept[java.io.IOException] { + spark.readStream.option("fs.defaultFS", defaultFs).text(path) + }.getMessage + assert(message == expectMessage) + } + } + test("read from textfile") { withTempDirs { case (src, tmp) => val textStream = spark.readStream.textFile(src.getCanonicalPath) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SharedThriftServer.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SharedThriftServer.scala index e002bc0117c8b..c9e41db52cd50 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SharedThriftServer.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SharedThriftServer.scala @@ -33,6 +33,8 @@ trait SharedThriftServer extends SharedSparkSession { private var hiveServer2: HiveThriftServer2 = _ private var serverPort: Int = 0 + def mode: ServerMode.Value + override def beforeAll(): Unit = { super.beforeAll() // Retries up to 3 times with different port numbers if the server fails to start @@ -53,11 +55,17 @@ trait SharedThriftServer extends SharedSparkSession { } } + protected def jdbcUri: String = if (mode == ServerMode.http) { + s"jdbc:hive2://localhost:$serverPort/default;transportMode=http;httpPath=cliservice" + } else { + s"jdbc:hive2://localhost:$serverPort" + } + protected def withJdbcStatement(fs: (Statement => Unit)*): Unit = { val user = System.getProperty("user.name") require(serverPort != 0, "Failed to bind an actual port for HiveThriftServer2") val connections = - fs.map { _ => DriverManager.getConnection(s"jdbc:hive2://localhost:$serverPort", user, "") } + fs.map { _ => DriverManager.getConnection(jdbcUri, user, "") } val statements = connections.map(_.createStatement()) try { @@ -71,21 +79,33 @@ trait SharedThriftServer extends SharedSparkSession { private def startThriftServer(attempt: Int): Unit = { logInfo(s"Trying to start HiveThriftServer2:, attempt=$attempt") val sqlContext = spark.newSession().sqlContext - // Set the HIVE_SERVER2_THRIFT_PORT to 0, so it could randomly pick any free port to use. + // Set the HIVE_SERVER2_THRIFT_PORT and HIVE_SERVER2_THRIFT_HTTP_PORT to 0, so it could + // randomly pick any free port to use. // It's much more robust than set a random port generated by ourselves ahead sqlContext.setConf(ConfVars.HIVE_SERVER2_THRIFT_PORT.varname, "0") - hiveServer2 = HiveThriftServer2.startWithContext(sqlContext) - hiveServer2.getServices.asScala.foreach { - case t: ThriftCLIService if t.getPortNumber != 0 => - serverPort = t.getPortNumber - logInfo(s"Started HiveThriftServer2: port=$serverPort, attempt=$attempt") - case _ => - } + sqlContext.setConf(ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT.varname, "0") + sqlContext.setConf(ConfVars.HIVE_SERVER2_TRANSPORT_MODE.varname, mode.toString) + + try { + hiveServer2 = HiveThriftServer2.startWithContext(sqlContext) + hiveServer2.getServices.asScala.foreach { + case t: ThriftCLIService => + serverPort = t.getPortNumber + logInfo(s"Started HiveThriftServer2: port=$serverPort, attempt=$attempt") + case _ => + } - // Wait for thrift server to be ready to serve the query, via executing simple query - // till the query succeeds. See SPARK-30345 for more details. - eventually(timeout(30.seconds), interval(1.seconds)) { - withJdbcStatement { _.execute("SELECT 1") } + // Wait for thrift server to be ready to serve the query, via executing simple query + // till the query succeeds. See SPARK-30345 for more details. + eventually(timeout(30.seconds), interval(1.seconds)) { + withJdbcStatement { _.execute("SELECT 1") } + } + } catch { + case e: Exception => + logError("Error start hive server with Context ", e) + if (hiveServer2 != null) { + hiveServer2.stop() + } } } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala index 15cc3109da3f7..553f10a275bce 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala @@ -54,6 +54,9 @@ import org.apache.spark.sql.types._ */ class ThriftServerQueryTestSuite extends SQLQueryTestSuite with SharedThriftServer { + + override def mode: ServerMode.Value = ServerMode.binary + override protected def testFile(fileName: String): String = { val url = Thread.currentThread().getContextClassLoader.getResource(fileName) // Copy to avoid URISyntaxException during accessing the resources in `sql/core` diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala index 3e1fce78ae71c..d6420dee41adb 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.thriftserver -class ThriftServerWithSparkContextSuite extends SharedThriftServer { +trait ThriftServerWithSparkContextSuite extends SharedThriftServer { test("SPARK-29911: Uncache cached tables when session closed") { val cacheManager = spark.sharedState.cacheManager @@ -42,3 +42,12 @@ class ThriftServerWithSparkContextSuite extends SharedThriftServer { } } } + + +class ThriftServerWithSparkContextInBinarySuite extends ThriftServerWithSparkContextSuite { + override def mode: ServerMode.Value = ServerMode.binary +} + +class ThriftServerWithSparkContextInHttpSuite extends ThriftServerWithSparkContextSuite { + override def mode: ServerMode.Value = ServerMode.http +} diff --git a/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftBinaryCLIService.java b/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftBinaryCLIService.java index e1ee503b81209..00bdf7e19126e 100644 --- a/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftBinaryCLIService.java +++ b/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftBinaryCLIService.java @@ -28,6 +28,7 @@ import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.conf.HiveConf.ConfVars; import org.apache.hadoop.hive.shims.ShimLoader; +import org.apache.hive.service.ServiceException; import org.apache.hive.service.auth.HiveAuthFactory; import org.apache.hive.service.cli.CLIService; import org.apache.hive.service.server.ThreadFactoryWithGarbageCleanup; @@ -45,7 +46,7 @@ public ThriftBinaryCLIService(CLIService cliService) { } @Override - public void run() { + protected void initializeServer() { try { // Server thread pool String threadPoolName = "HiveServer2-Handler-Pool"; @@ -100,6 +101,14 @@ public void run() { String msg = "Starting " + ThriftBinaryCLIService.class.getSimpleName() + " on port " + serverSocket.getServerSocket().getLocalPort() + " with " + minWorkerThreads + "..." + maxWorkerThreads + " worker threads"; LOG.info(msg); + } catch (Exception t) { + throw new ServiceException("Error initializing " + getName(), t); + } + } + + @Override + public void run() { + try { server.serve(); } catch (Throwable t) { LOG.fatal( diff --git a/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java b/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java index 8fce9d9383438..783e5795aca76 100644 --- a/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java +++ b/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java @@ -175,6 +175,7 @@ public synchronized void init(HiveConf hiveConf) { public synchronized void start() { super.start(); if (!isStarted && !isEmbedded) { + initializeServer(); new Thread(this).start(); isStarted = true; } @@ -633,6 +634,8 @@ public TFetchResultsResp FetchResults(TFetchResultsReq req) throws TException { return resp; } + protected abstract void initializeServer(); + @Override public abstract void run(); diff --git a/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java b/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java index 1099a00b67eb7..bd64c777c1d76 100644 --- a/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java +++ b/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java @@ -28,6 +28,7 @@ import org.apache.hadoop.hive.shims.ShimLoader; import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.util.Shell; +import org.apache.hive.service.ServiceException; import org.apache.hive.service.auth.HiveAuthFactory; import org.apache.hive.service.cli.CLIService; import org.apache.hive.service.cli.thrift.TCLIService.Iface; @@ -53,13 +54,8 @@ public ThriftHttpCLIService(CLIService cliService) { super(cliService, ThriftHttpCLIService.class.getSimpleName()); } - /** - * Configure Jetty to serve http requests. Example of a client connection URL: - * http://localhost:10000/servlets/thrifths2/ A gateway may cause actual target URL to differ, - * e.g. http://gateway:port/hive2/servlets/thrifths2/ - */ @Override - public void run() { + protected void initializeServer() { try { // Server thread pool // Start with minWorkerThreads, expand till maxWorkerThreads and reject subsequent requests @@ -150,6 +146,19 @@ public void run() { + " mode on port " + connector.getLocalPort()+ " path=" + httpPath + " with " + minWorkerThreads + "..." + maxWorkerThreads + " worker threads"; LOG.info(msg); + } catch (Exception t) { + throw new ServiceException("Error initializing " + getName(), t); + } + } + + /** + * Configure Jetty to serve http requests. Example of a client connection URL: + * http://localhost:10000/servlets/thrifths2/ A gateway may cause actual target URL to differ, + * e.g. http://gateway:port/hive2/servlets/thrifths2/ + */ + @Override + public void run() { + try { httpServer.join(); } catch (Throwable t) { LOG.fatal( diff --git a/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftBinaryCLIService.java b/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftBinaryCLIService.java index a7de9c0f3d0d2..ce79e3c8228a6 100644 --- a/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftBinaryCLIService.java +++ b/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftBinaryCLIService.java @@ -29,6 +29,7 @@ import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.conf.HiveConf.ConfVars; import org.apache.hadoop.hive.shims.ShimLoader; +import org.apache.hive.service.ServiceException; import org.apache.hive.service.auth.HiveAuthFactory; import org.apache.hive.service.cli.CLIService; import org.apache.hive.service.server.ThreadFactoryWithGarbageCleanup; @@ -46,7 +47,7 @@ public ThriftBinaryCLIService(CLIService cliService) { } @Override - public void run() { + protected void initializeServer() { try { // Server thread pool String threadPoolName = "HiveServer2-Handler-Pool"; @@ -101,6 +102,14 @@ public void run() { String msg = "Starting " + ThriftBinaryCLIService.class.getSimpleName() + " on port " + portNum + " with " + minWorkerThreads + "..." + maxWorkerThreads + " worker threads"; LOG.info(msg); + } catch (Exception t) { + throw new ServiceException("Error initializing " + getName(), t); + } + } + + @Override + public void run() { + try { server.serve(); } catch (Throwable t) { LOG.error( diff --git a/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java b/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java index d41c3b493bb47..e46799a1c427d 100644 --- a/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java +++ b/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java @@ -176,6 +176,7 @@ public synchronized void init(HiveConf hiveConf) { public synchronized void start() { super.start(); if (!isStarted && !isEmbedded) { + initializeServer(); new Thread(this).start(); isStarted = true; } @@ -670,6 +671,8 @@ public TGetCrossReferenceResp GetCrossReference(TGetCrossReferenceReq req) return resp; } + protected abstract void initializeServer(); + @Override public abstract void run(); diff --git a/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java b/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java index 73d5f84476af0..ab9ed5b1f371e 100644 --- a/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java +++ b/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java @@ -28,6 +28,7 @@ import org.apache.hadoop.hive.shims.ShimLoader; import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.util.Shell; +import org.apache.hive.service.ServiceException; import org.apache.hive.service.auth.HiveAuthFactory; import org.apache.hive.service.cli.CLIService; import org.apache.hive.service.rpc.thrift.TCLIService; @@ -54,13 +55,8 @@ public ThriftHttpCLIService(CLIService cliService) { super(cliService, ThriftHttpCLIService.class.getSimpleName()); } - /** - * Configure Jetty to serve http requests. Example of a client connection URL: - * http://localhost:10000/servlets/thrifths2/ A gateway may cause actual target URL to differ, - * e.g. http://gateway:port/hive2/servlets/thrifths2/ - */ @Override - public void run() { + protected void initializeServer() { try { // Server thread pool // Start with minWorkerThreads, expand till maxWorkerThreads and reject subsequent requests @@ -151,6 +147,19 @@ public void run() { + " mode on port " + portNum + " path=" + httpPath + " with " + minWorkerThreads + "..." + maxWorkerThreads + " worker threads"; LOG.info(msg); + } catch (Exception t) { + throw new ServiceException("Error initializing " + getName(), t); + } + } + + /** + * Configure Jetty to serve http requests. Example of a client connection URL: + * http://localhost:10000/servlets/thrifths2/ A gateway may cause actual target URL to differ, + * e.g. http://gateway:port/hive2/servlets/thrifths2/ + */ + @Override + public void run() { + try { httpServer.join(); } catch (Throwable t) { LOG.error( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 9f83f2ab96094..116217ecec0ba 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.CommandUtils import org.apache.spark.sql.hive.HiveExternalCatalog @@ -225,9 +226,12 @@ case class InsertIntoHiveTable( ExternalCatalogUtils.unescapePathName(splitPart(1)) }.toMap + val caseInsensitiveDpMap = CaseInsensitiveMap(dpMap) + val updatedPartitionSpec = partition.map { case (key, Some(value)) => key -> value - case (key, None) if dpMap.contains(key) => key -> dpMap(key) + case (key, None) if caseInsensitiveDpMap.contains(key) => + key -> caseInsensitiveDpMap(key) case (key, _) => throw new SparkException(s"Dynamic partition key $key is not among " + "written partition paths.") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 79c6ade2807d3..d12eae0e410b1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2544,6 +2544,19 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi assert(e.getMessage.contains("Cannot modify the value of a static config")) } } + + test("SPARK-29295: dynamic partition map parsed from partition path should be case insensitive") { + withTable("t") { + withSQLConf("hive.exec.dynamic.partition" -> "true", + "hive.exec.dynamic.partition.mode" -> "nonstrict") { + withTempDir { loc => + sql(s"CREATE TABLE t(c1 INT) PARTITIONED BY(P1 STRING) LOCATION '${loc.getAbsolutePath}'") + sql("INSERT OVERWRITE TABLE t PARTITION(P1) VALUES(1, 'caseSensitive')") + checkAnswer(sql("select * from t"), Row(1, "caseSensitive")) + } + } + } + } } class SQLQuerySuite extends SQLQuerySuiteBase with DisableAdaptiveExecutionSuite