diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 512d539ee9c38..ef28e2c48ad02 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -321,6 +321,17 @@ package object config {
.intConf
.createWithDefault(3)
+ private[spark] val REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS =
+ ConfigBuilder("spark.reducer.maxBlocksInFlightPerAddress")
+ .doc("This configuration limits the number of remote blocks being fetched per reduce task" +
+ " from a given host port. When a large number of blocks are being requested from a given" +
+ " address in a single fetch or simultaneously, this could crash the serving executor or" +
+ " Node Manager. This is especially useful to reduce the load on the Node Manager when" +
+ " external shuffle is enabled. You can mitigate the issue by setting it to a lower value.")
+ .intConf
+ .checkValue(_ > 0, "The max no. of blocks in flight cannot be non-positive.")
+ .createWithDefault(Int.MaxValue)
+
private[spark] val REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM =
ConfigBuilder("spark.reducer.maxReqSizeShuffleToMem")
.doc("The blocks of a shuffle request will be fetched to disk when size of the request is " +
diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
index 2fbac79a2305b..c8d1460300934 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
@@ -51,6 +51,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
+ SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
SparkEnv.get.conf.get(config.REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM),
SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index 81d822dc8a98f..2d176b62f8b36 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -23,7 +23,7 @@ import java.util.concurrent.LinkedBlockingQueue
import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable
-import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue}
import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.internal.Logging
@@ -52,6 +52,8 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream
* @param streamWrapper A function to wrap the returned input stream.
* @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
* @param maxReqsInFlight max number of remote requests to fetch blocks at any given point.
+ * @param maxBlocksInFlightPerAddress max number of shuffle blocks being fetched at any given point
+ * for a given remote host:port.
* @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory.
* @param detectCorrupt whether to detect any corruption in fetched blocks.
*/
@@ -64,6 +66,7 @@ final class ShuffleBlockFetcherIterator(
streamWrapper: (BlockId, InputStream) => InputStream,
maxBytesInFlight: Long,
maxReqsInFlight: Int,
+ maxBlocksInFlightPerAddress: Int,
maxReqSizeShuffleToMem: Long,
detectCorrupt: Boolean)
extends Iterator[(BlockId, InputStream)] with TempShuffleFileManager with Logging {
@@ -110,12 +113,21 @@ final class ShuffleBlockFetcherIterator(
*/
private[this] val fetchRequests = new Queue[FetchRequest]
+ /**
+ * Queue of fetch requests which could not be issued the first time they were dequeued. These
+ * requests are tried again when the fetch constraints are satisfied.
+ */
+ private[this] val deferredFetchRequests = new HashMap[BlockManagerId, Queue[FetchRequest]]()
+
/** Current bytes in flight from our requests */
private[this] var bytesInFlight = 0L
/** Current number of requests in flight */
private[this] var reqsInFlight = 0
+ /** Current number of blocks in flight per host:port */
+ private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, Int]()
+
/**
* The blocks that can't be decompressed successfully, it is used to guarantee that we retry
* at most once for those corrupted blocks.
@@ -248,7 +260,8 @@ final class ShuffleBlockFetcherIterator(
// smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
// nodes, rather than blocking on reading output from one node.
val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
- logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)
+ logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize
+ + ", maxBlocksInFlightPerAddress: " + maxBlocksInFlightPerAddress)
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight.
@@ -277,11 +290,13 @@ final class ShuffleBlockFetcherIterator(
} else if (size < 0) {
throw new BlockException(blockId, "Negative block size " + size)
}
- if (curRequestSize >= targetRequestSize) {
+ if (curRequestSize >= targetRequestSize ||
+ curBlocks.size >= maxBlocksInFlightPerAddress) {
// Add this FetchRequest
remoteRequests += new FetchRequest(address, curBlocks)
+ logDebug(s"Creating fetch request of $curRequestSize at $address "
+ + s"with ${curBlocks.size} blocks")
curBlocks = new ArrayBuffer[(BlockId, Long)]
- logDebug(s"Creating fetch request of $curRequestSize at $address")
curRequestSize = 0
}
}
@@ -375,6 +390,7 @@ final class ShuffleBlockFetcherIterator(
result match {
case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) =>
if (address != blockManager.blockManagerId) {
+ numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
shuffleMetrics.incRemoteBytesRead(buf.size)
if (buf.isInstanceOf[FileSegmentManagedBuffer]) {
shuffleMetrics.incRemoteBytesReadToDisk(buf.size)
@@ -443,12 +459,57 @@ final class ShuffleBlockFetcherIterator(
}
private def fetchUpToMaxBytes(): Unit = {
- // Send fetch requests up to maxBytesInFlight
- while (fetchRequests.nonEmpty &&
- (bytesInFlight == 0 ||
- (reqsInFlight + 1 <= maxReqsInFlight &&
- bytesInFlight + fetchRequests.front.size <= maxBytesInFlight))) {
- sendRequest(fetchRequests.dequeue())
+ // Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host
+ // immediately, defer the request until the next time it can be processed.
+
+ // Process any outstanding deferred fetch requests if possible.
+ if (deferredFetchRequests.nonEmpty) {
+ for ((remoteAddress, defReqQueue) <- deferredFetchRequests) {
+ while (isRemoteBlockFetchable(defReqQueue) &&
+ !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) {
+ val request = defReqQueue.dequeue()
+ logDebug(s"Processing deferred fetch request for $remoteAddress with "
+ + s"${request.blocks.length} blocks")
+ send(remoteAddress, request)
+ if (defReqQueue.isEmpty) {
+ deferredFetchRequests -= remoteAddress
+ }
+ }
+ }
+ }
+
+ // Process any regular fetch requests if possible.
+ while (isRemoteBlockFetchable(fetchRequests)) {
+ val request = fetchRequests.dequeue()
+ val remoteAddress = request.address
+ if (isRemoteAddressMaxedOut(remoteAddress, request)) {
+ logDebug(s"Deferring fetch request for $remoteAddress with ${request.blocks.size} blocks")
+ val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]())
+ defReqQueue.enqueue(request)
+ deferredFetchRequests(remoteAddress) = defReqQueue
+ } else {
+ send(remoteAddress, request)
+ }
+ }
+
+ def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = {
+ sendRequest(request)
+ numBlocksInFlightPerAddress(remoteAddress) =
+ numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size
+ }
+
+ def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = {
+ fetchReqQueue.nonEmpty &&
+ (bytesInFlight == 0 ||
+ (reqsInFlight + 1 <= maxReqsInFlight &&
+ bytesInFlight + fetchReqQueue.front.size <= maxBytesInFlight))
+ }
+
+ // Checks if sending a new fetch request will exceed the max no. of blocks being fetched from a
+ // given remote address.
+ def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: FetchRequest): Boolean = {
+ numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size >
+ maxBlocksInFlightPerAddress
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala
index 1b2b1932e0c3d..eff0aa4453f08 100644
--- a/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala
+++ b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala
@@ -51,6 +51,10 @@ private[spark] class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Orderin
this
}
+ def poll(): A = {
+ underlying.poll()
+ }
+
override def +=(elem1: A, elem2: A, elems: A*): this.type = {
this += elem1 += elem2 ++= elems
}
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index 6a70cedf769b8..c371cbcf8dff5 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -110,6 +110,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
48 * 1024 * 1024,
Int.MaxValue,
Int.MaxValue,
+ Int.MaxValue,
true)
// 3 local blocks fetched in initialization
@@ -187,6 +188,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
48 * 1024 * 1024,
Int.MaxValue,
Int.MaxValue,
+ Int.MaxValue,
true)
verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release()
@@ -254,6 +256,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
48 * 1024 * 1024,
Int.MaxValue,
Int.MaxValue,
+ Int.MaxValue,
true)
// Continue only after the mock calls onBlockFetchFailure
@@ -319,6 +322,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
48 * 1024 * 1024,
Int.MaxValue,
Int.MaxValue,
+ Int.MaxValue,
true)
// Continue only after the mock calls onBlockFetchFailure
@@ -400,6 +404,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
48 * 1024 * 1024,
Int.MaxValue,
Int.MaxValue,
+ Int.MaxValue,
false)
// Continue only after the mock calls onBlockFetchFailure
@@ -457,6 +462,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
(_, in) => in,
maxBytesInFlight = Int.MaxValue,
maxReqsInFlight = Int.MaxValue,
+ maxBlocksInFlightPerAddress = Int.MaxValue,
maxReqSizeShuffleToMem = 200,
detectCorrupt = true)
}
diff --git a/core/src/test/scala/org/apache/spark/util/BoundedPriorityQueueSuite.scala b/core/src/test/scala/org/apache/spark/util/BoundedPriorityQueueSuite.scala
new file mode 100644
index 0000000000000..9465ca70e94f2
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/BoundedPriorityQueueSuite.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import org.apache.spark.SparkFunSuite
+
+class BoundedPriorityQueueSuite extends SparkFunSuite {
+ test("BoundedPriorityQueue poll test") {
+ val pq = new BoundedPriorityQueue[Double](4)
+
+ pq += 0.1
+ pq += 1.5
+ pq += 1.0
+ pq += 0.3
+ pq += 0.01
+
+ assert(pq.isEmpty == false)
+ assert(pq.poll() == 0.1)
+ assert(pq.poll() == 0.3)
+ assert(pq.poll() == 1.0)
+ assert(pq.poll() == 1.5)
+ assert(pq.isEmpty == true)
+
+ val pq2 = new BoundedPriorityQueue[(Int, Double)](4)(Ordering.by(_._2))
+ pq2 += 1 -> 0.5
+ pq2 += 5 -> 0.1
+ pq2 += 3 -> 0.3
+ pq2 += 4 -> 0.2
+ pq2 += 1 -> 0.4
+
+ assert(pq2.poll()._2 == 0.2)
+ assert(pq2.poll()._2 == 0.3)
+ assert(pq2.poll()._2 == 0.4)
+ assert(pq2.poll()._2 == 0.5)
+ }
+}
diff --git a/docs/configuration.md b/docs/configuration.md
index 91b5befd1b1eb..d3df923c42690 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -536,6 +536,15 @@ Apart from these, the following properties are also available, and may be useful
+ spark.reducer.maxBlocksInFlightPerAddress |
+ Int.MaxValue |
+
+ This configuration limits the number of remote blocks being fetched per reduce task from a
+ given host port. When a large number of blocks are being requested from a given address in a
+ single fetch or simultaneously, this could crash the serving executor or Node Manager. This
+ is especially useful to reduce the load on the Node Manager when external shuffle is enabled.
+ You can mitigate this issue by setting it to a lower value.
+ |
spark.reducer.maxReqSizeShuffleToMem |
Long.MaxValue |
diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md
index 7401b63e022c1..cf257c06c9516 100644
--- a/docs/running-on-mesos.md
+++ b/docs/running-on-mesos.md
@@ -545,6 +545,17 @@ See the [configuration page](configuration.html) for information on Spark config
Fetcher Cache
|
+
+ spark.mesos.driver.failoverTimeout |
+ 0.0 |
+
+ The amount of time (in seconds) that the master will wait for the
+ driver to reconnect, after being temporarily disconnected, before
+ it tears down the driver framework by killing all its
+ executors. The default value is zero, meaning no timeout: if the
+ driver disconnects, the master immediately tears down the framework.
+ |
+
# Troubleshooting and Debugging
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala
index 56d697f359614..6c8619e3c3c13 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala
@@ -58,9 +58,16 @@ package object config {
private [spark] val DRIVER_LABELS =
ConfigBuilder("spark.mesos.driver.labels")
- .doc("Mesos labels to add to the driver. Labels are free-form key-value pairs. Key-value" +
+ .doc("Mesos labels to add to the driver. Labels are free-form key-value pairs. Key-value " +
"pairs should be separated by a colon, and commas used to list more than one." +
"Ex. key:value,key2:value2")
.stringConf
.createOptional
+
+ private [spark] val DRIVER_FAILOVER_TIMEOUT =
+ ConfigBuilder("spark.mesos.driver.failoverTimeout")
+ .doc("Amount of time in seconds that the master will wait to hear from the driver, " +
+ "during a temporary disconnection, before tearing down all the executors.")
+ .doubleConf
+ .createWithDefault(0.0)
}
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala
index 7dd42c41aa7c2..6e7f41dad34ba 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala
@@ -29,6 +29,7 @@ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _}
import org.apache.mesos.SchedulerDriver
import org.apache.spark.{SecurityManager, SparkContext, SparkException, TaskState}
+import org.apache.spark.deploy.mesos.config._
import org.apache.spark.internal.config
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient
@@ -177,7 +178,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend(
sc.conf,
sc.conf.getOption("spark.mesos.driver.webui.url").orElse(sc.ui.map(_.webUrl)),
None,
- None,
+ Some(sc.conf.get(DRIVER_FAILOVER_TIMEOUT)),
sc.conf.getOption("spark.mesos.driver.frameworkId")
)
diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala
index 7cca5fedb31eb..d9ff4a403ea36 100644
--- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala
+++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala
@@ -33,6 +33,7 @@ import org.scalatest.mock.MockitoSugar
import org.scalatest.BeforeAndAfter
import org.apache.spark.{LocalSparkContext, SecurityManager, SparkConf, SparkContext, SparkFunSuite}
+import org.apache.spark.deploy.mesos.config._
import org.apache.spark.internal.config._
import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient
import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef}
@@ -369,6 +370,41 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite
backend.start()
}
+ test("failover timeout is set in created scheduler driver") {
+ val failoverTimeoutIn = 3600.0
+ initializeSparkConf(Map(DRIVER_FAILOVER_TIMEOUT.key -> failoverTimeoutIn.toString))
+ sc = new SparkContext(sparkConf)
+
+ val taskScheduler = mock[TaskSchedulerImpl]
+ when(taskScheduler.sc).thenReturn(sc)
+
+ val driver = mock[SchedulerDriver]
+ when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING)
+
+ val securityManager = mock[SecurityManager]
+
+ val backend = new MesosCoarseGrainedSchedulerBackend(
+ taskScheduler, sc, "master", securityManager) {
+ override protected def createSchedulerDriver(
+ masterUrl: String,
+ scheduler: Scheduler,
+ sparkUser: String,
+ appName: String,
+ conf: SparkConf,
+ webuiUrl: Option[String] = None,
+ checkpoint: Option[Boolean] = None,
+ failoverTimeout: Option[Double] = None,
+ frameworkId: Option[String] = None): SchedulerDriver = {
+ markRegistered()
+ assert(failoverTimeout.isDefined)
+ assert(failoverTimeout.get.equals(failoverTimeoutIn))
+ driver
+ }
+ }
+
+ backend.start()
+ }
+
test("honors unset spark.mesos.containerizer") {
setBackend(Map("spark.mesos.executor.docker.image" -> "test"))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala
index 2652f6d72730c..e0748043c46e2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala
@@ -35,13 +35,13 @@ trait LogicalPlanVisitor[T] {
case p: LocalLimit => visitLocalLimit(p)
case p: Pivot => visitPivot(p)
case p: Project => visitProject(p)
- case p: Range => visitRange(p)
case p: Repartition => visitRepartition(p)
case p: RepartitionByExpression => visitRepartitionByExpr(p)
case p: ResolvedHint => visitHint(p)
case p: Sample => visitSample(p)
case p: ScriptTransformation => visitScriptTransform(p)
case p: Union => visitUnion(p)
+ case p: Window => visitWindow(p)
case p: LogicalPlan => default(p)
}
@@ -73,8 +73,6 @@ trait LogicalPlanVisitor[T] {
def visitProject(p: Project): T
- def visitRange(p: Range): T
-
def visitRepartition(p: Repartition): T
def visitRepartitionByExpr(p: RepartitionByExpression): T
@@ -84,4 +82,6 @@ trait LogicalPlanVisitor[T] {
def visitScriptTransform(p: ScriptTransformation): T
def visitUnion(p: Union): T
+
+ def visitWindow(p: Window): T
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala
index 93908b04fb643..4cff72d45a400 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala
@@ -65,11 +65,6 @@ object BasicStatsPlanVisitor extends LogicalPlanVisitor[Statistics] {
ProjectEstimation.estimate(p).getOrElse(fallback(p))
}
- override def visitRange(p: logical.Range): Statistics = {
- val sizeInBytes = LongType.defaultSize * p.numElements
- Statistics(sizeInBytes = sizeInBytes)
- }
-
override def visitRepartition(p: Repartition): Statistics = fallback(p)
override def visitRepartitionByExpr(p: RepartitionByExpression): Statistics = fallback(p)
@@ -79,4 +74,6 @@ object BasicStatsPlanVisitor extends LogicalPlanVisitor[Statistics] {
override def visitScriptTransform(p: ScriptTransformation): Statistics = fallback(p)
override def visitUnion(p: Union): Statistics = fallback(p)
+
+ override def visitWindow(p: Window): Statistics = fallback(p)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala
index 559f12072e448..d701a956887a9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala
@@ -136,10 +136,6 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] {
override def visitProject(p: Project): Statistics = visitUnaryNode(p)
- override def visitRange(p: logical.Range): Statistics = {
- p.computeStats()
- }
-
override def visitRepartition(p: Repartition): Statistics = default(p)
override def visitRepartitionByExpr(p: RepartitionByExpression): Statistics = default(p)
@@ -160,4 +156,6 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] {
override def visitUnion(p: Union): Statistics = {
Statistics(sizeInBytes = p.children.map(_.stats.sizeInBytes).sum)
}
+
+ override def visitWindow(p: Window): Statistics = visitUnaryNode(p)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
index 913be6d1ff07f..7d532ff343178 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.statsEstimation
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
@@ -54,6 +56,24 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase {
)
}
+ test("range") {
+ val range = Range(1, 5, 1, None)
+ val rangeStats = Statistics(sizeInBytes = 4 * 8)
+ checkStats(
+ range,
+ expectedStatsCboOn = rangeStats,
+ expectedStatsCboOff = rangeStats)
+ }
+
+ test("windows") {
+ val windows = plan.window(Seq(min(attribute).as('sum_attr)), Seq(attribute), Nil)
+ val windowsStats = Statistics(sizeInBytes = plan.size.get * (4 + 4 + 8) / (4 + 8))
+ checkStats(
+ windows,
+ expectedStatsCboOn = windowsStats,
+ expectedStatsCboOff = windowsStats)
+ }
+
test("limit estimation: limit < child's rowCount") {
val localLimit = LocalLimit(Literal(2), plan)
val globalLimit = GlobalLimit(Literal(2), plan)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index b825b6cd6160f..71ab0ddf2d6f4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -980,7 +980,7 @@ class Dataset[T] private[sql](
* @param condition Join expression.
* @param joinType Type of join to perform. Default `inner`. Must be one of:
* `inner`, `cross`, `outer`, `full`, `full_outer`, `left`, `left_outer`,
- * `right`, `right_outer`, `left_semi`, `left_anti`.
+ * `right`, `right_outer`.
*
* @group typedrel
* @since 1.6.0
@@ -997,6 +997,10 @@ class Dataset[T] private[sql](
JoinType(joinType),
Some(condition.expr))).analyzed.asInstanceOf[Join]
+ if (joined.joinType == LeftSemi || joined.joinType == LeftAnti) {
+ throw new AnalysisException("Invalid join type in joinWith: " + joined.joinType.sql)
+ }
+
// For both join side, combine all outputs into a single column and alias it with "_1" or "_2",
// to match the schema for the encoder of the join result.
// Note that we do this before joining them, to enable the join operator to return null for one
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 1007a7d55691b..34134db278ad8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -489,13 +489,13 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] {
* Inserts an InputAdapter on top of those that do not support codegen.
*/
private def insertInputAdapter(plan: SparkPlan): SparkPlan = plan match {
- case j @ SortMergeJoinExec(_, _, _, _, left, right) if j.supportCodegen =>
- // The children of SortMergeJoin should do codegen separately.
- j.copy(left = InputAdapter(insertWholeStageCodegen(left)),
- right = InputAdapter(insertWholeStageCodegen(right)))
case p if !supportCodegen(p) =>
// collapse them recursively
InputAdapter(insertWholeStageCodegen(p))
+ case j @ SortMergeJoinExec(_, _, _, _, left, right) =>
+ // The children of SortMergeJoin should do codegen separately.
+ j.copy(left = InputAdapter(insertWholeStageCodegen(left)),
+ right = InputAdapter(insertWholeStageCodegen(right)))
case p =>
p.withNewChildren(p.children.map(insertInputAdapter))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index d36a04f1fff8e..cbe8ce421f92b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -96,6 +96,24 @@ case class DataSource(
bucket.sortColumnNames, "in the sort definition", equality)
}
+ /**
+ * In the read path, only managed tables by Hive provide the partition columns properly when
+ * initializing this class. All other file based data sources will try to infer the partitioning,
+ * and then cast the inferred types to user specified dataTypes if the partition columns exist
+ * inside `userSpecifiedSchema`, otherwise we can hit data corruption bugs like SPARK-18510, or
+ * inconsistent data types as reported in SPARK-21463.
+ * @param fileIndex A FileIndex that will perform partition inference
+ * @return The PartitionSchema resolved from inference and cast according to `userSpecifiedSchema`
+ */
+ private def combineInferredAndUserSpecifiedPartitionSchema(fileIndex: FileIndex): StructType = {
+ val resolved = fileIndex.partitionSchema.map { partitionField =>
+ // SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred
+ userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse(
+ partitionField)
+ }
+ StructType(resolved)
+ }
+
/**
* Get the schema of the given FileFormat, if provided by `userSpecifiedSchema`, or try to infer
* it. In the read path, only managed tables by Hive provide the partition columns properly when
@@ -139,12 +157,7 @@ case class DataSource(
val partitionSchema = if (partitionColumns.isEmpty) {
// Try to infer partitioning, because no DataSource in the read path provides the partitioning
// columns properly unless it is a Hive DataSource
- val resolved = tempFileIndex.partitionSchema.map { partitionField =>
- // SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred
- userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse(
- partitionField)
- }
- StructType(resolved)
+ combineInferredAndUserSpecifiedPartitionSchema(tempFileIndex)
} else {
// maintain old behavior before SPARK-18510. If userSpecifiedSchema is empty used inferred
// partitioning
@@ -336,7 +349,13 @@ case class DataSource(
caseInsensitiveOptions.get("path").toSeq ++ paths,
sparkSession.sessionState.newHadoopConf()) =>
val basePath = new Path((caseInsensitiveOptions.get("path").toSeq ++ paths).head)
- val fileCatalog = new MetadataLogFileIndex(sparkSession, basePath)
+ val tempFileCatalog = new MetadataLogFileIndex(sparkSession, basePath, None)
+ val fileCatalog = if (userSpecifiedSchema.nonEmpty) {
+ val partitionSchema = combineInferredAndUserSpecifiedPartitionSchema(tempFileCatalog)
+ new MetadataLogFileIndex(sparkSession, basePath, Option(partitionSchema))
+ } else {
+ tempFileCatalog
+ }
val dataSchema = userSpecifiedSchema.orElse {
format.inferSchema(
sparkSession,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 57e9bc9b70454..24e13697c0c9f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -271,7 +271,7 @@ private[jdbc] class JDBCRDD(
conn = getConnection()
val dialect = JdbcDialects.get(url)
import scala.collection.JavaConverters._
- dialect.beforeFetch(conn, options.asConnectionProperties.asScala.toMap)
+ dialect.beforeFetch(conn, options.asProperties.asScala.toMap)
// H2's JDBC driver does not support the setSchema() method. We pass a
// fully-qualified table name in the SELECT statement. I don't know how to
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
index a9e64c640042a..4b1b2520390ba 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
@@ -195,7 +195,7 @@ class FileStreamSource(
private def allFilesUsingMetadataLogFileIndex() = {
// Note if `sourceHasMetadata` holds, then `qualifiedBasePath` is guaranteed to be a
// non-glob path
- new MetadataLogFileIndex(sparkSession, qualifiedBasePath).allFiles()
+ new MetadataLogFileIndex(sparkSession, qualifiedBasePath, None).allFiles()
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala
index aeaa134736937..1da703cefd8ea 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala
@@ -23,14 +23,21 @@ import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.types.StructType
/**
* A [[FileIndex]] that generates the list of files to processing by reading them from the
* metadata log files generated by the [[FileStreamSink]].
+ *
+ * @param userPartitionSchema an optional partition schema that will be use to provide types for
+ * the discovered partitions
*/
-class MetadataLogFileIndex(sparkSession: SparkSession, path: Path)
- extends PartitioningAwareFileIndex(sparkSession, Map.empty, None) {
+class MetadataLogFileIndex(
+ sparkSession: SparkSession,
+ path: Path,
+ userPartitionSchema: Option[StructType])
+ extends PartitioningAwareFileIndex(sparkSession, Map.empty, userPartitionSchema) {
private val metadataDirectory = new Path(path, FileStreamSink.metadataDir)
logInfo(s"Reading streaming file log from $metadataDirectory")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
index af2b4fb92062b..156002ef58fbe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
@@ -195,15 +195,6 @@ private[window] final class SlidingWindowFunctionFrame(
override def write(index: Int, current: InternalRow): Unit = {
var bufferUpdated = index == 0
- // Add all rows to the buffer for which the input row value is equal to or less than
- // the output row upper bound.
- while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) {
- buffer.add(nextRow.copy())
- nextRow = WindowFunctionFrame.getNextOrNull(inputIterator)
- inputHighIndex += 1
- bufferUpdated = true
- }
-
// Drop all rows from the buffer for which the input row value is smaller than
// the output row lower bound.
while (!buffer.isEmpty && lbound.compare(buffer.peek(), inputLowIndex, current, index) < 0) {
@@ -212,6 +203,19 @@ private[window] final class SlidingWindowFunctionFrame(
bufferUpdated = true
}
+ // Add all rows to the buffer for which the input row value is equal to or less than
+ // the output row upper bound.
+ while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) {
+ if (lbound.compare(nextRow, inputLowIndex, current, index) < 0) {
+ inputLowIndex += 1
+ } else {
+ buffer.add(nextRow.copy())
+ bufferUpdated = true
+ }
+ nextRow = WindowFunctionFrame.getNextOrNull(inputIterator)
+ inputHighIndex += 1
+ }
+
// Only recalculate and update when the buffer changes.
if (bufferUpdated) {
processor.initialize(input.length)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 69d110e414278..73098cdb92471 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -21,6 +21,7 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.sql.{Date, Timestamp}
import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
+import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange}
@@ -400,6 +401,21 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
((("b", 2), ("b", 2)), ("b", 2)))
}
+ test("joinWith join types") {
+ val ds1 = Seq(1, 2, 3).toDS().as("a")
+ val ds2 = Seq(1, 2).toDS().as("b")
+
+ val e1 = intercept[AnalysisException] {
+ ds1.joinWith(ds2, $"a.value" === $"b.value", "left_semi")
+ }.getMessage
+ assert(e1.contains("Invalid join type in joinWith: " + LeftSemi.sql))
+
+ val e2 = intercept[AnalysisException] {
+ ds1.joinWith(ds2, $"a.value" === $"b.value", "left_anti")
+ }.getMessage
+ assert(e2.contains("Invalid join type in joinWith: " + LeftAnti.sql))
+ }
+
test("groupBy function, keys") {
val ds = Seq(("a", 1), ("b", 1)).toDS()
val grouped = ds.groupByKey(v => (1, v._2))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala
index 52c200796ce41..623a1b6f854cf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala
@@ -22,20 +22,22 @@ import java.util.concurrent.TimeUnit
import scala.concurrent.duration._
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.streaming.ProcessingTime
+import org.apache.spark.sql.streaming.{ProcessingTime, Trigger}
class ProcessingTimeSuite extends SparkFunSuite {
test("create") {
- assert(ProcessingTime(10.seconds).intervalMs === 10 * 1000)
- assert(ProcessingTime.create(10, TimeUnit.SECONDS).intervalMs === 10 * 1000)
- assert(ProcessingTime("1 minute").intervalMs === 60 * 1000)
- assert(ProcessingTime("interval 1 minute").intervalMs === 60 * 1000)
-
- intercept[IllegalArgumentException] { ProcessingTime(null: String) }
- intercept[IllegalArgumentException] { ProcessingTime("") }
- intercept[IllegalArgumentException] { ProcessingTime("invalid") }
- intercept[IllegalArgumentException] { ProcessingTime("1 month") }
- intercept[IllegalArgumentException] { ProcessingTime("1 year") }
+ def getIntervalMs(trigger: Trigger): Long = trigger.asInstanceOf[ProcessingTime].intervalMs
+
+ assert(getIntervalMs(Trigger.ProcessingTime(10.seconds)) === 10 * 1000)
+ assert(getIntervalMs(Trigger.ProcessingTime(10, TimeUnit.SECONDS)) === 10 * 1000)
+ assert(getIntervalMs(Trigger.ProcessingTime("1 minute")) === 60 * 1000)
+ assert(getIntervalMs(Trigger.ProcessingTime("interval 1 minute")) === 60 * 1000)
+
+ intercept[IllegalArgumentException] { Trigger.ProcessingTime(null: String) }
+ intercept[IllegalArgumentException] { Trigger.ProcessingTime("") }
+ intercept[IllegalArgumentException] { Trigger.ProcessingTime("invalid") }
+ intercept[IllegalArgumentException] { Trigger.ProcessingTime("1 month") }
+ intercept[IllegalArgumentException] { Trigger.ProcessingTime("1 year") }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
index 52e4f047225de..a9f3fb355c775 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
@@ -356,6 +356,46 @@ class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext {
spark.catalog.dropTempView("nums")
}
+ test("window function: mutiple window expressions specified by range in a single expression") {
+ val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y")
+ nums.createOrReplaceTempView("nums")
+ withTempView("nums") {
+ val expected =
+ Row(1, 1, 1, 4, null, 8, 25) ::
+ Row(1, 3, 4, 9, 1, 12, 24) ::
+ Row(1, 5, 9, 15, 4, 16, 21) ::
+ Row(1, 7, 16, 21, 8, 9, 16) ::
+ Row(1, 9, 25, 16, 12, null, 9) ::
+ Row(0, 2, 2, 6, null, 10, 30) ::
+ Row(0, 4, 6, 12, 2, 14, 28) ::
+ Row(0, 6, 12, 18, 6, 18, 24) ::
+ Row(0, 8, 20, 24, 10, 10, 18) ::
+ Row(0, 10, 30, 18, 14, null, 10) ::
+ Nil
+
+ val actual = sql(
+ """
+ |SELECT
+ | y,
+ | x,
+ | sum(x) over w1 as history_sum,
+ | sum(x) over w2 as period_sum1,
+ | sum(x) over w3 as period_sum2,
+ | sum(x) over w4 as period_sum3,
+ | sum(x) over w5 as future_sum
+ |FROM nums
+ |WINDOW
+ | w1 AS (PARTITION BY y ORDER BY x RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW),
+ | w2 AS (PARTITION BY y ORDER BY x RANGE BETWEEN 2 PRECEDING AND 2 FOLLOWING),
+ | w3 AS (PARTITION BY y ORDER BY x RANGE BETWEEN 4 PRECEDING AND 2 PRECEDING ),
+ | w4 AS (PARTITION BY y ORDER BY x RANGE BETWEEN 2 FOLLOWING AND 4 FOLLOWING),
+ | w5 AS (PARTITION BY y ORDER BY x RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)
+ """.stripMargin
+ )
+ checkAnswer(actual, expected)
+ }
+ }
+
test("SPARK-7595: Window will cause resolve failed with self join") {
checkAnswer(sql(
"""
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index a4b30a2f8cec1..183c68fd3c016 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -22,8 +22,10 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{Add, Literal, Stack}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
+import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.expressions.scalalang.typed
import org.apache.spark.sql.functions.{avg, broadcast, col, max}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
@@ -127,4 +129,24 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
"named_struct('a',id+2, 'b',id+2) as col2")
.filter("col1 = col2").count()
}
+
+ test("SPARK-21441 SortMergeJoin codegen with CodegenFallback expressions should be disabled") {
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") {
+ import testImplicits._
+
+ val df1 = Seq((1, 1), (2, 2), (3, 3)).toDF("key", "int")
+ val df2 = Seq((1, "1"), (2, "2"), (3, "3")).toDF("key", "str")
+
+ val df = df1.join(df2, df1("key") === df2("key"))
+ .filter("int = 2 or reflect('java.lang.Integer', 'valueOf', str) = 1")
+ .select("int")
+
+ val plan = df.queryExecution.executedPlan
+ assert(!plan.find(p =>
+ p.isInstanceOf[WholeStageCodegenExec] &&
+ p.asInstanceOf[WholeStageCodegenExec].child.children(0)
+ .isInstanceOf[SortMergeJoinExec]).isDefined)
+ assert(df.collect() === Array(Row(1), Row(2)))
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
index 84b34d5ad26d1..2f5fd8438f682 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
@@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.{PartitionPath => Partition}
+import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
@@ -1022,4 +1023,36 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
}
}
}
+
+ test("SPARK-21463: MetadataLogFileIndex should respect userSpecifiedSchema for partition cols") {
+ withTempDir { tempDir =>
+ val output = new File(tempDir, "output").toString
+ val checkpoint = new File(tempDir, "chkpoint").toString
+ try {
+ val stream = MemoryStream[(String, Int)]
+ val df = stream.toDS().toDF("time", "value")
+ val sq = df.writeStream
+ .option("checkpointLocation", checkpoint)
+ .format("parquet")
+ .partitionBy("time")
+ .start(output)
+
+ stream.addData(("2017-01-01-00", 1), ("2017-01-01-01", 2))
+ sq.processAllAvailable()
+
+ val schema = new StructType()
+ .add("time", StringType)
+ .add("value", IntegerType)
+ val readBack = spark.read.schema(schema).parquet(output)
+ assert(readBack.schema.toSet === schema.toSet)
+
+ checkAnswer(
+ readBack,
+ Seq(Row("2017-01-01-00", 1), Row("2017-01-01-01", 2))
+ )
+ } finally {
+ spark.streams.active.foreach(_.stop())
+ }
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
index 9f2f0d195de9f..a5399cdb6e5b1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
@@ -664,7 +664,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
.flatMapGroupsWithState(Update, ProcessingTimeTimeout)(stateFunc)
testStream(result, Update)(
- StartStream(ProcessingTime("1 second"), triggerClock = clock),
+ StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock),
AddData(inputData, "a"),
AdvanceManualClock(1 * 1000),
CheckLastBatch(("a", "1")),
@@ -729,7 +729,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
.flatMapGroupsWithState(Update, EventTimeTimeout)(stateFunc)
testStream(result, Update)(
- StartStream(ProcessingTime("1 second")),
+ StartStream(Trigger.ProcessingTime("1 second")),
AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), // Set timeout timestamp of ...
CheckLastBatch(("a", 15)), // "a" to 15 + 5 = 20s, watermark to 5s
AddData(inputData, ("a", 4)), // Add data older than watermark for "a"
@@ -901,7 +901,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
.flatMapGroupsWithState(Update, ProcessingTimeTimeout)(stateFunc)
testStream(result, Update)(
- StartStream(ProcessingTime("1 second"), triggerClock = clock),
+ StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock),
AddData(inputData, ("a", 1L)),
AdvanceManualClock(1 * 1000),
CheckLastBatch(("a", "1"))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
index 4345a70601c34..b6e82b621c8cb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
@@ -267,7 +267,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte
.where('value >= current_timestamp().cast("long") - 10L)
testStream(aggregated, Complete)(
- StartStream(ProcessingTime("10 seconds"), triggerClock = clock),
+ StartStream(Trigger.ProcessingTime("10 seconds"), triggerClock = clock),
// advance clock to 10 seconds, all keys retained
AddData(inputData, 0L, 5L, 5L, 10L),
@@ -294,7 +294,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte
clock.advance(60 * 1000L)
true
},
- StartStream(ProcessingTime("10 seconds"), triggerClock = clock),
+ StartStream(Trigger.ProcessingTime("10 seconds"), triggerClock = clock),
// The commit log blown, causing the last batch to re-run
CheckLastBatch((20L, 1), (85L, 1)),
AssertOnQuery { q =>
@@ -322,7 +322,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte
.where($"value".cast("date") >= date_sub(current_date(), 10))
.select(($"value".cast("long") / DateTimeUtils.SECONDS_PER_DAY).cast("long"), $"count(1)")
testStream(aggregated, Complete)(
- StartStream(ProcessingTime("10 day"), triggerClock = clock),
+ StartStream(Trigger.ProcessingTime("10 day"), triggerClock = clock),
// advance clock to 10 days, should retain all keys
AddData(inputData, 0L, 5L, 5L, 10L),
AdvanceManualClock(DateTimeUtils.MILLIS_PER_DAY * 10),
@@ -346,7 +346,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte
clock.advance(DateTimeUtils.MILLIS_PER_DAY * 60)
true
},
- StartStream(ProcessingTime("10 day"), triggerClock = clock),
+ StartStream(Trigger.ProcessingTime("10 day"), triggerClock = clock),
// Commit log blown, causing a re-run of the last batch
CheckLastBatch((20L, 1), (85L, 1)),