diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index ec85f723c08c6..88a138fd8eb1f 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2818,14 +2818,14 @@ setMethod("write.df", signature(df = "SparkDataFrame"), function(df, path = NULL, source = NULL, mode = "error", ...) { if (!is.null(path) && !is.character(path)) { - stop("path should be charactor, NULL or omitted.") + stop("path should be character, NULL or omitted.") } if (!is.null(source) && !is.character(source)) { stop("source should be character, NULL or omitted. It is the datasource specified ", "in 'spark.sql.sources.default' configuration by default.") } if (!is.character(mode)) { - stop("mode should be charactor or omitted. It is 'error' by default.") + stop("mode should be character or omitted. It is 'error' by default.") } if (is.null(source)) { source <- getDefaultSqlSource() @@ -3040,7 +3040,7 @@ setMethod("fillna", signature(x = "SparkDataFrame"), function(x, value, cols = NULL) { if (!(class(value) %in% c("integer", "numeric", "character", "list"))) { - stop("value should be an integer, numeric, charactor or named list.") + stop("value should be an integer, numeric, character or named list.") } if (class(value) == "list") { @@ -3052,7 +3052,7 @@ setMethod("fillna", # Check each item in the named list is of valid type lapply(value, function(v) { if (!(class(v) %in% c("integer", "numeric", "character"))) { - stop("Each item in value should be an integer, numeric or charactor.") + stop("Each item in value should be an integer, numeric or character.") } }) @@ -3598,7 +3598,7 @@ setMethod("write.stream", "in 'spark.sql.sources.default' configuration by default.") } if (!is.null(outputMode) && !is.character(outputMode)) { - stop("outputMode should be charactor or omitted.") + stop("outputMode should be character or omitted.") } if (is.null(source)) { source <- getDefaultSqlSource() diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index c2a1e240ad395..f5c3a749fe0a1 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -606,7 +606,7 @@ tableToDF <- function(tableName) { #' @note read.df since 1.4.0 read.df.default <- function(path = NULL, source = NULL, schema = NULL, na.strings = "NA", ...) { if (!is.null(path) && !is.character(path)) { - stop("path should be charactor, NULL or omitted.") + stop("path should be character, NULL or omitted.") } if (!is.null(source) && !is.character(source)) { stop("source should be character, NULL or omitted. It is the datasource specified ", diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 58cf24256a94f..3fbb618ddfc39 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -2926,9 +2926,9 @@ test_that("Call DataFrameWriter.save() API in Java without path and check argume paste("source should be character, NULL or omitted. It is the datasource specified", "in 'spark.sql.sources.default' configuration by default.")) expect_error(write.df(df, path = c(3)), - "path should be charactor, NULL or omitted.") + "path should be character, NULL or omitted.") expect_error(write.df(df, mode = TRUE), - "mode should be charactor or omitted. It is 'error' by default.") + "mode should be character or omitted. It is 'error' by default.") }) test_that("Call DataFrameWriter.load() API in Java without path and check argument types", { @@ -2947,7 +2947,7 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume # Arguments checking in R side. expect_error(read.df(path = c(3)), - "path should be charactor, NULL or omitted.") + "path should be character, NULL or omitted.") expect_error(read.df(jsonPath, source = c(1, 2)), paste("source should be character, NULL or omitted. It is the datasource specified", "in 'spark.sql.sources.default' configuration by default.")) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala index 0ba95169529e6..97eed540b8f59 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala @@ -35,7 +35,7 @@ private[spark] trait RpcEnvFactory { * * The life-cycle of an endpoint is: * - * constructor -> onStart -> receive* -> onStop + * {@code constructor -> onStart -> receive* -> onStop} * * Note: `receive` can be called concurrently. If you want `receive` to be thread-safe, please use * [[ThreadSafeRpcEndpoint]] @@ -63,16 +63,16 @@ private[spark] trait RpcEndpoint { } /** - * Process messages from [[RpcEndpointRef.send]] or [[RpcCallContext.reply)]]. If receiving a - * unmatched message, [[SparkException]] will be thrown and sent to `onError`. + * Process messages from `RpcEndpointRef.send` or `RpcCallContext.reply`. If receiving a + * unmatched message, `SparkException` will be thrown and sent to `onError`. */ def receive: PartialFunction[Any, Unit] = { case _ => throw new SparkException(self + " does not implement 'receive'") } /** - * Process messages from [[RpcEndpointRef.ask]]. If receiving a unmatched message, - * [[SparkException]] will be thrown and sent to `onError`. + * Process messages from `RpcEndpointRef.ask`. If receiving a unmatched message, + * `SparkException` will be thrown and sent to `onError`. */ def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case _ => context.sendFailure(new SparkException(self + " won't reply anything")) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala index 2c9a976e76939..0557b7a3cc0b7 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala @@ -26,7 +26,7 @@ import org.apache.spark.SparkConf import org.apache.spark.util.{ThreadUtils, Utils} /** - * An exception thrown if RpcTimeout modifies a [[TimeoutException]]. + * An exception thrown if RpcTimeout modifies a `TimeoutException`. */ private[rpc] class RpcTimeoutException(message: String, cause: TimeoutException) extends TimeoutException(message) { initCause(cause) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 09717316833a7..aab177f257a8c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -607,7 +607,7 @@ class DAGScheduler( * @param resultHandler callback to pass each result to * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name * - * @throws Exception when the job fails + * @note Throws `Exception` when the job fails */ def runJob[T, U]( rdd: RDD[T], @@ -644,7 +644,7 @@ class DAGScheduler( * * @param rdd target RDD to run tasks on * @param func a function to run on each partition of the RDD - * @param evaluator [[ApproximateEvaluator]] to receive the partial results + * @param evaluator `ApproximateEvaluator` to receive the partial results * @param callSite where in the user program this job was called * @param timeout maximum time to wait for the job, in milliseconds * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExternalClusterManager.scala b/core/src/main/scala/org/apache/spark/scheduler/ExternalClusterManager.scala index d1ac7131baba5..47f3527a32c01 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ExternalClusterManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ExternalClusterManager.scala @@ -42,7 +42,7 @@ private[spark] trait ExternalClusterManager { /** * Create a scheduler backend for the given SparkContext and scheduler. This is - * called after task scheduler is created using [[ExternalClusterManager.createTaskScheduler()]]. + * called after task scheduler is created using `ExternalClusterManager.createTaskScheduler()`. * @param sc SparkContext * @param masterURL the master URL * @param scheduler TaskScheduler that will be used with the scheduler backend. diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index c849a16023a7a..1b6bc9139f9c9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -38,7 +38,7 @@ import org.apache.spark.util.{AccumulatorV2, ThreadUtils, Utils} /** * Schedules tasks for multiple types of clusters by acting through a SchedulerBackend. - * It can also work with a local setup by using a [[LocalSchedulerBackend]] and setting + * It can also work with a local setup by using a `LocalSchedulerBackend` and setting * isLocal to true. It handles common logic, like determining a scheduling order across jobs, waking * up to launch speculative tasks, etc. * @@ -704,12 +704,12 @@ private[spark] object TaskSchedulerImpl { * Used to balance containers across hosts. * * Accepts a map of hosts to resource offers for that host, and returns a prioritized list of - * resource offers representing the order in which the offers should be used. The resource + * resource offers representing the order in which the offers should be used. The resource * offers are ordered such that we'll allocate one container on each host before allocating a * second container on any host, and so on, in order to reduce the damage if a host fails. * - * For example, given , , , returns - * [o1, o5, o4, 02, o6, o3] + * For example, given {@literal }, {@literal } and + * {@literal }, returns {@literal [o1, o5, o4, o2, o6, o3]}. */ def prioritizeContainers[K, T] (map: HashMap[K, ArrayBuffer[T]]): List[T] = { val _keyList = new ArrayBuffer[K](map.size) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 7befdb0c1f64d..0529fe9eed4da 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler.cluster import java.util.concurrent.Semaphore +import java.util.concurrent.atomic.AtomicBoolean import scala.concurrent.Future @@ -42,7 +43,7 @@ private[spark] class StandaloneSchedulerBackend( with Logging { private var client: StandaloneAppClient = null - private var stopping = false + private val stopping = new AtomicBoolean(false) private val launcherBackend = new LauncherBackend() { override protected def onStopRequest(): Unit = stop(SparkAppHandle.State.KILLED) } @@ -112,7 +113,7 @@ private[spark] class StandaloneSchedulerBackend( launcherBackend.setState(SparkAppHandle.State.RUNNING) } - override def stop(): Unit = synchronized { + override def stop(): Unit = { stop(SparkAppHandle.State.FINISHED) } @@ -125,14 +126,14 @@ private[spark] class StandaloneSchedulerBackend( override def disconnected() { notifyContext() - if (!stopping) { + if (!stopping.get) { logWarning("Disconnected from Spark cluster! Waiting for reconnection...") } } override def dead(reason: String) { notifyContext() - if (!stopping) { + if (!stopping.get) { launcherBackend.setState(SparkAppHandle.State.KILLED) logError("Application has been killed. Reason: " + reason) try { @@ -206,20 +207,20 @@ private[spark] class StandaloneSchedulerBackend( registrationBarrier.release() } - private def stop(finalState: SparkAppHandle.State): Unit = synchronized { - try { - stopping = true - - super.stop() - client.stop() + private def stop(finalState: SparkAppHandle.State): Unit = { + if (stopping.compareAndSet(false, true)) { + try { + super.stop() + client.stop() - val callback = shutdownCallback - if (callback != null) { - callback(this) + val callback = shutdownCallback + if (callback != null) { + callback(this) + } + } finally { + launcherBackend.setState(finalState) + launcherBackend.close() } - } finally { - launcherBackend.setState(finalState) - launcherBackend.close() } } diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index 01bbda0b5e6b3..cb8b1cc077637 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -125,7 +125,7 @@ abstract class SerializerInstance { * A stream for writing serialized objects. */ @DeveloperApi -abstract class SerializationStream { +abstract class SerializationStream extends Closeable { /** The most general-purpose method to write an object. */ def writeObject[T: ClassTag](t: T): SerializationStream /** Writes the object representing the key of a key-value pair. */ @@ -133,7 +133,7 @@ abstract class SerializationStream { /** Writes the object representing the value of a key-value pair. */ def writeValue[T: ClassTag](value: T): SerializationStream = writeObject(value) def flush(): Unit - def close(): Unit + override def close(): Unit def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream = { while (iter.hasNext) { @@ -149,14 +149,14 @@ abstract class SerializationStream { * A stream for reading serialized objects. */ @DeveloperApi -abstract class DeserializationStream { +abstract class DeserializationStream extends Closeable { /** The most general-purpose method to read an object. */ def readObject[T: ClassTag](): T /** Reads the object representing the key of a key-value pair. */ def readKey[T: ClassTag](): T = readObject[T]() /** Reads the object representing the value of a key-value pair. */ def readValue[T: ClassTag](): T = readObject[T]() - def close(): Unit + override def close(): Unit /** * Read the elements of this stream through an iterator. This can only be called once, as diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 63acba65d3c5b..3219969bcd06f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -66,7 +66,7 @@ private[spark] trait BlockData { /** * Returns a Netty-friendly wrapper for the block's data. * - * @see [[ManagedBuffer#convertToNetty()]] + * Please see `ManagedBuffer.convertToNetty()` for more details. */ def toNetty(): Object diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 6d03ee091e4ed..ddbcb2d19dcbb 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -243,7 +243,7 @@ private[spark] object AccumulatorSuite { import InternalAccumulator._ /** - * Create a long accumulator and register it to [[AccumulatorContext]]. + * Create a long accumulator and register it to `AccumulatorContext`. */ def createLongAccum( name: String, @@ -258,7 +258,7 @@ private[spark] object AccumulatorSuite { } /** - * Make an [[AccumulableInfo]] out of an [[Accumulable]] with the intent to use the + * Make an `AccumulableInfo` out of an [[Accumulable]] with the intent to use the * info as an accumulator update. */ def makeInfo(a: AccumulatorV2[_, _]): AccumulableInfo = a.toInfo(Some(a.value), None) diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index eb3fb99747d12..fe944031bc948 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.network.shuffle.{ExternalShuffleBlockHandler, ExternalSh /** * This suite creates an external shuffle server and routes all shuffle fetches through it. * Note that failures in this suite may arise due to changes in Spark that invalidate expectations - * set up in [[ExternalShuffleBlockHandler]], such as changing the format of shuffle files or how + * set up in `ExternalShuffleBlockHandler`, such as changing the format of shuffle files or how * we hash files into folders. */ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala index 24ec99c7e5e60..1dd89bcbe36bc 100644 --- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala @@ -22,7 +22,7 @@ import org.scalatest.BeforeAndAfterAll import org.scalatest.BeforeAndAfterEach import org.scalatest.Suite -/** Manages a local `sc` {@link SparkContext} variable, correctly stopping it after each test. */ +/** Manages a local `sc` `SparkContext` variable, correctly stopping it after each test. */ trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self: Suite => @transient var sc: SparkContext = _ diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala index 8103983c4392a..8300607ea888b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala @@ -95,12 +95,12 @@ abstract class SchedulerIntegrationSuite[T <: MockBackend: ClassTag] extends Spa } /** - * A map from partition -> results for all tasks of a job when you call this test framework's + * A map from partition to results for all tasks of a job when you call this test framework's * [[submit]] method. Two important considerations: * * 1. If there is a job failure, results may or may not be empty. If any tasks succeed before * the job has failed, they will get included in `results`. Instead, check for job failure by - * checking [[failure]]. (Also see [[assertDataStructuresEmpty()]]) + * checking [[failure]]. (Also see `assertDataStructuresEmpty()`) * * 2. This only gets cleared between tests. So you'll need to do special handling if you submit * more than one job in one test. diff --git a/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala b/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala index 4ce3b941bea55..99882bf76e29d 100644 --- a/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.serializer.KryoTest.RegistratorWithoutAutoReset /** * Tests to ensure that [[Serializer]] implementations obey the API contracts for methods that * describe properties of the serialized stream, such as - * [[Serializer.supportsRelocationOfSerializedObjects]]. + * `Serializer.supportsRelocationOfSerializedObjects`. */ class SerializerPropertiesSuite extends SparkFunSuite { diff --git a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala index ecad0f5352e59..dfecd04c1b969 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala @@ -70,9 +70,18 @@ class RandomBlockReplicationPolicyBehavior extends SparkFunSuite } } + /** + * Returns a sequence of [[BlockManagerId]], whose rack is randomly picked from the given `racks`. + * Note that, each rack will be picked at least once from `racks`, if `count` is greater or equal + * to the number of `racks`. + */ protected def generateBlockManagerIds(count: Int, racks: Seq[String]): Seq[BlockManagerId] = { - (1 to count).map{i => - BlockManagerId(s"Exec-$i", s"Host-$i", 10000 + i, Some(racks(Random.nextInt(racks.size)))) + val randomizedRacks: Seq[String] = Random.shuffle( + racks ++ racks.length.until(count).map(_ => racks(Random.nextInt(racks.length))) + ) + + (0 until count).map { i => + BlockManagerId(s"Exec-$i", s"Host-$i", 10000 + i, Some(randomizedRacks(i))) } } } diff --git a/dev/run-tests.py b/dev/run-tests.py index 04035b33e6a6b..450b68123e1fc 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -344,6 +344,19 @@ def build_spark_sbt(hadoop_version): exec_sbt(profiles_and_goals) +def build_spark_unidoc_sbt(hadoop_version): + set_title_and_block("Building Unidoc API Documentation", "BLOCK_DOCUMENTATION") + # Enable all of the profiles for the build: + build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags + sbt_goals = ["unidoc"] + profiles_and_goals = build_profiles + sbt_goals + + print("[info] Building Spark unidoc (w/Hive 1.2.1) using SBT with these arguments: ", + " ".join(profiles_and_goals)) + + exec_sbt(profiles_and_goals) + + def build_spark_assembly_sbt(hadoop_version): # Enable all of the profiles for the build: build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags @@ -352,6 +365,8 @@ def build_spark_assembly_sbt(hadoop_version): print("[info] Building Spark assembly (w/Hive 1.2.1) using SBT with these arguments: ", " ".join(profiles_and_goals)) exec_sbt(profiles_and_goals) + # Make sure that Java and Scala API documentation can be generated + build_spark_unidoc_sbt(hadoop_version) def build_apache_spark(build_tool, hadoop_version): diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index c425faca4c273..28942b68fa20d 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -883,7 +883,7 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession
Spark SQL can automatically infer the schema of a JSON dataset and load it as a `Dataset[Row]`. -This conversion can be done using `SparkSession.read.json()` on either an RDD of String, +This conversion can be done using `SparkSession.read.json()` on either a `Dataset[String]`, or a JSON file. Note that the file that is offered as _a json file_ is not a typical JSON file. Each @@ -897,7 +897,7 @@ For a regular multi-line JSON file, set the `wholeFile` option to `true`.
Spark SQL can automatically infer the schema of a JSON dataset and load it as a `Dataset`. -This conversion can be done using `SparkSession.read().json()` on either an RDD of String, +This conversion can be done using `SparkSession.read().json()` on either a `Dataset`, or a JSON file. Note that the file that is offered as _a json file_ is not a typical JSON file. Each diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 37a1d6189a42d..3cf7151819e2d 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -8,7 +8,7 @@ title: Structured Streaming Programming Guide {:toc} # Overview -Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data.The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java or Python to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* +Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java or Python to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* **Structured Streaming is still ALPHA in Spark 2.1** and the APIs are still experimental. In this guide, we are going to walk you through the programming model and the APIs. First, let's start with a simple example - a streaming word count. @@ -362,7 +362,7 @@ A query on the input will generate the "Result Table". Every trigger interval (s ![Model](img/structured-streaming-model.png) -The "Output" is defined as what gets written out to the external storage. The output can be defined in different modes +The "Output" is defined as what gets written out to the external storage. The output can be defined in a different mode: - *Complete Mode* - The entire updated Result Table will be written to the external storage. It is up to the storage connector to decide how to handle writing of the entire table. diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java index 1a7054614b348..b66abaed66000 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java @@ -215,7 +215,7 @@ private static void runJsonDatasetExample(SparkSession spark) { // +------+ // Alternatively, a DataFrame can be created for a JSON dataset represented by - // an Dataset[String] storing one JSON object per string. + // a Dataset storing one JSON object per string. List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); Dataset anotherPeopleDataset = spark.createDataset(jsonData, Encoders.STRING()); diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala index 82fd56de39847..ad74da72bd5e6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala @@ -139,7 +139,7 @@ object SQLDataSourceExample { // +------+ // Alternatively, a DataFrame can be created for a JSON dataset represented by - // an Dataset[String] storing one JSON object per string + // a Dataset[String] storing one JSON object per string val otherPeopleDataset = spark.createDataset( """{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}""" :: Nil) val otherPeople = spark.read.json(otherPeopleDataset) diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala index 8970ad2bafda0..77553412eda56 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala @@ -267,7 +267,7 @@ object KinesisInputDStream { getRequiredParam(checkpointAppName, "checkpointAppName"), checkpointInterval.getOrElse(ssc.graph.batchDuration), storageLevel.getOrElse(DEFAULT_STORAGE_LEVEL), - handler, + ssc.sc.clean(handler), kinesisCredsProvider.getOrElse(DefaultCredentials), dynamoDBCredsProvider, cloudWatchCredsProvider) diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index ed7e35805026e..341a6898cbbff 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -22,7 +22,6 @@ import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.Random -import com.amazonaws.regions.RegionUtils import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.model.Record import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} @@ -173,11 +172,15 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun * and you have to set the system environment variable RUN_KINESIS_TESTS=1 . */ testIfEnabled("basic operation") { - val awsCredentials = KinesisTestUtils.getAWSCredentials() - val stream = KinesisUtils.createStream(ssc, appName, testUtils.streamName, - testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST, - Seconds(10), StorageLevel.MEMORY_ONLY, - awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) + val stream = KinesisInputDStream.builder.streamingContext(ssc) + .checkpointAppName(appName) + .streamName(testUtils.streamName) + .endpointUrl(testUtils.endpointUrl) + .regionName(testUtils.regionName) + .initialPositionInStream(InitialPositionInStream.LATEST) + .checkpointInterval(Seconds(10)) + .storageLevel(StorageLevel.MEMORY_ONLY) + .build() val collected = new mutable.HashSet[Int] stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd => @@ -198,12 +201,17 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun } testIfEnabled("custom message handling") { - val awsCredentials = KinesisTestUtils.getAWSCredentials() def addFive(r: Record): Int = JavaUtils.bytesToString(r.getData).toInt + 5 - val stream = KinesisUtils.createStream(ssc, appName, testUtils.streamName, - testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST, - Seconds(10), StorageLevel.MEMORY_ONLY, addFive(_), - awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) + + val stream = KinesisInputDStream.builder.streamingContext(ssc) + .checkpointAppName(appName) + .streamName(testUtils.streamName) + .endpointUrl(testUtils.endpointUrl) + .regionName(testUtils.regionName) + .initialPositionInStream(InitialPositionInStream.LATEST) + .checkpointInterval(Seconds(10)) + .storageLevel(StorageLevel.MEMORY_ONLY) + .buildWithMessageHandler(addFive(_)) stream shouldBe a [ReceiverInputDStream[_]] @@ -233,11 +241,15 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun val localTestUtils = new KPLBasedKinesisTestUtils(1) localTestUtils.createStream() try { - val awsCredentials = KinesisTestUtils.getAWSCredentials() - val stream = KinesisUtils.createStream(ssc, localAppName, localTestUtils.streamName, - localTestUtils.endpointUrl, localTestUtils.regionName, InitialPositionInStream.LATEST, - Seconds(10), StorageLevel.MEMORY_ONLY, - awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) + val stream = KinesisInputDStream.builder.streamingContext(ssc) + .checkpointAppName(localAppName) + .streamName(localTestUtils.streamName) + .endpointUrl(localTestUtils.endpointUrl) + .regionName(localTestUtils.regionName) + .initialPositionInStream(InitialPositionInStream.LATEST) + .checkpointInterval(Seconds(10)) + .storageLevel(StorageLevel.MEMORY_ONLY) + .build() val collected = new mutable.HashSet[Int] stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd => @@ -303,13 +315,17 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun ssc = new StreamingContext(sc, Milliseconds(1000)) ssc.checkpoint(checkpointDir) - val awsCredentials = KinesisTestUtils.getAWSCredentials() val collectedData = new mutable.HashMap[Time, (Array[SequenceNumberRanges], Seq[Int])] - val kinesisStream = KinesisUtils.createStream(ssc, appName, testUtils.streamName, - testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST, - Seconds(10), StorageLevel.MEMORY_ONLY, - awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) + val kinesisStream = KinesisInputDStream.builder.streamingContext(ssc) + .checkpointAppName(appName) + .streamName(testUtils.streamName) + .endpointUrl(testUtils.endpointUrl) + .regionName(testUtils.regionName) + .initialPositionInStream(InitialPositionInStream.LATEST) + .checkpointInterval(Seconds(10)) + .storageLevel(StorageLevel.MEMORY_ONLY) + .build() // Verify that the generated RDDs are KinesisBackedBlockRDDs, and collect the data in each batch kinesisStream.foreachRDD((rdd: RDD[Array[Byte]], time: Time) => { diff --git a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala index d2ad9be555770..66c4747fec268 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkConf import org.apache.spark.SparkContext /** - * Provides a method to run tests against a {@link SparkContext} variable that is correctly stopped + * Provides a method to run tests against a `SparkContext` variable that is correctly stopped * after each test. */ trait LocalSparkContext { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index d8608d885d6f1..bc0b49d48d323 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -74,7 +74,7 @@ abstract class Classifier[ * and features (`Vector`). * @param numClasses Number of classes label can take. Labels must be integers in the range * [0, numClasses). - * @throws SparkException if any label is not an integer >= 0 + * @note Throws `SparkException` if any label is a non-integer or is negative */ protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = { require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 327cb974ef96c..3f8d65a378e2c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -144,45 +144,13 @@ class PrefixSpan private ( logInfo(s"minimum count for a frequent pattern: $minCount") // Find frequent items. - val freqItemAndCounts = data.flatMap { itemsets => - val uniqItems = mutable.Set.empty[Item] - itemsets.foreach { _.foreach { item => - uniqItems += item - }} - uniqItems.toIterator.map((_, 1L)) - }.reduceByKey(_ + _) - .filter { case (_, count) => - count >= minCount - }.collect() - val freqItems = freqItemAndCounts.sortBy(-_._2).map(_._1) + val freqItems = findFrequentItems(data, minCount) logInfo(s"number of frequent items: ${freqItems.length}") // Keep only frequent items from input sequences and convert them to internal storage. val itemToInt = freqItems.zipWithIndex.toMap - val dataInternalRepr = data.flatMap { itemsets => - val allItems = mutable.ArrayBuilder.make[Int] - var containsFreqItems = false - allItems += 0 - itemsets.foreach { itemsets => - val items = mutable.ArrayBuilder.make[Int] - itemsets.foreach { item => - if (itemToInt.contains(item)) { - items += itemToInt(item) + 1 // using 1-indexing in internal format - } - } - val result = items.result() - if (result.nonEmpty) { - containsFreqItems = true - allItems ++= result.sorted - } - allItems += 0 - } - if (containsFreqItems) { - Iterator.single(allItems.result()) - } else { - Iterator.empty - } - }.persist(StorageLevel.MEMORY_AND_DISK) + val dataInternalRepr = toDatabaseInternalRepr(data, itemToInt) + .persist(StorageLevel.MEMORY_AND_DISK) val results = genFreqPatterns(dataInternalRepr, minCount, maxPatternLength, maxLocalProjDBSize) @@ -231,6 +199,67 @@ class PrefixSpan private ( @Since("1.5.0") object PrefixSpan extends Logging { + /** + * This methods finds all frequent items in a input dataset. + * + * @param data Sequences of itemsets. + * @param minCount The minimal number of sequence an item should be present in to be frequent + * + * @return An array of Item containing only frequent items. + */ + private[fpm] def findFrequentItems[Item: ClassTag]( + data: RDD[Array[Array[Item]]], + minCount: Long): Array[Item] = { + + data.flatMap { itemsets => + val uniqItems = mutable.Set.empty[Item] + itemsets.foreach(set => uniqItems ++= set) + uniqItems.toIterator.map((_, 1L)) + }.reduceByKey(_ + _).filter { case (_, count) => + count >= minCount + }.sortBy(-_._2).map(_._1).collect() + } + + /** + * This methods cleans the input dataset from un-frequent items, and translate it's item + * to their corresponding Int identifier. + * + * @param data Sequences of itemsets. + * @param itemToInt A map allowing translation of frequent Items to their Int Identifier. + * The map should only contain frequent item. + * + * @return The internal repr of the inputted dataset. With properly placed zero delimiter. + */ + private[fpm] def toDatabaseInternalRepr[Item: ClassTag]( + data: RDD[Array[Array[Item]]], + itemToInt: Map[Item, Int]): RDD[Array[Int]] = { + + data.flatMap { itemsets => + val allItems = mutable.ArrayBuilder.make[Int] + var containsFreqItems = false + allItems += 0 + itemsets.foreach { itemsets => + val items = mutable.ArrayBuilder.make[Int] + itemsets.foreach { item => + if (itemToInt.contains(item)) { + items += itemToInt(item) + 1 // using 1-indexing in internal format + } + } + val result = items.result() + if (result.nonEmpty) { + containsFreqItems = true + allItems ++= result.sorted + allItems += 0 + } + } + if (containsFreqItems) { + Iterator.single(allItems.result()) + } else { + Iterator.empty + } + } + } + /** * Find the complete set of frequent sequential patterns in the input sequences. * @param data ordered sequences of itemsets. We represent a sequence internally as Array[Int], diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 4cdbf845ae4f5..4a7e4dd80f246 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -230,7 +230,9 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } -/** Used to test [[Pipeline]] with [[MLWritable]] stages */ +/** + * Used to test [[Pipeline]] with `MLWritable` stages + */ class WritableStage(override val uid: String) extends Transformer with MLWritable { final val intParam: IntParam = new IntParam(this, "intParam", "doc") @@ -257,7 +259,9 @@ object WritableStage extends MLReadable[WritableStage] { override def load(path: String): WritableStage = super.load(path) } -/** Used to test [[Pipeline]] with non-[[MLWritable]] stages */ +/** + * Used to test [[Pipeline]] with non-`MLWritable` stages + */ class UnWritableStage(override val uid: String) extends Transformer { final val intParam: IntParam = new IntParam(this, "intParam", "doc") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala index dd4dd62b8cfe9..db4f56ed60d32 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala @@ -29,8 +29,10 @@ private[ml] object LSHTest { * the following property is satisfied. * * There exist dist1, dist2, p1, p2, so that for any two elements e1 and e2, - * If dist(e1, e2) <= dist1, then Pr{h(x) == h(y)} >= p1 - * If dist(e1, e2) >= dist2, then Pr{h(x) == h(y)} <= p2 + * If dist(e1, e2) is less than or equal to dist1, then Pr{h(x) == h(y)} is greater than + * or equal to p1 + * If dist(e1, e2) is greater than or equal to dist2, then Pr{h(x) == h(y)} is less than + * or equal to p2 * * This is called locality sensitive property. This method checks the property on an * existing dataset and calculate the probabilities. @@ -38,8 +40,10 @@ private[ml] object LSHTest { * * This method hashes each elements to hash buckets using LSH, and calculate the false positive * and false negative: - * False positive: Of all (e1, e2) sharing any bucket, the probability of dist(e1, e2) > distFP - * False negative: Of all (e1, e2) not sharing buckets, the probability of dist(e1, e2) < distFN + * False positive: Of all (e1, e2) sharing any bucket, the probability of dist(e1, e2) is greater + * than distFP + * False negative: Of all (e1, e2) not sharing buckets, the probability of dist(e1, e2) is less + * than distFN * * @param dataset The dataset to verify the locality sensitive hashing property. * @param lsh The lsh instance to perform the hashing diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index aa9c53ca30eee..78a33e05e0e48 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -377,7 +377,7 @@ class ParamsSuite extends SparkFunSuite { object ParamsSuite extends SparkFunSuite { /** - * Checks common requirements for [[Params.params]]: + * Checks common requirements for `Params.params`: * - params are ordered by names * - param parent has the same UID as the object's UID * - param name is the same as the param method name diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index c90cb8ca1034c..92a236928e90b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -34,7 +34,7 @@ private[ml] object TreeTests extends SparkFunSuite { * Convert the given data to a DataFrame, and set the features and label metadata. * @param data Dataset. Categorical features and labels must already have 0-based indices. * This must be non-empty. - * @param categoricalFeatures Map: categorical feature index -> number of distinct values + * @param categoricalFeatures Map: categorical feature index to number of distinct values * @param numClasses Number of classes label can take. If 0, mark as continuous. * @return DataFrame with metadata */ @@ -69,7 +69,9 @@ private[ml] object TreeTests extends SparkFunSuite { df("label").as("label", labelMetadata)) } - /** Java-friendly version of [[setMetadata()]] */ + /** + * Java-friendly version of `setMetadata()` + */ def setMetadata( data: JavaRDD[LabeledPoint], categoricalFeatures: java.util.Map[java.lang.Integer, java.lang.Integer], diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index bfe8f12258bb8..27d606cb05dc2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -81,20 +81,20 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => /** * Default test for Estimator, Model pairs: * - Explicitly set Params, and train model - * - Test save/load using [[testDefaultReadWrite()]] on Estimator and Model + * - Test save/load using `testDefaultReadWrite` on Estimator and Model * - Check Params on Estimator and Model * - Compare model data * - * This requires that [[Model]]'s [[Param]]s should be a subset of [[Estimator]]'s [[Param]]s. + * This requires that `Model`'s `Param`s should be a subset of `Estimator`'s `Param`s. * * @param estimator Estimator to test - * @param dataset Dataset to pass to [[Estimator.fit()]] - * @param testEstimatorParams Set of [[Param]] values to set in estimator - * @param testModelParams Set of [[Param]] values to set in model - * @param checkModelData Method which takes the original and loaded [[Model]] and compares their - * data. This method does not need to check [[Param]] values. - * @tparam E Type of [[Estimator]] - * @tparam M Type of [[Model]] produced by estimator + * @param dataset Dataset to pass to `Estimator.fit()` + * @param testEstimatorParams Set of `Param` values to set in estimator + * @param testModelParams Set of `Param` values to set in model + * @param checkModelData Method which takes the original and loaded `Model` and compares their + * data. This method does not need to check `Param` values. + * @tparam E Type of `Estimator` + * @tparam M Type of `Model` produced by estimator */ def testEstimatorAndModelReadWrite[ E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable]( diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala index 141249a427a4c..54e363a8b9f2b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala @@ -105,8 +105,8 @@ class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { private object StopwatchSuite extends SparkFunSuite { /** - * Checks the input stopwatch on a task that takes a random time (<10ms) to finish. Validates and - * returns the duration reported by the stopwatch. + * Checks the input stopwatch on a task that takes a random time (less than 10ms) to finish. + * Validates and returns the duration reported by the stopwatch. */ def checkStopwatch(sw: Stopwatch): Long = { val ubStart = now diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala index 8f11bbc8e47af..50b73e0e99a22 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala @@ -30,7 +30,9 @@ trait TempDirectory extends BeforeAndAfterAll { self: Suite => private var _tempDir: File = _ - /** Returns the temporary directory as a [[File]] instance. */ + /** + * Returns the temporary directory as a `File` instance. + */ protected def tempDir: File = _tempDir override def beforeAll(): Unit = { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala index 4c2376376dd2a..c2e08d078fc1a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -360,6 +360,49 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { compareResults(expected, model.freqSequences.collect()) } + test("PrefixSpan pre-processing's cleaning test") { + + // One item per itemSet + val itemToInt1 = (4 to 5).zipWithIndex.toMap + val sequences1 = Seq( + Array(Array(4), Array(1), Array(2), Array(5), Array(2), Array(4), Array(5)), + Array(Array(6), Array(7), Array(8))) + val rdd1 = sc.parallelize(sequences1, 2).cache() + + val cleanedSequence1 = PrefixSpan.toDatabaseInternalRepr(rdd1, itemToInt1).collect() + + val expected1 = Array(Array(0, 4, 0, 5, 0, 4, 0, 5, 0)) + .map(_.map(x => if (x == 0) 0 else itemToInt1(x) + 1)) + + compareInternalSequences(expected1, cleanedSequence1) + + // Multi-item sequence + val itemToInt2 = (4 to 6).zipWithIndex.toMap + val sequences2 = Seq( + Array(Array(4, 5), Array(1, 6, 2), Array(2), Array(5), Array(2), Array(4), Array(5, 6, 7)), + Array(Array(8, 9), Array(1, 2))) + val rdd2 = sc.parallelize(sequences2, 2).cache() + + val cleanedSequence2 = PrefixSpan.toDatabaseInternalRepr(rdd2, itemToInt2).collect() + + val expected2 = Array(Array(0, 4, 5, 0, 6, 0, 5, 0, 4, 0, 5, 6, 0)) + .map(_.map(x => if (x == 0) 0 else itemToInt2(x) + 1)) + + compareInternalSequences(expected2, cleanedSequence2) + + // Emptied sequence + val itemToInt3 = (10 to 10).zipWithIndex.toMap + val sequences3 = Seq( + Array(Array(4, 5), Array(1, 6, 2), Array(2), Array(5), Array(2), Array(4), Array(5, 6, 7)), + Array(Array(8, 9), Array(1, 2))) + val rdd3 = sc.parallelize(sequences3, 2).cache() + + val cleanedSequence3 = PrefixSpan.toDatabaseInternalRepr(rdd3, itemToInt3).collect() + val expected3 = Array[Array[Int]]() + + compareInternalSequences(expected3, cleanedSequence3) + } + test("model save/load") { val sequences = Seq( Array(Array(1, 2), Array(3)), @@ -409,4 +452,12 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { val actualSet = actualValue.map(x => (x._1.toSeq, x._2)).toSet assert(expectedSet === actualSet) } + + private def compareInternalSequences( + expectedValue: Array[Array[Int]], + actualValue: Array[Array[Int]]): Unit = { + val expectedSet = expectedValue.map(x => x.toSeq).toSet + val actualSet = actualValue.map(x => x.toSeq).toSet + assert(expectedSet === actualSet) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala index 14152cdd63bc7..d0f02dd966bd5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator} /** - * Test suites for [[GiniAggregator]] and [[EntropyAggregator]]. + * Test suites for `GiniAggregator` and `EntropyAggregator`. */ class ImpuritySuite extends SparkFunSuite { test("Gini impurity does not support negative labels") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala index 6bb7ed9c9513c..720237bd2dddd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala @@ -60,7 +60,7 @@ trait MLlibTestSparkContext extends TempDirectory { self: Suite => * A helper object for importing SQL implicits. * * Note that the alternative of importing `spark.implicits._` is not possible here. - * This is because we create the [[SQLContext]] immediately before the first test is run, + * This is because we create the `SQLContext` immediately before the first test is run, * but the implicits import is needed in the constructor. */ protected object testImplicits extends SQLImplicits { diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 291c1caaaed57..60141792d499b 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1804,17 +1804,31 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, a one-element list) - C{mergeValue}, to merge a V into a C (e.g., adds it to the end of a list) - - C{mergeCombiners}, to combine two C's into a single one. + - C{mergeCombiners}, to combine two C's into a single one (e.g., merges + the lists) + + To avoid memory allocation, both mergeValue and mergeCombiners are allowed to + modify and return their first argument instead of creating a new C. In addition, users can control the partitioning of the output RDD. .. note:: V and C can be different -- for example, one might group an RDD of type (Int, Int) into an RDD of type (Int, List[Int]). - >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) - >>> def add(a, b): return a + str(b) - >>> sorted(x.combineByKey(str, add, add).collect()) - [('a', '11'), ('b', '1')] + >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 2)]) + >>> def to_list(a): + ... return [a] + ... + >>> def append(a, b): + ... a.append(b) + ... return a + ... + >>> def extend(a, b): + ... a.extend(b) + ... return a + ... + >>> sorted(x.combineByKey(to_list, append, extend).collect()) + [('a', [1, 2]), ('b', [1])] """ if numPartitions is None: numPartitions = self._defaultReducePartitions() diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index c1917d2be69d8..b5fcf7092d93a 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -24,13 +24,13 @@ import atexit import os import platform +import warnings import py4j -import pyspark +from pyspark import SparkConf from pyspark.context import SparkContext from pyspark.sql import SparkSession, SQLContext -from pyspark.storagelevel import StorageLevel if os.environ.get("SPARK_EXECUTOR_URI"): SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"]) @@ -39,13 +39,23 @@ try: # Try to access HiveConf, it will raise exception if Hive is not added - SparkContext._jvm.org.apache.hadoop.hive.conf.HiveConf() - spark = SparkSession.builder\ - .enableHiveSupport()\ - .getOrCreate() + conf = SparkConf() + if conf.get('spark.sql.catalogImplementation', 'hive').lower() == 'hive': + SparkContext._jvm.org.apache.hadoop.hive.conf.HiveConf() + spark = SparkSession.builder\ + .enableHiveSupport()\ + .getOrCreate() + else: + spark = SparkSession.builder.getOrCreate() except py4j.protocol.Py4JError: + if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': + warnings.warn("Fall back to non-hive support because failing to access HiveConf, " + "please make sure you build spark with hive") spark = SparkSession.builder.getOrCreate() except TypeError: + if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': + warnings.warn("Fall back to non-hive support because failing to access HiveConf, " + "please make sure you build spark with hive") spark = SparkSession.builder.getOrCreate() sc = spark.sparkContext diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index d912f395dafce..960fb882cf901 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -173,8 +173,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, """ Loads JSON files and returns the results as a :class:`DataFrame`. - `JSON Lines `_(newline-delimited JSON) is supported by default. - For JSON (one record per file), set the `wholeFile` parameter to ``true``. + `JSON Lines `_ (newline-delimited JSON) is supported by default. + For JSON (one record per file), set the ``wholeFile`` parameter to ``true``. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. @@ -634,7 +634,9 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options) @since(1.4) def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None): - """Saves the content of the :class:`DataFrame` in JSON format at the specified path. + """Saves the content of the :class:`DataFrame` in JSON format + (`JSON Lines text format or newline-delimited JSON `_) at the + specified path. :param path: the path in any Hadoop supported file system :param mode: specifies the behavior of the save operation when data already exists. diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 3b604963415f9..65b59d480da36 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -405,8 +405,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, """ Loads a JSON file stream and returns the results as a :class:`DataFrame`. - `JSON Lines `_(newline-delimited JSON) is supported by default. - For JSON (one record per file), set the `wholeFile` parameter to ``true``. + `JSON Lines `_ (newline-delimited JSON) is supported by default. + For JSON (one record per file), set the ``wholeFile`` parameter to ``true``. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 3f25535cb5ec2..9d81025a3016b 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -239,7 +239,7 @@ trait MesosSchedulerUtils extends Logging { } /** - * Converts the attributes from the resource offer into a Map of name -> Attribute Value + * Converts the attributes from the resource offer into a Map of name to Attribute Value * The attribute values are the mesos attribute types and they are * * @param offerAttributes the attributes offered @@ -296,7 +296,7 @@ trait MesosSchedulerUtils extends Logging { /** * Parses the attributes constraints provided to spark and build a matching data struct: - * Map[, Set[values-to-match]] + * {@literal Map[, Set[values-to-match]} * The constraints are specified as ';' separated key-value pairs where keys and values * are separated by ':'. The ':' implies equality (for singular values) and "is one of" for * multiple values (comma separated). For example: @@ -354,7 +354,7 @@ trait MesosSchedulerUtils extends Logging { * container overheads. * * @param sc SparkContext to use to get `spark.mesos.executor.memoryOverhead` value - * @return memory requirement as (0.1 * ) or MEMORY_OVERHEAD_MINIMUM + * @return memory requirement as (0.1 * memoryOverhead) or MEMORY_OVERHEAD_MINIMUM * (whichever is larger) */ def executorMemory(sc: SparkContext): Int = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala index ec56fe7729c2a..57f7a80bedc6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala @@ -44,6 +44,3 @@ class PartitionsAlreadyExistException(db: String, table: String, specs: Seq[Tabl class FunctionAlreadyExistsException(db: String, func: String) extends AnalysisException(s"Function '$func' already exists in database '$db'") - -class TempFunctionAlreadyExistsException(func: String) - extends AnalysisException(s"Temporary function '$func' already exists") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 7da7f55aa5d7f..3f76f26dbe4ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -139,9 +139,8 @@ object UnsupportedOperationChecker { } throwErrorIf( child.isStreaming && distinctAggExprs.nonEmpty, - "Distinct aggregations are not supported on streaming DataFrames/Datasets, unless " + - "it is on aggregated DataFrame/Dataset in Complete output mode. Consider using " + - "approximate distinct aggregation (e.g. approx_count_distinct() instead of count()).") + "Distinct aggregations are not supported on streaming DataFrames/Datasets. Consider " + + "using approx_count_distinct() instead.") case _: Command => throwError("Commands like CreateTable*, AlterTable*, Show* are not supported with " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index faedf5f91c3ef..1417bccf657cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1050,7 +1050,7 @@ class SessionCatalog( * * This performs reflection to decide what type of [[Expression]] to return in the builder. */ - def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = { + protected def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = { // TODO: at least support UDAFs here throw new UnsupportedOperationException("Use sqlContext.udf.register(...) instead.") } @@ -1064,18 +1064,20 @@ class SessionCatalog( } /** - * Create a temporary function. - * This assumes no database is specified in `funcDefinition`. + * Registers a temporary or permanent function into a session-specific [[FunctionRegistry]] */ - def createTempFunction( - name: String, - info: ExpressionInfo, - funcDefinition: FunctionBuilder, - ignoreIfExists: Boolean): Unit = { - if (functionRegistry.lookupFunctionBuilder(name).isDefined && !ignoreIfExists) { - throw new TempFunctionAlreadyExistsException(name) + def registerFunction( + funcDefinition: CatalogFunction, + ignoreIfExists: Boolean, + functionBuilder: Option[FunctionBuilder] = None): Unit = { + val func = funcDefinition.identifier + if (functionRegistry.functionExists(func.unquotedString) && !ignoreIfExists) { + throw new AnalysisException(s"Function $func already exists") } - functionRegistry.registerFunction(name, info, funcDefinition) + val info = new ExpressionInfo(funcDefinition.className, func.database.orNull, func.funcName) + val builder = + functionBuilder.getOrElse(makeFunctionBuilder(func.unquotedString, funcDefinition.className)) + functionRegistry.registerFunction(func.unquotedString, info, builder) } /** @@ -1180,12 +1182,7 @@ class SessionCatalog( // catalog. So, it is possible that qualifiedName is not exactly the same as // catalogFunction.identifier.unquotedString (difference is on case-sensitivity). // At here, we preserve the input from the user. - val info = new ExpressionInfo( - catalogFunction.className, - qualifiedName.database.orNull, - qualifiedName.funcName) - val builder = makeFunctionBuilder(qualifiedName.unquotedString, catalogFunction.className) - createTempFunction(qualifiedName.unquotedString, info, builder, ignoreIfExists = false) + registerFunction(catalogFunction.copy(identifier = qualifiedName), ignoreIfExists = false) // Now, we need to create the Expression. functionRegistry.lookupFunction(qualifiedName.unquotedString, children) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 1049915986d9b..bb1273f5c3d84 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -462,35 +462,54 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String }) } - private[this] def cast(from: DataType, to: DataType): Any => Any = to match { - case dt if dt == from => identity[Any] - case StringType => castToString(from) - case BinaryType => castToBinary(from) - case DateType => castToDate(from) - case decimal: DecimalType => castToDecimal(from, decimal) - case TimestampType => castToTimestamp(from) - case CalendarIntervalType => castToInterval(from) - case BooleanType => castToBoolean(from) - case ByteType => castToByte(from) - case ShortType => castToShort(from) - case IntegerType => castToInt(from) - case FloatType => castToFloat(from) - case LongType => castToLong(from) - case DoubleType => castToDouble(from) - case array: ArrayType => castArray(from.asInstanceOf[ArrayType].elementType, array.elementType) - case map: MapType => castMap(from.asInstanceOf[MapType], map) - case struct: StructType => castStruct(from.asInstanceOf[StructType], struct) - case udt: UserDefinedType[_] - if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass => - identity[Any] - case _: UserDefinedType[_] => - throw new SparkException(s"Cannot cast $from to $to.") + private[this] def cast(from: DataType, to: DataType): Any => Any = { + // If the cast does not change the structure, then we don't really need to cast anything. + // We can return what the children return. Same thing should happen in the codegen path. + if (DataType.equalsStructurally(from, to)) { + identity + } else { + to match { + case dt if dt == from => identity[Any] + case StringType => castToString(from) + case BinaryType => castToBinary(from) + case DateType => castToDate(from) + case decimal: DecimalType => castToDecimal(from, decimal) + case TimestampType => castToTimestamp(from) + case CalendarIntervalType => castToInterval(from) + case BooleanType => castToBoolean(from) + case ByteType => castToByte(from) + case ShortType => castToShort(from) + case IntegerType => castToInt(from) + case FloatType => castToFloat(from) + case LongType => castToLong(from) + case DoubleType => castToDouble(from) + case array: ArrayType => + castArray(from.asInstanceOf[ArrayType].elementType, array.elementType) + case map: MapType => castMap(from.asInstanceOf[MapType], map) + case struct: StructType => castStruct(from.asInstanceOf[StructType], struct) + case udt: UserDefinedType[_] + if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass => + identity[Any] + case _: UserDefinedType[_] => + throw new SparkException(s"Cannot cast $from to $to.") + } + } } private[this] lazy val cast: Any => Any = cast(child.dataType, dataType) protected override def nullSafeEval(input: Any): Any = cast(input) + override def genCode(ctx: CodegenContext): ExprCode = { + // If the cast does not change the structure, then we don't really need to cast anything. + // We can return what the children return. Same thing should happen in the interpreted path. + if (DataType.equalsStructurally(child.dataType, dataType)) { + child.genCode(ctx) + } else { + super.genCode(ctx) + } + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 6d94764f1bfac..eed773d4cb368 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -996,6 +996,8 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) override def foldable: Boolean = false override def nullable: Boolean = false + override def flatArguments: Iterator[Any] = Iterator(child) + private val errMsg = "Null value appeared in non-nullable field:" + walkedTypePath.mkString("\n", "\n", "\n") + "If the schema is inferred from a Scala tuple/case class, or a Java bean, " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 59db28d58afce..d7b493d521ddb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -47,7 +47,6 @@ abstract class SubqueryExpression( plan: LogicalPlan, children: Seq[Expression], exprId: ExprId) extends PlanExpression[LogicalPlan] { - override lazy val resolved: Boolean = childrenResolved && plan.resolved override lazy val references: AttributeSet = if (plan.resolved) super.references -- plan.outputSet else super.references @@ -59,6 +58,13 @@ abstract class SubqueryExpression( children.zip(p.children).forall(p => p._1.semanticEquals(p._2)) case _ => false } + def canonicalize(attrs: AttributeSeq): SubqueryExpression = { + // Normalize the outer references in the subquery plan. + val normalizedPlan = plan.transformAllExpressions { + case OuterReference(r) => OuterReference(QueryPlan.normalizeExprId(r, attrs)) + } + withNewPlan(normalizedPlan).canonicalized.asInstanceOf[SubqueryExpression] + } } object SubqueryExpression { @@ -236,6 +242,12 @@ case class ScalarSubquery( override def nullable: Boolean = true override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(plan = plan) override def toString: String = s"scalar-subquery#${exprId.id} $conditionString" + override lazy val canonicalized: Expression = { + ScalarSubquery( + plan.canonicalized, + children.map(_.canonicalized), + ExprId(0)) + } } object ScalarSubquery { @@ -268,6 +280,12 @@ case class ListQuery( override def nullable: Boolean = false override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan) override def toString: String = s"list#${exprId.id} $conditionString" + override lazy val canonicalized: Expression = { + ListQuery( + plan.canonicalized, + children.map(_.canonicalized), + ExprId(0)) + } } /** @@ -290,4 +308,10 @@ case class Exists( override def nullable: Boolean = false override def withNewPlan(plan: LogicalPlan): Exists = copy(plan = plan) override def toString: String = s"exists#${exprId.id} $conditionString" + override lazy val canonicalized: Expression = { + Exists( + plan.canonicalized, + children.map(_.canonicalized), + ExprId(0)) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index cbd506465ae6a..c704c2e6d36bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -54,8 +54,6 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr private def reorder(plan: LogicalPlan, output: Seq[Attribute]): LogicalPlan = { val (items, conditions) = extractInnerJoins(plan) - // TODO: Compute the set of star-joins and use them in the join enumeration - // algorithm to prune un-optimal plan choices. val result = // Do reordering if the number of items is appropriate and join conditions exist. // We also need to check if costs of all items can be evaluated. @@ -150,12 +148,15 @@ object JoinReorderDP extends PredicateHelper with Logging { case (item, id) => Set(id) -> JoinPlan(Set(id), item, Set(), Cost(0, 0)) }.toMap) + // Build filters from the join graph to be used by the search algorithm. + val filters = JoinReorderDPFilters.buildJoinGraphInfo(conf, items, conditions, itemIndex) + // Build plans for next levels until the last level has only one plan. This plan contains // all items that can be joined, so there's no need to continue. val topOutputSet = AttributeSet(output) - while (foundPlans.size < items.length && foundPlans.last.size > 1) { + while (foundPlans.size < items.length) { // Build plans for the next level. - foundPlans += searchLevel(foundPlans, conf, conditions, topOutputSet) + foundPlans += searchLevel(foundPlans, conf, conditions, topOutputSet, filters) } val durationInMs = (System.nanoTime() - startTime) / (1000 * 1000) @@ -179,7 +180,8 @@ object JoinReorderDP extends PredicateHelper with Logging { existingLevels: Seq[JoinPlanMap], conf: SQLConf, conditions: Set[Expression], - topOutput: AttributeSet): JoinPlanMap = { + topOutput: AttributeSet, + filters: Option[JoinGraphInfo]): JoinPlanMap = { val nextLevel = mutable.Map.empty[Set[Int], JoinPlan] var k = 0 @@ -200,7 +202,7 @@ object JoinReorderDP extends PredicateHelper with Logging { } otherSideCandidates.foreach { otherSidePlan => - buildJoin(oneSidePlan, otherSidePlan, conf, conditions, topOutput) match { + buildJoin(oneSidePlan, otherSidePlan, conf, conditions, topOutput, filters) match { case Some(newJoinPlan) => // Check if it's the first plan for the item set, or it's a better plan than // the existing one due to lower cost. @@ -218,14 +220,20 @@ object JoinReorderDP extends PredicateHelper with Logging { } /** - * Builds a new JoinPlan when both conditions hold: + * Builds a new JoinPlan if the following conditions hold: * - the sets of items contained in left and right sides do not overlap. * - there exists at least one join condition involving references from both sides. + * - if star-join filter is enabled, allow the following combinations: + * 1) (oneJoinPlan U otherJoinPlan) is a subset of star-join + * 2) star-join is a subset of (oneJoinPlan U otherJoinPlan) + * 3) (oneJoinPlan U otherJoinPlan) is a subset of non star-join + * * @param oneJoinPlan One side JoinPlan for building a new JoinPlan. * @param otherJoinPlan The other side JoinPlan for building a new join node. * @param conf SQLConf for statistics computation. * @param conditions The overall set of join conditions. * @param topOutput The output attributes of the final plan. + * @param filters Join graph info to be used as filters by the search algorithm. * @return Builds and returns a new JoinPlan if both conditions hold. Otherwise, returns None. */ private def buildJoin( @@ -233,13 +241,27 @@ object JoinReorderDP extends PredicateHelper with Logging { otherJoinPlan: JoinPlan, conf: SQLConf, conditions: Set[Expression], - topOutput: AttributeSet): Option[JoinPlan] = { + topOutput: AttributeSet, + filters: Option[JoinGraphInfo]): Option[JoinPlan] = { if (oneJoinPlan.itemIds.intersect(otherJoinPlan.itemIds).nonEmpty) { // Should not join two overlapping item sets. return None } + if (filters.isDefined) { + // Apply star-join filter, which ensures that tables in a star schema relationship + // are planned together. The star-filter will eliminate joins among star and non-star + // tables until the star joins are built. The following combinations are allowed: + // 1. (oneJoinPlan U otherJoinPlan) is a subset of star-join + // 2. star-join is a subset of (oneJoinPlan U otherJoinPlan) + // 3. (oneJoinPlan U otherJoinPlan) is a subset of non star-join + val isValidJoinCombination = + JoinReorderDPFilters.starJoinFilter(oneJoinPlan.itemIds, otherJoinPlan.itemIds, + filters.get) + if (!isValidJoinCombination) return None + } + val onePlan = oneJoinPlan.plan val otherPlan = otherJoinPlan.plan val joinConds = conditions @@ -327,3 +349,109 @@ object JoinReorderDP extends PredicateHelper with Logging { case class Cost(card: BigInt, size: BigInt) { def +(other: Cost): Cost = Cost(this.card + other.card, this.size + other.size) } + +/** + * Implements optional filters to reduce the search space for join enumeration. + * + * 1) Star-join filters: Plan star-joins together since they are assumed + * to have an optimal execution based on their RI relationship. + * 2) Cartesian products: Defer their planning later in the graph to avoid + * large intermediate results (expanding joins, in general). + * 3) Composite inners: Don't generate "bushy tree" plans to avoid materializing + * intermediate results. + * + * Filters (2) and (3) are not implemented. + */ +object JoinReorderDPFilters extends PredicateHelper { + /** + * Builds join graph information to be used by the filtering strategies. + * Currently, it builds the sets of star/non-star joins. + * It can be extended with the sets of connected/unconnected joins, which + * can be used to filter Cartesian products. + */ + def buildJoinGraphInfo( + conf: SQLConf, + items: Seq[LogicalPlan], + conditions: Set[Expression], + itemIndex: Seq[(LogicalPlan, Int)]): Option[JoinGraphInfo] = { + + if (conf.joinReorderDPStarFilter) { + // Compute the tables in a star-schema relationship. + val starJoin = StarSchemaDetection(conf).findStarJoins(items, conditions.toSeq) + val nonStarJoin = items.filterNot(starJoin.contains(_)) + + if (starJoin.nonEmpty && nonStarJoin.nonEmpty) { + val itemMap = itemIndex.toMap + Some(JoinGraphInfo(starJoin.map(itemMap).toSet, nonStarJoin.map(itemMap).toSet)) + } else { + // Nothing interesting to return. + None + } + } else { + // Star schema filter is not enabled. + None + } + } + + /** + * Applies the star-join filter that eliminates join combinations among star + * and non-star tables until the star join is built. + * + * Given the oneSideJoinPlan/otherSideJoinPlan, which represent all the plan + * permutations generated by the DP join enumeration, and the star/non-star plans, + * the following plan combinations are allowed: + * 1. (oneSideJoinPlan U otherSideJoinPlan) is a subset of star-join + * 2. star-join is a subset of (oneSideJoinPlan U otherSideJoinPlan) + * 3. (oneSideJoinPlan U otherSideJoinPlan) is a subset of non star-join + * + * It assumes the sets are disjoint. + * + * Example query graph: + * + * t1 d1 - t2 - t3 + * \ / + * f1 + * | + * d2 + * + * star: {d1, f1, d2} + * non-star: {t2, t1, t3} + * + * level 0: (f1 ), (d2 ), (t3 ), (d1 ), (t1 ), (t2 ) + * level 1: {t3 t2 }, {f1 d2 }, {f1 d1 } + * level 2: {d2 f1 d1 } + * level 3: {t1 d1 f1 d2 }, {t2 d1 f1 d2 } + * level 4: {d1 t2 f1 t1 d2 }, {d1 t3 t2 f1 d2 } + * level 5: {d1 t3 t2 f1 t1 d2 } + * + * @param oneSideJoinPlan One side of the join represented as a set of plan ids. + * @param otherSideJoinPlan The other side of the join represented as a set of plan ids. + * @param filters Star and non-star plans represented as sets of plan ids + */ + def starJoinFilter( + oneSideJoinPlan: Set[Int], + otherSideJoinPlan: Set[Int], + filters: JoinGraphInfo) : Boolean = { + val starJoins = filters.starJoins + val nonStarJoins = filters.nonStarJoins + val join = oneSideJoinPlan.union(otherSideJoinPlan) + + // Disjoint sets + oneSideJoinPlan.intersect(otherSideJoinPlan).isEmpty && + // Either star or non-star is empty + (starJoins.isEmpty || nonStarJoins.isEmpty || + // Join is a subset of the star-join + join.subsetOf(starJoins) || + // Star-join is a subset of join + starJoins.subsetOf(join) || + // Join is a subset of non-star + join.subsetOf(nonStarJoins)) + } +} + +/** + * Helper class that keeps information about the join graph as sets of item/plan ids. + * It currently stores the star/non-star plans. It can be + * extended with the set of connected/unconnected plans. + */ +case class JoinGraphInfo (starJoins: Set[Int], nonStarJoins: Set[Int]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala index 91cb004eaec46..97ee9988386dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala @@ -76,7 +76,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper { val emptyStarJoinPlan = Seq.empty[LogicalPlan] - if (!conf.starSchemaDetection || input.size < 2) { + if (input.size < 2) { emptyStarJoinPlan } else { // Find if the input plans are eligible for star join detection. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 3008e8cb84659..2fb65bd435507 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -377,7 +377,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT // As the root of the expression, Alias will always take an arbitrary exprId, we need to // normalize that for equality testing, by assigning expr id from 0 incrementally. The // alias name doesn't matter and should be erased. - Alias(normalizeExprId(a.child), "")(ExprId(id), a.qualifier, isGenerated = a.isGenerated) + val normalizedChild = QueryPlan.normalizeExprId(a.child, allAttributes) + Alias(normalizedChild, "")(ExprId(id), a.qualifier, isGenerated = a.isGenerated) case ar: AttributeReference if allAttributes.indexOf(ar.exprId) == -1 => // Top level `AttributeReference` may also be used for output like `Alias`, we should @@ -385,7 +386,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT id += 1 ar.withExprId(ExprId(id)) - case other => normalizeExprId(other) + case other => QueryPlan.normalizeExprId(other, allAttributes) }.withNewChildren(canonicalizedChildren) } @@ -395,23 +396,6 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT */ protected def preCanonicalized: PlanType = this - /** - * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference` - * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we - * do not use `BindReferences` here as the plan may take the expression as a parameter with type - * `Attribute`, and replace it with `BoundReference` will cause error. - */ - protected def normalizeExprId[T <: Expression](e: T, input: AttributeSeq = allAttributes): T = { - e.transformUp { - case ar: AttributeReference => - val ordinal = input.indexOf(ar.exprId) - if (ordinal == -1) { - ar - } else { - ar.withExprId(ExprId(ordinal)) - } - }.canonicalized.asInstanceOf[T] - } /** * Returns true when the given query plan will return the same results as this query plan. @@ -438,3 +422,24 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT */ lazy val allAttributes: AttributeSeq = children.flatMap(_.output) } + +object QueryPlan { + /** + * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference` + * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we + * do not use `BindReferences` here as the plan may take the expression as a parameter with type + * `Attribute`, and replace it with `BoundReference` will cause error. + */ + def normalizeExprId[T <: Expression](e: T, input: AttributeSeq): T = { + e.transformUp { + case s: SubqueryExpression => s.canonicalize(input) + case ar: AttributeReference => + val ordinal = input.indexOf(ar.exprId) + if (ordinal == -1) { + ar + } else { + ar.withExprId(ExprId(ordinal)) + } + }.canonicalized.asInstanceOf[T] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 6b0f495033494..2e1798e22b9fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -736,6 +736,12 @@ object SQLConf { .checkValue(weight => weight >= 0 && weight <= 1, "The weight value must be in [0, 1].") .createWithDefault(0.7) + val JOIN_REORDER_DP_STAR_FILTER = + buildConf("spark.sql.cbo.joinReorder.dp.star.filter") + .doc("Applies star-join filter heuristics to cost based join enumeration.") + .booleanConf + .createWithDefault(false) + val STARSCHEMA_DETECTION = buildConf("spark.sql.cbo.starSchemaDetection") .doc("When true, it enables join reordering based on star schema detection. ") .booleanConf @@ -1011,6 +1017,8 @@ class SQLConf extends Serializable with Logging { def joinReorderCardWeight: Double = getConf(SQLConf.JOIN_REORDER_CARD_WEIGHT) + def joinReorderDPStarFilter: Boolean = getConf(SQLConf.JOIN_REORDER_DP_STAR_FILTER) + def windowExecBufferSpillThreshold: Int = getConf(WINDOW_EXEC_BUFFER_SPILL_THRESHOLD) def sortMergeJoinExecBufferSpillThreshold: Int = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 520aff5e2b677..30745c6a9d42a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -288,4 +288,30 @@ object DataType { case (fromDataType, toDataType) => fromDataType == toDataType } } + + /** + * Returns true if the two data types share the same "shape", i.e. the types (including + * nullability) are the same, but the field names don't need to be the same. + */ + def equalsStructurally(from: DataType, to: DataType): Boolean = { + (from, to) match { + case (left: ArrayType, right: ArrayType) => + equalsStructurally(left.elementType, right.elementType) && + left.containsNull == right.containsNull + + case (left: MapType, right: MapType) => + equalsStructurally(left.keyType, right.keyType) && + equalsStructurally(left.valueType, right.valueType) && + left.valueContainsNull == right.valueContainsNull + + case (StructType(fromFields), StructType(toFields)) => + fromFields.length == toFields.length && + fromFields.zip(toFields) + .forall { case (l, r) => + equalsStructurally(l.dataType, r.dataType) && l.nullable == r.nullable + } + + case (fromDataType, toDataType) => fromDataType == toDataType + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 850869799507f..8ae3ff5043e68 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -117,11 +117,11 @@ object RandomDataGenerator { } /** - * Returns a function which generates random values for the given [[DataType]], or `None` if no + * Returns a function which generates random values for the given `DataType`, or `None` if no * random data generator is defined for that data type. The generated values will use an external - * representation of the data type; for example, the random generator for [[DateType]] will return - * instances of [[java.sql.Date]] and the generator for [[StructType]] will return a [[Row]]. - * For a [[UserDefinedType]] for a class X, an instance of class X is returned. + * representation of the data type; for example, the random generator for `DateType` will return + * instances of [[java.sql.Date]] and the generator for `StructType` will return a [[Row]]. + * For a `UserDefinedType` for a class X, an instance of class X is returned. * * @param dataType the type to generate values for * @param nullable whether null values should be generated diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala index a6d90409382e5..769addf3b29e6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.Benchmark /** - * Benchmark [[UnsafeProjection]] for fixed-length/primitive-type fields. + * Benchmark `UnsafeProjection` for fixed-length/primitive-type fields. */ object UnsafeProjectionBenchmark { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 9ba846fb25279..be8903000a0d1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -1162,10 +1162,10 @@ abstract class SessionCatalogSuite extends PlanTest { withBasicCatalog { catalog => val tempFunc1 = (e: Seq[Expression]) => e.head val tempFunc2 = (e: Seq[Expression]) => e.last - val info1 = new ExpressionInfo("tempFunc1", "temp1") - val info2 = new ExpressionInfo("tempFunc2", "temp2") - catalog.createTempFunction("temp1", info1, tempFunc1, ignoreIfExists = false) - catalog.createTempFunction("temp2", info2, tempFunc2, ignoreIfExists = false) + catalog.registerFunction( + newFunc("temp1", None), ignoreIfExists = false, functionBuilder = Some(tempFunc1)) + catalog.registerFunction( + newFunc("temp2", None), ignoreIfExists = false, functionBuilder = Some(tempFunc2)) val arguments = Seq(Literal(1), Literal(2), Literal(3)) assert(catalog.lookupFunction(FunctionIdentifier("temp1"), arguments) === Literal(1)) assert(catalog.lookupFunction(FunctionIdentifier("temp2"), arguments) === Literal(3)) @@ -1174,13 +1174,15 @@ abstract class SessionCatalogSuite extends PlanTest { catalog.lookupFunction(FunctionIdentifier("temp3"), arguments) } val tempFunc3 = (e: Seq[Expression]) => Literal(e.size) - val info3 = new ExpressionInfo("tempFunc3", "temp1") // Temporary function already exists - intercept[TempFunctionAlreadyExistsException] { - catalog.createTempFunction("temp1", info3, tempFunc3, ignoreIfExists = false) - } + val e = intercept[AnalysisException] { + catalog.registerFunction( + newFunc("temp1", None), ignoreIfExists = false, functionBuilder = Some(tempFunc3)) + }.getMessage + assert(e.contains("Function temp1 already exists")) // Temporary function is overridden - catalog.createTempFunction("temp1", info3, tempFunc3, ignoreIfExists = true) + catalog.registerFunction( + newFunc("temp1", None), ignoreIfExists = true, functionBuilder = Some(tempFunc3)) assert( catalog.lookupFunction( FunctionIdentifier("temp1"), arguments) === Literal(arguments.length)) @@ -1193,8 +1195,8 @@ abstract class SessionCatalogSuite extends PlanTest { assert(!catalog.isTemporaryFunction(FunctionIdentifier("temp1"))) val tempFunc1 = (e: Seq[Expression]) => e.head - val info1 = new ExpressionInfo("tempFunc1", "temp1") - catalog.createTempFunction("temp1", info1, tempFunc1, ignoreIfExists = false) + catalog.registerFunction( + newFunc("temp1", None), ignoreIfExists = false, functionBuilder = Some(tempFunc1)) // Returns true when the function is temporary assert(catalog.isTemporaryFunction(FunctionIdentifier("temp1"))) @@ -1243,9 +1245,9 @@ abstract class SessionCatalogSuite extends PlanTest { test("drop temp function") { withBasicCatalog { catalog => - val info = new ExpressionInfo("tempFunc", "func1") val tempFunc = (e: Seq[Expression]) => e.head - catalog.createTempFunction("func1", info, tempFunc, ignoreIfExists = false) + catalog.registerFunction( + newFunc("func1", None), ignoreIfExists = false, functionBuilder = Some(tempFunc)) val arguments = Seq(Literal(1), Literal(2), Literal(3)) assert(catalog.lookupFunction(FunctionIdentifier("func1"), arguments) === Literal(1)) catalog.dropTempFunction("func1", ignoreIfNotExists = false) @@ -1284,9 +1286,9 @@ abstract class SessionCatalogSuite extends PlanTest { test("lookup temp function") { withBasicCatalog { catalog => - val info1 = new ExpressionInfo("tempFunc1", "func1") val tempFunc1 = (e: Seq[Expression]) => e.head - catalog.createTempFunction("func1", info1, tempFunc1, ignoreIfExists = false) + catalog.registerFunction( + newFunc("func1", None), ignoreIfExists = false, functionBuilder = Some(tempFunc1)) assert(catalog.lookupFunction( FunctionIdentifier("func1"), Seq(Literal(1), Literal(2), Literal(3))) == Literal(1)) catalog.dropTempFunction("func1", ignoreIfNotExists = false) @@ -1298,14 +1300,14 @@ abstract class SessionCatalogSuite extends PlanTest { test("list functions") { withBasicCatalog { catalog => - val info1 = new ExpressionInfo("tempFunc1", "func1") - val info2 = new ExpressionInfo("tempFunc2", "yes_me") + val funcMeta1 = newFunc("func1", None) + val funcMeta2 = newFunc("yes_me", None) val tempFunc1 = (e: Seq[Expression]) => e.head val tempFunc2 = (e: Seq[Expression]) => e.last catalog.createFunction(newFunc("func2", Some("db2")), ignoreIfExists = false) catalog.createFunction(newFunc("not_me", Some("db2")), ignoreIfExists = false) - catalog.createTempFunction("func1", info1, tempFunc1, ignoreIfExists = false) - catalog.createTempFunction("yes_me", info2, tempFunc2, ignoreIfExists = false) + catalog.registerFunction(funcMeta1, ignoreIfExists = false, functionBuilder = Some(tempFunc1)) + catalog.registerFunction(funcMeta2, ignoreIfExists = false, functionBuilder = Some(tempFunc2)) assert(catalog.listFunctions("db1", "*").map(_._1).toSet == Set(FunctionIdentifier("func1"), FunctionIdentifier("yes_me"))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 8eccadbdd8afb..a7ffa884d2286 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -813,4 +813,18 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { assert(cast(1.0.toFloat, DateType).checkInputDataTypes().isFailure) assert(cast(1.0, DateType).checkInputDataTypes().isFailure) } + + test("SPARK-20302 cast with same structure") { + val from = new StructType() + .add("a", IntegerType) + .add("b", new StructType().add("b1", LongType)) + + val to = new StructType() + .add("a1", IntegerType) + .add("b1", new StructType().add("b11", LongType)) + + val input = Row(10, Row(12L)) + + checkEvaluation(cast(Literal.create(input, from), to), input) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala new file mode 100644 index 0000000000000..a23d6266b2840 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala @@ -0,0 +1,426 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf._ + + +class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBase { + + override val conf = new SQLConf().copy( + CBO_ENABLED -> true, + JOIN_REORDER_ENABLED -> true, + JOIN_REORDER_DP_STAR_FILTER -> true) + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Operator Optimizations", FixedPoint(100), + CombineFilters, + PushDownPredicate, + ReorderJoin(conf), + PushPredicateThroughJoin, + ColumnPruning, + CollapseProject) :: + Batch("Join Reorder", Once, + CostBasedJoinReorder(conf)) :: Nil + } + + private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( + // F1 (fact table) + attr("f1_fk1") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_fk2") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_fk3") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_c1") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_c2") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), + nullCount = 0, avgLen = 4, maxLen = 4), + + // D1 (dimension) + attr("d1_pk") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d1_c2") -> ColumnStat(distinctCount = 50, min = Some(1), max = Some(50), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d1_c3") -> ColumnStat(distinctCount = 50, min = Some(1), max = Some(50), + nullCount = 0, avgLen = 4, maxLen = 4), + + // D2 (dimension) + attr("d2_pk") -> ColumnStat(distinctCount = 20, min = Some(1), max = Some(20), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d2_c2") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d2_c3") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + + // D3 (dimension) + attr("d3_pk") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d3_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d3_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4), + + // T1 (regular table i.e. outside star) + attr("t1_c1") -> ColumnStat(distinctCount = 20, min = Some(1), max = Some(20), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t1_c2") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t1_c3") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 1, avgLen = 4, maxLen = 4), + + // T2 (regular table) + attr("t2_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t2_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t2_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + + // T3 (regular table) + attr("t3_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t3_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t3_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + + // T4 (regular table) + attr("t4_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t4_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t4_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + + // T5 (regular table) + attr("t5_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t5_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t5_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + + // T6 (regular table) + attr("t6_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t6_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t6_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4) + + )) + + private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) + private val nameToColInfo: Map[String, (Attribute, ColumnStat)] = + columnInfo.map(kv => kv._1.name -> kv) + + private val f1 = StatsTestPlan( + outputList = Seq("f1_fk1", "f1_fk2", "f1_fk3", "f1_c1", "f1_c2").map(nameToAttr), + rowCount = 1000, + size = Some(1000 * (8 + 4 * 5)), + attributeStats = AttributeMap(Seq("f1_fk1", "f1_fk2", "f1_fk3", "f1_c1", "f1_c2") + .map(nameToColInfo))) + + // To control the layout of the join plans, keep the size for the non-fact tables constant + // and vary the rowcount and the number of distinct values of the join columns. + private val d1 = StatsTestPlan( + outputList = Seq("d1_pk", "d1_c2", "d1_c3").map(nameToAttr), + rowCount = 100, + size = Some(3000), + attributeStats = AttributeMap(Seq("d1_pk", "d1_c2", "d1_c3").map(nameToColInfo))) + + private val d2 = StatsTestPlan( + outputList = Seq("d2_pk", "d2_c2", "d2_c3").map(nameToAttr), + rowCount = 20, + size = Some(3000), + attributeStats = AttributeMap(Seq("d2_pk", "d2_c2", "d2_c3").map(nameToColInfo))) + + private val d3 = StatsTestPlan( + outputList = Seq("d3_pk", "d3_c2", "d3_c3").map(nameToAttr), + rowCount = 10, + size = Some(3000), + attributeStats = AttributeMap(Seq("d3_pk", "d3_c2", "d3_c3").map(nameToColInfo))) + + private val t1 = StatsTestPlan( + outputList = Seq("t1_c1", "t1_c2", "t1_c3").map(nameToAttr), + rowCount = 50, + size = Some(3000), + attributeStats = AttributeMap(Seq("t1_c1", "t1_c2", "t1_c3").map(nameToColInfo))) + + private val t2 = StatsTestPlan( + outputList = Seq("t2_c1", "t2_c2", "t2_c3").map(nameToAttr), + rowCount = 10, + size = Some(3000), + attributeStats = AttributeMap(Seq("t2_c1", "t2_c2", "t2_c3").map(nameToColInfo))) + + private val t3 = StatsTestPlan( + outputList = Seq("t3_c1", "t3_c2", "t3_c3").map(nameToAttr), + rowCount = 10, + size = Some(3000), + attributeStats = AttributeMap(Seq("t3_c1", "t3_c2", "t3_c3").map(nameToColInfo))) + + private val t4 = StatsTestPlan( + outputList = Seq("t4_c1", "t4_c2", "t4_c3").map(nameToAttr), + rowCount = 10, + size = Some(3000), + attributeStats = AttributeMap(Seq("t4_c1", "t4_c2", "t4_c3").map(nameToColInfo))) + + private val t5 = StatsTestPlan( + outputList = Seq("t5_c1", "t5_c2", "t5_c3").map(nameToAttr), + rowCount = 10, + size = Some(3000), + attributeStats = AttributeMap(Seq("t5_c1", "t5_c2", "t5_c3").map(nameToColInfo))) + + private val t6 = StatsTestPlan( + outputList = Seq("t6_c1", "t6_c2", "t6_c3").map(nameToAttr), + rowCount = 10, + size = Some(3000), + attributeStats = AttributeMap(Seq("t6_c1", "t6_c2", "t6_c3").map(nameToColInfo))) + + test("Test 1: Star query with two dimensions and two regular tables") { + + // d1 t1 + // \ / + // f1 + // / \ + // d2 t2 + // + // star: {f1, d1, d2} + // non-star: {t1, t2} + // + // level 0: (t2 ), (d2 ), (f1 ), (d1 ), (t1 ) + // level 1: {f1 d1 }, {d2 f1 } + // level 2: {d2 f1 d1 } + // level 3: {t2 d1 d2 f1 }, {t1 d1 d2 f1 } + // level 4: {f1 t1 t2 d1 d2 } + // + // Number of generated plans: 11 (vs. 20 w/o filter) + val query = + f1.join(t1).join(t2).join(d1).join(d2) + .where((nameToAttr("f1_c1") === nameToAttr("t1_c1")) && + (nameToAttr("f1_c2") === nameToAttr("t2_c1")) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk")) && + (nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) + + val expected = + f1.join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk"))) + .join(t2, Inner, Some(nameToAttr("f1_c2") === nameToAttr("t2_c1"))) + .join(t1, Inner, Some(nameToAttr("f1_c1") === nameToAttr("t1_c1"))) + + assertEqualPlans(query, expected) + } + + test("Test 2: Star with a linear branch") { + // + // t1 d1 - t2 - t3 + // \ / + // f1 + // | + // d2 + // + // star: {d1, f1, d2} + // non-star: {t2, t1, t3} + // + // level 0: (f1 ), (d2 ), (t3 ), (d1 ), (t1 ), (t2 ) + // level 1: {t3 t2 }, {f1 d2 }, {f1 d1 } + // level 2: {d2 f1 d1 } + // level 3: {t1 d1 f1 d2 }, {t2 d1 f1 d2 } + // level 4: {d1 t2 f1 t1 d2 }, {d1 t3 t2 f1 d2 } + // level 5: {d1 t3 t2 f1 t1 d2 } + // + // Number of generated plans: 15 (vs 24) + val query = + d1.join(t1).join(t2).join(f1).join(d2).join(t3) + .where((nameToAttr("d1_pk") === nameToAttr("f1_fk1")) && + (nameToAttr("t1_c1") === nameToAttr("f1_c1")) && + (nameToAttr("d2_pk") === nameToAttr("f1_fk2")) && + (nameToAttr("f1_fk2") === nameToAttr("d2_pk")) && + (nameToAttr("d1_c2") === nameToAttr("t2_c1")) && + (nameToAttr("t2_c2") === nameToAttr("t3_c1"))) + + val expected = + f1.join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk"))) + .join(t3.join(t2, Inner, Some(nameToAttr("t2_c2") === nameToAttr("t3_c1"))), Inner, + Some(nameToAttr("d1_c2") === nameToAttr("t2_c1"))) + .join(t1, Inner, Some(nameToAttr("t1_c1") === nameToAttr("f1_c1"))) + + assertEqualPlans(query, expected) + } + + test("Test 3: Star with derived branches") { + // t3 t2 + // | | + // d1 - t4 - t1 + // | + // f1 + // | + // d2 + // + // star: (d1 f1 d2 ) + // non-star: (t4 t1 t2 t3 ) + // + // level 0: (t1 ), (t3 ), (f1 ), (d1 ), (t2 ), (d2 ), (t4 ) + // level 1: {f1 d2 }, {t1 t4 }, {t1 t2 }, {f1 d1 }, {t3 t4 } + // level 2: {d1 f1 d2 }, {t2 t1 t4 }, {t1 t3 t4 } + // level 3: {t4 d1 f1 d2 }, {t3 t4 t1 t2 } + // level 4: {d1 f1 t4 d2 t3 }, {d1 f1 t4 d2 t1 } + // level 5: {d1 f1 t4 d2 t1 t2 }, {d1 f1 t4 d2 t1 t3 } + // level 6: {d1 f1 t4 d2 t1 t2 t3 } + // + // Number of generated plans: 22 (vs. 34) + val query = + d1.join(t1).join(t2).join(t3).join(t4).join(f1).join(d2) + .where((nameToAttr("t1_c1") === nameToAttr("t2_c1")) && + (nameToAttr("t3_c1") === nameToAttr("t4_c1")) && + (nameToAttr("t1_c2") === nameToAttr("t4_c2")) && + (nameToAttr("d1_c2") === nameToAttr("t4_c3")) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk")) && + (nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) + + val expected = + f1.join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk"))) + .join(t3.join(t4, Inner, Some(nameToAttr("t3_c1") === nameToAttr("t4_c1"))), Inner, + Some(nameToAttr("t3_c1") === nameToAttr("t4_c1"))) + .join(t1.join(t2, Inner, Some(nameToAttr("t1_c1") === nameToAttr("t2_c1"))), Inner, + Some(nameToAttr("t1_c2") === nameToAttr("t4_c2"))) + + assertEqualPlans(query, expected) + } + + test("Test 4: Star with several branches") { + // + // d1 - t3 - t4 + // | + // f1 - d3 - t1 - t2 + // | + // d2 - t5 - t6 + // + // star: {d1 f1 d2 d3 } + // non-star: {t5 t3 t6 t2 t4 t1} + // + // level 0: (t4 ), (d2 ), (t5 ), (d3 ), (d1 ), (f1 ), (t2 ), (t6 ), (t1 ), (t3 ) + // level 1: {t5 t6 }, {t4 t3 }, {d3 f1 }, {t2 t1 }, {d2 f1 }, {d1 f1 } + // level 2: {d2 d1 f1 }, {d2 d3 f1 }, {d3 d1 f1 } + // level 3: {d2 d1 d3 f1 } + // level 4: {d1 t3 d3 f1 d2 }, {d1 d3 f1 t1 d2 }, {d1 t5 d3 f1 d2 } + // level 5: {d1 t5 d3 f1 t1 d2 }, {d1 t3 t4 d3 f1 d2 }, {d1 t5 t6 d3 f1 d2 }, + // {d1 t5 t3 d3 f1 d2 }, {d1 t3 d3 f1 t1 d2 }, {d1 t2 d3 f1 t1 d2 } + // level 6: {d1 t5 t3 t4 d3 f1 d2 }, {d1 t3 t2 d3 f1 t1 d2 }, {d1 t5 t6 d3 f1 t1 d2 }, + // {d1 t5 t3 d3 f1 t1 d2 }, {d1 t5 t2 d3 f1 t1 d2 }, ... + // ... + // level 9: {d1 t5 t3 t6 t2 t4 d3 f1 t1 d2 } + // + // Number of generated plans: 46 (vs. 82) + val query = + d1.join(t3).join(t4).join(f1).join(d2).join(t5).join(t6).join(d3).join(t1).join(t2) + .where((nameToAttr("d1_c2") === nameToAttr("t3_c1")) && + (nameToAttr("t3_c2") === nameToAttr("t4_c2")) && + (nameToAttr("d1_pk") === nameToAttr("f1_fk1")) && + (nameToAttr("f1_fk2") === nameToAttr("d2_pk")) && + (nameToAttr("d2_c2") === nameToAttr("t5_c1")) && + (nameToAttr("t5_c2") === nameToAttr("t6_c2")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk")) && + (nameToAttr("d3_c2") === nameToAttr("t1_c1")) && + (nameToAttr("t1_c2") === nameToAttr("t2_c2"))) + + val expected = + f1.join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk"))) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk"))) + .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) + .join(t4.join(t3, Inner, Some(nameToAttr("t3_c2") === nameToAttr("t4_c2"))), Inner, + Some(nameToAttr("d1_c2") === nameToAttr("t3_c1"))) + .join(t2.join(t1, Inner, Some(nameToAttr("t1_c2") === nameToAttr("t2_c2"))), Inner, + Some(nameToAttr("d3_c2") === nameToAttr("t1_c1"))) + .join(t5.join(t6, Inner, Some(nameToAttr("t5_c2") === nameToAttr("t6_c2"))), Inner, + Some(nameToAttr("d2_c2") === nameToAttr("t5_c1"))) + + assertEqualPlans(query, expected) + } + + test("Test 5: RI star only") { + // d1 + // | + // f1 + // / \ + // d2 d3 + // + // star: {f1, d1, d2, d3} + // non-star: {} + // level 0: (d1), (f1), (d2), (d3) + // level 1: {f1 d3 }, {f1 d2 }, {d1 f1 } + // level 2: {d1 f1 d2 }, {d2 f1 d3 }, {d1 f1 d3 } + // level 3: {d1 d2 f1 d3 } + // Number of generated plans: 11 (= 11) + val query = + d1.join(d2).join(f1).join(d3) + .where((nameToAttr("f1_fk1") === nameToAttr("d1_pk")) && + (nameToAttr("f1_fk2") === nameToAttr("d2_pk")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk"))) + + val expected = + f1.join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk"))) + .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk"))) + + assertEqualPlans(query, expected) + } + + test("Test 6: No RI star") { + // + // f1 - t1 - t2 - t3 + // + // star: {} + // non-star: {f1, t1, t2, t3} + // level 0: (t1), (f1), (t2), (t3) + // level 1: {f1 t3 }, {f1 t2 }, {t1 f1 } + // level 2: {t1 f1 t2 }, {t2 f1 t3 }, {dt f1 t3 } + // level 3: {t1 t2 f1 t3 } + // Number of generated plans: 11 (= 11) + val query = + t1.join(f1).join(t2).join(t3) + .where((nameToAttr("f1_fk1") === nameToAttr("t1_c1")) && + (nameToAttr("f1_fk2") === nameToAttr("t2_c1")) && + (nameToAttr("f1_fk3") === nameToAttr("t3_c1"))) + + val expected = + f1.join(t3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("t3_c1"))) + .join(t2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("t2_c1"))) + .join(t1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("t1_c1"))) + + assertEqualPlans(query, expected) + } + + private def assertEqualPlans( plan1: LogicalPlan, plan2: LogicalPlan): Unit = { + val optimized = Optimize.execute(plan1.analyze) + val expected = plan2.analyze + compareJoinOrder(optimized, expected) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index f078ef013387b..c4635c8f126af 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -411,4 +411,35 @@ class DataTypeSuite extends SparkFunSuite { checkCatalogString(ArrayType(createStruct(40))) checkCatalogString(MapType(IntegerType, StringType)) checkCatalogString(MapType(IntegerType, createStruct(40))) + + def checkEqualsStructurally(from: DataType, to: DataType, expected: Boolean): Unit = { + val testName = s"equalsStructurally: (from: $from, to: $to)" + test(testName) { + assert(DataType.equalsStructurally(from, to) === expected) + } + } + + checkEqualsStructurally(BooleanType, BooleanType, true) + checkEqualsStructurally(IntegerType, IntegerType, true) + checkEqualsStructurally(IntegerType, LongType, false) + checkEqualsStructurally(ArrayType(IntegerType, true), ArrayType(IntegerType, true), true) + checkEqualsStructurally(ArrayType(IntegerType, true), ArrayType(IntegerType, false), false) + + checkEqualsStructurally( + new StructType().add("f1", IntegerType), + new StructType().add("f2", IntegerType), + true) + checkEqualsStructurally( + new StructType().add("f1", IntegerType), + new StructType().add("f2", IntegerType, false), + false) + + checkEqualsStructurally( + new StructType().add("f1", IntegerType).add("f", new StructType().add("f2", StringType)), + new StructType().add("f2", IntegerType).add("g", new StructType().add("f1", StringType)), + true) + checkEqualsStructurally( + new StructType().add("f1", IntegerType).add("f", new StructType().add("f2", StringType, false)), + new StructType().add("f2", IntegerType).add("g", new StructType().add("f1", StringType)), + false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 49691c15d0f7d..c1b32917415ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -268,8 +268,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } /** - * Loads a JSON file (JSON Lines text format or - * newline-delimited JSON) and returns the result as a `DataFrame`. + * Loads a JSON file and returns the results as a `DataFrame`. + * * See the documentation on the overloaded `json()` method with varargs for more details. * * @since 1.4.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index 074952ff7900a..7e5da012f84ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -510,7 +510,7 @@ abstract class Catalog { def refreshTable(tableName: String): Unit /** - * Invalidates and refreshes all the cached data (and the associated metadata) for any [[Dataset]] + * Invalidates and refreshes all the cached data (and the associated metadata) for any `Dataset` * that contains the given data source path. Path matching is by prefix, i.e. "/" would invalidate * everything that is cached. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 3a9132d74ac11..866fa98533218 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource} @@ -516,10 +517,10 @@ case class FileSourceScanExec( override lazy val canonicalized: FileSourceScanExec = { FileSourceScanExec( relation, - output.map(normalizeExprId(_, output)), + output.map(QueryPlan.normalizeExprId(_, output)), requiredSchema, - partitionFilters.map(normalizeExprId(_, output)), - dataFilters.map(normalizeExprId(_, output)), + partitionFilters.map(QueryPlan.normalizeExprId(_, output)), + dataFilters.map(QueryPlan.normalizeExprId(_, output)), None) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala index 5687f9332430e..e0d0029369576 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -51,6 +51,7 @@ case class CreateFunctionCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog + val func = CatalogFunction(FunctionIdentifier(functionName, databaseName), className, resources) if (isTemp) { if (databaseName.isDefined) { throw new AnalysisException(s"Specifying a database in CREATE TEMPORARY FUNCTION " + @@ -59,17 +60,13 @@ case class CreateFunctionCommand( // We first load resources and then put the builder in the function registry. // Please note that it is allowed to overwrite an existing temp function. catalog.loadFunctionResources(resources) - val info = new ExpressionInfo(className, functionName) - val builder = catalog.makeFunctionBuilder(functionName, className) - catalog.createTempFunction(functionName, info, builder, ignoreIfExists = false) + catalog.registerFunction(func, ignoreIfExists = false) } else { // For a permanent, we will store the metadata into underlying external catalog. // This function will be loaded into the FunctionRegistry when a query uses it. // We do not load it into FunctionRegistry right now. // TODO: should we also parse "IF NOT EXISTS"? - catalog.createFunction( - CatalogFunction(FunctionIdentifier(functionName, databaseName), className, resources), - ignoreIfExists = false) + catalog.createFunction(func, ignoreIfExists = false) } Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index bda64d4b91bbc..4ec09bff429c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -324,8 +324,11 @@ object FileFormatWriter extends Logging { override def releaseResources(): Unit = { if (currentWriter != null) { - currentWriter.close() - currentWriter = null + try { + currentWriter.close() + } finally { + currentWriter = null + } } } } @@ -459,8 +462,11 @@ object FileFormatWriter extends Logging { override def releaseResources(): Unit = { if (currentWriter != null) { - currentWriter.close() - currentWriter = null + try { + currentWriter.close() + } finally { + currentWriter = null + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 8857966676ae2..bcf0d970f7ec1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -284,42 +284,38 @@ class StreamExecution( triggerExecutor.execute(() => { startTrigger() - val continueToRun = - if (isActive) { - reportTimeTaken("triggerExecution") { - if (currentBatchId < 0) { - // We'll do this initialization only once - populateStartOffsets(sparkSessionToRunBatches) - logDebug(s"Stream running from $committedOffsets to $availableOffsets") - } else { - constructNextBatch() - } - if (dataAvailable) { - currentStatus = currentStatus.copy(isDataAvailable = true) - updateStatusMessage("Processing new data") - runBatch(sparkSessionToRunBatches) - } + if (isActive) { + reportTimeTaken("triggerExecution") { + if (currentBatchId < 0) { + // We'll do this initialization only once + populateStartOffsets(sparkSessionToRunBatches) + logDebug(s"Stream running from $committedOffsets to $availableOffsets") + } else { + constructNextBatch() } - // Report trigger as finished and construct progress object. - finishTrigger(dataAvailable) if (dataAvailable) { - // Update committed offsets. - batchCommitLog.add(currentBatchId) - committedOffsets ++= availableOffsets - logDebug(s"batch ${currentBatchId} committed") - // We'll increase currentBatchId after we complete processing current batch's data - currentBatchId += 1 - } else { - currentStatus = currentStatus.copy(isDataAvailable = false) - updateStatusMessage("Waiting for data to arrive") - Thread.sleep(pollingDelayMs) + currentStatus = currentStatus.copy(isDataAvailable = true) + updateStatusMessage("Processing new data") + runBatch(sparkSessionToRunBatches) } - true + } + // Report trigger as finished and construct progress object. + finishTrigger(dataAvailable) + if (dataAvailable) { + // Update committed offsets. + batchCommitLog.add(currentBatchId) + committedOffsets ++= availableOffsets + logDebug(s"batch ${currentBatchId} committed") + // We'll increase currentBatchId after we complete processing current batch's data + currentBatchId += 1 } else { - false + currentStatus = currentStatus.copy(isDataAvailable = false) + updateStatusMessage("Waiting for data to arrive") + Thread.sleep(pollingDelayMs) } + } updateStatusMessage("Waiting for next trigger") - continueToRun + isActive }) updateStatusMessage("Stopped") } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 7a7d52b21427a..e66fe97afad45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -26,7 +26,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.CleanerListener import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.SubqueryExpression -import org.apache.spark.sql.execution.RDDScanExec +import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.functions._ @@ -76,6 +76,13 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext sum } + private def getNumInMemoryTablesRecursively(plan: SparkPlan): Int = { + plan.collect { + case InMemoryTableScanExec(_, _, relation) => + getNumInMemoryTablesRecursively(relation.child) + 1 + }.sum + } + test("withColumn doesn't invalidate cached dataframe") { var evalCount = 0 val myUDF = udf((x: String) => { evalCount += 1; "result" }) @@ -670,4 +677,138 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext assert(spark.read.parquet(path).filter($"id" > 4).count() == 15) } } + + test("SPARK-19993 simple subquery caching") { + withTempView("t1", "t2") { + Seq(1).toDF("c1").createOrReplaceTempView("t1") + Seq(2).toDF("c1").createOrReplaceTempView("t2") + + sql( + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t2) + """.stripMargin).cache() + + val cachedDs = + sql( + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t2) + """.stripMargin) + assert(getNumInMemoryRelations(cachedDs) == 1) + + // Additional predicate in the subquery plan should cause a cache miss + val cachedMissDs = + sql( + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t2 where c1 = 0) + """.stripMargin) + assert(getNumInMemoryRelations(cachedMissDs) == 0) + } + } + + test("SPARK-19993 subquery caching with correlated predicates") { + withTempView("t1", "t2") { + Seq(1).toDF("c1").createOrReplaceTempView("t1") + Seq(1).toDF("c1").createOrReplaceTempView("t2") + + // Simple correlated predicate in subquery + sql( + """ + |SELECT * FROM t1 + |WHERE + |t1.c1 in (SELECT t2.c1 FROM t2 where t1.c1 = t2.c1) + """.stripMargin).cache() + + val cachedDs = + sql( + """ + |SELECT * FROM t1 + |WHERE + |t1.c1 in (SELECT t2.c1 FROM t2 where t1.c1 = t2.c1) + """.stripMargin) + assert(getNumInMemoryRelations(cachedDs) == 1) + } + } + + test("SPARK-19993 subquery with cached underlying relation") { + withTempView("t1") { + Seq(1).toDF("c1").createOrReplaceTempView("t1") + spark.catalog.cacheTable("t1") + + // underlying table t1 is cached as well as the query that refers to it. + val ds = + sql( + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t1) + """.stripMargin) + assert(getNumInMemoryRelations(ds) == 2) + + val cachedDs = + sql( + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t1) + """.stripMargin).cache() + assert(getNumInMemoryTablesRecursively(cachedDs.queryExecution.sparkPlan) == 3) + } + } + + test("SPARK-19993 nested subquery caching and scalar + predicate subqueris") { + withTempView("t1", "t2", "t3", "t4") { + Seq(1).toDF("c1").createOrReplaceTempView("t1") + Seq(2).toDF("c1").createOrReplaceTempView("t2") + Seq(1).toDF("c1").createOrReplaceTempView("t3") + Seq(1).toDF("c1").createOrReplaceTempView("t4") + + // Nested predicate subquery + sql( + """ + |SELECT * FROM t1 + |WHERE + |c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1)) + """.stripMargin).cache() + + val cachedDs = + sql( + """ + |SELECT * FROM t1 + |WHERE + |c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1)) + """.stripMargin) + assert(getNumInMemoryRelations(cachedDs) == 1) + + // Scalar subquery and predicate subquery + sql( + """ + |SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1) + |WHERE + |c1 = (SELECT max(c1) FROM t2 GROUP BY c1) + |OR + |EXISTS (SELECT c1 FROM t3) + |OR + |c1 IN (SELECT c1 FROM t4) + """.stripMargin).cache() + + val cachedDs2 = + sql( + """ + |SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1) + |WHERE + |c1 = (SELECT max(c1) FROM t2 GROUP BY c1) + |OR + |EXISTS (SELECT c1 FROM t3) + |OR + |c1 IN (SELECT c1 FROM t4) + """.stripMargin) + assert(getNumInMemoryRelations(cachedDs2) == 1) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala index 0f3d0cefe3bb5..92c5656f65bb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala @@ -56,7 +56,9 @@ object TestRegistrator { def apply(): TestRegistrator = new TestRegistrator() } -/** A [[Serializer]] that takes a [[KryoData]] and serializes it as KryoData(0). */ +/** + * A `Serializer` that takes a [[KryoData]] and serializes it as KryoData(0). + */ class ZeroKryoDataSerializer extends Serializer[KryoData] { override def write(kryo: Kryo, output: Output, t: KryoData): Unit = { output.writeInt(0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 6469e501c1f68..8f9c52cb1e031 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -75,9 +75,10 @@ class CatalogSuite } private def createTempFunction(name: String): Unit = { - val info = new ExpressionInfo("className", name) val tempFunc = (e: Seq[Expression]) => e.head - sessionCatalog.createTempFunction(name, info, tempFunc, ignoreIfExists = false) + val funcMeta = CatalogFunction(FunctionIdentifier(name, None), "className", Nil) + sessionCatalog.registerFunction( + funcMeta, ignoreIfExists = false, functionBuilder = Some(tempFunc)) } private def dropFunction(name: String, db: Option[String] = None): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 26967782f77c7..2108b118bf059 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -44,8 +44,8 @@ abstract class FileStreamSourceTest import testImplicits._ /** - * A subclass [[AddData]] for adding data to files. This is meant to use the - * [[FileStreamSource]] actually being used in the execution. + * A subclass `AddData` for adding data to files. This is meant to use the + * `FileStreamSource` actually being used in the execution. */ abstract class AddFileData extends AddData { override def addData(query: Option[StreamExecution]): (Source, Offset) = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 5ab9dc2bc7763..13fe51a557733 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -569,7 +569,7 @@ class ThrowingIOExceptionLikeHadoop12074 extends FakeSource { object ThrowingIOExceptionLikeHadoop12074 { /** - * A latch to allow the user to wait until [[ThrowingIOExceptionLikeHadoop12074.createSource]] is + * A latch to allow the user to wait until `ThrowingIOExceptionLikeHadoop12074.createSource` is * called. */ @volatile var createSourceLatch: CountDownLatch = null @@ -600,7 +600,7 @@ class ThrowingInterruptedIOException extends FakeSource { object ThrowingInterruptedIOException { /** - * A latch to allow the user to wait until [[ThrowingInterruptedIOException.createSource]] is + * A latch to allow the user to wait until `ThrowingInterruptedIOException.createSource` is * called. */ @volatile var createSourceLatch: CountDownLatch = null diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 03aa45b616880..5bc36dd30f6d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -277,6 +277,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { def threadState = if (currentStream != null && currentStream.microBatchThread.isAlive) "alive" else "dead" + def threadStackTrace = if (currentStream != null && currentStream.microBatchThread.isAlive) { + s"Thread stack trace: ${currentStream.microBatchThread.getStackTrace.mkString("\n")}" + } else { + "" + } def testState = s""" @@ -287,6 +292,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { |Output Mode: $outputMode |Stream state: $currentOffsets |Thread state: $threadState + |$threadStackTrace |${if (streamThreadDeathCause != null) stackTraceToString(streamThreadDeathCause) else ""} | |== Sink == diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 2ebbfcd22b97c..b69536ed37463 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -642,8 +642,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi * * @param expectedBehavior Expected behavior (not blocked, blocked, or exception thrown) * @param timeoutMs Timeout in milliseconds - * When timeoutMs <= 0, awaitTermination() is tested (i.e. w/o timeout) - * When timeoutMs > 0, awaitTermination(timeoutMs) is tested + * When timeoutMs is less than or equal to 0, awaitTermination() is + * tested (i.e. w/o timeout) + * When timeoutMs is greater than 0, awaitTermination(timeoutMs) is + * tested * @param expectedReturnValue Expected return value when awaitTermination(timeoutMs) is used */ case class TestAwaitTermination( @@ -667,8 +669,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi * * @param expectedBehavior Expected behavior (not blocked, blocked, or exception thrown) * @param timeoutMs Timeout in milliseconds - * When timeoutMs <= 0, awaitTermination() is tested (i.e. w/o timeout) - * When timeoutMs > 0, awaitTermination(timeoutMs) is tested + * When timeoutMs is less than or equal to 0, awaitTermination() is + * tested (i.e. w/o timeout) + * When timeoutMs is greater than 0, awaitTermination(timeoutMs) is + * tested * @param expectedReturnValue Expected return value when awaitTermination(timeoutMs) is used */ def assertOnQueryCondition( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index cab219216d1ca..6a4cc95d36bea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -41,11 +41,11 @@ import org.apache.spark.util.{UninterruptibleThread, Utils} /** * Helper trait that should be extended by all SQL test suites. * - * This allows subclasses to plugin a custom [[SQLContext]]. It comes with test data + * This allows subclasses to plugin a custom `SQLContext`. It comes with test data * prepared in advance as well as all implicit conversions used extensively by dataframes. - * To use implicit methods, import `testImplicits._` instead of through the [[SQLContext]]. + * To use implicit methods, import `testImplicits._` instead of through the `SQLContext`. * - * Subclasses should *not* create [[SQLContext]]s in the test suite constructor, which is + * Subclasses should *not* create `SQLContext`s in the test suite constructor, which is * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. */ private[sql] trait SQLTestUtils @@ -65,7 +65,7 @@ private[sql] trait SQLTestUtils * A helper object for importing SQL implicits. * * Note that the alternative of importing `spark.implicits._` is not possible here. - * This is because we create the [[SQLContext]] immediately before the first test is run, + * This is because we create the `SQLContext` immediately before the first test is run, * but the implicits import is needed in the constructor. */ protected object testImplicits extends SQLImplicits { @@ -73,7 +73,7 @@ private[sql] trait SQLTestUtils } /** - * Materialize the test data immediately after the [[SQLContext]] is set up. + * Materialize the test data immediately after the `SQLContext` is set up. * This is necessary if the data is accessed by name but not through direct reference. */ protected def setupTestData(): Unit = { @@ -250,8 +250,8 @@ private[sql] trait SQLTestUtils } /** - * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier - * way to construct [[DataFrame]] directly out of local data without relying on implicits. + * Turn a logical plan into a `DataFrame`. This should be removed once we have an easier + * way to construct `DataFrame` directly out of local data without relying on implicits. */ protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { Dataset.ofRows(spark, plan) @@ -271,7 +271,9 @@ private[sql] trait SQLTestUtils } } - /** Run a test on a separate [[UninterruptibleThread]]. */ + /** + * Run a test on a separate `UninterruptibleThread`. + */ protected def testWithUninterruptibleThread(name: String, quietly: Boolean = false) (body: => Unit): Unit = { val timeoutMillis = 10000 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index b01977a23890f..959edf9a49371 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.internal.{SessionState, SessionStateBuilder, SQLConf, WithTestConf} /** - * A special [[SparkSession]] prepared for testing. + * A special `SparkSession` prepared for testing. */ private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self => def this(sparkConf: SparkConf) { diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/Service.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/Service.java index b95077cd62186..0d0e3e4011b5b 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/Service.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/Service.java @@ -49,7 +49,7 @@ enum STATE { * The transition must be from {@link STATE#NOTINITED} to {@link STATE#INITED} unless the * operation failed and an exception was raised. * - * @param config + * @param conf * the configuration of the service */ void init(HiveConf conf); diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceOperations.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceOperations.java index a2c580d6acc71..c3219aabfc23b 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceOperations.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceOperations.java @@ -51,7 +51,7 @@ public static void ensureCurrentState(Service.STATE state, /** * Initialize a service. - *

+ * * The service state is checked before the operation begins. * This process is not thread safe. * @param service a service that must be in the state @@ -69,7 +69,7 @@ public static void init(Service service, HiveConf configuration) { /** * Start a service. - *

+ * * The service state is checked before the operation begins. * This process is not thread safe. * @param service a service that must be in the state @@ -86,7 +86,7 @@ public static void start(Service service) { /** * Initialize then start a service. - *

+ * * The service state is checked before the operation begins. * This process is not thread safe. * @param service a service that must be in the state @@ -102,9 +102,9 @@ public static void deploy(Service service, HiveConf configuration) { /** * Stop a service. - *

Do nothing if the service is null or not - * in a state in which it can be/needs to be stopped. - *

+ * + * Do nothing if the service is null or not in a state in which it can be/needs to be stopped. + * * The service state is checked before the operation begins. * This process is not thread safe. * @param service a service or null diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HttpAuthUtils.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HttpAuthUtils.java index 5021528299682..f7375ee707830 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HttpAuthUtils.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HttpAuthUtils.java @@ -89,7 +89,7 @@ public static String getKerberosServiceTicket(String principal, String host, * @param clientUserName Client User name. * @return An unsigned cookie token generated from input parameters. * The final cookie generated is of the following format : - * cu=&rn=&s= + * {@code cu=&rn=&s=} */ public static String createCookieToken(String clientUserName) { StringBuffer sb = new StringBuffer(); diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/PasswdAuthenticationProvider.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/PasswdAuthenticationProvider.java index e2a6de165adc5..1af1c1d06e7f7 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/PasswdAuthenticationProvider.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/PasswdAuthenticationProvider.java @@ -26,7 +26,7 @@ public interface PasswdAuthenticationProvider { * to authenticate users for their requests. * If a user is to be granted, return nothing/throw nothing. * When a user is to be disallowed, throw an appropriate {@link AuthenticationException}. - *

+ * * For an example implementation, see {@link LdapAuthenticationProviderImpl}. * * @param user The username received over the connection request diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/TSetIpAddressProcessor.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/TSetIpAddressProcessor.java index 645e3e2bbd4e2..9a61ad49942c8 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/TSetIpAddressProcessor.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/TSetIpAddressProcessor.java @@ -31,12 +31,9 @@ /** * This class is responsible for setting the ipAddress for operations executed via HiveServer2. - *

- *

    - *
  • IP address is only set for operations that calls listeners with hookContext
  • - *
  • IP address is only set if the underlying transport mechanism is socket
  • - *
- *

+ * + * - IP address is only set for operations that calls listeners with hookContext + * - IP address is only set if the underlying transport mechanism is socket * * @see org.apache.hadoop.hive.ql.hooks.ExecuteWithHookContext */ diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIServiceUtils.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIServiceUtils.java index 9d64b102e008d..bf2380632fa6c 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIServiceUtils.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIServiceUtils.java @@ -38,7 +38,7 @@ public class CLIServiceUtils { * Convert a SQL search pattern into an equivalent Java Regex. * * @param pattern input which may contain '%' or '_' wildcard characters, or - * these characters escaped using {@link #getSearchStringEscape()}. + * these characters escaped using {@code getSearchStringEscape()}. * @return replace %/_ with regex search characters, also handle escaped * characters. */ diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ClassicTableTypeMapping.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ClassicTableTypeMapping.java index 05a6bf938404b..af36057bdaeca 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ClassicTableTypeMapping.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ClassicTableTypeMapping.java @@ -28,9 +28,9 @@ /** * ClassicTableTypeMapping. * Classic table type mapping : - * Managed Table ==> Table - * External Table ==> Table - * Virtual View ==> View + * Managed Table to Table + * External Table to Table + * Virtual View to View */ public class ClassicTableTypeMapping implements TableTypeMapping { diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/TableTypeMapping.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/TableTypeMapping.java index e392c459cf586..e59d19ea6be42 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/TableTypeMapping.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/TableTypeMapping.java @@ -31,7 +31,7 @@ public interface TableTypeMapping { /** * Map hive's table type name to client's table type - * @param clientTypeName + * @param hiveTypeName * @return */ String mapToClientType(String hiveTypeName); diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/SessionManager.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/SessionManager.java index de066dd406c7a..c1b3892f52060 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/SessionManager.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/SessionManager.java @@ -224,7 +224,9 @@ public SessionHandle openSession(TProtocolVersion protocol, String username, Str * The username passed to this method is the effective username. * If withImpersonation is true (==doAs true) we wrap all the calls in HiveSession * within a UGI.doAs, where UGI corresponds to the effective user. - * @see org.apache.hive.service.cli.thrift.ThriftCLIService#getUserName() + * + * Please see {@code org.apache.hive.service.cli.thrift.ThriftCLIService.getUserName()} for + * more details. * * @param protocol * @param username diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/ThreadFactoryWithGarbageCleanup.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/ThreadFactoryWithGarbageCleanup.java index fb8141a905acb..94f8126552e9d 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/ThreadFactoryWithGarbageCleanup.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/ThreadFactoryWithGarbageCleanup.java @@ -30,12 +30,12 @@ * in custom cleanup code to be called before this thread is GC-ed. * Currently cleans up the following: * 1. ThreadLocal RawStore object: - * In case of an embedded metastore, HiveServer2 threads (foreground & background) + * In case of an embedded metastore, HiveServer2 threads (foreground and background) * end up caching a ThreadLocal RawStore object. The ThreadLocal RawStore object has - * an instance of PersistenceManagerFactory & PersistenceManager. + * an instance of PersistenceManagerFactory and PersistenceManager. * The PersistenceManagerFactory keeps a cache of PersistenceManager objects, * which are only removed when PersistenceManager#close method is called. - * HiveServer2 uses ExecutorService for managing thread pools for foreground & background threads. + * HiveServer2 uses ExecutorService for managing thread pools for foreground and background threads. * ExecutorService unfortunately does not provide any hooks to be called, * when a thread from the pool is terminated. * As a solution, we're using this ThreadFactory to keep a cache of RawStore objects per thread. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 6f5b923cd4f9e..4dec2f71b8a50 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -53,8 +53,8 @@ import org.apache.spark.unsafe.types.UTF8String * java.sql.Date * java.sql.Timestamp * Complex Types => - * Map: [[MapData]] - * List: [[ArrayData]] + * Map: `MapData` + * List: `ArrayData` * Struct: [[org.apache.spark.sql.catalyst.InternalRow]] * Union: NOT SUPPORTED YET * The Complex types plays as a container, which can hold arbitrary data types. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index c917f110b90f2..377d4f2473c58 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -31,8 +31,8 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.catalog.{FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} -import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExpressionInfo} +import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.internal.SQLConf @@ -124,13 +124,6 @@ private[sql] class HiveSessionCatalog( } private def lookupFunction0(name: FunctionIdentifier, children: Seq[Expression]): Expression = { - // TODO: Once lookupFunction accepts a FunctionIdentifier, we should refactor this method to - // if (super.functionExists(name)) { - // super.lookupFunction(name, children) - // } else { - // // This function is a Hive builtin function. - // ... - // } val database = name.database.map(formatDatabaseName) val funcName = name.copy(database = database) Try(super.lookupFunction(funcName, children)) match { @@ -164,10 +157,11 @@ private[sql] class HiveSessionCatalog( } } val className = functionInfo.getFunctionClass.getName - val builder = makeFunctionBuilder(functionName, className) + val functionIdentifier = + FunctionIdentifier(functionName.toLowerCase(Locale.ROOT), database) + val func = CatalogFunction(functionIdentifier, className, Nil) // Put this Hive built-in function to our function registry. - val info = new ExpressionInfo(className, functionName) - createTempFunction(functionName, info, builder, ignoreIfExists = false) + registerFunction(func, ignoreIfExists = false) // Now, we need to create the Expression. functionRegistry.lookupFunction(functionName, children) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index fab0d7fa84827..666548d1a490b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.CatalogRelation import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.hive._ @@ -203,9 +204,9 @@ case class HiveTableScanExec( override lazy val canonicalized: HiveTableScanExec = { val input: AttributeSeq = relation.output HiveTableScanExec( - requestedAttributes.map(normalizeExprId(_, input)), + requestedAttributes.map(QueryPlan.normalizeExprId(_, input)), relation.canonicalized.asInstanceOf[CatalogRelation], - partitionPruningPred.map(normalizeExprId(_, input)))(sparkSession) + partitionPruningPred.map(QueryPlan.normalizeExprId(_, input)))(sparkSession) } override def otherCopyArgs: Seq[AnyRef] = Seq(sparkSession) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala index 197110f4912a7..73383ae4d4118 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala @@ -22,7 +22,9 @@ import scala.concurrent.duration._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFPercentileApprox import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.expressions.{ExpressionInfo, Literal} +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogFunction +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile import org.apache.spark.sql.hive.HiveSessionCatalog import org.apache.spark.sql.hive.execution.TestingTypedCount @@ -217,9 +219,9 @@ class ObjectHashAggregateExecBenchmark extends BenchmarkBase with TestHiveSingle private def registerHiveFunction(functionName: String, clazz: Class[_]): Unit = { val sessionCatalog = sparkSession.sessionState.catalog.asInstanceOf[HiveSessionCatalog] - val builder = sessionCatalog.makeFunctionBuilder(functionName, clazz.getName) - val info = new ExpressionInfo(clazz.getName, functionName) - sessionCatalog.createTempFunction(functionName, info, builder, ignoreIfExists = false) + val functionIdentifier = FunctionIdentifier(functionName, database = None) + val func = CatalogFunction(functionIdentifier, clazz.getName, resources = Nil) + sessionCatalog.registerFunction(func, ignoreIfExists = false) } private def percentile_approx( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala index e772324a57ab8..bb4ce6d3aa3f1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.util._ /** * A framework for running the query tests that are listed as a set of text files. * - * TestSuites that derive from this class must provide a map of testCaseName -> testCaseFiles + * TestSuites that derive from this class must provide a map of testCaseName to testCaseFiles * that should be included. Additionally, there is support for whitelisting and blacklisting * tests as development progresses. */ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index 7226ed521ef32..a2f08c5ba72c6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -43,7 +43,7 @@ private[sql] trait OrcTest extends SQLTestUtils with TestHiveSingleton { } /** - * Writes `data` to a Orc file and reads it back as a [[DataFrame]], + * Writes `data` to a Orc file and reads it back as a `DataFrame`, * which is then passed to `f`. The Orc file will be deleted after `f` returns. */ protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag] @@ -53,7 +53,7 @@ private[sql] trait OrcTest extends SQLTestUtils with TestHiveSingleton { } /** - * Writes `data` to a Orc file, reads it back as a [[DataFrame]] and registers it as a + * Writes `data` to a Orc file, reads it back as a `DataFrame` and registers it as a * temporary table named `tableName`, then call `f`. The temporary table together with the * Orc file will be dropped/deleted after `f` returns. */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala index 58b7031d5ea6a..15d3c7e54b8dd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala @@ -29,7 +29,7 @@ import org.apache.spark.streaming.util.{EmptyStateMap, StateMap} import org.apache.spark.util.Utils /** - * Record storing the keyed-state [[MapWithStateRDD]]. Each record contains a [[StateMap]] and a + * Record storing the keyed-state [[MapWithStateRDD]]. Each record contains a `StateMap` and a * sequence of records returned by the mapping function of `mapWithState`. */ private[streaming] case class MapWithStateRDDRecord[K, S, E]( @@ -111,7 +111,7 @@ private[streaming] class MapWithStateRDDPartition( /** * RDD storing the keyed states of `mapWithState` operation and corresponding mapped data. * Each partition of this RDD has a single record of type [[MapWithStateRDDRecord]]. This contains a - * [[StateMap]] (containing the keyed-states) and the sequence of records returned by the mapping + * `StateMap` (containing the keyed-states) and the sequence of records returned by the mapping * function of `mapWithState`. * @param prevStateRDD The previous MapWithStateRDD on whose StateMap data `this` RDD * will be created diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala index a73e6cc2cd9c1..dc02062b9eb44 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala @@ -26,7 +26,7 @@ import org.apache.spark.internal.Logging * case of Spark Streaming the error is the difference between the measured processing * rate (number of elements/processing delay) and the previous rate. * - * @see https://en.wikipedia.org/wiki/PID_controller + * @see PID controller (Wikipedia) * * @param batchIntervalMillis the batch duration, in milliseconds * @param proportional how much the correction should depend on the current diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala index 7b2ef6881d6f7..e4b9dffee04f4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala @@ -24,7 +24,7 @@ import org.apache.spark.streaming.Duration * A component that estimates the rate at which an `InputDStream` should ingest * records, based on updates at every batch completion. * - * @see [[org.apache.spark.streaming.scheduler.RateController]] + * Please see `org.apache.spark.streaming.scheduler.RateController` for more details. */ private[streaming] trait RateEstimator extends Serializable {