diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 512d539ee9c38..ef28e2c48ad02 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -321,6 +321,17 @@ package object config { .intConf .createWithDefault(3) + private[spark] val REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS = + ConfigBuilder("spark.reducer.maxBlocksInFlightPerAddress") + .doc("This configuration limits the number of remote blocks being fetched per reduce task" + + " from a given host port. When a large number of blocks are being requested from a given" + + " address in a single fetch or simultaneously, this could crash the serving executor or" + + " Node Manager. This is especially useful to reduce the load on the Node Manager when" + + " external shuffle is enabled. You can mitigate the issue by setting it to a lower value.") + .intConf + .checkValue(_ > 0, "The max no. of blocks in flight cannot be non-positive.") + .createWithDefault(Int.MaxValue) + private[spark] val REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM = ConfigBuilder("spark.reducer.maxReqSizeShuffleToMem") .doc("The blocks of a shuffle request will be fetched to disk when size of the request is " + diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 2fbac79a2305b..c8d1460300934 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -51,6 +51,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue), + SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), SparkEnv.get.conf.get(config.REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM), SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true)) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 81d822dc8a98f..2d176b62f8b36 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -23,7 +23,7 @@ import java.util.concurrent.LinkedBlockingQueue import javax.annotation.concurrent.GuardedBy import scala.collection.mutable -import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging @@ -52,6 +52,8 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream * @param streamWrapper A function to wrap the returned input stream. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. + * @param maxBlocksInFlightPerAddress max number of shuffle blocks being fetched at any given point + * for a given remote host:port. * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory. * @param detectCorrupt whether to detect any corruption in fetched blocks. */ @@ -64,6 +66,7 @@ final class ShuffleBlockFetcherIterator( streamWrapper: (BlockId, InputStream) => InputStream, maxBytesInFlight: Long, maxReqsInFlight: Int, + maxBlocksInFlightPerAddress: Int, maxReqSizeShuffleToMem: Long, detectCorrupt: Boolean) extends Iterator[(BlockId, InputStream)] with TempShuffleFileManager with Logging { @@ -110,12 +113,21 @@ final class ShuffleBlockFetcherIterator( */ private[this] val fetchRequests = new Queue[FetchRequest] + /** + * Queue of fetch requests which could not be issued the first time they were dequeued. These + * requests are tried again when the fetch constraints are satisfied. + */ + private[this] val deferredFetchRequests = new HashMap[BlockManagerId, Queue[FetchRequest]]() + /** Current bytes in flight from our requests */ private[this] var bytesInFlight = 0L /** Current number of requests in flight */ private[this] var reqsInFlight = 0 + /** Current number of blocks in flight per host:port */ + private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, Int]() + /** * The blocks that can't be decompressed successfully, it is used to guarantee that we retry * at most once for those corrupted blocks. @@ -248,7 +260,8 @@ final class ShuffleBlockFetcherIterator( // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 // nodes, rather than blocking on reading output from one node. val targetRequestSize = math.max(maxBytesInFlight / 5, 1L) - logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize) + logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize + + ", maxBlocksInFlightPerAddress: " + maxBlocksInFlightPerAddress) // Split local and remote blocks. Remote blocks are further split into FetchRequests of size // at most maxBytesInFlight in order to limit the amount of data in flight. @@ -277,11 +290,13 @@ final class ShuffleBlockFetcherIterator( } else if (size < 0) { throw new BlockException(blockId, "Negative block size " + size) } - if (curRequestSize >= targetRequestSize) { + if (curRequestSize >= targetRequestSize || + curBlocks.size >= maxBlocksInFlightPerAddress) { // Add this FetchRequest remoteRequests += new FetchRequest(address, curBlocks) + logDebug(s"Creating fetch request of $curRequestSize at $address " + + s"with ${curBlocks.size} blocks") curBlocks = new ArrayBuffer[(BlockId, Long)] - logDebug(s"Creating fetch request of $curRequestSize at $address") curRequestSize = 0 } } @@ -375,6 +390,7 @@ final class ShuffleBlockFetcherIterator( result match { case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) => if (address != blockManager.blockManagerId) { + numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 shuffleMetrics.incRemoteBytesRead(buf.size) if (buf.isInstanceOf[FileSegmentManagedBuffer]) { shuffleMetrics.incRemoteBytesReadToDisk(buf.size) @@ -443,12 +459,57 @@ final class ShuffleBlockFetcherIterator( } private def fetchUpToMaxBytes(): Unit = { - // Send fetch requests up to maxBytesInFlight - while (fetchRequests.nonEmpty && - (bytesInFlight == 0 || - (reqsInFlight + 1 <= maxReqsInFlight && - bytesInFlight + fetchRequests.front.size <= maxBytesInFlight))) { - sendRequest(fetchRequests.dequeue()) + // Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host + // immediately, defer the request until the next time it can be processed. + + // Process any outstanding deferred fetch requests if possible. + if (deferredFetchRequests.nonEmpty) { + for ((remoteAddress, defReqQueue) <- deferredFetchRequests) { + while (isRemoteBlockFetchable(defReqQueue) && + !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) { + val request = defReqQueue.dequeue() + logDebug(s"Processing deferred fetch request for $remoteAddress with " + + s"${request.blocks.length} blocks") + send(remoteAddress, request) + if (defReqQueue.isEmpty) { + deferredFetchRequests -= remoteAddress + } + } + } + } + + // Process any regular fetch requests if possible. + while (isRemoteBlockFetchable(fetchRequests)) { + val request = fetchRequests.dequeue() + val remoteAddress = request.address + if (isRemoteAddressMaxedOut(remoteAddress, request)) { + logDebug(s"Deferring fetch request for $remoteAddress with ${request.blocks.size} blocks") + val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]()) + defReqQueue.enqueue(request) + deferredFetchRequests(remoteAddress) = defReqQueue + } else { + send(remoteAddress, request) + } + } + + def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = { + sendRequest(request) + numBlocksInFlightPerAddress(remoteAddress) = + numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size + } + + def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = { + fetchReqQueue.nonEmpty && + (bytesInFlight == 0 || + (reqsInFlight + 1 <= maxReqsInFlight && + bytesInFlight + fetchReqQueue.front.size <= maxBytesInFlight)) + } + + // Checks if sending a new fetch request will exceed the max no. of blocks being fetched from a + // given remote address. + def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: FetchRequest): Boolean = { + numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size > + maxBlocksInFlightPerAddress } } diff --git a/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala index 1b2b1932e0c3d..eff0aa4453f08 100644 --- a/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala +++ b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala @@ -51,6 +51,10 @@ private[spark] class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Orderin this } + def poll(): A = { + underlying.poll() + } + override def +=(elem1: A, elem2: A, elems: A*): this.type = { this += elem1 += elem2 ++= elems } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 6a70cedf769b8..c371cbcf8dff5 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -110,6 +110,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, + Int.MaxValue, true) // 3 local blocks fetched in initialization @@ -187,6 +188,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, + Int.MaxValue, true) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() @@ -254,6 +256,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, + Int.MaxValue, true) // Continue only after the mock calls onBlockFetchFailure @@ -319,6 +322,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, + Int.MaxValue, true) // Continue only after the mock calls onBlockFetchFailure @@ -400,6 +404,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, + Int.MaxValue, false) // Continue only after the mock calls onBlockFetchFailure @@ -457,6 +462,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (_, in) => in, maxBytesInFlight = Int.MaxValue, maxReqsInFlight = Int.MaxValue, + maxBlocksInFlightPerAddress = Int.MaxValue, maxReqSizeShuffleToMem = 200, detectCorrupt = true) } diff --git a/core/src/test/scala/org/apache/spark/util/BoundedPriorityQueueSuite.scala b/core/src/test/scala/org/apache/spark/util/BoundedPriorityQueueSuite.scala new file mode 100644 index 0000000000000..9465ca70e94f2 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/BoundedPriorityQueueSuite.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.apache.spark.SparkFunSuite + +class BoundedPriorityQueueSuite extends SparkFunSuite { + test("BoundedPriorityQueue poll test") { + val pq = new BoundedPriorityQueue[Double](4) + + pq += 0.1 + pq += 1.5 + pq += 1.0 + pq += 0.3 + pq += 0.01 + + assert(pq.isEmpty == false) + assert(pq.poll() == 0.1) + assert(pq.poll() == 0.3) + assert(pq.poll() == 1.0) + assert(pq.poll() == 1.5) + assert(pq.isEmpty == true) + + val pq2 = new BoundedPriorityQueue[(Int, Double)](4)(Ordering.by(_._2)) + pq2 += 1 -> 0.5 + pq2 += 5 -> 0.1 + pq2 += 3 -> 0.3 + pq2 += 4 -> 0.2 + pq2 += 1 -> 0.4 + + assert(pq2.poll()._2 == 0.2) + assert(pq2.poll()._2 == 0.3) + assert(pq2.poll()._2 == 0.4) + assert(pq2.poll()._2 == 0.5) + } +} diff --git a/docs/configuration.md b/docs/configuration.md index 91b5befd1b1eb..d3df923c42690 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -536,6 +536,15 @@ Apart from these, the following properties are also available, and may be useful + spark.reducer.maxBlocksInFlightPerAddress + Int.MaxValue + + This configuration limits the number of remote blocks being fetched per reduce task from a + given host port. When a large number of blocks are being requested from a given address in a + single fetch or simultaneously, this could crash the serving executor or Node Manager. This + is especially useful to reduce the load on the Node Manager when external shuffle is enabled. + You can mitigate this issue by setting it to a lower value. + spark.reducer.maxReqSizeShuffleToMem Long.MaxValue diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 7401b63e022c1..cf257c06c9516 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -545,6 +545,17 @@ See the [configuration page](configuration.html) for information on Spark config Fetcher Cache + + spark.mesos.driver.failoverTimeout + 0.0 + + The amount of time (in seconds) that the master will wait for the + driver to reconnect, after being temporarily disconnected, before + it tears down the driver framework by killing all its + executors. The default value is zero, meaning no timeout: if the + driver disconnects, the master immediately tears down the framework. + + # Troubleshooting and Debugging diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala index 56d697f359614..6c8619e3c3c13 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala @@ -58,9 +58,16 @@ package object config { private [spark] val DRIVER_LABELS = ConfigBuilder("spark.mesos.driver.labels") - .doc("Mesos labels to add to the driver. Labels are free-form key-value pairs. Key-value" + + .doc("Mesos labels to add to the driver. Labels are free-form key-value pairs. Key-value " + "pairs should be separated by a colon, and commas used to list more than one." + "Ex. key:value,key2:value2") .stringConf .createOptional + + private [spark] val DRIVER_FAILOVER_TIMEOUT = + ConfigBuilder("spark.mesos.driver.failoverTimeout") + .doc("Amount of time in seconds that the master will wait to hear from the driver, " + + "during a temporary disconnection, before tearing down all the executors.") + .doubleConf + .createWithDefault(0.0) } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 7dd42c41aa7c2..6e7f41dad34ba 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -29,6 +29,7 @@ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} import org.apache.mesos.SchedulerDriver import org.apache.spark.{SecurityManager, SparkContext, SparkException, TaskState} +import org.apache.spark.deploy.mesos.config._ import org.apache.spark.internal.config import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient @@ -177,7 +178,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( sc.conf, sc.conf.getOption("spark.mesos.driver.webui.url").orElse(sc.ui.map(_.webUrl)), None, - None, + Some(sc.conf.get(DRIVER_FAILOVER_TIMEOUT)), sc.conf.getOption("spark.mesos.driver.frameworkId") ) diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index 7cca5fedb31eb..d9ff4a403ea36 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -33,6 +33,7 @@ import org.scalatest.mock.MockitoSugar import org.scalatest.BeforeAndAfter import org.apache.spark.{LocalSparkContext, SecurityManager, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.deploy.mesos.config._ import org.apache.spark.internal.config._ import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} @@ -369,6 +370,41 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite backend.start() } + test("failover timeout is set in created scheduler driver") { + val failoverTimeoutIn = 3600.0 + initializeSparkConf(Map(DRIVER_FAILOVER_TIMEOUT.key -> failoverTimeoutIn.toString)) + sc = new SparkContext(sparkConf) + + val taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.sc).thenReturn(sc) + + val driver = mock[SchedulerDriver] + when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) + + val securityManager = mock[SecurityManager] + + val backend = new MesosCoarseGrainedSchedulerBackend( + taskScheduler, sc, "master", securityManager) { + override protected def createSchedulerDriver( + masterUrl: String, + scheduler: Scheduler, + sparkUser: String, + appName: String, + conf: SparkConf, + webuiUrl: Option[String] = None, + checkpoint: Option[Boolean] = None, + failoverTimeout: Option[Double] = None, + frameworkId: Option[String] = None): SchedulerDriver = { + markRegistered() + assert(failoverTimeout.isDefined) + assert(failoverTimeout.get.equals(failoverTimeoutIn)) + driver + } + } + + backend.start() + } + test("honors unset spark.mesos.containerizer") { setBackend(Map("spark.mesos.executor.docker.image" -> "test")) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala index 2652f6d72730c..e0748043c46e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala @@ -35,13 +35,13 @@ trait LogicalPlanVisitor[T] { case p: LocalLimit => visitLocalLimit(p) case p: Pivot => visitPivot(p) case p: Project => visitProject(p) - case p: Range => visitRange(p) case p: Repartition => visitRepartition(p) case p: RepartitionByExpression => visitRepartitionByExpr(p) case p: ResolvedHint => visitHint(p) case p: Sample => visitSample(p) case p: ScriptTransformation => visitScriptTransform(p) case p: Union => visitUnion(p) + case p: Window => visitWindow(p) case p: LogicalPlan => default(p) } @@ -73,8 +73,6 @@ trait LogicalPlanVisitor[T] { def visitProject(p: Project): T - def visitRange(p: Range): T - def visitRepartition(p: Repartition): T def visitRepartitionByExpr(p: RepartitionByExpression): T @@ -84,4 +82,6 @@ trait LogicalPlanVisitor[T] { def visitScriptTransform(p: ScriptTransformation): T def visitUnion(p: Union): T + + def visitWindow(p: Window): T } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala index 93908b04fb643..4cff72d45a400 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala @@ -65,11 +65,6 @@ object BasicStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { ProjectEstimation.estimate(p).getOrElse(fallback(p)) } - override def visitRange(p: logical.Range): Statistics = { - val sizeInBytes = LongType.defaultSize * p.numElements - Statistics(sizeInBytes = sizeInBytes) - } - override def visitRepartition(p: Repartition): Statistics = fallback(p) override def visitRepartitionByExpr(p: RepartitionByExpression): Statistics = fallback(p) @@ -79,4 +74,6 @@ object BasicStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { override def visitScriptTransform(p: ScriptTransformation): Statistics = fallback(p) override def visitUnion(p: Union): Statistics = fallback(p) + + override def visitWindow(p: Window): Statistics = fallback(p) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala index 559f12072e448..d701a956887a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala @@ -136,10 +136,6 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { override def visitProject(p: Project): Statistics = visitUnaryNode(p) - override def visitRange(p: logical.Range): Statistics = { - p.computeStats() - } - override def visitRepartition(p: Repartition): Statistics = default(p) override def visitRepartitionByExpr(p: RepartitionByExpression): Statistics = default(p) @@ -160,4 +156,6 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { override def visitUnion(p: Union): Statistics = { Statistics(sizeInBytes = p.children.map(_.stats.sizeInBytes).sum) } + + override def visitWindow(p: Window): Statistics = visitUnaryNode(p) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 913be6d1ff07f..7d532ff343178 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.statsEstimation +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ @@ -54,6 +56,24 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { ) } + test("range") { + val range = Range(1, 5, 1, None) + val rangeStats = Statistics(sizeInBytes = 4 * 8) + checkStats( + range, + expectedStatsCboOn = rangeStats, + expectedStatsCboOff = rangeStats) + } + + test("windows") { + val windows = plan.window(Seq(min(attribute).as('sum_attr)), Seq(attribute), Nil) + val windowsStats = Statistics(sizeInBytes = plan.size.get * (4 + 4 + 8) / (4 + 8)) + checkStats( + windows, + expectedStatsCboOn = windowsStats, + expectedStatsCboOff = windowsStats) + } + test("limit estimation: limit < child's rowCount") { val localLimit = LocalLimit(Literal(2), plan) val globalLimit = GlobalLimit(Literal(2), plan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b825b6cd6160f..71ab0ddf2d6f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -980,7 +980,7 @@ class Dataset[T] private[sql]( * @param condition Join expression. * @param joinType Type of join to perform. Default `inner`. Must be one of: * `inner`, `cross`, `outer`, `full`, `full_outer`, `left`, `left_outer`, - * `right`, `right_outer`, `left_semi`, `left_anti`. + * `right`, `right_outer`. * * @group typedrel * @since 1.6.0 @@ -997,6 +997,10 @@ class Dataset[T] private[sql]( JoinType(joinType), Some(condition.expr))).analyzed.asInstanceOf[Join] + if (joined.joinType == LeftSemi || joined.joinType == LeftAnti) { + throw new AnalysisException("Invalid join type in joinWith: " + joined.joinType.sql) + } + // For both join side, combine all outputs into a single column and alias it with "_1" or "_2", // to match the schema for the encoder of the join result. // Note that we do this before joining them, to enable the join operator to return null for one diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 1007a7d55691b..34134db278ad8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -489,13 +489,13 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { * Inserts an InputAdapter on top of those that do not support codegen. */ private def insertInputAdapter(plan: SparkPlan): SparkPlan = plan match { - case j @ SortMergeJoinExec(_, _, _, _, left, right) if j.supportCodegen => - // The children of SortMergeJoin should do codegen separately. - j.copy(left = InputAdapter(insertWholeStageCodegen(left)), - right = InputAdapter(insertWholeStageCodegen(right))) case p if !supportCodegen(p) => // collapse them recursively InputAdapter(insertWholeStageCodegen(p)) + case j @ SortMergeJoinExec(_, _, _, _, left, right) => + // The children of SortMergeJoin should do codegen separately. + j.copy(left = InputAdapter(insertWholeStageCodegen(left)), + right = InputAdapter(insertWholeStageCodegen(right))) case p => p.withNewChildren(p.children.map(insertInputAdapter)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index d36a04f1fff8e..cbe8ce421f92b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -96,6 +96,24 @@ case class DataSource( bucket.sortColumnNames, "in the sort definition", equality) } + /** + * In the read path, only managed tables by Hive provide the partition columns properly when + * initializing this class. All other file based data sources will try to infer the partitioning, + * and then cast the inferred types to user specified dataTypes if the partition columns exist + * inside `userSpecifiedSchema`, otherwise we can hit data corruption bugs like SPARK-18510, or + * inconsistent data types as reported in SPARK-21463. + * @param fileIndex A FileIndex that will perform partition inference + * @return The PartitionSchema resolved from inference and cast according to `userSpecifiedSchema` + */ + private def combineInferredAndUserSpecifiedPartitionSchema(fileIndex: FileIndex): StructType = { + val resolved = fileIndex.partitionSchema.map { partitionField => + // SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred + userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse( + partitionField) + } + StructType(resolved) + } + /** * Get the schema of the given FileFormat, if provided by `userSpecifiedSchema`, or try to infer * it. In the read path, only managed tables by Hive provide the partition columns properly when @@ -139,12 +157,7 @@ case class DataSource( val partitionSchema = if (partitionColumns.isEmpty) { // Try to infer partitioning, because no DataSource in the read path provides the partitioning // columns properly unless it is a Hive DataSource - val resolved = tempFileIndex.partitionSchema.map { partitionField => - // SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred - userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse( - partitionField) - } - StructType(resolved) + combineInferredAndUserSpecifiedPartitionSchema(tempFileIndex) } else { // maintain old behavior before SPARK-18510. If userSpecifiedSchema is empty used inferred // partitioning @@ -336,7 +349,13 @@ case class DataSource( caseInsensitiveOptions.get("path").toSeq ++ paths, sparkSession.sessionState.newHadoopConf()) => val basePath = new Path((caseInsensitiveOptions.get("path").toSeq ++ paths).head) - val fileCatalog = new MetadataLogFileIndex(sparkSession, basePath) + val tempFileCatalog = new MetadataLogFileIndex(sparkSession, basePath, None) + val fileCatalog = if (userSpecifiedSchema.nonEmpty) { + val partitionSchema = combineInferredAndUserSpecifiedPartitionSchema(tempFileCatalog) + new MetadataLogFileIndex(sparkSession, basePath, Option(partitionSchema)) + } else { + tempFileCatalog + } val dataSchema = userSpecifiedSchema.orElse { format.inferSchema( sparkSession, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 57e9bc9b70454..24e13697c0c9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -271,7 +271,7 @@ private[jdbc] class JDBCRDD( conn = getConnection() val dialect = JdbcDialects.get(url) import scala.collection.JavaConverters._ - dialect.beforeFetch(conn, options.asConnectionProperties.asScala.toMap) + dialect.beforeFetch(conn, options.asProperties.asScala.toMap) // H2's JDBC driver does not support the setSchema() method. We pass a // fully-qualified table name in the SELECT statement. I don't know how to diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index a9e64c640042a..4b1b2520390ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -195,7 +195,7 @@ class FileStreamSource( private def allFilesUsingMetadataLogFileIndex() = { // Note if `sourceHasMetadata` holds, then `qualifiedBasePath` is guaranteed to be a // non-glob path - new MetadataLogFileIndex(sparkSession, qualifiedBasePath).allFiles() + new MetadataLogFileIndex(sparkSession, qualifiedBasePath, None).allFiles() } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala index aeaa134736937..1da703cefd8ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala @@ -23,14 +23,21 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.types.StructType /** * A [[FileIndex]] that generates the list of files to processing by reading them from the * metadata log files generated by the [[FileStreamSink]]. + * + * @param userPartitionSchema an optional partition schema that will be use to provide types for + * the discovered partitions */ -class MetadataLogFileIndex(sparkSession: SparkSession, path: Path) - extends PartitioningAwareFileIndex(sparkSession, Map.empty, None) { +class MetadataLogFileIndex( + sparkSession: SparkSession, + path: Path, + userPartitionSchema: Option[StructType]) + extends PartitioningAwareFileIndex(sparkSession, Map.empty, userPartitionSchema) { private val metadataDirectory = new Path(path, FileStreamSink.metadataDir) logInfo(s"Reading streaming file log from $metadataDirectory") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala index af2b4fb92062b..156002ef58fbe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala @@ -195,15 +195,6 @@ private[window] final class SlidingWindowFunctionFrame( override def write(index: Int, current: InternalRow): Unit = { var bufferUpdated = index == 0 - // Add all rows to the buffer for which the input row value is equal to or less than - // the output row upper bound. - while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) { - buffer.add(nextRow.copy()) - nextRow = WindowFunctionFrame.getNextOrNull(inputIterator) - inputHighIndex += 1 - bufferUpdated = true - } - // Drop all rows from the buffer for which the input row value is smaller than // the output row lower bound. while (!buffer.isEmpty && lbound.compare(buffer.peek(), inputLowIndex, current, index) < 0) { @@ -212,6 +203,19 @@ private[window] final class SlidingWindowFunctionFrame( bufferUpdated = true } + // Add all rows to the buffer for which the input row value is equal to or less than + // the output row upper bound. + while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) { + if (lbound.compare(nextRow, inputLowIndex, current, index) < 0) { + inputLowIndex += 1 + } else { + buffer.add(nextRow.copy()) + bufferUpdated = true + } + nextRow = WindowFunctionFrame.getNextOrNull(inputIterator) + inputHighIndex += 1 + } + // Only recalculate and update when the buffer changes. if (bufferUpdated) { processor.initialize(input.length) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 69d110e414278..73098cdb92471 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -21,6 +21,7 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} +import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange} @@ -400,6 +401,21 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ((("b", 2), ("b", 2)), ("b", 2))) } + test("joinWith join types") { + val ds1 = Seq(1, 2, 3).toDS().as("a") + val ds2 = Seq(1, 2).toDS().as("b") + + val e1 = intercept[AnalysisException] { + ds1.joinWith(ds2, $"a.value" === $"b.value", "left_semi") + }.getMessage + assert(e1.contains("Invalid join type in joinWith: " + LeftSemi.sql)) + + val e2 = intercept[AnalysisException] { + ds1.joinWith(ds2, $"a.value" === $"b.value", "left_anti") + }.getMessage + assert(e2.contains("Invalid join type in joinWith: " + LeftAnti.sql)) + } + test("groupBy function, keys") { val ds = Seq(("a", 1), ("b", 1)).toDS() val grouped = ds.groupByKey(v => (1, v._2)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala index 52c200796ce41..623a1b6f854cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala @@ -22,20 +22,22 @@ import java.util.concurrent.TimeUnit import scala.concurrent.duration._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.streaming.ProcessingTime +import org.apache.spark.sql.streaming.{ProcessingTime, Trigger} class ProcessingTimeSuite extends SparkFunSuite { test("create") { - assert(ProcessingTime(10.seconds).intervalMs === 10 * 1000) - assert(ProcessingTime.create(10, TimeUnit.SECONDS).intervalMs === 10 * 1000) - assert(ProcessingTime("1 minute").intervalMs === 60 * 1000) - assert(ProcessingTime("interval 1 minute").intervalMs === 60 * 1000) - - intercept[IllegalArgumentException] { ProcessingTime(null: String) } - intercept[IllegalArgumentException] { ProcessingTime("") } - intercept[IllegalArgumentException] { ProcessingTime("invalid") } - intercept[IllegalArgumentException] { ProcessingTime("1 month") } - intercept[IllegalArgumentException] { ProcessingTime("1 year") } + def getIntervalMs(trigger: Trigger): Long = trigger.asInstanceOf[ProcessingTime].intervalMs + + assert(getIntervalMs(Trigger.ProcessingTime(10.seconds)) === 10 * 1000) + assert(getIntervalMs(Trigger.ProcessingTime(10, TimeUnit.SECONDS)) === 10 * 1000) + assert(getIntervalMs(Trigger.ProcessingTime("1 minute")) === 60 * 1000) + assert(getIntervalMs(Trigger.ProcessingTime("interval 1 minute")) === 60 * 1000) + + intercept[IllegalArgumentException] { Trigger.ProcessingTime(null: String) } + intercept[IllegalArgumentException] { Trigger.ProcessingTime("") } + intercept[IllegalArgumentException] { Trigger.ProcessingTime("invalid") } + intercept[IllegalArgumentException] { Trigger.ProcessingTime("1 month") } + intercept[IllegalArgumentException] { Trigger.ProcessingTime("1 year") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala index 52e4f047225de..a9f3fb355c775 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala @@ -356,6 +356,46 @@ class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext { spark.catalog.dropTempView("nums") } + test("window function: mutiple window expressions specified by range in a single expression") { + val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y") + nums.createOrReplaceTempView("nums") + withTempView("nums") { + val expected = + Row(1, 1, 1, 4, null, 8, 25) :: + Row(1, 3, 4, 9, 1, 12, 24) :: + Row(1, 5, 9, 15, 4, 16, 21) :: + Row(1, 7, 16, 21, 8, 9, 16) :: + Row(1, 9, 25, 16, 12, null, 9) :: + Row(0, 2, 2, 6, null, 10, 30) :: + Row(0, 4, 6, 12, 2, 14, 28) :: + Row(0, 6, 12, 18, 6, 18, 24) :: + Row(0, 8, 20, 24, 10, 10, 18) :: + Row(0, 10, 30, 18, 14, null, 10) :: + Nil + + val actual = sql( + """ + |SELECT + | y, + | x, + | sum(x) over w1 as history_sum, + | sum(x) over w2 as period_sum1, + | sum(x) over w3 as period_sum2, + | sum(x) over w4 as period_sum3, + | sum(x) over w5 as future_sum + |FROM nums + |WINDOW + | w1 AS (PARTITION BY y ORDER BY x RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), + | w2 AS (PARTITION BY y ORDER BY x RANGE BETWEEN 2 PRECEDING AND 2 FOLLOWING), + | w3 AS (PARTITION BY y ORDER BY x RANGE BETWEEN 4 PRECEDING AND 2 PRECEDING ), + | w4 AS (PARTITION BY y ORDER BY x RANGE BETWEEN 2 FOLLOWING AND 4 FOLLOWING), + | w5 AS (PARTITION BY y ORDER BY x RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) + """.stripMargin + ) + checkAnswer(actual, expected) + } + } + test("SPARK-7595: Window will cause resolve failed with self join") { checkAnswer(sql( """ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index a4b30a2f8cec1..183c68fd3c016 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -22,8 +22,10 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.{Add, Literal, Stack} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec +import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions.{avg, broadcast, col, max} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} @@ -127,4 +129,24 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { "named_struct('a',id+2, 'b',id+2) as col2") .filter("col1 = col2").count() } + + test("SPARK-21441 SortMergeJoin codegen with CodegenFallback expressions should be disabled") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") { + import testImplicits._ + + val df1 = Seq((1, 1), (2, 2), (3, 3)).toDF("key", "int") + val df2 = Seq((1, "1"), (2, "2"), (3, "3")).toDF("key", "str") + + val df = df1.join(df2, df1("key") === df2("key")) + .filter("int = 2 or reflect('java.lang.Integer', 'valueOf', str) = 1") + .select("int") + + val plan = df.queryExecution.executedPlan + assert(!plan.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].child.children(0) + .isInstanceOf[SortMergeJoinExec]).isDefined) + assert(df.collect() === Array(Row(1), Row(2))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 84b34d5ad26d1..2f5fd8438f682 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.{PartitionPath => Partition} +import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -1022,4 +1023,36 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } } } + + test("SPARK-21463: MetadataLogFileIndex should respect userSpecifiedSchema for partition cols") { + withTempDir { tempDir => + val output = new File(tempDir, "output").toString + val checkpoint = new File(tempDir, "chkpoint").toString + try { + val stream = MemoryStream[(String, Int)] + val df = stream.toDS().toDF("time", "value") + val sq = df.writeStream + .option("checkpointLocation", checkpoint) + .format("parquet") + .partitionBy("time") + .start(output) + + stream.addData(("2017-01-01-00", 1), ("2017-01-01-01", 2)) + sq.processAllAvailable() + + val schema = new StructType() + .add("time", StringType) + .add("value", IntegerType) + val readBack = spark.read.schema(schema).parquet(output) + assert(readBack.schema.toSet === schema.toSet) + + checkAnswer( + readBack, + Seq(Row("2017-01-01-00", 1), Row("2017-01-01-01", 2)) + ) + } finally { + spark.streams.active.foreach(_.stop()) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 9f2f0d195de9f..a5399cdb6e5b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -664,7 +664,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf .flatMapGroupsWithState(Update, ProcessingTimeTimeout)(stateFunc) testStream(result, Update)( - StartStream(ProcessingTime("1 second"), triggerClock = clock), + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), AddData(inputData, "a"), AdvanceManualClock(1 * 1000), CheckLastBatch(("a", "1")), @@ -729,7 +729,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf .flatMapGroupsWithState(Update, EventTimeTimeout)(stateFunc) testStream(result, Update)( - StartStream(ProcessingTime("1 second")), + StartStream(Trigger.ProcessingTime("1 second")), AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), // Set timeout timestamp of ... CheckLastBatch(("a", 15)), // "a" to 15 + 5 = 20s, watermark to 5s AddData(inputData, ("a", 4)), // Add data older than watermark for "a" @@ -901,7 +901,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf .flatMapGroupsWithState(Update, ProcessingTimeTimeout)(stateFunc) testStream(result, Update)( - StartStream(ProcessingTime("1 second"), triggerClock = clock), + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), AddData(inputData, ("a", 1L)), AdvanceManualClock(1 * 1000), CheckLastBatch(("a", "1")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 4345a70601c34..b6e82b621c8cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -267,7 +267,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte .where('value >= current_timestamp().cast("long") - 10L) testStream(aggregated, Complete)( - StartStream(ProcessingTime("10 seconds"), triggerClock = clock), + StartStream(Trigger.ProcessingTime("10 seconds"), triggerClock = clock), // advance clock to 10 seconds, all keys retained AddData(inputData, 0L, 5L, 5L, 10L), @@ -294,7 +294,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte clock.advance(60 * 1000L) true }, - StartStream(ProcessingTime("10 seconds"), triggerClock = clock), + StartStream(Trigger.ProcessingTime("10 seconds"), triggerClock = clock), // The commit log blown, causing the last batch to re-run CheckLastBatch((20L, 1), (85L, 1)), AssertOnQuery { q => @@ -322,7 +322,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte .where($"value".cast("date") >= date_sub(current_date(), 10)) .select(($"value".cast("long") / DateTimeUtils.SECONDS_PER_DAY).cast("long"), $"count(1)") testStream(aggregated, Complete)( - StartStream(ProcessingTime("10 day"), triggerClock = clock), + StartStream(Trigger.ProcessingTime("10 day"), triggerClock = clock), // advance clock to 10 days, should retain all keys AddData(inputData, 0L, 5L, 5L, 10L), AdvanceManualClock(DateTimeUtils.MILLIS_PER_DAY * 10), @@ -346,7 +346,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte clock.advance(DateTimeUtils.MILLIS_PER_DAY * 60) true }, - StartStream(ProcessingTime("10 day"), triggerClock = clock), + StartStream(Trigger.ProcessingTime("10 day"), triggerClock = clock), // Commit log blown, causing a re-run of the last batch CheckLastBatch((20L, 1), (85L, 1)),